Skip to content

Commit

Permalink
Make default_encoding optional
Browse files Browse the repository at this point in the history
  • Loading branch information
deedy5 committed Apr 1, 2024
1 parent 1aa8c23 commit 4d7aaaa
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 30 deletions.
4 changes: 4 additions & 0 deletions curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def request(
impersonate: Optional[Union[str, BrowserType]] = None,
thread: Optional[ThreadType] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
curl_options: Optional[dict] = None,
http_version: Optional[CurlHttpVersion] = None,
debug: bool = False,
Expand Down Expand Up @@ -90,6 +91,8 @@ def request(
impersonate: which browser version to impersonate.
thread: work with other thread implementations. choices: eventlet, gevent.
default_headers: whether to set default browser headers.
default_encoding: encoding for decoding response content if charset is not found in headers.
Defaults to "utf-8". Can be set to a callable for automatic detection.
curl_options: extra curl options to use.
http_version: limiting http version, http2 will be tries by default.
debug: print extra curl debug info.
Expand Down Expand Up @@ -122,6 +125,7 @@ def request(
content_callback=content_callback,
impersonate=impersonate,
default_headers=default_headers,
default_encoding=default_encoding,
http_version=http_version,
interface=interface,
multipart=multipart,
Expand Down
38 changes: 32 additions & 6 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import queue
import re
import warnings
from concurrent.futures import Future
from json import loads
from typing import Any, Awaitable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union

from .. import Curl
from .cookies import Cookies
from .errors import RequestsError
from .headers import Headers

CHARSET_RE = re.compile(r"charset=([\w-]+)")


def clear_queue(q: queue.Queue):
with q.mutex:
Expand Down Expand Up @@ -40,7 +43,8 @@ class Response:
cookies: response cookies.
elapsed: how many seconds the request cost.
encoding: http body encoding.
charset: alias for encoding.
charset_encoding: encoding specified by the Content-Type header.
default_encoding: user-defined encoding used for decoding content if charset is not found in headers.
redirect_count: how many redirects happened.
redirect_url: the final redirected url.
http_version: http version used.
Expand All @@ -58,8 +62,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non
self.headers = Headers()
self.cookies = Cookies()
self.elapsed = 0.0
self.encoding = "utf-8"
self.charset = self.encoding
self.default_encoding: Union[str, Callable[[bytes], str]] = "utf-8"
self.redirect_count = 0
self.redirect_url = ""
self.http_version = 0
Expand All @@ -70,15 +73,38 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non
self.astream_task: Optional[Awaitable] = None
self.quit_now = None

@property
def encoding(self) -> str:
if not hasattr(self, "_encoding"):
encoding = self.charset_encoding
if encoding is None:
if isinstance(self.default_encoding, str):
encoding = self.default_encoding
elif callable(self.default_encoding):
encoding = self.default_encoding(self.content)
self._encoding = encoding or "utf-8"
return self._encoding

@property
def charset_encoding(self) -> Optional[str]:
"""Return the encoding, as specified by the Content-Type header."""
content_type = self.headers.get("Content-Type")
if content_type:
charset_match = CHARSET_RE.search(content_type)
return charset_match.group(1) if charset_match else None
return None

def _decode(self, content: bytes) -> str:
try:
return content.decode(self.charset, errors="replace")
return content.decode(self.encoding, errors="replace")
except (UnicodeDecodeError, LookupError):
return content.decode("utf-8-sig")

@property
def text(self) -> str:
return self._decode(self.content)
if not hasattr(self, "_text"):
self._text = self._decode(self.content)
return self._text

def raise_for_status(self):
"""Raise an error if status code is not in [200, 400)"""
Expand Down
36 changes: 18 additions & 18 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import math
import queue
import re
import threading
import warnings
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -25,8 +24,6 @@
)
from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urljoin, urlparse

from charset_normalizer import detect

from .. import AsyncCurl, Curl, CurlError, CurlHttpVersion, CurlInfo, CurlOpt
from ..curl import CURL_WRITEFUNC_ERROR, CurlMime
from .cookies import Cookies, CookieTypes, CurlMorsel
Expand Down Expand Up @@ -57,7 +54,6 @@ class ProxySpec(TypedDict, total=False):
else:
ProxySpec = Dict[str, str]

CHARSET_RE = re.compile(r"charset=([\w-]+)")
ThreadType = Literal["eventlet", "gevent"]


Expand Down Expand Up @@ -207,6 +203,7 @@ def __init__(
max_redirects: int = -1,
impersonate: Optional[Union[str, BrowserType]] = None,
default_headers: bool = True,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
curl_options: Optional[dict] = None,
curl_infos: Optional[list] = None,
http_version: Optional[CurlHttpVersion] = None,
Expand All @@ -226,6 +223,7 @@ def __init__(
self.max_redirects = max_redirects
self.impersonate = impersonate
self.default_headers = default_headers
self.default_encoding = default_encoding
self.curl_options = curl_options or {}
self.curl_infos = curl_infos or []
self.http_version = http_version
Expand Down Expand Up @@ -540,7 +538,7 @@ def qput(chunk):

return req, buffer, header_buffer, q, header_recved, quit_now

def _parse_response(self, curl, buffer, header_buffer):
def _parse_response(self, curl, buffer, header_buffer, default_encoding):
c = curl
rsp = Response(c)
rsp.url = cast(bytes, c.getinfo(CurlInfo.EFFECTIVE_URL)).decode()
Expand Down Expand Up @@ -576,11 +574,7 @@ def _parse_response(self, curl, buffer, header_buffer):
rsp.cookies = self.cookies
# print("Cookies after extraction", self.cookies)

content_type = rsp.headers.get("Content-Type", default="")
charset_match = CHARSET_RE.search(content_type)
charset = charset_match.group(1) if charset_match else detect(rsp.content)["encoding"]
rsp.charset = rsp.encoding = charset or "utf-8"

rsp.default_encoding = default_encoding
rsp.elapsed = cast(float, c.getinfo(CurlInfo.TOTAL_TIME))
rsp.redirect_count = cast(int, c.getinfo(CurlInfo.REDIRECT_COUNT))
rsp.redirect_url = cast(bytes, c.getinfo(CurlInfo.REDIRECT_URL)).decode()
Expand Down Expand Up @@ -630,6 +624,8 @@ def __init__(
max_redirects: max redirect counts, default unlimited(-1).
impersonate: which browser version to impersonate in the session.
interface: which interface use in request to server.
default_encoding: encoding for decoding response content if charset is not found in headers.
Defaults to "utf-8". Can be set to a callable for automatic detection.
Notes:
This class can be used as a context manager.
Expand Down Expand Up @@ -758,6 +754,7 @@ def request(
content_callback: Optional[Callable] = None,
impersonate: Optional[Union[str, BrowserType]] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
Expand Down Expand Up @@ -816,7 +813,7 @@ def perform():
try:
c.perform()
except CurlError as e:
rsp = self._parse_response(c, buffer, header_buffer)
rsp = self._parse_response(c, buffer, header_buffer, default_encoding)
rsp.request = req
cast(queue.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp))
finally:
Expand All @@ -834,7 +831,7 @@ def cleanup(fut):

# Wait for the first chunk
cast(threading.Event, header_recved).wait()
rsp = self._parse_response(c, buffer, header_buffer)
rsp = self._parse_response(c, buffer, header_buffer, default_encoding)
header_parsed.set()

# Raise the exception if something wrong happens when receiving the header.
Expand All @@ -859,11 +856,11 @@ def cleanup(fut):
else:
c.perform()
except CurlError as e:
rsp = self._parse_response(c, buffer, header_buffer)
rsp = self._parse_response(c, buffer, header_buffer, default_encoding)
rsp.request = req
raise RequestsError(str(e), e.code, rsp) from e
else:
rsp = self._parse_response(c, buffer, header_buffer)
rsp = self._parse_response(c, buffer, header_buffer, default_encoding)
rsp.request = req
return rsp
finally:
Expand Down Expand Up @@ -910,6 +907,8 @@ def __init__(
allow_redirects: whether to allow redirection.
max_redirects: max redirect counts, default unlimited(-1).
impersonate: which browser version to impersonate in the session.
default_encoding: encoding for decoding response content if charset is not found in headers.
Defaults to "utf-8". Can be set to a callable for automatic detection.
Notes:
This class can be used as a context manager, and it's recommended to use via
Expand Down Expand Up @@ -1034,6 +1033,7 @@ async def request(
content_callback: Optional[Callable] = None,
impersonate: Optional[Union[str, BrowserType]] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
Expand Down Expand Up @@ -1084,7 +1084,7 @@ async def perform():
try:
await task
except CurlError as e:
rsp = self._parse_response(curl, buffer, header_buffer)
rsp = self._parse_response(curl, buffer, header_buffer, default_encoding)
rsp.request = req
cast(asyncio.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp))
finally:
Expand All @@ -1104,7 +1104,7 @@ def cleanup(fut):
# Unlike threads, coroutines does not use preemptive scheduling.
# For asyncio, there is no need for a header_parsed event, the
# _parse_response will execute in the foreground, no background tasks running.
rsp = self._parse_response(curl, buffer, header_buffer)
rsp = self._parse_response(curl, buffer, header_buffer, default_encoding)

first_element = _peek_aio_queue(cast(asyncio.Queue, q))
if isinstance(first_element, RequestsError):
Expand All @@ -1123,11 +1123,11 @@ def cleanup(fut):
await task
# print(curl.getinfo(CurlInfo.CAINFO))
except CurlError as e:
rsp = self._parse_response(curl, buffer, header_buffer)
rsp = self._parse_response(curl, buffer, header_buffer, default_encoding)
rsp.request = req
raise RequestsError(str(e), e.code, rsp) from e
else:
rsp = self._parse_response(curl, buffer, header_buffer)
rsp = self._parse_response(curl, buffer, header_buffer, default_encoding)
rsp.request = req
return rsp
finally:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ license = { file = "LICENSE" }
dependencies = [
"cffi>=1.12.0",
"certifi>=2024.2.2",
"charset_normalizer>=3.3.2,<4",
]
readme = "README.md"
requires-python = ">=3.8"
Expand All @@ -28,6 +27,7 @@ classifiers = [
[project.optional-dependencies]
dev = [
"autoflake==1.4",
"charset_normalizer>=3.3.2,<4",
"coverage==6.4.1",
"cryptography==38.0.3",
"flake8==6.0.0",
Expand All @@ -51,6 +51,7 @@ build = [
"wheel",
]
test = [
"charset_normalizer>=3.3.2,<4",
"cryptography==38.0.3",
"httpx==0.23.1",
"types-certifi==2021.10.8.2",
Expand Down
1 change: 0 additions & 1 deletion tests/unittest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ async def hello_world_gbk(scope, receive, send):


async def hello_world_windows1251(scope, receive, send):
# test encoding detection when charset is not specified in content-type header
await send(
{
"type": "http.response.start",
Expand Down
19 changes: 15 additions & 4 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from io import BytesIO

import pytest
from charset_normalizer import detect

from curl_cffi import CurlOpt, requests
from curl_cffi.const import CurlECode, CurlInfo
Expand Down Expand Up @@ -109,12 +110,22 @@ def test_headers(server):

def test_charset_parse(server):
r = requests.get(str(server.url.copy_with(path="/gbk")))
assert r.charset == "gbk"
assert r.encoding == "gbk"


def test_charset_detection(server):
r = requests.get(str(server.url.copy_with(path="/windows1251")))
assert r.charset == "windows-1251"
def test_charset_default_encoding(server):
r = requests.get(
str(server.url.copy_with(path="/windows1251")), default_encoding="windows-1251"
)
assert r.encoding == "windows-1251"


def test_charset_default_encoding_autodetect(server):
def autodetect(content):
return detect(content).get("encoding")

r = requests.get(str(server.url.copy_with(path="/windows1251")), default_encoding=autodetect)
assert r.encoding == "windows-1251"


def test_content_type_header_with_json(server):
Expand Down

0 comments on commit 4d7aaaa

Please sign in to comment.