Skip to content

Commit

Permalink
Add automatic downloading to CLI (#1571)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jul 10, 2024
1 parent d49ce98 commit e71b938
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 52 deletions.
22 changes: 7 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,17 @@ Finetuning is the process of taking a pretrained AI model and further training i
# 0) setup your dataset
curl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json

# 1) Download a pretrained model
litgpt download microsoft/phi-2

# 2) Finetune the model
# 1) Download and finetune the model
litgpt finetune microsoft/phi-2 \
--data JSON \
--data.json_path my_custom_dataset.json \
--data.val_split_fraction 0.1 \
--out_dir out/custom-model

# 3) Test the model
# 2) Test the model
litgpt chat out/custom-model/final

# 4) Deploy the model
# 3) Deploy the model
litgpt serve out/custom-model/final
```

Expand All @@ -238,7 +235,6 @@ Deploy a pretrained or finetune LLM to use it in real-world applications. Deploy

```bash
# deploy an out-of-the-box LLM
litgpt download microsoft/phi-2
litgpt serve microsoft/phi-2

# deploy your own trained model
Expand Down Expand Up @@ -306,11 +302,10 @@ litgpt chat microsoft/phi-2
 

```bash
# 1) Download the LLM
# 1) List all supported LLMs
litgpt download list
litgpt download microsoft/phi-2

# 2) Test the model
# 2) Download and use the model
litgpt chat microsoft/phi-2

>> Prompt: What do Llamas eat?
Expand Down Expand Up @@ -393,10 +388,7 @@ mkdir -p custom_texts
curl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt
curl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt

# 1) Download a pretrained model
litgpt download EleutherAI/pythia-160m

# 2) Continue pretraining the model
# 1) Download and continue pretraining a model
litgpt pretrain EleutherAI/pythia-160m \
--tokenizer_dir EleutherAI/pythia-160m \
--initial_checkpoint_dir EleutherAI/pythia-160m \
Expand All @@ -405,7 +397,7 @@ litgpt pretrain EleutherAI/pythia-160m \
--train.max_tokens 10_000_000 \
--out_dir out/custom-model

# 3) Test the model
# 2) Test the model
litgpt chat out/custom-model/final
```

Expand Down
15 changes: 2 additions & 13 deletions litgpt/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
#
# This file implements the LitGPT Python API
import os
from pathlib import Path
from typing import Any, List, Literal, Optional, Union

Expand All @@ -16,8 +15,8 @@
from litgpt.chat.base import generate as stream_generate_fn
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import (
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
Expand Down Expand Up @@ -109,17 +108,7 @@ def load(
allowed_init = {"pretrained", "random"}

if init == "pretrained":
from litgpt.scripts.download import download_from_hub # Moved here due to the circular import issue in LitGPT that we need to solve some time

checkpoint_dir = extend_checkpoint_dir(Path(model))
try:
check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True)
except FileNotFoundError:
if not access_token:
access_token = os.getenv("HF_TOKEN")
download_from_hub(repo_id=model, access_token=access_token)

checkpoint_dir = Path("checkpoints") / model
checkpoint_dir = auto_download_checkpoint(model_name=model, access_token=access_token)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

elif init == "random":
Expand Down
10 changes: 6 additions & 4 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import (
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint
)
Expand Down Expand Up @@ -176,11 +176,13 @@ def main(
precision: Optional[str] = None,
compile: bool = False,
multiline: bool = False,
access_token: Optional[str] = None,
) -> None:
"""Chat with a model.
Args:
checkpoint_dir: The checkpoint directory to load.
checkpoint_dir: A local path to a directory containing the model weights or a valid model name.
You can get a list of valid model names via the `litgpt download list` command line argument.
top_k: The number of top most probable tokens to consider in the sampling process.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
Expand All @@ -205,8 +207,8 @@ def main(
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
multiline: Whether to support multiline input prompts.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

precision = precision or get_default_supported_precision(training=False)
Expand All @@ -229,7 +231,7 @@ def main(
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_dir)

check_valid_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

with fabric.init_module(empty_init=True):
Expand Down
10 changes: 5 additions & 5 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from litgpt.chat.base import generate as stream_generate
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import (
extend_checkpoint_dir,
auto_download_checkpoint,
get_default_supported_precision,
load_checkpoint
)
Expand Down Expand Up @@ -173,7 +173,8 @@ def run_server(
devices: int = 1,
accelerator: str = "auto",
port: int = 8000,
stream: bool = False
stream: bool = False,
access_token: Optional[str] = None,
) -> None:
"""Serve a LitGPT model using LitServe.
Expand Down Expand Up @@ -207,12 +208,11 @@ def run_server(
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())

check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

if not stream:
server = LitServer(
SimpleLitAPI(
Expand Down
6 changes: 4 additions & 2 deletions litgpt/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint
from litgpt.utils import copy_config_files, extend_checkpoint_dir
from litgpt.utils import copy_config_files, auto_download_checkpoint


def prepare_results(results, save_filepath, print_results=True):
Expand Down Expand Up @@ -37,6 +37,7 @@ def convert_and_evaluate(
limit: Optional[float] = None,
seed: int = 1234,
save_filepath: Optional[Path] = None,
access_token: Optional[str] = None,
) -> None:
"""Evaluate a model with the LM Evaluation Harness.
Expand All @@ -55,6 +56,7 @@ def convert_and_evaluate(
seed: Random seed.
save_filepath: The file where the results will be saved.
Saves to `out_dir/results.json` by default.
access_token: Optional API token to access models with restrictions.
"""
if tasks is None:
from lm_eval.tasks import TaskManager
Expand All @@ -68,7 +70,7 @@ def convert_and_evaluate(
)
return

checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())

if not (isinstance(batch_size, int) and batch_size > 0) and not (isinstance(batch_size, str) and batch_size.startswith("auto")):
Expand Down
6 changes: 4 additions & 2 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
auto_download_checkpoint,
CycleIterator,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
Expand Down Expand Up @@ -62,6 +62,7 @@ def setup(
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the Adapter method.
Expand All @@ -79,8 +80,9 @@ def setup(
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
6 changes: 4 additions & 2 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
auto_download_checkpoint,
CycleIterator,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
init_out_dir,
instantiate_torch_optimizer,
Expand Down Expand Up @@ -62,6 +62,7 @@ def setup(
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the Adapter V2 method.
Expand All @@ -79,8 +80,9 @@ def setup(
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
6 changes: 4 additions & 2 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
auto_download_checkpoint,
CycleIterator,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
find_resume_path,
get_default_supported_precision,
load_checkpoint,
Expand Down Expand Up @@ -58,6 +58,7 @@ def setup(
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model.
Expand All @@ -77,8 +78,9 @@ def setup(
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
6 changes: 4 additions & 2 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from litgpt.scripts.merge_lora import merge_lora
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
auto_download_checkpoint,
CycleIterator,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
init_out_dir,
Expand Down Expand Up @@ -72,6 +72,7 @@ def setup(
optimizer: Union[str, Dict] = "AdamW",
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
access_token: Optional[str] = None,
) -> None:
"""Finetune a model using the LoRA method.
Expand All @@ -98,8 +99,9 @@ def setup(
optimizer: An optimizer name (such as "AdamW") or config.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
access_token: Optional API token to access models with restrictions.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
Expand Down
20 changes: 20 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.serialization import normalize_storage_type
from typing_extensions import Self


if TYPE_CHECKING:
from litgpt import GPT, Config

Expand Down Expand Up @@ -561,3 +562,22 @@ def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_71
"with more than 1B parameters on a CPU can be slow, it is recommended to switch to a GPU."
)
return size


def auto_download_checkpoint(model_name, access_token=None):
from litgpt.scripts.download import download_from_hub # moved here due to circular import issue

checkpoint_dir = extend_checkpoint_dir(Path(model_name))
try:
check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True)
except FileNotFoundError as e:
if access_token is None:
access_token = os.getenv("HF_TOKEN")

if checkpoint_dir.parts[0] != "checkpoints" and not checkpoint_dir.is_absolute():
download_from_hub(repo_id=str(model_name), access_token=access_token)
checkpoint_dir = Path("checkpoints") / checkpoint_dir
else:
raise e

return checkpoint_dir
1 change: 0 additions & 1 deletion tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from itertools import repeat
Expand Down
Loading

0 comments on commit e71b938

Please sign in to comment.