diff --git a/CHANGELOG.md b/CHANGELOG.md
index 04fde257b450..cca23a1dec77 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,31 @@
# Changelog
+### [Version 1.35.10](https://github.com/lobehub/lobe-chat/compare/v1.35.9...v1.35.10)
+
+Released on **2024-12-03**
+
+#### ♻ Code Refactoring
+
+- **misc**: Refactor the server db model implement.
+
+
+
+
+Improvements and Fixes
+
+#### Code refactoring
+
+- **misc**: Refactor the server db model implement, closes [#4878](https://github.com/lobehub/lobe-chat/issues/4878) ([3814853](https://github.com/lobehub/lobe-chat/commit/3814853))
+
+
+
+
+
+[![](https://img.shields.io/badge/-BACK_TO_TOP-151515?style=flat-square)](#readme-top)
+
+
+
### [Version 1.35.9](https://github.com/lobehub/lobe-chat/compare/v1.35.8...v1.35.9)
Released on **2024-12-03**
diff --git a/changelog/v1.json b/changelog/v1.json
index 04d0169d7bc9..6c349e43a780 100644
--- a/changelog/v1.json
+++ b/changelog/v1.json
@@ -1,4 +1,11 @@
[
+ {
+ "children": {
+ "improvements": ["Refactor the server db model implement."]
+ },
+ "date": "2024-12-03",
+ "version": "1.35.10"
+ },
{
"children": {},
"date": "2024-12-03",
diff --git a/package.json b/package.json
index 0bf2673fca33..f758e3fd829c 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "@lobehub/chat",
- "version": "1.35.9",
+ "version": "1.35.10",
"description": "Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.",
"keywords": [
"framework",
diff --git a/src/app/(main)/repos/[id]/@menu/default.tsx b/src/app/(main)/repos/[id]/@menu/default.tsx
index 2d25320f15ea..2c96eec9fd7d 100644
--- a/src/app/(main)/repos/[id]/@menu/default.tsx
+++ b/src/app/(main)/repos/[id]/@menu/default.tsx
@@ -1,6 +1,7 @@
import { notFound } from 'next/navigation';
import { Flexbox } from 'react-layout-kit';
+import { serverDB } from '@/database/server';
import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase';
import Head from './Head';
@@ -14,7 +15,7 @@ type Props = { params: Params };
const MenuPage = async ({ params }: Props) => {
const id = params.id;
- const item = await KnowledgeBaseModel.findById(params.id);
+ const item = await KnowledgeBaseModel.findById(serverDB, params.id);
if (!item) return notFound();
diff --git a/src/app/(main)/repos/[id]/page.tsx b/src/app/(main)/repos/[id]/page.tsx
index 24662ac70071..fce88618a3df 100644
--- a/src/app/(main)/repos/[id]/page.tsx
+++ b/src/app/(main)/repos/[id]/page.tsx
@@ -1,5 +1,6 @@
import { redirect } from 'next/navigation';
+import { serverDB } from '@/database/server';
import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase';
import FileManager from '@/features/FileManager';
@@ -10,7 +11,7 @@ interface Params {
type Props = { params: Params };
export default async ({ params }: Props) => {
- const item = await KnowledgeBaseModel.findById(params.id);
+ const item = await KnowledgeBaseModel.findById(serverDB, params.id);
if (!item) return redirect('/repos');
diff --git a/src/database/schemas/topic.ts b/src/database/schemas/topic.ts
index 705682fcf99b..3d83636ea0e6 100644
--- a/src/database/schemas/topic.ts
+++ b/src/database/schemas/topic.ts
@@ -3,6 +3,8 @@ import { boolean, jsonb, pgTable, text, unique } from 'drizzle-orm/pg-core';
import { createInsertSchema } from 'drizzle-zod';
import { idGenerator } from '@/database/utils/idGenerator';
+import { ChatTopicMetadata } from '@/types/topic';
+
import { timestamps, timestamptz } from './_helpers';
import { sessions } from './session';
import { users } from './user';
@@ -21,7 +23,7 @@ export const topics = pgTable(
.notNull(),
clientId: text('client_id'),
historySummary: text('history_summary'),
- metadata: jsonb('metadata'),
+ metadata: jsonb('metadata').$type(),
...timestamps,
},
(t) => ({
diff --git a/src/database/server/core/dbForTest.ts b/src/database/server/core/dbForTest.ts
index c4194f24edeb..e84d2ed11151 100644
--- a/src/database/server/core/dbForTest.ts
+++ b/src/database/server/core/dbForTest.ts
@@ -11,6 +11,8 @@ import { serverDBEnv } from '@/config/db';
import * as schema from '../../schemas';
+const migrationsFolder = join(__dirname, '../../migrations');
+
export const getTestDBInstance = async () => {
let connectionString = serverDBEnv.DATABASE_TEST_URL;
@@ -23,9 +25,7 @@ export const getTestDBInstance = async () => {
const db = nodeDrizzle(client, { schema });
- await nodeMigrator.migrate(db, {
- migrationsFolder: join(__dirname, '../../migrations'),
- });
+ await nodeMigrator.migrate(db, { migrationsFolder });
return db;
}
@@ -37,9 +37,7 @@ export const getTestDBInstance = async () => {
const db = neonDrizzle(client, { schema });
- await migrator.migrate(db, {
- migrationsFolder: join(__dirname, '../migrations'),
- });
+ await migrator.migrate(db, { migrationsFolder });
return db;
};
diff --git a/src/database/server/models/__tests__/_test_template.ts b/src/database/server/models/__tests__/_test_template.ts
index ee43feaa9907..0a96afd0a72b 100644
--- a/src/database/server/models/__tests__/_test_template.ts
+++ b/src/database/server/models/__tests__/_test_template.ts
@@ -1,6 +1,6 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
-import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
+import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
@@ -9,14 +9,8 @@ import { SessionGroupModel } from '../sessionGroup';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'session-group-model-test-user-id';
-const sessionGroupModel = new SessionGroupModel(userId);
+const sessionGroupModel = new SessionGroupModel(serverDB, userId);
beforeEach(async () => {
await serverDB.delete(users);
@@ -74,7 +68,7 @@ describe('SessionGroupModel', () => {
await sessionGroupModel.create({ name: 'Test Group 1' });
await sessionGroupModel.create({ name: 'Test Group 333' });
- const anotherSessionGroupModel = new SessionGroupModel('user2');
+ const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2');
await anotherSessionGroupModel.create({ name: 'Test Group 2' });
await sessionGroupModel.deleteAll();
diff --git a/src/database/server/models/__tests__/agent.test.ts b/src/database/server/models/__tests__/agent.test.ts
index 3312e8087ef9..c7e5b46c9b54 100644
--- a/src/database/server/models/__tests__/agent.test.ts
+++ b/src/database/server/models/__tests__/agent.test.ts
@@ -1,6 +1,6 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
-import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
+import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
@@ -18,14 +18,8 @@ import { AgentModel } from '../agent';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'agent-model-test-user-id';
-const agentModel = new AgentModel(userId);
+const agentModel = new AgentModel(serverDB, userId);
const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' };
const fileList = [
diff --git a/src/database/server/models/__tests__/asyncTask.test.ts b/src/database/server/models/__tests__/asyncTask.test.ts
index 8c51f5dc3c49..5efafaf3214a 100644
--- a/src/database/server/models/__tests__/asyncTask.test.ts
+++ b/src/database/server/models/__tests__/asyncTask.test.ts
@@ -10,14 +10,8 @@ import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '../asyncTask';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'async-task-model-test-user-id';
-const asyncTaskModel = new AsyncTaskModel(userId);
+const asyncTaskModel = new AsyncTaskModel(serverDB, userId);
beforeEach(async () => {
await serverDB.delete(users);
diff --git a/src/database/server/models/__tests__/chunk.test.ts b/src/database/server/models/__tests__/chunk.test.ts
index 3cf20b1bae89..e027ef362346 100644
--- a/src/database/server/models/__tests__/chunk.test.ts
+++ b/src/database/server/models/__tests__/chunk.test.ts
@@ -1,30 +1,18 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
-import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
+import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
+import { uuid } from '@/utils/uuid';
-import {
- chunks,
- embeddings,
- fileChunks,
- files,
- unstructuredChunks,
- users,
-} from '../../../schemas';
+import { chunks, embeddings, fileChunks, files, unstructuredChunks, users } from '../../../schemas';
import { ChunkModel } from '../chunk';
import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixtures/embedding';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'chunk-model-test-user-id';
-const chunkModel = new ChunkModel(userId);
+const chunkModel = new ChunkModel(serverDB, userId);
const sharedFileList = [
{
id: '1',
@@ -79,6 +67,27 @@ describe('ChunkModel', () => {
expect(createdChunks[0]).toMatchObject(params[0]);
expect(createdChunks[1]).toMatchObject(params[1]);
});
+
+ // 测试空参数场景
+ it('should handle empty params array', async () => {
+ const result = await chunkModel.bulkCreate([], '1');
+ expect(result).toHaveLength(0);
+ });
+
+ // 测试事务回滚
+ it('should rollback transaction on error', async () => {
+ const invalidParams = [
+ { text: 'Chunk 1', userId },
+ { index: 'abc', userId }, // 这会导致错误
+ ] as any;
+
+ await expect(chunkModel.bulkCreate(invalidParams, '1')).rejects.toThrow();
+
+ const createdChunks = await serverDB.query.chunks.findMany({
+ where: eq(chunks.userId, userId),
+ });
+ expect(createdChunks).toHaveLength(0);
+ });
});
describe('delete', () => {
@@ -191,6 +200,41 @@ describe('ChunkModel', () => {
expect(result[1].id).toBe(chunk2.id);
expect(result[0].similarity).toBeGreaterThan(result[1].similarity);
});
+ // 补充无文件 ID 的搜索场景
+ it('should perform semantic search without fileIds', async () => {
+ const [chunk1, chunk2] = await serverDB
+ .insert(chunks)
+ .values([
+ { text: 'Test Chunk 1', userId },
+ { text: 'Test Chunk 2', userId },
+ ])
+ .returning();
+
+ await serverDB.insert(embeddings).values([
+ { chunkId: chunk1.id, embeddings: designThinkingQuery, userId },
+ { chunkId: chunk2.id, embeddings: codeEmbedding, userId },
+ ]);
+
+ const result = await chunkModel.semanticSearch({
+ embedding: designThinkingQuery2,
+ fileIds: undefined,
+ query: 'design thinking',
+ });
+
+ expect(result).toBeDefined();
+ expect(result).toHaveLength(2);
+ });
+
+ // 测试空结果场景
+ it('should return empty array when no matches found', async () => {
+ const result = await chunkModel.semanticSearch({
+ embedding: designThinkingQuery,
+ fileIds: ['non-existent-file'],
+ query: 'no matches',
+ });
+
+ expect(result).toHaveLength(0);
+ });
});
describe('bulkCreateUnstructuredChunks', () => {
@@ -391,5 +435,100 @@ content in Table html is below:
`);
});
+
+ it('should handle null text', () => {
+ const chunk = {
+ text: null,
+ type: 'Text',
+ metadata: {},
+ };
+
+ const result = chunkModel['mapChunkText'](chunk);
+ expect(result).toBeNull();
+ });
+
+ it('should handle missing metadata for Table type', () => {
+ const chunk = {
+ text: 'Table text',
+ type: 'Table',
+ metadata: {},
+ };
+
+ const result = chunkModel['mapChunkText'](chunk);
+ expect(result).toContain('Table text');
+ expect(result).toContain('content in Table html is below:');
+ expect(result).toContain('undefined'); // metadata.text_as_html is undefined
+ });
+ });
+
+ describe('findById', () => {
+ it('should find a chunk by id', async () => {
+ // Create a test chunk
+ const [chunk] = await serverDB
+ .insert(chunks)
+ .values({ text: 'Test Chunk', userId })
+ .returning();
+
+ const result = await chunkModel.findById(chunk.id);
+
+ expect(result).toBeDefined();
+ expect(result?.id).toBe(chunk.id);
+ expect(result?.text).toBe('Test Chunk');
+ });
+
+ it('should return null for non-existent id', async () => {
+ const result = await chunkModel.findById(uuid());
+ expect(result).toBeUndefined();
+ });
+ });
+
+ describe('semanticSearchForChat', () => {
+ // 测试空文件 ID 列表场景
+ it('should return empty array when fileIds is empty', async () => {
+ const result = await chunkModel.semanticSearchForChat({
+ embedding: designThinkingQuery,
+ fileIds: [],
+ query: 'test',
+ });
+
+ expect(result).toHaveLength(0);
+ });
+
+ // 测试结果限制
+ it('should limit results to 5 items', async () => {
+ const fileId = '1';
+ // Create 6 chunks
+ const chunkResult = await serverDB
+ .insert(chunks)
+ .values(
+ Array(6)
+ .fill(0)
+ .map((_, i) => ({ text: `Test Chunk ${i}`, userId })),
+ )
+ .returning();
+
+ await serverDB.insert(fileChunks).values(
+ chunkResult.map((chunk) => ({
+ fileId,
+ chunkId: chunk.id,
+ })),
+ );
+
+ await serverDB.insert(embeddings).values(
+ chunkResult.map((chunk) => ({
+ chunkId: chunk.id,
+ embeddings: designThinkingQuery,
+ userId,
+ })),
+ );
+
+ const result = await chunkModel.semanticSearchForChat({
+ embedding: designThinkingQuery2,
+ fileIds: [fileId],
+ query: 'test',
+ });
+
+ expect(result).toHaveLength(5);
+ });
});
});
diff --git a/src/database/server/models/__tests__/file.test.ts b/src/database/server/models/__tests__/file.test.ts
index 4fcea38cd482..9d94b433e28c 100644
--- a/src/database/server/models/__tests__/file.test.ts
+++ b/src/database/server/models/__tests__/file.test.ts
@@ -2,27 +2,14 @@
import { eq, inArray } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
-import { getServerDBConfig, serverDBEnv } from '@/config/db';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
import { FilesTabs, SortType } from '@/types/files';
-import {
- files,
- globalFiles,
- knowledgeBaseFiles,
- knowledgeBases,
- users,
-} from '../../../schemas';
+import { files, globalFiles, knowledgeBaseFiles, knowledgeBases, users } from '../../../schemas';
import { FileModel } from '../file';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
let DISABLE_REMOVE_GLOBAL_FILE = false;
vi.mock('@/config/db', async () => ({
@@ -38,7 +25,7 @@ vi.mock('@/config/db', async () => ({
}));
const userId = 'file-model-test-user-id';
-const fileModel = new FileModel(userId);
+const fileModel = new FileModel(serverDB, userId);
const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' };
beforeEach(async () => {
@@ -603,4 +590,125 @@ describe('FileModel', () => {
expect(size).toBe(3500);
});
});
+
+ describe('findByNames', () => {
+ it('should find files by names', async () => {
+ // 准备测试数据
+ const fileList = [
+ {
+ name: 'test1.txt',
+ url: 'https://example.com/test1.txt',
+ size: 100,
+ fileType: 'text/plain',
+ userId,
+ },
+ {
+ name: 'test2.txt',
+ url: 'https://example.com/test2.txt',
+ size: 200,
+ fileType: 'text/plain',
+ userId,
+ },
+ {
+ name: 'other.txt',
+ url: 'https://example.com/other.txt',
+ size: 300,
+ fileType: 'text/plain',
+ userId,
+ },
+ ];
+
+ await serverDB.insert(files).values(fileList);
+
+ // 测试查找文件
+ const result = await fileModel.findByNames(['test1', 'test2']);
+ expect(result).toHaveLength(2);
+ expect(result.map((f) => f.name)).toContain('test1.txt');
+ expect(result.map((f) => f.name)).toContain('test2.txt');
+ });
+
+ it('should return empty array when no files match names', async () => {
+ const result = await fileModel.findByNames(['nonexistent']);
+ expect(result).toHaveLength(0);
+ });
+
+ it('should only find files belonging to current user', async () => {
+ // 准备测试数据
+ await serverDB.insert(files).values([
+ {
+ name: 'test1.txt',
+ url: 'https://example.com/test1.txt',
+ size: 100,
+ fileType: 'text/plain',
+ userId,
+ },
+ {
+ name: 'test2.txt',
+ url: 'https://example.com/test2.txt',
+ size: 200,
+ fileType: 'text/plain',
+ userId: 'user2', // 不同用户的文件
+ },
+ ]);
+
+ const result = await fileModel.findByNames(['test']);
+ expect(result).toHaveLength(1);
+ expect(result[0].name).toBe('test1.txt');
+ });
+ });
+
+ describe('deleteGlobalFile', () => {
+ it('should delete global file by hashId', async () => {
+ // 准备测试数据
+ const globalFile = {
+ hashId: 'test-hash',
+ fileType: 'text/plain',
+ size: 100,
+ url: 'https://example.com/global-file.txt',
+ metadata: { key: 'value' },
+ };
+
+ await serverDB.insert(globalFiles).values(globalFile);
+
+ // 执行删除操作
+ await fileModel.deleteGlobalFile('test-hash');
+
+ // 验证文件已被删除
+ const result = await serverDB.query.globalFiles.findFirst({
+ where: eq(globalFiles.hashId, 'test-hash'),
+ });
+ expect(result).toBeUndefined();
+ });
+
+ it('should not throw error when deleting non-existent global file', async () => {
+ // 删除不存在的文件不应抛出错误
+ await expect(fileModel.deleteGlobalFile('non-existent-hash')).resolves.not.toThrow();
+ });
+
+ it('should only delete specified global file', async () => {
+ // 准备测试数据
+ const globalFiles1 = {
+ hashId: 'hash1',
+ fileType: 'text/plain',
+ size: 100,
+ url: 'https://example.com/file1.txt',
+ };
+ const globalFiles2 = {
+ hashId: 'hash2',
+ fileType: 'text/plain',
+ size: 200,
+ url: 'https://example.com/file2.txt',
+ };
+
+ await serverDB.insert(globalFiles).values([globalFiles1, globalFiles2]);
+
+ // 删除一个文件
+ await fileModel.deleteGlobalFile('hash1');
+
+ // 验证只有指定文件被删除
+ const remainingFiles = await serverDB.query.globalFiles.findMany();
+ expect(remainingFiles).toHaveLength(1);
+ expect(remainingFiles[0].hashId).toBe('hash2');
+ });
+ });
});
diff --git a/src/database/server/models/__tests__/knowledgeBase.test.ts b/src/database/server/models/__tests__/knowledgeBase.test.ts
index 38bea2345a62..008ac19c58a2 100644
--- a/src/database/server/models/__tests__/knowledgeBase.test.ts
+++ b/src/database/server/models/__tests__/knowledgeBase.test.ts
@@ -1,7 +1,7 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { and, desc } from 'drizzle-orm/expressions';
-import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
+import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
@@ -17,14 +17,8 @@ import { KnowledgeBaseModel } from '../knowledgeBase';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'session-group-model-test-user-id';
-const knowledgeBaseModel = new KnowledgeBaseModel(userId);
+const knowledgeBaseModel = new KnowledgeBaseModel(serverDB, userId);
beforeEach(async () => {
await serverDB.delete(users);
@@ -82,7 +76,7 @@ describe('KnowledgeBaseModel', () => {
await knowledgeBaseModel.create({ name: 'Test Group 1' });
await knowledgeBaseModel.create({ name: 'Test Group 333' });
- const anotherSessionGroupModel = new KnowledgeBaseModel('user2');
+ const anotherSessionGroupModel = new KnowledgeBaseModel(serverDB, 'user2');
await anotherSessionGroupModel.create({ name: 'Test Group 2' });
await knowledgeBaseModel.deleteAll();
@@ -235,7 +229,7 @@ describe('KnowledgeBaseModel', () => {
it('should find a knowledge base by id without user restriction', async () => {
const { id } = await knowledgeBaseModel.create({ name: 'Test Group' });
- const group = await KnowledgeBaseModel.findById(id);
+ const group = await KnowledgeBaseModel.findById(serverDB, id);
expect(group).toMatchObject({
id,
name: 'Test Group',
@@ -244,10 +238,10 @@ describe('KnowledgeBaseModel', () => {
});
it('should find a knowledge base created by another user', async () => {
- const anotherKnowledgeBaseModel = new KnowledgeBaseModel('user2');
+ const anotherKnowledgeBaseModel = new KnowledgeBaseModel(serverDB, 'user2');
const { id } = await anotherKnowledgeBaseModel.create({ name: 'Another User Group' });
- const group = await KnowledgeBaseModel.findById(id);
+ const group = await KnowledgeBaseModel.findById(serverDB, id);
expect(group).toMatchObject({
id,
name: 'Another User Group',
diff --git a/src/database/server/models/__tests__/message.test.ts b/src/database/server/models/__tests__/message.test.ts
index 262082d3320d..653a41d9b950 100644
--- a/src/database/server/models/__tests__/message.test.ts
+++ b/src/database/server/models/__tests__/message.test.ts
@@ -2,10 +2,16 @@ import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
+import { uuid } from '@/utils/uuid';
import {
+ chunks,
+ embeddings,
+ fileChunks,
files,
messagePlugins,
+ messageQueries,
+ messageQueryChunks,
messageTTS,
messageTranslates,
messages,
@@ -15,17 +21,12 @@ import {
users,
} from '../../../schemas';
import { MessageModel } from '../message';
+import { codeEmbedding } from './fixtures/embedding';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'message-db';
-const messageModel = new MessageModel(userId);
+const messageModel = new MessageModel(serverDB, userId);
beforeEach(async () => {
// 在每个测试用例之前,清空表
@@ -273,6 +274,93 @@ describe('MessageModel', () => {
const result3 = await messageModel.query({ current: 2, pageSize: 2 });
expect(result3).toHaveLength(0);
});
+
+ // 补充测试复杂查询场景
+ it('should handle complex query with multiple joins and file chunks', async () => {
+ await serverDB.transaction(async (trx) => {
+ const chunk1Id = uuid();
+ const query1Id = uuid();
+ // 创建基础消息
+ await trx.insert(messages).values({
+ id: 'msg1',
+ userId,
+ role: 'user',
+ content: 'test message',
+ createdAt: new Date('2023-01-01'),
+ });
+
+ // 创建文件
+ await trx.insert(files).values([
+ {
+ id: 'file1',
+ userId,
+ name: 'test.txt',
+ url: 'test-url',
+ fileType: 'text/plain',
+ size: 100,
+ },
+ ]);
+
+ // 创建文件块
+ await trx.insert(chunks).values({
+ id: chunk1Id,
+ text: 'chunk content',
+ });
+
+ // 关联消息和文件
+ await trx.insert(messagesFiles).values({
+ messageId: 'msg1',
+ fileId: 'file1',
+ });
+
+ // 创建文件块关联
+ await trx.insert(fileChunks).values({
+ fileId: 'file1',
+ chunkId: chunk1Id,
+ });
+
+ // 创建消息查询
+ await trx.insert(messageQueries).values({
+ id: query1Id,
+ messageId: 'msg1',
+ userQuery: 'original query',
+ rewriteQuery: 'rewritten query',
+ });
+
+ // 创建消息查询块关联
+ await trx.insert(messageQueryChunks).values({
+ messageId: 'msg1',
+ queryId: query1Id,
+ chunkId: chunk1Id,
+ similarity: '0.95',
+ });
+ });
+
+ const result = await messageModel.query();
+
+ expect(result).toHaveLength(1);
+ expect(result[0].chunksList).toHaveLength(1);
+ expect(result[0].chunksList[0]).toMatchObject({
+ text: 'chunk content',
+ similarity: 0.95,
+ });
+ });
+
+ it('should return empty arrays for files and chunks if none exist', async () => {
+ await serverDB.insert(messages).values({
+ id: 'msg1',
+ userId,
+ role: 'user',
+ content: 'test message',
+ });
+
+ const result = await messageModel.query();
+
+ expect(result).toHaveLength(1);
+ expect(result[0].fileList).toEqual([]);
+ expect(result[0].imageList).toEqual([]);
+ expect(result[0].chunksList).toEqual([]);
+ });
});
describe('queryAll', () => {
@@ -1021,4 +1109,139 @@ describe('MessageModel', () => {
expect(result).toBe(2);
});
});
+
+ describe('findMessageQueriesById', () => {
+ it('should return undefined for non-existent message query', async () => {
+ const result = await messageModel.findMessageQueriesById('non-existent-id');
+ expect(result).toBeUndefined();
+ });
+
+ it('should return message query with embeddings', async () => {
+ const query1Id = uuid();
+ const embeddings1Id = uuid();
+
+ await serverDB.transaction(async (trx) => {
+ await trx.insert(messages).values({ id: 'msg1', userId, role: 'user', content: 'abc' });
+
+ await trx.insert(embeddings).values({
+ id: embeddings1Id,
+ embeddings: codeEmbedding,
+ });
+
+ await trx.insert(messageQueries).values({
+ id: query1Id,
+ messageId: 'msg1',
+ userQuery: 'test query',
+ rewriteQuery: 'rewritten query',
+ embeddingsId: embeddings1Id,
+ });
+ });
+
+ const result = await messageModel.findMessageQueriesById('msg1');
+
+ expect(result).toBeDefined();
+ expect(result).toMatchObject({
+ id: query1Id,
+ userQuery: 'test query',
+ rewriteQuery: 'rewritten query',
+ embeddings: codeEmbedding,
+ });
+ });
+ });
+
+ describe('deleteMessagesBySession', () => {
+ it('should delete messages by session ID', async () => {
+ await serverDB.insert(sessions).values([
+ { id: 'session1', userId },
+ { id: 'session2', userId },
+ ]);
+
+ await serverDB.insert(messages).values([
+ {
+ id: '1',
+ userId,
+ sessionId: 'session1',
+ role: 'user',
+ content: 'message 1',
+ },
+ {
+ id: '2',
+ userId,
+ sessionId: 'session1',
+ role: 'assistant',
+ content: 'message 2',
+ },
+ {
+ id: '3',
+ userId,
+ sessionId: 'session2',
+ role: 'user',
+ content: 'message 3',
+ },
+ ]);
+
+ await messageModel.deleteMessagesBySession('session1');
+
+ const remainingMessages = await serverDB
+ .select()
+ .from(messages)
+ .where(eq(messages.userId, userId));
+
+ expect(remainingMessages).toHaveLength(1);
+ expect(remainingMessages[0].id).toBe('3');
+ });
+
+ it('should delete messages by session ID and topic ID', async () => {
+ await serverDB.insert(sessions).values([{ id: 'session1', userId }]);
+ await serverDB.insert(topics).values([
+ { id: 'topic1', sessionId: 'session1', userId },
+ { id: 'topic2', sessionId: 'session1', userId },
+ ]);
+
+ await serverDB.insert(messages).values([
+ {
+ id: '1',
+ userId,
+ sessionId: 'session1',
+ topicId: 'topic1',
+ role: 'user',
+ content: 'message 1',
+ },
+ {
+ id: '2',
+ userId,
+ sessionId: 'session1',
+ topicId: 'topic2',
+ role: 'assistant',
+ content: 'message 2',
+ },
+ ]);
+
+ await messageModel.deleteMessagesBySession('session1', 'topic1');
+
+ const remainingMessages = await serverDB
+ .select()
+ .from(messages)
+ .where(eq(messages.userId, userId));
+
+ expect(remainingMessages).toHaveLength(1);
+ expect(remainingMessages[0].id).toBe('2');
+ });
+ });
+
+ describe('genId', () => {
+ it('should generate unique message IDs', () => {
+ const model = new MessageModel(serverDB, userId);
+ // @ts-ignore - accessing private method for testing
+ const id1 = model.genId();
+ // @ts-ignore - accessing private method for testing
+ const id2 = model.genId();
+
+ expect(id1).toHaveLength(18);
+ expect(id2).toHaveLength(18);
+ expect(id1).not.toBe(id2);
+ expect(id1).toMatch(/^msg_/);
+ expect(id2).toMatch(/^msg_/);
+ });
+ });
});
diff --git a/src/database/server/models/__tests__/nextauth.test.ts b/src/database/server/models/__tests__/nextauth.test.ts
index 6038f11f1c73..82699ed0542a 100644
--- a/src/database/server/models/__tests__/nextauth.test.ts
+++ b/src/database/server/models/__tests__/nextauth.test.ts
@@ -8,7 +8,6 @@ import type {
import { eq } from 'drizzle-orm';
import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest';
-import { getTestDBInstance } from '@/database/server/core/dbForTest';
import {
nextauthAccounts,
nextauthAuthenticators,
@@ -16,17 +15,12 @@ import {
nextauthVerificationTokens,
users,
} from '@/database/schemas';
+import { getTestDBInstance } from '@/database/server/core/dbForTest';
import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter';
let serverDB = await getTestDBInstance();
let nextAuthAdapter = LobeNextAuthDbAdapter(serverDB);
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'user-db';
const user: AdapterUser = {
id: userId,
diff --git a/src/database/server/models/__tests__/plugin.test.ts b/src/database/server/models/__tests__/plugin.test.ts
index 449cac8cd73b..8ea58e06dc90 100644
--- a/src/database/server/models/__tests__/plugin.test.ts
+++ b/src/database/server/models/__tests__/plugin.test.ts
@@ -8,14 +8,8 @@ import { PluginModel } from '../plugin';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'plugin-db';
-const pluginModel = new PluginModel(userId);
+const pluginModel = new PluginModel(serverDB, userId);
beforeEach(async () => {
await serverDB.transaction(async (trx) => {
diff --git a/src/database/server/models/__tests__/session.test.ts b/src/database/server/models/__tests__/session.test.ts
index bccea65f9277..923504d25ab3 100644
--- a/src/database/server/models/__tests__/session.test.ts
+++ b/src/database/server/models/__tests__/session.test.ts
@@ -1,7 +1,9 @@
-import { eq, inArray } from 'drizzle-orm';
+import { and, eq, inArray } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
+import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
+import { idGenerator } from '@/database/utils/idGenerator';
import {
NewSession,
@@ -14,19 +16,12 @@ import {
topics,
users,
} from '../../../schemas';
-import { idGenerator } from '@/database/utils/idGenerator';
import { SessionModel } from '../session';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'session-user';
-const sessionModel = new SessionModel(userId);
+const sessionModel = new SessionModel(serverDB, userId);
beforeEach(async () => {
await serverDB.delete(users);
@@ -259,7 +254,13 @@ describe('SessionModel', () => {
]);
await serverDB.insert(agents).values([
- { id: 'agent-1', userId, model: 'gpt-3.5-turbo', title: 'Agent 1', description: 'Description with Keyword' },
+ {
+ id: 'agent-1',
+ userId,
+ model: 'gpt-3.5-turbo',
+ title: 'Agent 1',
+ description: 'Description with Keyword',
+ },
{ id: 'agent-2', userId, model: 'gpt-4', title: 'Agent 2' },
]);
@@ -338,7 +339,7 @@ describe('SessionModel', () => {
});
});
- describe.skip('batchCreate', () => {
+ describe('batchCreate', () => {
it('should batch create sessions', async () => {
// 调用 batchCreate 方法
const sessions: NewSession[] = [
@@ -624,4 +625,161 @@ describe('SessionModel', () => {
expect(await serverDB.select().from(sessions).where(eq(sessions.id, '3'))).toHaveLength(1);
});
});
+
+ // 在原有的 describe('SessionModel') 中添加以下测试套件
+
+ describe('createInbox', () => {
+ it('should create inbox session if not exists', async () => {
+ const inbox = await sessionModel.createInbox();
+
+ expect(inbox).toBeDefined();
+ expect(inbox?.slug).toBe('inbox');
+
+ // verify agent config
+ const session = await sessionModel.findByIdOrSlug('inbox');
+ expect(session?.agent).toBeDefined();
+ expect(session?.agent.model).toBe(DEFAULT_AGENT_CONFIG.model);
+ });
+
+ it('should not create duplicate inbox session', async () => {
+ // Create first inbox
+ await sessionModel.createInbox();
+
+ // Try to create another inbox
+ const duplicateInbox = await sessionModel.createInbox();
+
+ // Should return undefined as inbox already exists
+ expect(duplicateInbox).toBeUndefined();
+
+ // Verify only one inbox exists
+ const sessions = await serverDB.query.sessions.findMany();
+
+ const inboxSessions = sessions.filter((s) => s.slug === 'inbox');
+ expect(inboxSessions).toHaveLength(1);
+ });
+ });
+
+ describe('deleteAll', () => {
+ it('should delete all sessions for current user', async () => {
+ // Create test data
+ await serverDB.insert(sessions).values([
+ { id: '1', userId },
+ { id: '2', userId },
+ { id: '3', userId },
+ ]);
+
+ // Create sessions for another user that should not be deleted
+ await serverDB.insert(users).values([{ id: 'other-user' }]);
+ await serverDB.insert(sessions).values([
+ { id: '4', userId: 'other-user' },
+ { id: '5', userId: 'other-user' },
+ ]);
+
+ await sessionModel.deleteAll();
+
+ // Verify all sessions for current user are deleted
+ const remainingSessions = await serverDB
+ .select()
+ .from(sessions)
+ .where(eq(sessions.userId, userId));
+ expect(remainingSessions).toHaveLength(0);
+
+ // Verify other user's sessions are not deleted
+ const otherUserSessions = await serverDB
+ .select()
+ .from(sessions)
+ .where(eq(sessions.userId, 'other-user'));
+ expect(otherUserSessions).toHaveLength(2);
+ });
+
+ it('should delete associated data when deleting all sessions', async () => {
+ // Create test data with associated records
+ await serverDB.transaction(async (trx) => {
+ await trx.insert(sessions).values([
+ { id: '1', userId },
+ { id: '2', userId },
+ ]);
+
+ await trx.insert(topics).values([
+ { id: 't1', sessionId: '1', userId },
+ { id: 't2', sessionId: '2', userId },
+ ]);
+
+ await trx.insert(messages).values([
+ { id: 'm1', sessionId: '1', userId, role: 'user' },
+ { id: 'm2', sessionId: '2', userId, role: 'assistant' },
+ ]);
+ });
+
+ await sessionModel.deleteAll();
+
+ // Verify all associated data is deleted
+ const remainingTopics = await serverDB.select().from(topics).where(eq(topics.userId, userId));
+ expect(remainingTopics).toHaveLength(0);
+
+ const remainingMessages = await serverDB
+ .select()
+ .from(messages)
+ .where(eq(messages.userId, userId));
+ expect(remainingMessages).toHaveLength(0);
+ });
+ });
+
+ describe('updateConfig', () => {
+ it('should update agent config', async () => {
+ // Create test agent
+ const agentId = 'test-agent';
+ await serverDB.insert(agents).values({
+ id: agentId,
+ userId,
+ model: 'gpt-3.5-turbo',
+ title: 'Original Title',
+ });
+
+ // Update config
+ await sessionModel.updateConfig(agentId, {
+ model: 'gpt-4',
+ title: 'Updated Title',
+ description: 'New description',
+ });
+
+ // Verify update
+ const updatedAgent = await serverDB
+ .select()
+ .from(agents)
+ .where(and(eq(agents.id, agentId), eq(agents.userId, userId)));
+
+ expect(updatedAgent[0]).toMatchObject({
+ model: 'gpt-4',
+ title: 'Updated Title',
+ description: 'New description',
+ });
+ });
+
+ it('should not update config for other users agents', async () => {
+ // Create agent for another user
+ const agentId = 'other-agent';
+ await serverDB.insert(users).values([{ id: 'other-user' }]);
+ await serverDB.insert(agents).values({
+ id: agentId,
+ userId: 'other-user',
+ model: 'gpt-3.5-turbo',
+ title: 'Original Title',
+ });
+
+ // Try to update other user's agent
+ await sessionModel.updateConfig(agentId, {
+ model: 'gpt-4',
+ title: 'Updated Title',
+ });
+
+ // Verify no changes were made
+ const agent = await serverDB.select().from(agents).where(eq(agents.id, agentId));
+
+ expect(agent[0]).toMatchObject({
+ model: 'gpt-3.5-turbo',
+ title: 'Original Title',
+ });
+ });
+ });
});
diff --git a/src/database/server/models/__tests__/sessionGroup.test.ts b/src/database/server/models/__tests__/sessionGroup.test.ts
index a94be4447526..2341eecdc09a 100644
--- a/src/database/server/models/__tests__/sessionGroup.test.ts
+++ b/src/database/server/models/__tests__/sessionGroup.test.ts
@@ -10,14 +10,8 @@ import { SessionGroupModel } from '../sessionGroup';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'session-group-model-test-user-id';
-const sessionGroupModel = new SessionGroupModel(userId);
+const sessionGroupModel = new SessionGroupModel(serverDB, userId);
beforeEach(async () => {
await serverDB.delete(users);
@@ -75,7 +69,7 @@ describe('SessionGroupModel', () => {
await sessionGroupModel.create({ name: 'Test Group 1' });
await sessionGroupModel.create({ name: 'Test Group 333' });
- const anotherSessionGroupModel = new SessionGroupModel('user2');
+ const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2');
await anotherSessionGroupModel.create({ name: 'Test Group 2' });
await sessionGroupModel.deleteAll();
diff --git a/src/database/server/models/__tests__/topic.test.ts b/src/database/server/models/__tests__/topic.test.ts
index ef205f117986..358d1ec4bca2 100644
--- a/src/database/server/models/__tests__/topic.test.ts
+++ b/src/database/server/models/__tests__/topic.test.ts
@@ -8,15 +8,9 @@ import { CreateTopicParams, TopicModel } from '../topic';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'topic-user-test';
const sessionId = 'topic-session';
-const topicModel = new TopicModel(userId);
+const topicModel = new TopicModel(serverDB, userId);
describe('TopicModel', () => {
beforeEach(async () => {
diff --git a/src/database/server/models/__tests__/user.test.ts b/src/database/server/models/__tests__/user.test.ts
index 1004e5590cc1..9496a9fedb4e 100644
--- a/src/database/server/models/__tests__/user.test.ts
+++ b/src/database/server/models/__tests__/user.test.ts
@@ -13,15 +13,9 @@ import { UserModel } from '../user';
let serverDB = await getTestDBInstance();
-vi.mock('@/database/server/core/db', async () => ({
- get serverDB() {
- return serverDB;
- },
-}));
-
const userId = 'user-db';
const userEmail = 'user@example.com';
-const userModel = new UserModel();
+const userModel = new UserModel(serverDB, userId);
beforeEach(async () => {
await serverDB.delete(users);
@@ -44,14 +38,14 @@ describe('UserModel', () => {
email: 'test@example.com',
};
- await UserModel.createUser(params);
+ await UserModel.createUser(serverDB, params);
const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
expect(user).not.toBeNull();
expect(user?.username).toBe('testuser');
expect(user?.email).toBe('test@example.com');
- const sessionModel = new SessionModel(userId);
+ const sessionModel = new SessionModel(serverDB, userId);
const inbox = await sessionModel.findByIdOrSlug(INBOX_SESSION_ID);
expect(inbox).not.toBeNull();
});
@@ -61,7 +55,7 @@ describe('UserModel', () => {
it('should delete a user', async () => {
await serverDB.insert(users).values({ id: userId });
- await UserModel.deleteUser(userId);
+ await UserModel.deleteUser(serverDB, userId);
const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
expect(user).toBeUndefined();
@@ -72,7 +66,7 @@ describe('UserModel', () => {
it('should find a user by ID', async () => {
await serverDB.insert(users).values({ id: userId, username: 'testuser' });
- const user = await UserModel.findById(userId);
+ const user = await UserModel.findById(serverDB, userId);
expect(user).not.toBeNull();
expect(user?.id).toBe(userId);
@@ -84,7 +78,7 @@ describe('UserModel', () => {
it('should find a user by email', async () => {
await serverDB.insert(users).values({ id: userId, email: userEmail });
- const user = await UserModel.findByEmail(userEmail);
+ const user = await UserModel.findByEmail(serverDB, userEmail);
expect(user).not.toBeNull();
expect(user?.id).toBe(userId);
@@ -107,7 +101,7 @@ describe('UserModel', () => {
keyVaults: encryptedKeyVaults,
});
- const state = await userModel.getUserState(userId);
+ const state = await userModel.getUserState();
expect(state.userId).toBe(userId);
expect(state.preference).toEqual(preference);
@@ -115,7 +109,9 @@ describe('UserModel', () => {
});
it('should throw an error if user not found', async () => {
- await expect(userModel.getUserState('invalid-user-id')).rejects.toThrow('user not found');
+ const userModel = new UserModel(serverDB, 'invalid-user-id');
+
+ await expect(userModel.getUserState()).rejects.toThrow('user not found');
});
});
@@ -123,7 +119,7 @@ describe('UserModel', () => {
it('should update user fields', async () => {
await serverDB.insert(users).values({ id: userId, username: 'oldname' });
- await userModel.updateUser(userId, { username: 'newname' });
+ await userModel.updateUser({ username: 'newname' });
const updatedUser = await serverDB.query.users.findFirst({
where: eq(users.id, userId),
@@ -137,7 +133,7 @@ describe('UserModel', () => {
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(userSettings).values({ id: userId });
- await userModel.deleteSetting(userId);
+ await userModel.deleteSetting();
const settings = await serverDB.query.userSettings.findFirst({
where: eq(users.id, userId),
@@ -155,7 +151,7 @@ describe('UserModel', () => {
} as UserSettings;
await serverDB.insert(users).values({ id: userId });
- await userModel.updateSetting(userId, settings);
+ await userModel.updateSetting(settings);
const updatedSettings = await serverDB.query.userSettings.findFirst({
where: eq(users.id, userId),
@@ -178,7 +174,7 @@ describe('UserModel', () => {
const newSettings = {
general: { fontSize: 16, language: 'zh-CN', themeMode: 'dark' },
} as UserSettings;
- await userModel.updateSetting(userId, newSettings);
+ await userModel.updateSetting(newSettings);
const updatedSettings = await serverDB.query.userSettings.findFirst({
where: eq(users.id, userId),
@@ -195,7 +191,7 @@ describe('UserModel', () => {
const newPreference: Partial = {
guide: { topic: true, moveSettingsToAvatar: true },
};
- await userModel.updatePreference(userId, newPreference);
+ await userModel.updatePreference(newPreference);
const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
expect(updatedUser?.preference).toEqual({ ...preference, ...newPreference });
@@ -212,10 +208,49 @@ describe('UserModel', () => {
moveSettingsToAvatar: true,
uploadFileInKnowledgeBase: true,
};
- await userModel.updateGuide(userId, newGuide);
+ await userModel.updateGuide(newGuide);
const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) });
expect(updatedUser?.preference).toEqual({ ...preference, guide: newGuide });
});
});
+
+ describe('getUserApiKeys', () => {
+ it('should get and decrypt user API keys', async () => {
+ const keyVaults = { openai: { apiKey: 'test-key' } };
+ const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
+ const encryptedKeyVaults = await gateKeeper.encrypt(JSON.stringify(keyVaults));
+
+ const userId = 'user-api-id';
+
+ await serverDB.insert(users).values({ id: userId });
+ await serverDB.insert(userSettings).values({
+ id: userId,
+ keyVaults: encryptedKeyVaults,
+ });
+
+ const result = await UserModel.getUserApiKeys(serverDB, userId);
+ expect(result).toEqual(keyVaults);
+ });
+
+ it('should throw error when user not found', async () => {
+ await expect(UserModel.getUserApiKeys(serverDB, 'non-existent-id')).rejects.toThrow(
+ 'user not found',
+ );
+ });
+
+ it('should handle decrypt failure and return empty object', async () => {
+ const userId = 'user-api-test-id';
+ // 模拟解密失败的情况
+ const invalidEncryptedData = 'invalid:-encrypted-:data';
+ await serverDB.insert(users).values({ id: userId });
+ await serverDB.insert(userSettings).values({
+ id: userId,
+ keyVaults: invalidEncryptedData,
+ });
+
+ const result = await UserModel.getUserApiKeys(serverDB, userId);
+ expect(result).toEqual({});
+ });
+ });
});
diff --git a/src/database/server/models/_template.ts b/src/database/server/models/_template.ts
index 82c0efa3da2d..a802cbcec71e 100644
--- a/src/database/server/models/_template.ts
+++ b/src/database/server/models/_template.ts
@@ -1,19 +1,21 @@
import { eq } from 'drizzle-orm';
import { and, desc } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { NewSessionGroup, SessionGroupItem, sessionGroups } from '../../schemas';
export class TemplateModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (params: NewSessionGroup) => {
- const [result] = await serverDB
+ const [result] = await this.db
.insert(sessionGroups)
.values({ ...params, userId: this.userId })
.returning();
@@ -22,30 +24,30 @@ export class TemplateModel {
};
delete = async (id: string) => {
- return serverDB
+ return this.db
.delete(sessionGroups)
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
};
deleteAll = async () => {
- return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
+ return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
};
query = async () => {
- return serverDB.query.sessionGroups.findMany({
+ return this.db.query.sessionGroups.findMany({
orderBy: [desc(sessionGroups.updatedAt)],
where: eq(sessionGroups.userId, this.userId),
});
};
findById = async (id: string) => {
- return serverDB.query.sessionGroups.findFirst({
+ return this.db.query.sessionGroups.findFirst({
where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
});
};
async update(id: string, value: Partial) {
- return serverDB
+ return this.db
.update(sessionGroups)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
diff --git a/src/database/server/models/agent.ts b/src/database/server/models/agent.ts
index 34c627016c0b..03a779aabce1 100644
--- a/src/database/server/models/agent.ts
+++ b/src/database/server/models/agent.ts
@@ -1,7 +1,8 @@
import { inArray } from 'drizzle-orm';
import { and, desc, eq } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
+
import {
agents,
agentsFiles,
@@ -13,12 +14,15 @@ import {
export class AgentModel {
private userId: string;
- constructor(userId: string) {
+ private db: LobeChatDatabase;
+
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
async getAgentConfigById(id: string) {
- const agent = await serverDB.query.agents.findFirst({ where: eq(agents.id, id) });
+ const agent = await this.db.query.agents.findFirst({ where: eq(agents.id, id) });
const knowledge = await this.getAgentAssignedKnowledge(id);
@@ -26,14 +30,14 @@ export class AgentModel {
}
async getAgentAssignedKnowledge(id: string) {
- const knowledgeBaseResult = await serverDB
+ const knowledgeBaseResult = await this.db
.select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases })
.from(agentsKnowledgeBases)
.where(eq(agentsKnowledgeBases.agentId, id))
.orderBy(desc(agentsKnowledgeBases.createdAt))
.leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId));
- const fileResult = await serverDB
+ const fileResult = await this.db
.select({ enabled: agentsFiles.enabled, files })
.from(agentsFiles)
.where(eq(agentsFiles.agentId, id))
@@ -56,7 +60,7 @@ export class AgentModel {
* Find agent by session id
*/
async findBySessionId(sessionId: string) {
- const item = await serverDB.query.agentsToSessions.findFirst({
+ const item = await this.db.query.agentsToSessions.findFirst({
where: eq(agentsToSessions.sessionId, sessionId),
});
if (!item) return;
@@ -71,7 +75,7 @@ export class AgentModel {
knowledgeBaseId: string,
enabled: boolean = true,
) => {
- return serverDB
+ return this.db
.insert(agentsKnowledgeBases)
.values({
agentId,
@@ -83,7 +87,7 @@ export class AgentModel {
};
deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => {
- return serverDB
+ return this.db
.delete(agentsKnowledgeBases)
.where(
and(
@@ -96,7 +100,7 @@ export class AgentModel {
};
toggleKnowledgeBase = async (agentId: string, knowledgeBaseId: string, enabled?: boolean) => {
- return serverDB
+ return this.db
.update(agentsKnowledgeBases)
.set({ enabled })
.where(
@@ -111,7 +115,7 @@ export class AgentModel {
createAgentFiles = async (agentId: string, fileIds: string[], enabled: boolean = true) => {
// Exclude the fileIds that already exist in agentsFiles, and then insert them
- const existingFiles = await serverDB
+ const existingFiles = await this.db
.select({ id: agentsFiles.fileId })
.from(agentsFiles)
.where(
@@ -128,7 +132,7 @@ export class AgentModel {
if (needToInsertFileIds.length === 0) return;
- return serverDB
+ return this.db
.insert(agentsFiles)
.values(
needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })),
@@ -137,7 +141,7 @@ export class AgentModel {
};
deleteAgentFile = async (agentId: string, fileId: string) => {
- return serverDB
+ return this.db
.delete(agentsFiles)
.where(
and(
@@ -150,7 +154,7 @@ export class AgentModel {
};
toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => {
- return serverDB
+ return this.db
.update(agentsFiles)
.set({ enabled })
.where(
diff --git a/src/database/server/models/asyncTask.ts b/src/database/server/models/asyncTask.ts
index 48ed6b210909..5eb90fec31d7 100644
--- a/src/database/server/models/asyncTask.ts
+++ b/src/database/server/models/asyncTask.ts
@@ -1,7 +1,7 @@
import { eq, inArray, lt } from 'drizzle-orm';
import { and } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import {
AsyncTaskError,
AsyncTaskErrorType,
@@ -16,13 +16,15 @@ export const ASYNC_TASK_TIMEOUT = 298 * 1000;
export class AsyncTaskModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (params: Pick): Promise => {
- const data = await serverDB
+ const data = await this.db
.insert(asyncTasks)
.values({ ...params, userId: this.userId })
.returning();
@@ -31,17 +33,17 @@ export class AsyncTaskModel {
};
delete = async (id: string) => {
- return serverDB
+ return this.db
.delete(asyncTasks)
.where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId)));
};
findById = async (id: string) => {
- return serverDB.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) });
+ return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) });
};
update(taskId: string, value: Partial) {
- return serverDB
+ return this.db
.update(asyncTasks)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(asyncTasks.id, taskId)));
@@ -52,7 +54,7 @@ export class AsyncTaskModel {
if (taskIds.length > 0) {
await this.checkTimeoutTasks(taskIds);
- chunkTasks = await serverDB.query.asyncTasks.findMany({
+ chunkTasks = await this.db.query.asyncTasks.findMany({
where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)),
});
}
@@ -64,7 +66,7 @@ export class AsyncTaskModel {
* make the task status to be `error` if the task is not finished in 20 seconds
*/
async checkTimeoutTasks(ids: string[]) {
- const tasks = await serverDB
+ const tasks = await this.db
.select({ id: asyncTasks.id })
.from(asyncTasks)
.where(
@@ -76,7 +78,7 @@ export class AsyncTaskModel {
);
if (tasks.length > 0) {
- await serverDB
+ await this.db
.update(asyncTasks)
.set({
error: new AsyncTaskError(
diff --git a/src/database/server/models/chunk.ts b/src/database/server/models/chunk.ts
index a00c434b4b2a..6b554fd58435 100644
--- a/src/database/server/models/chunk.ts
+++ b/src/database/server/models/chunk.ts
@@ -2,7 +2,7 @@ import { asc, cosineDistance, count, eq, inArray, sql } from 'drizzle-orm';
import { and, desc, isNull } from 'drizzle-orm/expressions';
import { chunk } from 'lodash-es';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { ChunkMetadata, FileChunk } from '@/types/chunk';
import {
@@ -18,12 +18,17 @@ import {
export class ChunkModel {
private userId: string;
- constructor(userId: string) {
+ private db: LobeChatDatabase;
+
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
bulkCreate = async (params: NewChunkItem[], fileId: string) => {
- return serverDB.transaction(async (trx) => {
+ return this.db.transaction(async (trx) => {
+ if (params.length === 0) return [];
+
const result = await trx.insert(chunks).values(params).returning();
const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId }));
@@ -37,15 +42,15 @@ export class ChunkModel {
};
bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => {
- return serverDB.insert(unstructuredChunks).values(params);
+ return this.db.insert(unstructuredChunks).values(params);
};
delete = async (id: string) => {
- return serverDB.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId)));
+ return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId)));
};
deleteOrphanChunks = async () => {
- const orphanedChunks = await serverDB
+ const orphanedChunks = await this.db
.select({ chunkId: chunks.id })
.from(chunks)
.leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
@@ -56,7 +61,7 @@ export class ChunkModel {
const list = chunk(ids, 500);
- await serverDB.transaction(async (trx) => {
+ await this.db.transaction(async (trx) => {
await Promise.all(
list.map(async (chunkIds) => {
await trx.delete(chunks).where(inArray(chunks.id, chunkIds));
@@ -66,13 +71,13 @@ export class ChunkModel {
};
findById = async (id: string) => {
- return serverDB.query.chunks.findFirst({
+ return this.db.query.chunks.findFirst({
where: and(eq(chunks.id, id)),
});
};
async findByFileId(id: string, page = 0) {
- const data = await serverDB
+ const data = await this.db
.select({
abstract: chunks.abstract,
createdAt: chunks.createdAt,
@@ -98,7 +103,7 @@ export class ChunkModel {
}
async getChunksTextByFileId(id: string): Promise<{ id: string; text: string }[]> {
- const data = await serverDB
+ const data = await this.db
.select()
.from(chunks)
.innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
@@ -113,7 +118,7 @@ export class ChunkModel {
async countByFileIds(ids: string[]) {
if (ids.length === 0) return [];
- return serverDB
+ return this.db
.select({
count: count(fileChunks.chunkId),
id: fileChunks.fileId,
@@ -124,7 +129,7 @@ export class ChunkModel {
}
async countByFileId(ids: string) {
- const data = await serverDB
+ const data = await this.db
.select({
count: count(fileChunks.chunkId),
id: fileChunks.fileId,
@@ -146,7 +151,7 @@ export class ChunkModel {
}) {
const similarity = sql`1 - (${cosineDistance(embeddings.embeddings, embedding)})`;
- const data = await serverDB
+ const data = await this.db
.select({
fileId: fileChunks.fileId,
fileName: files.name,
@@ -185,7 +190,7 @@ export class ChunkModel {
if (!hasFiles) return [];
- const result = await serverDB
+ const result = await this.db
.select({
fileId: files.id,
fileName: files.name,
diff --git a/src/database/server/models/embedding.ts b/src/database/server/models/embedding.ts
index 45f5980ebf88..123e5e3a6083 100644
--- a/src/database/server/models/embedding.ts
+++ b/src/database/server/models/embedding.ts
@@ -1,19 +1,21 @@
import { count, eq } from 'drizzle-orm';
import { and } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { NewEmbeddingsItem, embeddings } from '../../schemas';
export class EmbeddingModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (value: Omit) => {
- const [item] = await serverDB
+ const [item] = await this.db
.insert(embeddings)
.values({ ...value, userId: this.userId })
.returning();
@@ -22,7 +24,7 @@ export class EmbeddingModel {
};
bulkCreate = async (values: Omit[]) => {
- return serverDB
+ return this.db
.insert(embeddings)
.values(values.map((item) => ({ ...item, userId: this.userId })))
.onConflictDoNothing({
@@ -31,25 +33,25 @@ export class EmbeddingModel {
};
delete = async (id: string) => {
- return serverDB
+ return this.db
.delete(embeddings)
.where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)));
};
query = async () => {
- return serverDB.query.embeddings.findMany({
+ return this.db.query.embeddings.findMany({
where: eq(embeddings.userId, this.userId),
});
};
findById = async (id: string) => {
- return serverDB.query.embeddings.findFirst({
+ return this.db.query.embeddings.findFirst({
where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)),
});
};
countUsage = async () => {
- const result = await serverDB
+ const result = await this.db
.select({
count: count(),
})
diff --git a/src/database/server/models/file.ts b/src/database/server/models/file.ts
index 8a01bfe71400..7c1f4ec6e5d2 100644
--- a/src/database/server/models/file.ts
+++ b/src/database/server/models/file.ts
@@ -3,7 +3,7 @@ import { and, desc, like } from 'drizzle-orm/expressions';
import type { PgTransaction } from 'drizzle-orm/pg-core';
import { serverDBEnv } from '@/config/db';
-import { serverDB } from '@/database/server/core/db';
+import { LobeChatDatabase } from '@/database/type';
import { FilesTabs, QueryFileListParams, SortType } from '@/types/files';
import {
@@ -20,13 +20,15 @@ import {
export class FileModel {
private readonly userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (params: Omit & { knowledgeBaseId?: string }) => {
- const result = await serverDB.transaction(async (trx) => {
+ const result = await this.db.transaction(async (trx) => {
const result = await trx
.insert(files)
.values({ ...params, userId: this.userId })
@@ -47,11 +49,11 @@ export class FileModel {
};
createGlobalFile = async (file: Omit) => {
- return serverDB.insert(globalFiles).values(file).returning();
+ return this.db.insert(globalFiles).values(file).returning();
};
checkHash = async (hash: string) => {
- const item = await serverDB.query.globalFiles.findFirst({
+ const item = await this.db.query.globalFiles.findFirst({
where: eq(globalFiles.hashId, hash),
});
if (!item) return { isExist: false };
@@ -71,7 +73,7 @@ export class FileModel {
const fileHash = file.fileHash!;
- return await serverDB.transaction(async (trx) => {
+ return await this.db.transaction(async (trx) => {
// 1. 删除相关的 chunks
await this.deleteFileChunks(trx as any, [id]);
@@ -96,11 +98,11 @@ export class FileModel {
};
deleteGlobalFile = async (hashId: string) => {
- return serverDB.delete(globalFiles).where(eq(globalFiles.hashId, hashId));
+ return this.db.delete(globalFiles).where(eq(globalFiles.hashId, hashId));
};
countUsage = async () => {
- const result = await serverDB
+ const result = await this.db
.select({
totalSize: sum(files.size),
})
@@ -114,7 +116,7 @@ export class FileModel {
const fileList = await this.findByIds(ids);
const hashList = fileList.map((file) => file.fileHash!);
- return await serverDB.transaction(async (trx) => {
+ return await this.db.transaction(async (trx) => {
// 1. 删除相关的 chunks
await this.deleteFileChunks(trx as any, ids);
@@ -159,7 +161,7 @@ export class FileModel {
};
clear = async () => {
- return serverDB.delete(files).where(eq(files.userId, this.userId));
+ return this.db.delete(files).where(eq(files.userId, this.userId));
};
query = async ({
@@ -198,7 +200,7 @@ export class FileModel {
}
// 3. build query
- let query = serverDB
+ let query = this.db
.select({
chunkTaskId: files.chunkTaskId,
createdAt: files.createdAt,
@@ -230,7 +232,7 @@ export class FileModel {
whereClause = and(
whereClause,
notExists(
- serverDB.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)),
+ this.db.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)),
),
);
}
@@ -240,19 +242,19 @@ export class FileModel {
};
findByIds = async (ids: string[]) => {
- return serverDB.query.files.findMany({
+ return this.db.query.files.findMany({
where: and(inArray(files.id, ids), eq(files.userId, this.userId)),
});
};
findById = async (id: string) => {
- return serverDB.query.files.findFirst({
+ return this.db.query.files.findFirst({
where: and(eq(files.id, id), eq(files.userId, this.userId)),
});
};
countFilesByHash = async (hash: string) => {
- const result = await serverDB
+ const result = await this.db
.select({
count: count(),
})
@@ -263,7 +265,7 @@ export class FileModel {
};
async update(id: string, value: Partial) {
- return serverDB
+ return this.db
.update(files)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(files.id, id), eq(files.userId, this.userId)));
@@ -293,7 +295,7 @@ export class FileModel {
};
async findByNames(fileNames: string[]) {
- return serverDB.query.files.findMany({
+ return this.db.query.files.findMany({
where: and(
or(...fileNames.map((name) => like(files.name, `${name}%`))),
eq(files.userId, this.userId),
diff --git a/src/database/server/models/knowledgeBase.ts b/src/database/server/models/knowledgeBase.ts
index 844558667a5b..6c765ab51603 100644
--- a/src/database/server/models/knowledgeBase.ts
+++ b/src/database/server/models/knowledgeBase.ts
@@ -1,22 +1,24 @@
import { eq, inArray } from 'drizzle-orm';
import { and, desc } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { KnowledgeBaseItem } from '@/types/knowledgeBase';
import { NewKnowledgeBase, knowledgeBaseFiles, knowledgeBases } from '../../schemas';
export class KnowledgeBaseModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
// create
create = async (params: Omit) => {
- const [result] = await serverDB
+ const [result] = await this.db
.insert(knowledgeBases)
.values({ ...params, userId: this.userId })
.returning();
@@ -25,7 +27,7 @@ export class KnowledgeBaseModel {
};
addFilesToKnowledgeBase = async (id: string, fileIds: string[]) => {
- return serverDB
+ return this.db
.insert(knowledgeBaseFiles)
.values(fileIds.map((fileId) => ({ fileId, knowledgeBaseId: id, userId: this.userId })))
.returning();
@@ -33,17 +35,17 @@ export class KnowledgeBaseModel {
// delete
delete = async (id: string) => {
- return serverDB
+ return this.db
.delete(knowledgeBases)
.where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)));
};
deleteAll = async () => {
- return serverDB.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId));
+ return this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId));
};
removeFilesFromKnowledgeBase = async (knowledgeBaseId: string, ids: string[]) => {
- return serverDB.delete(knowledgeBaseFiles).where(
+ return this.db.delete(knowledgeBaseFiles).where(
and(
eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId),
inArray(knowledgeBaseFiles.fileId, ids),
@@ -53,7 +55,7 @@ export class KnowledgeBaseModel {
};
// query
query = async () => {
- const data = await serverDB
+ const data = await this.db
.select({
avatar: knowledgeBases.avatar,
createdAt: knowledgeBases.createdAt,
@@ -73,21 +75,21 @@ export class KnowledgeBaseModel {
};
findById = async (id: string) => {
- return serverDB.query.knowledgeBases.findFirst({
+ return this.db.query.knowledgeBases.findFirst({
where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)),
});
};
// update
async update(id: string, value: Partial) {
- return serverDB
+ return this.db
.update(knowledgeBases)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)));
}
- static async findById(id: string) {
- return serverDB.query.knowledgeBases.findFirst({
+ static async findById(db: LobeChatDatabase, id: string) {
+ return db.query.knowledgeBases.findFirst({
where: eq(knowledgeBases.id, id),
});
}
diff --git a/src/database/server/models/message.ts b/src/database/server/models/message.ts
index e4e43e3b01a3..000a8f53c064 100644
--- a/src/database/server/models/message.ts
+++ b/src/database/server/models/message.ts
@@ -1,8 +1,8 @@
import { count } from 'drizzle-orm';
import { and, asc, desc, eq, gte, inArray, isNull, like, lt } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server/core/db';
import { idGenerator } from '@/database/utils/idGenerator';
+import { LobeChatDatabase } from '@/database/type';
import { getFullFileUrl } from '@/server/utils/files';
import {
ChatFileItem,
@@ -39,9 +39,11 @@ export interface QueryMessageParams {
export class MessageModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
// **************** Query *************** //
@@ -54,7 +56,7 @@ export class MessageModel {
const offset = current * pageSize;
// 1. get basic messages
- const result = await serverDB
+ const result = await this.db
.select({
/* eslint-disable sort-keys-fix/sort-keys-fix*/
id: messages.id,
@@ -115,7 +117,7 @@ export class MessageModel {
if (messageIds.length === 0) return [];
// 2. get relative files
- const rawRelatedFileList = await serverDB
+ const rawRelatedFileList = await this.db
.select({
fileType: files.fileType,
id: messagesFiles.fileId,
@@ -139,7 +141,7 @@ export class MessageModel {
const fileList = relatedFileList.filter((i) => !(i.fileType || '').startsWith('image'));
// 3. get relative file chunks
- const chunksList = await serverDB
+ const chunksList = await this.db
.select({
fileId: files.id,
fileType: files.fileType,
@@ -157,7 +159,7 @@ export class MessageModel {
.where(inArray(messageQueryChunks.messageId, messageIds));
// 3. get relative message query
- const messageQueriesList = await serverDB
+ const messageQueriesList = await this.db
.select({
id: messageQueries.id,
messageId: messageQueries.messageId,
@@ -216,13 +218,13 @@ export class MessageModel {
}
async findById(id: string) {
- return serverDB.query.messages.findFirst({
+ return this.db.query.messages.findFirst({
where: and(eq(messages.id, id), eq(messages.userId, this.userId)),
});
}
async findMessageQueriesById(messageId: string) {
- const result = await serverDB
+ const result = await this.db
.select({
embeddings: embeddings.embeddings,
id: messageQueries.id,
@@ -240,7 +242,7 @@ export class MessageModel {
}
async queryAll(): Promise {
- return serverDB
+ return this.db
.select()
.from(messages)
.orderBy(messages.createdAt)
@@ -250,7 +252,7 @@ export class MessageModel {
}
async queryBySessionId(sessionId?: string | null): Promise {
- return serverDB.query.messages.findMany({
+ return this.db.query.messages.findMany({
orderBy: [asc(messages.createdAt)],
where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)),
});
@@ -259,14 +261,14 @@ export class MessageModel {
async queryByKeyword(keyword: string): Promise {
if (!keyword) return [];
- return serverDB.query.messages.findMany({
+ return this.db.query.messages.findMany({
orderBy: [desc(messages.createdAt)],
where: and(eq(messages.userId, this.userId), like(messages.content, `%${keyword}%`)),
});
}
async count() {
- const result = await serverDB
+ const result = await this.db
.select({
count: count(),
})
@@ -283,7 +285,7 @@ export class MessageModel {
const tomorrow = new Date(today);
tomorrow.setDate(tomorrow.getDate() + 1);
- const result = await serverDB
+ const result = await this.db
.select({
count: count(),
})
@@ -315,7 +317,7 @@ export class MessageModel {
}: CreateMessageParams,
id: string = this.genId(),
): Promise {
- return serverDB.transaction(async (trx) => {
+ return this.db.transaction(async (trx) => {
const [item] = (await trx
.insert(messages)
.values({
@@ -366,72 +368,72 @@ export class MessageModel {
return { ...m, userId: this.userId };
});
- return serverDB.insert(messages).values(messagesToInsert);
+ return this.db.insert(messages).values(messagesToInsert);
}
async createMessageQuery(params: NewMessageQuery) {
- const result = await serverDB.insert(messageQueries).values(params).returning();
+ const result = await this.db.insert(messageQueries).values(params).returning();
return result[0];
}
// **************** Update *************** //
async update(id: string, message: Partial) {
- return serverDB
+ return this.db
.update(messages)
.set(message)
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)));
}
async updatePluginState(id: string, state: Record) {
- const item = await serverDB.query.messagePlugins.findFirst({
+ const item = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
});
if (!item) throw new Error('Plugin not found');
- return serverDB
+ return this.db
.update(messagePlugins)
.set({ state: merge(item.state || {}, state) })
.where(eq(messagePlugins.id, id));
}
async updateMessagePlugin(id: string, value: Partial) {
- const item = await serverDB.query.messagePlugins.findFirst({
+ const item = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
});
if (!item) throw new Error('Plugin not found');
- return serverDB.update(messagePlugins).set(value).where(eq(messagePlugins.id, id));
+ return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id));
}
async updateTranslate(id: string, translate: Partial) {
- const result = await serverDB.query.messageTranslates.findFirst({
+ const result = await this.db.query.messageTranslates.findFirst({
where: and(eq(messageTranslates.id, id)),
});
// If the message does not exist in the translate table, insert it
if (!result) {
- return serverDB.insert(messageTranslates).values({ ...translate, id });
+ return this.db.insert(messageTranslates).values({ ...translate, id });
}
// or just update the existing one
- return serverDB.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id));
+ return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id));
}
async updateTTS(id: string, tts: Partial) {
- const result = await serverDB.query.messageTTS.findFirst({
+ const result = await this.db.query.messageTTS.findFirst({
where: and(eq(messageTTS.id, id)),
});
// If the message does not exist in the translate table, insert it
if (!result) {
- return serverDB
+ return this.db
.insert(messageTTS)
.values({ contentMd5: tts.contentMd5, fileId: tts.file, id, voice: tts.voice });
}
// or just update the existing one
- return serverDB
+ return this.db
.update(messageTTS)
.set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice })
.where(eq(messageTTS.id, id));
@@ -440,7 +442,7 @@ export class MessageModel {
// **************** Delete *************** //
async deleteMessage(id: string) {
- return serverDB.transaction(async (tx) => {
+ return this.db.transaction(async (tx) => {
// 1. 查询要删除的 message 的完整信息
const message = await tx
.select()
@@ -476,25 +478,25 @@ export class MessageModel {
}
async deleteMessages(ids: string[]) {
- return serverDB
+ return this.db
.delete(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.id, ids)));
}
async deleteMessageTranslate(id: string) {
- return serverDB.delete(messageTranslates).where(and(eq(messageTranslates.id, id)));
+ return this.db.delete(messageTranslates).where(and(eq(messageTranslates.id, id)));
}
async deleteMessageTTS(id: string) {
- return serverDB.delete(messageTTS).where(and(eq(messageTTS.id, id)));
+ return this.db.delete(messageTTS).where(and(eq(messageTTS.id, id)));
}
async deleteMessageQuery(id: string) {
- return serverDB.delete(messageQueries).where(and(eq(messageQueries.id, id)));
+ return this.db.delete(messageQueries).where(and(eq(messageQueries.id, id)));
}
async deleteMessagesBySession(sessionId?: string | null, topicId?: string | null) {
- return serverDB
+ return this.db
.delete(messages)
.where(
and(
@@ -506,7 +508,7 @@ export class MessageModel {
}
async deleteAllMessages() {
- return serverDB.delete(messages).where(eq(messages.userId, this.userId));
+ return this.db.delete(messages).where(eq(messages.userId, this.userId));
}
// **************** Helper *************** //
diff --git a/src/database/server/models/plugin.ts b/src/database/server/models/plugin.ts
index 7b07c9a8a873..5b5293082c9b 100644
--- a/src/database/server/models/plugin.ts
+++ b/src/database/server/models/plugin.ts
@@ -1,20 +1,22 @@
import { and, desc, eq } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { InstalledPluginItem, NewInstalledPlugin, installedPlugins } from '../../schemas';
export class PluginModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (
params: Pick,
) => {
- const [result] = await serverDB
+ const [result] = await this.db
.insert(installedPlugins)
.values({ ...params, createdAt: new Date(), updatedAt: new Date(), userId: this.userId })
.returning();
@@ -23,17 +25,17 @@ export class PluginModel {
};
delete = async (id: string) => {
- return serverDB
+ return this.db
.delete(installedPlugins)
.where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)));
};
deleteAll = async () => {
- return serverDB.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId));
+ return this.db.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId));
};
query = async () => {
- return serverDB
+ return this.db
.select({
createdAt: installedPlugins.createdAt,
customParams: installedPlugins.customParams,
@@ -49,13 +51,13 @@ export class PluginModel {
};
findById = async (id: string) => {
- return serverDB.query.installedPlugins.findFirst({
+ return this.db.query.installedPlugins.findFirst({
where: and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)),
});
};
async update(id: string, value: Partial) {
- return serverDB
+ return this.db
.update(installedPlugins)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)));
diff --git a/src/database/server/models/session.ts b/src/database/server/models/session.ts
index a41b78f89ca7..de9e61fcbec9 100644
--- a/src/database/server/models/session.ts
+++ b/src/database/server/models/session.ts
@@ -1,10 +1,11 @@
import { Column, asc, count, inArray, like, sql } from 'drizzle-orm';
-import { and, desc, eq, isNull, not, or } from 'drizzle-orm/expressions';
+import { and, desc, eq, not, or } from 'drizzle-orm/expressions';
import { appEnv } from '@/config/app';
import { INBOX_SESSION_ID } from '@/const/session';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
-import { serverDB } from '@/database/server/core/db';
+import { LobeChatDatabase } from '@/database/type';
+import { idGenerator } from '@/database/utils/idGenerator';
import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent';
import { ChatSessionList, LobeAgentSession } from '@/types/session';
import { merge } from '@/utils/merge';
@@ -19,20 +20,21 @@ import {
sessionGroups,
sessions,
} from '../../schemas';
-import { idGenerator } from '@/database/utils/idGenerator';
export class SessionModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
// **************** Query *************** //
async query({ current = 0, pageSize = 9999 } = {}) {
const offset = current * pageSize;
- return serverDB.query.sessions.findMany({
+ return this.db.query.sessions.findMany({
limit: pageSize,
offset,
orderBy: [desc(sessions.updatedAt)],
@@ -45,7 +47,7 @@ export class SessionModel {
// 查询所有会话
const result = await this.query();
- const groups = await serverDB.query.sessionGroups.findMany({
+ const groups = await this.db.query.sessionGroups.findMany({
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
where: eq(sessions.userId, this.userId),
});
@@ -69,7 +71,7 @@ export class SessionModel {
async findByIdOrSlug(
idOrSlug: string,
): Promise<(SessionItem & { agent: AgentItem }) | undefined> {
- const result = await serverDB.query.sessions.findFirst({
+ const result = await this.db.query.sessions.findFirst({
where: and(
or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)),
eq(sessions.userId, this.userId),
@@ -83,7 +85,7 @@ export class SessionModel {
}
async count() {
- const result = await serverDB
+ const result = await this.db
.select({
count: count(),
})
@@ -109,7 +111,7 @@ export class SessionModel {
slug?: string;
type: 'agent' | 'group';
}): Promise {
- return serverDB.transaction(async (trx) => {
+ return this.db.transaction(async (trx) => {
const newAgents = await trx
.insert(agents)
.values({
@@ -144,7 +146,7 @@ export class SessionModel {
}
async createInbox() {
- const item = await serverDB.query.sessions.findFirst({
+ const item = await this.db.query.sessions.findFirst({
where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)),
});
if (item) return;
@@ -167,7 +169,7 @@ export class SessionModel {
};
});
- return serverDB.insert(sessions).values(sessionsToInsert);
+ return this.db.insert(sessions).values(sessionsToInsert);
}
async duplicate(id: string, newTitle?: string) {
@@ -199,7 +201,7 @@ export class SessionModel {
* Delete a session, also delete all messages and topics associated with it.
*/
async delete(id: string) {
- return serverDB
+ return this.db
.delete(sessions)
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)));
}
@@ -208,18 +210,18 @@ export class SessionModel {
* Batch delete sessions, also delete all messages and topics associated with them.
*/
async batchDelete(ids: string[]) {
- return serverDB
+ return this.db
.delete(sessions)
.where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId)));
}
async deleteAll() {
- return serverDB.delete(sessions).where(eq(sessions.userId, this.userId));
+ return this.db.delete(sessions).where(eq(sessions.userId, this.userId));
}
// **************** Update *************** //
async update(id: string, data: Partial) {
- return serverDB
+ return this.db
.update(sessions)
.set(data)
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)))
@@ -227,7 +229,7 @@ export class SessionModel {
}
async updateConfig(id: string, data: Partial) {
- return serverDB
+ return this.db
.update(agents)
.set(data)
.where(and(eq(agents.id, id), eq(agents.userId, this.userId)));
@@ -262,65 +264,22 @@ export class SessionModel {
} as any;
};
- async findSessions(params: {
- current?: number;
- group?: string;
- keyword?: string;
- pageSize?: number;
- pinned?: boolean;
- }) {
- const { pinned, keyword, group, pageSize = 9999, current = 0 } = params;
-
- const offset = current * pageSize;
- return serverDB.query.sessions.findMany({
- limit: pageSize,
- offset,
- orderBy: [desc(sessions.updatedAt)],
- where: and(
- eq(sessions.userId, this.userId),
- pinned !== undefined ? eq(sessions.pinned, pinned) : eq(sessions.userId, this.userId),
- keyword
- ? or(
- like(
- sql`lower(${sessions.title})` as unknown as Column,
- `%${keyword.toLowerCase()}%`,
- ),
- like(
- sql`lower(${sessions.description})` as unknown as Column,
- `%${keyword.toLowerCase()}%`,
- ),
- )
- : eq(sessions.userId, this.userId),
- group ? eq(sessions.groupId, group) : isNull(sessions.groupId),
- ),
-
- with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true },
- });
- }
-
- async findSessionsByKeywords(params: {
- current?: number;
- keyword: string;
- pageSize?: number;
- }) {
+ async findSessionsByKeywords(params: { current?: number; keyword: string; pageSize?: number }) {
const { keyword, pageSize = 9999, current = 0 } = params;
const offset = current * pageSize;
- const results = await serverDB.query.agents.findMany({
+ const results = await this.db.query.agents.findMany({
limit: pageSize,
offset,
orderBy: [desc(agents.updatedAt)],
where: and(
eq(agents.userId, this.userId),
or(
- like(
- sql`lower(${agents.title})` as unknown as Column,
- `%${keyword.toLowerCase()}%`,
- ),
+ like(sql`lower(${agents.title})` as unknown as Column, `%${keyword.toLowerCase()}%`),
like(
sql`lower(${agents.description})` as unknown as Column,
`%${keyword.toLowerCase()}%`,
),
- )
+ ),
),
with: { agentsToSessions: { columns: {}, with: { session: true } } },
});
@@ -328,6 +287,6 @@ export class SessionModel {
// @ts-expect-error
return results.map((item) => item.agentsToSessions[0].session);
} catch {}
- return []
+ return [];
}
}
diff --git a/src/database/server/models/sessionGroup.ts b/src/database/server/models/sessionGroup.ts
index 16df2a2fab37..14d2dc48a25c 100644
--- a/src/database/server/models/sessionGroup.ts
+++ b/src/database/server/models/sessionGroup.ts
@@ -1,20 +1,22 @@
import { eq } from 'drizzle-orm';
import { and, asc, desc } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { idGenerator } from '@/database/utils/idGenerator';
import { SessionGroupItem, sessionGroups } from '../../schemas';
export class SessionGroupModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (params: { name: string; sort?: number }) => {
- const [result] = await serverDB
+ const [result] = await this.db
.insert(sessionGroups)
.values({ ...params, id: this.genId(), userId: this.userId })
.returning();
@@ -23,37 +25,37 @@ export class SessionGroupModel {
};
delete = async (id: string) => {
- return serverDB
+ return this.db
.delete(sessionGroups)
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
};
deleteAll = async () => {
- return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
+ return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
};
query = async () => {
- return serverDB.query.sessionGroups.findMany({
+ return this.db.query.sessionGroups.findMany({
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
where: eq(sessionGroups.userId, this.userId),
});
};
findById = async (id: string) => {
- return serverDB.query.sessionGroups.findFirst({
+ return this.db.query.sessionGroups.findFirst({
where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
});
};
async update(id: string, value: Partial) {
- return serverDB
+ return this.db
.update(sessionGroups)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
}
async updateOrder(sortMap: { id: string; sort: number }[]) {
- await serverDB.transaction(async (tx) => {
+ await this.db.transaction(async (tx) => {
const updates = sortMap.map(({ id, sort }) => {
return tx
.update(sessionGroups)
diff --git a/src/database/server/models/thread.ts b/src/database/server/models/thread.ts
index 791510470e17..9d65e698ee25 100644
--- a/src/database/server/models/thread.ts
+++ b/src/database/server/models/thread.ts
@@ -1,7 +1,7 @@
import { eq } from 'drizzle-orm';
import { and, desc } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server';
+import { LobeChatDatabase } from '@/database/type';
import { CreateThreadParams, ThreadStatus } from '@/types/topic';
import { ThreadItem, threads } from '../../schemas';
@@ -20,14 +20,16 @@ const queryColumns = {
export class ThreadModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
create = async (params: CreateThreadParams) => {
// @ts-ignore
- const [result] = await serverDB
+ const [result] = await this.db
.insert(threads)
.values({ ...params, status: ThreadStatus.Active, userId: this.userId })
.onConflictDoNothing()
@@ -37,15 +39,15 @@ export class ThreadModel {
};
delete = async (id: string) => {
- return serverDB.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId)));
+ return this.db.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId)));
};
deleteAll = async () => {
- return serverDB.delete(threads).where(eq(threads.userId, this.userId));
+ return this.db.delete(threads).where(eq(threads.userId, this.userId));
};
query = async () => {
- const data = await serverDB
+ const data = await this.db
.select(queryColumns)
.from(threads)
.where(eq(threads.userId, this.userId))
@@ -55,7 +57,7 @@ export class ThreadModel {
};
queryByTopicId = async (topicId: string) => {
- const data = await serverDB
+ const data = await this.db
.select(queryColumns)
.from(threads)
.where(and(eq(threads.topicId, topicId), eq(threads.userId, this.userId)))
@@ -65,13 +67,13 @@ export class ThreadModel {
};
findById = async (id: string) => {
- return serverDB.query.threads.findFirst({
+ return this.db.query.threads.findFirst({
where: and(eq(threads.id, id), eq(threads.userId, this.userId)),
});
};
async update(id: string, value: Partial) {
- return serverDB
+ return this.db
.update(threads)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(threads.id, id), eq(threads.userId, this.userId)));
diff --git a/src/database/server/models/topic.ts b/src/database/server/models/topic.ts
index 0e09973d5727..ad914d614fe6 100644
--- a/src/database/server/models/topic.ts
+++ b/src/database/server/models/topic.ts
@@ -1,7 +1,7 @@
import { Column, count, inArray, sql } from 'drizzle-orm';
import { and, desc, eq, exists, isNull, like, or } from 'drizzle-orm/expressions';
-import { serverDB } from '@/database/server/core/db';
+import { LobeChatDatabase } from '@/database/type';
import { NewMessage, TopicItem, messages, topics } from '../../schemas';
import { idGenerator } from '@/database/utils/idGenerator';
@@ -21,9 +21,11 @@ interface QueryTopicParams {
export class TopicModel {
private userId: string;
+ private db: LobeChatDatabase;
- constructor(userId: string) {
+ constructor(db: LobeChatDatabase, userId: string) {
this.userId = userId;
+ this.db = db;
}
// **************** Query *************** //
@@ -31,7 +33,7 @@ export class TopicModel {
const offset = current * pageSize;
return (
- serverDB
+ this.db
.select({
createdAt: topics.createdAt,
favorite: topics.favorite,
@@ -52,13 +54,13 @@ export class TopicModel {
}
async findById(id: string) {
- return serverDB.query.topics.findFirst({
+ return this.db.query.topics.findFirst({
where: and(eq(topics.id, id), eq(topics.userId, this.userId)),
});
}
async queryAll(): Promise {
- return serverDB
+ return this.db
.select()
.from(topics)
.orderBy(topics.updatedAt)
@@ -74,7 +76,7 @@ export class TopicModel {
const matchKeyword = (field: any) =>
like(sql`lower(${field})` as unknown as Column, `%${keywordLowerCase}%`);
- return serverDB.query.topics.findMany({
+ return this.db.query.topics.findMany({
orderBy: [desc(topics.updatedAt)],
where: and(
eq(topics.userId, this.userId),
@@ -82,15 +84,10 @@ export class TopicModel {
or(
matchKeyword(topics.title),
exists(
- serverDB
+ this.db
.select()
.from(messages)
- .where(
- and(
- eq(messages.topicId, topics.id),
- matchKeyword(messages.content)
- )
- ),
+ .where(and(eq(messages.topicId, topics.id), matchKeyword(messages.content))),
),
),
),
@@ -98,7 +95,7 @@ export class TopicModel {
}
async count() {
- const result = await serverDB
+ const result = await this.db
.select({
count: count(),
})
@@ -115,7 +112,7 @@ export class TopicModel {
{ messages: messageIds, ...params }: CreateTopicParams,
id: string = this.genId(),
): Promise {
- return serverDB.transaction(async (tx) => {
+ return this.db.transaction(async (tx) => {
// 在 topics 表中插入新的 topic
const [topic] = await tx
.insert(topics)
@@ -140,7 +137,7 @@ export class TopicModel {
async batchCreate(topicParams: (CreateTopicParams & { id?: string })[]) {
// 开始一个事务
- return serverDB.transaction(async (tx) => {
+ return this.db.transaction(async (tx) => {
// 在 topics 表中批量插入新的 topics
const createdTopics = await tx
.insert(topics)
@@ -173,7 +170,7 @@ export class TopicModel {
}
async duplicate(topicId: string, newTitle?: string) {
- return serverDB.transaction(async (tx) => {
+ return this.db.transaction(async (tx) => {
// find original topic
const originalTopic = await tx.query.topics.findFirst({
where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)),
@@ -228,14 +225,14 @@ export class TopicModel {
* Delete a session, also delete all messages and topics associated with it.
*/
async delete(id: string) {
- return serverDB.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId)));
+ return this.db.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId)));
}
/**
* Deletes multiple topics based on the sessionId.
*/
async batchDeleteBySessionId(sessionId?: string | null) {
- return serverDB
+ return this.db
.delete(topics)
.where(and(this.matchSession(sessionId), eq(topics.userId, this.userId)));
}
@@ -244,19 +241,19 @@ export class TopicModel {
* Deletes multiple topics and all messages associated with them in a transaction.
*/
async batchDelete(ids: string[]) {
- return serverDB
+ return this.db
.delete(topics)
.where(and(inArray(topics.id, ids), eq(topics.userId, this.userId)));
}
async deleteAll() {
- return serverDB.delete(topics).where(eq(topics.userId, this.userId));
+ return this.db.delete(topics).where(eq(topics.userId, this.userId));
}
// **************** Update *************** //
async update(id: string, data: Partial) {
- return serverDB
+ return this.db
.update(topics)
.set({ ...data, updatedAt: new Date() })
.where(and(eq(topics.id, id), eq(topics.userId, this.userId)))
diff --git a/src/database/server/models/user.ts b/src/database/server/models/user.ts
index 2fc92d1f1d2a..74b4130805d5 100644
--- a/src/database/server/models/user.ts
+++ b/src/database/server/models/user.ts
@@ -2,7 +2,7 @@ import { TRPCError } from '@trpc/server';
import { eq } from 'drizzle-orm';
import { DeepPartial } from 'utility-types';
-import { serverDB } from '@/database/server/core/db';
+import { LobeChatDatabase } from '@/database/type';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { UserGuide, UserPreference } from '@/types/user';
import { UserKeyVaults, UserSettings } from '@/types/user/settings';
@@ -18,38 +18,16 @@ export class UserNotFoundError extends TRPCError {
}
export class UserModel {
- static createUser = async (params: NewUser) => {
- // if user already exists, skip creation
- if (params.id) {
- const user = await serverDB.query.users.findFirst({ where: eq(users.id, params.id) });
- if (!!user) return;
- }
-
- const [user] = await serverDB
- .insert(users)
- .values({ ...params })
- .returning();
-
- // Create an inbox session for the user
- const model = new SessionModel(user.id);
-
- await model.createInbox();
- };
+ private userId: string;
+ private db: LobeChatDatabase;
- static deleteUser = async (id: string) => {
- return serverDB.delete(users).where(eq(users.id, id));
- };
-
- static findById = async (id: string) => {
- return serverDB.query.users.findFirst({ where: eq(users.id, id) });
- };
-
- static findByEmail = async (email: string) => {
- return serverDB.query.users.findFirst({ where: eq(users.email, email) });
- };
+ constructor(db: LobeChatDatabase, userId: string) {
+ this.userId = userId;
+ this.db = db;
+ }
- getUserState = async (id: string) => {
- const result = await serverDB
+ async getUserState() {
+ const result = await this.db
.select({
isOnboarded: users.isOnboarded,
preference: users.preference,
@@ -63,7 +41,7 @@ export class UserModel {
settingsTool: userSettings.tool,
})
.from(users)
- .where(eq(users.id, id))
+ .where(eq(users.id, this.userId))
.leftJoin(userSettings, eq(users.id, userSettings.id));
if (!result || !result[0]) {
@@ -82,7 +60,7 @@ export class UserModel {
try {
decryptKeyVaults = JSON.parse(plaintext);
} catch (e) {
- console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e);
+ console.error(`Failed to parse keyVaults ,userId: ${this.userId}. Error:`, e);
}
}
}
@@ -101,54 +79,22 @@ export class UserModel {
isOnboarded: state.isOnboarded,
preference: state.preference as UserPreference,
settings,
- userId: id,
+ userId: this.userId,
};
- };
-
- static getUserApiKeys = async (id: string) => {
- const result = await serverDB
- .select({
- settingsKeyVaults: userSettings.keyVaults,
- })
- .from(userSettings)
- .where(eq(userSettings.id, id));
-
- if (!result || !result[0]) {
- throw new UserNotFoundError();
- }
-
- const state = result[0];
-
- // Decrypt keyVaults
- let decryptKeyVaults = {};
- if (state.settingsKeyVaults) {
- const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
- const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults);
-
- if (wasAuthentic) {
- try {
- decryptKeyVaults = JSON.parse(plaintext);
- } catch (e) {
- console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e);
- }
- }
- }
-
- return decryptKeyVaults as UserKeyVaults;
- };
+ }
- async updateUser(id: string, value: Partial) {
- return serverDB
+ async updateUser(value: Partial) {
+ return this.db
.update(users)
.set({ ...value, updatedAt: new Date() })
- .where(eq(users.id, id));
+ .where(eq(users.id, this.userId));
}
- async deleteSetting(id: string) {
- return serverDB.delete(userSettings).where(eq(userSettings.id, id));
+ async deleteSetting() {
+ return this.db.delete(userSettings).where(eq(userSettings.id, this.userId));
}
- async updateSetting(id: string, value: Partial) {
+ async updateSetting(value: Partial) {
const { keyVaults, ...res } = value;
// Encrypt keyVaults
@@ -165,33 +111,99 @@ export class UserModel {
const newValue = { ...res, keyVaults: encryptedKeyVaults };
// update or create user settings
- const settings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, id) });
+ const settings = await this.db.query.userSettings.findFirst({
+ where: eq(users.id, this.userId),
+ });
if (!settings) {
- await serverDB.insert(userSettings).values({ id, ...newValue });
+ await this.db.insert(userSettings).values({ id: this.userId, ...newValue });
return;
}
- return serverDB.update(userSettings).set(newValue).where(eq(userSettings.id, id));
+ return this.db.update(userSettings).set(newValue).where(eq(userSettings.id, this.userId));
}
- async updatePreference(id: string, value: Partial) {
- const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) });
+ async updatePreference(value: Partial) {
+ const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) });
if (!user) return;
- return serverDB
+ return this.db
.update(users)
.set({ preference: merge(user.preference, value) })
- .where(eq(users.id, id));
+ .where(eq(users.id, this.userId));
}
- async updateGuide(id: string, value: Partial) {
- const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) });
+ async updateGuide(value: Partial) {
+ const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) });
if (!user) return;
const prevPreference = (user.preference || {}) as UserPreference;
- return serverDB
+ return this.db
.update(users)
.set({ preference: { ...prevPreference, guide: merge(prevPreference.guide || {}, value) } })
- .where(eq(users.id, id));
+ .where(eq(users.id, this.userId));
}
+
+ // Static method
+
+ static createUser = async (db: LobeChatDatabase, params: NewUser) => {
+ // if user already exists, skip creation
+ if (params.id) {
+ const user = await db.query.users.findFirst({ where: eq(users.id, params.id) });
+ if (!!user) return;
+ }
+
+ const [user] = await db
+ .insert(users)
+ .values({ ...params })
+ .returning();
+
+ // Create an inbox session for the user
+ const model = new SessionModel(db, user.id);
+
+ await model.createInbox();
+ };
+
+ static deleteUser = async (db: LobeChatDatabase, id: string) => {
+ return db.delete(users).where(eq(users.id, id));
+ };
+
+ static findById = async (db: LobeChatDatabase, id: string) => {
+ return db.query.users.findFirst({ where: eq(users.id, id) });
+ };
+
+ static findByEmail = async (db: LobeChatDatabase, email: string) => {
+ return db.query.users.findFirst({ where: eq(users.email, email) });
+ };
+
+ static getUserApiKeys = async (db: LobeChatDatabase, id: string) => {
+ const result = await db
+ .select({
+ settingsKeyVaults: userSettings.keyVaults,
+ })
+ .from(userSettings)
+ .where(eq(userSettings.id, id));
+
+ if (!result || !result[0]) {
+ throw new UserNotFoundError();
+ }
+
+ const state = result[0];
+
+ // Decrypt keyVaults
+ let decryptKeyVaults = {};
+ if (state.settingsKeyVaults) {
+ const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
+ const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults);
+
+ if (wasAuthentic) {
+ try {
+ decryptKeyVaults = JSON.parse(plaintext);
+ } catch (e) {
+ console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e);
+ }
+ }
+ }
+
+ return decryptKeyVaults as UserKeyVaults;
+ };
}
diff --git a/src/database/type.ts b/src/database/type.ts
new file mode 100644
index 000000000000..cf9308763134
--- /dev/null
+++ b/src/database/type.ts
@@ -0,0 +1,7 @@
+import type { NeonDatabase } from 'drizzle-orm/neon-serverless';
+
+import * as schema from './schemas';
+
+export type LobeChatDatabaseSchema = typeof schema;
+
+export type LobeChatDatabase = NeonDatabase;
diff --git a/src/libs/next-auth/adapter/index.ts b/src/libs/next-auth/adapter/index.ts
index 5eeb5e4ab824..ac9f8ae2caaa 100644
--- a/src/libs/next-auth/adapter/index.ts
+++ b/src/libs/next-auth/adapter/index.ts
@@ -33,8 +33,6 @@ const {
* @returns {Adapter}
*/
export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Adapter {
- const userModel = new UserModel();
-
return {
async createAuthenticator(authenticator): Promise {
const result = await serverDB
@@ -55,10 +53,10 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad
async createUser(user): Promise {
const { id, name, email, emailVerified, image, providerAccountId } = user;
// return the user if it already exists
- let existingUser = await UserModel.findByEmail(email);
+ let existingUser = await UserModel.findByEmail(serverDB, email);
// If the user is not found by email, try to find by providerAccountId
if (!existingUser && providerAccountId) {
- existingUser = await UserModel.findById(providerAccountId);
+ existingUser = await UserModel.findById(serverDB, providerAccountId);
}
if (existingUser) {
const adapterUser = mapLobeUserToAdapterUser(existingUser);
@@ -66,6 +64,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad
}
// create a new user if it does not exist
await UserModel.createUser(
+ serverDB,
mapAdapterUserToLobeUser({
email,
emailVerified,
@@ -91,10 +90,10 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad
return;
},
async deleteUser(id): Promise {
- const user = await UserModel.findById(id);
+ const user = await UserModel.findById(serverDB, id);
if (!user) throw new Error('NextAuth: Delete User not found');
- await UserModel.deleteUser(id);
+ await UserModel.deleteUser(serverDB, id);
return;
},
@@ -145,7 +144,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad
},
async getUser(id): Promise {
- const lobeUser = await UserModel.findById(id);
+ const lobeUser = await UserModel.findById(serverDB, id);
if (!lobeUser) return null;
return mapLobeUserToAdapterUser(lobeUser);
},
@@ -170,7 +169,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad
},
async getUserByEmail(email): Promise {
- const lobeUser = await UserModel.findByEmail(email);
+ const lobeUser = await UserModel.findByEmail(serverDB, email);
return lobeUser ? mapLobeUserToAdapterUser(lobeUser) : null;
},
@@ -228,10 +227,11 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad
},
async updateUser(user): Promise {
- const lobeUser = await UserModel.findById(user?.id);
+ const lobeUser = await UserModel.findById(serverDB, user?.id);
if (!lobeUser) throw new Error('NextAuth: User not found');
+ const userModel = new UserModel(serverDB, user.id);
- const updatedUser = await userModel.updateUser(user.id, {
+ const updatedUser = await userModel.updateUser({
...partialMapAdapterUserToLobeUser(user),
});
if (!updatedUser) throw new Error('NextAuth: Failed to update user');
diff --git a/src/libs/trpc/async/asyncAuth.ts b/src/libs/trpc/async/asyncAuth.ts
index 3ebf3248a273..69b52d5d5441 100644
--- a/src/libs/trpc/async/asyncAuth.ts
+++ b/src/libs/trpc/async/asyncAuth.ts
@@ -1,6 +1,7 @@
import { TRPCError } from '@trpc/server';
import { serverDBEnv } from '@/config/db';
+import { serverDB } from '@/database/server';
import { UserModel } from '@/database/server/models/user';
import { asyncTrpc } from './init';
@@ -12,7 +13,7 @@ export const asyncAuth = asyncTrpc.middleware(async (opts) => {
throw new TRPCError({ code: 'UNAUTHORIZED' });
}
- const result = await UserModel.findById(ctx.userId);
+ const result = await UserModel.findById(serverDB, ctx.userId);
if (!result) {
throw new TRPCError({ code: 'UNAUTHORIZED', message: 'user is invalid' });
diff --git a/src/server/routers/async/file.ts b/src/server/routers/async/file.ts
index c9071564245c..8bbee1e9e1fe 100644
--- a/src/server/routers/async/file.ts
+++ b/src/server/routers/async/file.ts
@@ -5,6 +5,7 @@ import { z } from 'zod';
import { fileEnv } from '@/config/file';
import { DEFAULT_EMBEDDING_MODEL } from '@/const/settings';
+import { serverDB } from '@/database/server';
import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '@/database/server/models/asyncTask';
import { ChunkModel } from '@/database/server/models/chunk';
import { EmbeddingModel } from '@/database/server/models/embedding';
@@ -28,11 +29,11 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- asyncTaskModel: new AsyncTaskModel(ctx.userId),
- chunkModel: new ChunkModel(ctx.userId),
+ asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
+ chunkModel: new ChunkModel(serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
- embeddingModel: new EmbeddingModel(ctx.userId),
- fileModel: new FileModel(ctx.userId),
+ embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
+ fileModel: new FileModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/async/ragEval.ts b/src/server/routers/async/ragEval.ts
index 5fc637337f23..de4190ee668e 100644
--- a/src/server/routers/async/ragEval.ts
+++ b/src/server/routers/async/ragEval.ts
@@ -4,6 +4,7 @@ import { z } from 'zod';
import { chainAnswerWithContext } from '@/chains/answerWithContext';
import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
+import { serverDB } from '@/database/server';
import { ChunkModel } from '@/database/server/models/chunk';
import { EmbeddingModel } from '@/database/server/models/embedding';
import { FileModel } from '@/database/server/models/file';
@@ -24,13 +25,13 @@ const ragEvalProcedure = asyncAuthedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- chunkModel: new ChunkModel(ctx.userId),
+ chunkModel: new ChunkModel(serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
- embeddingModel: new EmbeddingModel(ctx.userId),
+ embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
evalRecordModel: new EvaluationRecordModel(ctx.userId),
evaluationModel: new EvalEvaluationModel(ctx.userId),
- fileModel: new FileModel(ctx.userId),
+ fileModel: new FileModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/_template.ts b/src/server/routers/lambda/_template.ts
index 10d5fae784dd..9a3b96b7b2bf 100644
--- a/src/server/routers/lambda/_template.ts
+++ b/src/server/routers/lambda/_template.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { SessionGroupModel } from '@/database/server/models/sessionGroup';
import { insertSessionGroupSchema } from '@/database/schemas';
import { authedProcedure, router } from '@/libs/trpc';
@@ -10,7 +11,7 @@ const sessionProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- sessionGroupModel: new SessionGroupModel(ctx.userId),
+ sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/agent.ts b/src/server/routers/lambda/agent.ts
index 1ce31bef76b0..a3735ac732ef 100644
--- a/src/server/routers/lambda/agent.ts
+++ b/src/server/routers/lambda/agent.ts
@@ -2,6 +2,7 @@ import { z } from 'zod';
import { INBOX_SESSION_ID } from '@/const/session';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
+import { serverDB } from '@/database/server';
import { AgentModel } from '@/database/server/models/agent';
import { FileModel } from '@/database/server/models/file';
import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase';
@@ -16,10 +17,10 @@ const agentProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- agentModel: new AgentModel(ctx.userId),
- fileModel: new FileModel(ctx.userId),
- knowledgeBaseModel: new KnowledgeBaseModel(ctx.userId),
- sessionModel: new SessionModel(ctx.userId),
+ agentModel: new AgentModel(serverDB, ctx.userId),
+ fileModel: new FileModel(serverDB, ctx.userId),
+ knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
+ sessionModel: new SessionModel(serverDB, ctx.userId),
},
});
});
@@ -87,7 +88,7 @@ export const agentRouter = router({
// if there is no session for user, create one
if (!item) {
// if there is no user, return default config
- const user = await UserModel.findById(ctx.userId);
+ const user = await UserModel.findById(serverDB, ctx.userId);
if (!user) return DEFAULT_AGENT_CONFIG;
const res = await ctx.sessionModel.createInbox();
diff --git a/src/server/routers/lambda/chunk.ts b/src/server/routers/lambda/chunk.ts
index d9818febe52b..07425224af47 100644
--- a/src/server/routers/lambda/chunk.ts
+++ b/src/server/routers/lambda/chunk.ts
@@ -21,12 +21,12 @@ const chunkProcedure = authedProcedure.use(keyVaults).use(async (opts) => {
return opts.next({
ctx: {
- asyncTaskModel: new AsyncTaskModel(ctx.userId),
- chunkModel: new ChunkModel(ctx.userId),
+ asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
+ chunkModel: new ChunkModel(serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
- embeddingModel: new EmbeddingModel(ctx.userId),
- fileModel: new FileModel(ctx.userId),
- messageModel: new MessageModel(ctx.userId),
+ embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
+ fileModel: new FileModel(serverDB, ctx.userId),
+ messageModel: new MessageModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/file.ts b/src/server/routers/lambda/file.ts
index 8b73cc715ca7..02b4366aa9bc 100644
--- a/src/server/routers/lambda/file.ts
+++ b/src/server/routers/lambda/file.ts
@@ -1,6 +1,7 @@
import { TRPCError } from '@trpc/server';
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { AsyncTaskModel } from '@/database/server/models/asyncTask';
import { ChunkModel } from '@/database/server/models/chunk';
import { FileModel } from '@/database/server/models/file';
@@ -15,9 +16,9 @@ const fileProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- asyncTaskModel: new AsyncTaskModel(ctx.userId),
- chunkModel: new ChunkModel(ctx.userId),
- fileModel: new FileModel(ctx.userId),
+ asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
+ chunkModel: new ChunkModel(serverDB, ctx.userId),
+ fileModel: new FileModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/knowledgeBase.ts b/src/server/routers/lambda/knowledgeBase.ts
index 9ccfc01ebfd2..ac0e8456689b 100644
--- a/src/server/routers/lambda/knowledgeBase.ts
+++ b/src/server/routers/lambda/knowledgeBase.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase';
import { insertKnowledgeBasesSchema } from '@/database/schemas';
import { authedProcedure, router } from '@/libs/trpc';
@@ -10,7 +11,7 @@ const knowledgeBaseProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- knowledgeBaseModel: new KnowledgeBaseModel(ctx.userId),
+ knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/message.ts b/src/server/routers/lambda/message.ts
index fb95768d3035..dcc665c4f786 100644
--- a/src/server/routers/lambda/message.ts
+++ b/src/server/routers/lambda/message.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { MessageModel } from '@/database/server/models/message';
import { updateMessagePluginSchema } from '@/database/schemas';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
@@ -12,7 +13,7 @@ const messageProcedure = authedProcedure.use(async (opts) => {
const { ctx } = opts;
return opts.next({
- ctx: { messageModel: new MessageModel(ctx.userId) },
+ ctx: { messageModel: new MessageModel(serverDB, ctx.userId) },
});
});
@@ -54,6 +55,7 @@ export const messageRouter = router({
return ctx.messageModel.queryBySessionId(input.sessionId);
}),
+ // TODO: 未来这部分方法也需要使用 authedProcedure
getMessages: publicProcedure
.input(
z.object({
@@ -66,7 +68,7 @@ export const messageRouter = router({
.query(async ({ input, ctx }) => {
if (!ctx.userId) return [];
- const messageModel = new MessageModel(ctx.userId);
+ const messageModel = new MessageModel(serverDB, ctx.userId);
return messageModel.query(input);
}),
diff --git a/src/server/routers/lambda/plugin.ts b/src/server/routers/lambda/plugin.ts
index 3c691c51df0c..13880b1ff32e 100644
--- a/src/server/routers/lambda/plugin.ts
+++ b/src/server/routers/lambda/plugin.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { PluginModel } from '@/database/server/models/plugin';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
import { LobeTool } from '@/types/tool';
@@ -8,7 +9,7 @@ const pluginProcedure = authedProcedure.use(async (opts) => {
const { ctx } = opts;
return opts.next({
- ctx: { pluginModel: new PluginModel(ctx.userId) },
+ ctx: { pluginModel: new PluginModel(serverDB, ctx.userId) },
});
});
@@ -61,10 +62,11 @@ export const pluginRouter = router({
return data.identifier;
}),
+ // TODO: 未来这部分方法也需要使用 authedProcedure
getPlugins: publicProcedure.query(async ({ ctx }): Promise => {
if (!ctx.userId) return [];
- const pluginModel = new PluginModel(ctx.userId);
+ const pluginModel = new PluginModel(serverDB, ctx.userId);
return pluginModel.query();
}),
diff --git a/src/server/routers/lambda/ragEval.ts b/src/server/routers/lambda/ragEval.ts
index 33b33a715944..150abe4a4e1c 100644
--- a/src/server/routers/lambda/ragEval.ts
+++ b/src/server/routers/lambda/ragEval.ts
@@ -6,6 +6,7 @@ import pMap from 'p-map';
import { z } from 'zod';
import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
+import { serverDB } from '@/database/server';
import { FileModel } from '@/database/server/models/file';
import {
EvalDatasetModel,
@@ -34,7 +35,7 @@ const ragEvalProcedure = authedProcedure.use(keyVaults).use(async (opts) => {
return opts.next({
ctx: {
datasetModel: new EvalDatasetModel(ctx.userId),
- fileModel: new FileModel(ctx.userId),
+ fileModel: new FileModel(serverDB, ctx.userId),
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
evaluationModel: new EvalEvaluationModel(ctx.userId),
evaluationRecordModel: new EvaluationRecordModel(ctx.userId),
diff --git a/src/server/routers/lambda/session.ts b/src/server/routers/lambda/session.ts
index af98cd197b53..65d94538e66e 100644
--- a/src/server/routers/lambda/session.ts
+++ b/src/server/routers/lambda/session.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { SessionModel } from '@/database/server/models/session';
import { SessionGroupModel } from '@/database/server/models/sessionGroup';
import { insertAgentSchema, insertSessionSchema } from '@/database/schemas';
@@ -15,8 +16,8 @@ const sessionProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- sessionGroupModel: new SessionGroupModel(ctx.userId),
- sessionModel: new SessionModel(ctx.userId),
+ sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
+ sessionModel: new SessionModel(serverDB, ctx.userId),
},
});
});
@@ -84,7 +85,7 @@ export const sessionRouter = router({
sessions: [],
};
- const sessionModel = new SessionModel(ctx.userId);
+ const sessionModel = new SessionModel(serverDB, ctx.userId);
return sessionModel.queryWithGroups();
}),
diff --git a/src/server/routers/lambda/sessionGroup.ts b/src/server/routers/lambda/sessionGroup.ts
index 10d5fae784dd..9a3b96b7b2bf 100644
--- a/src/server/routers/lambda/sessionGroup.ts
+++ b/src/server/routers/lambda/sessionGroup.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { SessionGroupModel } from '@/database/server/models/sessionGroup';
import { insertSessionGroupSchema } from '@/database/schemas';
import { authedProcedure, router } from '@/libs/trpc';
@@ -10,7 +11,7 @@ const sessionProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- sessionGroupModel: new SessionGroupModel(ctx.userId),
+ sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/thread.ts b/src/server/routers/lambda/thread.ts
index 32e5a5452fec..05cb3eec0846 100644
--- a/src/server/routers/lambda/thread.ts
+++ b/src/server/routers/lambda/thread.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { MessageModel } from '@/database/server/models/message';
import { ThreadModel } from '@/database/server/models/thread';
import { insertThreadSchema } from '@/database/schemas';
@@ -11,8 +12,8 @@ const threadProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
- messageModel: new MessageModel(ctx.userId),
- threadModel: new ThreadModel(ctx.userId),
+ messageModel: new MessageModel(serverDB, ctx.userId),
+ threadModel: new ThreadModel(serverDB, ctx.userId),
},
});
});
diff --git a/src/server/routers/lambda/topic.ts b/src/server/routers/lambda/topic.ts
index c8b44949b472..5dd07111fb9b 100644
--- a/src/server/routers/lambda/topic.ts
+++ b/src/server/routers/lambda/topic.ts
@@ -1,5 +1,6 @@
import { z } from 'zod';
+import { serverDB } from '@/database/server';
import { TopicModel } from '@/database/server/models/topic';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
import { BatchTaskResult } from '@/types/service';
@@ -8,7 +9,7 @@ const topicProcedure = authedProcedure.use(async (opts) => {
const { ctx } = opts;
return opts.next({
- ctx: { topicModel: new TopicModel(ctx.userId) },
+ ctx: { topicModel: new TopicModel(serverDB, ctx.userId) },
});
});
@@ -78,6 +79,7 @@ export const topicRouter = router({
return ctx.topicModel.queryAll();
}),
+ // TODO: this procedure should be used with authedProcedure
getTopics: publicProcedure
.input(
z.object({
@@ -89,7 +91,7 @@ export const topicRouter = router({
.query(async ({ input, ctx }) => {
if (!ctx.userId) return [];
- const topicModel = new TopicModel(ctx.userId);
+ const topicModel = new TopicModel(serverDB, ctx.userId);
return topicModel.query(input);
}),
diff --git a/src/server/routers/lambda/user.ts b/src/server/routers/lambda/user.ts
index 9660a955e691..f7522004940b 100644
--- a/src/server/routers/lambda/user.ts
+++ b/src/server/routers/lambda/user.ts
@@ -3,6 +3,7 @@ import { currentUser } from '@clerk/nextjs/server';
import { z } from 'zod';
import { enableClerk } from '@/const/auth';
+import { serverDB } from '@/database/server';
import { MessageModel } from '@/database/server/models/message';
import { SessionModel } from '@/database/server/models/session';
import { UserModel, UserNotFoundError } from '@/database/server/models/user';
@@ -12,7 +13,7 @@ import { UserGuideSchema, UserInitializationState, UserPreference } from '@/type
const userProcedure = authedProcedure.use(async (opts) => {
return opts.next({
- ctx: { userModel: new UserModel() },
+ ctx: { userModel: new UserModel(serverDB, opts.ctx.userId) },
});
});
@@ -23,7 +24,7 @@ export const userRouter = router({
// get or create first-time user
while (!state) {
try {
- state = await ctx.userModel.getUserState(ctx.userId);
+ state = await ctx.userModel.getUserState();
} catch (error) {
if (enableClerk && error instanceof UserNotFoundError) {
const user = await currentUser();
@@ -56,10 +57,10 @@ export const userRouter = router({
}
}
- const messageModel = new MessageModel(ctx.userId);
+ const messageModel = new MessageModel(serverDB, ctx.userId);
const messageCount = await messageModel.count();
- const sessionModel = new SessionModel(ctx.userId);
+ const sessionModel = new SessionModel(serverDB, ctx.userId);
const sessionCount = await sessionModel.count();
return {
@@ -77,25 +78,25 @@ export const userRouter = router({
}),
makeUserOnboarded: userProcedure.mutation(async ({ ctx }) => {
- return ctx.userModel.updateUser(ctx.userId, { isOnboarded: true });
+ return ctx.userModel.updateUser({ isOnboarded: true });
}),
resetSettings: userProcedure.mutation(async ({ ctx }) => {
- return ctx.userModel.deleteSetting(ctx.userId);
+ return ctx.userModel.deleteSetting();
}),
updateGuide: userProcedure.input(UserGuideSchema).mutation(async ({ ctx, input }) => {
- return ctx.userModel.updateGuide(ctx.userId, input);
+ return ctx.userModel.updateGuide(input);
}),
updatePreference: userProcedure.input(z.any()).mutation(async ({ ctx, input }) => {
- return ctx.userModel.updatePreference(ctx.userId, input);
+ return ctx.userModel.updatePreference(input);
}),
updateSettings: userProcedure
.input(z.object({}).passthrough())
.mutation(async ({ ctx, input }) => {
- return ctx.userModel.updateSetting(ctx.userId, input);
+ return ctx.userModel.updateSetting(input);
}),
});
diff --git a/src/server/services/chunk/index.ts b/src/server/services/chunk/index.ts
index 3ef1ba13275d..09c8e1d820b6 100644
--- a/src/server/services/chunk/index.ts
+++ b/src/server/services/chunk/index.ts
@@ -1,4 +1,5 @@
import { JWTPayload } from '@/const/auth';
+import { serverDB } from '@/database/server';
import { AsyncTaskModel } from '@/database/server/models/asyncTask';
import { FileModel } from '@/database/server/models/file';
import { ChunkContentParams, ContentChunk } from '@/server/modules/ContentChunk';
@@ -21,8 +22,8 @@ export class ChunkService {
this.chunkClient = new ContentChunk();
- this.fileModel = new FileModel(userId);
- this.asyncTaskModel = new AsyncTaskModel(userId);
+ this.fileModel = new FileModel(serverDB, userId);
+ this.asyncTaskModel = new AsyncTaskModel(serverDB, userId);
}
async chunkContent(params: ChunkContentParams) {
diff --git a/src/server/services/nextAuthUser/index.ts b/src/server/services/nextAuthUser/index.ts
index e832a6d04aa1..85f972aab576 100644
--- a/src/server/services/nextAuthUser/index.ts
+++ b/src/server/services/nextAuthUser/index.ts
@@ -7,11 +7,9 @@ import { pino } from '@/libs/logger';
import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter';
export class NextAuthUserService {
- userModel;
adapter;
constructor() {
- this.userModel = new UserModel();
this.adapter = LobeNextAuthDbAdapter(serverDB);
}
@@ -29,8 +27,10 @@ export class NextAuthUserService {
// 2. If found, Update user data from provider
if (user?.id) {
+ const userModel = new UserModel(serverDB, user.id);
+
// Perform update
- await this.userModel.updateUser(user.id, {
+ await userModel.updateUser({
avatar: data?.avatar,
email: data?.email,
fullName: data?.fullName,
diff --git a/src/server/services/user/index.ts b/src/server/services/user/index.ts
index 1bc3480fed74..ed5ee2395099 100644
--- a/src/server/services/user/index.ts
+++ b/src/server/services/user/index.ts
@@ -1,13 +1,14 @@
import { UserJSON } from '@clerk/backend';
import { NextResponse } from 'next/server';
+import { serverDB } from '@/database/server';
import { UserModel } from '@/database/server/models/user';
import { pino } from '@/libs/logger';
export class UserService {
createUser = async (id: string, params: UserJSON) => {
// Check if user already exists
- const res = await UserModel.findById(id);
+ const res = await UserModel.findById(serverDB, id);
// If user already exists, skip creating a new user
if (res)
@@ -27,7 +28,7 @@ export class UserService {
/* ↑ cloud slot ↑ */
// 2. create user in database
- await UserModel.createUser({
+ await UserModel.createUser(serverDB, {
avatar: params.image_url,
clerkCreatedAt: new Date(params.created_at),
email: email?.email_address,
@@ -49,7 +50,7 @@ export class UserService {
if (id) {
pino.info('delete user due to clerk webhook');
- await UserModel.deleteUser(id);
+ await UserModel.deleteUser(serverDB, id);
return NextResponse.json({ message: 'user deleted' }, { status: 200 });
} else {
@@ -61,10 +62,10 @@ export class UserService {
updateUser = async (id: string, params: UserJSON) => {
pino.info('updating user due to clerk webhook');
- const userModel = new UserModel();
+ const userModel = new UserModel(serverDB, id);
// Check if user already exists
- const res = await UserModel.findById(id);
+ const res = await UserModel.findById(serverDB, id);
// If user not exists, skip update the user
if (!res)
@@ -79,7 +80,7 @@ export class UserService {
const email = params.email_addresses.find((e) => e.id === params.primary_email_address_id);
const phone = params.phone_numbers.find((e) => e.id === params.primary_phone_number_id);
- await userModel.updateUser(id, {
+ await userModel.updateUser({
avatar: params.image_url,
email: email?.email_address,
firstName: params.first_name,