Skip to content

Commit

Permalink
Add mention of cuDNN in README
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb authored Nov 30, 2024
1 parent c05ba5d commit 687192c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,14 @@ generate(model, prompt,

## CUDA GPU

In addition to `JSON3` and `Jjama3`, ensure `CUDA.jl` and `Flux.jl` are installed.

```julia
using CUDA, Flux, JSON3, Jjama3
```

You might also need to install `cuDNN.jl` and run `using cuDNN` on some systems.

For sampling, you can pass `device = gpu` to the `generate` function:

```julia
Expand All @@ -168,4 +172,4 @@ train_toks = encode(tkn, "This is a test.")
gpu_train_toks = gpu(train_toks)

forward_loss(model, gpu_train_toks[1:end-1,:], gpu_train_toks[2:end,:])
```
```

0 comments on commit 687192c

Please sign in to comment.