diff --git a/docetl/operations/hf_outlines.py b/docetl/operations/hf_outlines.py index ba07e74c..b2d3d241 100644 --- a/docetl/operations/hf_outlines.py +++ b/docetl/operations/hf_outlines.py @@ -1,106 +1,60 @@ from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel +from pydantic import BaseModel, create_model from docetl.operations.base import BaseOperation from outlines import generate, models -from transformers import AutoModelForCausalLM, AutoTokenizer import json class HuggingFaceMapOperation(BaseOperation): class schema(BaseOperation.schema): + name: str type: str = "hf_map" model_path: str - use_local_model: bool = False - device: str = "cuda" output_schema: Dict[str, Any] prompt_template: str - batch_size: Optional[int] = 10 max_tokens: int = 4096 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, config: Dict[str, Any], runner=None, *args, **kwargs): + super().__init__( + config=config, + default_model=config.get('default_model', config['model_path']), + max_threads=config.get('max_threads', 1), + runner=runner + ) + + self.model = models.transformers( + self.config["model_path"] + ) + + # Create a dynamic Pydantic model from the output schema + field_definitions = { + k: (eval(v) if isinstance(v, str) else v, ...) + for k, v in self.config["output_schema"].items() + } + output_model = create_model('OutputModel', **field_definitions) - if self.config["use_local_model"]: - llm = AutoModelForCausalLM.from_pretrained( - self.config["model_path"], - device_map=self.config["device"] - ) - tokenizer = AutoTokenizer.from_pretrained(self.config["model_path"]) - self.model = models.Transformers(llm, tokenizer) - self.tokenizer = tokenizer - else: - self.model = models.transformers( - self.config["model_path"], - device=self.config["device"] - ) - self.tokenizer = self.model.tokenizer - - output_model = BaseModel.model_validate(self.config["output_schema"]) self.processor = generate.json( self.model, - output_model, - max_tokens=self.config["max_tokens"] + output_model ) def syntax_check(self) -> None: """Validate the operation configuration.""" - config = self.schema(**self.config) - - if not config.model_path: - raise ValueError("model_path is required") - - if not config.output_schema: - raise ValueError("output_schema is required") - - if not config.prompt_template: - raise ValueError("prompt_template is required") - - def create_prompt(self, item: Dict[str, Any]) -> str: - """Create a prompt from the template and input data.""" - messages = [ - { - 'role': 'user', - 'content': self.config["prompt_template"] - }, - { - 'role': 'assistant', - 'content': "I understand and will process the input as requested." - }, - { - 'role': 'user', - 'content': str(item) - } - ] - return self.tokenizer.apply_chat_template( - messages, - tokenize=False - ) + self.schema(**self.config) def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]: """Process a single item through the model.""" - prompt = self.create_prompt(item) try: - result = self.processor(prompt) + result = self.processor(self.config["prompt_template"] + "\n" + str(item)) result_dict = result.model_dump() final_dict = {**item, **result_dict} - return json.loads(json.dumps(final_dict, indent=2)) + return final_dict except Exception as e: self.console.print(f"Error processing item: {e}") - return json.loads(json.dumps(item, indent=2)) + return item - def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + @classmethod + def execute(cls, config: Dict[str, Any], input_data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], float]: """Execute the operation on the input data.""" - if self.status: - self.status.stop() - - results = [] - batch_size = self.config.get("batch_size", 10) - - for i in range(0, len(input_data), batch_size): - batch = input_data[i:i + batch_size] - batch_results = [self.process_item(item) for item in batch] - results.extend(batch_results) - - if self.status: - self.status.start() - + instance = cls(config) + results = [instance.process_item(item) for item in input_data] return results, 0.0 \ No newline at end of file diff --git a/tests/test_hf_outlines.py b/tests/test_hf_outlines.py index 40c561ff..ef9fc943 100644 --- a/tests/test_hf_outlines.py +++ b/tests/test_hf_outlines.py @@ -1,103 +1,82 @@ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from docetl.operations.hf_outlines import HuggingFaceMapOperation +@pytest.fixture +def mock_runner(): + return Mock() + @pytest.fixture def sample_config(): return { + "name": "test_hf_operation", "type": "hf_map", - "model_path": "microsoft/Phi-3-mini-4k-instruct", - "use_local_model": False, - "device": "cuda", + "model_path": "meta-llama/Llama-3.2-1B-Instruct", "output_schema": { "first_name": "str", - "last_name": "str", - "order_number": "str", - "department": "str" + "last_name": "str" }, "prompt_template": "Extract customer information from this text", - "batch_size": 2, "max_tokens": 4096 } @pytest.fixture -def mock_processor_output(): +def research_config(): + return { + "name": "research_analyzer", + "type": "hf_map", + "model_path": "meta-llama/Llama-3.2-1B-Instruct", + "output_schema": { + "title": "str", + "authors": "list", + "methodology": "str", + "findings": "list", + "limitations": "list", + "future_work": "list" + }, + "prompt_template": "Analyze the following research paper abstract.\nExtract key components and summarize findings.", + "max_tokens": 4096 + } + +@pytest.fixture +def mock_research_output(): class MockOutput: def model_dump(self): return { - "first_name": "John", - "last_name": "Doe", - "order_number": "12345", - "department": "Sales" + "title": "Deep Learning in Natural Language Processing", + "authors": ["John Smith", "Jane Doe"], + "methodology": "Comparative analysis of transformer architectures", + "findings": [ + "Improved accuracy by 15%", + "Reduced training time by 30%" + ], + "limitations": [ + "Limited dataset size", + "Computational constraints" + ], + "future_work": [ + "Extend to multilingual models", + "Optimize for edge devices" + ] } return MockOutput() -@pytest.fixture -def sample_input_data(): - return [ - {"message": "Customer John Doe ordered item #12345"}, - {"message": "Customer Jane Smith from Sales department"} - ] - -def test_initialization_remote_model(sample_config): - with patch('outlines.models.transformers') as mock_transformers: - operation = HuggingFaceMapOperation(sample_config) - assert operation.config == sample_config - assert operation.config["use_local_model"] is False - assert mock_transformers.called - -def test_initialization_local_model(sample_config): - sample_config["use_local_model"] = True - with patch('transformers.AutoModelForCausalLM.from_pretrained') as mock_model, \ - patch('transformers.AutoTokenizer.from_pretrained') as mock_tokenizer: - operation = HuggingFaceMapOperation(sample_config) - assert operation.config["use_local_model"] is True - assert mock_model.called - assert mock_tokenizer.called - -@pytest.mark.parametrize("device", ["cuda", "cpu"]) -def test_device_configuration(sample_config, device): - sample_config["device"] = device - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - assert operation.config["device"] == device - -def test_syntax_check(sample_config): - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - operation.syntax_check() - -@pytest.mark.parametrize("missing_field", [ - "model_path", - "output_schema", - "prompt_template" -]) -def test_syntax_check_missing_fields(sample_config, missing_field): - with patch('outlines.models.transformers'): - invalid_config = sample_config.copy() - invalid_config[missing_field] = "" - operation = HuggingFaceMapOperation(invalid_config) - with pytest.raises(ValueError): - operation.syntax_check() - -def test_create_prompt(sample_config): - with patch('outlines.models.transformers') as mock_transformers: - mock_tokenizer = Mock() - mock_tokenizer.apply_chat_template.return_value = "mocked prompt" - mock_transformers.return_value.tokenizer = mock_tokenizer - - operation = HuggingFaceMapOperation(sample_config) - test_item = {"message": "test message"} - prompt = operation.create_prompt(test_item) - - assert isinstance(prompt, str) - assert mock_tokenizer.apply_chat_template.called - -def test_process_item(sample_config, mock_processor_output): - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - operation.processor = Mock(return_value=mock_processor_output) +def test_process_item(sample_config, mock_runner): + mock_model = MagicMock() + + class MockOutput: + def model_dump(self): + return { + "first_name": "John", + "last_name": "Doe" + } + + mock_processor = Mock(return_value=MockOutput()) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): + operation = HuggingFaceMapOperation(sample_config, runner=mock_runner) test_item = {"message": "test message"} result = operation.process_item(test_item) @@ -106,42 +85,50 @@ def test_process_item(sample_config, mock_processor_output): assert "last_name" in result assert "message" in result -def test_process_item_error_handling(sample_config): - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - operation.processor = Mock(side_effect=Exception("Test error")) +def test_research_paper_analysis(research_config, mock_research_output, mock_runner): + mock_model = MagicMock() + mock_processor = Mock(return_value=mock_research_output) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): - test_item = {"message": "test message"} + operation = HuggingFaceMapOperation(research_config, runner=mock_runner) + test_item = { + "abstract": """ + This paper presents a comprehensive analysis of deep learning approaches + in natural language processing. We compare various transformer architectures + and their performance on standard NLP tasks. + """ + } result = operation.process_item(test_item) + # Verify structure and types assert isinstance(result, dict) - assert "message" in result - -def test_execute(sample_config, sample_input_data): - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - operation.process_item = Mock(return_value={"processed": True}) - - results, timing = operation.execute(sample_input_data) + assert "title" in result + assert isinstance(result["title"], str) + assert "authors" in result + assert isinstance(result["authors"], list) + assert "methodology" in result + assert isinstance(result["methodology"], str) + assert "findings" in result + assert isinstance(result["findings"], list) + assert len(result["findings"]) > 0 + assert "limitations" in result + assert isinstance(result["limitations"], list) + assert "future_work" in result + assert isinstance(result["future_work"], list) - assert len(results) == len(sample_input_data) - assert isinstance(timing, float) + # Verify original input is preserved + assert "abstract" in result -def test_batch_processing(sample_config, sample_input_data): - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - operation.process_item = Mock(return_value={"processed": True}) +def test_execute(sample_config, mock_runner): + mock_model = MagicMock() + mock_processor = Mock(return_value={"first_name": "John", "last_name": "Doe"}) + + with patch('outlines.models.transformers', return_value=mock_model) as mock_transformers, \ + patch('outlines.generate.json', return_value=mock_processor): - # Test with different batch sizes - sample_config["batch_size"] = 1 - results1, _ = operation.execute(sample_input_data) - assert len(results1) == len(sample_input_data) - - sample_config["batch_size"] = 2 - results2, _ = operation.execute(sample_input_data) - assert len(results2) == len(sample_input_data) - -def test_max_tokens_configuration(sample_config): - with patch('outlines.models.transformers'): - operation = HuggingFaceMapOperation(sample_config) - assert operation.config["max_tokens"] == 4096 \ No newline at end of file + input_data = [{"message": "test message"}] + results, timing = HuggingFaceMapOperation.execute(sample_config, input_data) + assert len(results) == 1 + assert isinstance(timing, float) \ No newline at end of file