Skip to content

Commit 63c265d

Browse files
committed
Fixes
1 parent 2a39b04 commit 63c265d

File tree

17 files changed

+149521
-42
lines changed

17 files changed

+149521
-42
lines changed

diploma_thesis/agents/utils/memory/memory.py

+7
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def is_filled(self):
5050
class NotReadyException(BaseException):
5151
pass
5252

53+
stores = 0
5354

5455
class Memory(Generic[_Configuration], metaclass=ABCMeta):
5556

@@ -63,6 +64,12 @@ def store(self, records: List[Record] | List[List[Record]]):
6364

6465
self.buffer.extend(records)
6566
else:
67+
global stores
68+
69+
stores += 1
70+
71+
print(f'Stores {stores} {len(records)}')
72+
6673
self.buffer.extend(records)
6774

6875
def sample(self, return_info: bool = False) -> List[Record]:

diploma_thesis/agents/utils/memory/replay_memory.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torchrl.data import TensorDictReplayBuffer
1+
from torchrl.data import TensorDictReplayBuffer, SamplerWithoutReplacement
22

33
from .memory import *
44
from .memory import Configuration as MemoryConfiguration
@@ -23,7 +23,7 @@ def from_cli(cls, parameters: Dict):
2323
class ReplayMemory(Memory[Configuration]):
2424

2525
def __make_buffer__(self) -> ReplayBuffer | TensorDictReplayBuffer:
26-
sampler = self.configuration.sampler.make() if self.configuration.sampler else RandomSampler()
26+
sampler = self.configuration.sampler.make() if self.configuration.sampler else SamplerWithoutReplacement()
2727
cls = None
2828

2929
params = dict(

diploma_thesis/agents/utils/nn/layers/linear.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,16 @@ def __init__(self,
3838
self.__build__()
3939

4040
def initialize_parameters(self, input) -> None:
41-
self.linear.initialize_parameters(input)
42-
43-
if isinstance(self.linear, torch.nn.Linear):
44-
match self.initialization:
45-
case Initialization.orthogonal.value:
46-
torch.nn.init.orthogonal_(self.linear.weight)
47-
torch.nn.init.zeros_(self.linear.bias)
48-
case _:
49-
pass
41+
if self.linear.has_uninitialized_params():
42+
self.linear.initialize_parameters(input)
43+
44+
if isinstance(self.linear, torch.nn.Linear):
45+
match self.initialization:
46+
case Initialization.orthogonal.value:
47+
torch.nn.init.orthogonal_(self.linear.weight)
48+
torch.nn.init.zeros_(self.linear.bias)
49+
case _:
50+
pass
5051

5152
def forward(self, batch: torch.FloatTensor) -> torch.FloatTensor:
5253
batch = self.linear(batch)

diploma_thesis/agents/utils/rl/ddqn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ def estimate_q(self, model: Policy, batch: Record | tensordict.TensorDictBase):
1414

1515
target = self.__get_action_values__(self.target_model, batch.next_state, best_actions)
1616

17-
q = batch.reward.squeeze() + self.return_estimator.discount_factor * target * (1 - batch.done.squeeze().int())
17+
q = batch.reward.squeeze() + self.return_estimator.discount_factor * target #* (1 - batch.done.squeeze().int())
1818

1919
return q

0 commit comments

Comments
 (0)