Skip to content

Commit

Permalink
Improve serialization (proper serialization for Pydantic models)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughcrt committed Dec 2, 2024
1 parent 2b4ad9f commit aa1218d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 104 deletions.
9 changes: 6 additions & 3 deletions lunary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from inspect import signature
import warnings, traceback, logging, copy, time, chevron, aiohttp, copy
import traceback, logging, copy, time, chevron, aiohttp, copy
from functools import wraps


Expand All @@ -9,10 +9,11 @@
from datetime import datetime, timezone
from typing import Optional, Any, Callable, Union
import jsonpickle
from pydantic import BaseModel
import humps

from .exceptions import *
from .parsers import default_input_parser, default_output_parser, filter_params, method_input_parser
from .parsers import default_input_parser, default_output_parser, filter_params, method_input_parser, PydanticHandler
from .openai_utils import OpenAIUtils
from .event_queue import EventQueue
from .thread import Thread
Expand Down Expand Up @@ -45,6 +46,7 @@
class LunaryException(Exception):
pass

jsonpickle.handlers.register(BaseModel, PydanticHandler, base=True)

def get_parent():
parent = parent_ctx.get()
Expand Down Expand Up @@ -662,7 +664,7 @@ def wrapper(*args, **kwargs):

parsed_input = {"input": input_value}
else:
raw_input = default_input_parser(args, kwargs)
raw_input = default_input_parser(*args, **kwargs)
parsed_input = {"input": raw_input}

return wrap(
Expand Down Expand Up @@ -726,6 +728,7 @@ def wrapper(self, *args, **kwargs):
)(self, *args, **kwargs)
return wrapper
return decorator

def tool(name=None, user_id=None, user_props=None, tags=None, app_id=None):
def decorator(fn):
return wrap(
Expand Down
8 changes: 8 additions & 0 deletions lunary/parsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Any, Dict
import jsonpickle
from pydantic import BaseModel, Field



def default_input_parser(*args, **kwargs):
def serialize(args, kwargs):
Expand Down Expand Up @@ -38,6 +42,10 @@ def serialize(args, kwargs):
def default_output_parser(output, *args, **kwargs):
return {"output": getattr(output, "content", output), "tokensUsage": None}

class PydanticHandler(jsonpickle.handlers.BaseHandler):
def flatten(self, obj, data):
"""Convert Pydantic model to a JSON-friendly dict using model_dump_json()"""
return jsonpickle.loads(obj.model_dump_json())

PARAMS_TO_CAPTURE = [
"frequency_penalty",
Expand Down
Loading

0 comments on commit aa1218d

Please sign in to comment.