@@ -448,27 +448,49 @@ def test_iter():
448
448
pytest .raises (TypeError , lambda : iter (ones ((3 , 3 ))))
449
449
450
450
@pytest .mark .parametrize ("api_version" , ['2021.12' , '2022.12' , '2023.12' ])
451
- def dlpack_2023_12 (api_version ):
451
+ def test_dlpack_2023_12 (api_version ):
452
452
if api_version == '2021.12' :
453
453
with pytest .warns (UserWarning ):
454
454
set_array_api_strict_flags (api_version = api_version )
455
455
else :
456
456
set_array_api_strict_flags (api_version = api_version )
457
457
458
458
a = asarray ([1 , 2 , 3 ], dtype = int8 )
459
-
460
- # Do not error
459
+ # Never an error
461
460
a .__dlpack__ ()
462
- a .__dlpack__ (dl_device = CPU_DEVICE )
463
- a .__dlpack__ (dl_device = None )
464
- a .__dlpack__ (max_version = (1 , 0 ))
465
- a .__dlpack__ (max_version = None )
466
- a .__dlpack__ (copy = False )
467
- a .__dlpack__ (copy = True )
468
- a .__dlpack__ (copy = None )
469
-
470
- x = np .from_dlpack (a )
471
- assert isinstance (x , np .ndarray )
472
- assert x .dtype == np .int8
473
- assert x .shape == (3 ,)
474
- assert np .all (x == np .asarray ([1 , 2 , 3 ]))
461
+
462
+ if api_version < '2023.12' :
463
+ pytest .raises (ValueError , lambda :
464
+ a .__dlpack__ (dl_device = a .__dlpack_device__ ()))
465
+ pytest .raises (ValueError , lambda :
466
+ a .__dlpack__ (dl_device = None ))
467
+ pytest .raises (ValueError , lambda :
468
+ a .__dlpack__ (max_version = (1 , 0 )))
469
+ pytest .raises (ValueError , lambda :
470
+ a .__dlpack__ (max_version = None ))
471
+ pytest .raises (ValueError , lambda :
472
+ a .__dlpack__ (copy = False ))
473
+ pytest .raises (ValueError , lambda :
474
+ a .__dlpack__ (copy = True ))
475
+ pytest .raises (ValueError , lambda :
476
+ a .__dlpack__ (copy = None ))
477
+ elif np .lib .NumpyVersion (np .__version__ ) < '2.1.0' :
478
+ pytest .raises (NotImplementedError , lambda :
479
+ a .__dlpack__ (dl_device = CPU_DEVICE ))
480
+ a .__dlpack__ (dl_device = None )
481
+ pytest .raises (NotImplementedError , lambda :
482
+ a .__dlpack__ (max_version = (1 , 0 )))
483
+ a .__dlpack__ (max_version = None )
484
+ pytest .raises (NotImplementedError , lambda :
485
+ a .__dlpack__ (copy = False ))
486
+ pytest .raises (NotImplementedError , lambda :
487
+ a .__dlpack__ (copy = True ))
488
+ a .__dlpack__ (copy = None )
489
+ else :
490
+ a .__dlpack__ (dl_device = a .__dlpack_device__ ())
491
+ a .__dlpack__ (dl_device = None )
492
+ a .__dlpack__ (max_version = (1 , 0 ))
493
+ a .__dlpack__ (max_version = None )
494
+ a .__dlpack__ (copy = False )
495
+ a .__dlpack__ (copy = True )
496
+ a .__dlpack__ (copy = None )
0 commit comments