Skip to content

Commit

Permalink
Add model loading for other algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack17432 committed Sep 5, 2023
1 parent a1a1035 commit 88f697f
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions cares_reinforcement_learning/util/NetworkFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ def create_DQN(args):

network = Network(args["observation_size"], args["action_num"], args["lr"])

if "network_file_path" in args:
network.load_state_dict(torch.load(args["network_file_path"]))

agent = DQN(
network=network,
gamma=args["gamma"],
Expand All @@ -21,6 +24,9 @@ def create_DuelingDQN(args):

network = DuelingNetwork(args["observation_size"], args["action_num"], args["lr"])

if "network_file_path" in args:
network.load_state_dict(torch.load(args["network_file_path"]))

agent = DQN(
network=network,
gamma=args["gamma"],
Expand All @@ -35,6 +41,9 @@ def create_DDQN(args):

network = Network(args["observation_size"], args["action_num"], args["lr"])

if "network_file_path" in args:
network.load_state_dict(torch.load(args["network_file_path"]))

agent = DoubleDQN(
network=network,
gamma=args["gamma"],
Expand All @@ -52,6 +61,10 @@ def create_PPO(args):
actor = Actor(args["observation_size"], args["action_num"], args["actor_lr"])
critic = Critic(args["observation_size"], args["critic_lr"])

if "actor_file_path" in args and "critic_file_path" in args:
actor.load_state_dict(torch.load(args["actor_file_path"]))
critic.load_state_dict(torch.load(args["critic_file_path"]))

agent = PPO(
actor_network=actor,
critic_network=critic,
Expand All @@ -70,6 +83,10 @@ def create_SAC(args):
actor = Actor(args["observation_size"], args["action_num"], args["actor_lr"])
critic = Critic(args["observation_size"], args["action_num"], args["critic_lr"])

if "actor_file_path" in args and "critic_file_path" in args:
actor.load_state_dict(torch.load(args["actor_file_path"]))
critic.load_state_dict(torch.load(args["critic_file_path"]))

agent = SAC(
actor_network=actor,
critic_network=critic,
Expand All @@ -89,6 +106,10 @@ def create_DDPG(args):
actor = Actor(args["observation_size"], args["action_num"], args["actor_lr"])
critic = Critic(args["observation_size"], args["action_num"], args["critic_lr"])

if "actor_file_path" in args and "critic_file_path" in args:
actor.load_state_dict(torch.load(args["actor_file_path"]))
critic.load_state_dict(torch.load(args["critic_file_path"]))

agent = DDPG(
actor_network=actor,
critic_network=critic,
Expand Down

0 comments on commit 88f697f

Please sign in to comment.