diff --git a/curl_cffi/requests/__init__.py b/curl_cffi/requests/__init__.py index 9a3465ad..2c4b054b 100644 --- a/curl_cffi/requests/__init__.py +++ b/curl_cffi/requests/__init__.py @@ -58,6 +58,7 @@ def request( impersonate: Optional[Union[str, BrowserType]] = None, thread: Optional[ThreadType] = None, default_headers: Optional[bool] = None, + default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", curl_options: Optional[dict] = None, http_version: Optional[CurlHttpVersion] = None, debug: bool = False, @@ -90,6 +91,8 @@ def request( impersonate: which browser version to impersonate. thread: work with other thread implementations. choices: eventlet, gevent. default_headers: whether to set default browser headers. + default_encoding: encoding for decoding response content if charset is not found in headers. + Defaults to "utf-8". Can be set to a callable for automatic detection. curl_options: extra curl options to use. http_version: limiting http version, http2 will be tries by default. debug: print extra curl debug info. @@ -122,6 +125,7 @@ def request( content_callback=content_callback, impersonate=impersonate, default_headers=default_headers, + default_encoding=default_encoding, http_version=http_version, interface=interface, multipart=multipart, diff --git a/curl_cffi/requests/models.py b/curl_cffi/requests/models.py index 87670ffe..5615cbd9 100644 --- a/curl_cffi/requests/models.py +++ b/curl_cffi/requests/models.py @@ -1,14 +1,17 @@ import queue +import re import warnings from concurrent.futures import Future from json import loads -from typing import Any, Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from .. import Curl from .cookies import Cookies from .errors import RequestsError from .headers import Headers +CHARSET_RE = re.compile(r"charset=([\w-]+)") + def clear_queue(q: queue.Queue): with q.mutex: @@ -40,7 +43,8 @@ class Response: cookies: response cookies. elapsed: how many seconds the request cost. encoding: http body encoding. - charset: alias for encoding. + charset_encoding: encoding specified by the Content-Type header. + default_encoding: user-defined encoding used for decoding content if charset is not found in headers. redirect_count: how many redirects happened. redirect_url: the final redirected url. http_version: http version used. @@ -58,8 +62,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non self.headers = Headers() self.cookies = Cookies() self.elapsed = 0.0 - self.encoding = "utf-8" - self.charset = self.encoding + self.default_encoding: Union[str, Callable[[bytes], str]] = "utf-8" self.redirect_count = 0 self.redirect_url = "" self.http_version = 0 @@ -70,15 +73,38 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non self.astream_task: Optional[Awaitable] = None self.quit_now = None + @property + def encoding(self) -> str: + if not hasattr(self, "_encoding"): + encoding = self.charset_encoding + if encoding is None: + if isinstance(self.default_encoding, str): + encoding = self.default_encoding + elif callable(self.default_encoding): + encoding = self.default_encoding(self.content) + self._encoding = encoding or "utf-8" + return self._encoding + + @property + def charset_encoding(self) -> Optional[str]: + """Return the encoding, as specified by the Content-Type header.""" + content_type = self.headers.get("Content-Type") + if content_type: + charset_match = CHARSET_RE.search(content_type) + return charset_match.group(1) if charset_match else None + return None + def _decode(self, content: bytes) -> str: try: - return content.decode(self.charset, errors="replace") + return content.decode(self.encoding, errors="replace") except (UnicodeDecodeError, LookupError): return content.decode("utf-8-sig") @property def text(self) -> str: - return self._decode(self.content) + if not hasattr(self, "_text"): + self._text = self._decode(self.content) + return self._text def raise_for_status(self): """Raise an error if status code is not in [200, 400)""" diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index fdaab078..a3808cdf 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -1,7 +1,6 @@ import asyncio import math import queue -import re import threading import warnings from concurrent.futures import ThreadPoolExecutor @@ -25,8 +24,6 @@ ) from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urljoin, urlparse -from charset_normalizer import detect - from .. import AsyncCurl, Curl, CurlError, CurlHttpVersion, CurlInfo, CurlOpt from ..curl import CURL_WRITEFUNC_ERROR, CurlMime from .cookies import Cookies, CookieTypes, CurlMorsel @@ -57,7 +54,6 @@ class ProxySpec(TypedDict, total=False): else: ProxySpec = Dict[str, str] -CHARSET_RE = re.compile(r"charset=([\w-]+)") ThreadType = Literal["eventlet", "gevent"] @@ -207,6 +203,7 @@ def __init__( max_redirects: int = -1, impersonate: Optional[Union[str, BrowserType]] = None, default_headers: bool = True, + default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", curl_options: Optional[dict] = None, curl_infos: Optional[list] = None, http_version: Optional[CurlHttpVersion] = None, @@ -226,6 +223,7 @@ def __init__( self.max_redirects = max_redirects self.impersonate = impersonate self.default_headers = default_headers + self.default_encoding = default_encoding self.curl_options = curl_options or {} self.curl_infos = curl_infos or [] self.http_version = http_version @@ -540,7 +538,7 @@ def qput(chunk): return req, buffer, header_buffer, q, header_recved, quit_now - def _parse_response(self, curl, buffer, header_buffer): + def _parse_response(self, curl, buffer, header_buffer, default_encoding): c = curl rsp = Response(c) rsp.url = cast(bytes, c.getinfo(CurlInfo.EFFECTIVE_URL)).decode() @@ -576,11 +574,7 @@ def _parse_response(self, curl, buffer, header_buffer): rsp.cookies = self.cookies # print("Cookies after extraction", self.cookies) - content_type = rsp.headers.get("Content-Type", default="") - charset_match = CHARSET_RE.search(content_type) - charset = charset_match.group(1) if charset_match else detect(rsp.content)["encoding"] - rsp.charset = rsp.encoding = charset or "utf-8" - + rsp.default_encoding = default_encoding rsp.elapsed = cast(float, c.getinfo(CurlInfo.TOTAL_TIME)) rsp.redirect_count = cast(int, c.getinfo(CurlInfo.REDIRECT_COUNT)) rsp.redirect_url = cast(bytes, c.getinfo(CurlInfo.REDIRECT_URL)).decode() @@ -630,6 +624,8 @@ def __init__( max_redirects: max redirect counts, default unlimited(-1). impersonate: which browser version to impersonate in the session. interface: which interface use in request to server. + default_encoding: encoding for decoding response content if charset is not found in headers. + Defaults to "utf-8". Can be set to a callable for automatic detection. Notes: This class can be used as a context manager. @@ -758,6 +754,7 @@ def request( content_callback: Optional[Callable] = None, impersonate: Optional[Union[str, BrowserType]] = None, default_headers: Optional[bool] = None, + default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, cert: Optional[Union[str, Tuple[str, str]]] = None, @@ -816,7 +813,7 @@ def perform(): try: c.perform() except CurlError as e: - rsp = self._parse_response(c, buffer, header_buffer) + rsp = self._parse_response(c, buffer, header_buffer, default_encoding) rsp.request = req cast(queue.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp)) finally: @@ -834,7 +831,7 @@ def cleanup(fut): # Wait for the first chunk cast(threading.Event, header_recved).wait() - rsp = self._parse_response(c, buffer, header_buffer) + rsp = self._parse_response(c, buffer, header_buffer, default_encoding) header_parsed.set() # Raise the exception if something wrong happens when receiving the header. @@ -859,11 +856,11 @@ def cleanup(fut): else: c.perform() except CurlError as e: - rsp = self._parse_response(c, buffer, header_buffer) + rsp = self._parse_response(c, buffer, header_buffer, default_encoding) rsp.request = req raise RequestsError(str(e), e.code, rsp) from e else: - rsp = self._parse_response(c, buffer, header_buffer) + rsp = self._parse_response(c, buffer, header_buffer, default_encoding) rsp.request = req return rsp finally: @@ -910,6 +907,8 @@ def __init__( allow_redirects: whether to allow redirection. max_redirects: max redirect counts, default unlimited(-1). impersonate: which browser version to impersonate in the session. + default_encoding: encoding for decoding response content if charset is not found in headers. + Defaults to "utf-8". Can be set to a callable for automatic detection. Notes: This class can be used as a context manager, and it's recommended to use via @@ -1034,6 +1033,7 @@ async def request( content_callback: Optional[Callable] = None, impersonate: Optional[Union[str, BrowserType]] = None, default_headers: Optional[bool] = None, + default_encoding: Union[str, Callable[[bytes], str]] = "utf-8", http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, cert: Optional[Union[str, Tuple[str, str]]] = None, @@ -1084,7 +1084,7 @@ async def perform(): try: await task except CurlError as e: - rsp = self._parse_response(curl, buffer, header_buffer) + rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) rsp.request = req cast(asyncio.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp)) finally: @@ -1104,7 +1104,7 @@ def cleanup(fut): # Unlike threads, coroutines does not use preemptive scheduling. # For asyncio, there is no need for a header_parsed event, the # _parse_response will execute in the foreground, no background tasks running. - rsp = self._parse_response(curl, buffer, header_buffer) + rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) first_element = _peek_aio_queue(cast(asyncio.Queue, q)) if isinstance(first_element, RequestsError): @@ -1123,11 +1123,11 @@ def cleanup(fut): await task # print(curl.getinfo(CurlInfo.CAINFO)) except CurlError as e: - rsp = self._parse_response(curl, buffer, header_buffer) + rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) rsp.request = req raise RequestsError(str(e), e.code, rsp) from e else: - rsp = self._parse_response(curl, buffer, header_buffer) + rsp = self._parse_response(curl, buffer, header_buffer, default_encoding) rsp.request = req return rsp finally: diff --git a/pyproject.toml b/pyproject.toml index b883880d..c470772e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ license = { file = "LICENSE" } dependencies = [ "cffi>=1.12.0", "certifi>=2024.2.2", - "charset_normalizer>=3.3.2,<4", ] readme = "README.md" requires-python = ">=3.8" @@ -28,6 +27,7 @@ classifiers = [ [project.optional-dependencies] dev = [ "autoflake==1.4", + "charset_normalizer>=3.3.2,<4", "coverage==6.4.1", "cryptography==38.0.3", "flake8==6.0.0", @@ -51,6 +51,7 @@ build = [ "wheel", ] test = [ + "charset_normalizer>=3.3.2,<4", "cryptography==38.0.3", "httpx==0.23.1", "types-certifi==2021.10.8.2", diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index 93233d86..862b1b3b 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -168,7 +168,6 @@ async def hello_world_gbk(scope, receive, send): async def hello_world_windows1251(scope, receive, send): - # test encoding detection when charset is not specified in content-type header await send( { "type": "http.response.start", diff --git a/tests/unittest/test_requests.py b/tests/unittest/test_requests.py index 8f447996..57a7165b 100644 --- a/tests/unittest/test_requests.py +++ b/tests/unittest/test_requests.py @@ -4,6 +4,7 @@ from io import BytesIO import pytest +from charset_normalizer import detect from curl_cffi import CurlOpt, requests from curl_cffi.const import CurlECode, CurlInfo @@ -109,12 +110,22 @@ def test_headers(server): def test_charset_parse(server): r = requests.get(str(server.url.copy_with(path="/gbk"))) - assert r.charset == "gbk" + assert r.encoding == "gbk" -def test_charset_detection(server): - r = requests.get(str(server.url.copy_with(path="/windows1251"))) - assert r.charset == "windows-1251" +def test_charset_default_encoding(server): + r = requests.get( + str(server.url.copy_with(path="/windows1251")), default_encoding="windows-1251" + ) + assert r.encoding == "windows-1251" + + +def test_charset_default_encoding_autodetect(server): + def autodetect(content): + return detect(content).get("encoding") + + r = requests.get(str(server.url.copy_with(path="/windows1251")), default_encoding=autodetect) + assert r.encoding == "windows-1251" def test_content_type_header_with_json(server):