Skip to content

Commit

Permalink
Merge pull request #98 from mirumee/custom_config_file
Browse files Browse the repository at this point in the history
Custom config file
  • Loading branch information
mat-sop authored Mar 17, 2023
2 parents b7e0dc3 + 5828dea commit 0888340
Show file tree
Hide file tree
Showing 18 changed files with 318 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- Added `ScalarsDefinitionsGenerator` and `PackageGenerator` plugin hooks.
- Added support for `[tool.ariadne-codegen]` section key. Deprecated `[ariadne-codegen]`.
- Added support for environment variables to remote schema headers values.
- Added `--config` argument to `ariadne-codegen` script, to support reading configuration from custom path.


## 0.3.0 (2023-02-21)
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ $ pip install ariadne-codegen

## Configuration

`ariadne-codegen` reads configuration from `[tool.ariadne-codegen]` section in your `pyproject.toml`'.
`ariadne-codegen` reads configuration from `[tool.ariadne-codegen]` section in your `pyproject.toml`'. You can use other configuration file with `--config` option, eg. `ariadne-codegen --config custom_file.toml`

Required settings:

Expand Down Expand Up @@ -191,6 +191,15 @@ class ListUsersUsers(BaseModel, UsersMixin):
```


## Multiple clients

To generate multiple different clients you can store config for each in different file, then provide path to config file by `--config` option, eg.
```
ariadne-codegen --config clientA.toml
ariadne-codegen --config clientB.toml
```


## Generated code dependencies

Generated code requires:
Expand Down
10 changes: 7 additions & 3 deletions ariadne_codegen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,14 @@ def get_config_file_path(file_name: str = "pyproject.toml") -> Path:
return directory.joinpath(file_name).resolve()


def get_config_dict() -> Dict:
def get_config_dict(config_file_name: Optional[str] = None) -> Dict:
"""Get config dict."""
config_path = get_config_file_path()
return toml.load(config_path)
if config_file_name:
config_file_path = get_config_file_path(config_file_name)
else:
config_file_path = get_config_file_path()

return toml.load(config_file_path)


def parse_config_dict(config_dict: Dict) -> Settings:
Expand Down
5 changes: 3 additions & 2 deletions ariadne_codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

@click.command()
@click.version_option()
def main():
config_dict = get_config_dict()
@click.option("--config", default=None, help="Path to custom configuration file.")
def main(config=None):
config_dict = get_config_dict(config)
settings = parse_config_dict(config_dict)
if settings.schema_path:
schema = get_graphql_schema_from_path(settings.schema_path)
Expand Down
5 changes: 5 additions & 0 deletions tests/main/custom_config_file/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[tool.ariadne-codegen]
queries_path = "queries.graphql"
schema_path = "schema.graphql"
target_package_name = "custom_config_client"
include_comments = false
23 changes: 23 additions & 0 deletions tests/main/custom_config_file/expected_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .async_base_client import AsyncBaseClient
from .base_model import BaseModel
from .client import Client
from .exceptions import (
GraphQLClientError,
GraphQLClientGraphQLError,
GraphQLClientGraphQLMultiError,
GraphQLClientHttpError,
GraphQlClientInvalidResponseError,
)
from .test import Test

__all__ = [
"AsyncBaseClient",
"BaseModel",
"Client",
"GraphQLClientError",
"GraphQLClientGraphQLError",
"GraphQLClientGraphQLMultiError",
"GraphQLClientHttpError",
"GraphQlClientInvalidResponseError",
"Test",
]
80 changes: 80 additions & 0 deletions tests/main/custom_config_file/expected_client/async_base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Dict, Optional, TypeVar, cast

import httpx
from pydantic import BaseModel

from .exceptions import (
GraphQLClientGraphQLMultiError,
GraphQLClientHttpError,
GraphQlClientInvalidResponseError,
)

Self = TypeVar("Self", bound="AsyncBaseClient")


class AsyncBaseClient:
def __init__(
self,
url: str = "",
headers: Optional[Dict[str, str]] = None,
http_client: Optional[httpx.AsyncClient] = None,
) -> None:
self.url = url
self.headers = headers

self.http_client = (
http_client if http_client else httpx.AsyncClient(headers=headers)
)

async def __aenter__(self: Self) -> Self:
return self

async def __aexit__(
self,
exc_type: object,
exc_val: object,
exc_tb: object,
) -> None:
await self.http_client.aclose()

async def execute(
self, query: str, variables: Optional[Dict[str, Any]] = None
) -> httpx.Response:
payload: Dict[str, Any] = {"query": query}
if variables:
payload["variables"] = self._convert_dict_to_json_serializable(variables)
return await self.http_client.post(url=self.url, json=payload)

def get_data(self, response: httpx.Response) -> dict[str, Any]:
if not response.is_success:
raise GraphQLClientHttpError(
status_code=response.status_code, response=response
)

try:
response_json = response.json()
except ValueError as exc:
raise GraphQlClientInvalidResponseError(response=response) from exc

if (not isinstance(response_json, dict)) or ("data" not in response_json):
raise GraphQlClientInvalidResponseError(response=response)

data = response_json["data"]
errors = response_json.get("errors")

if errors:
raise GraphQLClientGraphQLMultiError.from_errors_dicts(
errors_dicts=errors, data=data
)

return cast(dict[str, Any], data)

def _convert_dict_to_json_serializable(
self, dict_: Dict[str, Any]
) -> Dict[str, Any]:
return {
key: value
if not isinstance(value, BaseModel)
else value.dict(by_alias=True)
for key, value in dict_.items()
}
30 changes: 30 additions & 0 deletions tests/main/custom_config_file/expected_client/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any, Dict

from pydantic import BaseModel as PydanticBaseModel
from pydantic.class_validators import validator
from pydantic.fields import ModelField

from .scalars import SCALARS_PARSE_FUNCTIONS, SCALARS_SERIALIZE_FUNCTIONS


class BaseModel(PydanticBaseModel):
class Config:
allow_population_by_field_name = True
validate_assignment = True
arbitrary_types_allowed = True

# pylint: disable=no-self-argument
@validator("*", pre=True)
def decode_custom_scalars(cls, value: Any, field: ModelField) -> Any:
decode = SCALARS_PARSE_FUNCTIONS.get(field.type_)
if decode and callable(decode):
return decode(value)
return value

def dict(self, **kwargs: Any) -> Dict[str, Any]:
dict_ = super().dict(**kwargs)
for key, value in dict_.items():
serialize = SCALARS_SERIALIZE_FUNCTIONS.get(type(value))
if serialize and callable(serialize):
dict_[key] = serialize(value)
return dict_
23 changes: 23 additions & 0 deletions tests/main/custom_config_file/expected_client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any, List, Optional

from .async_base_client import AsyncBaseClient
from .test import Test


def gql(q: str) -> str:
return q


class Client(AsyncBaseClient):
async def test(self) -> Test:
query = gql(
"""
query test {
testQuery
}
"""
)
variables: dict[str, object] = {}
response = await self.execute(query=query, variables=variables)
data = self.get_data(response)
return Test.parse_obj(data)
Empty file.
71 changes: 71 additions & 0 deletions tests/main/custom_config_file/expected_client/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Any, Dict, List, Optional

import httpx


class GraphQLClientError(Exception):
"""Base exception."""


class GraphQLClientHttpError(GraphQLClientError):
def __init__(self, status_code: int, response: httpx.Response) -> None:
self.status_code = status_code
self.response = response

def __str__(self) -> str:
return f"HTTP status code: {self.status_code}"


class GraphQlClientInvalidResponseError(GraphQLClientError):
def __init__(self, response: httpx.Response) -> None:
self.response = response

def __str__(self) -> str:
return "Invalid response format."


class GraphQLClientGraphQLError(GraphQLClientError):
def __init__(
self,
message: str,
locations: Optional[List[Dict[str, int]]] = None,
path: Optional[List[str]] = None,
extensions: Optional[Dict[str, object]] = None,
orginal: Optional[Dict[str, object]] = None,
):
self.message = message
self.locations = locations
self.path = path
self.extensions = extensions
self.orginal = orginal

def __str__(self) -> str:
return self.message

@classmethod
def from_dict(cls, error: dict[str, Any]) -> "GraphQLClientGraphQLError":
return cls(
message=error["message"],
locations=error.get("locations"),
path=error.get("path"),
extensions=error.get("extensions"),
orginal=error,
)


class GraphQLClientGraphQLMultiError(GraphQLClientError):
def __init__(self, errors: List[GraphQLClientGraphQLError], data: dict[str, Any]):
self.errors = errors
self.data = data

def __str__(self) -> str:
return "; ".join(str(e) for e in self.errors)

@classmethod
def from_errors_dicts(
cls, errors_dicts: List[dict[str, Any]], data: dict[str, Any]
) -> "GraphQLClientGraphQLMultiError":
return cls(
errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts],
data=data,
)
Empty file.
4 changes: 4 additions & 0 deletions tests/main/custom_config_file/expected_client/scalars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Any, Callable, Dict

SCALARS_PARSE_FUNCTIONS: Dict[Any, Callable[[str], Any]] = {}
SCALARS_SERIALIZE_FUNCTIONS: Dict[Any, Callable[[Any], str]] = {}
10 changes: 10 additions & 0 deletions tests/main/custom_config_file/expected_client/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import Field

from .base_model import BaseModel


class Test(BaseModel):
test_query: str = Field(alias="testQuery")


Test.update_forward_refs()
3 changes: 3 additions & 0 deletions tests/main/custom_config_file/queries.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
query test {
testQuery
}
7 changes: 7 additions & 0 deletions tests/main/custom_config_file/schema.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
schema {
query: Query
}

type Query {
testQuery: String!
}
34 changes: 32 additions & 2 deletions tests/main/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from importlib.metadata import version
from pathlib import Path
from typing import List

import httpx
import pytest
Expand All @@ -18,8 +19,7 @@
def project_dir(request, tmp_path):
pyproject_path, files_to_copy = request.param
tmp_path.joinpath("pyproject.toml").write_text(pyproject_path.read_text())
for file_ in files_to_copy:
tmp_path.joinpath(file_.name).write_text(file_.read_text())
copy_files(files_to_copy, tmp_path)

old_cwd = Path.cwd()
os.chdir(tmp_path)
Expand All @@ -29,6 +29,11 @@ def project_dir(request, tmp_path):
os.chdir(old_cwd)


def copy_files(files_to_copy: List[Path], target_dir: Path):
for file_ in files_to_copy:
target_dir.joinpath(file_.name).write_text(file_.read_text())


def assert_the_same_files_in_directories(dir1: Path, dir2: Path):
files1 = [f for f in dir1.glob("*") if f.name != "__pycache__"]
assert [f.name for f in files1] == [
Expand Down Expand Up @@ -209,3 +214,28 @@ def test_main_uses_remote_schema_url_and_remote_schema_headers(
assert mocked_post.called_with(
url="http://test/graphql/", headers={"header1": "value1", "header2": "value2"}
)


def test_main_can_read_config_from_provided_file(tmp_path):
old_cwd = Path.cwd()
files_to_copy = (
Path(__file__).parent / "custom_config_file" / "config.toml",
Path(__file__).parent / "custom_config_file" / "queries.graphql",
Path(__file__).parent / "custom_config_file" / "schema.graphql",
)
copy_files(files_to_copy, tmp_path)
expected_client_path = (
Path(__file__).parent / "custom_config_file" / "expected_client"
)
package_name = "custom_config_client"

os.chdir(tmp_path)
result = CliRunner().invoke(
main, args="--config config.toml", catch_exceptions=False
)
os.chdir(old_cwd)

assert result.exit_code == 0
package_path = tmp_path / package_name
assert package_path.is_dir()
assert_the_same_files_in_directories(package_path, expected_client_path)
Loading

0 comments on commit 0888340

Please sign in to comment.