Skip to content

Commit

Permalink
Quick-fix to platform
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed Sep 2, 2024
1 parent 6ce10d2 commit 5ad3421
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def loop(worker, shot_id):
if using_gpu:
deviceid = devices[worker.indices[1] % num_gpus]
if platform in ['nvidia-acc', 'nvidia-cuda']:
devito_args = _kwargs.get('devito_args', {})
devito_args = _kwargs.get('devito_args', {}).copy()
devito_args['deviceid'] = deviceid
_kwargs['devito_args'] = devito_args
elif platform == 'gpu':
Expand Down Expand Up @@ -313,7 +313,7 @@ async def loop(worker, shot_id):
if using_gpu:
deviceid = devices[worker.indices[1] % num_gpus]
if platform in ['nvidia-acc', 'nvidia-cuda']:
devito_args = _kwargs.get('devito_args', {})
devito_args = _kwargs.get('devito_args', {}).copy()
devito_args['deviceid'] = deviceid
_kwargs['devito_args'] = devito_args
elif platform == 'gpu':
Expand Down Expand Up @@ -399,8 +399,8 @@ async def loop(worker, shot_id):

if using_gpu:
deviceid = devices[worker.indices[1] % num_gpus]
if platform == 'nvidia-acc':
devito_args = _kwargs.get('devito_args', {})
if platform in ['nvidia-acc', 'nvidia-cuda']:
devito_args = _kwargs.get('devito_args', {}).copy()
devito_args['deviceid'] = deviceid
_kwargs['devito_args'] = devito_args
elif platform == 'gpu':
Expand Down

0 comments on commit 5ad3421

Please sign in to comment.