Skip to content

Commit

Permalink
POC select graphql fields
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jan 19, 2024
1 parent 634a2cc commit 8bdb512
Show file tree
Hide file tree
Showing 8 changed files with 441 additions and 54 deletions.
174 changes: 173 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ pyarrow = [
pytest = {version=">=7.2.1", optional = true}
pytest-durations = {version = ">=1.2.0", optional = true}

# GraphQL
graphql-query = {version = ">=1.2.0", optional = true}

[tool.poetry.extras]
docs = [
"sphinx",
Expand All @@ -112,6 +115,7 @@ docs = [
"sphinx-notfound-page",
"sphinx-reredirects",
]
graphql = ["graphql-query"]
s3 = ["fs-s3fs"]
testing = [
"pytest",
Expand Down
98 changes: 79 additions & 19 deletions samples/sample_tap_countries/countries_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,36 @@
from __future__ import annotations

import abc
import sys
import typing as t

from graphql_query import Field, Operation, Query

from singer_sdk import typing as th
from singer_sdk.helpers._catalog import FieldTree, selection_to_tree
from singer_sdk.helpers._compat import importlib_resources
from singer_sdk.streams.graphql import GraphQLStream

if t.TYPE_CHECKING:
from singer_sdk._singerlib.catalog import SelectionMask

SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"


def _tree_to_fields(tree: FieldTree) -> list[Field]:
"""Convert a tree to a list of GraphQL fields."""
return [
Field(name=name, fields=_tree_to_fields(subtree))
for name, subtree in tree.items()
]


def selection_to_fields(selection: SelectionMask) -> list[Field]:
"""Convert a selection mask to a GraphQL query."""
tree = selection_to_tree(selection)
return _tree_to_fields(tree)


class CountriesAPIStream(GraphQLStream, metaclass=abc.ABCMeta):
"""Sample tap test for countries.
Expand All @@ -31,25 +53,25 @@ class CountriesStream(CountriesAPIStream):

name = "countries"
primary_keys = ("code",)
query = """
countries {
code
name
native
phone
continent {
code
name
}
capital
currency
languages {
code
name
}
emoji
}
"""
# query = """
# countries {
# code
# name
# native
# phone
# continent {
# code
# name
# }
# capital
# currency
# languages {
# code
# name
# }
# emoji
# }
# """
schema = th.PropertiesList(
th.Property("code", th.StringType),
th.Property("name", th.StringType),
Expand All @@ -76,6 +98,44 @@ class CountriesStream(CountriesAPIStream):
),
).to_dict()

@property
def query(self) -> str:
"""Return the GraphQL query string."""
countries = Query(
name="countries",
fields=selection_to_fields(self.mask),
# fields=[
# Field(name="code"),
# Field(name="name"),
# Field(name="native"),
# Field(name="phone"),
# Field(name="capital"),
# Field(name="currency"),
# Field(name="emoji"),
# Field(
# name="continent",
# fields=[
# Field(name="code"),
# Field(name="name"),
# ],
# ),
# Field(
# name="languages",
# fields=[
# # Field(name="code"),
# # Field(name="name"),
# ],
# ),
# ],
)
print(countries, file=sys.stderr)
return Operation(
type="query",
queries=[
countries,
],
).render()


class ContinentsStream(CountriesAPIStream):
"""Continents stream from the Countries API."""
Expand Down
86 changes: 83 additions & 3 deletions singer_sdk/_singerlib/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import logging
import sys
import typing as t
from dataclasses import dataclass, fields

Expand Down Expand Up @@ -31,7 +32,7 @@ def __missing__(self, breadcrumb: Breadcrumb) -> bool:
Returns:
True if the breadcrumb is selected, False otherwise.
"""
return self[breadcrumb[:-2]] if len(breadcrumb) >= 2 else True # noqa: PLR2004
return self[breadcrumb[:-2]] if len(breadcrumb) >= 2 else True


@dataclass
Expand Down Expand Up @@ -153,6 +154,65 @@ def root(self) -> StreamMetadata:
"""
return self[()] # type: ignore[return-value]

@staticmethod
def _get_object_fields_metadata(
properties: dict[str, dict[str, t.Any]],
breadcrumb: Breadcrumb,
) -> t.Generator[tuple[Breadcrumb, Metadata], None, None]:
"""Get metadata for nested fields in a schema.
Args:
schema: Schema.
breadcrumb: Breadcrumb to check.
metadata: Metadata object.
Returns:
Metadata mapping.
"""
for field_name, field_schema in properties.items():
field_breadcrumb = (*breadcrumb, "properties", field_name)
print(field_breadcrumb, file=sys.stderr)
yield (
field_breadcrumb,
Metadata(inclusion=Metadata.InclusionType.AVAILABLE),
)
if "object" in field_schema.get("type"):
yield from MetadataMapping._get_object_fields_metadata(
field_schema.get("properties", {}),
field_breadcrumb,
)
if "array" in field_schema.get("type"):
yield from MetadataMapping._get_array_fields_metadata(
field_schema,
field_breadcrumb,
)

@staticmethod
def _get_array_fields_metadata(
field_schema: dict[str, t.Any],
breadcrumb: Breadcrumb,
) -> t.Generator[tuple[Breadcrumb, Metadata], None, None]:
"""Get metadata for nested fields in a schema.
Args:
schema: Schema.
breadcrumb: Breadcrumb to check.
metadata: Metadata object.
Returns:
Metadata mapping.
"""
if "object" in field_schema.get("type"):
yield from MetadataMapping._get_object_fields_metadata(
field_schema.get("properties", {}),
breadcrumb,
)
if "array" in field_schema.get("type"):
yield from MetadataMapping._get_array_fields_metadata(
field_schema.get("items", {}),
breadcrumb,
)

@classmethod
def get_standard_metadata(
cls: type[MetadataMapping],
Expand Down Expand Up @@ -191,7 +251,10 @@ def get_standard_metadata(
if schema_name:
root.schema_name = schema_name

for field_name in schema.get("properties", {}):
properties = schema.get("properties", {})

for field_name, field_schema in properties.items():
breadcrumb = ("properties", field_name)
if (
key_properties
and field_name in key_properties
Expand All @@ -201,7 +264,24 @@ def get_standard_metadata(
else:
entry = Metadata(inclusion=Metadata.InclusionType.AVAILABLE)

mapping[("properties", field_name)] = entry
mapping[breadcrumb] = entry

print(breadcrumb, field_schema, file=sys.stderr)
if "object" in field_schema.get("type", []):
mapping.update(
cls._get_object_fields_metadata(
field_schema.get("properties", {}),
breadcrumb,
),
)

if "array" in field_schema.get("type", []):
mapping.update(
cls._get_array_fields_metadata(
field_schema.get("items", {}),
breadcrumb,
),
)

mapping[()] = root

Expand Down
40 changes: 40 additions & 0 deletions singer_sdk/helpers/_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
if t.TYPE_CHECKING:
from logging import Logger

from typing_extensions import TypeAlias

from singer_sdk._singerlib import Catalog, SelectionMask


_MAX_LRU_CACHE = 500
FieldTree: TypeAlias = t.Dict[str, "FieldTree"]


@cached(max_size=_MAX_LRU_CACHE)
Expand Down Expand Up @@ -142,3 +146,39 @@ def set_catalog_stream_selected(

md_entry = catalog_entry.metadata[breadcrumb]
md_entry.selected = selected


def selection_to_tree(selection: SelectionMask) -> FieldTree:
"""Convert a SelectionMask to a tree consisting of selected fields.
Args:
selection: Selection mask dictionary.
Returns:
A dictionary tree with the selected fields.
>>> selection = {
... ("properties", "code",): True,
... ("properties", "name",): True,
... ("properties", "emoji",): True,
... ("properties", "languages",): False,
... ("properties", "languages", "properties", "code"): True,
... ("properties", "languages", "properties", "name"): False,
... ("properties", "continent",): True,
... ("properties", "continent", "properties", "code"): True,
... ("properties", "continent", "properties", "name"): False,
... }
>>> selection_to_tree(selection)
{'code': {}, 'name': {}, 'emoji': {}, 'languages': {'code': {}}, 'continent': {'code': {}}}
""" # noqa: E501
fields = {}
for crumb, selected in selection.items():
if not crumb:
continue
if selected:
top = crumb[1]
if top not in fields:
fields[top] = {}
fields[top] = selection_to_tree({crumb[2:]: selected})

return fields
49 changes: 32 additions & 17 deletions singer_sdk/helpers/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import typing as t
from enum import Enum
from functools import lru_cache
from functools import lru_cache, singledispatch

import pendulum

Expand Down Expand Up @@ -469,26 +469,41 @@ def _conform_record_data_types( # noqa: PLR0912
return output_object, unmapped_properties


def _conform_primitive_property( # noqa: PLR0911
def _conform_primitive_property(
elem: t.Any, # noqa: ANN401
property_schema: dict,
) -> t.Any: # noqa: ANN401
"""Converts a primitive (i.e. not object or array) to a json compatible type."""
if isinstance(elem, (datetime.datetime, pendulum.DateTime)):
return to_json_compatible(elem)
if isinstance(elem, datetime.date):
return f"{elem.isoformat()}T00:00:00+00:00"
if isinstance(elem, datetime.timedelta):
epoch = datetime.datetime.fromtimestamp(0, UTC)
timedelta_from_epoch = epoch + elem
if timedelta_from_epoch.tzinfo is None:
timedelta_from_epoch = timedelta_from_epoch.replace(tzinfo=UTC)
return timedelta_from_epoch.isoformat()
if isinstance(elem, datetime.time):
return str(elem)
if isinstance(elem, bytes):
# for BIT value, treat 0 as False and anything else as True
return elem != b"\x00" if is_boolean_type(property_schema) else elem.hex()
if is_boolean_type(property_schema):
return None if elem is None else elem != 0
return elem


@_conform_primitive_property.register
def _(elem: datetime.datetime | pendulum.DateTime, _property_schema: dict) -> str:
return to_json_compatible(elem)


@_conform_primitive_property.register
def _(elem: datetime.date, _property_schema: dict) -> str:
return f"{elem.isoformat()}T00:00:00+00:00"


@_conform_primitive_property.register
def _(elem: datetime.timedelta, _property_schema: dict) -> str:
epoch = datetime.datetime.fromtimestamp(0, UTC)
timedelta_from_epoch = epoch + elem
if timedelta_from_epoch.tzinfo is None:
timedelta_from_epoch = timedelta_from_epoch.replace(tzinfo=UTC)
return timedelta_from_epoch.isoformat()


@_conform_primitive_property.register
def _(elem: datetime.time, _property_schema: dict) -> str:
return str(elem)


@_conform_primitive_property.register
def _(elem: bytes, property_schema: dict) -> bool | str:
# for BIT value, treat 0 as False and anything else as True
return elem != b"\x00" if is_boolean_type(property_schema) else elem.hex()
Loading

0 comments on commit 8bdb512

Please sign in to comment.