diff --git a/stride/__init__.py b/stride/__init__.py index fab679b..a500c79 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -131,7 +131,7 @@ async def loop(worker, shot_id): if using_gpu: deviceid = devices[worker.indices[1] % num_gpus] - if platform == 'nvidia-acc': + if platform in ['nvidia-acc', 'nvidia-cuda']: devito_args = _kwargs.get('devito_args', {}) devito_args['deviceid'] = deviceid _kwargs['devito_args'] = devito_args @@ -312,7 +312,7 @@ async def loop(worker, shot_id): if using_gpu: deviceid = devices[worker.indices[1] % num_gpus] - if platform == 'nvidia-acc': + if platform in ['nvidia-acc', 'nvidia-cuda']: devito_args = _kwargs.get('devito_args', {}) devito_args['deviceid'] = deviceid _kwargs['devito_args'] = devito_args