You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ValueError: Expected parameter loc (Tensor of shape (1, 2)) of distribution Normal(loc: torch.Size([1, 2]), scale: torch.Size([1, 2])) to satisfy the constraint Real(), but found inval
#12
Open
ghLcd9dG opened this issue
Nov 30, 2024
· 0 comments
Hi, thanks for your good work. I was trying to run egs_trade/rl/a001_proto_sb3/main.py, but got this error.
I am sure that data is downloaded and properly handled.
/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py:150: UserWarning: You are trying to run PPO on the GPU, but it is primarily intended to run on the CPU when not using a CNN policy (you are using ActorCriticPolicy which should be a MlpPolicy). See https://github.com/DLR-RM/stable-baselines3/issues/1245 for more info. You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU.Note: The model will train, but the GPU utilization will be poor and the training might take longer than on CPU.
warnings.warn(
Traceback (most recent call last):
File "/home//code_nlpl/ai_quant_trade/egs_trade/rl/a001_proto_sb3/main.py", line 231, in <module>
main()
File "/home//code_nlpl/ai_quant_trade/egs_trade/rl/a001_proto_sb3/main.py", line 224, in main
test_a_stock_trade(train_path, test_path, test_stock_code, init_account_balance, out_path=p_out)
File "/home//code_nlpl/ai_quant_trade/egs_trade/rl/a001_proto_sb3/main.py", line 108, in test_a_stock_trade
mdl.train(stock_file)
File "/home//code_nlpl/ai_quant_trade/egs_trade/rl/a001_proto_sb3/main.py", line 56, in train
self._model.learn(total_timesteps=int(1e4))
File "/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/ppo/ppo.py", line 311, in learn
return super().learn(
File "/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 323, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 202, in collect_rollouts
actions, values, log_probs = self.policy(obs_tensor)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 654, in forward
distribution = self._get_action_dist_from_latent(latent_pi)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/common/policies.py", line 694, in _get_action_dist_from_latent
return self.action_dist.proba_distribution(mean_actions, self.log_std)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/stable_baselines3/common/distributions.py", line 164, in proba_distribution
self.distribution = Normal(mean_actions, action_std)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/torch/distributions/normal.py", line 56, in __init__
super().__init__(batch_shape, validate_args=validate_args)
File "/data///mconda/envs/csn/lib/python3.9/site-packages/torch/distributions/distribution.py", line 68, in __init__
raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (1, 2)) of distribution Normal(loc: torch.Size([1, 2]), scale: torch.Size([1, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan]], device='cuda:0')
The text was updated successfully, but these errors were encountered:
Hi, thanks for your good work. I was trying to run egs_trade/rl/a001_proto_sb3/main.py, but got this error.
I am sure that data is downloaded and properly handled.
The text was updated successfully, but these errors were encountered: