diff --git a/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java b/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java index 15b83108c8..f4aefe6c79 100644 --- a/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java +++ b/server/controller/src/test/java/ai/starwhale/mlops/api/DataStoreControllerTest.java @@ -22,16 +22,19 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import ai.starwhale.mlops.api.protocol.datastore.ColumnDesc; +import ai.starwhale.mlops.api.protocol.datastore.CreateCheckpointRequest; import ai.starwhale.mlops.api.protocol.datastore.ListTablesRequest; import ai.starwhale.mlops.api.protocol.datastore.QueryTableRequest; import ai.starwhale.mlops.api.protocol.datastore.RecordDesc; @@ -41,6 +44,7 @@ import ai.starwhale.mlops.api.protocol.datastore.TableQueryFilterDesc; import ai.starwhale.mlops.api.protocol.datastore.TableQueryOperandDesc; import ai.starwhale.mlops.api.protocol.datastore.UpdateTableRequest; +import ai.starwhale.mlops.datastore.Checkpoint; import ai.starwhale.mlops.datastore.ColumnHintsDesc; import ai.starwhale.mlops.datastore.ColumnSchemaDesc; import ai.starwhale.mlops.datastore.DataStore; @@ -1698,4 +1702,34 @@ public void testExportTable(int limit, int expectedCall) throws IOException { verify(outputStream, times(expectedCall)).write("hello".getBytes()); } + @Test + public void testCheckpoint() { + DataStoreController controller = new DataStoreController(); + DataStore dataStore = mock(DataStore.class); + controller.setDataStore(dataStore); + var createReq = CreateCheckpointRequest.builder().table("t1").userData("foo").build(); + when(dataStore.createCheckpoint(createReq)).thenReturn( + Checkpoint.builder().userData("foo").revision(7L).build()); + var createResp = controller.createCheckpoint(createReq); + + var checkpoint = createResp.getBody().getData(); + assertEquals("foo", checkpoint.getUserData()); + assertEquals("7", checkpoint.getId()); + + when(dataStore.getCheckpoints("t1")).thenReturn(List.of( + Checkpoint.builder().userData("foo").revision(7L).build(), + Checkpoint.builder().userData("bar").revision(8L).build())); + + var listResp = controller.getCheckpoints("t1"); + var checkpoints = listResp.getBody().getData(); + assertEquals(2, checkpoints.size()); + assertEquals("foo", checkpoints.get(0).getUserData()); + assertEquals("7", checkpoints.get(0).getId()); + assertEquals("bar", checkpoints.get(1).getUserData()); + assertEquals("8", checkpoints.get(1).getId()); + + doNothing().when(dataStore).deleteCheckpoint("t1", 7L); + controller.deleteCheckpoint("t1", "7"); + verify(dataStore).deleteCheckpoint("t1", 7L); + } }