You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
0:49 $ python iree-jax/models/gpt2/test_jax.py
Running tests under Python 3.10.8: /Users/y/miniforge3/envs/iree-jax/bin/python
[ RUN ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
I0126 20:50:00.234236 8235580032 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.
I0126 20:50:00.234334 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0126 20:50:00.234373 8235580032 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234399 8235580032 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234518 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0126 20:50:00.234560 8235580032 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
[ OK ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
[ RUN ] GPT2RealWeightsTest.test_batch_one1 ('iree')
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
I0126 20:50:07.604442 8235580032 binaries.py:182] Invoke IREE Tool: /Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvm-embedded-linker-path=/Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0
[ FAILED ] GPT2RealWeightsTest.test_batch_one1 ('iree')
======================================================================
ERROR: test_batch_one1 ('iree') (__main__.GPT2RealWeightsTest)
GPT2RealWeightsTest.test_batch_one1 ('iree')
test_batch_one('iree')
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 72, in <module>
absltest.main()
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2060, in main
_run_in_app(run_tests, args, kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2165, in _run_in_app
app.run(main=main_function)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2163, in main_function
function(argv, args, kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2561, in run_tests
result = _run_and_get_tests_result(
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2527, in _run_and_get_tests_result
test_program = unittest.TestProgram(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 82, in run
return super(TextTestRunner, self).run(test)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
return test_method(self, *testcase_params)
File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
kv, x0 = encode(params, kv, prompt, 0, t)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/api.py", line 565, in cache_miss
out_flat = call_bind_continuation(execute(*args_flat))
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2108, in __call__
input_bufs = self.in_handler(args)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1888, in __call__
return self.handler(input_buffers)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in shard_args
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in <listcomp>
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 392, in _shard_arg
return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
return [buf if buf.device() == d else buf.copy_to_device(d)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
return [buf if buf.device() == d else buf.copy_to_device(d)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: (): incompatible function arguments. The following argument types are supported:
1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]
Invoked with: DeviceArray([[-0.11010301, -0.03926672, 0.03310751, ..., -0.1363697 ,
0.01506208, 0.04531523],
[ 0.04034033, -0.04861503, 0.04624869, ..., 0.08605453,
0.00253983, 0.04318958],
[-0.12746179, 0.04793796, 0.18410145, ..., 0.08991534,
-0.12972379, -0.08785918],
...,
[-0.04453601, -0.05483596, 0.01225674, ..., 0.10435229,
0.09783269, -0.06952604],
[ 0.1860082 , 0.01665728, 0.04611587, ..., -0.09625227,
0.07847701, -0.02245961],
[ 0.05135201, -0.02768905, 0.0499369 , ..., 0.00704835,
0.15519823, 0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
return test_method(self, *testcase_params)
File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
kv, x0 = encode(params, kv, prompt, 0, t)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
return [buf if buf.device() == d else buf.copy_to_device(d)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
return [buf if buf.device() == d else buf.copy_to_device(d)
TypeError: (): incompatible function arguments. The following argument types are supported:
1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]
Invoked with: DeviceArray([[-0.11010301, -0.03926672, 0.03310751, ..., -0.1363697 ,
0.01506208, 0.04531523],
[ 0.04034033, -0.04861503, 0.04624869, ..., 0.08605453,
0.00253983, 0.04318958],
[-0.12746179, 0.04793796, 0.18410145, ..., 0.08991534,
-0.12972379, -0.08785918],
...,
[-0.04453601, -0.05483596, 0.01225674, ..., 0.10435229,
0.09783269, -0.06952604],
[ 0.1860082 , 0.01665728, 0.04611587, ..., -0.09625227,
0.07847701, -0.02245961],
[ 0.05135201, -0.02768905, 0.0499369 , ..., 0.00704835,
0.15519823, 0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>
----------------------------------------------------------------------
Ran 2 tests in 9.993s
FAILED (errors=1)
The text was updated successfully, but these errors were encountered:
Problem
I tried to run the test with the command.
The error message is attached at the end of this issue.
Reproduce
iree.jax
from source code. How to install iree-jax #47 (comment)iree-jax/models/gpt2
.The text was updated successfully, but these errors were encountered: