Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pagination #1530

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions backend/src/routes/v2/model/getModelsSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export const getModelsSearchSchema = z.object({
search: z.string().optional().default(''),
allowTemplating: strictCoerceBoolean(z.boolean().optional()),
schemaId: z.string().optional(),
currentPage: z.coerce.number().optional(),
pageSize: z.coerce.number().optional(),
}),
})

Expand All @@ -35,7 +37,7 @@ registerPath({
content: {
'application/json': {
schema: z.object({
models: z.array(
results: z.array(
z.object({
id: z.string().openapi({ example: 'yolo-abcdef' }),
name: z.string().openapi({ example: 'Yolo v4' }),
Expand All @@ -46,6 +48,7 @@ registerPath({
schemaId: z.string().optional(),
}),
),
totalEntries: z.number(),
}),
},
},
Expand All @@ -61,16 +64,17 @@ export interface ModelSearchResult {
kind: EntryKindKeys
}

interface GetModelsResponse {
export interface GetModelsResponse {
models: Array<ModelSearchResult>
totalEntries: number
}

export const getModelsSearch = [
bodyParser.json(),
async (req: Request, res: Response<GetModelsResponse>) => {
req.audit = AuditInfo.SearchModels
const {
query: { kind, libraries, filters, search, task, allowTemplating, schemaId },
query: { kind, libraries, filters, search, task, allowTemplating, schemaId, currentPage, pageSize },
} = parse(req, getModelsSearchSchema)

const foundModels = await searchModels(
Expand All @@ -82,8 +86,10 @@ export const getModelsSearch = [
task,
allowTemplating,
schemaId,
currentPage,
pageSize,
)
const models = foundModels.map((model) => ({
const models = foundModels.models.map((model) => ({
id: model.id,
name: model.name,
description: model.description,
Expand All @@ -93,6 +99,6 @@ export const getModelsSearch = [

await audit.onSearchModel(req, models)

return res.json({ models })
return res.json({ models, totalEntries: foundModels.totalEntries })
},
]
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@
"questionThree": {
"title": "Question three",
"description": "This is number question",
"type": "number"
"type": "string",
"format": "date"
},
"questionFour": {
"title": "Question four",
Expand Down
32 changes: 18 additions & 14 deletions backend/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ export type CreateModelParams = Pick<
ModelInterface,
'name' | 'description' | 'visibility' | 'settings' | 'kind' | 'collaborators'
>

type ModelSearchResult = {
models: Array<ModelInterface>
totalEntries: number
}

export async function createModel(user: UserInterface, modelParams: CreateModelParams) {
const modelId = convertStringToId(modelParams.name)

Expand Down Expand Up @@ -114,7 +120,9 @@ export async function searchModels(
task?: string,
allowTemplating?: boolean,
schemaId?: string,
): Promise<Array<ModelInterface>> {
currentPage?: number,
pageSize?: number,
): Promise<ModelSearchResult> {
const query: any = {}

if (kind) {
Expand Down Expand Up @@ -164,21 +172,17 @@ export async function searchModels(
}
}

let cursor = ModelModel
// Find only matching documents
.find(query)
const results = await ModelModel.find(query).sort(!search ? { updatedAt: -1 } : { score: { $meta: 'textScore' } })

if (!search) {
// Sort by last updated
cursor = cursor.sort({ updatedAt: -1 })
} else {
// Sort by text search
cursor = cursor.sort({ score: { $meta: 'textScore' } })
}

const results = await cursor
// As we need to authenticate which models the user has permission to, we need to fetch all models that
// match the query first, and then filter them based on pagination.
const auths = await authorisation.models(user, results, ModelAction.View)
return results.filter((_, i) => auths[i].success)
const authorisedResults = results.filter((_, i) => auths[i].success)
if (!pageSize || !currentPage) {
return { models: authorisedResults, totalEntries: authorisedResults.length }
}
const paginatedResults = authorisedResults.slice((currentPage - 1) * pageSize, currentPage * pageSize)
return { models: paginatedResults, totalEntries: authorisedResults.length }
}

export async function getModelCard(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ exports[`routes > model > getModelsSearch > 200 > ok 1`] = `
{
"description": "description",
"id": "test",
"kind": "model",
"name": "name",
"tags": [],
},
],
"totalEntries": 1,
}
`;

Expand All @@ -18,7 +20,7 @@ exports[`routes > model > getModelsSearch > audit > expected call 1`] = `
{
"description": "description",
"id": "test",
"kind": undefined,
"kind": "model",
"name": "name",
"tags": [],
},
Expand Down
22 changes: 19 additions & 3 deletions backend/test/routes/model/getModelsSearch.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,37 @@ import qs from 'qs'
import { describe, expect, test, vi } from 'vitest'

import audit from '../../../src/connectors/audit/__mocks__/index.js'
import { getModelsSearchSchema } from '../../../src/routes/v2/model/getModelsSearch.js'
import { EntryKind } from '../../../src/models/Model.js'
import {
GetModelsResponse,
getModelsSearchSchema,
ModelSearchResult,
} from '../../../src/routes/v2/model/getModelsSearch.js'
import { createFixture, testGet } from '../../testUtils/routes.js'

vi.mock('../../../src/utils/user.js')
vi.mock('../../../src/connectors/audit/index.js')

const mockedModelResult: ModelSearchResult = {
id: 'test',
name: 'name',
description: 'description',
tags: ['tag'],
kind: EntryKind.Model,
}
const mockedResults: GetModelsResponse = {
models: [mockedModelResult],
totalEntries: 1,
}

vi.mock('../../../src/services/model.js', () => ({
searchModels: vi.fn(() => [{ id: 'test', name: 'name', description: 'description', tags: ['tag'] }]),
searchModels: vi.fn(() => mockedResults),
}))

describe('routes > model > getModelsSearch', () => {
test('200 > ok', async () => {
const fixture = createFixture(getModelsSearchSchema)
const res = await testGet(`/api/v2/models/search?${qs.stringify(fixture)}`)

expect(res.statusCode).toBe(200)
expect(res.body).matchSnapshot()
})
Expand Down
20 changes: 14 additions & 6 deletions frontend/actions/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import { ErrorInfo, fetcher } from '../utils/fetcher'

const emptyModelList = []

export interface EntrySearchResults {
models: EntrySearchResult[]
totalEntries: number
}

export interface EntrySearchResult {
id: string
name: string
Expand All @@ -26,6 +31,8 @@ export function useListModels(
search = '',
allowTemplating?: boolean,
schemaId?: string,
currentPage?: number,
pageSize?: number,
) {
const queryParams = {
...(kind && { kind }),
Expand All @@ -35,17 +42,18 @@ export function useListModels(
...(search && { search }),
...(allowTemplating && { allowTemplating }),
...(schemaId && { schemaId }),
...(currentPage && { currentPage }),
...(pageSize && { pageSize }),
}
const { data, isLoading, error, mutate } = useSWR<
{
models: EntrySearchResult[]
},
ErrorInfo
>(Object.entries(queryParams).length > 0 ? `/api/v2/models/search?${qs.stringify(queryParams)}` : null, fetcher)
const { data, isLoading, error, mutate } = useSWR<EntrySearchResults, ErrorInfo>(
Object.entries(queryParams).length > 0 ? `/api/v2/models/search?${qs.stringify(queryParams)}` : null,
fetcher,
)

return {
mutateModels: mutate,
models: data ? data.models : emptyModelList,
totalModels: data ? data.totalEntries : 0,
isModelsLoading: isLoading,
isModelsError: error,
}
Expand Down
Loading
Loading