Skip to content

Commit

Permalink
Adding experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
tatianacv committed Nov 30, 2023
1 parent db9f985 commit 7b6717b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ public void testTf3()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
testTf3("tf2_test_callbacks.py", "get_dataset");
testTf3("tf2_test_callbacks2.py", "replica_fn");
testTf3("tf2_test_callbacks3.py", "get_dataset");
}

public void testTf3(String file, String callback_function)
Expand Down
21 changes: 18 additions & 3 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -794,13 +794,16 @@

<package name="tensorflow/distribute">
<class name="MirroredStrategy" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self model">
<method name="do" descriptor="()LRoot;" numArgs="3" paramNames="self function">
<new def="x" class="Ltensorflow/distribute/distribute_datasets_from_function/distribute_datasets_from_function" />
<putfield class="LRoot" field="distribute_datasets_from_function" fieldType="LRoot" ref="self" value="x" />
<return value="arg0" />
<new def="y" class="Ltensorflow/distribute/run/run" />
<putfield class="LRoot" field="run" fieldType="LRoot" ref="self" value="y" />
<new def="y" class="Ltensorflow/distribute/experimental_distribute_datasets_from_function/experimental_distribute_datasets_from_function" />
<putfield class="LRoot" field="experimental_distribute_datasets_from_function" fieldType="LRoot" ref="self" value="y" />
<return value="arg1" />
<new def="z" class="Ltensorflow/distribute/run/run" />
<putfield class="LRoot" field="run" fieldType="LRoot" ref="self" value="z" />
<return value="arg2" />
</method>
</class>
</package>
Expand All @@ -817,6 +820,18 @@
</class>
</package>

<package name="tensorflow/distribute/experimental_distribute_datasets_from_function">
<class name="experimental_distribute_datasets_from_function" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self dataset_fn">
<new def="x" class="Lobject" />
<putfield class="LRoot" field="MirroredStrategy" fieldType="LRoot" ref="self" value="x" />
<putfield class="LRoot" field="$callback" fieldType="LRoot" ref="x" value="dataset_fn" />
<call class="LRoot" name="do" descriptor="()LRoot;" type="virtual" arg0="dataset_fn" arg1="1" numArgs="2" def="xx" />
<return value="xx" />
</method>
</class>
</package>

<package name="tensorflow/distribute/run">
<class name="run" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self fn">
Expand Down
15 changes: 15 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_callbacks3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf

# Testing API https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy#distribute_datasets_from_function
# Making sure that function `get_dataset` is in the CG

def get_dataset(input_context):
batch_size = input_context.get_per_replica_batch_size(2)
return tf.data.Dataset.range(4).batch(batch_size)

global_batch_size = 2

strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])

input_context = tf.distribute.InputContext()
dist_dataset = strategy.experimental_distribute_datasets_from_function(get_dataset)

0 comments on commit 7b6717b

Please sign in to comment.