diff --git a/backend/src/routes/v2/model/getModelsSearch.ts b/backend/src/routes/v2/model/getModelsSearch.ts index b961ff52c..e5cb98ca5 100644 --- a/backend/src/routes/v2/model/getModelsSearch.ts +++ b/backend/src/routes/v2/model/getModelsSearch.ts @@ -20,6 +20,8 @@ export const getModelsSearchSchema = z.object({ search: z.string().optional().default(''), allowTemplating: strictCoerceBoolean(z.boolean().optional()), schemaId: z.string().optional(), + currentPage: z.coerce.number().optional(), + pageSize: z.coerce.number().optional(), }), }) @@ -35,7 +37,7 @@ registerPath({ content: { 'application/json': { schema: z.object({ - models: z.array( + results: z.array( z.object({ id: z.string().openapi({ example: 'yolo-abcdef' }), name: z.string().openapi({ example: 'Yolo v4' }), @@ -46,6 +48,7 @@ registerPath({ schemaId: z.string().optional(), }), ), + totalEntries: z.number(), }), }, }, @@ -61,8 +64,9 @@ export interface ModelSearchResult { kind: EntryKindKeys } -interface GetModelsResponse { +export interface GetModelsResponse { models: Array + totalEntries: number } export const getModelsSearch = [ @@ -70,7 +74,7 @@ export const getModelsSearch = [ async (req: Request, res: Response) => { req.audit = AuditInfo.SearchModels const { - query: { kind, libraries, filters, search, task, allowTemplating, schemaId }, + query: { kind, libraries, filters, search, task, allowTemplating, schemaId, currentPage, pageSize }, } = parse(req, getModelsSearchSchema) const foundModels = await searchModels( @@ -82,8 +86,10 @@ export const getModelsSearch = [ task, allowTemplating, schemaId, + currentPage, + pageSize, ) - const models = foundModels.map((model) => ({ + const models = foundModels.models.map((model) => ({ id: model.id, name: model.name, description: model.description, @@ -93,6 +99,6 @@ export const getModelsSearch = [ await audit.onSearchModel(req, models) - return res.json({ models }) + return res.json({ models, totalEntries: foundModels.totalEntries }) }, ] diff --git a/backend/src/scripts/example_schemas/minimal_model_schema.json b/backend/src/scripts/example_schemas/minimal_model_schema.json index 9a72e1ae2..81ee5941d 100644 --- a/backend/src/scripts/example_schemas/minimal_model_schema.json +++ b/backend/src/scripts/example_schemas/minimal_model_schema.json @@ -84,7 +84,8 @@ "questionThree": { "title": "Question three", "description": "This is number question", - "type": "number" + "type": "string", + "format": "date" }, "questionFour": { "title": "Question four", diff --git a/backend/src/services/model.ts b/backend/src/services/model.ts index 68a0f92c5..c4d3fd7a1 100644 --- a/backend/src/services/model.ts +++ b/backend/src/services/model.ts @@ -27,6 +27,12 @@ export type CreateModelParams = Pick< ModelInterface, 'name' | 'description' | 'visibility' | 'settings' | 'kind' | 'collaborators' > + +type ModelSearchResult = { + models: Array + totalEntries: number +} + export async function createModel(user: UserInterface, modelParams: CreateModelParams) { const modelId = convertStringToId(modelParams.name) @@ -114,7 +120,9 @@ export async function searchModels( task?: string, allowTemplating?: boolean, schemaId?: string, -): Promise> { + currentPage?: number, + pageSize?: number, +): Promise { const query: any = {} if (kind) { @@ -164,21 +172,17 @@ export async function searchModels( } } - let cursor = ModelModel - // Find only matching documents - .find(query) + const results = await ModelModel.find(query).sort(!search ? { updatedAt: -1 } : { score: { $meta: 'textScore' } }) - if (!search) { - // Sort by last updated - cursor = cursor.sort({ updatedAt: -1 }) - } else { - // Sort by text search - cursor = cursor.sort({ score: { $meta: 'textScore' } }) - } - - const results = await cursor + // As we need to authenticate which models the user has permission to, we need to fetch all models that + // match the query first, and then filter them based on pagination. const auths = await authorisation.models(user, results, ModelAction.View) - return results.filter((_, i) => auths[i].success) + const authorisedResults = results.filter((_, i) => auths[i].success) + if (!pageSize || !currentPage) { + return { models: authorisedResults, totalEntries: authorisedResults.length } + } + const paginatedResults = authorisedResults.slice((currentPage - 1) * pageSize, currentPage * pageSize) + return { models: paginatedResults, totalEntries: authorisedResults.length } } export async function getModelCard( diff --git a/backend/test/routes/model/__snapshots__/getModelsSearch.spec.ts.snap b/backend/test/routes/model/__snapshots__/getModelsSearch.spec.ts.snap index 52d6454d8..1b578deab 100644 --- a/backend/test/routes/model/__snapshots__/getModelsSearch.spec.ts.snap +++ b/backend/test/routes/model/__snapshots__/getModelsSearch.spec.ts.snap @@ -6,10 +6,12 @@ exports[`routes > model > getModelsSearch > 200 > ok 1`] = ` { "description": "description", "id": "test", + "kind": "model", "name": "name", "tags": [], }, ], + "totalEntries": 1, } `; @@ -18,7 +20,7 @@ exports[`routes > model > getModelsSearch > audit > expected call 1`] = ` { "description": "description", "id": "test", - "kind": undefined, + "kind": "model", "name": "name", "tags": [], }, diff --git a/backend/test/routes/model/getModelsSearch.spec.ts b/backend/test/routes/model/getModelsSearch.spec.ts index b770903a8..035c7184f 100644 --- a/backend/test/routes/model/getModelsSearch.spec.ts +++ b/backend/test/routes/model/getModelsSearch.spec.ts @@ -2,21 +2,37 @@ import qs from 'qs' import { describe, expect, test, vi } from 'vitest' import audit from '../../../src/connectors/audit/__mocks__/index.js' -import { getModelsSearchSchema } from '../../../src/routes/v2/model/getModelsSearch.js' +import { EntryKind } from '../../../src/models/Model.js' +import { + GetModelsResponse, + getModelsSearchSchema, + ModelSearchResult, +} from '../../../src/routes/v2/model/getModelsSearch.js' import { createFixture, testGet } from '../../testUtils/routes.js' vi.mock('../../../src/utils/user.js') vi.mock('../../../src/connectors/audit/index.js') +const mockedModelResult: ModelSearchResult = { + id: 'test', + name: 'name', + description: 'description', + tags: ['tag'], + kind: EntryKind.Model, +} +const mockedResults: GetModelsResponse = { + models: [mockedModelResult], + totalEntries: 1, +} + vi.mock('../../../src/services/model.js', () => ({ - searchModels: vi.fn(() => [{ id: 'test', name: 'name', description: 'description', tags: ['tag'] }]), + searchModels: vi.fn(() => mockedResults), })) describe('routes > model > getModelsSearch', () => { test('200 > ok', async () => { const fixture = createFixture(getModelsSearchSchema) const res = await testGet(`/api/v2/models/search?${qs.stringify(fixture)}`) - expect(res.statusCode).toBe(200) expect(res.body).matchSnapshot() }) diff --git a/frontend/actions/model.ts b/frontend/actions/model.ts index 4a037cbd9..0258d77fc 100644 --- a/frontend/actions/model.ts +++ b/frontend/actions/model.ts @@ -6,6 +6,11 @@ import { ErrorInfo, fetcher } from '../utils/fetcher' const emptyModelList = [] +export interface EntrySearchResults { + models: EntrySearchResult[] + totalEntries: number +} + export interface EntrySearchResult { id: string name: string @@ -26,6 +31,8 @@ export function useListModels( search = '', allowTemplating?: boolean, schemaId?: string, + currentPage?: number, + pageSize?: number, ) { const queryParams = { ...(kind && { kind }), @@ -35,17 +42,18 @@ export function useListModels( ...(search && { search }), ...(allowTemplating && { allowTemplating }), ...(schemaId && { schemaId }), + ...(currentPage && { currentPage }), + ...(pageSize && { pageSize }), } - const { data, isLoading, error, mutate } = useSWR< - { - models: EntrySearchResult[] - }, - ErrorInfo - >(Object.entries(queryParams).length > 0 ? `/api/v2/models/search?${qs.stringify(queryParams)}` : null, fetcher) + const { data, isLoading, error, mutate } = useSWR( + Object.entries(queryParams).length > 0 ? `/api/v2/models/search?${qs.stringify(queryParams)}` : null, + fetcher, + ) return { mutateModels: mutate, models: data ? data.models : emptyModelList, + totalModels: data ? data.totalEntries : 0, isModelsLoading: isLoading, isModelsError: error, } diff --git a/frontend/pages/index.tsx b/frontend/pages/index.tsx index 22fc42e8a..f20669d2e 100644 --- a/frontend/pages/index.tsx +++ b/frontend/pages/index.tsx @@ -20,6 +20,7 @@ import { useRouter } from 'next/router' import React, { ChangeEvent, Fragment, useCallback, useEffect, useState } from 'react' import ChipSelector from 'src/common/ChipSelector' import Loading from 'src/common/Loading' +import PaginationSelector from 'src/common/PaginationSelector' import Title from 'src/common/Title' import useDebounce from 'src/hooks/useDebounce' import EntryList from 'src/marketplace/EntryList' @@ -39,30 +40,55 @@ export default function Marketplace() { const [selectedTask, setSelectedTask] = useState('') const [selectedTypes, setSelectedTypes] = useState([]) const [selectedTab, setSelectedTab] = useState(EntryKind.MODEL) + const [currentPage, setCurrentPage] = useState(1) + const [pageSize, setPageSize] = useState(10) const debouncedFilter = useDebounce(filter, 250) - const { models, isModelsError, isModelsLoading } = useListModels( + const { models, totalModels, isModelsError, isModelsLoading } = useListModels( EntryKind.MODEL, selectedTypes, selectedTask, selectedLibraries, debouncedFilter, + undefined, + undefined, + currentPage, + pageSize, ) const { models: dataCards, + totalModels: totalDataCards, isModelsError: isDataCardsError, isModelsLoading: isDataCardsLoading, - } = useListModels(EntryKind.DATA_CARD, selectedTypes, selectedTask, selectedLibraries, debouncedFilter) + } = useListModels( + EntryKind.DATA_CARD, + selectedTypes, + selectedTask, + selectedLibraries, + debouncedFilter, + undefined, + undefined, + currentPage, + pageSize, + ) const theme = useTheme() const router = useRouter() - const { filter: filterFromQuery, task: taskFromQuery, libraries: librariesFromQuery } = router.query + const { + filter: filterFromQuery, + task: taskFromQuery, + libraries: librariesFromQuery, + currentPage: currentPageFromQuery, + pageSize: pageSizeFromQuery, + } = router.query useEffect(() => { - if (filterFromQuery) setFilter(filterFromQuery as string) - if (taskFromQuery) setSelectedTask(taskFromQuery as string) + if (filterFromQuery && typeof filterFromQuery === 'string') setFilter(filterFromQuery) + if (taskFromQuery && typeof taskFromQuery === 'string') setSelectedTask(taskFromQuery) + if (currentPageFromQuery && typeof currentPageFromQuery === 'string') setCurrentPage(parseInt(currentPageFromQuery)) + if (pageSizeFromQuery && typeof pageSizeFromQuery === 'string') setPageSize(parseInt(pageSizeFromQuery)) if (librariesFromQuery) { let librariesAsArray: string[] = [] if (typeof librariesFromQuery === 'string') { @@ -72,7 +98,7 @@ export default function Marketplace() { } setSelectedLibraries([...librariesAsArray]) } - }, [filterFromQuery, taskFromQuery, librariesFromQuery]) + }, [filterFromQuery, taskFromQuery, librariesFromQuery, currentPageFromQuery, pageSizeFromQuery]) const handleSelectedTypesOnChange = useCallback((selected: string[]) => { if (selected.length > 0) { @@ -98,6 +124,24 @@ export default function Marketplace() { [router], ) + const handleCurrentPageChange = useCallback( + (newPage: number) => { + if (newPage > 0) { + setCurrentPage(newPage) + updateQueryParams('currentPage', newPage.toString()) + } + }, + [updateQueryParams], + ) + + const handlePageSizeChange = useCallback( + (newValue: number) => { + setPageSize(newValue) + updateQueryParams('pageSize', newValue.toString()) + }, + [updateQueryParams], + ) + const handleFilterChange = useCallback( (e: ChangeEvent) => { setFilter(e.target.value) @@ -221,38 +265,47 @@ export default function Marketplace() { setSelectedTab(EntryKind.MODEL)} /> setSelectedTab(EntryKind.DATA_CARD)} /> - {isModelsLoading && } - {!isModelsLoading && selectedTab === EntryKind.MODEL && ( -
- -
- )} - {!isDataCardsLoading && selectedTab === EntryKind.DATA_CARD && ( -
- -
- )} + + {(isModelsLoading || isDataCardsLoading) && } + {!isModelsLoading && selectedTab === EntryKind.MODEL && ( +
+ +
+ )} + {!isDataCardsLoading && selectedTab === EntryKind.DATA_CARD && ( +
+ +
+ )} + handleCurrentPageChange(newValue)} + totalEntries={totalModels} + pageSize={pageSize} + onPageSizeChange={(newValue) => handlePageSizeChange(newValue)} + /> +
diff --git a/frontend/src/common/PaginationSelector.tsx b/frontend/src/common/PaginationSelector.tsx new file mode 100644 index 000000000..c43028ba1 --- /dev/null +++ b/frontend/src/common/PaginationSelector.tsx @@ -0,0 +1,121 @@ +import { ArrowBack, ArrowForward } from '@mui/icons-material' +import { + FormControl, + IconButton, + InputLabel, + MenuItem, + Select, + SelectChangeEvent, + Stack, + Typography, +} from '@mui/material' +import { useCallback, useEffect, useMemo } from 'react' + +interface PaginationSelectorProps { + currentPage: number + onCurrentPageChange: (newValue: number) => void + totalEntries: number + pageSize: number + onPageSizeChange: (newValue: number) => void +} + +export default function PaginationSelector({ + currentPage, + onCurrentPageChange, + totalEntries, + pageSize, + onPageSizeChange, +}: PaginationSelectorProps) { + const lastPage = useMemo(() => { + return Math.ceil(totalEntries / parseInt(pageSize.toString())) + }, [pageSize, totalEntries]) + + useEffect(() => { + if (currentPage > lastPage) { + onCurrentPageChange(lastPage) + } + }, [currentPage, lastPage, onCurrentPageChange]) + + const handleBackPage = useCallback(() => { + if (currentPage > 1) { + onCurrentPageChange(currentPage - 1) + } + }, [currentPage, onCurrentPageChange]) + + const handleForwardPage = useCallback(() => { + if (currentPage < lastPage) { + onCurrentPageChange(currentPage + 1) + } + }, [currentPage, lastPage, onCurrentPageChange]) + + const handleManualPageChange = useCallback( + (event: SelectChangeEvent) => { + onCurrentPageChange(parseInt(event.target.value)) + }, + [onCurrentPageChange], + ) + + const handlePageSizeChange = useCallback( + (event: SelectChangeEvent) => { + onPageSizeChange(parseInt(event.target.value)) + }, + [onPageSizeChange], + ) + + const pageNumberOptions = useMemo(() => { + return [...Array(lastPage | 0)].map((_, index) => { + return ( + + {index + 1} + + ) + }) + }, [lastPage]) + + return ( + + + + + Page: + + of {Math.round(lastPage)} + + Page size + + + + + + + ) +}