From 4057ad33ce0ea5e33d196679cb4b1cadc6db88eb Mon Sep 17 00:00:00 2001 From: CanisMinor Date: Wed, 18 Dec 2024 00:19:50 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20refactor=20the?= =?UTF-8?q?=20drizzle=20code=20style=20(#5058)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor: Update drizzle code style * ♻️ refactor: Fix some drizzle-orm/expressions import * 💄 style: 替换为箭头函数 * Update topic.ts --------- Co-authored-by: Arvin Xu --- .../dataImporter/__tests__/index.test.ts | 29 ++--- .../repositories/dataImporter/index.ts | 77 +++++------- .../server/models/__tests__/_test_template.ts | 2 +- .../server/models/__tests__/agent.test.ts | 2 +- .../server/models/__tests__/asyncTask.test.ts | 2 +- .../server/models/__tests__/chunk.test.ts | 2 +- .../server/models/__tests__/file.test.ts | 2 +- .../models/__tests__/knowledgeBase.test.ts | 3 +- .../server/models/__tests__/message.test.ts | 107 ++++++---------- .../server/models/__tests__/nextauth.test.ts | 2 +- .../server/models/__tests__/session.test.ts | 2 +- .../models/__tests__/sessionGroup.test.ts | 3 +- .../server/models/__tests__/topic.test.ts | 2 +- .../server/models/__tests__/user.test.ts | 2 +- src/database/server/models/_template.ts | 4 +- src/database/server/models/agent.ts | 42 +++---- src/database/server/models/asyncTask.ts | 4 +- src/database/server/models/chunk.ts | 28 ++--- src/database/server/models/embedding.ts | 2 +- src/database/server/models/file.ts | 18 ++- src/database/server/models/knowledgeBase.ts | 10 +- src/database/server/models/message.ts | 119 ++++++++---------- src/database/server/models/plugin.ts | 4 +- src/database/server/models/ragEval/dataset.ts | 4 +- .../server/models/ragEval/datasetRecord.ts | 11 +- .../server/models/ragEval/evaluation.ts | 5 +- .../server/models/ragEval/evaluationRecord.ts | 4 +- src/database/server/models/session.ts | 73 +++++------ src/database/server/models/sessionGroup.ts | 8 +- src/database/server/models/thread.ts | 4 +- src/database/server/models/topic.ts | 101 +++++++-------- src/database/server/models/user.ts | 24 ++-- .../utils/streams/azureOpenai.test.ts | 1 - src/libs/next-auth/adapter/index.ts | 2 +- src/server/routers/lambda/chunk.ts | 4 +- 35 files changed, 309 insertions(+), 400 deletions(-) diff --git a/src/database/repositories/dataImporter/__tests__/index.test.ts b/src/database/repositories/dataImporter/__tests__/index.test.ts index 559fb56f4daf..769b4b83265b 100644 --- a/src/database/repositories/dataImporter/__tests__/index.test.ts +++ b/src/database/repositories/dataImporter/__tests__/index.test.ts @@ -1,8 +1,7 @@ // @vitest-environment node -import { eq, inArray } from 'drizzle-orm'; +import { eq, inArray } from 'drizzle-orm/expressions'; import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { getTestDBInstance } from '@/database/server/core/dbForTest'; import { agents, agentsToSessions, @@ -12,6 +11,7 @@ import { topics, users, } from '@/database/schemas'; +import { getTestDBInstance } from '@/database/server/core/dbForTest'; import { CURRENT_CONFIG_VERSION } from '@/migrations'; import { ImporterEntryData } from '@/types/importer'; @@ -60,8 +60,7 @@ describe('DataImporter', () => { it('should skip existing session groups and return correct result', async () => { await serverDB .insert(sessionGroups) - .values({ clientId: 'group1', name: 'Existing Group', userId }) - .execute(); + .values({ clientId: 'group1', name: 'Existing Group', userId }); const data: ImporterEntryData = { version: CURRENT_CONFIG_VERSION, @@ -141,7 +140,7 @@ describe('DataImporter', () => { }); it('should skip existing sessions and return correct result', async () => { - await serverDB.insert(sessions).values({ clientId: 'session1', userId }).execute(); + await serverDB.insert(sessions).values({ clientId: 'session1', userId }); const data: ImporterEntryData = { version: CURRENT_CONFIG_VERSION, @@ -477,10 +476,7 @@ describe('DataImporter', () => { }); it('should skip existing topics and return correct result', async () => { - await serverDB - .insert(topics) - .values({ clientId: 'topic1', title: 'Existing Topic', userId }) - .execute(); + await serverDB.insert(topics).values({ clientId: 'topic1', title: 'Existing Topic', userId }); const data: ImporterEntryData = { version: CURRENT_CONFIG_VERSION, @@ -616,15 +612,12 @@ describe('DataImporter', () => { }); it('should skip existing messages and return correct result', async () => { - await serverDB - .insert(messages) - .values({ - clientId: 'msg1', - content: 'Existing Message', - role: 'user', - userId, - }) - .execute(); + await serverDB.insert(messages).values({ + clientId: 'msg1', + content: 'Existing Message', + role: 'user', + userId, + }); const data: ImporterEntryData = { version: CURRENT_CONFIG_VERSION, diff --git a/src/database/repositories/dataImporter/index.ts b/src/database/repositories/dataImporter/index.ts index d8b28dd81699..67d4c610476e 100644 --- a/src/database/repositories/dataImporter/index.ts +++ b/src/database/repositories/dataImporter/index.ts @@ -1,5 +1,5 @@ -import { eq, inArray, sql } from 'drizzle-orm'; -import { and } from 'drizzle-orm/expressions'; +import { sql } from 'drizzle-orm'; +import { and, eq, inArray } from 'drizzle-orm/expressions'; import { agents, @@ -71,8 +71,7 @@ export class DataImporterRepos { set: { updatedAt: new Date() }, target: [sessionGroups.clientId, sessionGroups.userId], }) - .returning({ clientId: sessionGroups.clientId, id: sessionGroups.id }) - .execute(); + .returning({ clientId: sessionGroups.clientId, id: sessionGroups.id }); sessionGroupResult.added = mapArray.length - query.length; @@ -109,8 +108,7 @@ export class DataImporterRepos { set: { updatedAt: new Date() }, target: [sessions.clientId, sessions.userId], }) - .returning({ clientId: sessions.clientId, id: sessions.id }) - .execute(); + .returning({ clientId: sessions.clientId, id: sessions.id }); // get the session client-server id map sessionIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id])); @@ -133,18 +131,14 @@ export class DataImporterRepos { userId: this.userId, })), ) - .returning({ id: agents.id }) - .execute(); + .returning({ id: agents.id }); - await trx - .insert(agentsToSessions) - .values( - shouldInsertSessionAgents.map(({ id }, index) => ({ - agentId: agentMapArray[index].id, - sessionId: sessionIdMap[id], - })), - ) - .execute(); + await trx.insert(agentsToSessions).values( + shouldInsertSessionAgents.map(({ id }, index) => ({ + agentId: agentMapArray[index].id, + sessionId: sessionIdMap[id], + })), + ); } } @@ -178,8 +172,7 @@ export class DataImporterRepos { set: { updatedAt: new Date() }, target: [topics.clientId, topics.userId], }) - .returning({ clientId: topics.clientId, id: topics.id }) - .execute(); + .returning({ clientId: topics.clientId, id: topics.id }); topicIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id])); @@ -230,7 +223,7 @@ export class DataImporterRepos { for (let i = 0; i < inertValues.length; i += BATCH_SIZE) { const batch = inertValues.slice(i, i + BATCH_SIZE); - await trx.insert(messages).values(batch).execute(); + await trx.insert(messages).values(batch); } console.timeEnd('insert messages'); @@ -265,7 +258,7 @@ export class DataImporterRepos { .filter(Boolean); if (parentIdUpdates.length > 0) { - const updateQuery = trx + await trx .update(messages) .set({ parentId: sql`CASE ${sql.join(parentIdUpdates)} END`, @@ -281,42 +274,34 @@ export class DataImporterRepos { // const SQL = updateQuery.toSQL(); // console.log('sql:', SQL.sql); // console.log('params:', SQL.params); - - await updateQuery.execute(); } console.timeEnd('execute updates parentId'); // 4. insert message plugins const pluginInserts = shouldInsertMessages.filter((msg) => msg.plugin); if (pluginInserts.length > 0) { - await trx - .insert(messagePlugins) - .values( - pluginInserts.map((msg) => ({ - apiName: msg.plugin?.apiName, - arguments: msg.plugin?.arguments, - id: messageIdMap[msg.id], - identifier: msg.plugin?.identifier, - state: msg.pluginState, - toolCallId: msg.tool_call_id, - type: msg.plugin?.type, - })), - ) - .execute(); + await trx.insert(messagePlugins).values( + pluginInserts.map((msg) => ({ + apiName: msg.plugin?.apiName, + arguments: msg.plugin?.arguments, + id: messageIdMap[msg.id], + identifier: msg.plugin?.identifier, + state: msg.pluginState, + toolCallId: msg.tool_call_id, + type: msg.plugin?.type, + })), + ); } // 5. insert message translate const translateInserts = shouldInsertMessages.filter((msg) => msg.extra?.translate); if (translateInserts.length > 0) { - await trx - .insert(messageTranslates) - .values( - translateInserts.map((msg) => ({ - id: messageIdMap[msg.id], - ...msg.extra?.translate, - })), - ) - .execute(); + await trx.insert(messageTranslates).values( + translateInserts.map((msg) => ({ + id: messageIdMap[msg.id], + ...msg.extra?.translate, + })), + ); } // TODO: 未来需要处理 TTS 和图片的插入 (目前存在 file 的部分,不方便处理) diff --git a/src/database/server/models/__tests__/_test_template.ts b/src/database/server/models/__tests__/_test_template.ts index 0a96afd0a72b..3a2d2ecbbc1b 100644 --- a/src/database/server/models/__tests__/_test_template.ts +++ b/src/database/server/models/__tests__/_test_template.ts @@ -1,5 +1,5 @@ // @vitest-environment node -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/agent.test.ts b/src/database/server/models/__tests__/agent.test.ts index c7e5b46c9b54..60d0ea39d452 100644 --- a/src/database/server/models/__tests__/agent.test.ts +++ b/src/database/server/models/__tests__/agent.test.ts @@ -1,5 +1,5 @@ // @vitest-environment node -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/asyncTask.test.ts b/src/database/server/models/__tests__/asyncTask.test.ts index 5efafaf3214a..30f9e5937f9a 100644 --- a/src/database/server/models/__tests__/asyncTask.test.ts +++ b/src/database/server/models/__tests__/asyncTask.test.ts @@ -1,5 +1,5 @@ // @vitest-environment node -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/chunk.test.ts b/src/database/server/models/__tests__/chunk.test.ts index e027ef362346..4f36ea17a0eb 100644 --- a/src/database/server/models/__tests__/chunk.test.ts +++ b/src/database/server/models/__tests__/chunk.test.ts @@ -1,5 +1,5 @@ // @vitest-environment node -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/file.test.ts b/src/database/server/models/__tests__/file.test.ts index 611b75a22fc5..53dabfe59e63 100644 --- a/src/database/server/models/__tests__/file.test.ts +++ b/src/database/server/models/__tests__/file.test.ts @@ -1,5 +1,5 @@ // @vitest-environment node -import { eq, inArray } from 'drizzle-orm'; +import { eq, inArray } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/knowledgeBase.test.ts b/src/database/server/models/__tests__/knowledgeBase.test.ts index 008ac19c58a2..e9bc96e63069 100644 --- a/src/database/server/models/__tests__/knowledgeBase.test.ts +++ b/src/database/server/models/__tests__/knowledgeBase.test.ts @@ -1,6 +1,5 @@ // @vitest-environment node -import { eq } from 'drizzle-orm'; -import { and, desc } from 'drizzle-orm/expressions'; +import { and, eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/message.test.ts b/src/database/server/models/__tests__/message.test.ts index cdbe22dabb00..db98503b0055 100644 --- a/src/database/server/models/__tests__/message.test.ts +++ b/src/database/server/models/__tests__/message.test.ts @@ -1,4 +1,4 @@ -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; @@ -514,11 +514,7 @@ describe('MessageModel', () => { await messageModel.create({ role: 'user', content: 'new message', sessionId: '1' }); // 断言结果 - const result = await serverDB - .select() - .from(messages) - .where(eq(messages.userId, userId)) - .execute(); + const result = await serverDB.select().from(messages).where(eq(messages.userId, userId)); expect(result).toHaveLength(1); expect(result[0].content).toBe('new message'); }); @@ -549,11 +545,7 @@ describe('MessageModel', () => { }); // 断言结果 - const result = await serverDB - .select() - .from(messages) - .where(eq(messages.userId, userId)) - .execute(); + const result = await serverDB.select().from(messages).where(eq(messages.userId, userId)); expect(result[0].id).toBeDefined(); expect(result[0].id).toHaveLength(18); }); @@ -582,8 +574,7 @@ describe('MessageModel', () => { const pluginResult = await serverDB .select() .from(messagePlugins) - .where(eq(messagePlugins.id, result.id)) - .execute(); + .where(eq(messagePlugins.id, result.id)); expect(pluginResult).toHaveLength(1); expect(pluginResult[0].identifier).toBe('plugin1'); }); @@ -650,8 +641,7 @@ describe('MessageModel', () => { const pluginResult = await serverDB .select() .from(messagePlugins) - .where(eq(messagePlugins.id, result.id)) - .execute(); + .where(eq(messagePlugins.id, result.id)); expect(pluginResult).toHaveLength(1); expect(pluginResult[0].identifier).toBe('lobe-web-browsing'); expect(pluginResult[0].state!).toMatchObject(state); @@ -670,11 +660,7 @@ describe('MessageModel', () => { await messageModel.batchCreate(newMessages); // 断言结果 - const result = await serverDB - .select() - .from(messages) - .where(eq(messages.userId, userId)) - .execute(); + const result = await serverDB.select().from(messages).where(eq(messages.userId, userId)); expect(result).toHaveLength(2); expect(result[0].content).toBe('message 1'); expect(result[1].content).toBe('message 2'); @@ -692,7 +678,7 @@ describe('MessageModel', () => { await messageModel.update('1', { content: 'updated message' }); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result[0].content).toBe('updated message'); }); @@ -706,7 +692,7 @@ describe('MessageModel', () => { await messageModel.update('1', { content: 'updated message' }); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result[0].content).toBe('message 1'); }); @@ -745,7 +731,7 @@ describe('MessageModel', () => { }); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result[0].tools[0].arguments).toBe( '{"query":"2024 杭州暴雨","searchEngines":["duckduckgo","google","brave"]}', ); @@ -763,7 +749,7 @@ describe('MessageModel', () => { await messageModel.deleteMessage('1'); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result).toHaveLength(0); }); @@ -783,14 +769,13 @@ describe('MessageModel', () => { await messageModel.deleteMessage('1'); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result).toHaveLength(0); const result2 = await serverDB .select() .from(messagePlugins) - .where(eq(messagePlugins.id, '2')) - .execute(); + .where(eq(messagePlugins.id, '2')); expect(result2).toHaveLength(0); }); @@ -805,7 +790,7 @@ describe('MessageModel', () => { await messageModel.deleteMessage('1'); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result).toHaveLength(1); }); }); @@ -822,9 +807,9 @@ describe('MessageModel', () => { await messageModel.deleteMessages(['1', '2']); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result).toHaveLength(0); - const result2 = await serverDB.select().from(messages).where(eq(messages.id, '2')).execute(); + const result2 = await serverDB.select().from(messages).where(eq(messages.id, '2')); expect(result2).toHaveLength(0); }); @@ -839,7 +824,7 @@ describe('MessageModel', () => { await messageModel.deleteMessages(['1', '2']); // 断言结果 - const result = await serverDB.select().from(messages).where(eq(messages.id, '1')).execute(); + const result = await serverDB.select().from(messages).where(eq(messages.id, '1')); expect(result).toHaveLength(1); }); }); @@ -857,18 +842,12 @@ describe('MessageModel', () => { await messageModel.deleteAllMessages(); // 断言结果 - const result = await serverDB - .select() - .from(messages) - .where(eq(messages.userId, userId)) - .execute(); + const result = await serverDB.select().from(messages).where(eq(messages.userId, userId)); + expect(result).toHaveLength(0); - const otherResult = await serverDB - .select() - .from(messages) - .where(eq(messages.userId, '456')) - .execute(); + const otherResult = await serverDB.select().from(messages).where(eq(messages.userId, '456')); + expect(otherResult).toHaveLength(1); }); }); @@ -887,11 +866,8 @@ describe('MessageModel', () => { await messageModel.updatePluginState('1', { key2: 'value2' }); // 断言结果 - const result = await serverDB - .select() - .from(messagePlugins) - .where(eq(messagePlugins.id, '1')) - .execute(); + const result = await serverDB.select().from(messagePlugins).where(eq(messagePlugins.id, '1')); + expect(result[0].state).toEqual({ key1: 'value1', key2: 'value2' }); }); @@ -916,11 +892,8 @@ describe('MessageModel', () => { await messageModel.updateMessagePlugin('1', { identifier: 'plugin2' }); // 断言结果 - const result = await serverDB - .select() - .from(messagePlugins) - .where(eq(messagePlugins.id, '1')) - .execute(); + const result = await serverDB.select().from(messagePlugins).where(eq(messagePlugins.id, '1')); + expect(result[0].identifier).toEqual('plugin2'); }); @@ -950,8 +923,8 @@ describe('MessageModel', () => { const result = await serverDB .select() .from(messageTranslates) - .where(eq(messageTranslates.id, '1')) - .execute(); + .where(eq(messageTranslates.id, '1')); + expect(result).toHaveLength(1); expect(result[0].content).toBe('translated message 1'); }); @@ -974,8 +947,8 @@ describe('MessageModel', () => { const result = await serverDB .select() .from(messageTranslates) - .where(eq(messageTranslates.id, '1')) - .execute(); + .where(eq(messageTranslates.id, '1')); + expect(result[0].content).toBe('updated translated message 1'); }); }); @@ -991,11 +964,8 @@ describe('MessageModel', () => { await messageModel.updateTTS('1', { contentMd5: 'md5', file: 'f1', voice: 'voice1' }); // 断言结果 - const result = await serverDB - .select() - .from(messageTTS) - .where(eq(messageTTS.id, '1')) - .execute(); + const result = await serverDB.select().from(messageTTS).where(eq(messageTTS.id, '1')); + expect(result).toHaveLength(1); expect(result[0].voice).toBe('voice1'); }); @@ -1015,11 +985,8 @@ describe('MessageModel', () => { await messageModel.updateTTS('1', { voice: 'updated voice1' }); // 断言结果 - const result = await serverDB - .select() - .from(messageTTS) - .where(eq(messageTTS.id, '1')) - .execute(); + const result = await serverDB.select().from(messageTTS).where(eq(messageTTS.id, '1')); + expect(result[0].voice).toBe('updated voice1'); }); }); @@ -1037,8 +1004,8 @@ describe('MessageModel', () => { const result = await serverDB .select() .from(messageTranslates) - .where(eq(messageTranslates.id, '1')) - .execute(); + .where(eq(messageTranslates.id, '1')); + expect(result).toHaveLength(0); }); }); @@ -1053,11 +1020,7 @@ describe('MessageModel', () => { await messageModel.deleteMessageTTS('1'); // 断言结果 - const result = await serverDB - .select() - .from(messageTTS) - .where(eq(messageTTS.id, '1')) - .execute(); + const result = await serverDB.select().from(messageTTS).where(eq(messageTTS.id, '1')); expect(result).toHaveLength(0); }); }); diff --git a/src/database/server/models/__tests__/nextauth.test.ts b/src/database/server/models/__tests__/nextauth.test.ts index d6fdd57120fe..ecd4d034bcda 100644 --- a/src/database/server/models/__tests__/nextauth.test.ts +++ b/src/database/server/models/__tests__/nextauth.test.ts @@ -5,7 +5,7 @@ import type { AdapterUser, VerificationToken, } from '@auth/core/adapters'; -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest'; import { diff --git a/src/database/server/models/__tests__/session.test.ts b/src/database/server/models/__tests__/session.test.ts index 923504d25ab3..abe6a2ce1dac 100644 --- a/src/database/server/models/__tests__/session.test.ts +++ b/src/database/server/models/__tests__/session.test.ts @@ -1,4 +1,4 @@ -import { and, eq, inArray } from 'drizzle-orm'; +import { and, eq, inArray } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { DEFAULT_AGENT_CONFIG } from '@/const/settings'; diff --git a/src/database/server/models/__tests__/sessionGroup.test.ts b/src/database/server/models/__tests__/sessionGroup.test.ts index 2341eecdc09a..33285f76c297 100644 --- a/src/database/server/models/__tests__/sessionGroup.test.ts +++ b/src/database/server/models/__tests__/sessionGroup.test.ts @@ -1,6 +1,5 @@ // @vitest-environment node -import { eq } from 'drizzle-orm'; -import { desc } from 'drizzle-orm/expressions'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/topic.test.ts b/src/database/server/models/__tests__/topic.test.ts index 358d1ec4bca2..a15c0bd98709 100644 --- a/src/database/server/models/__tests__/topic.test.ts +++ b/src/database/server/models/__tests__/topic.test.ts @@ -1,4 +1,4 @@ -import { eq, inArray } from 'drizzle-orm'; +import { eq, inArray } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDBInstance } from '@/database/server/core/dbForTest'; diff --git a/src/database/server/models/__tests__/user.test.ts b/src/database/server/models/__tests__/user.test.ts index 9496a9fedb4e..168af993854e 100644 --- a/src/database/server/models/__tests__/user.test.ts +++ b/src/database/server/models/__tests__/user.test.ts @@ -1,4 +1,4 @@ -import { eq } from 'drizzle-orm'; +import { eq } from 'drizzle-orm/expressions'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { INBOX_SESSION_ID } from '@/const/session'; diff --git a/src/database/server/models/_template.ts b/src/database/server/models/_template.ts index 3e69153d2989..0d692720baf0 100644 --- a/src/database/server/models/_template.ts +++ b/src/database/server/models/_template.ts @@ -45,10 +45,10 @@ export class TemplateModel { }); }; - async update(id: string, value: Partial) { + update = async (id: string, value: Partial) => { return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); - } + }; } diff --git a/src/database/server/models/agent.ts b/src/database/server/models/agent.ts index 8b6a388d9d88..e7767119be6a 100644 --- a/src/database/server/models/agent.ts +++ b/src/database/server/models/agent.ts @@ -19,15 +19,15 @@ export class AgentModel { this.db = db; } - async getAgentConfigById(id: string) { + getAgentConfigById = async (id: string) => { const agent = await this.db.query.agents.findFirst({ where: eq(agents.id, id) }); const knowledge = await this.getAgentAssignedKnowledge(id); return { ...agent, ...knowledge }; - } + }; - async getAgentAssignedKnowledge(id: string) { + getAgentAssignedKnowledge = async (id: string) => { const knowledgeBaseResult = await this.db .select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases }) .from(agentsKnowledgeBases) @@ -52,12 +52,12 @@ export class AgentModel { enabled: item.enabled, })), }; - } + }; /** * Find agent by session id */ - async findBySessionId(sessionId: string) { + findBySessionId = async (sessionId: string) => { const item = await this.db.query.agentsToSessions.findFirst({ where: eq(agentsToSessions.sessionId, sessionId), }); @@ -66,22 +66,19 @@ export class AgentModel { const agentId = item.agentId; return this.getAgentConfigById(agentId); - } + }; createAgentKnowledgeBase = async ( agentId: string, knowledgeBaseId: string, enabled: boolean = true, ) => { - return this.db - .insert(agentsKnowledgeBases) - .values({ - agentId, - enabled, - knowledgeBaseId, - userId: this.userId, - }) - .execute(); + return this.db.insert(agentsKnowledgeBases).values({ + agentId, + enabled, + knowledgeBaseId, + userId: this.userId, + }); }; deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => { @@ -93,8 +90,7 @@ export class AgentModel { eq(agentsKnowledgeBases.knowledgeBaseId, knowledgeBaseId), eq(agentsKnowledgeBases.userId, this.userId), ), - ) - .execute(); + ); }; toggleKnowledgeBase = async (agentId: string, knowledgeBaseId: string, enabled?: boolean) => { @@ -107,8 +103,7 @@ export class AgentModel { eq(agentsKnowledgeBases.knowledgeBaseId, knowledgeBaseId), eq(agentsKnowledgeBases.userId, this.userId), ), - ) - .execute(); + ); }; createAgentFiles = async (agentId: string, fileIds: string[], enabled: boolean = true) => { @@ -134,8 +129,7 @@ export class AgentModel { .insert(agentsFiles) .values( needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })), - ) - .execute(); + ); }; deleteAgentFile = async (agentId: string, fileId: string) => { @@ -147,8 +141,7 @@ export class AgentModel { eq(agentsFiles.fileId, fileId), eq(agentsFiles.userId, this.userId), ), - ) - .execute(); + ); }; toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => { @@ -161,7 +154,6 @@ export class AgentModel { eq(agentsFiles.fileId, fileId), eq(agentsFiles.userId, this.userId), ), - ) - .execute(); + ); }; } diff --git a/src/database/server/models/asyncTask.ts b/src/database/server/models/asyncTask.ts index 85067856bd42..0bec8f6ec97b 100644 --- a/src/database/server/models/asyncTask.ts +++ b/src/database/server/models/asyncTask.ts @@ -64,7 +64,7 @@ export class AsyncTaskModel { /** * make the task status to be `error` if the task is not finished in 20 seconds */ - async checkTimeoutTasks(ids: string[]) { + checkTimeoutTasks = async (ids: string[]) => { const tasks = await this.db .select({ id: asyncTasks.id }) .from(asyncTasks) @@ -93,5 +93,5 @@ export class AsyncTaskModel { ), ); } - } + }; } diff --git a/src/database/server/models/chunk.ts b/src/database/server/models/chunk.ts index ed34796cc6a9..5194d8cafdfe 100644 --- a/src/database/server/models/chunk.ts +++ b/src/database/server/models/chunk.ts @@ -76,7 +76,7 @@ export class ChunkModel { }); }; - async findByFileId(id: string, page = 0) { + findByFileId = async (id: string, page = 0) => { const data = await this.db .select({ abstract: chunks.abstract, @@ -100,9 +100,9 @@ export class ChunkModel { return { ...item, metadata, pageNumber: metadata?.pageNumber } as FileChunk; }); - } + }; - async getChunksTextByFileId(id: string): Promise<{ id: string; text: string }[]> { + getChunksTextByFileId = async (id: string): Promise<{ id: string; text: string }[]> => { const data = await this.db .select() .from(chunks) @@ -113,9 +113,9 @@ export class ChunkModel { .map((item) => item.chunks) .map((chunk) => ({ id: chunk.id, text: this.mapChunkText(chunk) })) .filter((chunk) => chunk.text) as { id: string; text: string }[]; - } + }; - async countByFileIds(ids: string[]) { + countByFileIds = async (ids: string[]) => { if (ids.length === 0) return []; return this.db @@ -126,9 +126,9 @@ export class ChunkModel { .from(fileChunks) .where(inArray(fileChunks.fileId, ids)) .groupBy(fileChunks.fileId); - } + }; - async countByFileId(ids: string) { + countByFileId = async (ids: string) => { const data = await this.db .select({ count: count(fileChunks.chunkId), @@ -139,16 +139,16 @@ export class ChunkModel { .groupBy(fileChunks.fileId); return data[0]?.count ?? 0; - } + }; - async semanticSearch({ + semanticSearch = async ({ embedding, fileIds, }: { embedding: number[]; fileIds: string[] | undefined; query: string; - }) { + }) => { const similarity = sql`1 - (${cosineDistance(embeddings.embeddings, embedding)})`; const data = await this.db @@ -174,16 +174,16 @@ export class ChunkModel { ...item, metadata: item.metadata as ChunkMetadata, })); - } + }; - async semanticSearchForChat({ + semanticSearchForChat = async ({ embedding, fileIds, }: { embedding: number[]; fileIds: string[] | undefined; query: string; - }) { + }) => { const similarity = sql`1 - (${cosineDistance(embeddings.embeddings, embedding)})`; const hasFiles = fileIds && fileIds.length > 0; @@ -219,7 +219,7 @@ export class ChunkModel { text: this.mapChunkText(item), }; }); - } + }; private mapChunkText = (chunk: { metadata: any; text: string | null; type: string | null }) => { let text = chunk.text; diff --git a/src/database/server/models/embedding.ts b/src/database/server/models/embedding.ts index fb650d259221..a62755992890 100644 --- a/src/database/server/models/embedding.ts +++ b/src/database/server/models/embedding.ts @@ -50,7 +50,7 @@ export class EmbeddingModel { }); }; - countUsage = async () => { + countUsage = async (): Promise => { const result = await this.db .select({ count: count(), diff --git a/src/database/server/models/file.ts b/src/database/server/models/file.ts index 6747dcc68ebe..0d2b02b82b38 100644 --- a/src/database/server/models/file.ts +++ b/src/database/server/models/file.ts @@ -1,5 +1,5 @@ -import { asc, count, eq, ilike, inArray, notExists, or, sum } from 'drizzle-orm'; -import { and, desc, like } from 'drizzle-orm/expressions'; +import { count, sum } from 'drizzle-orm'; +import { and, asc, desc, eq, ilike, inArray, like, notExists, or } from 'drizzle-orm/expressions'; import type { PgTransaction } from 'drizzle-orm/pg-core'; import { LobeChatDatabase } from '@/database/type'; @@ -276,12 +276,11 @@ export class FileModel { return result[0].count; }; - async update(id: string, value: Partial) { - return this.db + update = async (id: string, value: Partial) => + this.db .update(files) .set({ ...value, updatedAt: new Date() }) .where(and(eq(files.id, id), eq(files.userId, this.userId))); - } /** * get the corresponding file type prefix according to FilesTabs @@ -306,17 +305,16 @@ export class FileModel { } }; - async findByNames(fileNames: string[]) { - return this.db.query.files.findMany({ + findByNames = async (fileNames: string[]) => + this.db.query.files.findMany({ where: and( or(...fileNames.map((name) => like(files.name, `${name}%`))), eq(files.userId, this.userId), ), }); - } // 抽象出通用的删除 chunks 方法 - private async deleteFileChunks(trx: PgTransaction, fileIds: string[]) { + private deleteFileChunks = async (trx: PgTransaction, fileIds: string[]) => { const BATCH_SIZE = 1000; // 每批处理的数量 // 1. 获取所有关联的 chunk IDs @@ -339,5 +337,5 @@ export class FileModel { } return chunkIds; - } + }; } diff --git a/src/database/server/models/knowledgeBase.ts b/src/database/server/models/knowledgeBase.ts index e7b808411632..ab3e3d62c71a 100644 --- a/src/database/server/models/knowledgeBase.ts +++ b/src/database/server/models/knowledgeBase.ts @@ -80,16 +80,14 @@ export class KnowledgeBaseModel { }; // update - async update(id: string, value: Partial) { - return this.db + update = async (id: string, value: Partial) => + this.db .update(knowledgeBases) .set({ ...value, updatedAt: new Date() }) .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); - } - static async findById(db: LobeChatDatabase, id: string) { - return db.query.knowledgeBases.findFirst({ + static findById = async (db: LobeChatDatabase, id: string) => + db.query.knowledgeBases.findFirst({ where: eq(knowledgeBases.id, id), }); - } } diff --git a/src/database/server/models/message.ts b/src/database/server/models/message.ts index 276761bacdac..7fa7a477df22 100644 --- a/src/database/server/models/message.ts +++ b/src/database/server/models/message.ts @@ -46,10 +46,10 @@ export class MessageModel { } // **************** Query *************** // - async query( + query = async ( { current = 0, pageSize = 1000, sessionId, topicId }: QueryMessageParams = {}, options: { postProcessUrl?: (path: string | null) => Promise } = {}, - ): Promise { + ): Promise => { const offset = current * pageSize; // 1. get basic messages @@ -212,15 +212,15 @@ export class MessageModel { }; }, ); - } + }; - async findById(id: string) { + findById = async (id: string) => { return this.db.query.messages.findFirst({ where: and(eq(messages.id, id), eq(messages.userId, this.userId)), }); - } + }; - async findMessageQueriesById(messageId: string) { + findMessageQueriesById = async (messageId: string) => { const result = await this.db .select({ embeddings: embeddings.embeddings, @@ -236,47 +236,43 @@ export class MessageModel { if (result.length === 0) return undefined; return result[0]; - } + }; - async queryAll(): Promise { + queryAll = async (): Promise => { return this.db .select() .from(messages) .orderBy(messages.createdAt) - .where(eq(messages.userId, this.userId)) - - .execute(); - } + .where(eq(messages.userId, this.userId)); + }; - async queryBySessionId(sessionId?: string | null): Promise { + queryBySessionId = async (sessionId?: string | null): Promise => { return this.db.query.messages.findMany({ orderBy: [asc(messages.createdAt)], where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)), }); - } + }; - async queryByKeyword(keyword: string): Promise { + queryByKeyword = async (keyword: string): Promise => { if (!keyword) return []; - return this.db.query.messages.findMany({ orderBy: [desc(messages.createdAt)], where: and(eq(messages.userId, this.userId), like(messages.content, `%${keyword}%`)), }); - } + }; - async count() { + count = async (): Promise => { const result = await this.db .select({ - count: count(), + count: count(messages.id), }) .from(messages) - .where(eq(messages.userId, this.userId)) - .execute(); + .where(eq(messages.userId, this.userId)); return result[0].count; - } + }; - async countToday() { + countToday = async (): Promise => { const today = new Date(); today.setHours(0, 0, 0, 0); const tomorrow = new Date(today); @@ -284,7 +280,7 @@ export class MessageModel { const result = await this.db .select({ - count: count(), + count: count(messages.id), }) .from(messages) .where( @@ -293,15 +289,14 @@ export class MessageModel { gte(messages.createdAt, today), lt(messages.createdAt, tomorrow), ), - ) - .execute(); + ); return result[0].count; - } + }; // **************** Create *************** // - async create( + create = async ( { fromModel, fromProvider, @@ -313,7 +308,7 @@ export class MessageModel { ...message }: CreateMessageParams, id: string = this.genId(), - ): Promise { + ): Promise => { return this.db.transaction(async (trx) => { const [item] = (await trx .insert(messages) @@ -358,31 +353,31 @@ export class MessageModel { return item; }); - } + }; - async batchCreate(newMessages: MessageItem[]) { + batchCreate = async (newMessages: MessageItem[]) => { const messagesToInsert = newMessages.map((m) => { return { ...m, userId: this.userId }; }); return this.db.insert(messages).values(messagesToInsert); - } + }; - async createMessageQuery(params: NewMessageQuery) { + createMessageQuery = async (params: NewMessageQuery) => { const result = await this.db.insert(messageQueries).values(params).returning(); return result[0]; - } + }; // **************** Update *************** // - async update(id: string, message: Partial) { + update = async (id: string, message: Partial) => { return this.db .update(messages) .set(message) .where(and(eq(messages.id, id), eq(messages.userId, this.userId))); - } + }; - async updatePluginState(id: string, state: Record) { + updatePluginState = async (id: string, state: Record) => { const item = await this.db.query.messagePlugins.findFirst({ where: eq(messagePlugins.id, id), }); @@ -392,18 +387,18 @@ export class MessageModel { .update(messagePlugins) .set({ state: merge(item.state || {}, state) }) .where(eq(messagePlugins.id, id)); - } + }; - async updateMessagePlugin(id: string, value: Partial) { + updateMessagePlugin = async (id: string, value: Partial) => { const item = await this.db.query.messagePlugins.findFirst({ where: eq(messagePlugins.id, id), }); if (!item) throw new Error('Plugin not found'); return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id)); - } + }; - async updateTranslate(id: string, translate: Partial) { + updateTranslate = async (id: string, translate: Partial) => { const result = await this.db.query.messageTranslates.findFirst({ where: and(eq(messageTranslates.id, id)), }); @@ -415,9 +410,9 @@ export class MessageModel { // or just update the existing one return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id)); - } + }; - async updateTTS(id: string, tts: Partial) { + updateTTS = async (id: string, tts: Partial) => { const result = await this.db.query.messageTTS.findFirst({ where: and(eq(messageTTS.id, id)), }); @@ -434,11 +429,11 @@ export class MessageModel { .update(messageTTS) .set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice }) .where(eq(messageTTS.id, id)); - } + }; // **************** Delete *************** // - async deleteMessage(id: string) { + deleteMessage = async (id: string) => { return this.db.transaction(async (tx) => { // 1. 查询要删除的 message 的完整信息 const message = await tx @@ -460,8 +455,7 @@ export class MessageModel { const res = await tx .select({ id: messagePlugins.id }) .from(messagePlugins) - .where(inArray(messagePlugins.toolCallId, toolCallIds)) - .execute(); + .where(inArray(messagePlugins.toolCallId, toolCallIds)); relatedMessageIds = res.map((row) => row.id); } @@ -472,28 +466,24 @@ export class MessageModel { // 5. 删除所有相关的 message await tx.delete(messages).where(inArray(messages.id, messageIdsToDelete)); }); - } + }; - async deleteMessages(ids: string[]) { - return this.db + deleteMessages = async (ids: string[]) => + this.db .delete(messages) .where(and(eq(messages.userId, this.userId), inArray(messages.id, ids))); - } - async deleteMessageTranslate(id: string) { - return this.db.delete(messageTranslates).where(and(eq(messageTranslates.id, id))); - } + deleteMessageTranslate = async (id: string) => + this.db.delete(messageTranslates).where(and(eq(messageTranslates.id, id))); - async deleteMessageTTS(id: string) { - return this.db.delete(messageTTS).where(and(eq(messageTTS.id, id))); - } + deleteMessageTTS = async (id: string) => + this.db.delete(messageTTS).where(and(eq(messageTTS.id, id))); - async deleteMessageQuery(id: string) { - return this.db.delete(messageQueries).where(and(eq(messageQueries.id, id))); - } + deleteMessageQuery = async (id: string) => + this.db.delete(messageQueries).where(and(eq(messageQueries.id, id))); - async deleteMessagesBySession(sessionId?: string | null, topicId?: string | null) { - return this.db + deleteMessagesBySession = async (sessionId?: string | null, topicId?: string | null) => + this.db .delete(messages) .where( and( @@ -502,11 +492,10 @@ export class MessageModel { this.matchTopic(topicId), ), ); - } - async deleteAllMessages() { + deleteAllMessages = async () => { return this.db.delete(messages).where(eq(messages.userId, this.userId)); - } + }; // **************** Helper *************** // diff --git a/src/database/server/models/plugin.ts b/src/database/server/models/plugin.ts index 22aa9bbc544b..8aa24dc6c866 100644 --- a/src/database/server/models/plugin.ts +++ b/src/database/server/models/plugin.ts @@ -60,10 +60,10 @@ export class PluginModel { }); }; - async update(id: string, value: Partial) { + update = async (id: string, value: Partial) => { return this.db .update(installedPlugins) .set({ ...value, updatedAt: new Date() }) .where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId))); - } + }; } diff --git a/src/database/server/models/ragEval/dataset.ts b/src/database/server/models/ragEval/dataset.ts index 18c9a7ca7deb..b2d449b51599 100644 --- a/src/database/server/models/ragEval/dataset.ts +++ b/src/database/server/models/ragEval/dataset.ts @@ -1,7 +1,7 @@ -import { and, desc, eq } from 'drizzle-orm'; +import { and, desc, eq } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; import { NewEvalDatasetsItem, evalDatasets } from '@/database/schemas'; +import { serverDB } from '@/database/server'; import { RAGEvalDataSetItem } from '@/types/eval'; export class EvalDatasetModel { diff --git a/src/database/server/models/ragEval/datasetRecord.ts b/src/database/server/models/ragEval/datasetRecord.ts index 6b61fa75f974..97ad9bcb2a04 100644 --- a/src/database/server/models/ragEval/datasetRecord.ts +++ b/src/database/server/models/ragEval/datasetRecord.ts @@ -1,11 +1,7 @@ -import { and, eq, inArray } from 'drizzle-orm'; +import { and, eq, inArray } from 'drizzle-orm/expressions'; +import { NewEvalDatasetRecordsItem, evalDatasetRecords, files } from '@/database/schemas'; import { serverDB } from '@/database/server'; -import { - NewEvalDatasetRecordsItem, - evalDatasetRecords, - files, -} from '@/database/schemas'; import { EvalDatasetRecordRefFile } from '@/types/eval'; export class EvalDatasetRecordModel { @@ -50,8 +46,7 @@ export class EvalDatasetRecordModel { const fileItems = await serverDB .select({ fileType: files.fileType, id: files.id, name: files.name }) .from(files) - .where(and(inArray(files.id, fileList), eq(files.userId, this.userId))) - .execute(); + .where(and(inArray(files.id, fileList), eq(files.userId, this.userId))); return list.map((item) => { return { diff --git a/src/database/server/models/ragEval/evaluation.ts b/src/database/server/models/ragEval/evaluation.ts index 73f5264c95cf..04d70f37e2d7 100644 --- a/src/database/server/models/ragEval/evaluation.ts +++ b/src/database/server/models/ragEval/evaluation.ts @@ -1,12 +1,13 @@ -import { SQL, and, count, desc, eq, inArray } from 'drizzle-orm'; +import { SQL, count } from 'drizzle-orm'; +import { and, desc, eq, inArray } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; import { NewEvalEvaluationItem, evalDatasets, evalEvaluation, evaluationRecords, } from '@/database/schemas'; +import { serverDB } from '@/database/server'; import { EvalEvaluationStatus, RAGEvalEvaluationItem } from '@/types/eval'; export class EvalEvaluationModel { diff --git a/src/database/server/models/ragEval/evaluationRecord.ts b/src/database/server/models/ragEval/evaluationRecord.ts index 8edfa0b193a5..7e652798d686 100644 --- a/src/database/server/models/ragEval/evaluationRecord.ts +++ b/src/database/server/models/ragEval/evaluationRecord.ts @@ -1,7 +1,7 @@ -import { and, eq } from 'drizzle-orm'; +import { and, eq } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; import { NewEvaluationRecordsItem, evaluationRecords } from '@/database/schemas'; +import { serverDB } from '@/database/server'; export class EvaluationRecordModel { private userId: string; diff --git a/src/database/server/models/session.ts b/src/database/server/models/session.ts index 494e128e748e..74a4ae87dc39 100644 --- a/src/database/server/models/session.ts +++ b/src/database/server/models/session.ts @@ -31,7 +31,7 @@ export class SessionModel { } // **************** Query *************** // - async query({ current = 0, pageSize = 9999 } = {}) { + query = async ({ current = 0, pageSize = 9999 } = {}) => { const offset = current * pageSize; return this.db.query.sessions.findMany({ @@ -41,9 +41,9 @@ export class SessionModel { where: and(eq(sessions.userId, this.userId), not(eq(sessions.slug, INBOX_SESSION_ID))), with: { agentsToSessions: { columns: {}, with: { agent: true } }, group: true }, }); - } + }; - async queryWithGroups(): Promise { + queryWithGroups = async (): Promise => { // 查询所有会话 const result = await this.query(); @@ -56,9 +56,9 @@ export class SessionModel { sessionGroups: groups as unknown as ChatSessionList['sessionGroups'], sessions: result.map((item) => this.mapSessionItem(item as any)), }; - } + }; - async queryByKeyword(keyword: string) { + queryByKeyword = async (keyword: string) => { if (!keyword) return []; const keywordLowerCase = keyword.toLowerCase(); @@ -66,11 +66,11 @@ export class SessionModel { const data = await this.findSessionsByKeywords({ keyword: keywordLowerCase }); return data.map((item) => this.mapSessionItem(item as any)); - } + }; - async findByIdOrSlug( + findByIdOrSlug = async ( idOrSlug: string, - ): Promise<(SessionItem & { agent: AgentItem }) | undefined> { + ): Promise<(SessionItem & { agent: AgentItem }) | undefined> => { const result = await this.db.query.sessions.findFirst({ where: and( or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)), @@ -82,23 +82,22 @@ export class SessionModel { if (!result) return; return { ...result, agent: (result?.agentsToSessions?.[0] as any)?.agent } as any; - } + }; - async count() { + count = async (): Promise => { const result = await this.db .select({ - count: count(), + count: count(sessions.id), }) .from(sessions) - .where(eq(sessions.userId, this.userId)) - .execute(); + .where(eq(sessions.userId, this.userId)); return result[0].count; - } + }; // **************** Create *************** // - async create({ + create = async ({ id = idGenerator('sessions'), type = 'agent', session = {}, @@ -110,7 +109,7 @@ export class SessionModel { session?: Partial; slug?: string; type: 'agent' | 'group'; - }): Promise { + }): Promise => { return this.db.transaction(async (trx) => { const newAgents = await trx .insert(agents) @@ -143,9 +142,9 @@ export class SessionModel { return result[0]; }); - } + }; - async createInbox() { + createInbox = async () => { const item = await this.db.query.sessions.findFirst({ where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)), }); @@ -158,9 +157,9 @@ export class SessionModel { slug: INBOX_SESSION_ID, type: 'agent', }); - } + }; - async batchCreate(newSessions: NewSession[]) { + batchCreate = async (newSessions: NewSession[]) => { const sessionsToInsert = newSessions.map((s) => { return { ...s, @@ -170,9 +169,9 @@ export class SessionModel { }); return this.db.insert(sessions).values(sessionsToInsert); - } + }; - async duplicate(id: string, newTitle?: string) { + duplicate = async (id: string, newTitle?: string) => { const result = await this.findByIdOrSlug(id); if (!result) return; @@ -193,49 +192,49 @@ export class SessionModel { }, type: 'agent', }); - } + }; // **************** Delete *************** // /** * Delete a session, also delete all messages and topics associated with it. */ - async delete(id: string) { + delete = async (id: string) => { return this.db .delete(sessions) .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))); - } + }; /** * Batch delete sessions, also delete all messages and topics associated with them. */ - async batchDelete(ids: string[]) { + batchDelete = async (ids: string[]) => { return this.db .delete(sessions) .where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId))); - } + }; - async deleteAll() { + deleteAll = async () => { return this.db.delete(sessions).where(eq(sessions.userId, this.userId)); - } + }; // **************** Update *************** // - async update(id: string, data: Partial) { + update = async (id: string, data: Partial) => { return this.db .update(sessions) .set(data) .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))) .returning(); - } + }; - async updateConfig(id: string, data: Partial) { + updateConfig = async (id: string, data: Partial) => { if (Object.keys(data).length === 0) return; return this.db .update(agents) .set(data) .where(and(eq(agents.id, id), eq(agents.userId, this.userId))); - } + }; // **************** Helper *************** // @@ -266,7 +265,11 @@ export class SessionModel { } as any; }; - async findSessionsByKeywords(params: { current?: number; keyword: string; pageSize?: number }) { + findSessionsByKeywords = async (params: { + current?: number; + keyword: string; + pageSize?: number; + }) => { const { keyword, pageSize = 9999, current = 0 } = params; const offset = current * pageSize; const results = await this.db.query.agents.findMany({ @@ -290,5 +293,5 @@ export class SessionModel { return results.map((item) => item.agentsToSessions[0].session); } catch {} return []; - } + }; } diff --git a/src/database/server/models/sessionGroup.ts b/src/database/server/models/sessionGroup.ts index 840daeedcff8..cd93c028dfd1 100644 --- a/src/database/server/models/sessionGroup.ts +++ b/src/database/server/models/sessionGroup.ts @@ -46,14 +46,14 @@ export class SessionGroupModel { }); }; - async update(id: string, value: Partial) { + update = async (id: string, value: Partial) => { return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); - } + }; - async updateOrder(sortMap: { id: string; sort: number }[]) { + updateOrder = async (sortMap: { id: string; sort: number }[]) => { await this.db.transaction(async (tx) => { const updates = sortMap.map(({ id, sort }) => { return tx @@ -64,7 +64,7 @@ export class SessionGroupModel { await Promise.all(updates); }); - } + }; private genId = () => idGenerator('sessionGroups'); } diff --git a/src/database/server/models/thread.ts b/src/database/server/models/thread.ts index 582059ff97d4..83e46eb46818 100644 --- a/src/database/server/models/thread.ts +++ b/src/database/server/models/thread.ts @@ -71,10 +71,10 @@ export class ThreadModel { }); }; - async update(id: string, value: Partial) { + update = async (id: string, value: Partial) => { return this.db .update(threads) .set({ ...value, updatedAt: new Date() }) .where(and(eq(threads.id, id), eq(threads.userId, this.userId))); - } + }; } diff --git a/src/database/server/models/topic.ts b/src/database/server/models/topic.ts index 61ad2bb62c33..a94e1d8d5286 100644 --- a/src/database/server/models/topic.ts +++ b/src/database/server/models/topic.ts @@ -29,46 +29,42 @@ export class TopicModel { } // **************** Query *************** // - async query({ current = 0, pageSize = 9999, sessionId }: QueryTopicParams = {}) { + query = async ({ current = 0, pageSize = 9999, sessionId }: QueryTopicParams = {}) => { const offset = current * pageSize; - - return ( - this.db - .select({ - createdAt: topics.createdAt, - favorite: topics.favorite, - historySummary: topics.historySummary, - id: topics.id, - metadata: topics.metadata, - title: topics.title, - updatedAt: topics.updatedAt, - }) - .from(topics) - .where(and(eq(topics.userId, this.userId), this.matchSession(sessionId))) - // In boolean sorting, false is considered "smaller" than true. - // So here we use desc to ensure that topics with favorite as true are in front. - .orderBy(desc(topics.favorite), desc(topics.updatedAt)) - .limit(pageSize) - .offset(offset) - ); - } - - async findById(id: string) { + return this.db + .select({ + createdAt: topics.createdAt, + favorite: topics.favorite, + historySummary: topics.historySummary, + id: topics.id, + metadata: topics.metadata, + title: topics.title, + updatedAt: topics.updatedAt, + }) + .from(topics) + .where(and(eq(topics.userId, this.userId), this.matchSession(sessionId))) + // In boolean sorting, false is considered "smaller" than true. + // So here we use desc to ensure that topics with favorite as true are in front. + .orderBy(desc(topics.favorite), desc(topics.updatedAt)) + .limit(pageSize) + .offset(offset); + }; + + findById = async (id: string) => { return this.db.query.topics.findFirst({ where: and(eq(topics.id, id), eq(topics.userId, this.userId)), }); - } + }; - async queryAll(): Promise { + queryAll = async (): Promise => { return this.db .select() .from(topics) .orderBy(topics.updatedAt) - .where(eq(topics.userId, this.userId)) - .execute(); - } + .where(eq(topics.userId, this.userId)); + }; - async queryByKeyword(keyword: string, sessionId?: string | null): Promise { + queryByKeyword = async (keyword: string, sessionId?: string | null): Promise => { if (!keyword) return []; const keywordLowerCase = keyword.toLowerCase(); @@ -92,26 +88,25 @@ export class TopicModel { ), ), }); - } + }; - async count() { + count = async (): Promise => { const result = await this.db .select({ - count: count(), + count: count(topics.id), }) .from(topics) - .where(eq(topics.userId, this.userId)) - .execute(); + .where(eq(topics.userId, this.userId)); return result[0].count; - } + }; // **************** Create *************** // - async create( + create = async ( { messages: messageIds, ...params }: CreateTopicParams, id: string = this.genId(), - ): Promise { + ): Promise => { return this.db.transaction(async (tx) => { // 在 topics 表中插入新的 topic const [topic] = await tx @@ -133,9 +128,9 @@ export class TopicModel { return topic; }); - } + }; - async batchCreate(topicParams: (CreateTopicParams & { id?: string })[]) { + batchCreate = async (topicParams: (CreateTopicParams & { id?: string })[]) => { // 开始一个事务 return this.db.transaction(async (tx) => { // 在 topics 表中批量插入新的 topics @@ -167,9 +162,9 @@ export class TopicModel { return createdTopics; }); - } + }; - async duplicate(topicId: string, newTitle?: string) { + duplicate = async (topicId: string, newTitle?: string) => { return this.db.transaction(async (tx) => { // find original topic const originalTopic = await tx.query.topics.findFirst({ @@ -217,48 +212,48 @@ export class TopicModel { topic: duplicatedTopic, }; }); - } + }; // **************** Delete *************** // /** * Delete a session, also delete all messages and topics associated with it. */ - async delete(id: string) { + delete = async (id: string) => { return this.db.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId))); - } + }; /** * Deletes multiple topics based on the sessionId. */ - async batchDeleteBySessionId(sessionId?: string | null) { + batchDeleteBySessionId = async (sessionId?: string | null) => { return this.db .delete(topics) .where(and(this.matchSession(sessionId), eq(topics.userId, this.userId))); - } + }; /** * Deletes multiple topics and all messages associated with them in a transaction. */ - async batchDelete(ids: string[]) { + batchDelete = async (ids: string[]) => { return this.db .delete(topics) .where(and(inArray(topics.id, ids), eq(topics.userId, this.userId))); - } + }; - async deleteAll() { + deleteAll = async () => { return this.db.delete(topics).where(eq(topics.userId, this.userId)); - } + }; // **************** Update *************** // - async update(id: string, data: Partial) { + update = async (id: string, data: Partial) => { return this.db .update(topics) .set({ ...data, updatedAt: new Date() }) .where(and(eq(topics.id, id), eq(topics.userId, this.userId))) .returning(); - } + }; // **************** Helper *************** // diff --git a/src/database/server/models/user.ts b/src/database/server/models/user.ts index 361a064eef14..178824219b5c 100644 --- a/src/database/server/models/user.ts +++ b/src/database/server/models/user.ts @@ -26,7 +26,7 @@ export class UserModel { this.db = db; } - async getUserState() { + getUserState = async () => { const result = await this.db .select({ isOnboarded: users.isOnboarded, @@ -81,20 +81,20 @@ export class UserModel { settings, userId: this.userId, }; - } + }; - async updateUser(value: Partial) { + updateUser = async (value: Partial) => { return this.db .update(users) .set({ ...value, updatedAt: new Date() }) .where(eq(users.id, this.userId)); - } + }; - async deleteSetting() { + deleteSetting = async () => { return this.db.delete(userSettings).where(eq(userSettings.id, this.userId)); - } + }; - async updateSetting(value: Partial) { + updateSetting = async (value: Partial) => { const { keyVaults, ...res } = value; // Encrypt keyVaults @@ -120,9 +120,9 @@ export class UserModel { } return this.db.update(userSettings).set(newValue).where(eq(userSettings.id, this.userId)); - } + }; - async updatePreference(value: Partial) { + updatePreference = async (value: Partial) => { const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) }); if (!user) return; @@ -130,9 +130,9 @@ export class UserModel { .update(users) .set({ preference: merge(user.preference, value) }) .where(eq(users.id, this.userId)); - } + }; - async updateGuide(value: Partial) { + updateGuide = async (value: Partial) => { const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) }); if (!user) return; @@ -141,7 +141,7 @@ export class UserModel { .update(users) .set({ preference: { ...prevPreference, guide: merge(prevPreference.guide || {}, value) } }) .where(eq(users.id, this.userId)); - } + }; // Static method diff --git a/src/libs/agent-runtime/utils/streams/azureOpenai.test.ts b/src/libs/agent-runtime/utils/streams/azureOpenai.test.ts index 51b7f85e5296..292ace65d5e9 100644 --- a/src/libs/agent-runtime/utils/streams/azureOpenai.test.ts +++ b/src/libs/agent-runtime/utils/streams/azureOpenai.test.ts @@ -1,4 +1,3 @@ -import { desc } from 'drizzle-orm/expressions'; import { describe, expect, it, vi } from 'vitest'; import { AzureOpenAIStream } from './azureOpenai'; diff --git a/src/libs/next-auth/adapter/index.ts b/src/libs/next-auth/adapter/index.ts index 6c5925110481..9cbb3ef4ab41 100644 --- a/src/libs/next-auth/adapter/index.ts +++ b/src/libs/next-auth/adapter/index.ts @@ -4,7 +4,7 @@ import type { AdapterUser, VerificationToken, } from '@auth/core/adapters'; -import { and, eq } from 'drizzle-orm'; +import { and, eq } from 'drizzle-orm/expressions'; import type { NeonDatabase } from 'drizzle-orm/neon-serverless'; import { Adapter, AdapterAccount } from 'next-auth/adapters'; diff --git a/src/server/routers/lambda/chunk.ts b/src/server/routers/lambda/chunk.ts index 07425224af47..c663c74d31e6 100644 --- a/src/server/routers/lambda/chunk.ts +++ b/src/server/routers/lambda/chunk.ts @@ -1,14 +1,14 @@ -import { inArray } from 'drizzle-orm'; +import { inArray } from 'drizzle-orm/expressions'; import { z } from 'zod'; import { DEFAULT_EMBEDDING_MODEL } from '@/const/settings'; +import { knowledgeBaseFiles } from '@/database/schemas'; import { serverDB } from '@/database/server'; import { AsyncTaskModel } from '@/database/server/models/asyncTask'; import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; import { FileModel } from '@/database/server/models/file'; import { MessageModel } from '@/database/server/models/message'; -import { knowledgeBaseFiles } from '@/database/schemas'; import { ModelProvider } from '@/libs/agent-runtime'; import { authedProcedure, router } from '@/libs/trpc'; import { keyVaults } from '@/libs/trpc/middleware/keyVaults';