Skip to content

Commit

Permalink
Merge pull request #364 from mitmedialab/dev
Browse files Browse the repository at this point in the history
Update main
  • Loading branch information
pmalacho-mit authored Jun 20, 2024
2 parents 02fffc7 + 0023dec commit f83d392
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions extensions/src/teachableMachine/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export default class teachableMachine extends extension({
iconURL: "teachable-machine-blocks.png",
insetIconURL: "teachable-machine-blocks-small.svg",
tags: ["Dancing with AI", "Made by PRG"]
}) {
}, "indicators") {
lastUpdate: number;
maxConfidence: number;
modelConfidences: {};
Expand Down Expand Up @@ -166,18 +166,24 @@ export default class teachableMachine extends extension({
}

async startPredicting(modelDataUrl) {
if (!this.predictionState[modelDataUrl]) {
try {
this.predictionState[modelDataUrl] = {};
// https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/image
const { model, type } = await this.initModel(modelDataUrl);
this.predictionState[modelDataUrl].modelType = type;
this.predictionState[modelDataUrl].model = model;
this.runtime.requestToolboxExtensionsUpdate();
} catch (e) {
this.predictionState[modelDataUrl] = {};
console.log("Model initialization failure!", e);
}
const alreadyLoaded = Boolean(this.predictionState[modelDataUrl]);
try {
const indicator = await this.indicate({
type: "warning",
msg: alreadyLoaded ? "Updating model" : "Loading model"
});
this.predictionState[modelDataUrl] = {};
// https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/image
const { model, type } = await this.initModel(modelDataUrl);
this.predictionState[modelDataUrl].modelType = type;
this.predictionState[modelDataUrl].model = model;
this.runtime.requestToolboxExtensionsUpdate();
indicator.close();
this.indicateFor({ type: "success", msg: "Model loaded" }, 1);
} catch (e) {
this.predictionState[modelDataUrl] = {};
console.log("Model initialization failure!", e);
this.indicateFor({ type: "error", msg: "Unable to load model." }, 1);
}
}

Expand All @@ -195,8 +201,9 @@ export default class teachableMachine extends extension({
}

async initModel(modelUrl) {
const modelURL = modelUrl + "model.json";
const metadataURL = modelUrl + "metadata.json";
const avoidCache = `?x=${Date.now()}`;
const modelURL = modelUrl + "model.json" + avoidCache;
const metadataURL = modelUrl + "metadata.json" + avoidCache;
const customMobileNet = await tmImage.load(modelURL, metadataURL);
if ((customMobileNet as any)._metadata.hasOwnProperty('tfjsSpeechCommandsVersion')) {
// customMobileNet.dispose(); // too early to dispose
Expand All @@ -217,24 +224,29 @@ export default class teachableMachine extends extension({
const customPoseNet = await tmPose.load(modelURL, metadataURL);
return { model: customPoseNet, type: this.ModelType.POSE };
} else {
console.log(customMobileNet.getMetadata(), customMobileNet.getTotalClasses(), customMobileNet.getClassLabels());
return { model: customMobileNet, type: this.ModelType.IMAGE };
}
}

useModel(url) {
try {
const modelUrl = this.modelArgumentToURL(url);
this.getPredictionStateOrStartPredicting(modelUrl);
this.getPredictionStateOrStartPredicting(modelUrl, true);
this.updateStageModel(modelUrl);
} catch (e) {
this.teachableImageModel = null;
}
}

modelArgumentToURL(modelArg) {
return modelArg.startsWith('https://teachablemachine.withgoogle.com/models/') ?
modelArg :
`https://teachablemachine.withgoogle.com/models/${modelArg}/`;
modelArgumentToURL(modelArg: string) {
const endpointProvidedFromInterface = "https://teachablemachine.withgoogle.com/models/";
// NOTE: It's possible Google will change this endpoint in the future, and that will break this extension.
// TODO: https://github.com/mitmedialab/prg-extension-boilerplate/issues/343
const redirectEndpoint = "https://storage.googleapis.com/tm-model/";
return modelArg.startsWith(endpointProvidedFromInterface)
? modelArg.replace(endpointProvidedFromInterface, redirectEndpoint)
: redirectEndpoint + modelArg + "/";
}

updateStageModel(modelUrl) {
Expand All @@ -245,9 +257,9 @@ export default class teachableMachine extends extension({
}
}

getPredictionStateOrStartPredicting(modelUrl) {
getPredictionStateOrStartPredicting(modelUrl, override = false) {
const hasPredictionState = this.predictionState.hasOwnProperty(modelUrl);
if (!hasPredictionState) {
if (!hasPredictionState || override) {
this.startPredicting(modelUrl);
return null;
}
Expand Down

0 comments on commit f83d392

Please sign in to comment.