diff --git a/trainers/train_densenet.py b/trainers/train_densenet.py index 41094ad..1c241d6 100644 --- a/trainers/train_densenet.py +++ b/trainers/train_densenet.py @@ -56,7 +56,7 @@ def train(epoch): #net = torch.nn.DataParallel(net, device_ids=[0, 1]) net = torch.nn.DataParallel(net) # resnet.load_state_dict(torch.load('../models/simplenet_v3.pth')) - train_data_set = train_jpg_loader(72, transform=Compose( + train_data_set = train_jpg_loader(32, transform=Compose( [ Scale(256), @@ -68,7 +68,7 @@ def train(epoch): Normalize(mean, std) ] )) - validation_data_set = validation_jpg_loader(64, transform=Compose( + validation_data_set = validation_jpg_loader(32, transform=Compose( [ Scale(224), ToTensor(), @@ -83,6 +83,7 @@ def train(epoch): training_loss = 0.0 for batch_index, (target_x, target_y) in enumerate(train_data_set): if torch.cuda.is_available(): + #print("[*] CUDA is available") target_x, target_y = target_x.cuda(), target_y.cuda() net.train() target_x, target_y = Variable(target_x), Variable(target_y)