Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

models/gpt2/test_jax.py failed #49

Open
wangkuiyi opened this issue Jan 27, 2023 · 0 comments
Open

models/gpt2/test_jax.py failed #49

wangkuiyi opened this issue Jan 27, 2023 · 0 comments

Comments

@wangkuiyi
Copy link
Contributor

Problem

I tried to run the test with the command.

python iree-jax/models/gpt2/test_jax.py

The error message is attached at the end of this issue.

Reproduce

  1. Build IREE compiler and runtime with Python bindings. https://iree-org.github.io/iree/building-from-source/python-bindings-and-importers/#building-python-bindings
  2. Install IREE compiler & runtime Python bindings and iree.jax from source code. How to install iree-jax #47 (comment)
  3. Install dependencies of iree-jax/models/gpt2.
    conda install absl-py transformers h5py
    
  4. Run the test.
    python iree-jax/models/gpt2/test_jax.py
    

0:49 $ python iree-jax/models/gpt2/test_jax.py
Running tests under Python 3.10.8: /Users/y/miniforge3/envs/iree-jax/bin/python
[ RUN      ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
I0126 20:50:00.234236 8235580032 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.
I0126 20:50:00.234334 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0126 20:50:00.234373 8235580032 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234399 8235580032 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234518 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0126 20:50:00.234560 8235580032 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
  warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
[       OK ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
[ RUN      ] GPT2RealWeightsTest.test_batch_one1 ('iree')
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
  warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
I0126 20:50:07.604442 8235580032 binaries.py:182] Invoke IREE Tool: /Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvm-embedded-linker-path=/Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0
[  FAILED  ] GPT2RealWeightsTest.test_batch_one1 ('iree')
======================================================================
ERROR: test_batch_one1 ('iree') (__main__.GPT2RealWeightsTest)
GPT2RealWeightsTest.test_batch_one1 ('iree')
test_batch_one('iree')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 72, in <module>
    absltest.main()
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2060, in main
    _run_in_app(run_tests, args, kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2165, in _run_in_app
    app.run(main=main_function)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2163, in main_function
    function(argv, args, kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2561, in run_tests
    result = _run_and_get_tests_result(
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2527, in _run_and_get_tests_result
    test_program = unittest.TestProgram(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 82, in run
    return super(TextTestRunner, self).run(test)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
    return test_method(self, *testcase_params)
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
    kv, x0 = encode(params, kv, prompt, 0, t)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/api.py", line 565, in cache_miss
    out_flat = call_bind_continuation(execute(*args_flat))
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2108, in __call__
    input_bufs = self.in_handler(args)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1888, in __call__
    return self.handler(input_buffers)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in shard_args
    return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in <listcomp>
    return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 392, in _shard_arg
    return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
    return [buf if buf.device() == d else buf.copy_to_device(d)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
    return [buf if buf.device() == d else buf.copy_to_device(d)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]

Invoked with: DeviceArray([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,
               0.01506208,  0.04531523],
             [ 0.04034033, -0.04861503,  0.04624869, ...,  0.08605453,
               0.00253983,  0.04318958],
             [-0.12746179,  0.04793796,  0.18410145, ...,  0.08991534,
              -0.12972379, -0.08785918],
             ...,
             [-0.04453601, -0.05483596,  0.01225674, ...,  0.10435229,
               0.09783269, -0.06952604],
             [ 0.1860082 ,  0.01665728,  0.04611587, ..., -0.09625227,
               0.07847701, -0.02245961],
             [ 0.05135201, -0.02768905,  0.0499369 , ...,  0.00704835,
               0.15519823,  0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
    return test_method(self, *testcase_params)
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
    kv, x0 = encode(params, kv, prompt, 0, t)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
    return [buf if buf.device() == d else buf.copy_to_device(d)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
    return [buf if buf.device() == d else buf.copy_to_device(d)
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]

Invoked with: DeviceArray([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,
               0.01506208,  0.04531523],
             [ 0.04034033, -0.04861503,  0.04624869, ...,  0.08605453,
               0.00253983,  0.04318958],
             [-0.12746179,  0.04793796,  0.18410145, ...,  0.08991534,
              -0.12972379, -0.08785918],
             ...,
             [-0.04453601, -0.05483596,  0.01225674, ...,  0.10435229,
               0.09783269, -0.06952604],
             [ 0.1860082 ,  0.01665728,  0.04611587, ..., -0.09625227,
               0.07847701, -0.02245961],
             [ 0.05135201, -0.02768905,  0.0499369 , ...,  0.00704835,
               0.15519823,  0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>

----------------------------------------------------------------------
Ran 2 tests in 9.993s

FAILED (errors=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant