Skip to content

Commit

Permalink
Hacky architecture fig for demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Mar 11, 2024
1 parent 6260d4a commit 13d844b
Showing 1 changed file with 105 additions and 22 deletions.
127 changes: 105 additions & 22 deletions tools/demo/mammoth_demo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,87 @@
from dataclasses import dataclass
import requests
import streamlit as st
import streamlit as st # type: ignore

st.set_page_config(layout="wide")

MAMMOTH = '🦣'
FAT_UNDER = '▁'

ARCHITECTURE_HTML = """
<h3>Decoder</h3>
<div class="arch">
<div class="layer langspec">
<div class="compo compo-en">en</div>
<div class="compo compo-fr">fr</div>
<div class="compo compo-ru">ru</div>
</div>
<div class="layer taskspec">
<div class="compo compo-defmod">defmod</div>
<div class="compo compo-pargen">pargen</div>
<div class="compo compo-texsim">texsim</div>
<div class="compo compo-translate">translate</div>
</div>
<div class="layer langspec">
<div class="compo compo-en">en</div>
<div class="compo compo-fr">fr</div>
<div class="compo compo-ru">ru</div>
</div>
</div>
<h3>Encoder</h3>
<div class="arch">
<div class="layer full">
<div class="compo">fully shared</div>
</div>
</div>
<style>
.arch {
display: grid;
row-gap: 5px;
}
.layer {
display: grid;
grid-template-columns: repeat(3, 1fr);
column-gap: 5px;
}
.taskspec {
grid-template-columns: repeat(4, 1fr);
}
.langspec {
grid-template-columns: repeat(3, 1fr);
}
.full {
grid-template-columns: 1fr;
}
.compo {
display: grid;
border: 2px solid gray;
text-align: center;
}
.compo-__LANG__ {
border: 5px solid green;
font-weight: bold;
}
.compo-__TASK__ {
border: 5px solid red;
font-weight: bold;
}
.full div {
border: 5px solid blue;
}
</style>
"""


def render(template, model_task):
task, lang = model_task.split('_')
if task == 'translate':
_, lang = lang.split('-')
template = template.replace('__TASK__', task)
template = template.replace('__LANG__', lang)
return template


@dataclass
class ModelSpecs:
id: int
Expand All @@ -13,32 +90,38 @@ class ModelSpecs:

@staticmethod
def format_model(model):
suffix = ' <<' if model.loaded else ''
return f'{model.task}{suffix}'
return model.task


class Translator:
def __call__(self):
st.title(f'{MAMMOTH} MAMMOTH translation demo')
with st.form('Translation demo'):
model = st.selectbox(
'Model',
st.session_state.models,
format_func=ModelSpecs.format_model,
)
source = st.text_area(
'Source text',
height=None,
)
submitted = st.form_submit_button('▶️ Translate')
if submitted:
target_text = self.submit(source, model.id)
else:
target_text = ''
st.text_area(
'Target text',
value=target_text,
height=None,
col1, col2 = st.columns([0.6, 0.4], gap="large")
with col1:
with st.form('Translation demo'):
model = st.selectbox(
'Model',
st.session_state.models,
format_func=ModelSpecs.format_model,
)
source = st.text_area(
'Source text',
height=None,
)
submitted = st.form_submit_button('▶️ Translate')
if submitted:
target_text = self.submit(source, model.id)
else:
target_text = ''
st.text_area(
'Target text',
value=target_text,
height=None,
)
with col2:
st.markdown(
render(ARCHITECTURE_HTML, model.task),
unsafe_allow_html=True,
)

def submit(self, query, model):
Expand Down

0 comments on commit 13d844b

Please sign in to comment.