Skip to content

Commit

Permalink
stub -> app
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Apr 24, 2024
1 parent a883c1f commit 6c816f6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .common import stub
from .common import app
from .train import train, launch
from .inference import Inference
4 changes: 2 additions & 2 deletions src/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import PurePosixPath

from modal import Stub, Image, Volume
from modal import App, Image, Volume

APP_NAME = "example-axolotl"

Expand Down Expand Up @@ -32,7 +32,7 @@
"torch==2.1.2",
)

stub = Stub(APP_NAME)
app = App(APP_NAME)

# Volumes for pre-trained models and training runs.
pretrained_volume = Volume.from_name("example-pretrained-vol", create_if_missing=True)
Expand Down
6 changes: 3 additions & 3 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import modal
from fastapi.responses import StreamingResponse

from .common import stub, vllm_image, VOLUME_CONFIG
from .common import app, vllm_image, VOLUME_CONFIG

N_INFERENCE_GPU = 2

Expand All @@ -21,7 +21,7 @@ def get_model_path_from_run(path: Path) -> Path:
return path / yaml.safe_load(f.read())["output_dir"] / "merged"


@stub.cls(
@app.cls(
gpu=modal.gpu.H100(count=N_INFERENCE_GPU),
image=vllm_image,
volumes=VOLUME_CONFIG,
Expand Down Expand Up @@ -103,7 +103,7 @@ async def web(self, input: str):
return StreamingResponse(self._stream(input), media_type="text/event-stream")


@stub.local_entrypoint()
@app.local_entrypoint()
def inference_main(run_name: str = "", prompt: str = ""):
if prompt:
for chunk in Inference(run_name).completion.remote_gen(prompt):
Expand Down
10 changes: 5 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os

from .common import (
stub,
app,
axolotl_image,
VOLUME_CONFIG,
)
Expand Down Expand Up @@ -46,7 +46,7 @@ def run_cmd(cmd: str, run_folder: str):
VOLUME_CONFIG["/runs"].commit()


@stub.function(
@app.function(
image=axolotl_image,
gpu=GPU_CONFIG,
volumes=VOLUME_CONFIG,
Expand All @@ -70,7 +70,7 @@ def train(run_folder: str, output_dir: str):
return merge_handle


@stub.function(image=axolotl_image, volumes=VOLUME_CONFIG, timeout=3600 * 24)
@app.function(image=axolotl_image, volumes=VOLUME_CONFIG, timeout=3600 * 24)
def merge(run_folder: str, output_dir: str):
import shutil

Expand All @@ -86,7 +86,7 @@ def merge(run_folder: str, output_dir: str):
VOLUME_CONFIG["/runs"].commit()


@stub.function(image=axolotl_image, timeout=60 * 30, volumes=VOLUME_CONFIG)
@app.function(image=axolotl_image, timeout=60 * 30, volumes=VOLUME_CONFIG)
def launch(config_raw: str, data_raw: str):
from huggingface_hub import snapshot_download
import yaml
Expand Down Expand Up @@ -131,7 +131,7 @@ def launch(config_raw: str, data_raw: str):
return run_name, train_handle


@stub.local_entrypoint()
@app.local_entrypoint()
def main(
config: str,
data: str,
Expand Down

0 comments on commit 6c816f6

Please sign in to comment.