Skip to content

Commit

Permalink
s
Browse files Browse the repository at this point in the history
  • Loading branch information
rawsh authored Sep 14, 2024
1 parent b351d4a commit 1d72d51
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 41 deletions.
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM vespaengine/vespa:latest

RUN echo '#!/usr/bin/env bash' > /entry.sh
RUN echo 'exec "$@"' >> /entry.sh
RUN chmod +x /entry.sh
# ENTRYPOINT /entry.sh
ENTRYPOINT tail -f /dev/null
Binary file modified __pycache__/modal_reward.cpython-312.pyc
Binary file not shown.
93 changes: 52 additions & 41 deletions modal_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer


# @modal.web_endpoint(method="POST", docs=True)
@app.cls(
# gpu=modal.gpu.A10G(),
gpu=modal.gpu.T4(),
enable_memory_snapshot=True,
# volumes={"/my_vol": modal.Volume.from_name("my-test-volume")}
gpu=modal.gpu.L4(),
# gpu=modal.gpu.T4(),
# enable_memory_snapshot=True,
# volumes={"/my_vol": modal.Volume.from_name("my-test-volume")},
container_idle_timeout=10
)
class Embedder:

Expand All @@ -31,30 +33,39 @@ class Embedder:
def build(self):
# cache
print("build")
# dtype = torch.bfloat16
dtype = torch.float16
model = AutoModelForSequenceClassification.from_pretrained(self.model_id, device_map="auto",
trust_remote_code=True, torch_dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)
# torch.compile(model)
dtype = torch.bfloat16
# dtype = torch.float16
with torch.device("cuda"):
model = AutoModelForSequenceClassification.from_pretrained(self.model_id,
trust_remote_code=True, torch_dtype=dtype, use_safetensors=True)

@modal.enter(snap=True)
def load(self):
# Create a memory snapshot with the model loaded in CPU memory.
print("save state")
# dtype = torch.bfloat16
dtype = torch.float16
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, device_map="cpu",
trust_remote_code=True, torch_dtype=dtype)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)
# @modal.enter(snap=True)
# def load(self):
# # Create a memory snapshot with the model loaded in CPU memory.
# print("save state")

@modal.enter(snap=False)
# @modal.enter(snap=False)
@modal.enter()
def setup(self):
# Move the model to a GPU before doing any work.
print("loaded from snapshot")
self.model = self.model.to(self.device)
dtype = torch.bfloat16
with torch.device("cuda"):
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id,
trust_remote_code=True, torch_dtype=dtype, use_safetensors=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)

# @modal.enter()
# def setup(self):
# # Move the model to a GPU before doing any work.
# print("loaded from snapshot")
# dtype = torch.float16
# self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id, device_map="auto",
# trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=True)
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, use_fast=True)

@modal.method()
# @modal.method()
@modal.web_endpoint(method="POST", docs=True)
def score_output(self, messages: List[Dict[str, str]]):
print("batched")
input_ids = self.tokenizer.apply_chat_template(
Expand All @@ -71,24 +82,24 @@ def score_output(self, messages: List[Dict[str, str]]):
return float_output.item()


@app.function()
@modal.web_endpoint(method="POST", docs=True)
async def run(messages: List[Dict[str, str]]):
result = await Embedder().score_output.remote.aio(messages)
print(messages, result)
return {"result": result}
# @app.function()
# @modal.web_endpoint(method="POST", docs=True)
# async def run(messages: List[Dict[str, str]]):
# result = await Embedder().score_output.remote.aio(messages)
# print(messages, result)
# return {"result": result}


@app.local_entrypoint()
async def main():
# score the messages
prompt = 'What are some synonyms for the word "beautiful"?'
response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant'
response2 = 'bad'
messages1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}]
messages2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}]
m1 = Embedder().score_output.remote.aio(messages1)
m2 = Embedder().score_output.remote.aio(messages2)
res = await asyncio.gather(*[m1,m2])
print(response1, res[0])
print(response2, res[1])
# @app.local_entrypoint()
# async def main():
# # score the messages
# prompt = 'What are some synonyms for the word "beautiful"?'
# response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant'
# response2 = 'bad'
# messages1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}]
# messages2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}]
# m1 = Embedder().score_output(messages1)
# m2 = Embedder().score_output(messages2)
# res = await asyncio.gather(*[m1,m2])
# print(response1, res[0])
# print(response2, res[1])
38 changes: 38 additions & 0 deletions modal_vespa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import modal


# vespa_image = modal.Image.from_registry("vespaengine/vespa", add_python="3.11")
vespa_image = modal.Image.from_dockerfile("Dockerfile", add_python="3.11")
app = modal.App("dankvespa", image=vespa_image)

@modal.web_endpoint(method="POST", docs=True)
@app.cls(
enable_memory_snapshot=True,
# volumes={"/my_vol": modal.Volume.from_name("my-test-volume")}
)
class Vespa:
@modal.build()
def build(self):
# cache
print("build")

@modal.enter(snap=True)
def load(self):
# Create a memory snapshot with the model loaded in CPU memory.
print("save state")

@modal.enter(snap=False)
def setup(self):
# Move the model to a GPU before doing any work.
print("loaded from snapshot")

@modal.method()
def search(self, query: str):
print("search")


@app.local_entrypoint()
async def main():
# score the messages
m1 = await Vespa().search.remote.aio("test")
print(m1)

0 comments on commit 1d72d51

Please sign in to comment.