From 5ad342198fbab7ff2c62034c9e71b8c17820472e Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Mon, 2 Sep 2024 15:25:31 +0100 Subject: [PATCH] Quick-fix to platform --- stride/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stride/__init__.py b/stride/__init__.py index a500c79..8666b69 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -132,7 +132,7 @@ async def loop(worker, shot_id): if using_gpu: deviceid = devices[worker.indices[1] % num_gpus] if platform in ['nvidia-acc', 'nvidia-cuda']: - devito_args = _kwargs.get('devito_args', {}) + devito_args = _kwargs.get('devito_args', {}).copy() devito_args['deviceid'] = deviceid _kwargs['devito_args'] = devito_args elif platform == 'gpu': @@ -313,7 +313,7 @@ async def loop(worker, shot_id): if using_gpu: deviceid = devices[worker.indices[1] % num_gpus] if platform in ['nvidia-acc', 'nvidia-cuda']: - devito_args = _kwargs.get('devito_args', {}) + devito_args = _kwargs.get('devito_args', {}).copy() devito_args['deviceid'] = deviceid _kwargs['devito_args'] = devito_args elif platform == 'gpu': @@ -399,8 +399,8 @@ async def loop(worker, shot_id): if using_gpu: deviceid = devices[worker.indices[1] % num_gpus] - if platform == 'nvidia-acc': - devito_args = _kwargs.get('devito_args', {}) + if platform in ['nvidia-acc', 'nvidia-cuda']: + devito_args = _kwargs.get('devito_args', {}).copy() devito_args['deviceid'] = deviceid _kwargs['devito_args'] = devito_args elif platform == 'gpu':