forked from ucbepic/docetl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
504 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from typing import Any, Dict, List, Optional, Tuple | ||
from pydantic import BaseModel | ||
from docetl.operations.base import BaseOperation | ||
from outlines import generate, models | ||
import llama_cpp | ||
import json | ||
|
||
class LlamaCppMapOperation(BaseOperation): | ||
class schema(BaseOperation.schema): | ||
type: str = "llama_cpp_map" | ||
model_path: str | ||
model_file: str | ||
output_schema: Dict[str, Any] | ||
prompt_template: str | ||
batch_size: Optional[int] = 10 | ||
n_gpu_layers: int = -1 | ||
flash_attn: bool = True | ||
n_ctx: int = 8192 | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.tokenizer = llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( | ||
self.config["model_path"] | ||
) | ||
self.model = models.llamacpp( | ||
self.config["model_path"], | ||
self.config["model_file"], | ||
tokenizer=self.tokenizer, | ||
n_gpu_layers=self.config["n_gpu_layers"], | ||
flash_attn=self.config["flash_attn"], | ||
n_ctx=self.config["n_ctx"], | ||
verbose=False | ||
) | ||
|
||
# Initialize the processor based on output schema | ||
output_model = BaseModel.model_validate(self.config["output_schema"]) | ||
self.processor = generate.json( | ||
self.model, | ||
output_model, | ||
max_tokens=4096 | ||
) | ||
|
||
def syntax_check(self) -> None: | ||
"""Validate the operation configuration.""" | ||
config = self.schema(**self.config) | ||
|
||
if not config.model_path: | ||
raise ValueError("model_path is required") | ||
|
||
if not config.model_file: | ||
raise ValueError("model_file is required for llama_cpp models") | ||
|
||
if not config.output_schema: | ||
raise ValueError("output_schema is required") | ||
|
||
if not config.prompt_template: | ||
raise ValueError("prompt_template is required") | ||
|
||
def create_prompt(self, item: Dict[str, Any]) -> str: | ||
"""Create a prompt from the template and input data.""" | ||
messages = [ | ||
{ | ||
'role': 'user', | ||
'content': self.config["prompt_template"] | ||
}, | ||
{ | ||
'role': 'assistant', | ||
'content': "I understand and will process the input as requested." | ||
}, | ||
{ | ||
'role': 'user', | ||
'content': str(item) | ||
} | ||
] | ||
return self.tokenizer.hf_tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False | ||
) | ||
|
||
def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Process a single item through the Outlines model.""" | ||
prompt = self.create_prompt(item) | ||
try: | ||
result = self.processor(prompt) | ||
result_dict = result.model_dump() | ||
final_dict = {**item, **result_dict} | ||
return json.loads(json.dumps(final_dict, indent=2)) | ||
except Exception as e: | ||
self.console.print(f"Error processing item: {e}") | ||
return json.loads(json.dumps(item, indent=2)) | ||
|
||
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: | ||
"""Execute the operation on the input data.""" | ||
if self.status: | ||
self.status.stop() | ||
|
||
results = [] | ||
batch_size = self.config.get("batch_size", 10) | ||
|
||
for i in range(0, len(input_data), batch_size): | ||
batch = input_data[i:i + batch_size] | ||
batch_results = [self.process_item(item) for item in batch] | ||
results.extend(batch_results) | ||
|
||
if self.status: | ||
self.status.start() | ||
|
||
return results, 0.0 | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from typing import Any, Dict, List, Optional, Tuple | ||
from pydantic import BaseModel | ||
from docetl.operations.base import BaseOperation | ||
from outlines import generate, models | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
import json | ||
|
||
class HuggingFaceMapOperation(BaseOperation): | ||
class schema(BaseOperation.schema): | ||
type: str = "hf_map" | ||
model_path: str | ||
use_local_model: bool = False | ||
device: str = "cuda" | ||
output_schema: Dict[str, Any] | ||
prompt_template: str | ||
batch_size: Optional[int] = 10 | ||
max_tokens: int = 4096 | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
if self.config["use_local_model"]: | ||
llm = AutoModelForCausalLM.from_pretrained( | ||
self.config["model_path"], | ||
device_map=self.config["device"] | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(self.config["model_path"]) | ||
self.model = models.Transformers(llm, tokenizer) | ||
self.tokenizer = tokenizer | ||
else: | ||
self.model = models.transformers( | ||
self.config["model_path"], | ||
device=self.config["device"] | ||
) | ||
self.tokenizer = self.model.tokenizer | ||
|
||
output_model = BaseModel.model_validate(self.config["output_schema"]) | ||
self.processor = generate.json( | ||
self.model, | ||
output_model, | ||
max_tokens=self.config["max_tokens"] | ||
) | ||
|
||
def syntax_check(self) -> None: | ||
"""Validate the operation configuration.""" | ||
config = self.schema(**self.config) | ||
|
||
if not config.model_path: | ||
raise ValueError("model_path is required") | ||
|
||
if not config.output_schema: | ||
raise ValueError("output_schema is required") | ||
|
||
if not config.prompt_template: | ||
raise ValueError("prompt_template is required") | ||
|
||
def create_prompt(self, item: Dict[str, Any]) -> str: | ||
"""Create a prompt from the template and input data.""" | ||
messages = [ | ||
{ | ||
'role': 'user', | ||
'content': self.config["prompt_template"] | ||
}, | ||
{ | ||
'role': 'assistant', | ||
'content': "I understand and will process the input as requested." | ||
}, | ||
{ | ||
'role': 'user', | ||
'content': str(item) | ||
} | ||
] | ||
return self.tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=False | ||
) | ||
|
||
def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Process a single item through the model.""" | ||
prompt = self.create_prompt(item) | ||
try: | ||
result = self.processor(prompt) | ||
result_dict = result.model_dump() | ||
final_dict = {**item, **result_dict} | ||
return json.loads(json.dumps(final_dict, indent=2)) | ||
except Exception as e: | ||
self.console.print(f"Error processing item: {e}") | ||
return json.loads(json.dumps(item, indent=2)) | ||
|
||
def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: | ||
"""Execute the operation on the input data.""" | ||
if self.status: | ||
self.status.stop() | ||
|
||
results = [] | ||
batch_size = self.config.get("batch_size", 10) | ||
|
||
for i in range(0, len(input_data), batch_size): | ||
batch = input_data[i:i + batch_size] | ||
batch_results = [self.process_item(item) for item in batch] | ||
results.extend(batch_results) | ||
|
||
if self.status: | ||
self.status.start() | ||
|
||
return results, 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import pytest | ||
from unittest.mock import Mock, patch | ||
from docetl.operations.cpp_outlines import LlamaCppMapOperation | ||
|
||
@pytest.fixture | ||
def sample_config(): | ||
return { | ||
"type": "llama_cpp_map", | ||
"model_path": "/path/to/local/model", | ||
"model_file": "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", | ||
"output_schema": { | ||
"first_name": "str", | ||
"last_name": "str", | ||
"order_number": "str", | ||
"department": "str" | ||
}, | ||
"prompt_template": "Extract customer information from this text", | ||
"batch_size": 2, | ||
"n_gpu_layers": -1, | ||
"flash_attn": True, | ||
"n_ctx": 8192 | ||
} | ||
|
||
@pytest.fixture | ||
def mock_processor_output(): | ||
class MockOutput: | ||
def model_dump(self): | ||
return { | ||
"first_name": "John", | ||
"last_name": "Doe", | ||
"order_number": "12345", | ||
"department": "Sales" | ||
} | ||
return MockOutput() | ||
|
||
@pytest.fixture | ||
def sample_input_data(): | ||
return [ | ||
{"message": "Customer John Doe ordered item #12345"}, | ||
{"message": "Customer Jane Smith from Sales department"} | ||
] | ||
|
||
def test_initialization(sample_config): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
assert operation.config == sample_config | ||
assert operation.config["n_gpu_layers"] == -1 | ||
assert operation.config["flash_attn"] is True | ||
|
||
def test_syntax_check(sample_config): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
operation.syntax_check() # Should not raise error | ||
|
||
@pytest.mark.parametrize("missing_field", [ | ||
"model_path", | ||
"model_file", | ||
"output_schema", | ||
"prompt_template" | ||
]) | ||
def test_syntax_check_missing_fields(sample_config, missing_field): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
invalid_config = sample_config.copy() | ||
invalid_config[missing_field] = "" | ||
operation = LlamaCppMapOperation(invalid_config) | ||
with pytest.raises(ValueError): | ||
operation.syntax_check() | ||
|
||
def test_create_prompt(sample_config): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained') as mock_tokenizer_class: | ||
mock_tokenizer = Mock() | ||
mock_tokenizer.hf_tokenizer.apply_chat_template.return_value = "mocked prompt" | ||
mock_tokenizer_class.return_value = mock_tokenizer | ||
|
||
with patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
test_item = {"message": "test message"} | ||
prompt = operation.create_prompt(test_item) | ||
|
||
assert isinstance(prompt, str) | ||
assert mock_tokenizer.hf_tokenizer.apply_chat_template.called | ||
|
||
def test_process_item(sample_config, mock_processor_output): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
operation.processor = Mock(return_value=mock_processor_output) | ||
|
||
test_item = {"message": "test message"} | ||
result = operation.process_item(test_item) | ||
|
||
assert isinstance(result, dict) | ||
assert "first_name" in result | ||
assert "last_name" in result | ||
assert "message" in result | ||
|
||
def test_process_item_error_handling(sample_config): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
operation.processor = Mock(side_effect=Exception("Test error")) | ||
|
||
test_item = {"message": "test message"} | ||
result = operation.process_item(test_item) | ||
|
||
assert isinstance(result, dict) | ||
assert "message" in result | ||
|
||
def test_execute(sample_config, sample_input_data): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
operation.process_item = Mock(return_value={"processed": True}) | ||
|
||
results, timing = operation.execute(sample_input_data) | ||
|
||
assert len(results) == len(sample_input_data) | ||
assert isinstance(timing, float) | ||
|
||
def test_batch_processing(sample_config, sample_input_data): | ||
with patch('llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained'), \ | ||
patch('outlines.models.llamacpp'): | ||
operation = LlamaCppMapOperation(sample_config) | ||
operation.process_item = Mock(return_value={"processed": True}) | ||
|
||
# Test with different batch sizes | ||
sample_config["batch_size"] = 1 | ||
results1, _ = operation.execute(sample_input_data) | ||
assert len(results1) == len(sample_input_data) | ||
|
||
sample_config["batch_size"] = 2 | ||
results2, _ = operation.execute(sample_input_data) | ||
assert len(results2) == len(sample_input_data) |
Oops, something went wrong.