Skip to content

Commit

Permalink
Added support for more torch.cuda APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
TejaX-Alaghari committed Nov 27, 2024
1 parent 6ae8d31 commit 7280da2
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 0 deletions.
10 changes: 10 additions & 0 deletions clang/test/dpct/python_migration/case_006/case_006_torch_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: rm -rf %T && mkdir -p %T
// RUN: cd %T
// RUN: cp %S/input.py ./input.py
// RUN: dpct -in-root ./ -out-root out ./input.py --migrate-build-script-only
// RUN: echo "begin" > %T/diff.txt
// RUN: diff --strip-trailing-cr %S/expected.py %T/out/input.py >> %T/diff.txt
// RUN: echo "end" >> %T/diff.txt

// CHECK: begin
// CHECK-NEXT: end
14 changes: 14 additions & 0 deletions clang/test/dpct/python_migration/case_006/expected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch import xpu

devs = torch.xpu.device_count()
devs = xpu.device_count()

d_cap = torch.xpu.get_device_capability()
d_cap = xpu.get_device_capability()
d0_cap = torch.xpu.get_device_capability(devs[0])
d0_cap = xpu.get_device_capability(devs[0])

arch_list = ['']
arch_list = ['']

cuda_ver = torch.version.xpu
14 changes: 14 additions & 0 deletions clang/test/dpct/python_migration/case_006/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch import cuda

devs = torch.cuda.device_count()
devs = cuda.device_count()

d_cap = torch.cuda.get_device_capability()
d_cap = cuda.get_device_capability()
d0_cap = torch.cuda.get_device_capability(devs[0])
d0_cap = cuda.get_device_capability(devs[0])

arch_list = torch.cuda.get_arch_list()
arch_list = cuda.get_arch_list()

cuda_ver = torch.version.cuda
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@
In: CUDAExtension
Out: DPCPPExtension

- Rule: rule_from_torch_import_cuda
Kind: PythonRule
Priority: Fallback
MatchMode: Partial
PythonSyntax: from_torch_import_cuda
In: from torch import cuda
Out: from torch import xpu

- Rule: rule_cuda_is_available
Kind: PythonRule
Priority: Fallback
Expand Down Expand Up @@ -137,3 +145,40 @@
MatchMode: Full
In: ${arg}.cu
Out: ${arg}.cpp

- Rule: rule_cuda_device_count
Kind: PythonRule
Priority: Fallback
MatchMode: Partial
PythonSyntax: cuda_device_count
In: cuda.device_count
Out: xpu.device_count

- Rule: rule_cuda_get_device_capability
Kind: PythonRule
Priority: Fallback
MatchMode: Partial
PythonSyntax: cuda_get_device_capability
In: cuda.get_device_capability
Out: xpu.get_device_capability

- Rule: rule_cuda_get_arch_list
Kind: PythonRule
Priority: Fallback
MatchMode: Partial
PythonSyntax: cuda_get_arch_list
In: ${torch_prefix}cuda.get_arch_list()
Out: ${torch_prefix}['']
Subrules:
torch_prefix:
MatchMode: Full
In: torch.
Out: ""

- Rule: rule_version_cuda
Kind: PythonRule
Priority: Fallback
MatchMode: Partial
PythonSyntax: version_cuda
In: version.cuda
Out: version.xpu

0 comments on commit 7280da2

Please sign in to comment.