diff --git a/web/src/components/InvisibleMark.tsx b/web/src/components/InvisibleMark.tsx index 4fd2a1c..5caf0b1 100644 --- a/web/src/components/InvisibleMark.tsx +++ b/web/src/components/InvisibleMark.tsx @@ -1,4 +1,4 @@ -import { TokenProbScore, TranslationData } from "@/highlighting/context.ts"; +import { TokenProbScore } from "@/highlighting/dataclasses.ts"; import { HoverCard, @@ -6,31 +6,61 @@ import { HoverCardTrigger, } from "@/components/ui/hover-card.tsx"; import SynonymsList from "@/components/SynonymsList.tsx"; +import { useEffect, useState } from "react"; +import { Controller } from "@/highlighting/Controller.ts"; + +export class MarkId { + constructor(index: number, pos: number) { + this.index = index; + this.pos = pos; + } + + index: number; + pos: number; + + toInt() { + const sign = this.index == 0 ? 1 : -1; + return sign * (this.pos + 1); + } + + equals(other: MarkId) { + return this.index == other.index && this.pos == other.pos; + } +} export default function InvisibleMark( synonyms: TokenProbScore[] | undefined, - translationData: TranslationData, - highlight_other: any | undefined, - translation_index: number, + mark_id: MarkId, + controller: Controller, ) { return (props: any) => { let isOpened = false; let isHovered = false; - // FIXME: still buggy + const [isKilled, setIsKilled] = useState(false); + function updateHighlights() { - console.log("updateHighlights: " + isOpened + " " + isHovered); if (isOpened || isHovered) { - if (highlight_other) { - translationData.controller.setHighlights[1 - translation_index]( - highlight_other, - ); - } + controller.activateMark(mark_id); } else { - translationData.updateDefault(); + controller.deactivateMark(mark_id); } } + const baseClassName = "my-invisible text-primary"; + const [className, setClassName] = useState(baseClassName); + + function setAttention(attn_style: string) { + if (attn_style == "killed") { + setIsKilled(true); + } + setClassName(baseClassName + " " + attn_style); + } + + useEffect(() => { + controller.registerMark(mark_id, setAttention); + }, []); + return ( <> { isHovered = true; updateHighlights(); @@ -54,7 +84,7 @@ export default function InvisibleMark( {props.children} - {synonyms && ( + {synonyms && !isKilled && ( diff --git a/web/src/components/SynonymsList.tsx b/web/src/components/SynonymsList.tsx index f57032c..06844f2 100644 --- a/web/src/components/SynonymsList.tsx +++ b/web/src/components/SynonymsList.tsx @@ -1,4 +1,4 @@ -import { TokenProbScore } from "@/highlighting/context.ts"; +import { TokenProbScore } from "@/highlighting/dataclasses.ts"; export default function SynonymsList({ synonyms, diff --git a/web/src/components/TranslationModule.tsx b/web/src/components/TranslationModule.tsx index f21f52b..56ee159 100644 --- a/web/src/components/TranslationModule.tsx +++ b/web/src/components/TranslationModule.tsx @@ -5,7 +5,7 @@ import { Button } from "@/components/ui/button"; import { useState } from "react"; import TranslationTextarea from "./TranslationTextarea"; import { Controller } from "@/highlighting/Controller.ts"; -import { TranslationData } from "@/highlighting/context.ts"; +import { TranslationData } from "@/highlighting/dataclasses.ts"; import { Skeleton } from "@/components/ui/skeleton.tsx"; import { GradioTranslation } from "@/service/gradio-translation.ts"; @@ -18,28 +18,38 @@ export default function TranslationModule() { const [isTranslated, setIsTranslated] = useState(false); - const controller = new Controller( - [inputSetValue, outputSetValue], - [inputSetHighlight, outputSetHighlight], + const [controller, _] = useState( + new Controller( + [inputSetValue, outputSetValue], + [inputSetHighlight, outputSetHighlight], + ), ); const [isLoading, setIsLoading] = useState(false); const translationService = new GradioTranslation(); function translate() { - if (inputValue === "" || isLoading || isTranslated) { + if (inputValue === "" || isLoading) { return; } console.log("To be translated: " + inputValue); setIsLoading(true); - const result = translationService.predict(inputValue); - result.then((tokens) => { - console.log("Translated: " + tokens); + + // Translation + try { + const result = translationService.predict(inputValue); + result.then((tokens) => { + console.log("Translated: " + tokens); + setIsLoading(false); + setIsTranslated(true); + const translationData = new TranslationData(tokens); + controller.setTranslationData(translationData); + }); + } catch (e) { + console.error(e); setIsLoading(false); - setIsTranslated(true); - const translationData = new TranslationData(tokens, controller); - translationData.updateDefault(); - }); + alert("Translation failed. Please try again."); + } } function clearInput() { @@ -50,6 +60,13 @@ export default function TranslationModule() { setIsTranslated(false); } + function handleInput(inputText: string) { + inputSetValue(inputText); + if (isTranslated && inputText !== inputValue) { + controller.sendGlobalKill(); + } + } + return ( <> {/* Content */} @@ -60,16 +77,16 @@ export default function TranslationModule() {
{" "} diff --git a/web/src/components/highlights.css b/web/src/components/highlights.css index 9357a50..5001d24 100644 --- a/web/src/components/highlights.css +++ b/web/src/components/highlights.css @@ -1,28 +1,23 @@ /* attention highlights gradations */ .attn5 { - border-radius: 3px; - background-color: #98b7f8e0; + background-color: #98b7f8e0 !important; } .attn4 { - border-radius: 3px; - background-color: #98b7f8c0; + background-color: #98b7f8c0 !important; } .attn3 { - border-radius: 3px; - background-color: #98b7f8a0; + background-color: #98b7f8a0 !important; } .attn2 { - border-radius: 3px; - background-color: #98b7f880; + background-color: #98b7f880 !important; } .attn1 { - border-radius: 3px; - background-color: #98b7f860; + background-color: #98b7f860 !important; } mark.my-invisible { @@ -34,3 +29,8 @@ mark.my-invisible { mark.my-invisible:hover { background-color: #98b7f8e0; } + +mark.killed:hover { + background-color: transparent !important; + cursor: text !important; +} diff --git a/web/src/highlighting/Controller.ts b/web/src/highlighting/Controller.ts index 04a7472..e81469b 100644 --- a/web/src/highlighting/Controller.ts +++ b/web/src/highlighting/Controller.ts @@ -1,3 +1,10 @@ +import { + Token, + TokenAttentionScore, + TranslationData, +} from "@/highlighting/dataclasses.ts"; +import InvisibleMark, { MarkId } from "@/components/InvisibleMark.tsx"; + export class Controller { /** * Controller class to handle highlight and text updates. @@ -6,4 +13,178 @@ export class Controller { public setInputs: [(text: string) => void, (text: string) => void], public setHighlights: [(highlight: any) => void, (highlight: any) => void], ) {} + + setTranslationData(translationData: TranslationData) { + this.translationData = translationData; + + this.attentionScores = [[], []]; + this.attentionSetters.clear(); + this.isGlobalKill = false; + + this.setInputs[0](translationData.getText(translationData.tokens[0])); + this.setInputs[1](translationData.getText(translationData.tokens[1])); + this.setHighlights[0](this.getHighlighter(0)); + this.setHighlights[1](this.getHighlighter(1)); + } + + /** + * Sends a global kill signal to all marks in the attention setters. + * This is used to disable all marks styling and logic + * when the user edits the text. + * + * @return {void} + */ + sendGlobalKill(): void { + console.log("sending global kill"); + this.isGlobalKill = true; + + // send kill signal to all marks + for (const setter of this.attentionSetters.values()) { + setter("killed"); + } + } + + /** + * Returns highlighter object to be used when no tokens are hovered over. + * All tokens are inside transparent spans, with hover detection on the spans. + * + * @returns The highlighter object. + */ + private getHighlighter(index: number): any[] { + const curTokens = this.translationData.tokens[index]; + + let curTokenRanges = this.tokenRanges(curTokens); + + return curTokens.map((token, pos) => { + // Push the attention scores for the current token to the attentionScores array. + this.attentionScores[index].push(token.attentionScores ?? []); + // Return smart mark object for the token + return { + highlight: curTokenRanges[pos], + component: InvisibleMark( + token.probScores, + new MarkId(index, pos), + this, + ), + }; + }); + } + + /** + * Registers a mark with a given mark ID and a setter function for handling attention styles. + * + * @param {MarkId} mark_id - The ID of the mark to register. + * @param {function} setAttention - The setter function that allows setting attention styles for the mark. + * @return {void} + */ + registerMark( + mark_id: MarkId, + setAttention: (attn_style: string) => void, + ): void { + this.attentionSetters.set(mark_id.toInt(), setAttention); + } + + /** + * Activates a mark by setting the attention style for related tokens in the other textarea.s + * If isGlobalKill flag is set, the mark activation is skipped. + * + * @param {MarkId} mark_id - The ID of the mark to activate. + */ + activateMark(mark_id: MarkId) { + if (this.isGlobalKill) { + return; + } + + const index = mark_id.index; + const pos = mark_id.pos; + + // Set the attention style for the current token. + const curTokenAttentionScores = this.attentionScores[index][pos]; + + for (const tas of curTokenAttentionScores) { + const style = this.attentionStyle(tas.attentionScore); + const other_mark_id = new MarkId(1 - index, tas.tokenIndex); + const setAttention = this.attentionSetters.get(other_mark_id.toInt()); + if (style !== undefined && setAttention !== undefined) { + setAttention(style); + this.lastModifier.set(other_mark_id.toInt(), mark_id); + } + } + } + + /** + * Deactivates the specified mark by setting the attention style for related tokens in the other textarea. + * + * @param {MarkId} mark_id - The ID of the mark to deactivate. + */ + deactivateMark(mark_id: MarkId) { + if (this.isGlobalKill) { + return; + } + + const index = mark_id.index; + const pos = mark_id.pos; + + // Set the attention style for the current token. + const curTokenAttentionScores = this.attentionScores[index][pos]; + + for (const tas of curTokenAttentionScores) { + const other_mark_id = new MarkId(1 - index, tas.tokenIndex); + const setAttention = this.attentionSetters.get(other_mark_id.toInt()); + const lastModifier = this.lastModifier.get(other_mark_id.toInt()); + if ( + setAttention !== undefined && + lastModifier !== undefined && + lastModifier.equals(mark_id) + ) { + setAttention(""); + } + } + } + + private tokenRanges(otherTokens: Token[]): [number, number][] { + let otherTokenRanges: [number, number][] = []; + let curIndex = 0; + for (const token of otherTokens) { + const tokenLength = token.text.length; + otherTokenRanges.push([curIndex, curIndex + tokenLength]); + curIndex += tokenLength; + } + return otherTokenRanges; + } + + /** + * Returns the attention style based on the given attention score. + * + * @param {number} attention_score - The attention score that determines the style. + * @return {string|undefined} The attention style, or undefined if no style is applicable. + * @private + */ + private attentionStyle(attention_score: number): string | undefined { + const style_thresholds: [string, number][] = [ + ["attn5", 0.5], + ["attn4", 0.4], + ["attn3", 0.3], + ["attn2", 0.2], + ["attn1", 0.1], + ]; + for (const [style, threshold] of style_thresholds) { + if (attention_score >= threshold) { + return style; + } + } + return undefined; + } + + // @ts-ignore + private translationData: TranslationData = undefined; + // Attention scores for the tokens in the two textareas. + private attentionScores: [TokenAttentionScore[][], TokenAttentionScore[][]] = + [[], []]; + // Attention setters for the tokens in the two textareas. + private attentionSetters = new Map void>(); + // Last modifier for the tokens in the two textareas. + private lastModifier = new Map(); + // Global kill flag to disable all marks. + private isGlobalKill = false; } diff --git a/web/src/highlighting/context.ts b/web/src/highlighting/context.ts deleted file mode 100644 index 0b4278d..0000000 --- a/web/src/highlighting/context.ts +++ /dev/null @@ -1,157 +0,0 @@ -import { Controller } from "@/highlighting/Controller.ts"; -import InvisibleMark from "@/components/InvisibleMark.tsx"; - -export class TokenAttentionScore { - /** - * Holder for a token in the cross-attention matrix. It contains the token's - * index and the attention score. - * @param tokenIndex The index of the token in the other textarea. - * @param attentionScore The attention score for the token. - */ - constructor( - public tokenIndex: number, - public attentionScore: number, - ) {} -} - -export class TokenProbScore { - /** - * Holder for a token and its probability to be displayed in the synonym UI. - * - * @param {string} tokenText - The text of the token. - * @param {number} prob - The probability associated with the token. - */ - constructor( - public tokenText: string, - public prob: number, - ) {} -} - -export class Token { - /** - * Token class to hold all token information from the model. - * @param text The text of the token. - * @param attentionScores Cross-attention scores for the token. - * @param probScores Probabilities for the synonyms of the token. - */ - constructor( - public text: string, - public attentionScores?: TokenAttentionScore[], - public probScores?: TokenProbScore[], - ) {} -} - -/** - * Data class to hold the service data from the model. - */ -export class TranslationData { - /** - * Translation data class to hold the service data from the model. - * @param tokens Input and Output tokens from the model. - * @param controller Controller class to handle highlight and text updates. - */ - constructor( - public tokens: [Token[], Token[]], - public controller: Controller, - ) {} - - /** - * Returns the concatenated text from the tokens - * @param tokens The tokens to concatenate - * @returns The concatenated text from the tokens - */ - private getText(tokens: Token[]): string { - return tokens.map((token) => token.text).join(""); - } - - /** - * Returns highlighter object to be used when no tokens are hovered over. - * All tokens are inside transparent spans, with hover detection on the spans. - * On hover, the global highlighter is updated. - * - * @returns The highlighter object. - */ - private getDefaultHighlighter(index: number): any[] { - const curTokens = this.tokens[index]; - const otherTokens = this.tokens[1 - index]; - - let curTokenRanges = this.tokenRanges(curTokens); - let otherTokenRanges = this.tokenRanges(otherTokens); - - return curTokens.map((token, i) => { - const hoveredHighlight = token.attentionScores - ?.map((score) => { - const className = this.attentionStyle(score.attentionScore); - if (className === undefined) { - return undefined; - } - return { - highlight: otherTokenRanges[score.tokenIndex], - className: className, - }; - }) - .filter((highlight) => highlight?.className !== undefined); - return { - highlight: curTokenRanges[i], - component: InvisibleMark( - token.probScores, - this, - hoveredHighlight, - index, - ), - }; - }); - } - - private tokenRanges(otherTokens: Token[]): [number, number][] { - let otherTokenRanges: [number, number][] = []; - let curIndex = 0; - for (const token of otherTokens) { - const tokenLength = token.text.length; - otherTokenRanges.push([curIndex, curIndex + tokenLength]); - curIndex += tokenLength; - } - return otherTokenRanges; - } - - public updateDefault() { - console.log( - "Texts: ", - this.getText(this.tokens[0]), - this.getText(this.tokens[1]), - ); - console.log( - "Highlights: ", - this.getDefaultHighlighter(0), - this.getDefaultHighlighter(1), - ); - - this.controller.setInputs[0](this.getText(this.tokens[0])); - this.controller.setInputs[1](this.getText(this.tokens[1])); - this.controller.setHighlights[0](this.getDefaultHighlighter(0)); - this.controller.setHighlights[1](this.getDefaultHighlighter(1)); - } - - /** - * Returns the attention style based on the given attention score. - * - * @param {number} attention_score - The attention score that determines the style. - * @return {string|undefined} The attention style, or undefined if no style is applicable. - * @private - */ - private attentionStyle(attention_score: number): string | undefined { - const style_thresholds: [string, number][] = [ - ["attn5", 0.5], - ["attn4", 0.4], - ["attn3", 0.3], - ["attn2", 0.2], - ["attn1", 0.1], - ]; - for (const [style, threshold] of style_thresholds) { - if (attention_score >= threshold) { - return style; - } - } - return undefined; - } -} diff --git a/web/src/highlighting/dataclasses.ts b/web/src/highlighting/dataclasses.ts new file mode 100644 index 0000000..92b7253 --- /dev/null +++ b/web/src/highlighting/dataclasses.ts @@ -0,0 +1,54 @@ +export class TokenAttentionScore { + /** + * Holder for a token in the cross-attention matrix. It contains the token's + * index and the attention score. + * @param tokenIndex The index of the token in the other textarea. + * @param attentionScore The attention score for the token. + */ + constructor( + public tokenIndex: number, + public attentionScore: number, + ) {} +} + +export class TokenProbScore { + /** + * Holder for a token and its probability to be displayed in the synonym UI. + * + * @param {string} tokenText - The text of the token. + * @param {number} prob - The probability associated with the token. + */ + constructor( + public tokenText: string, + public prob: number, + ) {} +} + +export class Token { + /** + * Token class to hold all token information from the model. + * @param text The text of the token. + * @param attentionScores Cross-attention scores for the token. + * @param probScores Probabilities for the synonyms of the token. + */ + constructor( + public text: string, + public attentionScores?: TokenAttentionScore[], + public probScores?: TokenProbScore[], + ) {} +} + +/** + * Data class to hold the service data from the model. + */ +export class TranslationData { + /** + * Translation data class to hold the service data from the model. + * @param tokens Input and Output tokens from the model. + */ + constructor(public tokens: [Token[], Token[]]) {} + + public getText(tokens: Token[]): string { + return tokens.map((token) => token.text).join(""); + } +} diff --git a/web/src/service/gradio-translation.ts b/web/src/service/gradio-translation.ts index b1de90a..79e32dd 100644 --- a/web/src/service/gradio-translation.ts +++ b/web/src/service/gradio-translation.ts @@ -2,7 +2,7 @@ import { Token, TokenAttentionScore, TokenProbScore, -} from "@/highlighting/context.ts"; +} from "@/highlighting/dataclasses.ts"; import TranslationService from "@/service/translation-service.ts"; import { client } from "@gradio/client"; diff --git a/web/src/service/mock-translation.ts b/web/src/service/mock-translation.ts index 4f13a83..633bad0 100644 --- a/web/src/service/mock-translation.ts +++ b/web/src/service/mock-translation.ts @@ -2,7 +2,7 @@ import { Token, TokenAttentionScore, TokenProbScore, -} from "@/highlighting/context.ts"; +} from "@/highlighting/dataclasses.ts"; import TranslationService from "@/service/translation-service.ts"; export class MockTranslationService implements TranslationService { diff --git a/web/src/service/translation-service.ts b/web/src/service/translation-service.ts index c6215f7..c891989 100644 --- a/web/src/service/translation-service.ts +++ b/web/src/service/translation-service.ts @@ -1,4 +1,4 @@ -import { Token } from "@/highlighting/context.ts"; +import { Token } from "@/highlighting/dataclasses.ts"; /** * Represents a TranslationService.