Skip to content

Commit

Permalink
update training test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 10, 2024
1 parent 15d18d7 commit bc4a9c8
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,16 @@ def test_collate_fn():
inputs, targets = collate_fn(batch)
assert inputs.shape == (2, 3)
assert targets.shape == (2, 3)


def test_preprocess_data():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
data = "This is a test."
sequences = preprocess_data(data, tokenizer, max_length=10, overlap=5)
assert len(sequences) > 0


def test_validate_targets():
targets = [[1, 2, 3], [4, 5, 6]]
vocab_size = 10
assert validate_targets(targets, vocab_size) == True

0 comments on commit bc4a9c8

Please sign in to comment.