diff --git a/tests/test_train.py b/tests/test_train.py index 9343488..da0fb01 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -32,3 +32,19 @@ def test_validate_targets(): targets = [[1, 2, 3], [4, 5, 6]] vocab_size = 10 assert validate_targets(targets, vocab_size) == True + + +def test_train(): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + sequences = [[1, 2, 3], [4, 5, 6]] + targets = [[2, 3, 4], [5, 6, 7]] + dataset = CustomDataset(sequences, targets) + train_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) + val_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) + device = torch.device("cpu") + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + criterion = torch.nn.CrossEntropyLoss() + train(model, train_loader, val_loader, optimizer, criterion, device, epochs=1)