diff --git a/tools/demo/mammoth_demo.py b/tools/demo/mammoth_demo.py
index 4a9eda3f..4c1fbdec 100644
--- a/tools/demo/mammoth_demo.py
+++ b/tools/demo/mammoth_demo.py
@@ -1,10 +1,87 @@
from dataclasses import dataclass
import requests
-import streamlit as st
+import streamlit as st # type: ignore
+
+st.set_page_config(layout="wide")
MAMMOTH = '🦣'
FAT_UNDER = '▁'
+ARCHITECTURE_HTML = """
+
Decoder
+
+
+
+
defmod
+
pargen
+
texsim
+
translate
+
+
+
+Encoder
+
+
+
+"""
+
+
+def render(template, model_task):
+ task, lang = model_task.split('_')
+ if task == 'translate':
+ _, lang = lang.split('-')
+ template = template.replace('__TASK__', task)
+ template = template.replace('__LANG__', lang)
+ return template
+
+
@dataclass
class ModelSpecs:
id: int
@@ -13,32 +90,38 @@ class ModelSpecs:
@staticmethod
def format_model(model):
- suffix = ' <<' if model.loaded else ''
- return f'{model.task}{suffix}'
+ return model.task
class Translator:
def __call__(self):
st.title(f'{MAMMOTH} MAMMOTH translation demo')
- with st.form('Translation demo'):
- model = st.selectbox(
- 'Model',
- st.session_state.models,
- format_func=ModelSpecs.format_model,
- )
- source = st.text_area(
- 'Source text',
- height=None,
- )
- submitted = st.form_submit_button('▶️ Translate')
- if submitted:
- target_text = self.submit(source, model.id)
- else:
- target_text = ''
- st.text_area(
- 'Target text',
- value=target_text,
- height=None,
+ col1, col2 = st.columns([0.6, 0.4], gap="large")
+ with col1:
+ with st.form('Translation demo'):
+ model = st.selectbox(
+ 'Model',
+ st.session_state.models,
+ format_func=ModelSpecs.format_model,
+ )
+ source = st.text_area(
+ 'Source text',
+ height=None,
+ )
+ submitted = st.form_submit_button('▶️ Translate')
+ if submitted:
+ target_text = self.submit(source, model.id)
+ else:
+ target_text = ''
+ st.text_area(
+ 'Target text',
+ value=target_text,
+ height=None,
+ )
+ with col2:
+ st.markdown(
+ render(ARCHITECTURE_HTML, model.task),
+ unsafe_allow_html=True,
)
def submit(self, query, model):