Skip to content

Commit

Permalink
[Dev] Add triton example (#21)
Browse files Browse the repository at this point in the history
* [Dev] Add triton example
  • Loading branch information
lshmouse authored Oct 8, 2024
1 parent 3f6442e commit a1d8ed1
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ bazel test --config=cpplint //experimental/cpp_example/...
- [TODO] NCCL and examples

#### Autonomous System
- [TODO] rules_ros and examples: https://github.com/ApexAI/rules_ros
- [DONE] rules_ros and examples: https://github.com/ApexAI/rules_ros
- [TODO] nuscenes dataset and dataloader
- [TODO] rerun.io examplee
- [TODO] mcap example
- [DONE] mcap example

#### AI Models
- [TODO] NLP(BERT, GPT, etc.)
Expand All @@ -57,6 +57,7 @@ bazel test --config=cpplint //experimental/cpp_example/...
- [TODO] GPU mem test tools
- [TODO] nvcc network performance test tools
- [TODO] pytorch profiling tools
- [TODO] MLIR https://github.com/llvm/torch-mlir/blob/main/docs/development.md#bazel-build

### References
- Thanks to jiaming: https://github.com/storypku
Expand Down
10 changes: 10 additions & 0 deletions experimental/triton_example/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("@pip//:requirements.bzl", "requirement")
load("@rules_python//python:defs.bzl", "py_binary", "py_library")

py_binary(
name = "triton_example",
srcs = ["triton_example.py"],
main = "triton_example.py",
deps = [
],
)
2 changes: 2 additions & 0 deletions experimental/triton_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
## triton from openai
https://openai.com/index/triton/
36 changes: 36 additions & 0 deletions experimental/triton_example/triton_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import triton
import triton.language as tl

@triton.jit
def softmax(Y, stride_ym, stride_yn, X, stride_xm, stride_xn, M, N):
# row index
m = tl.program_id(0)
# col indices
# this specific kernel only works for matrices that
# have less than BLOCK_SIZE columns
BLOCK_SIZE: tl.constexpr = 1024
n = tl.arange(0, BLOCK_SIZE)
# the memory address of all the elements
# that we want to load can be computed as follows
X = X + m * stride_xm + n * stride_xn
# load input data; pad out-of-bounds elements with 0
x = tl.load(X, mask=n < N, other=-float('inf'))
# compute numerically-stable softmax
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
y = num / denom
# write back to Y
Y = Y + m * stride_ym + n * stride_yn
tl.store(Y, y, mask=n < N)

import torch
# Allocate input/output tensors
X = torch.normal(0, 1, size=(583, 931), device='cuda')
Y = torch.empty_like(X)
# SPMD launch grid
grid = (X.shape[0], )
# enqueue GPU kernel
softmax[grid](Y, Y.stride(0), Y.stride(1),
X, X.stride(0), X.stride(1),
X.shape[0] , X.shape[1])

0 comments on commit a1d8ed1

Please sign in to comment.