Skip to content

Commit

Permalink
fix: Ruff formatting, fix type issue
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinand-c committed Dec 9, 2023
1 parent f138a57 commit 4e91e7d
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 35 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,22 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: Lint with flake8
run: |
python -m pip install flake8
flake8 graphene_mongo --count --show-source --statistics
- name: Install dependencies
run: |
python -m pip install poetry
poetry config virtualenvs.create false
poetry install --with dev
- name: Lint
run: |
make lint
- name: Run Tests
run: make test
- name: Build Package
Expand Down
18 changes: 13 additions & 5 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,15 @@ def reference_resolver(root, *args, **kwargs):
if model in to_resolve_models:
futures.append(
pool.submit(
get_reference_objects, (model, object_id_list, registry, args)
get_reference_objects,
(model, object_id_list, registry, args),
)
)
else:
futures.append(
pool.submit(
get_non_querying_object, (model, object_id_list, registry, args)
get_non_querying_object,
(model, object_id_list, registry, args),
)
)
result = list()
Expand Down Expand Up @@ -325,7 +327,9 @@ async def reference_resolver_async(root, *args, **kwargs):
base_type = type(base_type)

return graphene.List(
base_type, description=get_field_description(field, registry), required=field.required
base_type,
description=get_field_description(field, registry),
required=field.required,
)


Expand Down Expand Up @@ -625,7 +629,9 @@ def dynamic_type():
return None
if isinstance(field, mongoengine.EmbeddedDocumentField):
return graphene.Field(
_type, description=get_field_description(field, registry), required=field.required
_type,
description=get_field_description(field, registry),
required=field.required,
)
field_resolver = None
required = False
Expand Down Expand Up @@ -755,5 +761,7 @@ def convert_field_to_enum(field, registry=None, executor: ExecutorEnum = Executo
registry.register_enum(field._enum_cls)
_type = registry.get_type_for_enum(field._enum_cls)
return graphene.Field(
_type, description=get_field_description(field, registry), required=field.required
_type,
description=get_field_description(field, registry),
required=field.required,
)
20 changes: 13 additions & 7 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def get_queryset(
reference_obj = reference_fields[arg_name].document_type(pk=arg)
hydrated_references[arg_name] = reference_obj
elif arg_name in self.model._fields_ordered and isinstance(
getattr(self.model, arg_name), mongoengine.fields.GenericReferenceField
getattr(self.model, arg_name),
mongoengine.fields.GenericReferenceField,
):
try:
reference_obj = get_document(
Expand All @@ -306,7 +307,8 @@ def get_queryset(
reference_obj = get_document(arg["_cls"])(pk=arg["_ref"].id)
hydrated_references[arg_name] = reference_obj
elif "__near" in arg_name and isinstance(
getattr(self.model, arg_name.split("__")[0]), mongoengine.fields.PointField
getattr(self.model, arg_name.split("__")[0]),
mongoengine.fields.PointField,
):
location = args.pop(arg_name, None)
hydrated_references[arg_name] = location["coordinates"]
Expand Down Expand Up @@ -383,7 +385,8 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
elif field_name in _root._fields_ordered and not (
isinstance(_root._fields[field_name].field, mongoengine.EmbeddedDocumentField)
or isinstance(
_root._fields[field_name].field, mongoengine.GenericEmbeddedDocumentField
_root._fields[field_name].field,
mongoengine.GenericEmbeddedDocumentField,
)
):
if getattr(_root, field_name, []) is not None:
Expand Down Expand Up @@ -464,13 +467,16 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
elif (
isinstance(getattr(self.model, key), mongoengine.fields.ReferenceField)
or isinstance(
getattr(self.model, key), mongoengine.fields.GenericReferenceField
getattr(self.model, key),
mongoengine.fields.GenericReferenceField,
)
or isinstance(
getattr(self.model, key), mongoengine.fields.LazyReferenceField
getattr(self.model, key),
mongoengine.fields.LazyReferenceField,
)
or isinstance(
getattr(self.model, key), mongoengine.fields.CachedReferenceField
getattr(self.model, key),
mongoengine.fields.CachedReferenceField,
)
):
if not isinstance(args_copy[key], ObjectId):
Expand Down Expand Up @@ -603,7 +609,7 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
connection_fields = [
field
for field in self.fields
if type(self.fields[field]) == MongoengineConnectionField
if isinstance(self.fields[field], MongoengineConnectionField)
]

def filter_connection(x):
Expand Down
14 changes: 9 additions & 5 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
elif field_name in _root._fields_ordered and not (
isinstance(_root._fields[field_name].field, mongoengine.EmbeddedDocumentField)
or isinstance(
_root._fields[field_name].field, mongoengine.GenericEmbeddedDocumentField
_root._fields[field_name].field,
mongoengine.GenericEmbeddedDocumentField,
)
):
if getattr(_root, field_name, []) is not None:
Expand Down Expand Up @@ -160,13 +161,16 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
elif (
isinstance(getattr(self.model, key), mongoengine.fields.ReferenceField)
or isinstance(
getattr(self.model, key), mongoengine.fields.GenericReferenceField
getattr(self.model, key),
mongoengine.fields.GenericReferenceField,
)
or isinstance(
getattr(self.model, key), mongoengine.fields.LazyReferenceField
getattr(self.model, key),
mongoengine.fields.LazyReferenceField,
)
or isinstance(
getattr(self.model, key), mongoengine.fields.CachedReferenceField
getattr(self.model, key),
mongoengine.fields.CachedReferenceField,
)
):
if not isinstance(args_copy[key], ObjectId):
Expand Down Expand Up @@ -295,7 +299,7 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args):
connection_fields = [
field
for field in self.fields
if type(self.fields[field]) == AsyncMongoengineConnectionField
if isinstance(self.fields[field], AsyncMongoengineConnectionField)
]

def filter_connection(x):
Expand Down
7 changes: 5 additions & 2 deletions graphene_mongo/tests/test_relay_query_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,8 @@ async def test_should_get_queryset_returns_dict_filters_async(fixtures):
class Query(graphene.ObjectType):
node = Node.Field()
articles = AsyncMongoengineConnectionField(
nodes_async.ArticleAsyncNode, get_queryset=lambda *_, **__: {"headline": "World"}
nodes_async.ArticleAsyncNode,
get_queryset=lambda *_, **__: {"headline": "World"},
)

query = """
Expand Down Expand Up @@ -1032,7 +1033,9 @@ class Query(graphene.ObjectType):


@pytest.mark.asyncio
async def test_should_filter_mongoengine_queryset_by_id_and_other_fields_async(fixtures):
async def test_should_filter_mongoengine_queryset_by_id_and_other_fields_async(
fixtures,
):
class Query(graphene.ObjectType):
players = AsyncMongoengineConnectionField(nodes_async.PlayerAsyncNode)

Expand Down
6 changes: 5 additions & 1 deletion graphene_mongo/tests/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from . import models
from ..types import MongoengineObjectType, MongoengineInterfaceType, MongoengineInputType
from ..types import (
MongoengineObjectType,
MongoengineInterfaceType,
MongoengineInputType,
)
from graphene.types.union import Union


Expand Down
14 changes: 10 additions & 4 deletions graphene_mongo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from .converter import convert_mongoengine_field
from .registry import Registry, get_global_registry, get_inputs_registry
from .utils import get_model_fields, is_valid_mongoengine_model, get_query_fields, ExecutorEnum
from .utils import (
get_model_fields,
is_valid_mongoengine_model,
get_query_fields,
ExecutorEnum,
)


def construct_fields(
Expand Down Expand Up @@ -254,9 +259,10 @@ def resolve_id(self, info):
MongoengineObjectType, MongoengineObjectTypeOptions = create_graphene_generic_class(
ObjectType, ObjectTypeOptions
)
MongoengineInterfaceType, MongoengineInterfaceTypeOptions = create_graphene_generic_class(
Interface, InterfaceOptions
)
(
MongoengineInterfaceType,
MongoengineInterfaceTypeOptions,
) = create_graphene_generic_class(Interface, InterfaceOptions)
MongoengineInputType, MongoengineInputTypeOptions = create_graphene_generic_class(
InputObjectType, InputObjectTypeOptions
)
Expand Down
12 changes: 8 additions & 4 deletions graphene_mongo/types_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,17 @@ def resolve_id(self, info):
return AsyncGrapheneMongoengineGenericType, AsyncMongoengineGenericObjectTypeOptions


AsyncMongoengineObjectType, AsyncMongoengineObjectTypeOptions = create_graphene_generic_class_async(
ObjectType, ObjectTypeOptions
)
(
AsyncMongoengineObjectType,
AsyncMongoengineObjectTypeOptions,
) = create_graphene_generic_class_async(ObjectType, ObjectTypeOptions)

(
AsyncMongoengineInterfaceType,
MongoengineInterfaceTypeOptions,
) = create_graphene_generic_class_async(Interface, InterfaceOptions)

AsyncGrapheneMongoengineObjectTypes = (AsyncMongoengineObjectType, AsyncMongoengineInterfaceType)
AsyncGrapheneMongoengineObjectTypes = (
AsyncMongoengineObjectType,
AsyncMongoengineInterfaceType,
)
11 changes: 9 additions & 2 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,18 @@ def find_skip_and_limit(first, last, after, before, count=None):


def connection_from_iterables(
edges, start_offset, has_previous_page, has_next_page, connection_type, edge_type, pageinfo_type
edges,
start_offset,
has_previous_page,
has_next_page,
connection_type,
edge_type,
pageinfo_type,
):
edges_items = [
edge_type(
node=node, cursor=offset_to_cursor((0 if start_offset is None else start_offset) + i)
node=node,
cursor=offset_to_cursor((0 if start_offset is None else start_offset) + i),
)
for i, node in enumerate(edges)
]
Expand Down

0 comments on commit 4e91e7d

Please sign in to comment.