Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration to Tensorflow2 #144

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Migration to Tensorflow2 #144

wants to merge 6 commits into from

Conversation

woj-i
Copy link
Contributor

@woj-i woj-i commented Oct 5, 2020

Automatic and manual changes for migration of Tensorflow 1 to 2.

Because there are no tests I checked some notebooks from examples:

  • very basic environment setup- works
  • setting up environment basic- works without the last cell (problem of requesting data from data server)
  • setting up environment full- works
  • guided ac3 - runtime error

I would need some help to finish this PR.

@woj-i woj-i mentioned this pull request Oct 5, 2020
@Kismuz
Copy link
Owner

Kismuz commented Oct 6, 2020

@woj-i , do I understand correctly issue is:
guided ac3 - runtime error,
or there are some other points?

@woj-i
Copy link
Contributor Author

woj-i commented Oct 6, 2020

I think it's more than that. But there are no tests to point that places. Currently I get the error (for guided_a3c notebook):

  File "/home/wind/repositories/gitHub/btgym/btgym/algorithms/aac.py", line 408, in __init__
    with tf.device(tf.compat.v1.train.replica_device_setter(1, worker_device=self.worker_device)):
  File "/home/wind/repositories/gitHub/btgym/venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 5273, in device_v2
    raise RuntimeError("tf.device does not support functions.")
RuntimeError: tf.device does not support functions.

@cmal
Copy link

cmal commented Oct 26, 2020

@woj-i Same error. Have you solved it?

@cmal
Copy link

cmal commented Oct 26, 2020

Change tf.device to tf.v1.compat.device will solve this.

@cmal
Copy link

cmal commented Oct 26, 2020

I found btgym/research/casual_conv/networks.py have the following code:

    alignments = attention_mechanism(
        query_state,
        attention_mechanism.initial_alignments(tf.shape(inputs)[0], dtype=tf.float32)
    )

Because LuongAttention has a significant change in tensorflow 2.0. So we should update this code.
But I do not understand the logic here. Could anyone help?

@cmal
Copy link

cmal commented Oct 26, 2020

and lstm_network in algorithms/nn/networks.py also need some updates, because the dropoutwrapper will cause error in tensorflow2, use dropout parameter instead, according to https://stackoverflow.com/questions/62989175/layernormlstmcell-object-has-no-attribute-zero-state-in-tf-2-2 and tensorflow/tensorflow#29129

@cmal
Copy link

cmal commented Oct 27, 2020

And this

Traceback (most recent call last):
  File "/Users/cmal/btgym/btgym/algorithms/aac.py", line 409, in __init__
    self.network = pi_global = self._make_policy('global')
  File "/Users/cmal/btgym/btgym/algorithms/aac.py", line 842, in _make_policy
    network = self.policy_class(**self.policy_kwargs)
  File "/Users/cmal/btgym/btgym/algorithms/policy/stacked_lstm.py", line 300, in __init__
    **kwargs,
  File "/Users/cmal/btgym/btgym/algorithms/nn/networks.py", line 148, in lstm_network
    lstm_init_state = lstm.zero_state(1, dtype=tf.float32)
  File "/Users/cmal/py3tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py", line 1273, in zero_state
    return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
  File "/Users/cmal/py3tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py", line 1273, in <genexpr>
    return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
  File "/Users/cmal/py3tf2.3/lib/python3.7/site-packages/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_wrapper_impl.py", line 203, in zero_state
    return self.cell.zero_state(batch_size, dtype)
AttributeError: 'LayerNormLSTMCell' object has no attribute 'zero_state'
[2020-10-27 04:19:46.238595] ERROR: Worker_0: Base class __init__() exception occurred.

@woj-i
Copy link
Contributor Author

woj-i commented Oct 29, 2020

Hey @cmal! Thank you for your help! You are very welcome to PR your changes to my tensorflow2 branch https://github.com/woj-i/btgym/tree/tensorflow2

@Kismuz
Copy link
Owner

Kismuz commented Nov 2, 2020

seems install_requires needs refactoring to resolve conflicts

@Arno989
Copy link

Arno989 commented Jan 4, 2021

Is there any progress on this upgrade PR? I wish to use this environment for my bachelor thesis the coming weeks and perhaps i could be of assistance to speed things up a bit because I have only yet worked with TF 2.0.

@woj-i
Copy link
Contributor Author

woj-i commented Jan 5, 2021

No progress from my side. @Arno989 you're welcome to contribute to this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants