diff --git a/CHANGELOG.md b/CHANGELOG.md index 04fde257b450..cca23a1dec77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,31 @@ # Changelog +### [Version 1.35.10](https://github.com/lobehub/lobe-chat/compare/v1.35.9...v1.35.10) + +Released on **2024-12-03** + +#### ♻ Code Refactoring + +- **misc**: Refactor the server db model implement. + +
+ +
+Improvements and Fixes + +#### Code refactoring + +- **misc**: Refactor the server db model implement, closes [#4878](https://github.com/lobehub/lobe-chat/issues/4878) ([3814853](https://github.com/lobehub/lobe-chat/commit/3814853)) + +
+ +
+ +[![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top) + +
+ ### [Version 1.35.9](https://github.com/lobehub/lobe-chat/compare/v1.35.8...v1.35.9) Released on **2024-12-03** diff --git a/changelog/v1.json b/changelog/v1.json index 04d0169d7bc9..6c349e43a780 100644 --- a/changelog/v1.json +++ b/changelog/v1.json @@ -1,4 +1,11 @@ [ + { + "children": { + "improvements": ["Refactor the server db model implement."] + }, + "date": "2024-12-03", + "version": "1.35.10" + }, { "children": {}, "date": "2024-12-03", diff --git a/package.json b/package.json index 0bf2673fca33..f758e3fd829c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@lobehub/chat", - "version": "1.35.9", + "version": "1.35.10", "description": "Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.", "keywords": [ "framework", diff --git a/src/app/(main)/repos/[id]/@menu/default.tsx b/src/app/(main)/repos/[id]/@menu/default.tsx index 2d25320f15ea..2c96eec9fd7d 100644 --- a/src/app/(main)/repos/[id]/@menu/default.tsx +++ b/src/app/(main)/repos/[id]/@menu/default.tsx @@ -1,6 +1,7 @@ import { notFound } from 'next/navigation'; import { Flexbox } from 'react-layout-kit'; +import { serverDB } from '@/database/server'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; import Head from './Head'; @@ -14,7 +15,7 @@ type Props = { params: Params }; const MenuPage = async ({ params }: Props) => { const id = params.id; - const item = await KnowledgeBaseModel.findById(params.id); + const item = await KnowledgeBaseModel.findById(serverDB, params.id); if (!item) return notFound(); diff --git a/src/app/(main)/repos/[id]/page.tsx b/src/app/(main)/repos/[id]/page.tsx index 24662ac70071..fce88618a3df 100644 --- a/src/app/(main)/repos/[id]/page.tsx +++ b/src/app/(main)/repos/[id]/page.tsx @@ -1,5 +1,6 @@ import { redirect } from 'next/navigation'; +import { serverDB } from '@/database/server'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; import FileManager from '@/features/FileManager'; @@ -10,7 +11,7 @@ interface Params { type Props = { params: Params }; export default async ({ params }: Props) => { - const item = await KnowledgeBaseModel.findById(params.id); + const item = await KnowledgeBaseModel.findById(serverDB, params.id); if (!item) return redirect('/repos'); diff --git a/src/database/schemas/topic.ts b/src/database/schemas/topic.ts index 705682fcf99b..3d83636ea0e6 100644 --- a/src/database/schemas/topic.ts +++ b/src/database/schemas/topic.ts @@ -3,6 +3,8 @@ import { boolean, jsonb, pgTable, text, unique } from 'drizzle-orm/pg-core'; import { createInsertSchema } from 'drizzle-zod'; import { idGenerator } from '@/database/utils/idGenerator'; +import { ChatTopicMetadata } from '@/types/topic'; + import { timestamps, timestamptz } from './_helpers'; import { sessions } from './session'; import { users } from './user'; @@ -21,7 +23,7 @@ export const topics = pgTable( .notNull(), clientId: text('client_id'), historySummary: text('history_summary'), - metadata: jsonb('metadata'), + metadata: jsonb('metadata').$type(), ...timestamps, }, (t) => ({ diff --git a/src/database/server/core/dbForTest.ts b/src/database/server/core/dbForTest.ts index c4194f24edeb..e84d2ed11151 100644 --- a/src/database/server/core/dbForTest.ts +++ b/src/database/server/core/dbForTest.ts @@ -11,6 +11,8 @@ import { serverDBEnv } from '@/config/db'; import * as schema from '../../schemas'; +const migrationsFolder = join(__dirname, '../../migrations'); + export const getTestDBInstance = async () => { let connectionString = serverDBEnv.DATABASE_TEST_URL; @@ -23,9 +25,7 @@ export const getTestDBInstance = async () => { const db = nodeDrizzle(client, { schema }); - await nodeMigrator.migrate(db, { - migrationsFolder: join(__dirname, '../../migrations'), - }); + await nodeMigrator.migrate(db, { migrationsFolder }); return db; } @@ -37,9 +37,7 @@ export const getTestDBInstance = async () => { const db = neonDrizzle(client, { schema }); - await migrator.migrate(db, { - migrationsFolder: join(__dirname, '../migrations'), - }); + await migrator.migrate(db, { migrationsFolder }); return db; }; diff --git a/src/database/server/models/__tests__/_test_template.ts b/src/database/server/models/__tests__/_test_template.ts index ee43feaa9907..0a96afd0a72b 100644 --- a/src/database/server/models/__tests__/_test_template.ts +++ b/src/database/server/models/__tests__/_test_template.ts @@ -1,6 +1,6 @@ // @vitest-environment node import { eq } from 'drizzle-orm'; -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; @@ -9,14 +9,8 @@ import { SessionGroupModel } from '../sessionGroup'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'session-group-model-test-user-id'; -const sessionGroupModel = new SessionGroupModel(userId); +const sessionGroupModel = new SessionGroupModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -74,7 +68,7 @@ describe('SessionGroupModel', () => { await sessionGroupModel.create({ name: 'Test Group 1' }); await sessionGroupModel.create({ name: 'Test Group 333' }); - const anotherSessionGroupModel = new SessionGroupModel('user2'); + const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2'); await anotherSessionGroupModel.create({ name: 'Test Group 2' }); await sessionGroupModel.deleteAll(); diff --git a/src/database/server/models/__tests__/agent.test.ts b/src/database/server/models/__tests__/agent.test.ts index 3312e8087ef9..c7e5b46c9b54 100644 --- a/src/database/server/models/__tests__/agent.test.ts +++ b/src/database/server/models/__tests__/agent.test.ts @@ -1,6 +1,6 @@ // @vitest-environment node import { eq } from 'drizzle-orm'; -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; @@ -18,14 +18,8 @@ import { AgentModel } from '../agent'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'agent-model-test-user-id'; -const agentModel = new AgentModel(userId); +const agentModel = new AgentModel(serverDB, userId); const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' }; const fileList = [ diff --git a/src/database/server/models/__tests__/asyncTask.test.ts b/src/database/server/models/__tests__/asyncTask.test.ts index 8c51f5dc3c49..5efafaf3214a 100644 --- a/src/database/server/models/__tests__/asyncTask.test.ts +++ b/src/database/server/models/__tests__/asyncTask.test.ts @@ -10,14 +10,8 @@ import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '../asyncTask'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'async-task-model-test-user-id'; -const asyncTaskModel = new AsyncTaskModel(userId); +const asyncTaskModel = new AsyncTaskModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); diff --git a/src/database/server/models/__tests__/chunk.test.ts b/src/database/server/models/__tests__/chunk.test.ts index 3cf20b1bae89..e027ef362346 100644 --- a/src/database/server/models/__tests__/chunk.test.ts +++ b/src/database/server/models/__tests__/chunk.test.ts @@ -1,30 +1,18 @@ // @vitest-environment node import { eq } from 'drizzle-orm'; -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; +import { uuid } from '@/utils/uuid'; -import { - chunks, - embeddings, - fileChunks, - files, - unstructuredChunks, - users, -} from '../../../schemas'; +import { chunks, embeddings, fileChunks, files, unstructuredChunks, users } from '../../../schemas'; import { ChunkModel } from '../chunk'; import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixtures/embedding'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'chunk-model-test-user-id'; -const chunkModel = new ChunkModel(userId); +const chunkModel = new ChunkModel(serverDB, userId); const sharedFileList = [ { id: '1', @@ -79,6 +67,27 @@ describe('ChunkModel', () => { expect(createdChunks[0]).toMatchObject(params[0]); expect(createdChunks[1]).toMatchObject(params[1]); }); + + // 测试空参数场景 + it('should handle empty params array', async () => { + const result = await chunkModel.bulkCreate([], '1'); + expect(result).toHaveLength(0); + }); + + // 测试事务回滚 + it('should rollback transaction on error', async () => { + const invalidParams = [ + { text: 'Chunk 1', userId }, + { index: 'abc', userId }, // 这会导致错误 + ] as any; + + await expect(chunkModel.bulkCreate(invalidParams, '1')).rejects.toThrow(); + + const createdChunks = await serverDB.query.chunks.findMany({ + where: eq(chunks.userId, userId), + }); + expect(createdChunks).toHaveLength(0); + }); }); describe('delete', () => { @@ -191,6 +200,41 @@ describe('ChunkModel', () => { expect(result[1].id).toBe(chunk2.id); expect(result[0].similarity).toBeGreaterThan(result[1].similarity); }); + // 补充无文件 ID 的搜索场景 + it('should perform semantic search without fileIds', async () => { + const [chunk1, chunk2] = await serverDB + .insert(chunks) + .values([ + { text: 'Test Chunk 1', userId }, + { text: 'Test Chunk 2', userId }, + ]) + .returning(); + + await serverDB.insert(embeddings).values([ + { chunkId: chunk1.id, embeddings: designThinkingQuery, userId }, + { chunkId: chunk2.id, embeddings: codeEmbedding, userId }, + ]); + + const result = await chunkModel.semanticSearch({ + embedding: designThinkingQuery2, + fileIds: undefined, + query: 'design thinking', + }); + + expect(result).toBeDefined(); + expect(result).toHaveLength(2); + }); + + // 测试空结果场景 + it('should return empty array when no matches found', async () => { + const result = await chunkModel.semanticSearch({ + embedding: designThinkingQuery, + fileIds: ['non-existent-file'], + query: 'no matches', + }); + + expect(result).toHaveLength(0); + }); }); describe('bulkCreateUnstructuredChunks', () => { @@ -391,5 +435,100 @@ content in Table html is below: ...
`); }); + + it('should handle null text', () => { + const chunk = { + text: null, + type: 'Text', + metadata: {}, + }; + + const result = chunkModel['mapChunkText'](chunk); + expect(result).toBeNull(); + }); + + it('should handle missing metadata for Table type', () => { + const chunk = { + text: 'Table text', + type: 'Table', + metadata: {}, + }; + + const result = chunkModel['mapChunkText'](chunk); + expect(result).toContain('Table text'); + expect(result).toContain('content in Table html is below:'); + expect(result).toContain('undefined'); // metadata.text_as_html is undefined + }); + }); + + describe('findById', () => { + it('should find a chunk by id', async () => { + // Create a test chunk + const [chunk] = await serverDB + .insert(chunks) + .values({ text: 'Test Chunk', userId }) + .returning(); + + const result = await chunkModel.findById(chunk.id); + + expect(result).toBeDefined(); + expect(result?.id).toBe(chunk.id); + expect(result?.text).toBe('Test Chunk'); + }); + + it('should return null for non-existent id', async () => { + const result = await chunkModel.findById(uuid()); + expect(result).toBeUndefined(); + }); + }); + + describe('semanticSearchForChat', () => { + // 测试空文件 ID 列表场景 + it('should return empty array when fileIds is empty', async () => { + const result = await chunkModel.semanticSearchForChat({ + embedding: designThinkingQuery, + fileIds: [], + query: 'test', + }); + + expect(result).toHaveLength(0); + }); + + // 测试结果限制 + it('should limit results to 5 items', async () => { + const fileId = '1'; + // Create 6 chunks + const chunkResult = await serverDB + .insert(chunks) + .values( + Array(6) + .fill(0) + .map((_, i) => ({ text: `Test Chunk ${i}`, userId })), + ) + .returning(); + + await serverDB.insert(fileChunks).values( + chunkResult.map((chunk) => ({ + fileId, + chunkId: chunk.id, + })), + ); + + await serverDB.insert(embeddings).values( + chunkResult.map((chunk) => ({ + chunkId: chunk.id, + embeddings: designThinkingQuery, + userId, + })), + ); + + const result = await chunkModel.semanticSearchForChat({ + embedding: designThinkingQuery2, + fileIds: [fileId], + query: 'test', + }); + + expect(result).toHaveLength(5); + }); }); }); diff --git a/src/database/server/models/__tests__/file.test.ts b/src/database/server/models/__tests__/file.test.ts index 4fcea38cd482..9d94b433e28c 100644 --- a/src/database/server/models/__tests__/file.test.ts +++ b/src/database/server/models/__tests__/file.test.ts @@ -2,27 +2,14 @@ import { eq, inArray } from 'drizzle-orm'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; -import { getServerDBConfig, serverDBEnv } from '@/config/db'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; import { FilesTabs, SortType } from '@/types/files'; -import { - files, - globalFiles, - knowledgeBaseFiles, - knowledgeBases, - users, -} from '../../../schemas'; +import { files, globalFiles, knowledgeBaseFiles, knowledgeBases, users } from '../../../schemas'; import { FileModel } from '../file'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - let DISABLE_REMOVE_GLOBAL_FILE = false; vi.mock('@/config/db', async () => ({ @@ -38,7 +25,7 @@ vi.mock('@/config/db', async () => ({ })); const userId = 'file-model-test-user-id'; -const fileModel = new FileModel(userId); +const fileModel = new FileModel(serverDB, userId); const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' }; beforeEach(async () => { @@ -603,4 +590,125 @@ describe('FileModel', () => { expect(size).toBe(3500); }); }); + + describe('findByNames', () => { + it('should find files by names', async () => { + // 准备测试数据 + const fileList = [ + { + name: 'test1.txt', + url: 'https://example.com/test1.txt', + size: 100, + fileType: 'text/plain', + userId, + }, + { + name: 'test2.txt', + url: 'https://example.com/test2.txt', + size: 200, + fileType: 'text/plain', + userId, + }, + { + name: 'other.txt', + url: 'https://example.com/other.txt', + size: 300, + fileType: 'text/plain', + userId, + }, + ]; + + await serverDB.insert(files).values(fileList); + + // 测试查找文件 + const result = await fileModel.findByNames(['test1', 'test2']); + expect(result).toHaveLength(2); + expect(result.map((f) => f.name)).toContain('test1.txt'); + expect(result.map((f) => f.name)).toContain('test2.txt'); + }); + + it('should return empty array when no files match names', async () => { + const result = await fileModel.findByNames(['nonexistent']); + expect(result).toHaveLength(0); + }); + + it('should only find files belonging to current user', async () => { + // 准备测试数据 + await serverDB.insert(files).values([ + { + name: 'test1.txt', + url: 'https://example.com/test1.txt', + size: 100, + fileType: 'text/plain', + userId, + }, + { + name: 'test2.txt', + url: 'https://example.com/test2.txt', + size: 200, + fileType: 'text/plain', + userId: 'user2', // 不同用户的文件 + }, + ]); + + const result = await fileModel.findByNames(['test']); + expect(result).toHaveLength(1); + expect(result[0].name).toBe('test1.txt'); + }); + }); + + describe('deleteGlobalFile', () => { + it('should delete global file by hashId', async () => { + // 准备测试数据 + const globalFile = { + hashId: 'test-hash', + fileType: 'text/plain', + size: 100, + url: 'https://example.com/global-file.txt', + metadata: { key: 'value' }, + }; + + await serverDB.insert(globalFiles).values(globalFile); + + // 执行删除操作 + await fileModel.deleteGlobalFile('test-hash'); + + // 验证文件已被删除 + const result = await serverDB.query.globalFiles.findFirst({ + where: eq(globalFiles.hashId, 'test-hash'), + }); + expect(result).toBeUndefined(); + }); + + it('should not throw error when deleting non-existent global file', async () => { + // 删除不存在的文件不应抛出错误 + await expect(fileModel.deleteGlobalFile('non-existent-hash')).resolves.not.toThrow(); + }); + + it('should only delete specified global file', async () => { + // 准备测试数据 + const globalFiles1 = { + hashId: 'hash1', + fileType: 'text/plain', + size: 100, + url: 'https://example.com/file1.txt', + }; + const globalFiles2 = { + hashId: 'hash2', + fileType: 'text/plain', + size: 200, + url: 'https://example.com/file2.txt', + }; + + await serverDB.insert(globalFiles).values([globalFiles1, globalFiles2]); + + // 删除一个文件 + await fileModel.deleteGlobalFile('hash1'); + + // 验证只有指定文件被删除 + const remainingFiles = await serverDB.query.globalFiles.findMany(); + expect(remainingFiles).toHaveLength(1); + expect(remainingFiles[0].hashId).toBe('hash2'); + }); + }); }); diff --git a/src/database/server/models/__tests__/knowledgeBase.test.ts b/src/database/server/models/__tests__/knowledgeBase.test.ts index 38bea2345a62..008ac19c58a2 100644 --- a/src/database/server/models/__tests__/knowledgeBase.test.ts +++ b/src/database/server/models/__tests__/knowledgeBase.test.ts @@ -1,7 +1,7 @@ // @vitest-environment node import { eq } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; @@ -17,14 +17,8 @@ import { KnowledgeBaseModel } from '../knowledgeBase'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'session-group-model-test-user-id'; -const knowledgeBaseModel = new KnowledgeBaseModel(userId); +const knowledgeBaseModel = new KnowledgeBaseModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -82,7 +76,7 @@ describe('KnowledgeBaseModel', () => { await knowledgeBaseModel.create({ name: 'Test Group 1' }); await knowledgeBaseModel.create({ name: 'Test Group 333' }); - const anotherSessionGroupModel = new KnowledgeBaseModel('user2'); + const anotherSessionGroupModel = new KnowledgeBaseModel(serverDB, 'user2'); await anotherSessionGroupModel.create({ name: 'Test Group 2' }); await knowledgeBaseModel.deleteAll(); @@ -235,7 +229,7 @@ describe('KnowledgeBaseModel', () => { it('should find a knowledge base by id without user restriction', async () => { const { id } = await knowledgeBaseModel.create({ name: 'Test Group' }); - const group = await KnowledgeBaseModel.findById(id); + const group = await KnowledgeBaseModel.findById(serverDB, id); expect(group).toMatchObject({ id, name: 'Test Group', @@ -244,10 +238,10 @@ describe('KnowledgeBaseModel', () => { }); it('should find a knowledge base created by another user', async () => { - const anotherKnowledgeBaseModel = new KnowledgeBaseModel('user2'); + const anotherKnowledgeBaseModel = new KnowledgeBaseModel(serverDB, 'user2'); const { id } = await anotherKnowledgeBaseModel.create({ name: 'Another User Group' }); - const group = await KnowledgeBaseModel.findById(id); + const group = await KnowledgeBaseModel.findById(serverDB, id); expect(group).toMatchObject({ id, name: 'Another User Group', diff --git a/src/database/server/models/__tests__/message.test.ts b/src/database/server/models/__tests__/message.test.ts index 262082d3320d..653a41d9b950 100644 --- a/src/database/server/models/__tests__/message.test.ts +++ b/src/database/server/models/__tests__/message.test.ts @@ -2,10 +2,16 @@ import { eq } from 'drizzle-orm'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; +import { uuid } from '@/utils/uuid'; import { + chunks, + embeddings, + fileChunks, files, messagePlugins, + messageQueries, + messageQueryChunks, messageTTS, messageTranslates, messages, @@ -15,17 +21,12 @@ import { users, } from '../../../schemas'; import { MessageModel } from '../message'; +import { codeEmbedding } from './fixtures/embedding'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'message-db'; -const messageModel = new MessageModel(userId); +const messageModel = new MessageModel(serverDB, userId); beforeEach(async () => { // 在每个测试用例之前,清空表 @@ -273,6 +274,93 @@ describe('MessageModel', () => { const result3 = await messageModel.query({ current: 2, pageSize: 2 }); expect(result3).toHaveLength(0); }); + + // 补充测试复杂查询场景 + it('should handle complex query with multiple joins and file chunks', async () => { + await serverDB.transaction(async (trx) => { + const chunk1Id = uuid(); + const query1Id = uuid(); + // 创建基础消息 + await trx.insert(messages).values({ + id: 'msg1', + userId, + role: 'user', + content: 'test message', + createdAt: new Date('2023-01-01'), + }); + + // 创建文件 + await trx.insert(files).values([ + { + id: 'file1', + userId, + name: 'test.txt', + url: 'test-url', + fileType: 'text/plain', + size: 100, + }, + ]); + + // 创建文件块 + await trx.insert(chunks).values({ + id: chunk1Id, + text: 'chunk content', + }); + + // 关联消息和文件 + await trx.insert(messagesFiles).values({ + messageId: 'msg1', + fileId: 'file1', + }); + + // 创建文件块关联 + await trx.insert(fileChunks).values({ + fileId: 'file1', + chunkId: chunk1Id, + }); + + // 创建消息查询 + await trx.insert(messageQueries).values({ + id: query1Id, + messageId: 'msg1', + userQuery: 'original query', + rewriteQuery: 'rewritten query', + }); + + // 创建消息查询块关联 + await trx.insert(messageQueryChunks).values({ + messageId: 'msg1', + queryId: query1Id, + chunkId: chunk1Id, + similarity: '0.95', + }); + }); + + const result = await messageModel.query(); + + expect(result).toHaveLength(1); + expect(result[0].chunksList).toHaveLength(1); + expect(result[0].chunksList[0]).toMatchObject({ + text: 'chunk content', + similarity: 0.95, + }); + }); + + it('should return empty arrays for files and chunks if none exist', async () => { + await serverDB.insert(messages).values({ + id: 'msg1', + userId, + role: 'user', + content: 'test message', + }); + + const result = await messageModel.query(); + + expect(result).toHaveLength(1); + expect(result[0].fileList).toEqual([]); + expect(result[0].imageList).toEqual([]); + expect(result[0].chunksList).toEqual([]); + }); }); describe('queryAll', () => { @@ -1021,4 +1109,139 @@ describe('MessageModel', () => { expect(result).toBe(2); }); }); + + describe('findMessageQueriesById', () => { + it('should return undefined for non-existent message query', async () => { + const result = await messageModel.findMessageQueriesById('non-existent-id'); + expect(result).toBeUndefined(); + }); + + it('should return message query with embeddings', async () => { + const query1Id = uuid(); + const embeddings1Id = uuid(); + + await serverDB.transaction(async (trx) => { + await trx.insert(messages).values({ id: 'msg1', userId, role: 'user', content: 'abc' }); + + await trx.insert(embeddings).values({ + id: embeddings1Id, + embeddings: codeEmbedding, + }); + + await trx.insert(messageQueries).values({ + id: query1Id, + messageId: 'msg1', + userQuery: 'test query', + rewriteQuery: 'rewritten query', + embeddingsId: embeddings1Id, + }); + }); + + const result = await messageModel.findMessageQueriesById('msg1'); + + expect(result).toBeDefined(); + expect(result).toMatchObject({ + id: query1Id, + userQuery: 'test query', + rewriteQuery: 'rewritten query', + embeddings: codeEmbedding, + }); + }); + }); + + describe('deleteMessagesBySession', () => { + it('should delete messages by session ID', async () => { + await serverDB.insert(sessions).values([ + { id: 'session1', userId }, + { id: 'session2', userId }, + ]); + + await serverDB.insert(messages).values([ + { + id: '1', + userId, + sessionId: 'session1', + role: 'user', + content: 'message 1', + }, + { + id: '2', + userId, + sessionId: 'session1', + role: 'assistant', + content: 'message 2', + }, + { + id: '3', + userId, + sessionId: 'session2', + role: 'user', + content: 'message 3', + }, + ]); + + await messageModel.deleteMessagesBySession('session1'); + + const remainingMessages = await serverDB + .select() + .from(messages) + .where(eq(messages.userId, userId)); + + expect(remainingMessages).toHaveLength(1); + expect(remainingMessages[0].id).toBe('3'); + }); + + it('should delete messages by session ID and topic ID', async () => { + await serverDB.insert(sessions).values([{ id: 'session1', userId }]); + await serverDB.insert(topics).values([ + { id: 'topic1', sessionId: 'session1', userId }, + { id: 'topic2', sessionId: 'session1', userId }, + ]); + + await serverDB.insert(messages).values([ + { + id: '1', + userId, + sessionId: 'session1', + topicId: 'topic1', + role: 'user', + content: 'message 1', + }, + { + id: '2', + userId, + sessionId: 'session1', + topicId: 'topic2', + role: 'assistant', + content: 'message 2', + }, + ]); + + await messageModel.deleteMessagesBySession('session1', 'topic1'); + + const remainingMessages = await serverDB + .select() + .from(messages) + .where(eq(messages.userId, userId)); + + expect(remainingMessages).toHaveLength(1); + expect(remainingMessages[0].id).toBe('2'); + }); + }); + + describe('genId', () => { + it('should generate unique message IDs', () => { + const model = new MessageModel(serverDB, userId); + // @ts-ignore - accessing private method for testing + const id1 = model.genId(); + // @ts-ignore - accessing private method for testing + const id2 = model.genId(); + + expect(id1).toHaveLength(18); + expect(id2).toHaveLength(18); + expect(id1).not.toBe(id2); + expect(id1).toMatch(/^msg_/); + expect(id2).toMatch(/^msg_/); + }); + }); }); diff --git a/src/database/server/models/__tests__/nextauth.test.ts b/src/database/server/models/__tests__/nextauth.test.ts index 6038f11f1c73..82699ed0542a 100644 --- a/src/database/server/models/__tests__/nextauth.test.ts +++ b/src/database/server/models/__tests__/nextauth.test.ts @@ -8,7 +8,6 @@ import type { import { eq } from 'drizzle-orm'; import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest'; -import { getTestDBInstance } from '@/database/server/core/dbForTest'; import { nextauthAccounts, nextauthAuthenticators, @@ -16,17 +15,12 @@ import { nextauthVerificationTokens, users, } from '@/database/schemas'; +import { getTestDBInstance } from '@/database/server/core/dbForTest'; import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter'; let serverDB = await getTestDBInstance(); let nextAuthAdapter = LobeNextAuthDbAdapter(serverDB); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'user-db'; const user: AdapterUser = { id: userId, diff --git a/src/database/server/models/__tests__/plugin.test.ts b/src/database/server/models/__tests__/plugin.test.ts index 449cac8cd73b..8ea58e06dc90 100644 --- a/src/database/server/models/__tests__/plugin.test.ts +++ b/src/database/server/models/__tests__/plugin.test.ts @@ -8,14 +8,8 @@ import { PluginModel } from '../plugin'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'plugin-db'; -const pluginModel = new PluginModel(userId); +const pluginModel = new PluginModel(serverDB, userId); beforeEach(async () => { await serverDB.transaction(async (trx) => { diff --git a/src/database/server/models/__tests__/session.test.ts b/src/database/server/models/__tests__/session.test.ts index bccea65f9277..923504d25ab3 100644 --- a/src/database/server/models/__tests__/session.test.ts +++ b/src/database/server/models/__tests__/session.test.ts @@ -1,7 +1,9 @@ -import { eq, inArray } from 'drizzle-orm'; +import { and, eq, inArray } from 'drizzle-orm'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { DEFAULT_AGENT_CONFIG } from '@/const/settings'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; +import { idGenerator } from '@/database/utils/idGenerator'; import { NewSession, @@ -14,19 +16,12 @@ import { topics, users, } from '../../../schemas'; -import { idGenerator } from '@/database/utils/idGenerator'; import { SessionModel } from '../session'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'session-user'; -const sessionModel = new SessionModel(userId); +const sessionModel = new SessionModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -259,7 +254,13 @@ describe('SessionModel', () => { ]); await serverDB.insert(agents).values([ - { id: 'agent-1', userId, model: 'gpt-3.5-turbo', title: 'Agent 1', description: 'Description with Keyword' }, + { + id: 'agent-1', + userId, + model: 'gpt-3.5-turbo', + title: 'Agent 1', + description: 'Description with Keyword', + }, { id: 'agent-2', userId, model: 'gpt-4', title: 'Agent 2' }, ]); @@ -338,7 +339,7 @@ describe('SessionModel', () => { }); }); - describe.skip('batchCreate', () => { + describe('batchCreate', () => { it('should batch create sessions', async () => { // 调用 batchCreate 方法 const sessions: NewSession[] = [ @@ -624,4 +625,161 @@ describe('SessionModel', () => { expect(await serverDB.select().from(sessions).where(eq(sessions.id, '3'))).toHaveLength(1); }); }); + + // 在原有的 describe('SessionModel') 中添加以下测试套件 + + describe('createInbox', () => { + it('should create inbox session if not exists', async () => { + const inbox = await sessionModel.createInbox(); + + expect(inbox).toBeDefined(); + expect(inbox?.slug).toBe('inbox'); + + // verify agent config + const session = await sessionModel.findByIdOrSlug('inbox'); + expect(session?.agent).toBeDefined(); + expect(session?.agent.model).toBe(DEFAULT_AGENT_CONFIG.model); + }); + + it('should not create duplicate inbox session', async () => { + // Create first inbox + await sessionModel.createInbox(); + + // Try to create another inbox + const duplicateInbox = await sessionModel.createInbox(); + + // Should return undefined as inbox already exists + expect(duplicateInbox).toBeUndefined(); + + // Verify only one inbox exists + const sessions = await serverDB.query.sessions.findMany(); + + const inboxSessions = sessions.filter((s) => s.slug === 'inbox'); + expect(inboxSessions).toHaveLength(1); + }); + }); + + describe('deleteAll', () => { + it('should delete all sessions for current user', async () => { + // Create test data + await serverDB.insert(sessions).values([ + { id: '1', userId }, + { id: '2', userId }, + { id: '3', userId }, + ]); + + // Create sessions for another user that should not be deleted + await serverDB.insert(users).values([{ id: 'other-user' }]); + await serverDB.insert(sessions).values([ + { id: '4', userId: 'other-user' }, + { id: '5', userId: 'other-user' }, + ]); + + await sessionModel.deleteAll(); + + // Verify all sessions for current user are deleted + const remainingSessions = await serverDB + .select() + .from(sessions) + .where(eq(sessions.userId, userId)); + expect(remainingSessions).toHaveLength(0); + + // Verify other user's sessions are not deleted + const otherUserSessions = await serverDB + .select() + .from(sessions) + .where(eq(sessions.userId, 'other-user')); + expect(otherUserSessions).toHaveLength(2); + }); + + it('should delete associated data when deleting all sessions', async () => { + // Create test data with associated records + await serverDB.transaction(async (trx) => { + await trx.insert(sessions).values([ + { id: '1', userId }, + { id: '2', userId }, + ]); + + await trx.insert(topics).values([ + { id: 't1', sessionId: '1', userId }, + { id: 't2', sessionId: '2', userId }, + ]); + + await trx.insert(messages).values([ + { id: 'm1', sessionId: '1', userId, role: 'user' }, + { id: 'm2', sessionId: '2', userId, role: 'assistant' }, + ]); + }); + + await sessionModel.deleteAll(); + + // Verify all associated data is deleted + const remainingTopics = await serverDB.select().from(topics).where(eq(topics.userId, userId)); + expect(remainingTopics).toHaveLength(0); + + const remainingMessages = await serverDB + .select() + .from(messages) + .where(eq(messages.userId, userId)); + expect(remainingMessages).toHaveLength(0); + }); + }); + + describe('updateConfig', () => { + it('should update agent config', async () => { + // Create test agent + const agentId = 'test-agent'; + await serverDB.insert(agents).values({ + id: agentId, + userId, + model: 'gpt-3.5-turbo', + title: 'Original Title', + }); + + // Update config + await sessionModel.updateConfig(agentId, { + model: 'gpt-4', + title: 'Updated Title', + description: 'New description', + }); + + // Verify update + const updatedAgent = await serverDB + .select() + .from(agents) + .where(and(eq(agents.id, agentId), eq(agents.userId, userId))); + + expect(updatedAgent[0]).toMatchObject({ + model: 'gpt-4', + title: 'Updated Title', + description: 'New description', + }); + }); + + it('should not update config for other users agents', async () => { + // Create agent for another user + const agentId = 'other-agent'; + await serverDB.insert(users).values([{ id: 'other-user' }]); + await serverDB.insert(agents).values({ + id: agentId, + userId: 'other-user', + model: 'gpt-3.5-turbo', + title: 'Original Title', + }); + + // Try to update other user's agent + await sessionModel.updateConfig(agentId, { + model: 'gpt-4', + title: 'Updated Title', + }); + + // Verify no changes were made + const agent = await serverDB.select().from(agents).where(eq(agents.id, agentId)); + + expect(agent[0]).toMatchObject({ + model: 'gpt-3.5-turbo', + title: 'Original Title', + }); + }); + }); }); diff --git a/src/database/server/models/__tests__/sessionGroup.test.ts b/src/database/server/models/__tests__/sessionGroup.test.ts index a94be4447526..2341eecdc09a 100644 --- a/src/database/server/models/__tests__/sessionGroup.test.ts +++ b/src/database/server/models/__tests__/sessionGroup.test.ts @@ -10,14 +10,8 @@ import { SessionGroupModel } from '../sessionGroup'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'session-group-model-test-user-id'; -const sessionGroupModel = new SessionGroupModel(userId); +const sessionGroupModel = new SessionGroupModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -75,7 +69,7 @@ describe('SessionGroupModel', () => { await sessionGroupModel.create({ name: 'Test Group 1' }); await sessionGroupModel.create({ name: 'Test Group 333' }); - const anotherSessionGroupModel = new SessionGroupModel('user2'); + const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2'); await anotherSessionGroupModel.create({ name: 'Test Group 2' }); await sessionGroupModel.deleteAll(); diff --git a/src/database/server/models/__tests__/topic.test.ts b/src/database/server/models/__tests__/topic.test.ts index ef205f117986..358d1ec4bca2 100644 --- a/src/database/server/models/__tests__/topic.test.ts +++ b/src/database/server/models/__tests__/topic.test.ts @@ -8,15 +8,9 @@ import { CreateTopicParams, TopicModel } from '../topic'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'topic-user-test'; const sessionId = 'topic-session'; -const topicModel = new TopicModel(userId); +const topicModel = new TopicModel(serverDB, userId); describe('TopicModel', () => { beforeEach(async () => { diff --git a/src/database/server/models/__tests__/user.test.ts b/src/database/server/models/__tests__/user.test.ts index 1004e5590cc1..9496a9fedb4e 100644 --- a/src/database/server/models/__tests__/user.test.ts +++ b/src/database/server/models/__tests__/user.test.ts @@ -13,15 +13,9 @@ import { UserModel } from '../user'; let serverDB = await getTestDBInstance(); -vi.mock('@/database/server/core/db', async () => ({ - get serverDB() { - return serverDB; - }, -})); - const userId = 'user-db'; const userEmail = 'user@example.com'; -const userModel = new UserModel(); +const userModel = new UserModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -44,14 +38,14 @@ describe('UserModel', () => { email: 'test@example.com', }; - await UserModel.createUser(params); + await UserModel.createUser(serverDB, params); const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(user).not.toBeNull(); expect(user?.username).toBe('testuser'); expect(user?.email).toBe('test@example.com'); - const sessionModel = new SessionModel(userId); + const sessionModel = new SessionModel(serverDB, userId); const inbox = await sessionModel.findByIdOrSlug(INBOX_SESSION_ID); expect(inbox).not.toBeNull(); }); @@ -61,7 +55,7 @@ describe('UserModel', () => { it('should delete a user', async () => { await serverDB.insert(users).values({ id: userId }); - await UserModel.deleteUser(userId); + await UserModel.deleteUser(serverDB, userId); const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(user).toBeUndefined(); @@ -72,7 +66,7 @@ describe('UserModel', () => { it('should find a user by ID', async () => { await serverDB.insert(users).values({ id: userId, username: 'testuser' }); - const user = await UserModel.findById(userId); + const user = await UserModel.findById(serverDB, userId); expect(user).not.toBeNull(); expect(user?.id).toBe(userId); @@ -84,7 +78,7 @@ describe('UserModel', () => { it('should find a user by email', async () => { await serverDB.insert(users).values({ id: userId, email: userEmail }); - const user = await UserModel.findByEmail(userEmail); + const user = await UserModel.findByEmail(serverDB, userEmail); expect(user).not.toBeNull(); expect(user?.id).toBe(userId); @@ -107,7 +101,7 @@ describe('UserModel', () => { keyVaults: encryptedKeyVaults, }); - const state = await userModel.getUserState(userId); + const state = await userModel.getUserState(); expect(state.userId).toBe(userId); expect(state.preference).toEqual(preference); @@ -115,7 +109,9 @@ describe('UserModel', () => { }); it('should throw an error if user not found', async () => { - await expect(userModel.getUserState('invalid-user-id')).rejects.toThrow('user not found'); + const userModel = new UserModel(serverDB, 'invalid-user-id'); + + await expect(userModel.getUserState()).rejects.toThrow('user not found'); }); }); @@ -123,7 +119,7 @@ describe('UserModel', () => { it('should update user fields', async () => { await serverDB.insert(users).values({ id: userId, username: 'oldname' }); - await userModel.updateUser(userId, { username: 'newname' }); + await userModel.updateUser({ username: 'newname' }); const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId), @@ -137,7 +133,7 @@ describe('UserModel', () => { await serverDB.insert(users).values({ id: userId }); await serverDB.insert(userSettings).values({ id: userId }); - await userModel.deleteSetting(userId); + await userModel.deleteSetting(); const settings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, userId), @@ -155,7 +151,7 @@ describe('UserModel', () => { } as UserSettings; await serverDB.insert(users).values({ id: userId }); - await userModel.updateSetting(userId, settings); + await userModel.updateSetting(settings); const updatedSettings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, userId), @@ -178,7 +174,7 @@ describe('UserModel', () => { const newSettings = { general: { fontSize: 16, language: 'zh-CN', themeMode: 'dark' }, } as UserSettings; - await userModel.updateSetting(userId, newSettings); + await userModel.updateSetting(newSettings); const updatedSettings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, userId), @@ -195,7 +191,7 @@ describe('UserModel', () => { const newPreference: Partial = { guide: { topic: true, moveSettingsToAvatar: true }, }; - await userModel.updatePreference(userId, newPreference); + await userModel.updatePreference(newPreference); const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(updatedUser?.preference).toEqual({ ...preference, ...newPreference }); @@ -212,10 +208,49 @@ describe('UserModel', () => { moveSettingsToAvatar: true, uploadFileInKnowledgeBase: true, }; - await userModel.updateGuide(userId, newGuide); + await userModel.updateGuide(newGuide); const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(updatedUser?.preference).toEqual({ ...preference, guide: newGuide }); }); }); + + describe('getUserApiKeys', () => { + it('should get and decrypt user API keys', async () => { + const keyVaults = { openai: { apiKey: 'test-key' } }; + const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); + const encryptedKeyVaults = await gateKeeper.encrypt(JSON.stringify(keyVaults)); + + const userId = 'user-api-id'; + + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(userSettings).values({ + id: userId, + keyVaults: encryptedKeyVaults, + }); + + const result = await UserModel.getUserApiKeys(serverDB, userId); + expect(result).toEqual(keyVaults); + }); + + it('should throw error when user not found', async () => { + await expect(UserModel.getUserApiKeys(serverDB, 'non-existent-id')).rejects.toThrow( + 'user not found', + ); + }); + + it('should handle decrypt failure and return empty object', async () => { + const userId = 'user-api-test-id'; + // 模拟解密失败的情况 + const invalidEncryptedData = 'invalid:-encrypted-:data'; + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(userSettings).values({ + id: userId, + keyVaults: invalidEncryptedData, + }); + + const result = await UserModel.getUserApiKeys(serverDB, userId); + expect(result).toEqual({}); + }); + }); }); diff --git a/src/database/server/models/_template.ts b/src/database/server/models/_template.ts index 82c0efa3da2d..a802cbcec71e 100644 --- a/src/database/server/models/_template.ts +++ b/src/database/server/models/_template.ts @@ -1,19 +1,21 @@ import { eq } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { NewSessionGroup, SessionGroupItem, sessionGroups } from '../../schemas'; export class TemplateModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: NewSessionGroup) => { - const [result] = await serverDB + const [result] = await this.db .insert(sessionGroups) .values({ ...params, userId: this.userId }) .returning(); @@ -22,30 +24,30 @@ export class TemplateModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(sessionGroups) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); + return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); }; query = async () => { - return serverDB.query.sessionGroups.findMany({ + return this.db.query.sessionGroups.findMany({ orderBy: [desc(sessionGroups.updatedAt)], where: eq(sessionGroups.userId, this.userId), }); }; findById = async (id: string) => { - return serverDB.query.sessionGroups.findFirst({ + return this.db.query.sessionGroups.findFirst({ where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); diff --git a/src/database/server/models/agent.ts b/src/database/server/models/agent.ts index 34c627016c0b..03a779aabce1 100644 --- a/src/database/server/models/agent.ts +++ b/src/database/server/models/agent.ts @@ -1,7 +1,8 @@ import { inArray } from 'drizzle-orm'; import { and, desc, eq } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; + import { agents, agentsFiles, @@ -13,12 +14,15 @@ import { export class AgentModel { private userId: string; - constructor(userId: string) { + private db: LobeChatDatabase; + + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } async getAgentConfigById(id: string) { - const agent = await serverDB.query.agents.findFirst({ where: eq(agents.id, id) }); + const agent = await this.db.query.agents.findFirst({ where: eq(agents.id, id) }); const knowledge = await this.getAgentAssignedKnowledge(id); @@ -26,14 +30,14 @@ export class AgentModel { } async getAgentAssignedKnowledge(id: string) { - const knowledgeBaseResult = await serverDB + const knowledgeBaseResult = await this.db .select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases }) .from(agentsKnowledgeBases) .where(eq(agentsKnowledgeBases.agentId, id)) .orderBy(desc(agentsKnowledgeBases.createdAt)) .leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId)); - const fileResult = await serverDB + const fileResult = await this.db .select({ enabled: agentsFiles.enabled, files }) .from(agentsFiles) .where(eq(agentsFiles.agentId, id)) @@ -56,7 +60,7 @@ export class AgentModel { * Find agent by session id */ async findBySessionId(sessionId: string) { - const item = await serverDB.query.agentsToSessions.findFirst({ + const item = await this.db.query.agentsToSessions.findFirst({ where: eq(agentsToSessions.sessionId, sessionId), }); if (!item) return; @@ -71,7 +75,7 @@ export class AgentModel { knowledgeBaseId: string, enabled: boolean = true, ) => { - return serverDB + return this.db .insert(agentsKnowledgeBases) .values({ agentId, @@ -83,7 +87,7 @@ export class AgentModel { }; deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => { - return serverDB + return this.db .delete(agentsKnowledgeBases) .where( and( @@ -96,7 +100,7 @@ export class AgentModel { }; toggleKnowledgeBase = async (agentId: string, knowledgeBaseId: string, enabled?: boolean) => { - return serverDB + return this.db .update(agentsKnowledgeBases) .set({ enabled }) .where( @@ -111,7 +115,7 @@ export class AgentModel { createAgentFiles = async (agentId: string, fileIds: string[], enabled: boolean = true) => { // Exclude the fileIds that already exist in agentsFiles, and then insert them - const existingFiles = await serverDB + const existingFiles = await this.db .select({ id: agentsFiles.fileId }) .from(agentsFiles) .where( @@ -128,7 +132,7 @@ export class AgentModel { if (needToInsertFileIds.length === 0) return; - return serverDB + return this.db .insert(agentsFiles) .values( needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })), @@ -137,7 +141,7 @@ export class AgentModel { }; deleteAgentFile = async (agentId: string, fileId: string) => { - return serverDB + return this.db .delete(agentsFiles) .where( and( @@ -150,7 +154,7 @@ export class AgentModel { }; toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => { - return serverDB + return this.db .update(agentsFiles) .set({ enabled }) .where( diff --git a/src/database/server/models/asyncTask.ts b/src/database/server/models/asyncTask.ts index 48ed6b210909..5eb90fec31d7 100644 --- a/src/database/server/models/asyncTask.ts +++ b/src/database/server/models/asyncTask.ts @@ -1,7 +1,7 @@ import { eq, inArray, lt } from 'drizzle-orm'; import { and } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { AsyncTaskError, AsyncTaskErrorType, @@ -16,13 +16,15 @@ export const ASYNC_TASK_TIMEOUT = 298 * 1000; export class AsyncTaskModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: Pick): Promise => { - const data = await serverDB + const data = await this.db .insert(asyncTasks) .values({ ...params, userId: this.userId }) .returning(); @@ -31,17 +33,17 @@ export class AsyncTaskModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(asyncTasks) .where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId))); }; findById = async (id: string) => { - return serverDB.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) }); + return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) }); }; update(taskId: string, value: Partial) { - return serverDB + return this.db .update(asyncTasks) .set({ ...value, updatedAt: new Date() }) .where(and(eq(asyncTasks.id, taskId))); @@ -52,7 +54,7 @@ export class AsyncTaskModel { if (taskIds.length > 0) { await this.checkTimeoutTasks(taskIds); - chunkTasks = await serverDB.query.asyncTasks.findMany({ + chunkTasks = await this.db.query.asyncTasks.findMany({ where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)), }); } @@ -64,7 +66,7 @@ export class AsyncTaskModel { * make the task status to be `error` if the task is not finished in 20 seconds */ async checkTimeoutTasks(ids: string[]) { - const tasks = await serverDB + const tasks = await this.db .select({ id: asyncTasks.id }) .from(asyncTasks) .where( @@ -76,7 +78,7 @@ export class AsyncTaskModel { ); if (tasks.length > 0) { - await serverDB + await this.db .update(asyncTasks) .set({ error: new AsyncTaskError( diff --git a/src/database/server/models/chunk.ts b/src/database/server/models/chunk.ts index a00c434b4b2a..6b554fd58435 100644 --- a/src/database/server/models/chunk.ts +++ b/src/database/server/models/chunk.ts @@ -2,7 +2,7 @@ import { asc, cosineDistance, count, eq, inArray, sql } from 'drizzle-orm'; import { and, desc, isNull } from 'drizzle-orm/expressions'; import { chunk } from 'lodash-es'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { ChunkMetadata, FileChunk } from '@/types/chunk'; import { @@ -18,12 +18,17 @@ import { export class ChunkModel { private userId: string; - constructor(userId: string) { + private db: LobeChatDatabase; + + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } bulkCreate = async (params: NewChunkItem[], fileId: string) => { - return serverDB.transaction(async (trx) => { + return this.db.transaction(async (trx) => { + if (params.length === 0) return []; + const result = await trx.insert(chunks).values(params).returning(); const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId })); @@ -37,15 +42,15 @@ export class ChunkModel { }; bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => { - return serverDB.insert(unstructuredChunks).values(params); + return this.db.insert(unstructuredChunks).values(params); }; delete = async (id: string) => { - return serverDB.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId))); + return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId))); }; deleteOrphanChunks = async () => { - const orphanedChunks = await serverDB + const orphanedChunks = await this.db .select({ chunkId: chunks.id }) .from(chunks) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) @@ -56,7 +61,7 @@ export class ChunkModel { const list = chunk(ids, 500); - await serverDB.transaction(async (trx) => { + await this.db.transaction(async (trx) => { await Promise.all( list.map(async (chunkIds) => { await trx.delete(chunks).where(inArray(chunks.id, chunkIds)); @@ -66,13 +71,13 @@ export class ChunkModel { }; findById = async (id: string) => { - return serverDB.query.chunks.findFirst({ + return this.db.query.chunks.findFirst({ where: and(eq(chunks.id, id)), }); }; async findByFileId(id: string, page = 0) { - const data = await serverDB + const data = await this.db .select({ abstract: chunks.abstract, createdAt: chunks.createdAt, @@ -98,7 +103,7 @@ export class ChunkModel { } async getChunksTextByFileId(id: string): Promise<{ id: string; text: string }[]> { - const data = await serverDB + const data = await this.db .select() .from(chunks) .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) @@ -113,7 +118,7 @@ export class ChunkModel { async countByFileIds(ids: string[]) { if (ids.length === 0) return []; - return serverDB + return this.db .select({ count: count(fileChunks.chunkId), id: fileChunks.fileId, @@ -124,7 +129,7 @@ export class ChunkModel { } async countByFileId(ids: string) { - const data = await serverDB + const data = await this.db .select({ count: count(fileChunks.chunkId), id: fileChunks.fileId, @@ -146,7 +151,7 @@ export class ChunkModel { }) { const similarity = sql`1 - (${cosineDistance(embeddings.embeddings, embedding)})`; - const data = await serverDB + const data = await this.db .select({ fileId: fileChunks.fileId, fileName: files.name, @@ -185,7 +190,7 @@ export class ChunkModel { if (!hasFiles) return []; - const result = await serverDB + const result = await this.db .select({ fileId: files.id, fileName: files.name, diff --git a/src/database/server/models/embedding.ts b/src/database/server/models/embedding.ts index 45f5980ebf88..123e5e3a6083 100644 --- a/src/database/server/models/embedding.ts +++ b/src/database/server/models/embedding.ts @@ -1,19 +1,21 @@ import { count, eq } from 'drizzle-orm'; import { and } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { NewEmbeddingsItem, embeddings } from '../../schemas'; export class EmbeddingModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (value: Omit) => { - const [item] = await serverDB + const [item] = await this.db .insert(embeddings) .values({ ...value, userId: this.userId }) .returning(); @@ -22,7 +24,7 @@ export class EmbeddingModel { }; bulkCreate = async (values: Omit[]) => { - return serverDB + return this.db .insert(embeddings) .values(values.map((item) => ({ ...item, userId: this.userId }))) .onConflictDoNothing({ @@ -31,25 +33,25 @@ export class EmbeddingModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(embeddings) .where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId))); }; query = async () => { - return serverDB.query.embeddings.findMany({ + return this.db.query.embeddings.findMany({ where: eq(embeddings.userId, this.userId), }); }; findById = async (id: string) => { - return serverDB.query.embeddings.findFirst({ + return this.db.query.embeddings.findFirst({ where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)), }); }; countUsage = async () => { - const result = await serverDB + const result = await this.db .select({ count: count(), }) diff --git a/src/database/server/models/file.ts b/src/database/server/models/file.ts index 8a01bfe71400..7c1f4ec6e5d2 100644 --- a/src/database/server/models/file.ts +++ b/src/database/server/models/file.ts @@ -3,7 +3,7 @@ import { and, desc, like } from 'drizzle-orm/expressions'; import type { PgTransaction } from 'drizzle-orm/pg-core'; import { serverDBEnv } from '@/config/db'; -import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; import { FilesTabs, QueryFileListParams, SortType } from '@/types/files'; import { @@ -20,13 +20,15 @@ import { export class FileModel { private readonly userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: Omit & { knowledgeBaseId?: string }) => { - const result = await serverDB.transaction(async (trx) => { + const result = await this.db.transaction(async (trx) => { const result = await trx .insert(files) .values({ ...params, userId: this.userId }) @@ -47,11 +49,11 @@ export class FileModel { }; createGlobalFile = async (file: Omit) => { - return serverDB.insert(globalFiles).values(file).returning(); + return this.db.insert(globalFiles).values(file).returning(); }; checkHash = async (hash: string) => { - const item = await serverDB.query.globalFiles.findFirst({ + const item = await this.db.query.globalFiles.findFirst({ where: eq(globalFiles.hashId, hash), }); if (!item) return { isExist: false }; @@ -71,7 +73,7 @@ export class FileModel { const fileHash = file.fileHash!; - return await serverDB.transaction(async (trx) => { + return await this.db.transaction(async (trx) => { // 1. 删除相关的 chunks await this.deleteFileChunks(trx as any, [id]); @@ -96,11 +98,11 @@ export class FileModel { }; deleteGlobalFile = async (hashId: string) => { - return serverDB.delete(globalFiles).where(eq(globalFiles.hashId, hashId)); + return this.db.delete(globalFiles).where(eq(globalFiles.hashId, hashId)); }; countUsage = async () => { - const result = await serverDB + const result = await this.db .select({ totalSize: sum(files.size), }) @@ -114,7 +116,7 @@ export class FileModel { const fileList = await this.findByIds(ids); const hashList = fileList.map((file) => file.fileHash!); - return await serverDB.transaction(async (trx) => { + return await this.db.transaction(async (trx) => { // 1. 删除相关的 chunks await this.deleteFileChunks(trx as any, ids); @@ -159,7 +161,7 @@ export class FileModel { }; clear = async () => { - return serverDB.delete(files).where(eq(files.userId, this.userId)); + return this.db.delete(files).where(eq(files.userId, this.userId)); }; query = async ({ @@ -198,7 +200,7 @@ export class FileModel { } // 3. build query - let query = serverDB + let query = this.db .select({ chunkTaskId: files.chunkTaskId, createdAt: files.createdAt, @@ -230,7 +232,7 @@ export class FileModel { whereClause = and( whereClause, notExists( - serverDB.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)), + this.db.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)), ), ); } @@ -240,19 +242,19 @@ export class FileModel { }; findByIds = async (ids: string[]) => { - return serverDB.query.files.findMany({ + return this.db.query.files.findMany({ where: and(inArray(files.id, ids), eq(files.userId, this.userId)), }); }; findById = async (id: string) => { - return serverDB.query.files.findFirst({ + return this.db.query.files.findFirst({ where: and(eq(files.id, id), eq(files.userId, this.userId)), }); }; countFilesByHash = async (hash: string) => { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -263,7 +265,7 @@ export class FileModel { }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(files) .set({ ...value, updatedAt: new Date() }) .where(and(eq(files.id, id), eq(files.userId, this.userId))); @@ -293,7 +295,7 @@ export class FileModel { }; async findByNames(fileNames: string[]) { - return serverDB.query.files.findMany({ + return this.db.query.files.findMany({ where: and( or(...fileNames.map((name) => like(files.name, `${name}%`))), eq(files.userId, this.userId), diff --git a/src/database/server/models/knowledgeBase.ts b/src/database/server/models/knowledgeBase.ts index 844558667a5b..6c765ab51603 100644 --- a/src/database/server/models/knowledgeBase.ts +++ b/src/database/server/models/knowledgeBase.ts @@ -1,22 +1,24 @@ import { eq, inArray } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { KnowledgeBaseItem } from '@/types/knowledgeBase'; import { NewKnowledgeBase, knowledgeBaseFiles, knowledgeBases } from '../../schemas'; export class KnowledgeBaseModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // create create = async (params: Omit) => { - const [result] = await serverDB + const [result] = await this.db .insert(knowledgeBases) .values({ ...params, userId: this.userId }) .returning(); @@ -25,7 +27,7 @@ export class KnowledgeBaseModel { }; addFilesToKnowledgeBase = async (id: string, fileIds: string[]) => { - return serverDB + return this.db .insert(knowledgeBaseFiles) .values(fileIds.map((fileId) => ({ fileId, knowledgeBaseId: id, userId: this.userId }))) .returning(); @@ -33,17 +35,17 @@ export class KnowledgeBaseModel { // delete delete = async (id: string) => { - return serverDB + return this.db .delete(knowledgeBases) .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId)); + return this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId)); }; removeFilesFromKnowledgeBase = async (knowledgeBaseId: string, ids: string[]) => { - return serverDB.delete(knowledgeBaseFiles).where( + return this.db.delete(knowledgeBaseFiles).where( and( eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), inArray(knowledgeBaseFiles.fileId, ids), @@ -53,7 +55,7 @@ export class KnowledgeBaseModel { }; // query query = async () => { - const data = await serverDB + const data = await this.db .select({ avatar: knowledgeBases.avatar, createdAt: knowledgeBases.createdAt, @@ -73,21 +75,21 @@ export class KnowledgeBaseModel { }; findById = async (id: string) => { - return serverDB.query.knowledgeBases.findFirst({ + return this.db.query.knowledgeBases.findFirst({ where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)), }); }; // update async update(id: string, value: Partial) { - return serverDB + return this.db .update(knowledgeBases) .set({ ...value, updatedAt: new Date() }) .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); } - static async findById(id: string) { - return serverDB.query.knowledgeBases.findFirst({ + static async findById(db: LobeChatDatabase, id: string) { + return db.query.knowledgeBases.findFirst({ where: eq(knowledgeBases.id, id), }); } diff --git a/src/database/server/models/message.ts b/src/database/server/models/message.ts index e4e43e3b01a3..000a8f53c064 100644 --- a/src/database/server/models/message.ts +++ b/src/database/server/models/message.ts @@ -1,8 +1,8 @@ import { count } from 'drizzle-orm'; import { and, asc, desc, eq, gte, inArray, isNull, like, lt } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server/core/db'; import { idGenerator } from '@/database/utils/idGenerator'; +import { LobeChatDatabase } from '@/database/type'; import { getFullFileUrl } from '@/server/utils/files'; import { ChatFileItem, @@ -39,9 +39,11 @@ export interface QueryMessageParams { export class MessageModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // **************** Query *************** // @@ -54,7 +56,7 @@ export class MessageModel { const offset = current * pageSize; // 1. get basic messages - const result = await serverDB + const result = await this.db .select({ /* eslint-disable sort-keys-fix/sort-keys-fix*/ id: messages.id, @@ -115,7 +117,7 @@ export class MessageModel { if (messageIds.length === 0) return []; // 2. get relative files - const rawRelatedFileList = await serverDB + const rawRelatedFileList = await this.db .select({ fileType: files.fileType, id: messagesFiles.fileId, @@ -139,7 +141,7 @@ export class MessageModel { const fileList = relatedFileList.filter((i) => !(i.fileType || '').startsWith('image')); // 3. get relative file chunks - const chunksList = await serverDB + const chunksList = await this.db .select({ fileId: files.id, fileType: files.fileType, @@ -157,7 +159,7 @@ export class MessageModel { .where(inArray(messageQueryChunks.messageId, messageIds)); // 3. get relative message query - const messageQueriesList = await serverDB + const messageQueriesList = await this.db .select({ id: messageQueries.id, messageId: messageQueries.messageId, @@ -216,13 +218,13 @@ export class MessageModel { } async findById(id: string) { - return serverDB.query.messages.findFirst({ + return this.db.query.messages.findFirst({ where: and(eq(messages.id, id), eq(messages.userId, this.userId)), }); } async findMessageQueriesById(messageId: string) { - const result = await serverDB + const result = await this.db .select({ embeddings: embeddings.embeddings, id: messageQueries.id, @@ -240,7 +242,7 @@ export class MessageModel { } async queryAll(): Promise { - return serverDB + return this.db .select() .from(messages) .orderBy(messages.createdAt) @@ -250,7 +252,7 @@ export class MessageModel { } async queryBySessionId(sessionId?: string | null): Promise { - return serverDB.query.messages.findMany({ + return this.db.query.messages.findMany({ orderBy: [asc(messages.createdAt)], where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)), }); @@ -259,14 +261,14 @@ export class MessageModel { async queryByKeyword(keyword: string): Promise { if (!keyword) return []; - return serverDB.query.messages.findMany({ + return this.db.query.messages.findMany({ orderBy: [desc(messages.createdAt)], where: and(eq(messages.userId, this.userId), like(messages.content, `%${keyword}%`)), }); } async count() { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -283,7 +285,7 @@ export class MessageModel { const tomorrow = new Date(today); tomorrow.setDate(tomorrow.getDate() + 1); - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -315,7 +317,7 @@ export class MessageModel { }: CreateMessageParams, id: string = this.genId(), ): Promise { - return serverDB.transaction(async (trx) => { + return this.db.transaction(async (trx) => { const [item] = (await trx .insert(messages) .values({ @@ -366,72 +368,72 @@ export class MessageModel { return { ...m, userId: this.userId }; }); - return serverDB.insert(messages).values(messagesToInsert); + return this.db.insert(messages).values(messagesToInsert); } async createMessageQuery(params: NewMessageQuery) { - const result = await serverDB.insert(messageQueries).values(params).returning(); + const result = await this.db.insert(messageQueries).values(params).returning(); return result[0]; } // **************** Update *************** // async update(id: string, message: Partial) { - return serverDB + return this.db .update(messages) .set(message) .where(and(eq(messages.id, id), eq(messages.userId, this.userId))); } async updatePluginState(id: string, state: Record) { - const item = await serverDB.query.messagePlugins.findFirst({ + const item = await this.db.query.messagePlugins.findFirst({ where: eq(messagePlugins.id, id), }); if (!item) throw new Error('Plugin not found'); - return serverDB + return this.db .update(messagePlugins) .set({ state: merge(item.state || {}, state) }) .where(eq(messagePlugins.id, id)); } async updateMessagePlugin(id: string, value: Partial) { - const item = await serverDB.query.messagePlugins.findFirst({ + const item = await this.db.query.messagePlugins.findFirst({ where: eq(messagePlugins.id, id), }); if (!item) throw new Error('Plugin not found'); - return serverDB.update(messagePlugins).set(value).where(eq(messagePlugins.id, id)); + return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id)); } async updateTranslate(id: string, translate: Partial) { - const result = await serverDB.query.messageTranslates.findFirst({ + const result = await this.db.query.messageTranslates.findFirst({ where: and(eq(messageTranslates.id, id)), }); // If the message does not exist in the translate table, insert it if (!result) { - return serverDB.insert(messageTranslates).values({ ...translate, id }); + return this.db.insert(messageTranslates).values({ ...translate, id }); } // or just update the existing one - return serverDB.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id)); + return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id)); } async updateTTS(id: string, tts: Partial) { - const result = await serverDB.query.messageTTS.findFirst({ + const result = await this.db.query.messageTTS.findFirst({ where: and(eq(messageTTS.id, id)), }); // If the message does not exist in the translate table, insert it if (!result) { - return serverDB + return this.db .insert(messageTTS) .values({ contentMd5: tts.contentMd5, fileId: tts.file, id, voice: tts.voice }); } // or just update the existing one - return serverDB + return this.db .update(messageTTS) .set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice }) .where(eq(messageTTS.id, id)); @@ -440,7 +442,7 @@ export class MessageModel { // **************** Delete *************** // async deleteMessage(id: string) { - return serverDB.transaction(async (tx) => { + return this.db.transaction(async (tx) => { // 1. 查询要删除的 message 的完整信息 const message = await tx .select() @@ -476,25 +478,25 @@ export class MessageModel { } async deleteMessages(ids: string[]) { - return serverDB + return this.db .delete(messages) .where(and(eq(messages.userId, this.userId), inArray(messages.id, ids))); } async deleteMessageTranslate(id: string) { - return serverDB.delete(messageTranslates).where(and(eq(messageTranslates.id, id))); + return this.db.delete(messageTranslates).where(and(eq(messageTranslates.id, id))); } async deleteMessageTTS(id: string) { - return serverDB.delete(messageTTS).where(and(eq(messageTTS.id, id))); + return this.db.delete(messageTTS).where(and(eq(messageTTS.id, id))); } async deleteMessageQuery(id: string) { - return serverDB.delete(messageQueries).where(and(eq(messageQueries.id, id))); + return this.db.delete(messageQueries).where(and(eq(messageQueries.id, id))); } async deleteMessagesBySession(sessionId?: string | null, topicId?: string | null) { - return serverDB + return this.db .delete(messages) .where( and( @@ -506,7 +508,7 @@ export class MessageModel { } async deleteAllMessages() { - return serverDB.delete(messages).where(eq(messages.userId, this.userId)); + return this.db.delete(messages).where(eq(messages.userId, this.userId)); } // **************** Helper *************** // diff --git a/src/database/server/models/plugin.ts b/src/database/server/models/plugin.ts index 7b07c9a8a873..5b5293082c9b 100644 --- a/src/database/server/models/plugin.ts +++ b/src/database/server/models/plugin.ts @@ -1,20 +1,22 @@ import { and, desc, eq } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { InstalledPluginItem, NewInstalledPlugin, installedPlugins } from '../../schemas'; export class PluginModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async ( params: Pick, ) => { - const [result] = await serverDB + const [result] = await this.db .insert(installedPlugins) .values({ ...params, createdAt: new Date(), updatedAt: new Date(), userId: this.userId }) .returning(); @@ -23,17 +25,17 @@ export class PluginModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(installedPlugins) .where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId)); + return this.db.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId)); }; query = async () => { - return serverDB + return this.db .select({ createdAt: installedPlugins.createdAt, customParams: installedPlugins.customParams, @@ -49,13 +51,13 @@ export class PluginModel { }; findById = async (id: string) => { - return serverDB.query.installedPlugins.findFirst({ + return this.db.query.installedPlugins.findFirst({ where: and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(installedPlugins) .set({ ...value, updatedAt: new Date() }) .where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId))); diff --git a/src/database/server/models/session.ts b/src/database/server/models/session.ts index a41b78f89ca7..de9e61fcbec9 100644 --- a/src/database/server/models/session.ts +++ b/src/database/server/models/session.ts @@ -1,10 +1,11 @@ import { Column, asc, count, inArray, like, sql } from 'drizzle-orm'; -import { and, desc, eq, isNull, not, or } from 'drizzle-orm/expressions'; +import { and, desc, eq, not, or } from 'drizzle-orm/expressions'; import { appEnv } from '@/config/app'; import { INBOX_SESSION_ID } from '@/const/session'; import { DEFAULT_AGENT_CONFIG } from '@/const/settings'; -import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; +import { idGenerator } from '@/database/utils/idGenerator'; import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent'; import { ChatSessionList, LobeAgentSession } from '@/types/session'; import { merge } from '@/utils/merge'; @@ -19,20 +20,21 @@ import { sessionGroups, sessions, } from '../../schemas'; -import { idGenerator } from '@/database/utils/idGenerator'; export class SessionModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // **************** Query *************** // async query({ current = 0, pageSize = 9999 } = {}) { const offset = current * pageSize; - return serverDB.query.sessions.findMany({ + return this.db.query.sessions.findMany({ limit: pageSize, offset, orderBy: [desc(sessions.updatedAt)], @@ -45,7 +47,7 @@ export class SessionModel { // 查询所有会话 const result = await this.query(); - const groups = await serverDB.query.sessionGroups.findMany({ + const groups = await this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], where: eq(sessions.userId, this.userId), }); @@ -69,7 +71,7 @@ export class SessionModel { async findByIdOrSlug( idOrSlug: string, ): Promise<(SessionItem & { agent: AgentItem }) | undefined> { - const result = await serverDB.query.sessions.findFirst({ + const result = await this.db.query.sessions.findFirst({ where: and( or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)), eq(sessions.userId, this.userId), @@ -83,7 +85,7 @@ export class SessionModel { } async count() { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -109,7 +111,7 @@ export class SessionModel { slug?: string; type: 'agent' | 'group'; }): Promise { - return serverDB.transaction(async (trx) => { + return this.db.transaction(async (trx) => { const newAgents = await trx .insert(agents) .values({ @@ -144,7 +146,7 @@ export class SessionModel { } async createInbox() { - const item = await serverDB.query.sessions.findFirst({ + const item = await this.db.query.sessions.findFirst({ where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)), }); if (item) return; @@ -167,7 +169,7 @@ export class SessionModel { }; }); - return serverDB.insert(sessions).values(sessionsToInsert); + return this.db.insert(sessions).values(sessionsToInsert); } async duplicate(id: string, newTitle?: string) { @@ -199,7 +201,7 @@ export class SessionModel { * Delete a session, also delete all messages and topics associated with it. */ async delete(id: string) { - return serverDB + return this.db .delete(sessions) .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))); } @@ -208,18 +210,18 @@ export class SessionModel { * Batch delete sessions, also delete all messages and topics associated with them. */ async batchDelete(ids: string[]) { - return serverDB + return this.db .delete(sessions) .where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId))); } async deleteAll() { - return serverDB.delete(sessions).where(eq(sessions.userId, this.userId)); + return this.db.delete(sessions).where(eq(sessions.userId, this.userId)); } // **************** Update *************** // async update(id: string, data: Partial) { - return serverDB + return this.db .update(sessions) .set(data) .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))) @@ -227,7 +229,7 @@ export class SessionModel { } async updateConfig(id: string, data: Partial) { - return serverDB + return this.db .update(agents) .set(data) .where(and(eq(agents.id, id), eq(agents.userId, this.userId))); @@ -262,65 +264,22 @@ export class SessionModel { } as any; }; - async findSessions(params: { - current?: number; - group?: string; - keyword?: string; - pageSize?: number; - pinned?: boolean; - }) { - const { pinned, keyword, group, pageSize = 9999, current = 0 } = params; - - const offset = current * pageSize; - return serverDB.query.sessions.findMany({ - limit: pageSize, - offset, - orderBy: [desc(sessions.updatedAt)], - where: and( - eq(sessions.userId, this.userId), - pinned !== undefined ? eq(sessions.pinned, pinned) : eq(sessions.userId, this.userId), - keyword - ? or( - like( - sql`lower(${sessions.title})` as unknown as Column, - `%${keyword.toLowerCase()}%`, - ), - like( - sql`lower(${sessions.description})` as unknown as Column, - `%${keyword.toLowerCase()}%`, - ), - ) - : eq(sessions.userId, this.userId), - group ? eq(sessions.groupId, group) : isNull(sessions.groupId), - ), - - with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true }, - }); - } - - async findSessionsByKeywords(params: { - current?: number; - keyword: string; - pageSize?: number; - }) { + async findSessionsByKeywords(params: { current?: number; keyword: string; pageSize?: number }) { const { keyword, pageSize = 9999, current = 0 } = params; const offset = current * pageSize; - const results = await serverDB.query.agents.findMany({ + const results = await this.db.query.agents.findMany({ limit: pageSize, offset, orderBy: [desc(agents.updatedAt)], where: and( eq(agents.userId, this.userId), or( - like( - sql`lower(${agents.title})` as unknown as Column, - `%${keyword.toLowerCase()}%`, - ), + like(sql`lower(${agents.title})` as unknown as Column, `%${keyword.toLowerCase()}%`), like( sql`lower(${agents.description})` as unknown as Column, `%${keyword.toLowerCase()}%`, ), - ) + ), ), with: { agentsToSessions: { columns: {}, with: { session: true } } }, }); @@ -328,6 +287,6 @@ export class SessionModel { // @ts-expect-error return results.map((item) => item.agentsToSessions[0].session); } catch {} - return [] + return []; } } diff --git a/src/database/server/models/sessionGroup.ts b/src/database/server/models/sessionGroup.ts index 16df2a2fab37..14d2dc48a25c 100644 --- a/src/database/server/models/sessionGroup.ts +++ b/src/database/server/models/sessionGroup.ts @@ -1,20 +1,22 @@ import { eq } from 'drizzle-orm'; import { and, asc, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { idGenerator } from '@/database/utils/idGenerator'; import { SessionGroupItem, sessionGroups } from '../../schemas'; export class SessionGroupModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: { name: string; sort?: number }) => { - const [result] = await serverDB + const [result] = await this.db .insert(sessionGroups) .values({ ...params, id: this.genId(), userId: this.userId }) .returning(); @@ -23,37 +25,37 @@ export class SessionGroupModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(sessionGroups) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); + return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); }; query = async () => { - return serverDB.query.sessionGroups.findMany({ + return this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], where: eq(sessionGroups.userId, this.userId), }); }; findById = async (id: string) => { - return serverDB.query.sessionGroups.findFirst({ + return this.db.query.sessionGroups.findFirst({ where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); } async updateOrder(sortMap: { id: string; sort: number }[]) { - await serverDB.transaction(async (tx) => { + await this.db.transaction(async (tx) => { const updates = sortMap.map(({ id, sort }) => { return tx .update(sessionGroups) diff --git a/src/database/server/models/thread.ts b/src/database/server/models/thread.ts index 791510470e17..9d65e698ee25 100644 --- a/src/database/server/models/thread.ts +++ b/src/database/server/models/thread.ts @@ -1,7 +1,7 @@ import { eq } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { CreateThreadParams, ThreadStatus } from '@/types/topic'; import { ThreadItem, threads } from '../../schemas'; @@ -20,14 +20,16 @@ const queryColumns = { export class ThreadModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: CreateThreadParams) => { // @ts-ignore - const [result] = await serverDB + const [result] = await this.db .insert(threads) .values({ ...params, status: ThreadStatus.Active, userId: this.userId }) .onConflictDoNothing() @@ -37,15 +39,15 @@ export class ThreadModel { }; delete = async (id: string) => { - return serverDB.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId))); + return this.db.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(threads).where(eq(threads.userId, this.userId)); + return this.db.delete(threads).where(eq(threads.userId, this.userId)); }; query = async () => { - const data = await serverDB + const data = await this.db .select(queryColumns) .from(threads) .where(eq(threads.userId, this.userId)) @@ -55,7 +57,7 @@ export class ThreadModel { }; queryByTopicId = async (topicId: string) => { - const data = await serverDB + const data = await this.db .select(queryColumns) .from(threads) .where(and(eq(threads.topicId, topicId), eq(threads.userId, this.userId))) @@ -65,13 +67,13 @@ export class ThreadModel { }; findById = async (id: string) => { - return serverDB.query.threads.findFirst({ + return this.db.query.threads.findFirst({ where: and(eq(threads.id, id), eq(threads.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(threads) .set({ ...value, updatedAt: new Date() }) .where(and(eq(threads.id, id), eq(threads.userId, this.userId))); diff --git a/src/database/server/models/topic.ts b/src/database/server/models/topic.ts index 0e09973d5727..ad914d614fe6 100644 --- a/src/database/server/models/topic.ts +++ b/src/database/server/models/topic.ts @@ -1,7 +1,7 @@ import { Column, count, inArray, sql } from 'drizzle-orm'; import { and, desc, eq, exists, isNull, like, or } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; import { NewMessage, TopicItem, messages, topics } from '../../schemas'; import { idGenerator } from '@/database/utils/idGenerator'; @@ -21,9 +21,11 @@ interface QueryTopicParams { export class TopicModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // **************** Query *************** // @@ -31,7 +33,7 @@ export class TopicModel { const offset = current * pageSize; return ( - serverDB + this.db .select({ createdAt: topics.createdAt, favorite: topics.favorite, @@ -52,13 +54,13 @@ export class TopicModel { } async findById(id: string) { - return serverDB.query.topics.findFirst({ + return this.db.query.topics.findFirst({ where: and(eq(topics.id, id), eq(topics.userId, this.userId)), }); } async queryAll(): Promise { - return serverDB + return this.db .select() .from(topics) .orderBy(topics.updatedAt) @@ -74,7 +76,7 @@ export class TopicModel { const matchKeyword = (field: any) => like(sql`lower(${field})` as unknown as Column, `%${keywordLowerCase}%`); - return serverDB.query.topics.findMany({ + return this.db.query.topics.findMany({ orderBy: [desc(topics.updatedAt)], where: and( eq(topics.userId, this.userId), @@ -82,15 +84,10 @@ export class TopicModel { or( matchKeyword(topics.title), exists( - serverDB + this.db .select() .from(messages) - .where( - and( - eq(messages.topicId, topics.id), - matchKeyword(messages.content) - ) - ), + .where(and(eq(messages.topicId, topics.id), matchKeyword(messages.content))), ), ), ), @@ -98,7 +95,7 @@ export class TopicModel { } async count() { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -115,7 +112,7 @@ export class TopicModel { { messages: messageIds, ...params }: CreateTopicParams, id: string = this.genId(), ): Promise { - return serverDB.transaction(async (tx) => { + return this.db.transaction(async (tx) => { // 在 topics 表中插入新的 topic const [topic] = await tx .insert(topics) @@ -140,7 +137,7 @@ export class TopicModel { async batchCreate(topicParams: (CreateTopicParams & { id?: string })[]) { // 开始一个事务 - return serverDB.transaction(async (tx) => { + return this.db.transaction(async (tx) => { // 在 topics 表中批量插入新的 topics const createdTopics = await tx .insert(topics) @@ -173,7 +170,7 @@ export class TopicModel { } async duplicate(topicId: string, newTitle?: string) { - return serverDB.transaction(async (tx) => { + return this.db.transaction(async (tx) => { // find original topic const originalTopic = await tx.query.topics.findFirst({ where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)), @@ -228,14 +225,14 @@ export class TopicModel { * Delete a session, also delete all messages and topics associated with it. */ async delete(id: string) { - return serverDB.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId))); + return this.db.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId))); } /** * Deletes multiple topics based on the sessionId. */ async batchDeleteBySessionId(sessionId?: string | null) { - return serverDB + return this.db .delete(topics) .where(and(this.matchSession(sessionId), eq(topics.userId, this.userId))); } @@ -244,19 +241,19 @@ export class TopicModel { * Deletes multiple topics and all messages associated with them in a transaction. */ async batchDelete(ids: string[]) { - return serverDB + return this.db .delete(topics) .where(and(inArray(topics.id, ids), eq(topics.userId, this.userId))); } async deleteAll() { - return serverDB.delete(topics).where(eq(topics.userId, this.userId)); + return this.db.delete(topics).where(eq(topics.userId, this.userId)); } // **************** Update *************** // async update(id: string, data: Partial) { - return serverDB + return this.db .update(topics) .set({ ...data, updatedAt: new Date() }) .where(and(eq(topics.id, id), eq(topics.userId, this.userId))) diff --git a/src/database/server/models/user.ts b/src/database/server/models/user.ts index 2fc92d1f1d2a..74b4130805d5 100644 --- a/src/database/server/models/user.ts +++ b/src/database/server/models/user.ts @@ -2,7 +2,7 @@ import { TRPCError } from '@trpc/server'; import { eq } from 'drizzle-orm'; import { DeepPartial } from 'utility-types'; -import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { UserGuide, UserPreference } from '@/types/user'; import { UserKeyVaults, UserSettings } from '@/types/user/settings'; @@ -18,38 +18,16 @@ export class UserNotFoundError extends TRPCError { } export class UserModel { - static createUser = async (params: NewUser) => { - // if user already exists, skip creation - if (params.id) { - const user = await serverDB.query.users.findFirst({ where: eq(users.id, params.id) }); - if (!!user) return; - } - - const [user] = await serverDB - .insert(users) - .values({ ...params }) - .returning(); - - // Create an inbox session for the user - const model = new SessionModel(user.id); - - await model.createInbox(); - }; + private userId: string; + private db: LobeChatDatabase; - static deleteUser = async (id: string) => { - return serverDB.delete(users).where(eq(users.id, id)); - }; - - static findById = async (id: string) => { - return serverDB.query.users.findFirst({ where: eq(users.id, id) }); - }; - - static findByEmail = async (email: string) => { - return serverDB.query.users.findFirst({ where: eq(users.email, email) }); - }; + constructor(db: LobeChatDatabase, userId: string) { + this.userId = userId; + this.db = db; + } - getUserState = async (id: string) => { - const result = await serverDB + async getUserState() { + const result = await this.db .select({ isOnboarded: users.isOnboarded, preference: users.preference, @@ -63,7 +41,7 @@ export class UserModel { settingsTool: userSettings.tool, }) .from(users) - .where(eq(users.id, id)) + .where(eq(users.id, this.userId)) .leftJoin(userSettings, eq(users.id, userSettings.id)); if (!result || !result[0]) { @@ -82,7 +60,7 @@ export class UserModel { try { decryptKeyVaults = JSON.parse(plaintext); } catch (e) { - console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e); + console.error(`Failed to parse keyVaults ,userId: ${this.userId}. Error:`, e); } } } @@ -101,54 +79,22 @@ export class UserModel { isOnboarded: state.isOnboarded, preference: state.preference as UserPreference, settings, - userId: id, + userId: this.userId, }; - }; - - static getUserApiKeys = async (id: string) => { - const result = await serverDB - .select({ - settingsKeyVaults: userSettings.keyVaults, - }) - .from(userSettings) - .where(eq(userSettings.id, id)); - - if (!result || !result[0]) { - throw new UserNotFoundError(); - } - - const state = result[0]; - - // Decrypt keyVaults - let decryptKeyVaults = {}; - if (state.settingsKeyVaults) { - const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults); - - if (wasAuthentic) { - try { - decryptKeyVaults = JSON.parse(plaintext); - } catch (e) { - console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e); - } - } - } - - return decryptKeyVaults as UserKeyVaults; - }; + } - async updateUser(id: string, value: Partial) { - return serverDB + async updateUser(value: Partial) { + return this.db .update(users) .set({ ...value, updatedAt: new Date() }) - .where(eq(users.id, id)); + .where(eq(users.id, this.userId)); } - async deleteSetting(id: string) { - return serverDB.delete(userSettings).where(eq(userSettings.id, id)); + async deleteSetting() { + return this.db.delete(userSettings).where(eq(userSettings.id, this.userId)); } - async updateSetting(id: string, value: Partial) { + async updateSetting(value: Partial) { const { keyVaults, ...res } = value; // Encrypt keyVaults @@ -165,33 +111,99 @@ export class UserModel { const newValue = { ...res, keyVaults: encryptedKeyVaults }; // update or create user settings - const settings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, id) }); + const settings = await this.db.query.userSettings.findFirst({ + where: eq(users.id, this.userId), + }); if (!settings) { - await serverDB.insert(userSettings).values({ id, ...newValue }); + await this.db.insert(userSettings).values({ id: this.userId, ...newValue }); return; } - return serverDB.update(userSettings).set(newValue).where(eq(userSettings.id, id)); + return this.db.update(userSettings).set(newValue).where(eq(userSettings.id, this.userId)); } - async updatePreference(id: string, value: Partial) { - const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) }); + async updatePreference(value: Partial) { + const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) }); if (!user) return; - return serverDB + return this.db .update(users) .set({ preference: merge(user.preference, value) }) - .where(eq(users.id, id)); + .where(eq(users.id, this.userId)); } - async updateGuide(id: string, value: Partial) { - const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) }); + async updateGuide(value: Partial) { + const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) }); if (!user) return; const prevPreference = (user.preference || {}) as UserPreference; - return serverDB + return this.db .update(users) .set({ preference: { ...prevPreference, guide: merge(prevPreference.guide || {}, value) } }) - .where(eq(users.id, id)); + .where(eq(users.id, this.userId)); } + + // Static method + + static createUser = async (db: LobeChatDatabase, params: NewUser) => { + // if user already exists, skip creation + if (params.id) { + const user = await db.query.users.findFirst({ where: eq(users.id, params.id) }); + if (!!user) return; + } + + const [user] = await db + .insert(users) + .values({ ...params }) + .returning(); + + // Create an inbox session for the user + const model = new SessionModel(db, user.id); + + await model.createInbox(); + }; + + static deleteUser = async (db: LobeChatDatabase, id: string) => { + return db.delete(users).where(eq(users.id, id)); + }; + + static findById = async (db: LobeChatDatabase, id: string) => { + return db.query.users.findFirst({ where: eq(users.id, id) }); + }; + + static findByEmail = async (db: LobeChatDatabase, email: string) => { + return db.query.users.findFirst({ where: eq(users.email, email) }); + }; + + static getUserApiKeys = async (db: LobeChatDatabase, id: string) => { + const result = await db + .select({ + settingsKeyVaults: userSettings.keyVaults, + }) + .from(userSettings) + .where(eq(userSettings.id, id)); + + if (!result || !result[0]) { + throw new UserNotFoundError(); + } + + const state = result[0]; + + // Decrypt keyVaults + let decryptKeyVaults = {}; + if (state.settingsKeyVaults) { + const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); + const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults); + + if (wasAuthentic) { + try { + decryptKeyVaults = JSON.parse(plaintext); + } catch (e) { + console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e); + } + } + } + + return decryptKeyVaults as UserKeyVaults; + }; } diff --git a/src/database/type.ts b/src/database/type.ts new file mode 100644 index 000000000000..cf9308763134 --- /dev/null +++ b/src/database/type.ts @@ -0,0 +1,7 @@ +import type { NeonDatabase } from 'drizzle-orm/neon-serverless'; + +import * as schema from './schemas'; + +export type LobeChatDatabaseSchema = typeof schema; + +export type LobeChatDatabase = NeonDatabase; diff --git a/src/libs/next-auth/adapter/index.ts b/src/libs/next-auth/adapter/index.ts index 5eeb5e4ab824..ac9f8ae2caaa 100644 --- a/src/libs/next-auth/adapter/index.ts +++ b/src/libs/next-auth/adapter/index.ts @@ -33,8 +33,6 @@ const { * @returns {Adapter} */ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Adapter { - const userModel = new UserModel(); - return { async createAuthenticator(authenticator): Promise { const result = await serverDB @@ -55,10 +53,10 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad async createUser(user): Promise { const { id, name, email, emailVerified, image, providerAccountId } = user; // return the user if it already exists - let existingUser = await UserModel.findByEmail(email); + let existingUser = await UserModel.findByEmail(serverDB, email); // If the user is not found by email, try to find by providerAccountId if (!existingUser && providerAccountId) { - existingUser = await UserModel.findById(providerAccountId); + existingUser = await UserModel.findById(serverDB, providerAccountId); } if (existingUser) { const adapterUser = mapLobeUserToAdapterUser(existingUser); @@ -66,6 +64,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad } // create a new user if it does not exist await UserModel.createUser( + serverDB, mapAdapterUserToLobeUser({ email, emailVerified, @@ -91,10 +90,10 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad return; }, async deleteUser(id): Promise { - const user = await UserModel.findById(id); + const user = await UserModel.findById(serverDB, id); if (!user) throw new Error('NextAuth: Delete User not found'); - await UserModel.deleteUser(id); + await UserModel.deleteUser(serverDB, id); return; }, @@ -145,7 +144,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad }, async getUser(id): Promise { - const lobeUser = await UserModel.findById(id); + const lobeUser = await UserModel.findById(serverDB, id); if (!lobeUser) return null; return mapLobeUserToAdapterUser(lobeUser); }, @@ -170,7 +169,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad }, async getUserByEmail(email): Promise { - const lobeUser = await UserModel.findByEmail(email); + const lobeUser = await UserModel.findByEmail(serverDB, email); return lobeUser ? mapLobeUserToAdapterUser(lobeUser) : null; }, @@ -228,10 +227,11 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad }, async updateUser(user): Promise { - const lobeUser = await UserModel.findById(user?.id); + const lobeUser = await UserModel.findById(serverDB, user?.id); if (!lobeUser) throw new Error('NextAuth: User not found'); + const userModel = new UserModel(serverDB, user.id); - const updatedUser = await userModel.updateUser(user.id, { + const updatedUser = await userModel.updateUser({ ...partialMapAdapterUserToLobeUser(user), }); if (!updatedUser) throw new Error('NextAuth: Failed to update user'); diff --git a/src/libs/trpc/async/asyncAuth.ts b/src/libs/trpc/async/asyncAuth.ts index 3ebf3248a273..69b52d5d5441 100644 --- a/src/libs/trpc/async/asyncAuth.ts +++ b/src/libs/trpc/async/asyncAuth.ts @@ -1,6 +1,7 @@ import { TRPCError } from '@trpc/server'; import { serverDBEnv } from '@/config/db'; +import { serverDB } from '@/database/server'; import { UserModel } from '@/database/server/models/user'; import { asyncTrpc } from './init'; @@ -12,7 +13,7 @@ export const asyncAuth = asyncTrpc.middleware(async (opts) => { throw new TRPCError({ code: 'UNAUTHORIZED' }); } - const result = await UserModel.findById(ctx.userId); + const result = await UserModel.findById(serverDB, ctx.userId); if (!result) { throw new TRPCError({ code: 'UNAUTHORIZED', message: 'user is invalid' }); diff --git a/src/server/routers/async/file.ts b/src/server/routers/async/file.ts index c9071564245c..8bbee1e9e1fe 100644 --- a/src/server/routers/async/file.ts +++ b/src/server/routers/async/file.ts @@ -5,6 +5,7 @@ import { z } from 'zod'; import { fileEnv } from '@/config/file'; import { DEFAULT_EMBEDDING_MODEL } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '@/database/server/models/asyncTask'; import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; @@ -28,11 +29,11 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => { return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.userId), - chunkModel: new ChunkModel(ctx.userId), + asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), chunkService: new ChunkService(ctx.userId), - embeddingModel: new EmbeddingModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + embeddingModel: new EmbeddingModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/async/ragEval.ts b/src/server/routers/async/ragEval.ts index 5fc637337f23..de4190ee668e 100644 --- a/src/server/routers/async/ragEval.ts +++ b/src/server/routers/async/ragEval.ts @@ -4,6 +4,7 @@ import { z } from 'zod'; import { chainAnswerWithContext } from '@/chains/answerWithContext'; import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; import { FileModel } from '@/database/server/models/file'; @@ -24,13 +25,13 @@ const ragEvalProcedure = asyncAuthedProcedure.use(async (opts) => { return opts.next({ ctx: { - chunkModel: new ChunkModel(ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), chunkService: new ChunkService(ctx.userId), datasetRecordModel: new EvalDatasetRecordModel(ctx.userId), - embeddingModel: new EmbeddingModel(ctx.userId), + embeddingModel: new EmbeddingModel(serverDB, ctx.userId), evalRecordModel: new EvaluationRecordModel(ctx.userId), evaluationModel: new EvalEvaluationModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/_template.ts b/src/server/routers/lambda/_template.ts index 10d5fae784dd..9a3b96b7b2bf 100644 --- a/src/server/routers/lambda/_template.ts +++ b/src/server/routers/lambda/_template.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { SessionGroupModel } from '@/database/server/models/sessionGroup'; import { insertSessionGroupSchema } from '@/database/schemas'; import { authedProcedure, router } from '@/libs/trpc'; @@ -10,7 +11,7 @@ const sessionProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.userId), + sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/agent.ts b/src/server/routers/lambda/agent.ts index 1ce31bef76b0..a3735ac732ef 100644 --- a/src/server/routers/lambda/agent.ts +++ b/src/server/routers/lambda/agent.ts @@ -2,6 +2,7 @@ import { z } from 'zod'; import { INBOX_SESSION_ID } from '@/const/session'; import { DEFAULT_AGENT_CONFIG } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { AgentModel } from '@/database/server/models/agent'; import { FileModel } from '@/database/server/models/file'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; @@ -16,10 +17,10 @@ const agentProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - agentModel: new AgentModel(ctx.userId), - fileModel: new FileModel(ctx.userId), - knowledgeBaseModel: new KnowledgeBaseModel(ctx.userId), - sessionModel: new SessionModel(ctx.userId), + agentModel: new AgentModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), + knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId), + sessionModel: new SessionModel(serverDB, ctx.userId), }, }); }); @@ -87,7 +88,7 @@ export const agentRouter = router({ // if there is no session for user, create one if (!item) { // if there is no user, return default config - const user = await UserModel.findById(ctx.userId); + const user = await UserModel.findById(serverDB, ctx.userId); if (!user) return DEFAULT_AGENT_CONFIG; const res = await ctx.sessionModel.createInbox(); diff --git a/src/server/routers/lambda/chunk.ts b/src/server/routers/lambda/chunk.ts index d9818febe52b..07425224af47 100644 --- a/src/server/routers/lambda/chunk.ts +++ b/src/server/routers/lambda/chunk.ts @@ -21,12 +21,12 @@ const chunkProcedure = authedProcedure.use(keyVaults).use(async (opts) => { return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.userId), - chunkModel: new ChunkModel(ctx.userId), + asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), chunkService: new ChunkService(ctx.userId), - embeddingModel: new EmbeddingModel(ctx.userId), - fileModel: new FileModel(ctx.userId), - messageModel: new MessageModel(ctx.userId), + embeddingModel: new EmbeddingModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), + messageModel: new MessageModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/file.ts b/src/server/routers/lambda/file.ts index 8b73cc715ca7..02b4366aa9bc 100644 --- a/src/server/routers/lambda/file.ts +++ b/src/server/routers/lambda/file.ts @@ -1,6 +1,7 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { AsyncTaskModel } from '@/database/server/models/asyncTask'; import { ChunkModel } from '@/database/server/models/chunk'; import { FileModel } from '@/database/server/models/file'; @@ -15,9 +16,9 @@ const fileProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.userId), - chunkModel: new ChunkModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/knowledgeBase.ts b/src/server/routers/lambda/knowledgeBase.ts index 9ccfc01ebfd2..ac0e8456689b 100644 --- a/src/server/routers/lambda/knowledgeBase.ts +++ b/src/server/routers/lambda/knowledgeBase.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; import { insertKnowledgeBasesSchema } from '@/database/schemas'; import { authedProcedure, router } from '@/libs/trpc'; @@ -10,7 +11,7 @@ const knowledgeBaseProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - knowledgeBaseModel: new KnowledgeBaseModel(ctx.userId), + knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/message.ts b/src/server/routers/lambda/message.ts index fb95768d3035..dcc665c4f786 100644 --- a/src/server/routers/lambda/message.ts +++ b/src/server/routers/lambda/message.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { MessageModel } from '@/database/server/models/message'; import { updateMessagePluginSchema } from '@/database/schemas'; import { authedProcedure, publicProcedure, router } from '@/libs/trpc'; @@ -12,7 +13,7 @@ const messageProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { messageModel: new MessageModel(ctx.userId) }, + ctx: { messageModel: new MessageModel(serverDB, ctx.userId) }, }); }); @@ -54,6 +55,7 @@ export const messageRouter = router({ return ctx.messageModel.queryBySessionId(input.sessionId); }), + // TODO: 未来这部分方法也需要使用 authedProcedure getMessages: publicProcedure .input( z.object({ @@ -66,7 +68,7 @@ export const messageRouter = router({ .query(async ({ input, ctx }) => { if (!ctx.userId) return []; - const messageModel = new MessageModel(ctx.userId); + const messageModel = new MessageModel(serverDB, ctx.userId); return messageModel.query(input); }), diff --git a/src/server/routers/lambda/plugin.ts b/src/server/routers/lambda/plugin.ts index 3c691c51df0c..13880b1ff32e 100644 --- a/src/server/routers/lambda/plugin.ts +++ b/src/server/routers/lambda/plugin.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { PluginModel } from '@/database/server/models/plugin'; import { authedProcedure, publicProcedure, router } from '@/libs/trpc'; import { LobeTool } from '@/types/tool'; @@ -8,7 +9,7 @@ const pluginProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { pluginModel: new PluginModel(ctx.userId) }, + ctx: { pluginModel: new PluginModel(serverDB, ctx.userId) }, }); }); @@ -61,10 +62,11 @@ export const pluginRouter = router({ return data.identifier; }), + // TODO: 未来这部分方法也需要使用 authedProcedure getPlugins: publicProcedure.query(async ({ ctx }): Promise => { if (!ctx.userId) return []; - const pluginModel = new PluginModel(ctx.userId); + const pluginModel = new PluginModel(serverDB, ctx.userId); return pluginModel.query(); }), diff --git a/src/server/routers/lambda/ragEval.ts b/src/server/routers/lambda/ragEval.ts index 33b33a715944..150abe4a4e1c 100644 --- a/src/server/routers/lambda/ragEval.ts +++ b/src/server/routers/lambda/ragEval.ts @@ -6,6 +6,7 @@ import pMap from 'p-map'; import { z } from 'zod'; import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { FileModel } from '@/database/server/models/file'; import { EvalDatasetModel, @@ -34,7 +35,7 @@ const ragEvalProcedure = authedProcedure.use(keyVaults).use(async (opts) => { return opts.next({ ctx: { datasetModel: new EvalDatasetModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), datasetRecordModel: new EvalDatasetRecordModel(ctx.userId), evaluationModel: new EvalEvaluationModel(ctx.userId), evaluationRecordModel: new EvaluationRecordModel(ctx.userId), diff --git a/src/server/routers/lambda/session.ts b/src/server/routers/lambda/session.ts index af98cd197b53..65d94538e66e 100644 --- a/src/server/routers/lambda/session.ts +++ b/src/server/routers/lambda/session.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { SessionModel } from '@/database/server/models/session'; import { SessionGroupModel } from '@/database/server/models/sessionGroup'; import { insertAgentSchema, insertSessionSchema } from '@/database/schemas'; @@ -15,8 +16,8 @@ const sessionProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.userId), - sessionModel: new SessionModel(ctx.userId), + sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId), + sessionModel: new SessionModel(serverDB, ctx.userId), }, }); }); @@ -84,7 +85,7 @@ export const sessionRouter = router({ sessions: [], }; - const sessionModel = new SessionModel(ctx.userId); + const sessionModel = new SessionModel(serverDB, ctx.userId); return sessionModel.queryWithGroups(); }), diff --git a/src/server/routers/lambda/sessionGroup.ts b/src/server/routers/lambda/sessionGroup.ts index 10d5fae784dd..9a3b96b7b2bf 100644 --- a/src/server/routers/lambda/sessionGroup.ts +++ b/src/server/routers/lambda/sessionGroup.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { SessionGroupModel } from '@/database/server/models/sessionGroup'; import { insertSessionGroupSchema } from '@/database/schemas'; import { authedProcedure, router } from '@/libs/trpc'; @@ -10,7 +11,7 @@ const sessionProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.userId), + sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/thread.ts b/src/server/routers/lambda/thread.ts index 32e5a5452fec..05cb3eec0846 100644 --- a/src/server/routers/lambda/thread.ts +++ b/src/server/routers/lambda/thread.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { MessageModel } from '@/database/server/models/message'; import { ThreadModel } from '@/database/server/models/thread'; import { insertThreadSchema } from '@/database/schemas'; @@ -11,8 +12,8 @@ const threadProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - messageModel: new MessageModel(ctx.userId), - threadModel: new ThreadModel(ctx.userId), + messageModel: new MessageModel(serverDB, ctx.userId), + threadModel: new ThreadModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/topic.ts b/src/server/routers/lambda/topic.ts index c8b44949b472..5dd07111fb9b 100644 --- a/src/server/routers/lambda/topic.ts +++ b/src/server/routers/lambda/topic.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { TopicModel } from '@/database/server/models/topic'; import { authedProcedure, publicProcedure, router } from '@/libs/trpc'; import { BatchTaskResult } from '@/types/service'; @@ -8,7 +9,7 @@ const topicProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { topicModel: new TopicModel(ctx.userId) }, + ctx: { topicModel: new TopicModel(serverDB, ctx.userId) }, }); }); @@ -78,6 +79,7 @@ export const topicRouter = router({ return ctx.topicModel.queryAll(); }), + // TODO: this procedure should be used with authedProcedure getTopics: publicProcedure .input( z.object({ @@ -89,7 +91,7 @@ export const topicRouter = router({ .query(async ({ input, ctx }) => { if (!ctx.userId) return []; - const topicModel = new TopicModel(ctx.userId); + const topicModel = new TopicModel(serverDB, ctx.userId); return topicModel.query(input); }), diff --git a/src/server/routers/lambda/user.ts b/src/server/routers/lambda/user.ts index 9660a955e691..f7522004940b 100644 --- a/src/server/routers/lambda/user.ts +++ b/src/server/routers/lambda/user.ts @@ -3,6 +3,7 @@ import { currentUser } from '@clerk/nextjs/server'; import { z } from 'zod'; import { enableClerk } from '@/const/auth'; +import { serverDB } from '@/database/server'; import { MessageModel } from '@/database/server/models/message'; import { SessionModel } from '@/database/server/models/session'; import { UserModel, UserNotFoundError } from '@/database/server/models/user'; @@ -12,7 +13,7 @@ import { UserGuideSchema, UserInitializationState, UserPreference } from '@/type const userProcedure = authedProcedure.use(async (opts) => { return opts.next({ - ctx: { userModel: new UserModel() }, + ctx: { userModel: new UserModel(serverDB, opts.ctx.userId) }, }); }); @@ -23,7 +24,7 @@ export const userRouter = router({ // get or create first-time user while (!state) { try { - state = await ctx.userModel.getUserState(ctx.userId); + state = await ctx.userModel.getUserState(); } catch (error) { if (enableClerk && error instanceof UserNotFoundError) { const user = await currentUser(); @@ -56,10 +57,10 @@ export const userRouter = router({ } } - const messageModel = new MessageModel(ctx.userId); + const messageModel = new MessageModel(serverDB, ctx.userId); const messageCount = await messageModel.count(); - const sessionModel = new SessionModel(ctx.userId); + const sessionModel = new SessionModel(serverDB, ctx.userId); const sessionCount = await sessionModel.count(); return { @@ -77,25 +78,25 @@ export const userRouter = router({ }), makeUserOnboarded: userProcedure.mutation(async ({ ctx }) => { - return ctx.userModel.updateUser(ctx.userId, { isOnboarded: true }); + return ctx.userModel.updateUser({ isOnboarded: true }); }), resetSettings: userProcedure.mutation(async ({ ctx }) => { - return ctx.userModel.deleteSetting(ctx.userId); + return ctx.userModel.deleteSetting(); }), updateGuide: userProcedure.input(UserGuideSchema).mutation(async ({ ctx, input }) => { - return ctx.userModel.updateGuide(ctx.userId, input); + return ctx.userModel.updateGuide(input); }), updatePreference: userProcedure.input(z.any()).mutation(async ({ ctx, input }) => { - return ctx.userModel.updatePreference(ctx.userId, input); + return ctx.userModel.updatePreference(input); }), updateSettings: userProcedure .input(z.object({}).passthrough()) .mutation(async ({ ctx, input }) => { - return ctx.userModel.updateSetting(ctx.userId, input); + return ctx.userModel.updateSetting(input); }), }); diff --git a/src/server/services/chunk/index.ts b/src/server/services/chunk/index.ts index 3ef1ba13275d..09c8e1d820b6 100644 --- a/src/server/services/chunk/index.ts +++ b/src/server/services/chunk/index.ts @@ -1,4 +1,5 @@ import { JWTPayload } from '@/const/auth'; +import { serverDB } from '@/database/server'; import { AsyncTaskModel } from '@/database/server/models/asyncTask'; import { FileModel } from '@/database/server/models/file'; import { ChunkContentParams, ContentChunk } from '@/server/modules/ContentChunk'; @@ -21,8 +22,8 @@ export class ChunkService { this.chunkClient = new ContentChunk(); - this.fileModel = new FileModel(userId); - this.asyncTaskModel = new AsyncTaskModel(userId); + this.fileModel = new FileModel(serverDB, userId); + this.asyncTaskModel = new AsyncTaskModel(serverDB, userId); } async chunkContent(params: ChunkContentParams) { diff --git a/src/server/services/nextAuthUser/index.ts b/src/server/services/nextAuthUser/index.ts index e832a6d04aa1..85f972aab576 100644 --- a/src/server/services/nextAuthUser/index.ts +++ b/src/server/services/nextAuthUser/index.ts @@ -7,11 +7,9 @@ import { pino } from '@/libs/logger'; import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter'; export class NextAuthUserService { - userModel; adapter; constructor() { - this.userModel = new UserModel(); this.adapter = LobeNextAuthDbAdapter(serverDB); } @@ -29,8 +27,10 @@ export class NextAuthUserService { // 2. If found, Update user data from provider if (user?.id) { + const userModel = new UserModel(serverDB, user.id); + // Perform update - await this.userModel.updateUser(user.id, { + await userModel.updateUser({ avatar: data?.avatar, email: data?.email, fullName: data?.fullName, diff --git a/src/server/services/user/index.ts b/src/server/services/user/index.ts index 1bc3480fed74..ed5ee2395099 100644 --- a/src/server/services/user/index.ts +++ b/src/server/services/user/index.ts @@ -1,13 +1,14 @@ import { UserJSON } from '@clerk/backend'; import { NextResponse } from 'next/server'; +import { serverDB } from '@/database/server'; import { UserModel } from '@/database/server/models/user'; import { pino } from '@/libs/logger'; export class UserService { createUser = async (id: string, params: UserJSON) => { // Check if user already exists - const res = await UserModel.findById(id); + const res = await UserModel.findById(serverDB, id); // If user already exists, skip creating a new user if (res) @@ -27,7 +28,7 @@ export class UserService { /* ↑ cloud slot ↑ */ // 2. create user in database - await UserModel.createUser({ + await UserModel.createUser(serverDB, { avatar: params.image_url, clerkCreatedAt: new Date(params.created_at), email: email?.email_address, @@ -49,7 +50,7 @@ export class UserService { if (id) { pino.info('delete user due to clerk webhook'); - await UserModel.deleteUser(id); + await UserModel.deleteUser(serverDB, id); return NextResponse.json({ message: 'user deleted' }, { status: 200 }); } else { @@ -61,10 +62,10 @@ export class UserService { updateUser = async (id: string, params: UserJSON) => { pino.info('updating user due to clerk webhook'); - const userModel = new UserModel(); + const userModel = new UserModel(serverDB, id); // Check if user already exists - const res = await UserModel.findById(id); + const res = await UserModel.findById(serverDB, id); // If user not exists, skip update the user if (!res) @@ -79,7 +80,7 @@ export class UserService { const email = params.email_addresses.find((e) => e.id === params.primary_email_address_id); const phone = params.phone_numbers.find((e) => e.id === params.primary_phone_number_id); - await userModel.updateUser(id, { + await userModel.updateUser({ avatar: params.image_url, email: email?.email_address, firstName: params.first_name,