Skip to content

Commit

Permalink
commit change
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuck Tang committed Aug 9, 2024
1 parent dc79330 commit 32e1eed
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions docker/generate_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import yaml

PRODUCTION_PYTHON_VERSION = '3.11'
PRODUCTION_PYTORCH_VERSION = '2.3.1'
PRODUCTION_PYTORCH_VERSION = '2.4.0'


def _get_torchvision_version(pytorch_version: str):
if pytorch_version == '2.4.0':
return '0.19.0'
if pytorch_version == '2.3.1':
return '0.18.1'
if pytorch_version == '2.2.2':
return '0.17.2'
if pytorch_version == '2.1.2':
return '0.16.2'
raise ValueError(f'Invalid pytorch_version: {pytorch_version}')


Expand All @@ -42,12 +42,12 @@ def _get_cuda_version(pytorch_version: str, use_cuda: bool):
# From https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/
if not use_cuda:
return ''
if pytorch_version == '2.4.0':
return '12.1.4'
if pytorch_version == '2.3.1':
return '12.1.1'
if pytorch_version == '2.2.2':
return '12.1.1'
if pytorch_version == '2.1.2':
return '12.1.1'
raise ValueError(f'Invalid pytorch_version: {pytorch_version}')


Expand Down Expand Up @@ -167,7 +167,7 @@ def _write_table(table_tag: str, table_contents: str):


def _main():
python_pytorch_versions = [('3.11', '2.3.1'), ('3.11', '2.2.2'), ('3.10', '2.1.2')]
python_pytorch_versions = [('3.11', '2.4.0'), ('3.11', '2.3.1'), ('3.11', '2.2.2')]
cuda_options = [True, False]
stages = ['pytorch_stage']
interconnects = ['mellanox', 'EFA'] # mellanox is default, EFA needed for AWS
Expand Down

0 comments on commit 32e1eed

Please sign in to comment.