diff --git a/extensions/src/teachableMachine/index.ts b/extensions/src/teachableMachine/index.ts index 13d4e35d4..24677794f 100644 --- a/extensions/src/teachableMachine/index.ts +++ b/extensions/src/teachableMachine/index.ts @@ -25,7 +25,7 @@ export default class teachableMachine extends extension({ iconURL: "teachable-machine-blocks.png", insetIconURL: "teachable-machine-blocks-small.svg", tags: ["Dancing with AI", "Made by PRG"] -}) { +}, "indicators") { lastUpdate: number; maxConfidence: number; modelConfidences: {}; @@ -166,18 +166,24 @@ export default class teachableMachine extends extension({ } async startPredicting(modelDataUrl) { - if (!this.predictionState[modelDataUrl]) { - try { - this.predictionState[modelDataUrl] = {}; - // https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/image - const { model, type } = await this.initModel(modelDataUrl); - this.predictionState[modelDataUrl].modelType = type; - this.predictionState[modelDataUrl].model = model; - this.runtime.requestToolboxExtensionsUpdate(); - } catch (e) { - this.predictionState[modelDataUrl] = {}; - console.log("Model initialization failure!", e); - } + const alreadyLoaded = Boolean(this.predictionState[modelDataUrl]); + try { + const indicator = await this.indicate({ + type: "warning", + msg: alreadyLoaded ? "Updating model" : "Loading model" + }); + this.predictionState[modelDataUrl] = {}; + // https://github.com/googlecreativelab/teachablemachine-community/tree/master/libraries/image + const { model, type } = await this.initModel(modelDataUrl); + this.predictionState[modelDataUrl].modelType = type; + this.predictionState[modelDataUrl].model = model; + this.runtime.requestToolboxExtensionsUpdate(); + indicator.close(); + this.indicateFor({ type: "success", msg: "Model loaded" }, 1); + } catch (e) { + this.predictionState[modelDataUrl] = {}; + console.log("Model initialization failure!", e); + this.indicateFor({ type: "error", msg: "Unable to load model." }, 1); } } @@ -195,8 +201,9 @@ export default class teachableMachine extends extension({ } async initModel(modelUrl) { - const modelURL = modelUrl + "model.json"; - const metadataURL = modelUrl + "metadata.json"; + const avoidCache = `?x=${Date.now()}`; + const modelURL = modelUrl + "model.json" + avoidCache; + const metadataURL = modelUrl + "metadata.json" + avoidCache; const customMobileNet = await tmImage.load(modelURL, metadataURL); if ((customMobileNet as any)._metadata.hasOwnProperty('tfjsSpeechCommandsVersion')) { // customMobileNet.dispose(); // too early to dispose @@ -217,6 +224,7 @@ export default class teachableMachine extends extension({ const customPoseNet = await tmPose.load(modelURL, metadataURL); return { model: customPoseNet, type: this.ModelType.POSE }; } else { + console.log(customMobileNet.getMetadata(), customMobileNet.getTotalClasses(), customMobileNet.getClassLabels()); return { model: customMobileNet, type: this.ModelType.IMAGE }; } } @@ -224,17 +232,21 @@ export default class teachableMachine extends extension({ useModel(url) { try { const modelUrl = this.modelArgumentToURL(url); - this.getPredictionStateOrStartPredicting(modelUrl); + this.getPredictionStateOrStartPredicting(modelUrl, true); this.updateStageModel(modelUrl); } catch (e) { this.teachableImageModel = null; } } - modelArgumentToURL(modelArg) { - return modelArg.startsWith('https://teachablemachine.withgoogle.com/models/') ? - modelArg : - `https://teachablemachine.withgoogle.com/models/${modelArg}/`; + modelArgumentToURL(modelArg: string) { + const endpointProvidedFromInterface = "https://teachablemachine.withgoogle.com/models/"; + // NOTE: It's possible Google will change this endpoint in the future, and that will break this extension. + // TODO: https://github.com/mitmedialab/prg-extension-boilerplate/issues/343 + const redirectEndpoint = "https://storage.googleapis.com/tm-model/"; + return modelArg.startsWith(endpointProvidedFromInterface) + ? modelArg.replace(endpointProvidedFromInterface, redirectEndpoint) + : redirectEndpoint + modelArg + "/"; } updateStageModel(modelUrl) { @@ -245,9 +257,9 @@ export default class teachableMachine extends extension({ } } - getPredictionStateOrStartPredicting(modelUrl) { + getPredictionStateOrStartPredicting(modelUrl, override = false) { const hasPredictionState = this.predictionState.hasOwnProperty(modelUrl); - if (!hasPredictionState) { + if (!hasPredictionState || override) { this.startPredicting(modelUrl); return null; }