diff --git a/tools/demo/mammoth_demo.py b/tools/demo/mammoth_demo.py index 4a9eda3f..4c1fbdec 100644 --- a/tools/demo/mammoth_demo.py +++ b/tools/demo/mammoth_demo.py @@ -1,10 +1,87 @@ from dataclasses import dataclass import requests -import streamlit as st +import streamlit as st # type: ignore + +st.set_page_config(layout="wide") MAMMOTH = '🦣' FAT_UNDER = '▁' +ARCHITECTURE_HTML = """ +

Decoder

+
+
+
en
+
fr
+
ru
+
+
+
defmod
+
pargen
+
texsim
+
translate
+
+
+
en
+
fr
+
ru
+
+
+

Encoder

+
+
+
fully shared
+
+
+ + +""" + + +def render(template, model_task): + task, lang = model_task.split('_') + if task == 'translate': + _, lang = lang.split('-') + template = template.replace('__TASK__', task) + template = template.replace('__LANG__', lang) + return template + + @dataclass class ModelSpecs: id: int @@ -13,32 +90,38 @@ class ModelSpecs: @staticmethod def format_model(model): - suffix = ' <<' if model.loaded else '' - return f'{model.task}{suffix}' + return model.task class Translator: def __call__(self): st.title(f'{MAMMOTH} MAMMOTH translation demo') - with st.form('Translation demo'): - model = st.selectbox( - 'Model', - st.session_state.models, - format_func=ModelSpecs.format_model, - ) - source = st.text_area( - 'Source text', - height=None, - ) - submitted = st.form_submit_button('▶️ Translate') - if submitted: - target_text = self.submit(source, model.id) - else: - target_text = '' - st.text_area( - 'Target text', - value=target_text, - height=None, + col1, col2 = st.columns([0.6, 0.4], gap="large") + with col1: + with st.form('Translation demo'): + model = st.selectbox( + 'Model', + st.session_state.models, + format_func=ModelSpecs.format_model, + ) + source = st.text_area( + 'Source text', + height=None, + ) + submitted = st.form_submit_button('▶️ Translate') + if submitted: + target_text = self.submit(source, model.id) + else: + target_text = '' + st.text_area( + 'Target text', + value=target_text, + height=None, + ) + with col2: + st.markdown( + render(ARCHITECTURE_HTML, model.task), + unsafe_allow_html=True, ) def submit(self, query, model):