Skip to content

Commit

Permalink
Merge pull request #12 from longtermrisk/bugfix/sft-trainer
Browse files Browse the repository at this point in the history
fix sft_trainer bug, don't buffer run logs for so long
  • Loading branch information
nielsrolf authored Feb 12, 2025
2 parents 73d38e5 + 92596bc commit 3b5acc1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
12 changes: 7 additions & 5 deletions openweights/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,16 @@ def _execute_job(self, job):
stderr=subprocess.STDOUT,
cwd=tmp_dir,
env=env,
preexec_fn=os.setsid # Allow us to send signals to the process group
preexec_fn=os.setsid, # Allow us to send signals to the process group
bufsize=1, # Line buffered
universal_newlines=True # Text mode
)

# Stream logs to both file and stdout
for line in iter(self.current_process.stdout.readline, b''):
decoded = line.decode().rstrip('\n')
print(decoded)
log_file.write(decoded + '\n')
for line in iter(self.current_process.stdout.readline, ''):
print(line.rstrip('\n'), flush=True) # Immediate stdout flush
log_file.write(line)
log_file.flush() # Force immediate write to file

self.current_process.wait()

Expand Down
2 changes: 1 addition & 1 deletion openweights/worker/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def apply_chat_template(examples):
instruction_part, response_part = get_instruct_response_part(tokenizer)
trainer_kwargs['data_collator'] = DataCollatorForSeq2Seq(tokenizer = tokenizer)
trainer = train_on_responses_only(
SFTTrainer(**trainer),
SFTTrainer(**trainer_kwargs),
instruction_part=instruction_part,
response_part=response_part
)
Expand Down

0 comments on commit 3b5acc1

Please sign in to comment.