Skip to content

Commit

Permalink
#4 added hf support, still running on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
onadegibert committed Jul 5, 2024
1 parent c6ac4b9 commit 57e0036
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 76 deletions.
84 changes: 47 additions & 37 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ if backward_pretrained:
backward_dir = backward_pretrained
elif opusmt_backward:
do_train_backward = False
elif huggingface:
do_train_backward = False
else:
# don't evaluate pretrained model
results.extend(expand(f'{eval_backward_dir}/{{langpair}}/{{dataset}}.metrics',dataset=eval_datasets, langpair=langpairs))
Expand Down Expand Up @@ -351,11 +353,12 @@ rule merge_devset:
log: f"{log_dir}/merge_devset.log"
conda: "envs/base.yml"
threads: workflow.cores
input: expand(f"{original}/{{langpair}}/devset.{{lang}}.gz", langpair=langpairs, lang=['source.langtagged', 'target']),
input: expand(f"{original}/{{langpair}}/devset.{{lang}}.gz", langpair=langpairs, lang=['target']), #removed source.langtagged from here, to deal with huggingface strategy
bin=ancient(deduper)
output: src=f"{original}/devset.source.gz",trg=f"{original}/devset.target.gz"
params: prefix_input=f"{original}/*/devset", prefix_output=f"{original}/devset"
shell: '''cat $(echo {params.prefix_input}.source.langtagged.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.source.gz"
shell: '''[ ! -f {params.prefix_input}.source.langtagged.gz ] && cp {params.prefix_input}.source.gz {params.prefix_input}.source.langtagged.gz
cat $(echo {params.prefix_input}.source.langtagged.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.source.gz"
cat $(echo {params.prefix_input}.target.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.target.gz" '''

if do_train_backward:
Expand Down Expand Up @@ -454,19 +457,6 @@ if augment_corpus:
"{input.src1}" "{input.src2}" "{input.trg1}" "{input.trg2}" "{output.res_src}" "{output.res_trg}" "" \
>> {log} 2>&1'''


rule add_lang_tag_corpus_src:
message: "Adding language tag id for corpus translation"
log: f"{log_dir}/add_langid_corpus_{{langpair}}.log"
conda: "envs/base.yml"
threads: workflow.cores
input: f"{clean_corpus_prefix}.source.gz", model_dir=f"{final_teacher_dir}0-0/" # BEWARE: only works for one model per language pair
output: f"{clean_corpus_prefix}.source.langtagged.gz"
params: prefix=f"{clean_corpus_prefix}",
trg_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[1]).to_alpha3(),
suffix="source"
shell: '''bash pipeline/clean/add-lang-tag.sh "{params.trg_three_letter}" "{params.prefix}" "{o2m_teacher}" "{params.suffix}" "{input.model_dir}" >> {log} 2>&1'''

if do_train_backward:
rule add_lang_tag_corpus_backward:
message: "Adding language tag id for backward model training"
Expand Down Expand Up @@ -514,36 +504,56 @@ if do_train_backward:
params: prefix_input=f"{original}/*/devset", prefix_output=f"{original}/devset"
shell: '''cat $(echo {params.prefix_input}.target.langtagged.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.target.langtagged.gz" '''

rule add_lang_tag_devset:
message: "Adding language tag id for devset"
log: f"{log_dir}/add_langid_devset_{{langpair}}.log"
conda: "envs/base.yml"
threads: workflow.cores
input: f"{original}/{{langpair}}/devset.source.gz", model_dir=f"{final_teacher_dir}0-0/" # BEWARE: only works for one model per language pair
output: f"{original}/{{langpair}}/devset.source.langtagged.gz"
params: output_dir=f"{original}/{{langpair}}/", prefix=f"{original}/{{langpair}}/devset",
trg_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[1]).to_alpha3(),
suffix="source"
shell: '''bash pipeline/clean/add-lang-tag.sh "{params.trg_three_letter}" "{params.prefix}" "{o2m_teacher}" "{params.suffix}" "{input.model_dir}" >> {log} 2>&1'''

rule merge_corpus:
message: "Merging clean parallel datasets"
log: f"{log_dir}/merge_corpus.log"
conda: "envs/base.yml"
threads: workflow.cores
input: expand(f"{clean_corpus_prefix}.{{lang}}.gz", langpair=langpairs, lang=['source.langtagged', 'target']),
input: expand(f"{clean_corpus_src}",langpair=langpairs),
expand(f"{clean_corpus_trg}", langpair=langpairs),
bin=ancient(deduper)
output: src=f"{teacher_corpus}.source.gz",trg=f"{teacher_corpus}.target.gz"
params: prefix_input = f"{teacher_corpus}".replace('corpus', ''), prefix_output=f"{teacher_corpus}"
shell: '''cat $(echo {params.prefix_input}*/corpus.source.langtagged.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.source.gz"
cat $(echo {params.prefix_input}*/corpus.target.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.target.gz" '''
shell: '''
if ls {params.prefix_input}*/corpus.source.langtagged.gz 1> /dev/null 2>&1; then
cat $(echo {params.prefix_input}*/corpus.source.langtagged.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.source.gz"
else
cat $(echo {params.prefix_input}*/corpus.source.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.source.gz"
fi
cat $(echo {params.prefix_input}*/corpus.target.gz | tr ' ' '\n' | tr '\n' ' ') > "{params.prefix_output}.target.gz"
'''

# Three options for teacher: 1. download opus-mt model, 2. train teacher with pipeline, 3. path to pretrained teacher model
# TODO: make it possible to combine any of the above options, i.e. use opus-mt, train and use
# pretrained all in the same run. Probably should have a model list where you can define all the
# models to use, and then prefixes (opusmt_, train_, pretrained_, nllb_ etc.) determine how the models are
# created/used/connected to (in case of e.g. external APIs).
if 'opusmt-teacher' in config['experiment']:

rule add_lang_tag_corpus_src:
message: "Adding language tag id for corpus translation"
log: f"{log_dir}/add_langid_corpus_{{langpair}}.log"
conda: "envs/base.yml"
threads: workflow.cores
input: f"{clean_corpus_prefix}.source.gz", model_dir=f"{final_teacher_dir}0-0/" # BEWARE: only works for one model per language pair
output: f"{clean_corpus_prefix}.source.langtagged.gz"
params: prefix=f"{clean_corpus_prefix}",
trg_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[1]).to_alpha3(),
suffix="source"
shell: '''bash pipeline/clean/add-lang-tag.sh "{params.trg_three_letter}" "{params.prefix}" "{o2m_teacher}" "{params.suffix}" "{input.model_dir}" >> {log} 2>&1'''

rule add_lang_tag_devset:
message: "Adding language tag id for devset"
log: f"{log_dir}/add_langid_devset_{{langpair}}.log"
conda: "envs/base.yml"
threads: workflow.cores
input: f"{original}/{{langpair}}/devset.source.gz", model_dir=f"{final_teacher_dir}0-0/" # BEWARE: only works for one model per language pair
output: f"{original}/{{langpair}}/devset.source.langtagged.gz"
params: output_dir=f"{original}/{{langpair}}/", prefix=f"{original}/{{langpair}}/devset",
trg_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[1]).to_alpha3(),
suffix="source"
shell: '''bash pipeline/clean/add-lang-tag.sh "{params.trg_three_letter}" "{params.prefix}" "{o2m_teacher}" "{params.suffix}" "{input.model_dir}" >> {log} 2>&1'''

if not isinstance(opusmt_teacher[0],dict):
rule download_teacher_models:
message: "Downloading OPUS-MT teacher model for {wildcards.langpair}"
Expand Down Expand Up @@ -623,7 +633,7 @@ checkpoint split_corpus:
log: f"{log_dir}/split_corpus_{{langpair}}.log"
conda: "envs/base.yml"
threads: 1
input: corpus_src=f"{clean_corpus_prefix}.source.langtagged.gz",corpus_trg=f"{clean_corpus_prefix}.target.gz"
input: corpus_src=clean_corpus_src, corpus_trg=clean_corpus_trg
output: output_dir=directory(f"{translated}/{{langpair}}/corpus"), file=f"{translated}/{{langpair}}/corpus/file.00"
shell: '''bash pipeline/translate/split-corpus.sh \
{input.corpus_src} {input.corpus_trg} {output.output_dir} {split_length} >> {log} 2>&1'''
Expand Down Expand Up @@ -685,8 +695,8 @@ else:
translated_mono_src_extension = ".out"
deseg_nbest_file = teacher_target_file

if hf_teacher:
# Configuration for the evaluation module
if huggingface:
# Configuration for the huggingface module
hf_config = {
"log_dir": log_dir,
"hf_teacher": hf_teacher,
Expand Down Expand Up @@ -822,7 +832,7 @@ rule merge_translated:
resources: mem_mb=64000
#group 'mono_src'
input:
src1=f"{clean_corpus_prefix}.source.langtagged.gz",
src1=clean_corpus_src,
src2=f"{clean}/{{langpair}}/mono.{src}.gz",
trg1=lambda wildcards: expand(f"{translated}/{{langpair}}/corpus.{{model_index}}.target.gz",model_index=model_indices, allow_missing=True),
trg2=lambda wildcards: expand(f"{translated}/{{langpair}}/mono.{{model_index}}.{trg}.gz",model_index=model_indices, allow_missing=True),
Expand Down Expand Up @@ -903,8 +913,8 @@ rule add_lang_tag_corpus_src_for_student:
log: f"{log_dir}/add_langid_corpus_{{langpair}}_student.log"
conda: "envs/base.yml"
threads: workflow.cores
input: expand(f"{filtered}/{{langpair}}/corpus.{{lang}}.gz", langpair=langpairs, lang=['source', 'target'])
output: f"{filtered}/{{langpair}}/corpus.source.langtagged.gz"
input: expand(f"{train_student_dir}/corpus.{{lang}}.gz", langpair=langpairs, lang=['source', 'target'])
output: f"{filtered}/{{langpair}}/corpus.source.langtagged.gz",f"{filtered}/{{langpair}}/corpus.target.gz"
params: prefix=f"{filtered}/{{langpair}}/corpus",
trg_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[1]).to_alpha3(),
suffix="source"
Expand Down Expand Up @@ -1202,4 +1212,4 @@ module evaluate:
snakefile: "rules/evaluate.smk"
config: eval_config

use rule * from evaluate as *
use rule * from evaluate as *
6 changes: 4 additions & 2 deletions configs/config.hf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ experiment:

parallel-max-sentences: 10000000
split-length: 1000000


one2many-student: True

best-model: perplexity
spm-sample-size: 1000000

huggingface:
model: "facebook/nllb-200-distilled-600M"
task: translation #if not in config, assumes "translation by default"
Expand Down
69 changes: 42 additions & 27 deletions pipeline/translate/translate_hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import argparse
import os
from transformers import pipeline
import time
import torch

# Make sure we have a GPU
print(torch.cuda.is_available())
print(torch.cuda.device_count())
#print(torch.cuda.get_device_name(0))

def parse_args():
parser = argparse.ArgumentParser(description="Translate text using Hugging Face pipeline.")
parser.add_argument('filein', type=str, help='Input file name')
Expand All @@ -24,17 +21,19 @@ def main():
os.environ['HF_HOME'] = args.modeldir

print(f"Translating {args.filein} from {args.src} to {args.trg} with {args.modelname}...")

from transformers import pipeline # It is here since we first need to change the cache directory


print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPUs available:", torch.cuda.device_count())

# Initialize the translation pipeline with cache_dir
pipe = pipeline(
task=args.task,
model=args.modelname,
num_beams=8,
num_return_sequences=8,
device_map="auto",
batch_size=32, max_length=150
max_length=150
)

if "nllb" in args.modelname:
Expand All @@ -49,15 +48,16 @@ def main():
else:
print(f"Source language found: {src_lang}")
print(f"Target language found: {trg_lang}")

pipe = pipeline(
task=args.task,
model=args.modelname,
num_beams=8,
num_return_sequences=8,
device_map="auto",
batch_size=32,
src_lang=src_lang,
tgt_lang=trg_lang
tgt_lang=trg_lang,
max_length=150
)

# Read the input text
Expand All @@ -66,23 +66,38 @@ def main():

if args.prompt:
# Modify the input text based on the prompt
with open(args.filein, 'r', encoding='utf-8') as infile:
text = [args.prompt.replace('<sourcetext>', line.strip()) for line in infile]
# Show an example of how the prompt is added to the input text
print(f"Added prompt like this:\n{text[0]}")
else:
with open(args.filein, 'r', encoding='utf-8') as infile:
text = infile.readlines()
text = [args.prompt.replace('<sourcetext>', line.strip()) for line in text]
# Show an example of how the prompt is added to the input text
print(f"Added prompt like this:\n{text[0]}")

# Prepare for batch processing
batch_size = 32

# Open the output file in append mode
with open(args.fileout, 'a', encoding='utf-8') as outfile:
start_time = time.time() # Start time
# Perform the translation with progress print statements
for i in range(0, len(text), batch_size):
batch = text[i:i+batch_size]
translated_batch = pipe(batch)

key = list(translated_batch[0][0].keys())[0] # Depending on the task, this may be either "translation_text" or "generated_text"

# Write each translated sentence to the output file incrementally
for sentence in translated_batch:
for translation in sentence:
outfile.write(f"{i} ||| {translation[key]}\n")

# Perform the translation
translations = pipe(text)
key = list(translations[0][0].keys())[0] # Depending on the task, this may be either "translation_text" or "generated_text"
# Print progress every 50 sentences
if i % 50 == 0:
print(f"Translated {i} sentences...")
end_time = time.time() # End time
total_time = end_time - start_time
translations_per_second = len(text) / total_time if total_time > 0 else float('inf')

# Write the results to the output file
with open(args.fileout, 'w', encoding='utf-8') as outfile:
for i, sentence in enumerate(translations):
for translation in sentence:
outfile.write(f"{i} ||| {translation[key]}\n")
# Final progress print
print(f"Translation complete. Translating {len(text)} sentences took {total_time} seconds.")
print(f"{translations_per_second:.2f} translations/second")

if __name__ == "__main__":
main()
main()
10 changes: 9 additions & 1 deletion rules/configuration.smk
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ else:
teacher_corpus = f'{clean}/corpus'

clean_corpus_src = f'{clean_corpus_prefix}.source.gz'
if opusmt_teacher:
clean_corpus_src = f'{clean_corpus_prefix}.source.langtagged.gz'

clean_corpus_trg = f'{clean_corpus_prefix}.target.gz'

# opustrainer
Expand All @@ -236,4 +239,9 @@ else:
if "huggingface" in config["experiment"]:
hf_teacher = config['experiment']['huggingface'].get('model')
hf_task = config['experiment']['huggingface'].get('task',"translation")
hf_prompt = config['experiment']['huggingface'].get('prompt',"")
hf_prompt = config['experiment']['huggingface'].get('prompt',"")
huggingface = True
train_student_dir = f"{merged}/{{langpair}}"
else:
huggingface = False
train_student_dir = f"{filtered}/{{langpair}}"
24 changes: 15 additions & 9 deletions rules/translate_hf.smk
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@ rule translate_corpus_hf:
threads: config["gpus_num"] * 2
resources: gpu=config["gpus_num"]
input:
teacher=config["hf_teacher"],
file=config["teacher_source_file"],
model_dir=config["final_teacher_dir"],
task=config["task"]
file=config["teacher_source_file"]
output: file=config["teacher_target_file"]
params: src_three_letter=lambda wildcards: Language.get(wildcards.config["langpair"].split('-')[0]).to_alpha3(),
trg_three_letter=lambda wildcards: Language.get(wildcards.config["langpair"].split('-')[1]).to_alpha3(),
prompt=config["prompt"]
params: src_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[0]).to_alpha3(),
trg_three_letter=lambda wildcards: Language.get(wildcards.langpair.split('-')[1]).to_alpha3(),
prompt=config["prompt"],
model_dir=config["final_teacher_dir"],
teacher=config["hf_teacher"],
task=config["task"]
# Hacky way to deal with optional prompt
shell: '''
PROMPT_ARG=""
if [ ! -z "{params.prompt}" ]; then
PROMPT_ARG="--prompt '{params.prompt}'"
fi
python pipeline/translate/translate_hf.py \
"{input.file}" "{output.file}" "{input.teacher}" "{input.model_dir}" "{params.src_three_letter}" "{params.trg_three_letter}" "{input.task}" {params.prompt and f'--prompt {params.prompt}'} >> {log} 2>&1
'''
"{input.file}" "{output.file}" "{params.teacher}" "{params.model_dir}" "{params.src_three_letter}" "{params.trg_three_letter}" "{params.task}" $PROMPT_ARG >> {log} 2>&1
'''

0 comments on commit 57e0036

Please sign in to comment.