Skip to content

Commit

Permalink
修复bug
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Jul 21, 2022
1 parent 481c2c8 commit edfb59f
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/baseline/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,16 @@ def load_data(self) -> List[DataLoader]:
self._test_steps = ceil(
self._test_size / self.batch_size) # 训练总step num

self.check_val_every_n_steps = ceil(
self.check_val_every_n_epoch * self._train_steps) # 每多少个step进行验证


# 如果是分布式训练,则步数要除以总节点数
self._train_steps = ceil(self._train_steps / self.ddp_nodes_num)
self._dev_steps = ceil(self._dev_steps / self.ddp_nodes_num)
self._test_steps = ceil(self._test_steps / self.ddp_nodes_num)

self.check_val_every_n_steps = ceil(
self.check_val_every_n_epoch * self._train_steps) # 每多少个step进行验证

if self.check_val_every_n_steps < 10:
self.check_val_every_n_steps = 10

Expand All @@ -220,23 +222,23 @@ def load_data(self) -> List[DataLoader]:
if self.ddp_local_rank != -1:
# 如果使用分布式训练, 对train_ds进行DistributedSampler
train_ds = torch.utils.data.dataloader.DataLoader(
train_ds, sampler=DistributedSampler(train_ds), batch_size=self.batch_size, num_workers=4)
train_ds, sampler=DistributedSampler(train_ds), batch_size=self.batch_size, num_workers=8)

dev_ds = torch.utils.data.dataloader.DataLoader(
dev_ds, batch_size=self.batch_size, shuffle=False, num_workers=4)
dev_ds, batch_size=self.batch_size, shuffle=False, num_workers=8)

if test_ds is not None:
test_ds = torch.utils.data.dataloader.DataLoader(
test_ds, batch_size=self.batch_size, shuffle=False, num_workers=4)
test_ds, batch_size=self.batch_size, shuffle=False, num_workers=8)

else:
train_ds = torch.utils.data.dataloader.DataLoader(
train_ds, batch_size=self.batch_size, shuffle=True, num_workers=4)
train_ds, batch_size=self.batch_size, shuffle=True, num_workers=8)
dev_ds = torch.utils.data.dataloader.DataLoader(
dev_ds, batch_size=self.batch_size, shuffle=False, num_workers=4)
dev_ds, batch_size=self.batch_size, shuffle=False, num_workers=8)
if test_ds is not None:
test_ds = torch.utils.data.dataloader.DataLoader(
test_ds, batch_size=self.batch_size, shuffle=False, num_workers=4)
test_ds, batch_size=self.batch_size, shuffle=False, num_workers=8)

return [train_ds, dev_ds, test_ds]

Expand Down Expand Up @@ -276,7 +278,7 @@ def load_suite(self):
scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=self._warmup_steps,
num_training_steps=self._train_steps
num_training_steps=len(self.train_ds)*self.epochs
) if self._warmup_steps != -1 else None
return model, optimizer, scheduler

Expand Down

0 comments on commit edfb59f

Please sign in to comment.