Skip to content

Commit

Permalink
[algorithms.dreamver_v3] fix: 修正漏れ
Browse files Browse the repository at this point in the history
  • Loading branch information
pocokhc committed Jan 22, 2024
1 parent ba698d0 commit ddffda3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 39 deletions.
40 changes: 23 additions & 17 deletions srl/algorithms/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,18 @@
import tensorflow_probability as tfp
from tensorflow import keras

from srl.base.define import (DoneTypes, EnvObservationTypes, RLBaseTypes,
RLTypes)
from srl.base.define import DoneTypes, EnvObservationTypes, RLBaseTypes, RLTypes
from srl.base.exception import UndefinedError
from srl.base.rl.base import RLParameter, RLTrainer, RLWorker
from srl.base.rl.config import RLConfig
from srl.base.rl.processor import Processor
from srl.base.rl.registration import register
from srl.base.rl.worker_run import WorkerRun
from srl.rl.functions import common
from srl.rl.memories.experience_replay_buffer import (
ExperienceReplayBuffer, ExperienceReplayBufferConfig)
from srl.rl.models.tf.distributions.bernoulli_dist_block import \
BernoulliDistBlock
from srl.rl.models.tf.distributions.categorical_dist_block import (
CategoricalDistBlock, CategoricalGradDist)
from srl.rl.models.tf.distributions.categorical_gumbel_dist_block import \
CategoricalGumbelDistBlock
from srl.rl.memories.experience_replay_buffer import ExperienceReplayBuffer, ExperienceReplayBufferConfig
from srl.rl.models.tf.distributions.bernoulli_dist_block import BernoulliDistBlock
from srl.rl.models.tf.distributions.categorical_dist_block import CategoricalDistBlock, CategoricalGradDist
from srl.rl.models.tf.distributions.categorical_gumbel_dist_block import CategoricalGumbelDistBlock
from srl.rl.models.tf.distributions.linear_block import Linear, LinearBlock
from srl.rl.models.tf.distributions.normal_dist_block import NormalDistBlock
from srl.rl.models.tf.distributions.twohot_dist_block import TwoHotDistBlock
Expand Down Expand Up @@ -1432,7 +1427,7 @@ def _f(arr1, arr2, i):

return

# @tf.function
@tf.function
def _compute_horizon_step(self, stoch, deter, feat, is_critic: bool = False):
# featはアクション後の状態でQに近いイメージ
horizon_feat = [feat]
Expand Down Expand Up @@ -1495,13 +1490,8 @@ def _compute_horizon_step(self, stoch, deter, feat, is_critic: bool = False):
if self.config.actor_loss_type == "dreamer_v1":
# Vの最大化
actor_loss = -tf.reduce_mean(tf.reduce_sum(horizon_V[1:], axis=0))
elif self.config.actor_loss_type in ["dreamer_v2", "dreamer_v3"]:
elif self.config.actor_loss_type == "dreamer_v2":
adv = horizon_V[1:]
if self.config.actor_loss_type == "dreamer_v3":
# パーセンタイルの計算
d5 = tfp.stats.percentile(adv, 5)
d95 = tfp.stats.percentile(adv, 95)
adv = adv / tf.maximum(1.0, d95 - d5)

if self.config.action_type == RLTypes.DISCRETE:
# dynamics backprop 最大化
Expand All @@ -1524,6 +1514,22 @@ def _compute_horizon_step(self, stoch, deter, feat, is_critic: bool = False):
# entropyの最大化
entropy_loss = -self.config.entropy_rate * tf.reduce_mean(entropy)

actor_loss = act_v_loss + entropy_loss

elif self.config.actor_loss_type == "dreamer_v3":
adv = horizon_V[1:]

# パーセンタイルの計算
d5 = tfp.stats.percentile(adv, 5)
d95 = tfp.stats.percentile(adv, 95)
adv = adv / tf.maximum(1.0, d95 - d5)

# dynamics backprop 最大化
act_v_loss = -tf.reduce_mean(tf.reduce_sum(adv, axis=0))

# entropyの最大化
entropy_loss = -self.config.entropy_rate * tf.reduce_mean(entropy)

actor_loss = act_v_loss + entropy_loss
else:
raise UndefinedError(self.config.actor_loss_type)
Expand Down
49 changes: 27 additions & 22 deletions tests/algorithms_/test_dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def test_compute_V_discount():
).numpy()
print(horizon_V)
y = [
[[4 + discount * 6.339], [1 + discount * 1]],
[[3 + discount * 3.71], [1]],
[[2 + discount * 1.9], [0]],
[[1 + discount * 1], [0]],
[[4 + discount * 7.068], [1 + discount * 1]],
[[3 + discount * 4.52], [1]],
[[2 + discount * 2.8], [0]],
[[1 + discount * 2], [0]],
]
print(y)
assert horizon_V.shape == (4, 2, 1)
Expand All @@ -96,7 +96,7 @@ def test_compute_V_dreamer_v1(horizon_ewa_disclam):

horizon_reward, horizon_v, horizon_cont = _setup_compute_V()
horizon_V = _compute_V(
"dreamer_v1",
"ewa",
horizon_reward,
horizon_v,
horizon_cont,
Expand All @@ -107,10 +107,10 @@ def test_compute_V_dreamer_v1(horizon_ewa_disclam):
print(horizon_V, horizon_V.shape)

y = [
[[4 + discount * 6.339], [1 + discount * 1]],
[[3 + discount * 3.71], [1]],
[[2 + discount * 1.9], [0]],
[[1 + discount * 1], [0]],
[[4 + discount * 7.068], [1 + discount * 1]],
[[3 + discount * 4.52], [1]],
[[2 + discount * 2.8], [0]],
[[1 + discount * 2], [0]],
]
print(y)
y1 = []
Expand All @@ -136,7 +136,7 @@ def test_compute_V_dreamer_v2(horizon_h_target):

horizon_reward, horizon_v, horizon_cont = _setup_compute_V()
horizon_V = _compute_V(
"dreamer_v2",
"h-return",
horizon_reward,
horizon_v,
horizon_cont,
Expand All @@ -146,15 +146,20 @@ def test_compute_V_dreamer_v2(horizon_h_target):
).numpy()
print(horizon_V, horizon_V.shape)

y = [
[
[4 + discount * ((1 - horizon_h_target) * 1 + horizon_h_target * 6.02949)],
[1 + discount * ((1 - horizon_h_target) * 2 + horizon_h_target * 1)],
],
[[3 + discount * ((1 - horizon_h_target) * 1 + horizon_h_target * 3.629)], [1]],
[[2 + discount * ((1 - horizon_h_target) * 1 + horizon_h_target * 1.9)], [0]],
[[1 + discount * 1], [0]],
]
_a = 1 + discount * 2
_b = 0
y = [[[_a], [_b]]]
horizon_reward = horizon_reward.numpy()
horizon_v = horizon_v.numpy()
for i in reversed(range(3)):
_a = horizon_reward[i][0][0] + discount * ((1 - horizon_h_target) * horizon_v[i][0][0] + horizon_h_target * _a)
if i == 1:
_b = 1
elif i == 0:
_b = horizon_reward[i][1][0] + discount * (
(1 - horizon_h_target) * horizon_v[i][1][0] + horizon_h_target * _b
)
y.insert(0, [[_a], [_b]])
print(y)

assert horizon_V.shape == (4, 2, 1)
Expand All @@ -163,7 +168,7 @@ def test_compute_V_dreamer_v2(horizon_h_target):

@pytest.mark.parametrize("normalization_type", ["none", "layer"])
@pytest.mark.parametrize("resize_type", ["stride", "stride3", "max"])
@pytest.mark.parametrize("dist_type", ["mse", "normal"])
@pytest.mark.parametrize("dist_type", ["linear", "normal"])
def test_image_enc_dec(normalization_type, resize_type, dist_type):
from srl.algorithms.dreamer_v3 import _ImageDecoder, _ImageEncoder

Expand All @@ -186,6 +191,6 @@ def test_image_enc_dec(normalization_type, resize_type, dist_type):
# test_compute_V_simple()
# test_compute_V_discount()
# test_compute_V_dreamer_v1(0.1)
# test_compute_V_dreamer_v2(0.9)
test_compute_V_dreamer_v2(0.9)

test_image_enc_dec("none", "max", "mse")
# test_image_enc_dec("none", "max", "mse")

0 comments on commit ddffda3

Please sign in to comment.