Skip to content

Commit

Permalink
Fix deserialization error for LRO which has discriminator (#2628)
Browse files Browse the repository at this point in the history
* code

* fix for legacy test

* inv and black

* fix mypy

* fix pyright error

* fix pylint

* inv

* update

* review

* fix

* fix

* Fix test

* fix multiapi  test

* disable deserialize for all initial operation

* review

* inv

* update changelog

* inv

* inv

* force initial operation to return stream

* revert extra changes in builder_serializer

* regen

* regen lropaging

* regen with load_body for aiohttp

* fix

* inv

* use pipeline_response.http_response for legacy

* fix test

* inv

* read in response

* inv

* fix multiapi test

* inv

* fix pyright

---------

Co-authored-by: iscai-msft <[email protected]>
  • Loading branch information
msyyc and iscai-msft authored Jun 11, 2024
1 parent 76be680 commit 76168e4
Show file tree
Hide file tree
Showing 871 changed files with 5,481 additions and 15,927 deletions.
8 changes: 8 additions & 0 deletions .chronus/changes/deserialization-fix-2024-4-24-16-48-41.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
changeKind: fix
packages:
- "@autorest/python"
- "@azure-tools/typespec-python"
---

Fix deserialization error for lro when return type has discriminator and succeed in initial response
8 changes: 8 additions & 0 deletions packages/autorest.python/autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def __init__(
self.has_etag: bool = self.yaml_data.get("hasEtag", False)
self.cross_language_definition_id: Optional[str] = self.yaml_data.get("crossLanguageDefinitionId")

@property
def stream_value(self) -> Union[str, bool]:
return (
f'kwargs.pop("stream", {self.has_stream_response})'
if self.expose_stream_keyword and self.has_response_body
else self.has_stream_response
)

@property
def has_form_data_body(self):
return self.parameters.has_form_data_body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,8 @@ def example_template(self, builder: OperationType) -> List[str]:

def make_pipeline_call(self, builder: OperationType) -> List[str]:
type_ignore = self.async_mode and builder.group_name == "" # is in a mixin
stream_value = (
f'kwargs.pop("stream", {builder.has_stream_response})'
if builder.expose_stream_keyword and builder.has_response_body
else builder.has_stream_response
)
return [
f"_stream = {stream_value}",
f"_stream = {builder.stream_value}",
f"pipeline_response: PipelineResponse = {self._call_method}self._client.{self.pipeline_name}.run( "
+ f"{'# type: ignore' if type_ignore else ''} # pylint: disable=protected-access",
" _request,",
Expand Down Expand Up @@ -925,7 +920,7 @@ def response_headers_and_deserialization(
if self.code_model.options["models_mode"] == "msrest":
deserialize_code.append("deserialized = self._deserialize(")
deserialize_code.append(f" '{response.serialization_type}',{pylint_disable}")
deserialize_code.append(" pipeline_response")
deserialize_code.append(" pipeline_response.http_response")
deserialize_code.append(")")
elif self.code_model.options["models_mode"] == "dpg":
if builder.has_stream_response:
Expand Down Expand Up @@ -964,12 +959,11 @@ def response_headers_and_deserialization(
def handle_error_response(self, builder: OperationType) -> List[str]:
async_await = "await " if self.async_mode else ""
retval = [f"if response.status_code not in {str(builder.success_status_codes)}:"]
retval.extend(
[
" if _stream:",
f" {async_await} response.read() # Load the body in memory and close the socket",
]
)
response_read = f" {async_await}response.read() # Load the body in memory and close the socket"
if builder.stream_value is True: # _stream is True so no need to judge it
retval.append(response_read)
elif isinstance(builder.stream_value, str): # _stream is not sure, so we need to judge it
retval.extend([" if _stream:", f" {response_read}"])
type_ignore = " # type: ignore" if _need_type_ignore(builder) else ""
retval.append(
f" map_error(status_code=response.status_code, response=response, error_map=error_map){type_ignore}"
Expand Down Expand Up @@ -1218,12 +1212,15 @@ def _extract_data_callback(self, builder: PagingOperationType) -> List[str]:
response = builder.responses[0]
deserialized = "pipeline_response.http_response.json()"
if self.code_model.options["models_mode"] == "msrest":
suffix = ".http_response" if hasattr(builder, "initial_operation") else ""
deserialize_type = response.serialization_type
pylint_disable = " # pylint: disable=protected-access"
if isinstance(response.type, ModelType) and not response.type.internal:
deserialize_type = f'"{response.serialization_type}"'
pylint_disable = ""
deserialized = f"self._deserialize(\n {deserialize_type},{pylint_disable}\n pipeline_response\n)"
deserialized = (
f"self._deserialize(\n {deserialize_type},{pylint_disable}\n pipeline_response{suffix}\n)"
)
retval.append(f" deserialized = {deserialized}")
elif self.code_model.options["models_mode"] == "dpg":
# we don't want to generate paging models for DPG
Expand Down Expand Up @@ -1318,6 +1315,8 @@ def initial_call(self, builder: LROOperationType) -> List[str]:
retval.append(" params=_params,")
retval.append(" **kwargs")
retval.append(" )")
retval.append(f" {'await ' if self.async_mode else ''}raw_result.http_response.read() # type: ignore")

retval.append("kwargs.pop('error_map', None)")
return retval

Expand Down
7 changes: 7 additions & 0 deletions packages/autorest.python/autorest/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,18 @@ def update_lro_operation(
yaml_data: Dict[str, Any],
is_overload: bool = False,
) -> None:
def convert_initial_operation_response_type(data: Dict[str, Any]) -> None:
for response in data.get("responses", []):
response["type"] = KNOWN_TYPES["binary"]

self.update_operation(code_model, yaml_data, is_overload=is_overload)
self.update_operation(code_model, yaml_data["initialOperation"], is_overload=is_overload)
convert_initial_operation_response_type(yaml_data["initialOperation"])
self._update_lro_operation_helper(yaml_data)
for overload in yaml_data.get("overloads", []):
self._update_lro_operation_helper(overload)
self.update_operation(code_model, overload["initialOperation"], is_overload=True)
convert_initial_operation_response_type(overload["initialOperation"])

def update_paging_operation(
self,
Expand Down Expand Up @@ -466,6 +472,7 @@ def update_operation_groups(self, code_model: Dict[str, Any], client: Dict[str,
def update_yaml(self, yaml_data: Dict[str, Any]) -> None:
"""Convert in place the YAML str."""
self.update_types(yaml_data["types"])
yaml_data["types"] += KNOWN_TYPES.values()
for client in yaml_data["clients"]:
self.update_client(client)
self.update_operation_groups(yaml_data, client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ async def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -130,8 +128,6 @@ async def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -173,8 +169,6 @@ async def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -150,8 +148,6 @@ def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -193,8 +189,6 @@ def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ async def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -130,8 +128,6 @@ async def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -173,8 +169,6 @@ async def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def head200(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [200, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -150,8 +148,6 @@ def head204(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down Expand Up @@ -193,8 +189,6 @@ def head404(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return
response = pipeline_response.http_response

if response.status_code not in [204, 404]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# --------------------------------------------------------------------------
from io import IOBase
import sys
from typing import Any, Callable, Dict, IO, Iterable, Optional, Type, TypeVar, Union, cast, overload
from typing import Any, Callable, Dict, IO, Iterable, Iterator, Optional, Type, TypeVar, Union, cast, overload

from my.library import CustomDefaultPollingMethod, CustomPager, CustomPoller

Expand Down Expand Up @@ -74,7 +74,9 @@ def build_polling_paging_example_basic_paging_request(**kwargs: Any) -> HttpRequ

class PollingPagingExampleOperationsMixin(PollingPagingExampleMixinABC):

def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = None, **kwargs: Any) -> Optional[JSON]:
def _basic_polling_initial(
self, product: Optional[Union[JSON, IO[bytes]]] = None, **kwargs: Any
) -> Iterator[bytes]:
error_map: MutableMapping[int, Type[HttpResponseError]] = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
Expand All @@ -87,7 +89,7 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
_params = kwargs.pop("params", {}) or {}

content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[Optional[JSON]] = kwargs.pop("cls", None)
cls: ClsType[Iterator[bytes]] = kwargs.pop("cls", None)

content_type = content_type or "application/json"
_json = None
Expand All @@ -109,30 +111,28 @@ def _basic_polling_initial(self, product: Optional[Union[JSON, IO[bytes]]] = Non
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = True
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200, 204]:
if _stream:
response.read() # Load the body in memory and close the socket
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
else:
deserialized = None
deserialized = response.iter_bytes()

if response.status_code == 204:
deserialized = response.iter_bytes()

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return cls(pipeline_response, cast(Iterator[bytes], deserialized), {}) # type: ignore

return deserialized # type: ignore
return cast(Iterator[bytes], deserialized) # type: ignore

@overload
def begin_basic_polling(
Expand Down Expand Up @@ -245,6 +245,7 @@ def begin_basic_polling(
params=_params,
**kwargs
)
raw_result.http_response.read() # type: ignore
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
Expand Down Expand Up @@ -336,8 +337,6 @@ def get_next(next_link=None):
response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# --------------------------------------------------------------------------
from io import IOBase
import sys
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload
from typing import Any, AsyncIterable, AsyncIterator, Callable, Dict, IO, Optional, Type, TypeVar, Union, cast, overload

from my.library.aio import AsyncCustomDefaultPollingMethod, AsyncCustomPager, AsyncCustomPoller

Expand Down Expand Up @@ -47,7 +47,7 @@ class PollingPagingExampleOperationsMixin(PollingPagingExampleMixinABC):

async def _basic_polling_initial(
self, product: Optional[Union[JSON, IO[bytes]]] = None, **kwargs: Any
) -> Optional[JSON]:
) -> AsyncIterator[bytes]:
error_map: MutableMapping[int, Type[HttpResponseError]] = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
Expand All @@ -60,7 +60,7 @@ async def _basic_polling_initial(
_params = kwargs.pop("params", {}) or {}

content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None))
cls: ClsType[Optional[JSON]] = kwargs.pop("cls", None)
cls: ClsType[AsyncIterator[bytes]] = kwargs.pop("cls", None)

content_type = content_type or "application/json"
_json = None
Expand All @@ -82,30 +82,28 @@ async def _basic_polling_initial(
)
_request.url = self._client.format_url(_request.url)

_stream = False
_stream = True
pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200, 204]:
if _stream:
await response.read() # Load the body in memory and close the socket
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

deserialized = None
if response.status_code == 200:
if response.content:
deserialized = response.json()
else:
deserialized = None
deserialized = response.iter_bytes()

if response.status_code == 204:
deserialized = response.iter_bytes()

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore
return cls(pipeline_response, cast(AsyncIterator[bytes], deserialized), {}) # type: ignore

return deserialized # type: ignore
return cast(AsyncIterator[bytes], deserialized) # type: ignore

@overload
async def begin_basic_polling(
Expand Down Expand Up @@ -218,6 +216,7 @@ async def begin_basic_polling(
params=_params,
**kwargs
)
await raw_result.http_response.read() # type: ignore
kwargs.pop("error_map", None)

def get_long_running_output(pipeline_response):
Expand Down Expand Up @@ -313,8 +312,6 @@ async def get_next(next_link=None):
response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
await response.read() # Load the body in memory and close the socket
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

Expand Down
Loading

0 comments on commit 76168e4

Please sign in to comment.