Skip to content

Commit

Permalink
Merge pull request caikit#86 from gabe-l-hart/DataclassToProto
Browse files Browse the repository at this point in the history
Dataclass to proto
  • Loading branch information
gabe-l-hart authored May 3, 2023
2 parents 6f24890 + 7e331d4 commit 3fb5769
Show file tree
Hide file tree
Showing 23 changed files with 631 additions and 496 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ import_heading_stdlib=Standard
import_heading_thirdparty=Third Party
import_heading_firstparty=First Party
import_heading_localfolder=Local
known_firstparty=alog,aconfig,jtd_to_proto,import_tracker
known_firstparty=alog,aconfig,py_to_proto,import_tracker
known_localfolder=caikit,sample_lib,tests
303 changes: 119 additions & 184 deletions caikit/core/data_model/dataobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,99 +18,64 @@


# Standard
from datetime import datetime
from enum import Enum
from functools import update_wrapper
from types import ModuleType
from typing import Callable, Dict, List, Type, Union
from typing import Any, Callable, List, Type, Union, get_args, get_origin
import dataclasses
import importlib
import sys
import types
import typing

# Third Party
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

# First Party
from jtd_to_proto.jtd_to_proto import JTD_TO_PROTO_TYPES
from jtd_to_proto.validation import is_valid_jtd
from py_to_proto.dataclass_to_proto import DataclassConverter
import alog
import jtd_to_proto
import py_to_proto

# Local
from ..toolkit.errors import error_handler
from . import enums
from .base import DataBase, _DataBaseMetaClass
from .streams.data_stream import DataStream

## Globals #####################################################################

log = alog.use_channel("SCHEMA")
error = error_handler.get(log)

# Type defs for input schemas to a dataobject
_SCHEMA_VALUE_TYPE = Union[str, "_SCHEMA_VALUE_TYPE", Type[DataBase]]
_SCHEMA_DEF_TYPE = Dict[str, _SCHEMA_VALUE_TYPE]

# Registry of auto-generated protos so that they can be rendered to .proto
_AUTO_GEN_PROTO_CLASSES = []

# Reserved keywords in JTD
_JTD_KEYWORDS = [
"elements",
"type",
"properties",
"enum",
"values",
"optionalProperties",
"additionalProperties",
"discriminator",
]

# Type defs for schemas passed to jtd_to_proto
_JTD_VALUE_TYPE = Union[str, "_JTD_VALUE_TYPE", _descriptor.Descriptor]
_JTD_DEF_TYPE = Dict[str, _JTD_VALUE_TYPE]

# Python type -> jtd name
_NATIVE_TYPE_TO_JTD = {
str: "string",
int: "int64",
float: "float64",
bytes: "bytes",
bool: "boolean",
datetime: "timestamp",
}
# Special attribute used to indicate which defaults are user provided
_USER_DEFINED_DEFAULTS = "__user_defined_defaults__"

## Public ######################################################################

# Common package prefix
CAIKIT_DATA_MODEL = "caikit_data_model"


def dataobject(
schema: _SCHEMA_DEF_TYPE,
package: str = CAIKIT_DATA_MODEL,
) -> Callable[[Type], Type[DataBase]]:
"""The @schema decorator can be used to define a Data Model object's schema
inline with the definition of the python class rather than needing to bind
to a pre-compiled protobufs class. For example:
def dataobject(*args, **kwargs) -> Callable[[Type], Type[DataBase]]:
"""The @dataobject decorator can be used to define a Data Model object's
schema inline with the definition of the python class rather than needing to
bind to a pre-compiled protobufs class. For example:
@dataobject(
package="foo.bar",
schema={"foo": str, "bar": int},
)
@dataobject("foo.bar")
@dataclass
class MyDataObject:
'''My Custom Data Object'''
foo: str
bar: int
NOTE: The wrapped class must NOT inherit directly from DataBase. That
inheritance will be added by this decorator, but if it is written
directly, the metaclass that links protobufs to the class will be called
before this decorator can auto-gen the protobufs class.
Args:
schema: _SCHEMA_DEF_TYPE
The full schema definition dict
package: str
The package name to use for the generated protobufs class
Expand All @@ -128,18 +93,53 @@ def decorator(cls: Type) -> Type[DataBase]:
cls.__name__,
)

# Create the message class from the schema
jtd_def = _to_jtd_schema(schema)
log.debug3("JTD Def for %s: %s", cls.__name__, jtd_def)
proto_class = jtd_to_proto.descriptor_to_message_class(
jtd_to_proto.jtd_to_proto(
# Add the package to the kwargs
kwargs.setdefault("package", package)

# If there's a schema in the keyword args, use jtd_to_proto
schema = kwargs.pop("schema", None)
if schema is not None:
log.debug("Using JTD To Proto")
kwargs.setdefault("validate_jtd", True)
descriptor = py_to_proto.jtd_to_proto(
name=cls.__name__,
package=package,
jtd_def=jtd_def,
jtd_def=schema,
**kwargs,
)
)
# pylint: disable=unused-variable,global-variable-not-assigned
global _AUTO_GEN_PROTO_CLASSES
# If it's already a dataclass, convert it directly
else:
log.debug("Using dataclass/enum to proto on dataclass")

# If it's not an enum, fill in any missing field defaults as None
# and make sure it's a dataclass
if not issubclass(cls, Enum):
log.debug2("Wrapping data class %s", cls)
user_defined_defaults = {}
for annotation in getattr(cls, "__annotations__", {}):
user_defined_default = getattr(cls, annotation, dataclasses.MISSING)
if user_defined_default == dataclasses.MISSING:
log.debug3("Filling in None default for %s.%s", cls, annotation)
setattr(cls, annotation, None)
else:
user_defined_defaults[annotation] = user_defined_default
# If the current __init__ is auto-generated by dataclass, remove
# it so that a new one is created with the new defaults. This is
# a little hard to detect across different python versions, so
# the most reliable way is to assume that the only place
# __annotations__ are added to the __init__ function itself are
# in dataclass. If cls is either not a dataclass or is a
# dataclass with a non-default __init__, there will not be
# annotations
if getattr(cls.__init__, "__annotations__", None):
log.debug3("Resetting default dataclass init")
delattr(cls, "__init__")
cls = dataclasses.dataclass(cls)
setattr(cls, _USER_DEFINED_DEFAULTS, user_defined_defaults)

descriptor = _dataobject_to_proto(dataclass_=cls, **kwargs)

# Create the message class from the dataclass
proto_class = py_to_proto.descriptor_to_message_class(descriptor)
_AUTO_GEN_PROTO_CLASSES.append(proto_class)

# Add enums to the global enums module
Expand All @@ -152,51 +152,9 @@ def decorator(cls: Type) -> Type[DataBase]:
if isinstance(proto_class, type):
wrapper_class = _make_data_model_class(proto_class, cls)
else:
ck_enum = enums.EnumBase(proto_class)

# Handling enums with the @dataobject decorator is quite tricky
# because unlike a message which is represented as a `class` in
# python, an enum is represented as an INSTANCE of a `class`. This
# means that the naive implementation of this decorator would apply
# to a `class`, but return an object that is NOT a `class` (e.g.
# isinstance(MyEnum, type) == False). This is bad for two distinct
# reasons:
#
# 1. It's confusing to see code written with a decorator around a
# `class` and have the resulting thing NOT be a class
# 2. It makes it harder to add additional functionality to the Enum
# by defining custom methods on your "enum class"
#
# To get around this, we need to "bind" the instance of the EnumBase
# to a net-new `class`! This class will function as a singleton
# wrapper around the EnumBase instance, but will allow the decorator
# to return a true `class` and allow user-defined methods on that
# class to persist through the decorator.
# pylint: disable=unused-variable
class EnumBindingMeta(type):
def __new__(mcs, name, bases, attrs):
attrs.update(vars(ck_enum))
for method in ["toYAML", "toJSON", "toDict"]:
attrs[method] = getattr(ck_enum, method)
attrs["_proto_enum"] = proto_class
attrs["_singleton_inst"] = ck_enum
bases = tuple(list(bases) + [_EnumBaseSentinel])
return super().__new__(mcs, name, bases, attrs)

def __call__(cls):
return ck_enum

def __str__(cls):
return ck_enum.__str__()

def __repr__(cls):
return ck_enum.__repr__()

class _Dummy(cls, metaclass=EnumBindingMeta):
pass

update_wrapper(_Dummy, cls, updated=())
wrapper_class = _Dummy
enums.import_enum(proto_class, cls)
setattr(cls, "_proto_enum", proto_class)
wrapper_class = cls

# Attach the proto class to the protobufs module
parent_mod_name = getattr(cls, "__module__", "").rpartition(".")[0]
Expand All @@ -214,6 +172,19 @@ class _Dummy(cls, metaclass=EnumBindingMeta):
# Return the merged data class
return wrapper_class

# If called without the function invocation, fill in the default argument
if args and callable(args[0]):
assert not kwargs, "This shouldn't happen!"
package = CAIKIT_DATA_MODEL
return decorator(args[0])

# Pull the package as an arg or a keyword arg
if args:
package = args[0]
if "package" in kwargs:
raise TypeError("Got multiple values for argument 'package'")
else:
package = kwargs.get("package", CAIKIT_DATA_MODEL)
return decorator


Expand All @@ -232,89 +203,53 @@ def render_dataobject_protos(interfaces_dir: str):
## Implementation Details ######################################################


class _EnumBaseSentinel:
"""This base class is used to provide a common base class for enum warpper
classes so that they can be identified generically
"""
def _dataobject_to_proto(*args, **kwargs):
return _DataobjectConverter(*args, **kwargs).descriptor


# pylint: disable=too-many-return-statements
def _to_jtd_schema(
input_schema: _SCHEMA_DEF_TYPE, is_inside_properties_dict: bool = False
) -> _JTD_DEF_TYPE:
"""Recursive helper that will convert an input schema to a fully fleshed out
JTD schema
class _DataobjectConverter(DataclassConverter):
"""Augment the dataclass converter to be able to pull descriptors from
existing data objects
"""
try:
# Unwrap optional to base type if applicable
input_schema = _unwrap_optional_type(input_schema)

# If it's a reference to an EnumBase, de-alias to that enum's EnumDescriptor
# NOTE: This must come before the check for dict since EnumBase instances
# are themselves dicts
if isinstance(input_schema, enums.EnumBase) or (
isinstance(input_schema, type)
and issubclass(input_schema, _EnumBaseSentinel)
):
return {"type": input_schema._proto_enum.DESCRIPTOR}

if isinstance(input_schema, dict):
# If this dict is already a JTD schema, return it as is
if is_valid_jtd(input_schema, valid_types=JTD_TO_PROTO_TYPES.keys()):
return input_schema

# If the dict is structured as a JTD element already, recurse on the
# values
if any(keyword in input_schema for keyword in _JTD_KEYWORDS):
return {
k: _to_jtd_schema(v, "properties" in k.lower())
for k, v in input_schema.items()
}

# If not, assume it's a flat properties dict
# Check to make sure we don't re-wrap *properties
translated_dict = {k: _to_jtd_schema(v) for k, v in input_schema.items()}
return (
{"properties": translated_dict}
if not is_inside_properties_dict
else translated_dict
)

# If it's a reference to another data model object, de-alias to that
# object's underlying proto descriptor
if isinstance(input_schema, type) and issubclass(input_schema, DataBase):
return {"type": input_schema.get_proto_class().DESCRIPTOR}

# If it's a native type, wrap it as a "type" element
if input_schema in _NATIVE_TYPE_TO_JTD:
return {"type": _NATIVE_TYPE_TO_JTD[input_schema]}

# If it's a list or data stream, wrap it with "elements":
if typing.get_origin(input_schema) in [list, DataStream]:
# type_ could be caikit.core.data_model.streams.data_stream.DataStream[int]
return {"elements": _to_jtd_schema(typing.get_args(input_schema)[0])}

# All other cases are invalid!
raise ValueError(f"Invalid input schema: {input_schema}")

except ValueError:
log.error("Invalid schema: %s", input_schema)
raise


def _unwrap_optional_type(type_: typing.Any) -> typing.Any:
"""Unwrap an Optional[T] type, or return the type as-is if it is not an optional
NB: Optional[T] is expressed as Union[T, None]
This function checks for Unions of [T, None] and returns T, or raises if the union
contains more types, as those need to be handled differently (and are not yet supported)
"""
if typing.get_origin(type_) != Union:
return type_
possible_types = set(typing.get_args(type_))
possible_types.discard(type(None))
if len(possible_types) == 1:
return list(possible_types)[0]
raise ValueError(f"Invalid input schema, cannot handle unions yet: {type_}")
def get_concrete_type(self, entry: Any) -> Any:
"""Also include data model classes and enums as concrete types"""
unwrapped = self._resolve_wrapped_type(entry)
if (isinstance(unwrapped, type) and issubclass(unwrapped, DataBase)) or hasattr(
unwrapped, "_proto_enum"
):
return entry
return super().get_concrete_type(entry)

def get_descriptor(self, entry: Any) -> Any:
"""Unpack data model classes and enums to their descriptors"""
entry = self._resolve_wrapped_type(entry)
if isinstance(entry, type) and issubclass(entry, DataBase):
return entry._proto_class.DESCRIPTOR
proto_enum = getattr(entry, "_proto_enum", None)
if proto_enum is not None:
return proto_enum.DESCRIPTOR
return super().get_descriptor(entry)

def get_optional_field_names(self, entry: Any) -> List[str]:
"""Get the names of any fields which are optional. This will be any
field that has a user-defined default or is marked as Optional[]
"""
optional_fields = list(getattr(entry, _USER_DEFINED_DEFAULTS, {}))
for field_name, field in entry.__dataclass_fields__.items():
if (
field_name not in optional_fields
and self._is_python_optional(field.type) is not None
):
optional_fields.append(field_name)
return optional_fields

@staticmethod
def _is_python_optional(entry: Any) -> Any:
"""Detect if this type is a python optional"""
if get_origin(entry) is Union:
args = get_args(entry)
return type(None) in args


def _get_all_enums(
Expand Down
Loading

0 comments on commit 3fb5769

Please sign in to comment.