Skip to content

Commit

Permalink
docs: Add sum per language for task counts (#1468)
Browse files Browse the repository at this point in the history
* add sum per lang

* add sort by sum option

* make lint
  • Loading branch information
isaac-chung authored Nov 18, 2024
1 parent 8bb4a29 commit 2fb6fe7
Show file tree
Hide file tree
Showing 2 changed files with 1,066 additions and 1,061 deletions.
19 changes: 12 additions & 7 deletions docs/create_tasks_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def create_tasks_table(tasks: list[mteb.AbsTask]) -> str:
return table


def create_task_lang_table(tasks: list[mteb.AbsTask]) -> str:
def create_task_lang_table(tasks: list[mteb.AbsTask], sort_by_sum=False) -> str:
table_dict = {}
## Group by language. If it is a multilingual dataset, 1 is added to all languages present.
for task in tasks:
Expand All @@ -82,22 +82,27 @@ def create_task_lang_table(tasks: list[mteb.AbsTask]) -> str:
## Wrangle for polars
pl_table_dict = []
for lang, d in table_dict.items():
d.update({"lang": lang})
d.update({"0-lang": lang}) # for sorting columns
pl_table_dict.append(d)

df = pl.DataFrame(pl_table_dict).sort(by="lang")
df = pl.DataFrame(pl_table_dict).sort(by="0-lang")
df = df.with_columns(sum=pl.sum_horizontal(get_args(TASK_TYPE)))
df = df.select(sorted(df.columns))
if sort_by_sum:
df = df.sort(by="sum", descending=True)

total = df.sum()

task_names_md = " | ".join(sorted(get_args(TASK_TYPE)))
horizontal_line_md = "---|---" * len(sorted(get_args(TASK_TYPE)))
horizontal_line_md = "---|---" * (len(sorted(get_args(TASK_TYPE))) + 1)
table = f"""
| Language | {task_names_md} |
| Language | {task_names_md} | Sum |
|{horizontal_line_md}|
"""

for row in df.iter_rows():
table += f"| {row[-1]} "
for num in row[:-1]:
table += f"| {row[0]} "
for num in row[1:]:
table += f"| {num} "
table += "|\n"

Expand Down
Loading

0 comments on commit 2fb6fe7

Please sign in to comment.