Skip to content

Commit

Permalink
Merge pull request #946 from singnet/training
Browse files Browse the repository at this point in the history
Training
  • Loading branch information
MarinaFedy authored Nov 20, 2024
2 parents c61e5da + ad5de4a commit 53fe230
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 75 deletions.
3 changes: 3 additions & 0 deletions src/Redux/actionCreators/ServiceDetailsActions.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { fetchAuthenticatedUser } from "./UserActions";
import { loaderActions } from "./";
import { LoaderContent } from "../../utility/constants/LoaderContent";
import { isEmpty } from "lodash";
import { resetCurrentModelDetails, resetModelList } from "./ServiceTrainingActions";

export const UPDATE_SERVICE_DETAILS = "UPDATE_SERVICE_DETAILS";
export const RESET_SERVICE_DETAILS = "RESET_SERVICE_DETAILS";
Expand Down Expand Up @@ -37,6 +38,8 @@ export const fetchServiceDetails = (orgId, serviceId) => async (dispatch) => {
try {
dispatch(loaderActions.startAppLoader(LoaderContent.FETCH_SERVICE_DETAILS));
dispatch(resetServiceDetails);
dispatch(resetCurrentModelDetails());
dispatch(resetModelList());
const serviceDetails = await fetchServiceDetailsAPI(orgId, serviceId);
dispatch(fetchServiceDetailsSuccess(serviceDetails));
} catch (error) {
Expand Down
27 changes: 14 additions & 13 deletions src/Redux/actionCreators/ServiceTrainingActions.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import { getServiceClient } from "./SDKActions";

export const SET_MODEL_DETAILS = "SET_MODEL_DETAILS";
export const SET_MODELS_LIST = "SET_MODELS_LIST";
export const CLEAN_MODEL_DETAILS = "CLEAN_MODEL_DETAILS";
export const RESET_MODEL_DETAILS = "RESET_MODEL_DETAILS";
export const RESET_MODEL_LIST = "RESET_MODEL_LIST";

export const setCurrentModelDetails = (currentModelDetails) => (dispatch) => {
dispatch({ type: SET_MODEL_DETAILS, payload: currentModelDetails });
Expand All @@ -15,13 +16,15 @@ export const setModelsList = (modelsList) => (dispatch) => {
dispatch({ type: SET_MODELS_LIST, payload: modelsList });
};

export const cleanCurrentModelDetails = () => (dispatch) => {
dispatch({ type: CLEAN_MODEL_DETAILS });
export const resetCurrentModelDetails = () => (dispatch) => {
dispatch({ type: RESET_MODEL_DETAILS });
};

export const createModel = (organizationId, serviceId, address, newModelParams) => async (dispatch) => {
console.log("createModel", organizationId, serviceId, address, newModelParams);
export const resetModelList = () => (dispatch) => {
dispatch({ type: RESET_MODEL_LIST });
};

export const createModel = (organizationId, serviceId, address, newModelParams) => async (dispatch) => {
try {
dispatch(startAppLoader(LoaderContent.CREATE_TRAINING_MODEL));
const serviceName = getServiceNameFromTrainingMethod(newModelParams?.trainingMethod);
Expand All @@ -31,6 +34,7 @@ export const createModel = (organizationId, serviceId, address, newModelParams)
serviceName,
description: newModelParams?.trainingModelDescription,
publicAccess: !newModelParams?.isRestrictAccessModel,
dataLink: newModelParams.dataLink,
address: newModelParams?.isRestrictAccessModel ? newModelParams?.accessAddresses : [],
};

Expand Down Expand Up @@ -89,7 +93,7 @@ export const deleteModel =
const serviceClient = await dispatch(getServiceClient(organizationId, serviceId));
await serviceClient.deleteModel(params);
await dispatch(getTrainingModels(organizationId, serviceId, address));
dispatch(cleanCurrentModelDetails());
dispatch(resetCurrentModelDetails());
} catch (error) {
// TODO
} finally {
Expand All @@ -113,8 +117,6 @@ const getServiceNameFromTrainingMethod = (trainingMethod) => {
export const getTrainingModelStatus =
({ organizationId, serviceId, modelId, name, method, address }) =>
async (dispatch) => {
console.log("getTrainingModels: ", organizationId, serviceId, modelId, method, name, address);

try {
dispatch(startAppLoader(LoaderContent.FETCH_TRAINING_EXISTING_MODEL));
const serviceClient = await dispatch(getServiceClient(organizationId, serviceId));
Expand All @@ -124,9 +126,8 @@ export const getTrainingModelStatus =
name,
address,
};
const existingModelStatus = await serviceClient.getModelStatus(params);
console.log("existingModelStatus: ", existingModelStatus);
return existingModelStatus;
const numberModelStatus = await serviceClient.getModelStatus(params);
return modelStatus[numberModelStatus];
} catch (err) {
// TODO
} finally {
Expand Down Expand Up @@ -159,8 +160,8 @@ export const getTrainingModels = (organizationId, serviceId, address) => async (
address,
};

const numberModelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams));
return { ...model, status: modelStatus[numberModelStatus] };
const modelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams));
return { ...model, status: modelStatus };
})
);

Expand Down
15 changes: 13 additions & 2 deletions src/Redux/reducers/ServiceTrainingReducer.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import { serviceTrainingActions } from "../actionCreators";

export const modelStatus = {
IN_PROGRESS: "IN_PROGRESS",
COMPLETED: "COMPLETED",
CREATED: "CREATED",
ERRORED: "ERRORED",
DELETED: "DELETED",
};

const currentModellInitialState = {
modelId: "",
methodName: "",
Expand All @@ -15,20 +23,23 @@ const currentModellInitialState = {

const trainingModelInitialState = {
currentModel: currentModellInitialState,
modelsList: [],
modelsList: undefined,
};

const serviceTrainingReducer = (state = trainingModelInitialState, action) => {
switch (action.type) {
case serviceTrainingActions.SET_MODEL_DETAILS: {
return { ...state, currentModel: action.payload };
}
case serviceTrainingActions.CLEAN_MODEL_DETAILS: {
case serviceTrainingActions.RESET_MODEL_DETAILS: {
return { ...state, currentModel: currentModellInitialState };
}
case serviceTrainingActions.SET_MODELS_LIST: {
return { ...state, modelsList: action.payload };
}
case serviceTrainingActions.RESET_MODEL_LIST: {
return { ...state, modelsList: trainingModelInitialState.modelsList };
}
default: {
return state;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,20 @@ import AlertBox from "../../../../../common/AlertBox";
import StyledButton from "../../../../../common/StyledButton";
import StyledLinearProgress from "../../../../../common/StyledLinearProgress";
import { useStyles } from "./styles";
import { getIsTrainingAvailable } from "../../../../../../Redux/actionCreators/ServiceDetailsActions";
import { useDispatch, useSelector } from "react-redux";
import { getTrainingModels } from "../../../../../../Redux/actionCreators/ServiceTrainingActions";
import { currentServiceDetails } from "../../../../../../Redux/reducers/ServiceDetailsReducer";
import { isUndefined } from "lodash";
import { updateMetamaskWallet } from "../../../../../../Redux/actionCreators/UserActions";

const ActiveSession = ({ classes, freeCallsRemaining, handleComplete, freeCallsAllowed, isServiceAvailable }) => {
const dispatch = useDispatch();
const { detailsTraining } = useSelector((state) => state.serviceDetailsReducer);
const { org_id, service_id } = useSelector((state) => currentServiceDetails(state));
const { modelsList } = useSelector((state) => state.serviceTrainingReducer);
const isLoggedIn = useSelector((state) => state.userReducer.login.isLoggedIn);

const [showTooltip, setShowTooltip] = useState(false);

const progressValue = () => (freeCallsRemaining / freeCallsAllowed) * 100;
Expand All @@ -22,6 +34,13 @@ const ActiveSession = ({ classes, freeCallsRemaining, handleComplete, freeCallsA
setShowTooltip(false);
};

const isTrainingAvailable = getIsTrainingAvailable(detailsTraining, isLoggedIn);

const handleRequestModels = async () => {
const address = await dispatch(updateMetamaskWallet());
await dispatch(getTrainingModels(org_id, service_id, address));
};

return (
<div className={classes.activeSessionContainer}>
<AlertBox
Expand All @@ -42,8 +61,16 @@ const ActiveSession = ({ classes, freeCallsRemaining, handleComplete, freeCallsA
onClose={handleTooltipClose}
classes={{ tooltip: classes.tooltip }}
>
<div>
<div className={classes.activeSectionButtons}>
<StyledButton type="blue" btnText="run for free" onClick={handleComplete} disabled={!isServiceAvailable} />
{isTrainingAvailable && isUndefined(modelsList) && (
<StyledButton
type="transparent"
btnText="request my models"
onClick={handleRequestModels}
disabled={!isServiceAvailable}
/>
)}
</div>
</Tooltip>
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,9 @@ export const useStyles = (theme) => ({
tooltip: {
fontSize: 14,
},
activeSectionButtons: {
display: "flex",
gap: 20,
justifyContent: "center",
},
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { createServiceClient, callTypes } from "../../../../utility/sdk";
import ThirdPartyServiceErrorBoundary from "./ThirdPartyServiceErrorBoundary";
import { channelInfo } from "../../../../Redux/reducers/UserReducer";
import { isEmpty } from "lodash";
import { modelStatus } from "../../../../Redux/reducers/ServiceTrainingReducer";

class ThirdPartyAIService extends Component {
state = {
Expand Down Expand Up @@ -60,12 +61,14 @@ class ThirdPartyAIService extends Component {
if (isEmpty(modelsList)) {
return [];
}
return modelsList.map((model) => {
return {
value: model.modelId,
label: model.modelName,
};
});
return modelsList
.filter((model) => model.status === modelStatus.COMPLETED)
.map((model) => {
return {
value: model.modelId,
label: model.modelName,
};
});
}

render() {
Expand Down
51 changes: 35 additions & 16 deletions src/components/ServiceDetails/ExistingModel/ModelDetails/index.js
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
import React, { useState } from "react";
import { useDispatch } from "react-redux";
import { useDispatch, useSelector } from "react-redux";

import { withStyles } from "@mui/styles";
import { useStyles } from "./styles";

import Button from "@mui/material/Button";
// import EditIcon from "@mui/icons-material/Edit";
import DeleteIcon from "@mui/icons-material/Delete";
import NearMeOutlinedIcon from "@mui/icons-material/NearMeOutlined";
import Box from "@mui/material/Box";
import Typography from "@mui/material/Typography";
import Modal from "@mui/material/Modal";
import StyledButton from "../../../common/StyledButton";
import { setCurrentModelDetails, deleteModel } from "../../../../Redux/actionCreators/ServiceTrainingActions";
import {
setCurrentModelDetails,
deleteModel,
getTrainingModelStatus,
setModelsList,
} from "../../../../Redux/actionCreators/ServiceTrainingActions";
import { useLocation, useNavigate, useParams } from "react-router-dom";

export const modelStatus = {
IN_PROGRESS: "IN_PROGRESS",
COMPLETED: "COMPLETED",
CREATED: "CREATED",
ERRORED: "ERRORED",
DELETED: "DELETED",
};
import { modelStatus } from "../../../../Redux/reducers/ServiceTrainingReducer";

const ModelDetails = ({ classes, openEditModel, model, address }) => {
const dispatch = useDispatch();
const navigate = useNavigate();
const location = useLocation();

const { modelsList } = useSelector((state) => state.serviceTrainingReducer);
const { orgId, serviceId } = useParams();

const [open, setOpen] = useState(false);
Expand All @@ -51,6 +49,27 @@ const ModelDetails = ({ classes, openEditModel, model, address }) => {
navigate(location.pathname.split("tab/")[0] + "tab/" + 0); //TODO
};

const handleGetModelStatus = async () => {
const getModelStatusParams = {
organizationId: orgId,
serviceId,
modelId: model.modelId,
name: model.serviceName,
method: model.methodName,
address,
};

const newModelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams));
const updatedModelList = modelsList.map((modelBeforeUpdating) => {
let modelForUpdating = modelBeforeUpdating;
if (modelForUpdating.modelId === model.modelId) {
modelForUpdating.status = newModelStatus;
}
return modelForUpdating;
});
await dispatch(setModelsList(updatedModelList));
};

return (
<>
<div className={classes.modelDetailsContainer}>
Expand Down Expand Up @@ -80,11 +99,11 @@ const ModelDetails = ({ classes, openEditModel, model, address }) => {
</div>
</div>
<div className={classes.actionButtons}>
<Button className={classes.inferenceBtn} onClick={handleSetModel} disabled={!isInferenceAvailable}>
<NearMeOutlinedIcon />
<span>Inference</span>
</Button>
<div>
<div className={classes.actionButtonsGroup}>
<StyledButton btnText="Inference" disabled={!isInferenceAvailable} onClick={handleSetModel} />
<StyledButton type="transparentBlueBorder" btnText="Get status" onClick={handleGetModelStatus} />
</div>
<div className={classes.actionButtonsGroup}>
{/* <Button className={classes.updateBtn} onClick={handleEditModel}>
<EditIcon />
<span>Edit</span>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { modelStatus } from "./index";
import { modelStatus } from "../../../../Redux/reducers/ServiceTrainingReducer";

export const useStyles = (theme) => ({
modelDetailsContainer: {
Expand Down Expand Up @@ -86,6 +86,10 @@ export const useStyles = (theme) => ({
textTransform: "capitalize",
},
},
actionButtonsGroup: {
display: "flex",
gap: 20,
},
updateBtn: { color: theme.palette.text.darkShadedGray },
inferenceBtn: {
background: theme.palette.text.primary,
Expand Down
17 changes: 7 additions & 10 deletions src/components/ServiceDetails/ExistingModel/index.js
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import React, { Fragment, useEffect, useState } from "react";
import { useDispatch, useSelector } from "react-redux";
import { withStyles } from "@mui/styles";
import ModelDetails from "./ModelDetails";
import { useStyles } from "./styles";
import ConnectMetamask from "../ConnectMetamask";
import ModelDetails from "./ModelDetails";
import { loaderActions, userActions } from "../../../Redux/actionCreators";
import { LoaderContent } from "../../../utility/constants/LoaderContent";
import { currentServiceDetails } from "../../../Redux/reducers/ServiceDetailsReducer";
import Typography from "@mui/material/Typography";
import AlertBox, { alertTypes } from "../../common/AlertBox";
import Card from "../../common/Card";
import { getTrainingModels } from "../../../Redux/actionCreators/ServiceTrainingActions";
import { isEmpty } from "lodash";
import { isUndefined } from "lodash";
import StyledButton from "../../common/StyledButton";

const ExistingModel = ({ classes, openEditModel }) => {
const { org_id, service_id } = useSelector((state) => currentServiceDetails(state));
const { modelsList } = useSelector((state) => state.serviceTrainingReducer);
const { address } = useSelector((state) => state.userReducer.wallet);

const [existingModels, setExistingModels] = useState(modelsList);
const [metamaskConnected, setMetamaskConnected] = useState(!isEmpty(address));
const [alert, setAlert] = useState({});
const dispatch = useDispatch();

Expand All @@ -33,18 +31,17 @@ const ExistingModel = ({ classes, openEditModel }) => {
dispatch(loaderActions.startAppLoader(LoaderContent.CONNECT_METAMASK));
const address = await dispatch(userActions.updateMetamaskWallet());
await dispatch(getTrainingModels(org_id, service_id, address));
setMetamaskConnected(true);
} catch (error) {
setAlert({ type: alertTypes.ERROR, message: "Unable to fetch existing models. Please try again" });
dispatch(loaderActions.stopAppLoader());
}
};

const ModelList = () => {
if (!existingModels.length) {
if (!existingModels?.length) {
return (
<div className={classes.noDataFoundTxt}>
<Typography>No data found</Typography>
<h2>No data found</h2>
</div>
);
}
Expand All @@ -62,11 +59,11 @@ const ExistingModel = ({ classes, openEditModel }) => {
<Card
header="Existing Model"
children={
metamaskConnected ? (
!isUndefined(modelsList) ? (
<ModelList />
) : (
<Fragment>
<ConnectMetamask handleConnectMM={handleConnectMM} />
<StyledButton btnText="Get models" onClick={handleConnectMM} />
<AlertBox type={alert.type} message={alert.message} />
</Fragment>
)
Expand Down
Loading

0 comments on commit 53fe230

Please sign in to comment.