Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] stream content where possible #349

Merged
merged 1 commit into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 85 additions & 52 deletions fixbackend/inventory/inventory_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import (
Optional,
Expand All @@ -36,6 +37,8 @@
Generic,
Callable,
Any,
AsyncGenerator,
AsyncContextManager,
)

from fixcloudutils.service import Service
Expand Down Expand Up @@ -102,7 +105,27 @@ def __init__(self, inventory_url: str, client: AsyncClient) -> None:
self.inventory_url = inventory_url
self.client = client

async def _perform(
def _check_response(
self,
response: Response,
expected_media_types: Optional[Union[str, Set[str]]],
allowed_error_codes: Optional[Set[int]],
) -> None:
if response.is_error and (allowed_error_codes is None or response.status_code in allowed_error_codes):
if response.status_code == 401:
raise GraphDatabaseForbidden(401, response.text)
elif response.status_code == 400 and "[HTTP 401][ERR 11]" in response.text:
raise GraphDatabaseNotAvailable(503, response.text)
elif response.status_code == 404 and "NoSuchGraph" in response.text:
raise NoSuchGraph(404, response.text)
else:
raise InventoryException(response.status_code, response.text)
if expected_media_types is not None and not response.is_error:
media_type, *params = response.headers.get("content-type", "").split(";")
emt = {expected_media_types} if isinstance(expected_media_types, str) else expected_media_types
assert media_type in emt, f"Expected content type {expected_media_types}, but got {media_type}"

async def _request(
self,
method: str,
path: str,
Expand All @@ -129,67 +152,81 @@ async def _perform(
# If the request takes longer than the defined timeout, we define this as client error (4xx)
raise InventoryRequestTookTooLong(408, f"Request took too long: {e}") from e
else:
if response.is_error and (allowed_error_codes is None or response.status_code in allowed_error_codes):
if response.status_code == 401:
raise GraphDatabaseForbidden(401, response.text)
elif response.status_code == 400 and "[HTTP 401][ERR 11]" in response.text:
raise GraphDatabaseNotAvailable(503, response.text)
elif response.status_code == 404 and "NoSuchGraph" in response.text:
raise NoSuchGraph(404, response.text)
else:
raise InventoryException(response.status_code, response.text)
if expected_media_types is not None and not response.is_error:
media_type, *params = response.headers.get("content-type", "").split(";")
emt = {expected_media_types} if isinstance(expected_media_types, str) else expected_media_types
assert media_type in emt, f"Expected content type {expected_media_types}, but got {media_type}"
self._check_response(response, expected_media_types, allowed_error_codes)
return response

@asynccontextmanager
async def _stream(
self,
method: str,
path: str,
*,
params: Optional[QueryParamTypes] = None,
headers: Optional[Dict[str, str]] = None,
content: Optional[str] = None,
json: Optional[Json] = None,
expected_media_types: Optional[Union[str, Set[str]]] = None,
allowed_error_codes: Optional[Set[int]] = None,
) -> AsyncGenerator[AsyncIteratorWithContext[Json], None]:
try:
async with self.client.stream(
method, self.inventory_url + path, params=params, headers=headers, content=content, json=json
) as response:
self._check_response(response, expected_media_types, allowed_error_codes)
yield AsyncIteratorWithContext(response)

except ConnectError as e:
log.exception(f"Can not connect to inventory: {e}")
raise InventoryException(502, f"Can not connect to inventory: {e}") from e
except ReadTimeout as e:
log.warning(f"Request took too long: {e}")
# If the request takes longer than the defined timeout, we define this as client error (4xx)
raise InventoryRequestTookTooLong(408, f"Request took too long: {e}") from e

async def create_database(self, access: GraphDatabaseAccess, *, graph: str = DefaultGraph) -> None:
log.info(f"Create new database for tenant: {access.workspace_id}")
# Create a new database and an empty graph
await self._perform(
await self._request(
"POST",
f"/graph/{graph}",
read_content=True,
headers=self.__headers(access, content_type=MediaTypeText, FixGraphDbCreateDatabase="true"),
)

async def execute_single(
def execute_single(
self, access: GraphDatabaseAccess, command: str, *, env: Optional[Dict[str, str]] = None
) -> AsyncIteratorWithContext[JsonElement]:
) -> AsyncContextManager[AsyncIteratorWithContext[JsonElement]]:
log.info(f"Execute command: {command}")
response = await self._perform(
return self._stream( # type: ignore
"POST",
"/cli/execute",
content=command,
params=env,
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def search(
def search(
self,
access: GraphDatabaseAccess,
query: str,
*,
graph: str = DefaultGraph,
section: str = DefaultSection,
with_edges: bool = False,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
list_or_graph = "graph" if with_edges else "list"
log.info(f"Search {list_or_graph} with query: {query}")
response = await self._perform(
return self._stream(
"POST",
f"/graph/{graph}/search/{list_or_graph}",
content=query,
params={"section": section},
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def search_history(
def search_history(
self,
access: GraphDatabaseAccess,
query: str,
Expand All @@ -199,7 +236,7 @@ async def search_history(
change: Optional[List[HistoryChange]] = None,
graph: str = DefaultGraph,
section: str = DefaultSection,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
log.info(f"Search list with query: {query}")
params: Dict[str, str] = {"section": section}
if before:
Expand All @@ -208,34 +245,32 @@ async def search_history(
params["after"] = utc_str(after)
if change:
params["change"] = ",".join(c.value for c in change)
response = await self._perform(
return self._stream(
"POST",
f"/graph/{graph}/search/history/list",
content=query,
params=params,
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def aggregate(
def aggregate(
self,
access: GraphDatabaseAccess,
query: str,
*,
graph: str = DefaultGraph,
section: str = DefaultSection,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
log.info(f"Aggregate with query: {query}")
response = await self._perform(
return self._stream(
"POST",
f"/graph/{graph}/search/aggregate",
content=query,
params={"section": section},
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def benchmarks(
self,
Expand All @@ -258,7 +293,7 @@ async def benchmarks(
params["with_checks"] = with_checks
if ids_only is not None:
params["ids_only"] = ids_only
response = await self._perform(
response = await self._request(
"GET",
"/report/benchmarks",
params=params,
Expand Down Expand Up @@ -297,7 +332,7 @@ async def checks(
params["id"] = ",".join(check_ids)
if ids_only is not None:
params["ids_only"] = ids_only
response = await self._perform(
response = await self._request(
"GET",
"/report/checks",
params=params,
Expand All @@ -317,9 +352,10 @@ async def delete_account(
) -> None:
log.info(f"Delete account {account_id} from cloud {cloud}")
query = f'is(account) and id=={account_id} and /ancestors.cloud.reported.name=="{cloud}" limit 1'
async for node in await self.search(access, query):
node_id = node["id"]
await self._perform("DELETE", f"/graph/{graph}/node/{node_id}", headers=self.__headers(access))
async with self.search(access, query) as results:
async for node in results:
node_id = node["id"]
await self._request("DELETE", f"/graph/{graph}/node/{node_id}", headers=self.__headers(access))

async def complete_property_path(
self,
Expand All @@ -333,7 +369,7 @@ async def complete_property_path(
f"Complete property path path={request.path}, prop={request.prop}, kinds={len(request.kinds or [])}, "
f"fuzzy={request.fuzzy}, skip={request.skip}, limit={request.limit}"
)
response = await self._perform(
response = await self._request(
"POST",
f"/graph/{graph}/property/path/complete",
json=request.model_dump(),
Expand All @@ -345,7 +381,7 @@ async def complete_property_path(
count = int(response.headers.get("Total-Count", "0"))
return count, cast(Dict[str, str], response.json())

async def possible_values(
def possible_values(
self,
access: GraphDatabaseAccess,
*,
Expand All @@ -357,7 +393,7 @@ async def possible_values(
count: bool = False,
graph: str = DefaultGraph,
section: str = DefaultSection,
) -> AsyncIteratorWithContext[JsonElement]:
) -> AsyncContextManager[AsyncIteratorWithContext[JsonElement]]:
log.info(f"Get possible values with query: {query}, prop_or_predicate: {prop_or_predicate} on detail: {detail}")
params = {
"section": section,
Expand All @@ -366,21 +402,19 @@ async def possible_values(
"skip": str(skip),
"count": json.dumps(count),
}
response = await self._perform(
return self._stream( # type: ignore
"POST",
f"/graph/{graph}/property/{detail}",
content=query,
params=params,
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
read_content=True,
)
return AsyncIteratorWithContext(response)

async def resource(self, access: GraphDatabaseAccess, *, id: NodeId, graph: str = DefaultGraph) -> Optional[Json]:
log.info(f"Get resource with id: {id}")
headers = self.__headers(access, accept=MediaTypeJson, content_type=MediaTypeText)
response = await self._perform(
response = await self._request(
"GET",
f"/graph/{graph}/node/{id}",
headers=headers,
Expand Down Expand Up @@ -421,7 +455,7 @@ async def model(
"with_relatives": json.dumps(with_relatives),
"with_metadata": json.dumps(with_metadata),
}
response = await self._perform(
response = await self._request(
"GET",
f"/graph/{graph}/model",
params=params,
Expand All @@ -431,7 +465,7 @@ async def model(
)
return cast(List[Json], response.json())

async def timeseries(
def timeseries(
self,
access: GraphDatabaseAccess,
name: str,
Expand All @@ -441,7 +475,7 @@ async def timeseries(
group: Optional[Set[str]] = None,
filter_group: Optional[List[str]] = None,
granularity: Optional[int | timedelta] = None,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
log.info(
f"Get timeseries with name: {name}, start: {start}, end: {end}, "
f"group: {group}, filter: {filter_group}, granularity: {granularity}"
Expand All @@ -459,14 +493,13 @@ async def timeseries(
if granularity:
value = granularity if isinstance(granularity, int) else f"{granularity.total_seconds()}s"
body["granularity"] = value
response = await self._perform(
return self._stream(
"POST",
f"/timeseries/{name}",
json=body,
headers=headers,
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def update_node(
self,
Expand All @@ -478,7 +511,7 @@ async def update_node(
section: str = DefaultSection,
) -> Json:
log.info(f"Update node with id: {id}")
response = await self._perform(
response = await self._request(
"PATCH",
f"/graph/{graph}/node/{node_id}",
json=patch,
Expand All @@ -490,7 +523,7 @@ async def update_node(
return cast(Json, response.json())

async def config(self, access: GraphDatabaseAccess, config_id: str) -> Json:
response = await self._perform(
response = await self._request(
"GET",
f"/config/{config_id}",
headers=self.__headers(access, accept=MediaTypeJson, content_type=MediaTypeJson),
Expand All @@ -502,7 +535,7 @@ async def config(self, access: GraphDatabaseAccess, config_id: str) -> Json:
async def update_config(
self, access: GraphDatabaseAccess, config_id: str, update: Json, *, patch: bool = False
) -> Json:
response = await self._perform(
response = await self._request(
"PATCH" if patch else "PUT",
f"/config/{config_id}",
json=update,
Expand Down Expand Up @@ -530,7 +563,7 @@ async def call_json(
)
if extra_headers is not None:
headers.update(extra_headers)
response = await self._perform(
response = await self._request(
method,
path,
json=body,
Expand Down
Loading
Loading