Skip to content

Commit

Permalink
[feat] stream content where possible (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Mar 22, 2024
1 parent be74794 commit d9a490b
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 221 deletions.
137 changes: 85 additions & 52 deletions fixbackend/inventory/inventory_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import logging
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import (
Optional,
Expand All @@ -36,6 +37,8 @@
Generic,
Callable,
Any,
AsyncGenerator,
AsyncContextManager,
)

from fixcloudutils.service import Service
Expand Down Expand Up @@ -102,7 +105,27 @@ def __init__(self, inventory_url: str, client: AsyncClient) -> None:
self.inventory_url = inventory_url
self.client = client

async def _perform(
def _check_response(
self,
response: Response,
expected_media_types: Optional[Union[str, Set[str]]],
allowed_error_codes: Optional[Set[int]],
) -> None:
if response.is_error and (allowed_error_codes is None or response.status_code in allowed_error_codes):
if response.status_code == 401:
raise GraphDatabaseForbidden(401, response.text)
elif response.status_code == 400 and "[HTTP 401][ERR 11]" in response.text:
raise GraphDatabaseNotAvailable(503, response.text)
elif response.status_code == 404 and "NoSuchGraph" in response.text:
raise NoSuchGraph(404, response.text)
else:
raise InventoryException(response.status_code, response.text)
if expected_media_types is not None and not response.is_error:
media_type, *params = response.headers.get("content-type", "").split(";")
emt = {expected_media_types} if isinstance(expected_media_types, str) else expected_media_types
assert media_type in emt, f"Expected content type {expected_media_types}, but got {media_type}"

async def _request(
self,
method: str,
path: str,
Expand All @@ -129,67 +152,81 @@ async def _perform(
# If the request takes longer than the defined timeout, we define this as client error (4xx)
raise InventoryRequestTookTooLong(408, f"Request took too long: {e}") from e
else:
if response.is_error and (allowed_error_codes is None or response.status_code in allowed_error_codes):
if response.status_code == 401:
raise GraphDatabaseForbidden(401, response.text)
elif response.status_code == 400 and "[HTTP 401][ERR 11]" in response.text:
raise GraphDatabaseNotAvailable(503, response.text)
elif response.status_code == 404 and "NoSuchGraph" in response.text:
raise NoSuchGraph(404, response.text)
else:
raise InventoryException(response.status_code, response.text)
if expected_media_types is not None and not response.is_error:
media_type, *params = response.headers.get("content-type", "").split(";")
emt = {expected_media_types} if isinstance(expected_media_types, str) else expected_media_types
assert media_type in emt, f"Expected content type {expected_media_types}, but got {media_type}"
self._check_response(response, expected_media_types, allowed_error_codes)
return response

@asynccontextmanager
async def _stream(
self,
method: str,
path: str,
*,
params: Optional[QueryParamTypes] = None,
headers: Optional[Dict[str, str]] = None,
content: Optional[str] = None,
json: Optional[Json] = None,
expected_media_types: Optional[Union[str, Set[str]]] = None,
allowed_error_codes: Optional[Set[int]] = None,
) -> AsyncGenerator[AsyncIteratorWithContext[Json], None]:
try:
async with self.client.stream(
method, self.inventory_url + path, params=params, headers=headers, content=content, json=json
) as response:
self._check_response(response, expected_media_types, allowed_error_codes)
yield AsyncIteratorWithContext(response)

except ConnectError as e:
log.exception(f"Can not connect to inventory: {e}")
raise InventoryException(502, f"Can not connect to inventory: {e}") from e
except ReadTimeout as e:
log.warning(f"Request took too long: {e}")
# If the request takes longer than the defined timeout, we define this as client error (4xx)
raise InventoryRequestTookTooLong(408, f"Request took too long: {e}") from e

async def create_database(self, access: GraphDatabaseAccess, *, graph: str = DefaultGraph) -> None:
log.info(f"Create new database for tenant: {access.workspace_id}")
# Create a new database and an empty graph
await self._perform(
await self._request(
"POST",
f"/graph/{graph}",
read_content=True,
headers=self.__headers(access, content_type=MediaTypeText, FixGraphDbCreateDatabase="true"),
)

async def execute_single(
def execute_single(
self, access: GraphDatabaseAccess, command: str, *, env: Optional[Dict[str, str]] = None
) -> AsyncIteratorWithContext[JsonElement]:
) -> AsyncContextManager[AsyncIteratorWithContext[JsonElement]]:
log.info(f"Execute command: {command}")
response = await self._perform(
return self._stream( # type: ignore
"POST",
"/cli/execute",
content=command,
params=env,
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def search(
def search(
self,
access: GraphDatabaseAccess,
query: str,
*,
graph: str = DefaultGraph,
section: str = DefaultSection,
with_edges: bool = False,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
list_or_graph = "graph" if with_edges else "list"
log.info(f"Search {list_or_graph} with query: {query}")
response = await self._perform(
return self._stream(
"POST",
f"/graph/{graph}/search/{list_or_graph}",
content=query,
params={"section": section},
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def search_history(
def search_history(
self,
access: GraphDatabaseAccess,
query: str,
Expand All @@ -199,7 +236,7 @@ async def search_history(
change: Optional[List[HistoryChange]] = None,
graph: str = DefaultGraph,
section: str = DefaultSection,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
log.info(f"Search list with query: {query}")
params: Dict[str, str] = {"section": section}
if before:
Expand All @@ -208,34 +245,32 @@ async def search_history(
params["after"] = utc_str(after)
if change:
params["change"] = ",".join(c.value for c in change)
response = await self._perform(
return self._stream(
"POST",
f"/graph/{graph}/search/history/list",
content=query,
params=params,
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def aggregate(
def aggregate(
self,
access: GraphDatabaseAccess,
query: str,
*,
graph: str = DefaultGraph,
section: str = DefaultSection,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
log.info(f"Aggregate with query: {query}")
response = await self._perform(
return self._stream(
"POST",
f"/graph/{graph}/search/aggregate",
content=query,
params={"section": section},
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def benchmarks(
self,
Expand All @@ -258,7 +293,7 @@ async def benchmarks(
params["with_checks"] = with_checks
if ids_only is not None:
params["ids_only"] = ids_only
response = await self._perform(
response = await self._request(
"GET",
"/report/benchmarks",
params=params,
Expand Down Expand Up @@ -297,7 +332,7 @@ async def checks(
params["id"] = ",".join(check_ids)
if ids_only is not None:
params["ids_only"] = ids_only
response = await self._perform(
response = await self._request(
"GET",
"/report/checks",
params=params,
Expand All @@ -317,9 +352,10 @@ async def delete_account(
) -> None:
log.info(f"Delete account {account_id} from cloud {cloud}")
query = f'is(account) and id=={account_id} and /ancestors.cloud.reported.name=="{cloud}" limit 1'
async for node in await self.search(access, query):
node_id = node["id"]
await self._perform("DELETE", f"/graph/{graph}/node/{node_id}", headers=self.__headers(access))
async with self.search(access, query) as results:
async for node in results:
node_id = node["id"]
await self._request("DELETE", f"/graph/{graph}/node/{node_id}", headers=self.__headers(access))

async def complete_property_path(
self,
Expand All @@ -333,7 +369,7 @@ async def complete_property_path(
f"Complete property path path={request.path}, prop={request.prop}, kinds={len(request.kinds or [])}, "
f"fuzzy={request.fuzzy}, skip={request.skip}, limit={request.limit}"
)
response = await self._perform(
response = await self._request(
"POST",
f"/graph/{graph}/property/path/complete",
json=request.model_dump(),
Expand All @@ -345,7 +381,7 @@ async def complete_property_path(
count = int(response.headers.get("Total-Count", "0"))
return count, cast(Dict[str, str], response.json())

async def possible_values(
def possible_values(
self,
access: GraphDatabaseAccess,
*,
Expand All @@ -357,7 +393,7 @@ async def possible_values(
count: bool = False,
graph: str = DefaultGraph,
section: str = DefaultSection,
) -> AsyncIteratorWithContext[JsonElement]:
) -> AsyncContextManager[AsyncIteratorWithContext[JsonElement]]:
log.info(f"Get possible values with query: {query}, prop_or_predicate: {prop_or_predicate} on detail: {detail}")
params = {
"section": section,
Expand All @@ -366,21 +402,19 @@ async def possible_values(
"skip": str(skip),
"count": json.dumps(count),
}
response = await self._perform(
return self._stream( # type: ignore
"POST",
f"/graph/{graph}/property/{detail}",
content=query,
params=params,
headers=self.__headers(access, accept=MediaTypeNdJson, content_type=MediaTypeText),
expected_media_types=ExpectMediaTypeNdJson,
read_content=True,
)
return AsyncIteratorWithContext(response)

async def resource(self, access: GraphDatabaseAccess, *, id: NodeId, graph: str = DefaultGraph) -> Optional[Json]:
log.info(f"Get resource with id: {id}")
headers = self.__headers(access, accept=MediaTypeJson, content_type=MediaTypeText)
response = await self._perform(
response = await self._request(
"GET",
f"/graph/{graph}/node/{id}",
headers=headers,
Expand Down Expand Up @@ -421,7 +455,7 @@ async def model(
"with_relatives": json.dumps(with_relatives),
"with_metadata": json.dumps(with_metadata),
}
response = await self._perform(
response = await self._request(
"GET",
f"/graph/{graph}/model",
params=params,
Expand All @@ -431,7 +465,7 @@ async def model(
)
return cast(List[Json], response.json())

async def timeseries(
def timeseries(
self,
access: GraphDatabaseAccess,
name: str,
Expand All @@ -441,7 +475,7 @@ async def timeseries(
group: Optional[Set[str]] = None,
filter_group: Optional[List[str]] = None,
granularity: Optional[int | timedelta] = None,
) -> AsyncIteratorWithContext[Json]:
) -> AsyncContextManager[AsyncIteratorWithContext[Json]]:
log.info(
f"Get timeseries with name: {name}, start: {start}, end: {end}, "
f"group: {group}, filter: {filter_group}, granularity: {granularity}"
Expand All @@ -459,14 +493,13 @@ async def timeseries(
if granularity:
value = granularity if isinstance(granularity, int) else f"{granularity.total_seconds()}s"
body["granularity"] = value
response = await self._perform(
return self._stream(
"POST",
f"/timeseries/{name}",
json=body,
headers=headers,
expected_media_types=ExpectMediaTypeNdJson,
)
return AsyncIteratorWithContext(response)

async def update_node(
self,
Expand All @@ -478,7 +511,7 @@ async def update_node(
section: str = DefaultSection,
) -> Json:
log.info(f"Update node with id: {id}")
response = await self._perform(
response = await self._request(
"PATCH",
f"/graph/{graph}/node/{node_id}",
json=patch,
Expand All @@ -490,7 +523,7 @@ async def update_node(
return cast(Json, response.json())

async def config(self, access: GraphDatabaseAccess, config_id: str) -> Json:
response = await self._perform(
response = await self._request(
"GET",
f"/config/{config_id}",
headers=self.__headers(access, accept=MediaTypeJson, content_type=MediaTypeJson),
Expand All @@ -502,7 +535,7 @@ async def config(self, access: GraphDatabaseAccess, config_id: str) -> Json:
async def update_config(
self, access: GraphDatabaseAccess, config_id: str, update: Json, *, patch: bool = False
) -> Json:
response = await self._perform(
response = await self._request(
"PATCH" if patch else "PUT",
f"/config/{config_id}",
json=update,
Expand Down Expand Up @@ -530,7 +563,7 @@ async def call_json(
)
if extra_headers is not None:
headers.update(extra_headers)
response = await self._perform(
response = await self._request(
method,
path,
json=body,
Expand Down
Loading

0 comments on commit d9a490b

Please sign in to comment.