Skip to content

Commit

Permalink
caching updates
Browse files Browse the repository at this point in the history
  • Loading branch information
fulpm committed Sep 26, 2024
1 parent 2bfd64e commit 8dbc9b0
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
2 changes: 1 addition & 1 deletion codegen/templates/client/trainee.njk
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* NOTE: This file is auto generated, do not modify manually.
*/
import type { Trainee, Session } from "@/types";
import type { Session, Trainee } from "@/types";
import type * as schemas from "@/types/schemas";
import { AbstractHowsoClient } from "./base";

Expand Down
2 changes: 1 addition & 1 deletion src/client/trainee.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* NOTE: This file is auto generated, do not modify manually.
*/
import type { Trainee, Session } from "@/types";
import type { Session, Trainee } from "@/types";
import type * as schemas from "@/types/schemas";
import { AbstractHowsoClient } from "./base";

Expand Down
39 changes: 28 additions & 11 deletions src/client/worker/client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Session, Trainee } from "@/types";
import type { FeatureAttributesIndex, Session, Trainee } from "@/types";
import type * as schemas from "@/types/schemas";
import {
AmalgamError,
Expand Down Expand Up @@ -158,8 +158,8 @@ export class HowsoWorkerClient extends TraineeClient {
*/
protected async getTraineeFromEngine(traineeId: string): Promise<Trainee> {
const [metadata, features] = await Promise.all([
this.execute<Record<string, any>>(traineeId, "get_metadata", {}),
this.execute<Record<string, schemas.FeatureAttributes>>(traineeId, "get_feature_attributes", {}),
this.execute<any>(traineeId, "get_metadata", {}),
this.execute<FeatureAttributesIndex>(traineeId, "get_feature_attributes", {}),
]);
if (!metadata?.payload) {
throw new HowsoError(`Trainee "${traineeId}" not found.`, "not_found");
Expand Down Expand Up @@ -345,7 +345,7 @@ export class HowsoWorkerClient extends TraineeClient {
await this.execute(traineeId, "set_metadata", { metadata });

// Set the feature attributes
const { payload: feature_attributes } = await this.execute<Record<string, schemas.FeatureAttributes>>(
const { payload: feature_attributes } = await this.execute<FeatureAttributesIndex>(
traineeId,
"set_feature_attributes",
{
Expand Down Expand Up @@ -456,11 +456,6 @@ export class HowsoWorkerClient extends TraineeClient {
}, []);
}

/**
* Set the Trainee's feature attributes.
* @param traineeId The Trainee identifier.
* @param request The operation parameters.
*/
public async setFeatureAttributes(traineeId: string, request: schemas.SetFeatureAttributesRequest) {
const response = await super.setFeatureAttributes(traineeId, request);
// Also update cached Trainee
Expand All @@ -471,8 +466,30 @@ export class HowsoWorkerClient extends TraineeClient {
return response;
}

public async addFeature(traineeId: string, request: schemas.AddFeatureRequest) {
const response = await super.addFeature(traineeId, request);
// Also update cached Trainee
const trainee = this.cache.get(traineeId)?.trainee;
if (trainee) {
const { payload: features } = await this.getFeatureAttributes(traineeId);
trainee.features = features;
}
return response;
}

public async removeFeature(traineeId: string, request: schemas.RemoveFeatureRequest) {
const response = await super.removeFeature(traineeId, request);
// Also update cached Trainee
const trainee = this.cache.get(traineeId)?.trainee;
if (trainee) {
const { payload: features } = await this.getFeatureAttributes(traineeId);
trainee.features = features;
}
return response;
}

/**
* Batch train data into the Trainee.
* Train data into the Trainee using batched requests to the Engine.
* @param traineeId The Trainee identifier.
* @param request The train parameters.
*/
Expand All @@ -492,7 +509,7 @@ export class HowsoWorkerClient extends TraineeClient {
async function* (this: HowsoWorkerClient, size: number) {
let offset = 0;
while (offset < cases.length) {
await this.execute<any | null>(trainee.id, "train", {
await this.train(trainee.id, {
...rest,
cases: cases.slice(offset, offset + size),
});
Expand Down

0 comments on commit 8dbc9b0

Please sign in to comment.