Skip to content

Commit 5c20494

Browse files
committed
Change training schedule
1 parent 7b4baf7 commit 5c20494

File tree

12 files changed

+52
-46
lines changed

12 files changed

+52
-46
lines changed

diploma_thesis/agents/utils/rl/dqn.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
1-
from dataclasses import dataclass
21
from typing import Dict
32

43
import tensordict
54
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
65

76
from agents.utils.memory import NotReadyException
8-
from agents.utils.rl.rl import *
97
from agents.utils.return_estimator import ValueFetchMethod
8+
from agents.utils.rl.rl import *
109

1110

1211
class DeepQTrainer(RLTrainer):
1312
@dataclass
1413
class Configuration:
1514
decay: float = 0.99
1615
update_steps: int = 10
16+
epochs: int = 1
1717

1818
@staticmethod
1919
def from_cli(parameters: Dict):
2020
return DeepQTrainer.Configuration(
2121
decay=parameters.get('decay', 0.99),
22-
update_steps=parameters.get('update_steps', 20)
22+
update_steps=parameters.get('update_steps', 20),
23+
epochs=parameters.get('epochs', 1)
2324
)
2425

2526
def __init__(self, configuration: Configuration, *args, **kwargs):
@@ -37,34 +38,35 @@ def configure(self, model: Policy):
3738
self._target_model = AveragedModel(model.clone(), avg_fn=avg_fn).to(self.device)
3839

3940
def __train__(self, model: Policy):
40-
try:
41-
batch, info = self.storage.sample(device=self.device)
42-
except NotReadyException:
43-
return
41+
for _ in range(self.configuration.epochs):
42+
try:
43+
batch, info = self.storage.sample(device=self.device)
44+
except NotReadyException:
45+
return
4446

45-
with torch.no_grad():
46-
q_values = self.estimate_q(model, batch)
47+
with torch.no_grad():
48+
q_values = self.estimate_q(model, batch)
4749

48-
def compute_loss():
49-
actions = self.__get_action_values__(model, batch.state, batch.action)
50+
def compute_loss():
51+
actions = self.__get_action_values__(model, batch.state, batch.action)
5052

51-
loss_ = self.loss(actions, q_values)
52-
td_error_ = torch.square(actions - q_values)
53+
loss_ = self.loss(actions, q_values)
54+
td_error_ = torch.square(actions - q_values)
5355

54-
entropy = torch.distributions.Categorical(logits=actions).entropy().mean()
56+
entropy = torch.distributions.Categorical(logits=actions).entropy().mean()
5557

56-
return loss_, (td_error_, entropy)
58+
return loss_, (td_error_, entropy)
5759

58-
loss, result = self.step(compute_loss, self.optimizer)
60+
loss, result = self.step(compute_loss, self.optimizer)
5961

60-
td_error, entropy = result
61-
td_error_mean = td_error.mean()
62+
td_error, entropy = result
63+
td_error_mean = td_error.mean()
6264

63-
self.record_loss(loss)
64-
self.record_loss(td_error_mean, key='td_error')
65-
self.record_loss(entropy, key='entropy')
65+
self.record_loss(loss)
66+
self.record_loss(td_error_mean, key='td_error')
67+
self.record_loss(entropy, key='entropy')
6668

67-
print(f'loss: {loss}, td_error: {td_error_mean}, entropy: {entropy}')
69+
print(f'loss: {loss}, td_error: {td_error_mean}, entropy: {entropy}')
6870

6971
with torch.no_grad():
7072
if self.optimizer.step_count % self.configuration.update_steps == 0:

diploma_thesis/configuration/experiments/jsp/GRAPH-NN/experiments/1 (DQN)/experiment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ graph: &graph
1919
is_work_center_set_in_shop_floor_connected: False
2020

2121
default_mods: &default_mods
22-
- 'util/infrastructure/cuda.yml'
22+
# - 'util/infrastructure/cuda.yml'
2323
- 'util/train_schedule/on_store_64.yml'
2424

2525
###############################################################################################

diploma_thesis/configuration/experiments/jsp/GRAPH-NN/flexible_machine.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ parameters:
5454
parameters:
5555
decay: 1.0
5656
update_steps: 20
57+
epochs: 5
5758

5859
memory:
5960
kind: 'replay'
6061
parameters:
6162
size: 2048
62-
batch_size: 128
63+
batch_size: 256
6364
prefetch: 8
6465

6566
loss:

diploma_thesis/configuration/experiments/jsp/GRAPH-NN/flexible_marl_machine.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ parameters:
5555
parameters:
5656
decay: 1.0
5757
update_steps: 20
58+
epochs: 5
5859

5960
memory:
6061
kind: 'replay'
6162
parameters:
6263
size: 2048
63-
batch_size: 128
64+
batch_size: 256
6465
prefetch: 8
6566

6667
loss:

diploma_thesis/configuration/experiments/jsp/GRAPH-NN/machine.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ parameters:
5454
parameters:
5555
decay: 1.0
5656
update_steps: 20
57+
epochs: 5
5758

5859
memory:
5960
kind: 'replay'
6061
parameters:
61-
size: 1024
62-
batch_size: 128
62+
size: 2048
63+
batch_size: 256
6364
prefetch: 8
6465

6566
loss:

diploma_thesis/configuration/experiments/jsp/GRAPH-NN/marl_machine.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ parameters:
5555
parameters:
5656
decay: 1.0
5757
update_steps: 20
58+
epochs: 5
5859

5960
memory:
6061
kind: 'replay'
6162
parameters:
6263
size: 2048
63-
batch_size: 128
64+
batch_size: 256
6465
prefetch: 8
6566

6667
loss:

diploma_thesis/configuration/experiments/jsp/MARL-DQN/experiment/2/experiment.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ parallel_multi_source_run: &parallel_multi_source_run
125125
parameters:
126126
mods:
127127
__inout_factory__:
128-
- [ ['utilization/90.yml'] ]
128+
- [ ['utilization/80.yml'] ]
129129
nested:
130130
parameters:
131131
dispatch:
132-
seed: [ [ 0, 1, 2 ] ]
132+
seed: [ [ 0, 1, 2, 3 ] ]
133133

134134

135135
###############################################################################################

diploma_thesis/configuration/experiments/jsp/MARL-DQN/experiment/2/parallel_run.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ parameters:
2020
train_interval: 100
2121
max_training_steps: 0
2222

23-
n_workers: 3
23+
n_workers: 4

diploma_thesis/configuration/experiments/jsp/MARL-DQN/machine.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ parameters:
6868

6969
optimizer:
7070
model:
71-
kind: 'sgd'
71+
kind: 'adam_w'
7272
parameters:
7373
lr: 0.001
7474

diploma_thesis/configuration/experiments/jsp/MARL-DQN/marl_machine.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ parameters:
6868

6969
optimizer:
7070
model:
71-
kind: 'sgd'
71+
kind: 'adam_w'
7272
parameters:
7373
lr: 0.001
7474

Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
rules:
33
- 'spt'
4-
- 'lwkr'
4+
- 'cr'
55
- 'ms'
66
- 'winq'

diploma_thesis/configuration/experiments/jsp/tournament.yml

+12-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ task:
44
n_workers: 8
55
n_threads: 8
66
name: 'rules'
7-
output_dir: 'results/jsp/experiments/tournaments/jsp/1. MARL-DQN'
7+
output_dir: 'results/jsp/experiments/tournaments/jsp/Static Rules'
88
store_run_statistics: True
99
log_run: False
1010
update: True
@@ -26,12 +26,12 @@ task:
2626
# prefix: 'preferred'
2727
# path: 'results/jsp/experiments/1_4/preferred'
2828

29-
- kind: 'persisted_agents'
30-
parameters:
31-
prefix: ''
32-
path: 'results/jsp/experiments/1. MARL-DQN'
33-
depth: 5
34-
#
29+
# - kind: 'persisted_agents'
30+
# parameters:
31+
# prefix: ''
32+
# path: 'results/jsp/experiments/1. MARL-DQN'
33+
# depth: 5
34+
##
3535
# - kind: 'persisted_agents'
3636
# parameters:
3737
# prefix: 'flexible_with_graph'
@@ -97,11 +97,11 @@ task:
9797
# prefix: 'marl_all'
9898
# path: 'results/jsp/experiments/1_2/4'
9999
#
100-
# - kind: 'static'
101-
# parameters:
102-
# scheduling: 'all'
103-
# routing:
104-
# - 'ct'
100+
- kind: 'static'
101+
parameters:
102+
scheduling: 'all'
103+
routing:
104+
- 'ct'
105105

106106

107107
# - kind: 'persisted_agents'

0 commit comments

Comments
 (0)