Skip to content

Commit

Permalink
LitGPT Python API draft (#1459)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
rasbt and awaelchli authored Jun 7, 2024
1 parent 67e9164 commit 0bb34ab
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
1 change: 1 addition & 0 deletions tutorials/developer-docs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
LitGPT developer documentation files.
108 changes: 108 additions & 0 deletions tutorials/developer-docs/python-api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# LitGPT High-level Python API

This is a work-in-progress draft for a high-level LitGPT Pyhon API.

&nbsp;
## Model loading & saving

The `LLM.load` command loads an `llm` object, which contains both the model object (a PyTorch module) and a preprocessor.

```python
from litgpt import LLM

llm = LLM.load(
source="url | local_path",
# high-level user only needs to care about those:
memory_reduction="none | medium | strong"
# advanced options for technical users:
hub="hf | local | other"
quantize="bnb.nf4",
precision="bf16-true",
device=""auto | cuda | cpu",
)
```

Here,

- `llm.model` contains the PyTorch Module
- and `llm.preprocessor.tokenizer` contains the tokenizer

The `llm.save` command saves the model weights, tokenizer, and configuration information.


```python
llm.save(checkpoint_dir, format="lightning | ollama | hf")
```


&nbsp;
## Inference / Chat

```
response = llm.generate(
prompt="What do Llamas eat?",
temperature=0.1,
top_p=0.8,
...
)
```


&nbsp;
## Dataset

The `llm.prepare_dataset` command prepares a dataset for training.

```
llm.download_dataset(
URL,
...
)
```

```
dataset = llm.prepare_dataset(
path,
task="pretrain | instruction_finetune",
test_portion=0.1,
...
)
```

&nbsp;
## Training


```python
llm.instruction_finetune(
config=None,
dataset=dataset,
max_iter=10,
method="full | lora | adapter | adapter_v2"
)
```

```python
llm.pretrain(config=None, dataset=dataset, max_iter=10, ...)
```

&nbsp;
## Serving


```python
llm.serve(port=8000)
```

Then in another Python session:

```python
import requests, json

response = requests.post(
"http://127.0.0.1:8000/predict",
json={"prompt": "Fix typos in the following sentence: Exampel input"}
)

print(response.json()["output"])
```

0 comments on commit 0bb34ab

Please sign in to comment.