Skip to content

Commit

Permalink
Fix cutlass python library with cuda 12.6.2.post1 (#1942)
Browse files Browse the repository at this point in the history
* Fix `cutlass` python library with cuda `12.6.2.post1`

Previously we had this error:
```
  File "/storage/home/cutlass/python/cutlass/backend/operation.py", line 39, in <listcomp>
    _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")]
                       ^^^^^^
ValueError: invalid literal for int() with base 10: 'post1'
```

* Update sm90_utils.py

* Update generator.py

* Update python/cutlass_library/generator.py

Co-authored-by: Jack Kosaian <[email protected]>

* Update python/cutlass_library/sm90_utils.py

Co-authored-by: Jack Kosaian <[email protected]>

---------

Co-authored-by: Jack Kosaian <[email protected]>
  • Loading branch information
danthe3rd and jackkosaian authored Nov 18, 2024
1 parent 8aa95db commit b0e09d7
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/cutlass/backend/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from cutlass.backend.utils.device import device_cc

_version_splits = [int(x) for x in __version__.split("rc")[0].split(".")]
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
_supports_cluster_launch = None


Expand Down
2 changes: 1 addition & 1 deletion python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):

# Update cuda_version based on parsed string
if semantic_ver_string != '':
for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]):
for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]):
if i < len(cuda_version):
cuda_version[i] = x
else:
Expand Down
2 changes: 1 addition & 1 deletion python/cutlass_library/sm90_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):

# Update cuda_version based on parsed string
if semantic_ver_string != '':
for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]):
for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]):
if i < len(cuda_version):
cuda_version[i] = x
else:
Expand Down

0 comments on commit b0e09d7

Please sign in to comment.