diff --git a/solver.py b/solver.py index 7d65766..30ab53d 100644 --- a/solver.py +++ b/solver.py @@ -415,7 +415,7 @@ def viz_traverse(self, limit=3, inter=2/3, loc=-1): for i, key in enumerate(Z.keys()): for j, val in enumerate(interpolation): save_image(tensor=gifs[i][j].cpu(), - filename=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), + fp=os.path.join(output_dir, '{}_{}.jpg'.format(key, j)), nrow=self.z_dim, pad_value=1) grid2gif(os.path.join(output_dir, key+'*.jpg'),