-
Notifications
You must be signed in to change notification settings - Fork 4
/
modal_prm_reward.py
134 lines (117 loc) · 4.52 KB
/
modal_prm_reward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import modal
image = (
modal.Image.debian_slim()
.pip_install([
"torch",
"transformers",
"accelerate",
"batched",
])
)
# app = modal.App("mirrorqwen-prm", image=image)
app = modal.App("mirrorqwen-prm-st", image=image)
with image.imports():
from typing import List, Dict, Tuple
import asyncio
import torch
from time import perf_counter as pc
from transformers import pipeline
import os
class BatchProcessor:
def __init__(self):
import batched
self.batched = batched
def create_batch_processor(self, pipeline_func):
@self.batched.dynamically(batch_size=256, timeout_ms=200.0, small_batch_threshold=4)
def _process_batch(prompts: List[str]) -> List[Dict]:
return pipeline_func(prompts)
return _process_batch
@app.cls(
# gpu=modal.gpu.T4(),
gpu=modal.gpu.A10G(),
# gpu=modal.gpu.H100(),
# gpu=modal.gpu.A100(),
container_idle_timeout=120,
# allow_concurrent_inputs=1000,
allow_concurrent_inputs=1000,
secrets=[
modal.Secret.from_name("hf-token"),
],
)
class Embedder:
model_id = "rawsh/mirrorqwen2.5-0.5b-prm"
# revision = "894341fbd81d0c1abdd98b4e0630de932aa63c6f" # base
# revision = "42e07d1b708282ac2aae338050d8116f8c69398d" # st0
# revision = "65f4a7601dffacc40e0ef7fa4733d346c926bd18" # st1 v1
# revision = "80da7ccc4f107e0cb6bf937d61be4702badfb96b" # st1 v2
# revision = "4d618515c90069993f4b32e4201783efdeebbc22" # st2
# revision = "b052380b619e5c62ce9f407522362f5caf7b8346" # st3
# note: orpo 1 st for prm used strong/weak to generate samples.
# inference pair to gen data for orpo 2 was orpo 1 policy + st0
# revision = "e49e4ca7c847194be48c42c52ad8f871da204300" # orpo2
revision = "ecae5a74ef094d6e839dcb2a32500c36e6786ad1" # orpo3
device = "cuda"
print(model_id)
@modal.build()
def build(self):
print("build")
dtype = torch.bfloat16
with torch.device("cuda"):
print("[build] loading model")
start = pc()
classifier = pipeline("sentiment-analysis", model=self.model_id, revision=self.revision,
trust_remote_code=True, torch_dtype=dtype, device="cuda")
elapsed = pc() - start
print(f"[build] loading model took {elapsed} seconds")
@modal.enter()
def setup(self):
print("setup")
dtype = torch.bfloat16
with torch.device("cuda"):
print("[setup] loading model")
start = pc()
self.pipeline = pipeline("sentiment-analysis", model=self.model_id, revision=self.revision,
trust_remote_code=True, torch_dtype=dtype, device="cuda", batch_size=256)
elapsed = pc() - start
print(f"[setup] loading model took {elapsed} seconds")
# Initialize batch processor
batch_processor = BatchProcessor()
self._process_batch = batch_processor.create_batch_processor(self.pipeline)
@modal.web_endpoint(method="POST", docs=True)
async def score_output(self, inp: dict):
prompt = inp["prompt"]
# Handle both single inputs and lists of inputs
if isinstance(prompt, str):
prompts = [prompt]
else:
prompts = prompt
try:
# Use the batched processing method
results = await self._process_batch.acall(prompts)
# Return single result if input was single, otherwise return list
if isinstance(inp["prompt"], str):
return results[0]
return results
except Exception as e:
return {"error": str(e)}
@app.local_entrypoint()
async def main():
embedder = Embedder()
# Test with multiple prompts
prompt = 'What are some synonyms for the word "beautiful"?'
response1 = 'Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant'
response2 = 'bad'
# Create batch of requests
inputs = [
{"prompt": response1},
{"prompt": response2}
]
# Process in parallel
results = await asyncio.gather(*[
embedder.score_output(inp) for inp in inputs
])
# Print results
for response, result in zip([response1, response2], results):
print(f"Response: {response}\nResult: {result}\n")
# Print batching statistics
print("Batching stats:", embedder._process_batch.stats)