Skip to content

Commit

Permalink
Add gradio inference to frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko committed Nov 22, 2023
1 parent 8bb42de commit 32b50fb
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 5 deletions.
1 change: 1 addition & 0 deletions web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"tailwindcss-animate": "^1.0.7"
},
"devDependencies": {
"@gradio/client": "^0.8.1",
"@types/node": "^20.9.0",
"@types/react": "^18.2.15",
"@types/react-dom": "^18.2.7",
Expand Down
47 changes: 47 additions & 0 deletions web/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions web/src/components/TranslationModule.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import TranslationTextarea from "./TranslationTextarea";
import { Controller } from "@/highlighting/Controller.ts";
import { TranslationData } from "@/highlighting/context.ts";
import { Skeleton } from "@/components/ui/skeleton.tsx";
import { MockTranslationService } from "@/service/mock-translation.ts";
import { GradioTranslation } from "@/service/gradio-translation.ts";

export default function TranslationModule() {
const [inputValue, inputSetValue] = useState("");
Expand All @@ -24,16 +24,17 @@ export default function TranslationModule() {
);

const [isLoading, setIsLoading] = useState(false);
const translationService = new MockTranslationService();
const translationService = new GradioTranslation();

function translate() {
if (inputValue === "") {
if (inputValue === "" || isLoading || isTranslated) {
return;
}
console.log("To be translated: " + inputValue);
setIsLoading(true);
const result = translationService.predict(inputValue);
result.then((tokens) => {
console.log("Translated: " + tokens);
setIsLoading(false);
setIsTranslated(true);
const translationData = new TranslationData(tokens, controller);
Expand All @@ -54,10 +55,14 @@ export default function TranslationModule() {
onChange={inputSetValue}
elementId={"textToTranslate"}
placeholder={"Enter text here..."}
readOnly={isTranslated}
readOnly={isTranslated || isLoading}
/>
<div className="mt-4">
<Button className="bg-primary" onClick={translate}>
<Button
className="bg-primary"
onClick={translate}
disabled={isTranslated || isLoading}
>
Translate
</Button>
</div>
Expand Down
100 changes: 100 additions & 0 deletions web/src/service/gradio-translation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import {
Token,
TokenAttentionScore,
TokenProbScore,
} from "@/highlighting/context.ts";
import TranslationService from "@/service/translation-service.ts";
import { client } from "@gradio/client";

export class GradioTranslation implements TranslationService {
private gradioClient: Promise<any>;

constructor() {
this.gradioClient = client(
"https://waleko-gradio-transformer-en-ru.hf.space/--replicas/42v9l/",
{},
);
}

healthCheck(): Promise<boolean> {
return Promise.resolve(true);
}

async predict(text: string): Promise<[Token[], Token[]]> {
// @ts-ignore
const client = await this.gradioClient;
const result = await client.predict("/predict", [text]);

const obj = result.data[0];
console.log(obj);

const inputTokens = this.populateTokenInfo(
text,
obj.input_tokens,
obj.cross_attention,
true,
);
const outputTokens = this.populateTokenInfo(
obj.output_text,
obj.output_tokens,
obj.cross_attention,
false,
obj.output_scores,
);

console.log("inputTokens", inputTokens);
console.log("outputTokens", outputTokens);
return [inputTokens, outputTokens];
}

private populateTokenInfo(
text: string,
tokens: string[],
attention_matrix: number[][],
is_input: boolean,
scores?: [string, number][][],
): Token[] {
// Fix tokens (find them in text and add whitespace)
const fixed_tokens = [];
let last_index = 0;
for (const token of tokens) {
const index = text.indexOf(token, last_index);
if (index === -1) {
throw new Error(`Token ${token} not found in text ${text}`);
}
fixed_tokens.push(text.substring(last_index, index + token.length));
last_index = index + token.length;
}
// Build tokens
const token_infos = [];
for (let i = 0; i < tokens.length; i++) {
const token = fixed_tokens[i];
// Build attention scores
const attention_scores = [];
if (is_input) {
for (let j = 0; j < attention_matrix.length; j++) {
attention_scores.push(
new TokenAttentionScore(j, attention_matrix[j][i]),
);
}
} else {
for (let j = 0; j < attention_matrix[i].length; j++) {
attention_scores.push(
new TokenAttentionScore(j, attention_matrix[i][j]),
);
}
}
// Build prob scores
let prob_scores = undefined;
if (scores !== undefined) {
prob_scores = [];
for (const [token, score] of scores[i]) {
prob_scores.push(new TokenProbScore(token, score));
}
}
// Add token
token_infos.push(new Token(token, attention_scores, prob_scores));
}
return token_infos;
}
}

0 comments on commit 32b50fb

Please sign in to comment.