Skip to content

Commit

Permalink
Everything is hooked up
Browse files Browse the repository at this point in the history
  • Loading branch information
pmalacho-mit committed Nov 27, 2023
1 parent a0f4f6a commit 235c5ed
Show file tree
Hide file tree
Showing 34 changed files with 207 additions and 1,804 deletions.
14 changes: 11 additions & 3 deletions extensions/src/common/extension/mixins/base/scratchInfo/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { castToType } from "$common/cast";
import CustomArgumentManager from "$common/extension/mixins/configurable/customArguments/CustomArgumentManager";
import { ArgumentType, BlockType } from "$common/types/enums";
import { BlockOperation, ValueOf, Menu, ExtensionMetadata, ExtensionBlockMetadata, ExtensionMenuMetadata, DynamicMenu, BlockMetadata, } from "$common/types";
import { BlockOperation, ValueOf, Menu, ExtensionMetadata, ExtensionBlockMetadata, ExtensionMenuMetadata, DynamicMenu, BlockMetadata, BlockUtilityWithID, } from "$common/types";
import { registerButtonCallback } from "$common/ui";
import { isString, typesafeCall, } from "$common/utils";
import type BlockUtility from "$root/packages/scratch-vm/src/engine/block-utility";
Expand All @@ -12,11 +12,17 @@ import { convertToArgumentInfo, extractArgs, zipArgs } from "./args";
import { convertToDisplayText } from "./text";
import { CustomizableExtensionConstructor, MinimalExtensionInstance, } from "..";
import { ExtensionInstanceWithFunctionality } from "../..";
import { blockIDKey } from "$common/globals";

export const getImplementationName = (opcode: string) => `internal_${opcode}`;

const inlineImageAccessError = "ERROR: This argument represents an inline image and should not be accessed.";

const isBlockUtilityWithID = (query: any): query is BlockUtilityWithID => query?.[blockIDKey] !== undefined;
const nonBlockContextError = "Block method was not given a block utility, and thus was likely called by something OTHER THAN the Scratch Runtime. NOTE: You cannot call block methods directly from within your class due to how block methods are converted to work with scratch. Consider abstracting the logic to a seperate, non-block method which can be invoked directly."
const checkForBlockContext = (blockUtility: BlockUtilityWithID) => isBlockUtilityWithID(blockUtility) ? void 0 : console.error(nonBlockContextError);


/**
* Wraps a blocks operation so that the arguments passed from Scratch are first extracted and then passed as indices in a parameter array.
* @param _this What will be bound to the 'this' context of the underlying operation
Expand All @@ -29,7 +35,8 @@ export const wrapOperation = <T extends MinimalExtensionInstance>(
operation: BlockOperation,
args: { name: string, type: ValueOf<typeof ArgumentType>, handler: Handler }[]
) => _this.supports("customArguments")
? function (this: ExtensionInstanceWithFunctionality<["customArguments"]>, argsFromScratch: Record<string, any>, blockUtility: BlockUtility) {
? function (this: ExtensionInstanceWithFunctionality<["customArguments"]>, argsFromScratch: Record<string, any>, blockUtility: BlockUtilityWithID) {
checkForBlockContext(blockUtility);
const castedArguments = args.map(({ name, type, handler }) => {
if (type === ArgumentType.Image) return inlineImageAccessError;
const param = argsFromScratch[name];
Expand All @@ -44,7 +51,8 @@ export const wrapOperation = <T extends MinimalExtensionInstance>(
});
return operation.call(_this, ...castedArguments, blockUtility);
}
: function (this: T, argsFromScratch: Record<string, any>, blockUtility: BlockUtility) {
: function (this: T, argsFromScratch: Record<string, any>, blockUtility: BlockUtilityWithID) {
checkForBlockContext(blockUtility);
const castedArguments = args.map(({ name, type, handler }) =>
type === ArgumentType.Image
? inlineImageAccessError
Expand Down
4 changes: 2 additions & 2 deletions extensions/src/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ export const getTextFromMenuItem = <T>(item: MenuItem<T>) => typeof item === "ob

export async function fetchWithTimeout(
resource: FetchParams["request"],
options: FetchParams["options"] & { timeout: number }
options: FetchParams["options"] & { timeoutMs: number }
) {
const { timeout } = options;
const { timeoutMs: timeout } = options;

const controller = new AbortController();
const id = setTimeout(() => controller.abort(), timeout);
Expand Down
4 changes: 0 additions & 4 deletions extensions/src/textClassification/Editor.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
export let extension: Extension;
export let close: () => void;
export const onClose = () => {
console.log("closed!");
};
const invoke: ReactiveInvoke<Extension> = (functionName, ...args) =>
reactiveInvoke((extension = extension), functionName, args);
Expand Down
47 changes: 47 additions & 0 deletions extensions/src/textClassification/ImportExport.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<script lang="ts">
import type Extension from ".";
import PrimaryButton from "$common/components/PrimaryButton.svelte";
import { color } from "$common";
export let extension: Extension;
export let close: () => void;
let input: HTMLInputElement;
let file: File = null;
const _export = () => {
extension.exportClassifier();
close();
};
const _import = async () => {
const success = await extension.importClassifier(file);
if (!success) return alert(`Failed to import ${file.name}`);
extension.buildCustomDeepModel();
close();
};
</script>

<div style:background-color={color.ui.white}>
<PrimaryButton on:click={_export}>Export Classifier</PrimaryButton>
<PrimaryButton on:click={_import} disabled={!file}>
Import Classifier
<input
bind:this={input}
type="file"
accept=".json"
on:change={(e) => (file = e.currentTarget.files[0])}
/>
</PrimaryButton>
<PrimaryButton on:click={close}>Done</PrimaryButton>
</div>

<style>
div {
display: flex;
flex-direction: row;
justify-content: space-between;
gap: 1rem;
padding: 1rem;
}
</style>
174 changes: 84 additions & 90 deletions extensions/src/textClassification/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/// <reference types="dom-speech-recognition" />
import { Environment, extension, ExtensionMenuDisplayDetails, block, wrapClamp, fetchWithTimeout, RuntimeEvent, buttonBlock, untilTimePassed } from "$common";
import { Environment, extension, ExtensionMenuDisplayDetails, wrapClamp, fetchWithTimeout, RuntimeEvent, buttonBlock, } from "$common";
import type BlockUtility from "$scratch-vm/engine/block-utility";
import { legacyFullSupport, info, } from "./legacy";
import { getState, setState, tryCopyStateToClone } from "./state";
import { State, getState, setState, tryCopyStateToClone } from "./state";
import { getSynthesisURL } from "./services/synthesis";
import timer from "./timer";
import voices, { Voice } from "./voices";
Expand All @@ -22,7 +22,7 @@ const details: ExtensionMenuDisplayDetails = {
description: "Create a text classification model for use in a Scratch project!",
};

const defaultLabel = "No labels";
const defaultLabels = ["No labels"];

export default class TextClassification extends extension(details, "legacySupport", "ui", "indicators") {
labels: string[] = [];
Expand Down Expand Up @@ -53,12 +53,15 @@ export default class TextClassification extends extension(details, "legacySuppor
@buttonBlock("Edit Model")
editButton() { this.openUI("Editor", "Edit Text Model") }

@buttonBlock("Load / Save Model")
saveLoadButton() { this.openUI("ImportExport", "Save / Load Text Model") }

@legacyBlock.ifTextMatchesClass((self) => ({
argumentMethods: {
1: {
getItems: () => {
const { labels } = self;
return labels?.length > 0 ? labels : [defaultLabel];
return labels?.length > 0 ? labels : defaultLabels;
}
}
}
Expand Down Expand Up @@ -100,40 +103,13 @@ export default class TextClassification extends extension(details, "legacySuppor

@legacyBlock.speakText()
async speakText(text: string, { target }: BlockUtility) {
const locale = 'en-US';
const { audioEngine } = this.runtime;
const { currentVoice } = getState(target);
const { gender, playbackRate } = voices[currentVoice];
const encoded = encodeURIComponent(JSON.stringify(text).substring(0, 128));
const endpoint = getSynthesisURL({ gender, locale, text: encoded });

await new Promise<void>(async (resolve) => {
try {
const response = await fetchWithTimeout(endpoint, { timeout: 40 });
if (!response.ok) return console.warn(response.statusText);
const sound = { data: response.body as unknown as Buffer };
const soundPlayer = await audioEngine.decodeSoundPlayer(sound);
this.soundPlayers.set(soundPlayer.id, soundPlayer);
soundPlayer.setPlaybackRate(playbackRate);
const chain = audioEngine.createEffectChain();
chain.set('volume', 250);
soundPlayer.connect(chain);
soundPlayer.play();
soundPlayer.on('stop', () => {
this.soundPlayers.delete(soundPlayer.id);
resolve();
});
}
catch (error) {
console.warn(error);
}
});
await this.speak(text, getState(target));
}

@legacyBlock.askSpeechRecognition()
async askSpeechRecognition(prompt: string, util: BlockUtility) {
await this.speakText(prompt, util);
this.recognizeSpeech();
async askSpeechRecognition(prompt: string, { target }: BlockUtility) {
await this.speak(prompt, getState(target));
await this.recognizeSpeech();
}

@legacyBlock.getRecognizedSpeech()
Expand All @@ -151,14 +127,15 @@ export default class TextClassification extends extension(details, "legacySuppor
return voiceItems[voiceIndex].value;
}

return voiceItems.find(({ value, text }) => value === reported || text === reported)?.value
return voiceItems.find(({ value, text }) => value === reported || text === reported)?.value ?? "SQUEAK";
}
}
}
})
setVoice(voice: Voice, { target }: BlockUtility) {
const state = getState(target);
state.currentVoice = voice ?? state.currentVoice;
setState(target, state);
}

@legacyBlock.onHeardSound()
Expand Down Expand Up @@ -193,9 +170,9 @@ export default class TextClassification extends extension(details, "legacySuppor
}

async buildCustomDeepModel() {
const indicator = await this.indicate({ msg: "wait .. loading model", type: "warning", });
const identifier = Symbol();
this.currentModelIdentifier = identifier;
const indicator = await this.indicate({ msg: "wait .. loading model", type: "warning", });
const isCurrent = () => this.currentModelIdentifier === identifier;
const result = await build(this.labels, this.modelData, isCurrent);

Expand All @@ -210,19 +187,84 @@ export default class TextClassification extends extension(details, "legacySuppor
}
}

importClassifier(file: File) {
return new Promise<boolean>((resolve) => {
if (!file) return resolve(false);
const reader = new FileReader();
reader.onload = ({ target: { result } }) => {
console.log(result);
try {
const data = JSON.parse(result as string) as Record<string, string[]>;
this.modelData = new Map(Object.entries(data));
this.labels = [...this.modelData.keys()];
resolve(true);
} catch (err) {
console.error(`Incorrect document form: ${file.name}: ${err}`);
this.modelData = new Map();
this.labels = [];
resolve(false);
}
};
reader.readAsText(file);
});
}

exportClassifier() {
const serialized = JSON.stringify(Object.fromEntries(this.modelData));
const data = `text/json;charset=utf-8,${encodeURIComponent(serialized)}`;
const anchor = document.createElement('a');
anchor.setAttribute("href", "data:" + data);
anchor.setAttribute("download", "classifier-info.json");
anchor.click();
}

/** End UI Methods */

/** Begin Private Methods */

private async speak(text: string, { currentVoice }: State) {
const locale = 'en-US';
const { audioEngine } = this.runtime;
const { gender, playbackRate } = voices[currentVoice];
const encoded = encodeURIComponent(text.substring(0, 128));
const endpoint = getSynthesisURL({ gender, locale, text: encoded });

await new Promise<void>(async (resolve) => {
try {
const response = await fetchWithTimeout(endpoint, { timeoutMs: 40000 });
if (!response.ok) return console.warn(response.statusText);
const buffer = await response.arrayBuffer();
const soundPlayer = await audioEngine.decodeSoundPlayer({ data: { buffer } });
this.soundPlayers.set(soundPlayer.id, soundPlayer);
soundPlayer.setPlaybackRate(playbackRate);
const chain = audioEngine.createEffectChain();
chain.set('volume', 250);
soundPlayer.connect(chain);
soundPlayer.play();
soundPlayer.on('stop', () => {
this.soundPlayers.delete(soundPlayer.id);
resolve();
});
}
catch (error) {
console.warn(error);
}
});
}

private async getToxicityModel() {
if (this.toxicityModel) return this.toxicityModel;
const msg = await this.indicate({ msg: "Loading toxicity model", type: "warning", });
try {
this.toxicityModel ??= await loadToxicity(0.1, toxicityLabelItems.map(({ value }) => value));
console.log('loaded Toxicity model');
this.toxicityModel = await loadToxicity(0.1, toxicityLabelItems.map(({ value }) => value));
msg.close();
this.indicateFor({ msg: "Toxicity model loaded!", type: "success", }, 2);
return this.toxicityModel;
}
catch (error) {
console.log('Failed to load toxicity model', error);
msg.close();
this.indicateFor({ msg: "Failed to load toxicity model", type: "error", }, 2);
}
return this.toxicityModel;
}

private getLoudness() {
Expand Down Expand Up @@ -263,9 +305,7 @@ export default class TextClassification extends extension(details, "legacySuppor
? Math.round(filtered[0].results[0].probabilities[returnPositive ? 1 : 0] * 100)
: 0;
}
catch (error) {
console.log('Failed to classify text', error);
}
catch (error) { console.error('Failed to classify text', error); }
}

private async getConfidence(text: string) {
Expand All @@ -280,50 +320,4 @@ export default class TextClassification extends extension(details, "legacySuppor
const { label } = await this.customPredictor(newText);
return label;
}

/** End Private Methods */

private uiEventsTODO() {
/*
// Listen for model editing events emitted by the text modal
this.runtime.on('NEW_EXAMPLES', (examples, label) => {
this.newExamples(examples, label);
});
this.runtime.on('NEW_LABEL', (label) => {
this.newLabel(label);
});
this.runtime.on('DELETE_EXAMPLE', (label, exampleNum) => {
this.deleteExample(label, exampleNum);
});
this.runtime.on('RENAME_LABEL', (oldName, newName) => {
this.renameLabel(oldName, newName);
});
this.runtime.on('DELETE_LABEL', (label) => {
this.clearAllWithLabel({ LABEL: label });
});
this.runtime.on('CLEAR_ALL_LABELS', () => {
if (!this.labelListEmpty && confirm('Are you sure you want to clear all labels?')) { //confirm with alert dialogue before clearing the model
let labels = [...this.labelList];
for (var i = 0; i < labels.length; i++) {
this.clearAllWithLabel({ LABEL: labels[i] });
}
//this.clearAll(); this crashed Scratch for some reason
}
});
//Listen for model editing events emitted by the classifier modal
this.runtime.on('EXPORT_CLASSIFIER', () => {
this.exportClassifier();
});
this.runtime.on('LOAD_CLASSIFIER', () => {
console.log("load");
this.loadClassifier();
});
this.runtime.on('DONE', () => {
console.log("DONE");
this.buildCustomDeepModel();
});*/
}
}
2 changes: 1 addition & 1 deletion extensions/src/textClassification/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export const build = async (
) => {
const { length } = labels;

if (length < 2) return { error: "No classes inputted" };
if (length < 2) return { error: "2 or more classes required" };

const model = sequential();

Expand Down
Loading

0 comments on commit 235c5ed

Please sign in to comment.