From 32b50fb539ccfedfe37c2111ff1434c1568e5d6f Mon Sep 17 00:00:00 2001 From: Alexander Kovrigin Date: Wed, 22 Nov 2023 22:12:00 +0100 Subject: [PATCH] Add gradio inference to frontend --- web/package.json | 1 + web/pnpm-lock.yaml | 47 +++++++++++ web/src/components/TranslationModule.tsx | 15 ++-- web/src/service/gradio-translation.ts | 100 +++++++++++++++++++++++ 4 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 web/src/service/gradio-translation.ts diff --git a/web/package.json b/web/package.json index 82172f8..7d87623 100644 --- a/web/package.json +++ b/web/package.json @@ -32,6 +32,7 @@ "tailwindcss-animate": "^1.0.7" }, "devDependencies": { + "@gradio/client": "^0.8.1", "@types/node": "^20.9.0", "@types/react": "^18.2.15", "@types/react-dom": "^18.2.7", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 8b6e27f..0887e21 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -67,6 +67,9 @@ dependencies: version: 1.0.7(tailwindcss@3.3.5) devDependencies: + '@gradio/client': + specifier: ^0.8.1 + version: 0.8.1 '@types/node': specifier: ^20.9.0 version: 20.9.0 @@ -659,6 +662,17 @@ packages: react: 18.2.0 dev: false + /@gradio/client@0.8.1: + resolution: {integrity: sha512-qMFINw6MbubLYQiLorwGhyapH7bUSthH+7tevGgKOPuQWoW5TliPi95VEFRUQ+d2Y/vXqWj0+hubO94RO7v/7w==} + engines: {node: '>=18.0.0'} + dependencies: + bufferutil: 4.0.8 + semiver: 1.1.0 + ws: 8.14.2(bufferutil@4.0.8) + transitivePeerDependencies: + - utf-8-validate + dev: true + /@humanwhocodes/config-array@0.11.13: resolution: {integrity: sha512-JSBDMiDKSzQVngfRjOdFXgFfklaXI4K9nLF49Auh21lmBWRLIK3+xTErTWD4KU54pb6coM6ESE7Awz/FNU3zgQ==} engines: {node: '>=10.10.0'} @@ -1719,6 +1733,14 @@ packages: update-browserslist-db: 1.0.13(browserslist@4.22.1) dev: true + /bufferutil@4.0.8: + resolution: {integrity: sha512-4T53u4PdgsXqKaIctwF8ifXlRTTmEPJ8iEPWFdGZvcf7sbwYo6FKFEX9eNNAnzFZ7EzJAQ3CJeOtCRA4rDp7Pw==} + engines: {node: '>=6.14.2'} + requiresBuild: true + dependencies: + node-gyp-build: 4.7.0 + dev: true + /bulma@0.9.4: resolution: {integrity: sha512-86FlT5+1GrsgKbPLRRY7cGDg8fsJiP/jzTqXXVqiUZZ2aZT8uemEOHlU1CDU+TxklPEZ11HZNNWclRBBecP4CQ==} dev: false @@ -2456,6 +2478,11 @@ packages: whatwg-url: 5.0.0 dev: false + /node-gyp-build@4.7.0: + resolution: {integrity: sha512-PbZERfeFdrHQOOXiAKOY0VPbykZy90ndPKk0d+CFDegTKmWp1VgOTz2xACVbr1BjCWxrQp68CXtvNsveFhqDJg==} + hasBin: true + dev: true + /node-releases@2.0.13: resolution: {integrity: sha512-uYr7J37ae/ORWdZeQ1xxMJe3NtdmqMC/JZK+geofDrkLUApKRHPd18/TxtBOJ4A0/+uUIliorNrfYV6s1b02eQ==} dev: true @@ -2796,6 +2823,11 @@ packages: loose-envify: 1.4.0 dev: false + /semiver@1.1.0: + resolution: {integrity: sha512-QNI2ChmuioGC1/xjyYwyZYADILWyW6AmS1UH6gDj/SFUUUS4MBAWs/7mxnkRPc/F4iHezDP+O8t0dO8WHiEOdg==} + engines: {node: '>=6'} + dev: true + /semver@6.3.1: resolution: {integrity: sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==} hasBin: true @@ -3102,6 +3134,21 @@ packages: /wrappy@1.0.2: resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} + /ws@8.14.2(bufferutil@4.0.8): + resolution: {integrity: sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + dependencies: + bufferutil: 4.0.8 + dev: true + /yallist@3.1.1: resolution: {integrity: sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==} dev: true diff --git a/web/src/components/TranslationModule.tsx b/web/src/components/TranslationModule.tsx index 6c04ba7..4211b33 100644 --- a/web/src/components/TranslationModule.tsx +++ b/web/src/components/TranslationModule.tsx @@ -7,7 +7,7 @@ import TranslationTextarea from "./TranslationTextarea"; import { Controller } from "@/highlighting/Controller.ts"; import { TranslationData } from "@/highlighting/context.ts"; import { Skeleton } from "@/components/ui/skeleton.tsx"; -import { MockTranslationService } from "@/service/mock-translation.ts"; +import { GradioTranslation } from "@/service/gradio-translation.ts"; export default function TranslationModule() { const [inputValue, inputSetValue] = useState(""); @@ -24,16 +24,17 @@ export default function TranslationModule() { ); const [isLoading, setIsLoading] = useState(false); - const translationService = new MockTranslationService(); + const translationService = new GradioTranslation(); function translate() { - if (inputValue === "") { + if (inputValue === "" || isLoading || isTranslated) { return; } console.log("To be translated: " + inputValue); setIsLoading(true); const result = translationService.predict(inputValue); result.then((tokens) => { + console.log("Translated: " + tokens); setIsLoading(false); setIsTranslated(true); const translationData = new TranslationData(tokens, controller); @@ -54,10 +55,14 @@ export default function TranslationModule() { onChange={inputSetValue} elementId={"textToTranslate"} placeholder={"Enter text here..."} - readOnly={isTranslated} + readOnly={isTranslated || isLoading} />
-
diff --git a/web/src/service/gradio-translation.ts b/web/src/service/gradio-translation.ts new file mode 100644 index 0000000..b1de90a --- /dev/null +++ b/web/src/service/gradio-translation.ts @@ -0,0 +1,100 @@ +import { + Token, + TokenAttentionScore, + TokenProbScore, +} from "@/highlighting/context.ts"; +import TranslationService from "@/service/translation-service.ts"; +import { client } from "@gradio/client"; + +export class GradioTranslation implements TranslationService { + private gradioClient: Promise; + + constructor() { + this.gradioClient = client( + "https://waleko-gradio-transformer-en-ru.hf.space/--replicas/42v9l/", + {}, + ); + } + + healthCheck(): Promise { + return Promise.resolve(true); + } + + async predict(text: string): Promise<[Token[], Token[]]> { + // @ts-ignore + const client = await this.gradioClient; + const result = await client.predict("/predict", [text]); + + const obj = result.data[0]; + console.log(obj); + + const inputTokens = this.populateTokenInfo( + text, + obj.input_tokens, + obj.cross_attention, + true, + ); + const outputTokens = this.populateTokenInfo( + obj.output_text, + obj.output_tokens, + obj.cross_attention, + false, + obj.output_scores, + ); + + console.log("inputTokens", inputTokens); + console.log("outputTokens", outputTokens); + return [inputTokens, outputTokens]; + } + + private populateTokenInfo( + text: string, + tokens: string[], + attention_matrix: number[][], + is_input: boolean, + scores?: [string, number][][], + ): Token[] { + // Fix tokens (find them in text and add whitespace) + const fixed_tokens = []; + let last_index = 0; + for (const token of tokens) { + const index = text.indexOf(token, last_index); + if (index === -1) { + throw new Error(`Token ${token} not found in text ${text}`); + } + fixed_tokens.push(text.substring(last_index, index + token.length)); + last_index = index + token.length; + } + // Build tokens + const token_infos = []; + for (let i = 0; i < tokens.length; i++) { + const token = fixed_tokens[i]; + // Build attention scores + const attention_scores = []; + if (is_input) { + for (let j = 0; j < attention_matrix.length; j++) { + attention_scores.push( + new TokenAttentionScore(j, attention_matrix[j][i]), + ); + } + } else { + for (let j = 0; j < attention_matrix[i].length; j++) { + attention_scores.push( + new TokenAttentionScore(j, attention_matrix[i][j]), + ); + } + } + // Build prob scores + let prob_scores = undefined; + if (scores !== undefined) { + prob_scores = []; + for (const [token, score] of scores[i]) { + prob_scores.push(new TokenProbScore(token, score)); + } + } + // Add token + token_infos.push(new Token(token, attention_scores, prob_scores)); + } + return token_infos; + } +}