diff --git a/front/admin/cli.ts b/front/admin/cli.ts index 12db8309225f..8602e1312bbf 100644 --- a/front/admin/cli.ts +++ b/front/admin/cli.ts @@ -4,19 +4,13 @@ import { Storage } from "@google-cloud/storage"; import parseArgs from "minimist"; import readline from "readline"; -import { subscriptionForWorkspace } from "@app/lib/auth"; -import { - DataSource, - EventSchema, - Membership, - User, - Workspace, -} from "@app/lib/models"; +import { DataSource, EventSchema, User, Workspace } from "@app/lib/models"; import { FREE_UPGRADED_PLAN_CODE } from "@app/lib/plans/plan_codes"; import { internalSubscribeWorkspaceToFreeNoPlan, internalSubscribeWorkspaceToFreePlan, } from "@app/lib/plans/subscription"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { generateModelSId } from "@app/lib/utils"; import logger from "@app/logger/logger"; @@ -25,105 +19,6 @@ const { DUST_DATA_SOURCES_BUCKET = "", SERVICE_ACCOUNT } = process.env; // `cli` takes an object type and a command as first two arguments and then a list of arguments. const workspace = async (command: string, args: parseArgs.ParsedArgs) => { switch (command) { - case "find": { - if (!args.name) { - throw new Error("Missing --name argument"); - } - - const workspaces = await Workspace.findAll({ - where: { - name: args.name, - }, - }); - - workspaces.forEach((w) => { - console.log(`> wId='${w.sId}' name='${w.name}'`); - }); - return; - } - - case "show": { - if (!args.wId) { - throw new Error("Missing --wId argument"); - } - - const w = await Workspace.findOne({ - where: { - sId: args.wId, - }, - }); - - if (!w) { - throw new Error(`Workspace not found: wId='${args.wId}'`); - } - - console.log(`workspace:`); - console.log(` wId: ${w.sId}`); - console.log(` name: ${w.name}`); - - const subscription = await subscriptionForWorkspace(w.sId); - const plan = subscription.plan; - console.log(` plan:`); - console.log(` limits:`); - console.log(` dataSources:`); - console.log(` count: ${plan.limits.dataSources.count}`); - console.log(` documents:`); - console.log( - ` count: ${plan.limits.dataSources.documents.count}` - ); - console.log( - ` sizeMb: ${plan.limits.dataSources.documents.sizeMb}` - ); - console.log( - ` managed Slack: ${plan.limits.connections.isSlackAllowed}` - ); - console.log( - ` managed Notion: ${plan.limits.connections.isNotionAllowed}` - ); - console.log( - ` managed Github: ${plan.limits.connections.isGithubAllowed}` - ); - console.log( - ` managed Intercom: ${plan.limits.connections.isIntercomAllowed}` - ); - console.log( - ` managed Google Drive: ${plan.limits.connections.isGoogleDriveAllowed}` - ); - - const dataSources = await DataSource.findAll({ - where: { - workspaceId: w.id, - }, - }); - - console.log("Data sources:"); - dataSources.forEach((ds) => { - console.log(` - name: ${ds.name} provider: ${ds.connectorProvider}`); - }); - - const memberships = await Membership.findAll({ - where: { - workspaceId: w.id, - }, - }); - const users = await User.findAll({ - where: { - id: memberships.map((m) => m.userId), - }, - }); - - console.log("Users:"); - users.forEach((u) => { - console.log( - ` - userId: ${u.id} email: ${u.email} role: ${ - memberships.find((m) => m.userId === u.id)?.role - }` - ); - }); - - return; - } - case "create": { if (!args.name) { throw new Error("Missing --name argument"); @@ -182,98 +77,10 @@ const workspace = async (command: string, args: parseArgs.ParsedArgs) => { return; } - case "add-user": { - if (!args.wId) { - throw new Error("Missing --wId argument"); - } - if (!args.userId) { - throw new Error("Missing --userId argument"); - } - if (!args.role) { - throw new Error("Missing --role argument"); - } - if (!["admin", "builder", "user"].includes(args.role)) { - throw new Error(`Invalid --role: ${args.role}`); - } - const role = args.role as "admin" | "builder" | "user"; - - const w = await Workspace.findOne({ - where: { - sId: args.wId, - }, - }); - if (!w) { - throw new Error(`Workspace not found: wId='${args.wId}'`); - } - const u = await User.findOne({ - where: { - id: args.userId, - }, - }); - if (!u) { - throw new Error(`User not found: userId='${args.userId}'`); - } - await Membership.create({ - role, - workspaceId: w.id, - userId: u.id, - startAt: new Date(), - }); - return; - } - - case "change-role": { - if (!args.wId) { - throw new Error("Missing --wId argument"); - } - if (!args.userId) { - throw new Error("Missing --userId argument"); - } - if (!args.role) { - throw new Error("Missing --role argument"); - } - if (!["admin", "builder", "user", "revoked"].includes(args.role)) { - throw new Error(`Invalid --role: ${args.role}`); - } - const role = args.role as "admin" | "builder" | "user" | "revoked"; - - const w = await Workspace.findOne({ - where: { - sId: args.wId, - }, - }); - if (!w) { - throw new Error(`Workspace not found: wId='${args.wId}'`); - } - const u = await User.findOne({ - where: { - id: args.userId, - }, - }); - if (!u) { - throw new Error(`User not found: userId='${args.userId}'`); - } - const m = await Membership.findOne({ - where: { - workspaceId: w.id, - userId: u.id, - }, - }); - if (!m) { - throw new Error( - `User is not a member of workspace: userId='${args.userId}' wId='${args.wId}'` - ); - } - - m.role = role; - await m.save(); - return; - } - default: console.log(`Unknown workspace command: ${command}`); console.log( - "Possible values: `find`, `show`, `create`, `set-limits`, `add-user`, `change-role`, `upgrade`, `downgrade`" + "Possible values: `find`, `show`, `create`, `set-limits`, `upgrade`, `downgrade`" ); } }; @@ -319,10 +126,8 @@ const user = async (command: string, args: parseArgs.ParsedArgs) => { console.log(` name: ${u.name}`); console.log(` email: ${u.email}`); - const memberships = await Membership.findAll({ - where: { - userId: u.id, - }, + const memberships = await MembershipResource.getLatestMemberships({ + userIds: [u.id], }); const workspaces = await Workspace.findAll({ @@ -338,7 +143,9 @@ const user = async (command: string, args: parseArgs.ParsedArgs) => { console.log(` - wId: ${w.sId}`); console.log(` name: ${w.name}`); if (m) { - console.log(` role: ${m.role}`); + console.log(` role: ${m.isRevoked() ? "revoked" : m.role}`); + console.log(` startAt: ${m.startAt}`); + console.log(` endAt: ${m.endAt}`); } }); diff --git a/front/admin/db.ts b/front/admin/db.ts index 593c47e7659d..7a881e17e946 100644 --- a/front/admin/db.ts +++ b/front/admin/db.ts @@ -17,7 +17,6 @@ import { EventSchema, ExtractedEvent, Key, - Membership, MembershipInvitation, Mention, Message, @@ -47,6 +46,7 @@ import { import { ConversationClassification } from "@app/lib/models/conversation_classification"; import { FeatureFlag } from "@app/lib/models/feature_flag"; import { ContentFragmentModel } from "@app/lib/resources/storage/models/content_fragment"; +import { MembershipModel } from "@app/lib/resources/storage/models/membership"; import { TemplateModel } from "@app/lib/resources/storage/models/templates"; async function main() { @@ -54,7 +54,7 @@ async function main() { await UserMetadata.sync({ alter: true }); await Workspace.sync({ alter: true }); await WorkspaceHasDomain.sync({ alter: true }); - await Membership.sync({ alter: true }); + await MembershipModel.sync({ alter: true }); await MembershipInvitation.sync({ alter: true }); await App.sync({ alter: true }); await Dataset.sync({ alter: true }); diff --git a/front/lib/amplitude/node/index.ts b/front/lib/amplitude/node/index.ts index cbf86c6ba2b8..762b66d5aabb 100644 --- a/front/lib/amplitude/node/index.ts +++ b/front/lib/amplitude/node/index.ts @@ -24,9 +24,9 @@ import { import { isGlobalAgentId } from "@app/lib/api/assistant/global_agents"; import type { Authenticator } from "@app/lib/auth"; import { subscriptionForWorkspace } from "@app/lib/auth"; -import { Membership } from "@app/lib/models"; import { User, Workspace } from "@app/lib/models"; import { countActiveSeatsInWorkspace } from "@app/lib/plans/workspace_usage"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; let BACKEND_CLIENT: Ampli | null = null; @@ -54,10 +54,8 @@ export function getBackendClient() { export async function trackUserMemberships(userId: ModelId) { const amplitude = getBackendClient(); const user = await User.findByPk(userId); - const memberships = await Membership.findAll({ - where: { - userId: userId, - }, + const memberships = await MembershipResource.getActiveMemberships({ + userIds: [userId], }); const groups: string[] = []; for (const membership of memberships) { diff --git a/front/lib/api/workspace.ts b/front/lib/api/workspace.ts index cd2a72629cfb..7b84059ffaa1 100644 --- a/front/lib/api/workspace.ts +++ b/front/lib/api/workspace.ts @@ -1,5 +1,6 @@ import type { LightWorkspaceType, + MembershipRoleType, ModelId, RoleType, SubscriptionType, @@ -8,15 +9,11 @@ import type { WorkspaceSegmentationType, WorkspaceType, } from "@dust-tt/types"; -import { Op } from "sequelize"; import type { Authenticator } from "@app/lib/auth"; -import { - Membership, - User, - Workspace, - WorkspaceHasDomain, -} from "@app/lib/models"; +import { User, Workspace, WorkspaceHasDomain } from "@app/lib/models"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; +import { renderLightWorkspaceType } from "@app/lib/workspace"; export async function getWorkspaceInfos( wId: string @@ -107,9 +104,11 @@ export async function getMembers( auth: Authenticator, { roles, + activeOnly, userIds, }: { - roles?: RoleType[]; + roles?: MembershipRoleType[]; + activeOnly?: boolean; userIds?: ModelId[]; } = {} ): Promise { @@ -118,20 +117,17 @@ export async function getMembers( return []; } - const whereClause: { - workspaceId: ModelId; - userId?: ModelId[]; - role?: RoleType[]; - } = userIds - ? { workspaceId: owner.id, userId: userIds } - : { workspaceId: owner.id }; - if (roles) { - whereClause.role = roles; - } - - const memberships = await Membership.findAll({ - where: whereClause, - }); + const memberships = activeOnly + ? await MembershipResource.getActiveMemberships({ + workspace: owner, + roles, + userIds, + }) + : await MembershipResource.getLatestMemberships({ + workspace: owner, + roles, + userIds, + }); const users = await User.findAll({ where: { @@ -142,7 +138,7 @@ export async function getMembers( return users.map((u) => { const m = memberships.find((m) => m.userId === u.id); let role = "none" as RoleType; - if (m) { + if (m && !m.isRevoked()) { switch (m.role) { case "admin": case "builder": @@ -171,33 +167,16 @@ export async function getMembers( export async function getMembersCount( auth: Authenticator, - { activeOnly }: { activeOnly?: boolean } = {} + { activeOnly = false }: { activeOnly?: boolean } = {} ): Promise { const owner = auth.workspace(); if (!owner) { return 0; } - return getMembersCountForWorkspace(owner, { activeOnly }); -} - -export async function getMembersCountForWorkspace( - workspace: WorkspaceType | Workspace, - { activeOnly }: { activeOnly?: boolean } = {} -): Promise { - const whereClause = activeOnly - ? { - role: { - [Op.ne]: "revoked", - }, - } - : {}; - - return Membership.count({ - where: { - workspaceId: workspace.id, - ...whereClause, - }, + return MembershipResource.getMembersCountForWorkspace({ + workspace: owner, + activeOnly, }); } @@ -222,9 +201,11 @@ export async function evaluateWorkspaceSeatAvailability( return true; } - const activeMembersCount = await getMembersCountForWorkspace(workspace, { - activeOnly: true, - }); + const activeMembersCount = + await MembershipResource.getMembersCountForWorkspace({ + workspace: renderLightWorkspaceType({ workspace }), + activeOnly: true, + }); return activeMembersCount < maxUsers; } diff --git a/front/lib/auth.ts b/front/lib/auth.ts index 214875737ae3..54e9b96b419d 100644 --- a/front/lib/auth.ts +++ b/front/lib/auth.ts @@ -29,7 +29,6 @@ import { isValidSession } from "@app/lib/iam/provider"; import { FeatureFlag, Key, - Membership, Plan, Subscription, User, @@ -39,7 +38,9 @@ import type { PlanAttributes } from "@app/lib/plans/free_plans"; import { FREE_NO_PLAN_DATA } from "@app/lib/plans/free_plans"; import { isUpgraded } from "@app/lib/plans/plan_codes"; import { getTrialVersionForPlan, isTrial } from "@app/lib/plans/trial"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { new_id } from "@app/lib/utils"; +import { renderLightWorkspaceType } from "@app/lib/workspace"; import logger from "@app/logger/logger"; import { renderSubscriptionFromModels } from "./plans/subscription"; @@ -129,13 +130,13 @@ export class Authenticator { if (user && workspace) { [role, subscription, flags] = await Promise.all([ (async (): Promise => { - const membership = await Membership.findOne({ - where: { + const membership = + await MembershipResource.getActiveMembershipOfUserInWorkspace({ userId: user.id, - workspaceId: workspace.id, - }, - }); + workspace: renderLightWorkspaceType({ workspace }), + }); return membership && + // TODO(@fontanierh): get rid of the check ? ["admin", "builder", "user"].includes(membership.role) ? (membership.role as RoleType) : "none"; diff --git a/front/lib/document_tracker.ts b/front/lib/document_tracker.ts index 0b20201c352c..28d4df591ac1 100644 --- a/front/lib/document_tracker.ts +++ b/front/lib/document_tracker.ts @@ -1,7 +1,9 @@ import { CoreAPI } from "@dust-tt/types"; import { literal, Op } from "sequelize"; -import { DataSource, Membership, TrackedDocument, User } from "@app/lib/models"; +import { Authenticator } from "@app/lib/auth"; +import { DataSource, TrackedDocument, User, Workspace } from "@app/lib/models"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import logger from "@app/logger/logger"; export async function updateTrackedDocuments( @@ -18,6 +20,21 @@ export async function updateTrackedDocuments( `Data source with id ${dataSourceId} has no workspace id set` ); } + const workspaceModel = await Workspace.findByPk(dataSource.workspaceId); + if (!workspaceModel) { + throw new Error( + `Could not find workspace with id ${dataSource.workspaceId}` + ); + } + const auth = await Authenticator.internalBuilderForWorkspace( + workspaceModel.sId + ); + const owner = auth.workspace(); + if (!owner) { + throw new Error( + `Could not find workspace with id ${dataSource.workspaceId}` + ); + } const hasExistingTrackedDocs = !!(await TrackedDocument.count({ where: { @@ -75,13 +92,8 @@ export async function updateTrackedDocuments( : []; // restrict to users in the workspace - const memberships = await Membership.findAll({ - where: { - userId: { - [Op.in]: users.map((user) => user.id), - }, - workspaceId: dataSource.workspaceId, - }, + const memberships = await MembershipResource.getActiveMemberships({ + workspace: owner, }); const userIdsInWorkspace = new Set( memberships.map((membership) => membership.userId) diff --git a/front/lib/iam/memberships.ts b/front/lib/iam/memberships.ts deleted file mode 100644 index 0c8c0e2e8c78..000000000000 --- a/front/lib/iam/memberships.ts +++ /dev/null @@ -1,15 +0,0 @@ -import { Op } from "sequelize"; - -import type { User } from "@app/lib/models"; -import { Membership } from "@app/lib/models"; - -export async function getActiveMembershipsForUser(userId: User["id"]) { - return Membership.findAll({ - where: { - userId, - role: { - [Op.ne]: "revoked", - }, - }, - }); -} diff --git a/front/lib/iam/session.ts b/front/lib/iam/session.ts index feb5456d6ca1..ee2eba67c656 100644 --- a/front/lib/iam/session.ts +++ b/front/lib/iam/session.ts @@ -6,7 +6,6 @@ import type { PreviewData, } from "next"; import type { ParsedUrlQuery } from "querystring"; -import { Op } from "sequelize"; import { Authenticator, getSession } from "@app/lib/auth"; import type { SessionWithUser } from "@app/lib/iam/provider"; @@ -15,7 +14,8 @@ import { fetchUserFromSession, maybeUpdateFromExternalUser, } from "@app/lib/iam/users"; -import { Membership, Workspace } from "@app/lib/models"; +import { Workspace } from "@app/lib/models"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import logger from "@app/logger/logger"; import { withGetServerSidePropsLogging } from "@app/logger/withlogging"; @@ -36,11 +36,8 @@ export async function getUserFromSession( return null; } - const memberships = await Membership.findAll({ - where: { - userId: user.id, - role: { [Op.in]: ["admin", "builder", "user"] }, - }, + const memberships = await MembershipResource.getActiveMemberships({ + userIds: [user.id], }); const workspaces = await Workspace.findAll({ where: { diff --git a/front/lib/models/index.ts b/front/lib/models/index.ts index 00c019f8ff4e..b21499b6386d 100644 --- a/front/lib/models/index.ts +++ b/front/lib/models/index.ts @@ -39,7 +39,6 @@ import { Plan, Subscription } from "@app/lib/models/plan"; import { User, UserMetadata } from "@app/lib/models/user"; import { Key, - Membership, MembershipInvitation, Workspace, WorkspaceHasDomain, @@ -68,7 +67,6 @@ export { FeatureFlag, GlobalAgentSettings, Key, - Membership, MembershipInvitation, Mention, Message, diff --git a/front/lib/models/workspace.ts b/front/lib/models/workspace.ts index b98261ed9262..c374da92d970 100644 --- a/front/lib/models/workspace.ts +++ b/front/lib/models/workspace.ts @@ -1,8 +1,4 @@ -import type { - MembershipRoleType, - RoleType, - WorkspaceSegmentationType, -} from "@dust-tt/types"; +import type { RoleType, WorkspaceSegmentationType } from "@dust-tt/types"; import type { CreationOptional, ForeignKey, @@ -131,72 +127,6 @@ Workspace.hasMany(WorkspaceHasDomain, { }); WorkspaceHasDomain.belongsTo(Workspace); -export class Membership extends Model< - InferAttributes, - InferCreationAttributes -> { - declare id: CreationOptional; - declare createdAt: CreationOptional; - declare updatedAt: CreationOptional; - - declare role: MembershipRoleType; - declare startAt: Date | null; - declare endAt: Date | null; - - declare userId: ForeignKey; - declare workspaceId: ForeignKey; -} -Membership.init( - { - id: { - type: DataTypes.INTEGER, - autoIncrement: true, - primaryKey: true, - }, - createdAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - updatedAt: { - type: DataTypes.DATE, - allowNull: false, - defaultValue: DataTypes.NOW, - }, - role: { - type: DataTypes.STRING, - allowNull: false, - }, - startAt: { - type: DataTypes.DATE, - allowNull: true, - }, - endAt: { - type: DataTypes.DATE, - allowNull: true, - }, - }, - { - modelName: "membership", - sequelize: frontSequelize, - indexes: [ - { fields: ["userId", "role"] }, - { fields: ["startAt"] }, - { fields: ["endAt"] }, - ], - } -); -User.hasMany(Membership, { - foreignKey: { allowNull: false }, - onDelete: "CASCADE", -}); -Workspace.hasMany(Membership, { - foreignKey: { allowNull: false }, - onDelete: "CASCADE", -}); -Membership.belongsTo(Workspace); -Membership.belongsTo(User); - export class MembershipInvitation extends Model< InferAttributes, InferCreationAttributes diff --git a/front/lib/plans/workspace_usage.ts b/front/lib/plans/workspace_usage.ts index 6433f10e3662..a471345551ef 100644 --- a/front/lib/plans/workspace_usage.ts +++ b/front/lib/plans/workspace_usage.ts @@ -1,8 +1,10 @@ import type { WorkspaceType } from "@dust-tt/types"; -import { Op, QueryTypes } from "sequelize"; +import { QueryTypes } from "sequelize"; -import { Membership, Workspace } from "@app/lib/models/workspace"; +import { Workspace } from "@app/lib/models/workspace"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { frontSequelize } from "@app/lib/resources/storage"; +import { renderLightWorkspaceType } from "@app/lib/workspace"; export async function countActiveSeatsInWorkspace( workspaceId: string @@ -15,13 +17,9 @@ export async function countActiveSeatsInWorkspace( if (!workspace) { throw new Error(`Workspace not found for sId: ${workspaceId}`); } - return Membership.count({ - where: { - workspaceId: workspace.id, - role: { - [Op.notIn]: ["none", "revoked"], - }, - }, + return MembershipResource.getMembersCountForWorkspace({ + workspace: renderLightWorkspaceType({ workspace }), + activeOnly: true, }); } diff --git a/front/lib/resources/base_resource.ts b/front/lib/resources/base_resource.ts index 60efe2949a65..e731e4c71431 100644 --- a/front/lib/resources/base_resource.ts +++ b/front/lib/resources/base_resource.ts @@ -43,21 +43,4 @@ export abstract class BaseResource { } abstract delete(transaction?: Transaction): Promise>; - - async update( - blob: Partial>, - transaction?: Transaction - ): Promise<[affectedCount: number]> { - const [affectedCount, affectedRows] = await this.model.update(blob, { - // @ts-expect-error TS cannot infer the presence of 'id' in Sequelize models, but our models always include 'id'. - where: { - id: this.id, - }, - transaction, - returning: true, - }); - // Update the current instance with the new values to avoid stale data - Object.assign(this, affectedRows[0].get()); - return [affectedCount]; - } } diff --git a/front/lib/resources/content_fragment_resource.ts b/front/lib/resources/content_fragment_resource.ts index 072fcb0a5338..b0a633e206c7 100644 --- a/front/lib/resources/content_fragment_resource.ts +++ b/front/lib/resources/content_fragment_resource.ts @@ -108,6 +108,22 @@ export class ContentFragmentResource extends BaseResource }, }; } + + async update( + blob: Partial>, + transaction?: Transaction + ): Promise<[affectedCount: number]> { + const [affectedCount, affectedRows] = await this.model.update(blob, { + where: { + id: this.id, + }, + transaction, + returning: true, + }); + // Update the current instance with the new values to avoid stale data + Object.assign(this, affectedRows[0].get()); + return [affectedCount]; + } } // TODO(2024-03-22 pr): Move as method of message resource after migration of diff --git a/front/lib/resources/membership_resource.ts b/front/lib/resources/membership_resource.ts new file mode 100644 index 000000000000..9f5ad3d10fc5 --- /dev/null +++ b/front/lib/resources/membership_resource.ts @@ -0,0 +1,390 @@ +import type { + LightWorkspaceType, + MembershipRoleType, + RequireAtLeastOne, + Result, +} from "@dust-tt/types"; +import { Err, Ok } from "@dust-tt/types"; +import type { + Attributes, + InferAttributes, + ModelStatic, + Transaction, + WhereOptions, +} from "sequelize"; +import { Op } from "sequelize"; + +import { BaseResource } from "@app/lib/resources/base_resource"; +import { MembershipModel } from "@app/lib/resources/storage/models/membership"; +import type { ReadonlyAttributesType } from "@app/lib/resources/storage/types"; +import logger from "@app/logger/logger"; + +type GetMembershipsOptions = RequireAtLeastOne<{ + userIds: number[]; + workspace: LightWorkspaceType; +}> & { + roles?: MembershipRoleType[]; + transaction?: Transaction; +}; + +// Attributes are marked as read-only to reflect the stateless nature of our Resource. +// This design will be moved up to BaseResource once we transition away from Sequelize. +// eslint-disable-next-line @typescript-eslint/no-empty-interface +export interface MembershipResource + extends ReadonlyAttributesType {} +export class MembershipResource extends BaseResource { + static model: ModelStatic = MembershipModel; + + constructor( + model: ModelStatic, + blob: Attributes + ) { + super(MembershipModel, blob); + } + + static async getActiveMemberships({ + userIds, + workspace, + roles, + transaction, + }: GetMembershipsOptions): Promise { + if (!workspace && !userIds?.length) { + throw new Error("At least one of workspace or userIds must be provided."); + } + const whereClause: WhereOptions> = { + startAt: { + [Op.lte]: new Date(), + }, + endAt: { + [Op.or]: [{ [Op.eq]: null }, { [Op.gte]: new Date() }], + }, + }; + + if (userIds) { + whereClause.userId = userIds; + } + if (workspace) { + whereClause.workspaceId = workspace.id; + } + if (roles) { + whereClause.role = { + [Op.in]: roles, + }; + } + + const memberships = await MembershipModel.findAll({ + where: whereClause, + transaction, + }); + + return memberships.map( + (membership) => new MembershipResource(MembershipModel, membership.get()) + ); + } + + static async getLatestMemberships({ + userIds, + workspace, + roles, + transaction, + }: GetMembershipsOptions): Promise { + const orderedResourcesFromModels = (resources: MembershipModel[]) => + resources + .sort((a, b) => a.startAt.getTime() - b.startAt.getTime()) + .map( + (resource) => new MembershipResource(MembershipModel, resource.get()) + ); + + const where: WhereOptions> = { + role: roles, + userId: userIds ? { [Op.in]: userIds } : undefined, + workspaceId: workspace ? workspace.id : undefined, + }; + + if (!workspace && !userIds?.length) { + throw new Error("At least one of workspace or userIds must be provided."); + } + if (userIds && !userIds.length) return []; + + // Get all the memberships matching the criteria. + const memberships = await MembershipModel.findAll({ + where, + order: [["startAt", "DESC"]], + transaction, + }); + // Then, we only keep the latest membership for each (user, workspace). + const latestMembershipByUserAndWorkspace = new Map< + string, + MembershipModel + >(); + for (const m of memberships) { + const key = `${m.userId}__${m.workspaceId}`; + const latest = latestMembershipByUserAndWorkspace.get(key); + if (!latest || latest.startAt < m.startAt) { + latestMembershipByUserAndWorkspace.set(key, m); + } + } + + return orderedResourcesFromModels( + Array.from(latestMembershipByUserAndWorkspace.values()) + ); + } + + static async getLatestMembershipOfUserInWorkspace({ + userId, + workspace, + transaction, + }: { + userId: number; + workspace: LightWorkspaceType; + transaction?: Transaction; + }): Promise { + const memberships = await this.getLatestMemberships({ + userIds: [userId], + workspace, + transaction, + }); + if (memberships.length === 0) { + return null; + } + if (memberships.length > 1) { + logger.error( + { + panic: true, + userId, + workspaceId: workspace.id, + memberships, + }, + "Unreachable: Found multiple latest memberships for user in workspace." + ); + throw new Error( + `Unreachable: Found multiple latest memberships for user ${userId} in workspace ${workspace.id}` + ); + } + return memberships[0]; + } + + static async getActiveMembershipOfUserInWorkspace({ + userId, + workspace, + transaction, + }: { + userId: number; + workspace: LightWorkspaceType; + transaction?: Transaction; + }): Promise { + const memberships = await this.getActiveMemberships({ + userIds: [userId], + workspace, + transaction, + }); + if (memberships.length === 0) { + return null; + } + if (memberships.length > 1) { + logger.error( + { + panic: true, + userId, + workspaceId: workspace.id, + memberships, + }, + "Unreachable: Found multiple active memberships for user in workspace." + ); + throw new Error( + `Unreachable: Found multiple active memberships for user ${userId} in workspace ${workspace.id}` + ); + } + return memberships[0]; + } + + static async getMembersCountForWorkspace({ + workspace, + activeOnly, + transaction, + }: { + workspace: LightWorkspaceType; + activeOnly: boolean; + transaction?: Transaction; + }): Promise { + const where: WhereOptions> = activeOnly + ? { + endAt: { + [Op.or]: [{ [Op.eq]: null }, { [Op.gt]: new Date() }], + }, + startAt: { + [Op.lte]: new Date(), + }, + } + : {}; + + where.workspaceId = workspace.id; + + return MembershipModel.count({ + where, + distinct: true, + col: "userId", + transaction, + }); + } + + static async createMembership({ + userId, + workspace, + role, + startAt = new Date(), + transaction, + }: { + userId: number; + workspace: LightWorkspaceType; + role: MembershipRoleType; + startAt?: Date; + transaction?: Transaction; + }): Promise { + if (startAt > new Date()) { + throw new Error("Cannot create a membership in the future"); + } + if ( + await MembershipModel.count({ + where: { + userId, + workspaceId: workspace.id, + endAt: { + [Op.or]: [{ [Op.eq]: null }, { [Op.gt]: startAt }], + }, + }, + transaction, + }) + ) { + throw new Error( + `User ${userId} already has an active membership in workspace ${workspace.id}` + ); + } + const newMembership = await MembershipModel.create( + { + startAt, + userId, + workspaceId: workspace.id, + role, + }, + { transaction } + ); + + return new MembershipResource(MembershipModel, newMembership.get()); + } + + static async revokeMembership({ + userId, + workspace, + endAt = new Date(), + transaction, + }: { + userId: number; + workspace: LightWorkspaceType; + endAt?: Date; + transaction?: Transaction; + }): Promise< + Result< + undefined, + { + type: "not_found" | "already_revoked"; + } + > + > { + const membership = await this.getLatestMembershipOfUserInWorkspace({ + userId, + workspace, + transaction, + }); + if (!membership) { + return new Err({ type: "not_found" }); + } + if (endAt < membership.startAt) { + throw new Error("endAt must be after startAt"); + } + if (membership.endAt) { + return new Err({ type: "already_revoked" }); + } + await MembershipModel.update( + { endAt }, + { where: { id: membership.id }, transaction } + ); + return new Ok(undefined); + } + + static async updateMembershipRole({ + userId, + workspace, + newRole, + allowTerminated = false, + transaction, + }: { + userId: number; + workspace: LightWorkspaceType; + newRole: Exclude; + // If true, allow updating the role of a terminated membership (which will also un-terminate it). + allowTerminated?: boolean; + transaction?: Transaction; + }): Promise< + Result< + void, + { + type: "not_found" | "already_on_role" | "membership_already_terminated"; + } + > + > { + const membership = await this.getLatestMembershipOfUserInWorkspace({ + userId, + workspace, + transaction, + }); + if (membership?.endAt && !allowTerminated) { + return new Err({ type: "membership_already_terminated" }); + } + if (!membership) { + return new Err({ type: "not_found" }); + } + + // If the membership is not terminated, we update the role in place. + // TODO(@fontanierh): check if we want to terminate + create a new membership with new role instead ? + if (!membership.endAt) { + if (membership.role === newRole) { + return new Err({ type: "already_on_role" }); + } + await MembershipModel.update( + { role: newRole }, + { where: { id: membership.id }, transaction } + ); + } else { + // If the last membership was terminated, we create a new membership with the new role. + await this.createMembership({ + userId, + workspace, + role: newRole, + startAt: new Date(), + transaction, + }); + } + + return new Ok(undefined); + } + + async delete(transaction?: Transaction): Promise> { + try { + await this.model.destroy({ + where: { + id: this.id, + }, + transaction, + }); + + return new Ok(undefined); + } catch (err) { + return new Err(err as Error); + } + } + + isRevoked(referenceDate: Date = new Date()): boolean { + return !!this.endAt && this.endAt < referenceDate; + } +} diff --git a/front/lib/resources/storage/models/membership.ts b/front/lib/resources/storage/models/membership.ts new file mode 100644 index 000000000000..efbeffc829d3 --- /dev/null +++ b/front/lib/resources/storage/models/membership.ts @@ -0,0 +1,77 @@ +import type { MembershipRoleType } from "@dust-tt/types"; +import type { + CreationOptional, + ForeignKey, + InferAttributes, + InferCreationAttributes, +} from "sequelize"; +import { DataTypes, Model } from "sequelize"; + +import { User, Workspace } from "@app/lib/models"; +import { frontSequelize } from "@app/lib/resources/storage"; + +export class MembershipModel extends Model< + InferAttributes, + InferCreationAttributes +> { + declare id: CreationOptional; + declare createdAt: CreationOptional; + declare updatedAt: CreationOptional; + + declare role: MembershipRoleType; + declare startAt: Date; + declare endAt: Date | null; + + declare userId: ForeignKey; + declare workspaceId: ForeignKey; +} +MembershipModel.init( + { + id: { + type: DataTypes.INTEGER, + autoIncrement: true, + primaryKey: true, + }, + createdAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + updatedAt: { + type: DataTypes.DATE, + allowNull: false, + defaultValue: DataTypes.NOW, + }, + role: { + type: DataTypes.STRING, + allowNull: false, + }, + startAt: { + type: DataTypes.DATE, + allowNull: false, + }, + endAt: { + type: DataTypes.DATE, + allowNull: true, + }, + }, + { + modelName: "membership", + sequelize: frontSequelize, + indexes: [ + { fields: ["userId", "role"] }, + { fields: ["startAt"] }, + { fields: ["endAt"] }, + ], + } +); +User.hasMany(MembershipModel, { + foreignKey: { allowNull: false }, + onDelete: "CASCADE", +}); +Workspace.hasMany(MembershipModel, { + foreignKey: { allowNull: false }, + onDelete: "CASCADE", +}); +MembershipModel.belongsTo(Workspace); +MembershipModel.belongsTo(User); diff --git a/front/lib/resources/template_resource.ts b/front/lib/resources/template_resource.ts index 2f053f08a042..80b822923057 100644 --- a/front/lib/resources/template_resource.ts +++ b/front/lib/resources/template_resource.ts @@ -87,6 +87,22 @@ export class TemplateResource extends BaseResource { } } + async update( + blob: Partial>, + transaction?: Transaction + ): Promise<[affectedCount: number]> { + const [affectedCount, affectedRows] = await this.model.update(blob, { + where: { + id: this.id, + }, + transaction, + returning: true, + }); + // Update the current instance with the new values to avoid stale data + Object.assign(this, affectedRows[0].get()); + return [affectedCount]; + } + isPublished() { return this.visibility === "published"; } diff --git a/front/lib/workspace.ts b/front/lib/workspace.ts new file mode 100644 index 000000000000..799242ba5f01 --- /dev/null +++ b/front/lib/workspace.ts @@ -0,0 +1,23 @@ +import type { + LightWorkspaceType, + RoleType, + WorkspaceType, +} from "@dust-tt/types"; + +import type { Workspace } from "@app/lib/models"; + +export function renderLightWorkspaceType({ + workspace, + role = "none", +}: { + workspace: Workspace | WorkspaceType | LightWorkspaceType; + role?: RoleType; +}): LightWorkspaceType { + return { + id: workspace.id, + sId: workspace.sId, + name: workspace.name, + segmentation: workspace.segmentation, + role, + }; +} diff --git a/front/mailing/20240308-weekly-update.ts b/front/mailing/20240308-weekly-update.ts index 7e3216ba65dd..f2438c3c498b 100644 --- a/front/mailing/20240308-weekly-update.ts +++ b/front/mailing/20240308-weekly-update.ts @@ -113,7 +113,7 @@ JOIN "memberships" "m" ON "u"."id" = "m"."userId" JOIN "workspaces" "w" ON "m"."workspaceId" = "w"."id" JOIN "subscriptions" "s" ON "w"."id" = "s"."workspaceId" WHERE "s"."status" = 'active' -AND "m"."role" != 'revoked'; +AND ("m"."startAt" <= NOW()) AND ("m"."endAt" IS NULL OR "m"."endAt" >= NOW()); ` ); diff --git a/front/migrations/20230413_workspaces_memberships.ts b/front/migrations/20230413_workspaces_memberships.ts index 2a3baf55ac27..5d13da26dff4 100644 --- a/front/migrations/20230413_workspaces_memberships.ts +++ b/front/migrations/20230413_workspaces_memberships.ts @@ -1,4 +1,5 @@ -import { Membership, User, Workspace } from "@app/lib/models"; +import { User, Workspace } from "@app/lib/models"; +import { MembershipModel } from "@app/lib/resources/storage/models/membership"; import { new_id } from "@app/lib/utils"; async function main() { @@ -15,7 +16,7 @@ async function main() { await Promise.all( chunk.map((u) => { return (async () => { - const m = await Membership.findOne({ + const m = await MembershipModel.findOne({ where: { userId: u.id, }, @@ -31,7 +32,7 @@ async function main() { type: "personal", }); - await Membership.create({ + await MembershipModel.create({ role: "admin", userId: u.id, workspaceId: w.id, diff --git a/front/migrations/20231204_author_backfill.ts b/front/migrations/20231204_author_backfill.ts index 032c93d04fe0..bdfe5c934ba6 100644 --- a/front/migrations/20231204_author_backfill.ts +++ b/front/migrations/20231204_author_backfill.ts @@ -1,4 +1,5 @@ -import { AgentConfiguration, Membership, User } from "@app/lib/models"; +import { AgentConfiguration, User } from "@app/lib/models"; +import { MembershipModel } from "@app/lib/resources/storage/models/membership"; async function main() { console.log("Starting author backfill"); @@ -34,7 +35,7 @@ async function backfillAuthor(workspaceId: number) { const author = await User.findOne({ include: [ { - model: Membership, + model: MembershipModel, where: { role: "admin", workspaceId, diff --git a/front/migrations/20231219_imageUrl_backfill.ts b/front/migrations/20231219_imageUrl_backfill.ts index 995eb9298f9a..38b5052f4267 100644 --- a/front/migrations/20231219_imageUrl_backfill.ts +++ b/front/migrations/20231219_imageUrl_backfill.ts @@ -1,4 +1,5 @@ -import { Membership, User, UserMessage, Workspace } from "@app/lib/models"; +import { User, UserMessage, Workspace } from "@app/lib/models"; +import { MembershipModel } from "@app/lib/resources/storage/models/membership"; async function main() { console.log("Starting imageUrl backfill"); @@ -36,7 +37,7 @@ async function backfillImageUrl(workspaceId: number) { }, include: [ { - model: Membership, + model: MembershipModel, where: { workspaceId, }, diff --git a/front/migrations/20240329_membership_start_end_date.ts b/front/migrations/20240329_membership_start_end_date.ts index cb6a28230439..96fbb5ceb4d5 100644 --- a/front/migrations/20240329_membership_start_end_date.ts +++ b/front/migrations/20240329_membership_start_end_date.ts @@ -1,11 +1,4 @@ -import { Op } from "sequelize"; - -import { Message } from "@app/lib/models"; -import { ContentFragmentResource } from "@app/lib/resources/content_fragment_resource"; import { frontSequelize } from "@app/lib/resources/storage"; -import { ContentFragmentModel } from "@app/lib/resources/storage/models/content_fragment"; - -const { LIVE } = process.env; async function main() { // For every membership object, we set a startAt equal to the `createdAt` field of the membership object. diff --git a/front/migrations/20240403_memberships_no_more_revoked.ts b/front/migrations/20240403_memberships_no_more_revoked.ts new file mode 100644 index 000000000000..d20b46b8795a --- /dev/null +++ b/front/migrations/20240403_memberships_no_more_revoked.ts @@ -0,0 +1,20 @@ +import { frontSequelize } from "@app/lib/resources/storage"; + +async function main() { + // We no longer rely on the "revoked" role for memberships. + await frontSequelize.query(` + UPDATE memberships + SET "role" = 'member' + WHERE role = 'revoked'; + `); +} + +main() + .then(() => { + console.log("Done"); + process.exit(0); + }) + .catch((err) => { + console.error(err); + process.exit(1); + }); diff --git a/front/pages/api/login.ts b/front/pages/api/login.ts index 718f51830b29..e4e58b31adef 100644 --- a/front/pages/api/login.ts +++ b/front/pages/api/login.ts @@ -14,7 +14,6 @@ import { getPendingMembershipInvitationForToken, markInvitationAsConsumed, } from "@app/lib/iam/invitations"; -import { getActiveMembershipsForUser } from "@app/lib/iam/memberships"; import type { SessionWithUser } from "@app/lib/iam/provider"; import { getUserFromSession } from "@app/lib/iam/session"; import { createOrUpdateUser } from "@app/lib/iam/users"; @@ -23,8 +22,10 @@ import { findWorkspaceWithVerifiedDomain, } from "@app/lib/iam/workspaces"; import type { MembershipInvitation, User } from "@app/lib/models"; -import { Membership, Workspace } from "@app/lib/models"; +import { Workspace } from "@app/lib/models"; import { updateWorkspacePerSeatSubscriptionUsage } from "@app/lib/plans/subscription"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; +import { renderLightWorkspaceType } from "@app/lib/workspace"; import logger from "@app/logger/logger"; import { apiError, withLogging } from "@app/logger/withlogging"; @@ -71,24 +72,22 @@ async function handleMembershipInvite( ); } - const m = await Membership.findOne({ - where: { - userId: user.id, - workspaceId: membershipInvite.workspaceId, - }, + const m = await MembershipResource.getLatestMembershipOfUserInWorkspace({ + userId: user.id, + workspace: renderLightWorkspaceType({ workspace }), }); - if (m?.role === "revoked") { + if (m?.isRevoked()) { return new Err( new AuthFlowError( - "Your access to the workspace has been revoked, please contact the workspace admin to update your role." + "Your access to the workspace has expired, please contact the workspace admin to update your role." ) ); } if (!m) { await createAndLogMembership({ - workspace: workspace, + workspace, userId: user.id, role: membershipInvite.initialRole, }); @@ -102,7 +101,7 @@ async function handleMembershipInvite( function canJoinTargetWorkspace( targetWorkspaceId: string | undefined, workspace: Workspace | undefined, - activeMemberships: Membership[] + activeMemberships: MembershipResource[] ) { // If there is no target workspace id, return true. if (!targetWorkspaceId) { @@ -133,7 +132,9 @@ async function handleEnterpriseSignUpFlow( }> { // Combine queries to optimize database calls. const [activeMemberships, workspace] = await Promise.all([ - getActiveMembershipsForUser(user.id), + MembershipResource.getActiveMemberships({ + userIds: [user.id], + }), Workspace.findOne({ where: { sId: enterpriseConnectionWorkspaceId, @@ -151,12 +152,11 @@ async function handleEnterpriseSignUpFlow( return { flow: "unauthorized", workspace: null }; } - const membership = await Membership.findOne({ - where: { + const membership = + await MembershipResource.getLatestMembershipOfUserInWorkspace({ userId: user.id, - workspaceId: workspace.id, - }, - }); + workspace: renderLightWorkspaceType({ workspace }), + }); // Create membership if it does not exist. if (!membership) { @@ -165,7 +165,7 @@ async function handleEnterpriseSignUpFlow( userId: user.id, role: "user", }); - } else if (membership.role === "revoked") { + } else if (membership.isRevoked()) { return { flow: "unauthorized", workspace: null }; } @@ -188,7 +188,10 @@ async function handleRegularSignupFlow( SSOEnforcedError > > { - const activeMemberships = await getActiveMembershipsForUser(user.id); + const activeMemberships = await MembershipResource.getActiveMemberships({ + userIds: [user.id], + }); + // Return early if the user is already a member of a workspace and is not attempting to join another one. if (activeMemberships.length !== 0 && !targetWorkspaceId) { return new Ok({ @@ -237,14 +240,12 @@ async function handleRegularSignupFlow( return new Ok({ flow: "no-auto-join", workspace: null }); } - const m = await Membership.findOne({ - where: { - userId: user.id, - workspaceId: existingWorkspace.id, - }, + const m = await MembershipResource.getLatestMembershipOfUserInWorkspace({ + userId: user.id, + workspace: renderLightWorkspaceType({ workspace: existingWorkspace }), }); - if (m?.role === "revoked") { + if (m?.isRevoked()) { return new Ok({ flow: "revoked", workspace: null }); } @@ -406,11 +407,10 @@ export async function createAndLogMembership({ workspace: Workspace; role: ActiveRoleType; }) { - const m = await Membership.create({ + const m = await MembershipResource.createMembership({ role: role, userId: userId, - workspaceId: workspace.id, - startAt: new Date(), + workspace: renderLightWorkspaceType({ workspace }), }); trackUserMemberships(m.userId).catch(logger.error); diff --git a/front/pages/api/poke/workspaces/[wId]/revoke.ts b/front/pages/api/poke/workspaces/[wId]/revoke.ts index 3efd0056a388..28d575a05ceb 100644 --- a/front/pages/api/poke/workspaces/[wId]/revoke.ts +++ b/front/pages/api/poke/workspaces/[wId]/revoke.ts @@ -1,9 +1,10 @@ import type { WithAPIErrorReponse } from "@dust-tt/types"; +import { assertNever } from "@dust-tt/types"; import type { NextApiRequest, NextApiResponse } from "next"; import { Authenticator, getSession } from "@app/lib/auth"; -import { Membership } from "@app/lib/models"; import { updateWorkspacePerSeatSubscriptionUsage } from "@app/lib/plans/subscription"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { apiError, withLogging } from "@app/logger/withlogging"; export type RevokeUserResponseBody = { @@ -43,27 +44,27 @@ async function handler( }, }); } - - const m = await Membership.findOne({ - where: { - userId, - workspaceId: owner.id, - }, + const revokeResult = await MembershipResource.revokeMembership({ + userId, + workspace: owner, }); - - if (!m) { - return apiError(req, res, { - status_code: 404, - api_error: { - type: "workspace_user_not_found", - message: "Could not find the membership.", - }, - }); + if (revokeResult.isErr()) { + switch (revokeResult.error.type) { + case "not_found": + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_user_not_found", + message: "Could not find the membership.", + }, + }); + case "already_revoked": + // Should not happen, but we ignore. + break; + default: + assertNever(revokeResult.error.type); + } } - - await m.update({ - role: "revoked", - }); await updateWorkspacePerSeatSubscriptionUsage({ workspaceId: owner.sId, }); diff --git a/front/pages/api/poke/workspaces/index.ts b/front/pages/api/poke/workspaces/index.ts index 5f5e24de04af..66f79ca4172f 100644 --- a/front/pages/api/poke/workspaces/index.ts +++ b/front/pages/api/poke/workspaces/index.ts @@ -4,14 +4,9 @@ import type { FindOptions, WhereOptions } from "sequelize"; import { Op } from "sequelize"; import { Authenticator, getSession } from "@app/lib/auth"; -import { - Membership, - Plan, - Subscription, - User, - Workspace, -} from "@app/lib/models"; +import { Plan, Subscription, User, Workspace } from "@app/lib/models"; import { FREE_TEST_PLAN_CODE } from "@app/lib/plans/plan_codes"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { isEmailValid } from "@app/lib/utils"; import { apiError, withLogging } from "@app/logger/withlogging"; @@ -135,13 +130,8 @@ async function handler( }, }); if (users.length) { - const memberships = await Membership.findAll({ - where: { - userId: { - [Op.in]: users.map((u) => u.id), - }, - }, - attributes: ["workspaceId"], + const memberships = await MembershipResource.getLatestMemberships({ + userIds: users.map((u) => u.id), }); if (memberships.length) { conditions.push({ diff --git a/front/pages/api/v1/w/[wId]/members/emails.ts b/front/pages/api/v1/w/[wId]/members/emails.ts index 8e98e7313f9a..733611ee1420 100644 --- a/front/pages/api/v1/w/[wId]/members/emails.ts +++ b/front/pages/api/v1/w/[wId]/members/emails.ts @@ -1,4 +1,4 @@ -import type { RoleType, WithAPIErrorReponse } from "@dust-tt/types"; +import type { WithAPIErrorReponse } from "@dust-tt/types"; import type { NextApiRequest, NextApiResponse } from "next"; import { getMembers } from "@app/lib/api/workspace"; @@ -38,11 +38,9 @@ async function handler( switch (req.method) { case "GET": - const roles: RoleType[] | undefined = activeOnly - ? ["admin", "builder", "user"] - : undefined; - - const allMembers = await getMembers(auth, { roles }); + const allMembers = await getMembers(auth, { + activeOnly: !!activeOnly, + }); return res.status(200).json({ emails: allMembers.map((m) => m.email) }); diff --git a/front/pages/api/w/[wId]/invitations/index.ts b/front/pages/api/w/[wId]/invitations/index.ts index 266956e0ff95..2c9a1a008104 100644 --- a/front/pages/api/w/[wId]/invitations/index.ts +++ b/front/pages/api/w/[wId]/invitations/index.ts @@ -13,11 +13,9 @@ import { updateOrCreateInvitation, } from "@app/lib/api/invitation"; import { getPendingInvitations } from "@app/lib/api/invitation"; -import { - getMembers, - getMembersCountForWorkspace, -} from "@app/lib/api/workspace"; +import { getMembers } from "@app/lib/api/workspace"; import { Authenticator, getSession } from "@app/lib/auth"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { isEmailValid } from "@app/lib/utils"; import logger from "@app/logger/logger"; import { apiError, withLogging } from "@app/logger/withlogging"; @@ -127,7 +125,10 @@ async function handler( const { maxUsers } = subscription.plan.limits.users; const availableSeats = maxUsers - - (await getMembersCountForWorkspace(owner, { activeOnly: true })); + (await MembershipResource.getMembersCountForWorkspace({ + workspace: owner, + activeOnly: true, + })); if (maxUsers !== -1 && availableSeats < invitationRequests.length) { return apiError(req, res, { status_code: 400, diff --git a/front/pages/api/w/[wId]/members/[userId]/index.ts b/front/pages/api/w/[wId]/members/[userId]/index.ts index 05cbf620d496..b8b86456e91f 100644 --- a/front/pages/api/w/[wId]/members/[userId]/index.ts +++ b/front/pages/api/w/[wId]/members/[userId]/index.ts @@ -1,9 +1,11 @@ import type { UserType, WithAPIErrorReponse } from "@dust-tt/types"; +import { assertNever, isMembershipRoleType } from "@dust-tt/types"; import type { NextApiRequest, NextApiResponse } from "next"; import { Authenticator, getSession } from "@app/lib/auth"; -import { Membership, User } from "@app/lib/models"; +import { User } from "@app/lib/models"; import { updateWorkspacePerSeatSubscriptionUsage } from "@app/lib/plans/subscription"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { apiError, withLogging } from "@app/logger/withlogging"; export type PostMemberResponseBody = { @@ -59,11 +61,9 @@ async function handler( id: userId, }, }), - Membership.findOne({ - where: { - userId: userId, - workspaceId: owner.id, - }, + MembershipResource.getLatestMembershipOfUserInWorkspace({ + userId, + workspace: owner, }), ]); @@ -79,28 +79,73 @@ async function handler( switch (req.method) { case "POST": - if ( - !req.body || - !req.body.role || - !["admin", "builder", "user", "revoked"].includes(req.body.role) - ) { - return apiError(req, res, { - status_code: 400, - api_error: { - type: "invalid_request_error", - message: "The request body is invalid, expects { role: string }.", - }, + // TODO(@fontanierh): use DELETE for revoking membership + if (req.body.role === "revoked") { + const revokeResult = await MembershipResource.revokeMembership({ + userId, + workspace: owner, }); - } + if (revokeResult.isErr()) { + switch (revokeResult.error.type) { + case "not_found": + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_user_not_found", + message: "Could not find the membership.", + }, + }); + case "already_revoked": + // Should not happen, but we ignore. + break; + default: + assertNever(revokeResult.error.type); + } + } - await membership.update({ - role: req.body.role, - endAt: req.body.role === "revoked" ? new Date() : null, - }); - if (req.body.role === "revoked") { await updateWorkspacePerSeatSubscriptionUsage({ workspaceId: owner.sId, }); + } else { + const role = req.body.role; + if (!isMembershipRoleType(role)) { + return apiError(req, res, { + status_code: 400, + api_error: { + type: "invalid_request_error", + message: + "The request body is invalid, expects { role: 'admin' | 'builder' | 'user' }.", + }, + }); + } + const updateRoleResult = await MembershipResource.updateMembershipRole({ + userId, + workspace: owner, + newRole: role, + // We allow to re-activate a terminated membership when updating the role here. + allowTerminated: true, + }); + if (updateRoleResult.isErr()) { + switch (updateRoleResult.error.type) { + case "not_found": + return apiError(req, res, { + status_code: 404, + api_error: { + type: "workspace_user_not_found", + message: "Could not find the membership.", + }, + }); + case "membership_already_terminated": + // This cannot happen because we allow updating terminated memberships + // by setting `allowTerminated` to true. + throw new Error("Unreachable."); + case "already_on_role": + // Should not happen, but we ignore. + break; + default: + assertNever(updateRoleResult.error.type); + } + } } const w = { ...owner }; diff --git a/front/pages/api/w/[wId]/workspace-analytics.ts b/front/pages/api/w/[wId]/workspace-analytics.ts index 56f346146f93..3e44292451a2 100644 --- a/front/pages/api/w/[wId]/workspace-analytics.ts +++ b/front/pages/api/w/[wId]/workspace-analytics.ts @@ -100,7 +100,7 @@ async function getAnalytics( FROM "memberships" JOIN "workspaces" ON "memberships"."workspaceId" = "workspaces"."id" WHERE "workspaces"."sId" = :wId - AND "memberships"."role" <> 'revoked'; + AND "memberships"."startAt" <= NOW() AND ("memberships"."endAt" IS NULL OR "memberships"."endAt" >= NOW()); `, { replacements: { diff --git a/front/pages/api/w/[wId]/workspace-usage.ts b/front/pages/api/w/[wId]/workspace-usage.ts index 6c809aa9e15f..1e938d22c050 100644 --- a/front/pages/api/w/[wId]/workspace-usage.ts +++ b/front/pages/api/w/[wId]/workspace-usage.ts @@ -86,7 +86,7 @@ async function handler( case "all": return { startDate: new Date("2020-01-01"), - endDate: new Date(), + endDate: endOfMonth(new Date()), }; case "month": const date = new Date(`${query.start}-01`); diff --git a/front/pages/no-workspace.tsx b/front/pages/no-workspace.tsx index 6d603fd50fe3..909443540062 100644 --- a/front/pages/no-workspace.tsx +++ b/front/pages/no-workspace.tsx @@ -13,7 +13,8 @@ import { getUserFromSession, withDefaultUserAuthPaywallWhitelisted, } from "@app/lib/iam/session"; -import { Membership, Workspace, WorkspaceHasDomain } from "@app/lib/models"; +import { Workspace, WorkspaceHasDomain } from "@app/lib/models"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import logger from "@app/logger/logger"; // Fetch workspace details for scenarios where auto-join is disabled. @@ -41,8 +42,10 @@ async function fetchWorkspaceDetails( async function fetchRevokedWorkspace( user: UserTypeWithWorkspaces ): Promise { - const memberships = await Membership.findAll({ - where: { userId: user.id }, + // TODO(@fontanierh): this doesn't look very solid as it will start to behave + // weirdly if a user has multiple revoked memberships. + const memberships = await MembershipResource.getLatestMemberships({ + userIds: [user.id], }); if (!memberships.length) { diff --git a/front/pages/w/[wId]/join.tsx b/front/pages/w/[wId]/join.tsx index e8a69d9ee8dc..9e881b6ca88e 100644 --- a/front/pages/w/[wId]/join.tsx +++ b/front/pages/w/[wId]/join.tsx @@ -101,7 +101,7 @@ export const getServerSideProps = makeGetServerSidePropsRequirementsWrapper({ return { props: { onboardingType: onboardingType, - workspace: workspace, + workspace, signUpCallbackUrl: signUpCallbackUrl, baseUrl: URL, gaTrackingId: GA_TRACKING_ID, diff --git a/front/poke/temporal/activities.ts b/front/poke/temporal/activities.ts index 7075ea2176b6..6bfe072a56f2 100644 --- a/front/poke/temporal/activities.ts +++ b/front/poke/temporal/activities.ts @@ -18,7 +18,6 @@ import { Dataset, DataSource, Key, - Membership, MembershipInvitation, Message, Provider, @@ -41,6 +40,7 @@ import { MessageReaction, } from "@app/lib/models/assistant/conversation"; import { ContentFragmentResource } from "@app/lib/resources/content_fragment_resource"; +import { MembershipResource } from "@app/lib/resources/membership_resource"; import { frontSequelize } from "@app/lib/resources/storage"; import logger from "@app/logger/logger"; @@ -436,19 +436,17 @@ export async function deleteMembersActivity({ transaction: t, }); - const memberships = await Membership.findAll({ - where: { - workspaceId: workspace.id, - }, + const memberships = await MembershipResource.getLatestMemberships({ + workspace, + transaction: t, }); if (memberships.length === 1) { // We also delete the user if it has no other workspace. const membership = memberships[0]; - const membershipsOfUser = await Membership.findAll({ - where: { - userId: membership.userId, - }, + const membershipsOfUser = await MembershipResource.getLatestMemberships({ + userIds: [membership.userId], + transaction: t, }); if (membershipsOfUser.length === 1) { const user = await User.findOne({ @@ -463,15 +461,15 @@ export async function deleteMembersActivity({ }, transaction: t, }); - await membership.destroy({ transaction: t }); + await membership.delete(t); await user.destroy({ transaction: t }); } } - } - - for (const membership of memberships) { - logger.info(`[Workspace delete] Deleting Membership ${membership.id}`); - await membership.destroy({ transaction: t }); + } else { + for (const membership of memberships) { + logger.info(`[Workspace delete] Deleting Membership ${membership.id}`); + await membership.delete(t); + } } }); } diff --git a/types/src/front/memberships.ts b/types/src/front/memberships.ts index 100115dd69f0..2ca0a35e7cc6 100644 --- a/types/src/front/memberships.ts +++ b/types/src/front/memberships.ts @@ -1 +1,7 @@ -export type MembershipRoleType = "admin" | "builder" | "user" | "revoked"; +const MEMBERSHIP_ROLE_TYPES = ["admin", "builder", "user"] as const; +export type MembershipRoleType = (typeof MEMBERSHIP_ROLE_TYPES)[number]; +export function isMembershipRoleType( + value: unknown +): value is MembershipRoleType { + return MEMBERSHIP_ROLE_TYPES.includes(value as MembershipRoleType); +} diff --git a/types/src/index.ts b/types/src/index.ts index fd217abb6283..e16470c126fe 100644 --- a/types/src/index.ts +++ b/types/src/index.ts @@ -59,6 +59,7 @@ export * from "./shared/model_id"; export * from "./shared/nango_errors"; export * from "./shared/rate_limiter"; export * from "./shared/result"; +export * from "./shared/typescipt_utils"; export * from "./shared/user_operation"; export * from "./shared/utils/assert_never"; export * from "./shared/utils/config"; diff --git a/types/src/shared/typescipt_utils.ts b/types/src/shared/typescipt_utils.ts index dc597ae106a8..29e154da7522 100644 --- a/types/src/shared/typescipt_utils.ts +++ b/types/src/shared/typescipt_utils.ts @@ -4,3 +4,11 @@ export type ExtractSpecificKeys = T extends any [P in K]: T[P]; } : never; + +export type RequireAtLeastOne = Pick< + T, + Exclude +> & + { + [K in Keys]-?: Required> & Partial>>; + }[Keys];