forked from google/nvidia_libs_test
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcuda_configure.bzl
110 lines (92 loc) · 2.61 KB
/
cuda_configure.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Build rule generator for locally installed CUDA toolkit and cuDNN SDK."""
def _get_env_var(repository_ctx, name, default):
if name in repository_ctx.os.environ:
return repository_ctx.os.environ[name]
return default
def _impl(repository_ctx):
cuda_path = _get_env_var(repository_ctx, "CUDA_PATH", "/usr/local/cuda")
cudnn_path = _get_env_var(repository_ctx, "CUDNN_PATH", cuda_path)
cudnn_lib_dir = _get_env_var(repository_ctx, "CUDNN_LIB_DIR", "lib64")
print("Using CUDA from %s\n" % cuda_path)
print("Using cuDNN from %s\n" % cudnn_path)
repository_ctx.symlink(cuda_path, "cuda")
repository_ctx.symlink(cudnn_path, "cudnn")
repository_ctx.file("nvcc.sh", """
#! /bin/bash
repo_path=%s
compiler=${CC:+"--compiler-bindir=$CC"}
$repo_path/cuda/bin/nvcc $compiler --compiler-options=-fPIC --include-path=$repo_path $*
""" % repository_ctx.path("."))
repository_ctx.file("BUILD", """
package(default_visibility = ["//visibility:public"])
sh_binary(
name = "nvcc",
srcs = ["nvcc.sh"],
)
# The *_headers cc_library rules below aren't cc_inc_library rules because
# dependent targets would only see the first one.
cc_library(
name = "cuda_headers",
hdrs = glob(
include = ["cuda/include/**/*.h*"],
exclude = ["cuda/include/cudnn.h"]
),
# Allows including CUDA headers with angle brackets.
includes = ["cuda/include"],
)
cc_library(
name = "cuda",
srcs = ["cuda/lib64/stubs/libcuda.so"],
linkopts = ["-ldl"],
)
cc_library(
name = "cuda_runtime",
srcs = ["cuda/lib64/libcudart_static.a"],
deps = [":cuda"],
linkopts = ["-lrt"],
)
cc_library(
name = "curand_static",
srcs = [
"cuda/lib64/libcurand_static.a",
],
deps = [
":culibos",
],
)
cc_library(
name = "cupti_headers",
hdrs = glob(["cuda/extras/CUPTI/include/**/*.h"]),
# Allows including CUPTI headers with angle brackets.
includes = ["cuda/extras/CUPTI/include"],
)
cc_library(
name = "cupti",
srcs = glob(["cuda/extras/CUPTI/lib64/libcupti.so*"]),
)
cc_library(
name = "cudnn",
srcs = [
"cudnn/%s/libcudnn_static.a",
"cuda/lib64/libcublas_static.a",
] + glob(["cuda/lib64/libcublasLt_static.a"]),
hdrs = ["cudnn/include/cudnn.h"],
deps = [
":cuda",
":cuda_headers",
":culibos",
],
)
cc_library(
name = "culibos",
srcs = ["cuda/lib64/libculibos.a"],
)
cc_library(
name = "cuda_util",
deps = [":cuda_util_compile"],
)
""" % cudnn_lib_dir)
cuda_configure = repository_rule(
implementation = _impl,
environ = ["CUDA_PATH", "CUDNN_PATH"],
)