Skip to content

Commit

Permalink
⚡️ perf: support to use signedUrl for S3 if bucket is not public-read (
Browse files Browse the repository at this point in the history
…lobehub#4254)

* 当未把s3存储桶设成公共读时,需要每次生成带签名的预览地址,默认有效期2小时。

* If bucket is not set public read, the preview address needs to be regenerated each time

* If bucket is not set public read, the preview address needs to be regenerated each time,fix test ts,add param to defined s3 preview url expires in.

* fix ci test

---------

Co-authored-by: Arvin Xu <[email protected]>
  • Loading branch information
vual and arvinxx authored Nov 3, 2024
1 parent ee31f1d commit 7204296
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/config/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const getFileConfig = () => {
S3_BUCKET: process.env.S3_BUCKET,
S3_ENABLE_PATH_STYLE: process.env.S3_ENABLE_PATH_STYLE === '1',
S3_ENDPOINT: process.env.S3_ENDPOINT,
S3_PREVIEW_URL_EXPIRE_IN: parseInt(process.env.S3_PREVIEW_URL_EXPIRE_IN || '7200'),
S3_PUBLIC_DOMAIN,
S3_REGION: process.env.S3_REGION,
S3_SECRET_ACCESS_KEY: process.env.S3_SECRET_ACCESS_KEY,
Expand All @@ -46,6 +47,7 @@ export const getFileConfig = () => {
S3_ENABLE_PATH_STYLE: z.boolean(),

S3_ENDPOINT: z.string().url().optional(),
S3_PREVIEW_URL_EXPIRE_IN: z.number(),
S3_PUBLIC_DOMAIN: z.string().url().optional(),
S3_REGION: z.string().optional(),
S3_SECRET_ACCESS_KEY: z.string().optional(),
Expand Down
10 changes: 6 additions & 4 deletions src/database/server/models/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ export class MessageModel {
.leftJoin(files, eq(files.id, messagesFiles.fileId))
.where(inArray(messagesFiles.messageId, messageIds));

const relatedFileList = rawRelatedFileList.map((file) => ({
...file,
url: getFullFileUrl(file.url),
}));
const relatedFileList = await Promise.all(
rawRelatedFileList.map(async (file) => ({
...file,
url: await getFullFileUrl(file.url),
})),
);

const imageList = relatedFileList.filter((i) => (i.fileType || '').startsWith('image'));
const fileList = relatedFileList.filter((i) => !(i.fileType || '').startsWith('image'));
Expand Down
11 changes: 11 additions & 0 deletions src/server/modules/S3/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ export class S3 {
return getSignedUrl(this.client, command, { expiresIn: 3600 });
}

public async createPreSignedUrlForPreview(key: string, expiresIn?: number): Promise<string> {
const command = new GetObjectCommand({
Bucket: this.bucket,
Key: key,
});

return getSignedUrl(this.client, command, {
expiresIn: expiresIn ?? fileEnv.S3_PREVIEW_URL_EXPIRE_IN,
});
}

public async uploadContent(path: string, content: string) {
const command = new PutObjectCommand({
ACL: this.setAcl ? 'public-read' : undefined,
Expand Down
20 changes: 12 additions & 8 deletions src/server/routers/lambda/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export const fileRouter = router({
url: input.url,
});

return { id, url: getFullFileUrl(input.url) };
return { id, url: await getFullFileUrl(input.url) };
}),
findById: fileProcedure
.input(
Expand All @@ -69,7 +69,7 @@ export const fileRouter = router({
const item = await ctx.fileModel.findById(input.id);
if (!item) throw new TRPCError({ code: 'BAD_REQUEST', message: 'File not found' });

return { ...item, url: getFullFileUrl(item?.url) };
return { ...item, url: await getFullFileUrl(item?.url) };
}),

getFileItemById: fileProcedure
Expand Down Expand Up @@ -102,7 +102,7 @@ export const fileRouter = router({
embeddingError: embeddingTask?.error,
embeddingStatus: embeddingTask?.status as AsyncTaskStatus,
finishEmbedding: embeddingTask?.status === AsyncTaskStatus.Success,
url: getFullFileUrl(item.url!),
url: await getFullFileUrl(item.url!),
};
}),

Expand All @@ -124,23 +124,27 @@ export const fileRouter = router({
AsyncTaskType.Embedding,
);

return fileList.map(({ chunkTaskId, embeddingTaskId, ...item }): FileListItem => {
const resultFiles = [] as any[];
for (const { chunkTaskId, embeddingTaskId, ...item } of fileList as any[]) {
const chunkTask = chunkTaskId ? chunkTasks.find((task) => task.id === chunkTaskId) : null;
const embeddingTask = embeddingTaskId
? embeddingTasks.find((task) => task.id === embeddingTaskId)
: null;

return {
const fileItem = {
...item,
chunkCount: chunks.find((chunk) => chunk.id === item.id)?.count ?? null,
chunkingError: chunkTask?.error ?? null,
chunkingStatus: chunkTask?.status as AsyncTaskStatus,
embeddingError: embeddingTask?.error ?? null,
embeddingStatus: embeddingTask?.status as AsyncTaskStatus,
finishEmbedding: embeddingTask?.status === AsyncTaskStatus.Success,
url: getFullFileUrl(item.url!),
};
});
url: await getFullFileUrl(item.url!),
} as FileListItem;
resultFiles.push(fileItem);
}

return resultFiles;
}),

removeAllFiles: fileProcedure.mutation(async ({ ctx }) => {
Expand Down
2 changes: 1 addition & 1 deletion src/server/routers/lambda/ragEval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ export const ragEvalRouter = router({
// 保存数据
await ctx.evaluationModel.update(input.id, {
status: EvalEvaluationStatus.Success,
evalRecordsUrl: getFullFileUrl(path),
evalRecordsUrl: await getFullFileUrl(path),
});
}

Expand Down
15 changes: 8 additions & 7 deletions src/server/utils/files.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const config = {
S3_ENABLE_PATH_STYLE: false,
S3_PUBLIC_DOMAIN: 'https://example.com',
S3_BUCKET: 'my-bucket',
S3_SET_ACL: true,
};

vi.mock('@/config/file', () => ({
Expand All @@ -17,20 +18,20 @@ vi.mock('@/config/file', () => ({
}));

describe('getFullFileUrl', () => {
it('should return empty string for null or undefined input', () => {
expect(getFullFileUrl(null)).toBe('');
expect(getFullFileUrl(undefined)).toBe('');
it('should return empty string for null or undefined input', async () => {
expect(await getFullFileUrl(null)).toBe('');
expect(await getFullFileUrl(undefined)).toBe('');
});

it('should return correct URL when S3_ENABLE_PATH_STYLE is false', () => {
it('should return correct URL when S3_ENABLE_PATH_STYLE is false', async () => {
const url = 'path/to/file.jpg';
expect(getFullFileUrl(url)).toBe('https://example.com/path/to/file.jpg');
expect(await getFullFileUrl(url)).toBe('https://example.com/path/to/file.jpg');
});

it('should return correct URL when S3_ENABLE_PATH_STYLE is true', () => {
it('should return correct URL when S3_ENABLE_PATH_STYLE is true', async () => {
config.S3_ENABLE_PATH_STYLE = true;
const url = 'path/to/file.jpg';
expect(getFullFileUrl(url)).toBe('https://example.com/my-bucket/path/to/file.jpg');
expect(await getFullFileUrl(url)).toBe('https://example.com/my-bucket/path/to/file.jpg');
config.S3_ENABLE_PATH_STYLE = false;
});
});
9 changes: 8 additions & 1 deletion src/server/utils/files.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import urlJoin from 'url-join';

import { fileEnv } from '@/config/file';
import { S3 } from '@/server/modules/S3';

export const getFullFileUrl = (url?: string | null) => {
export const getFullFileUrl = async (url?: string | null, expiresIn?: number) => {
if (!url) return '';

// If bucket is not set public read, the preview address needs to be regenerated each time
if (!fileEnv.S3_SET_ACL) {
const s3 = new S3();
return await s3.createPreSignedUrlForPreview(url, expiresIn);
}

if (fileEnv.S3_ENABLE_PATH_STYLE) {
return urlJoin(fileEnv.S3_PUBLIC_DOMAIN!, fileEnv.S3_BUCKET!, url);
}
Expand Down

0 comments on commit 7204296

Please sign in to comment.