diff --git a/src/behavior_generation_lecture_python/mdp/policy.py b/src/behavior_generation_lecture_python/mdp/policy.py index 19b712e..d0d1f64 100644 --- a/src/behavior_generation_lecture_python/mdp/policy.py +++ b/src/behavior_generation_lecture_python/mdp/policy.py @@ -29,7 +29,7 @@ def __init__(self, sizes: List[int], actions: List): torch.manual_seed(1337) self.net = multi_layer_perceptron(sizes=sizes) self.actions = actions - self._actions_tensor = torch.as_tensor(actions, dtype=torch.float32).view( + self._actions_tensor = torch.tensor(actions, dtype=torch.long).view( len(actions), -1 )