Skip to content

Commit

Permalink
Merge pull request #8920 from OpenMined/add_custom_headers
Browse files Browse the repository at this point in the history
Allow User to set custom headers
  • Loading branch information
IonesioJunior authored Jun 20, 2024
2 parents e043645 + fab77ab commit d5476da
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
43 changes: 39 additions & 4 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..service.user.user_service import UserService
from ..types.grid_url import GridURL
from ..types.syft_object import SYFT_OBJECT_VERSION_2
from ..types.syft_object import SYFT_OBJECT_VERSION_3
from ..types.uid import UID
from ..util.logger import debug
from ..util.telemetry import instrument
Expand Down Expand Up @@ -129,7 +130,7 @@ class Routes(Enum):


@serializable(attrs=["proxy_target_uid", "url"])
class HTTPConnection(NodeConnection):
class HTTPConnectionV2(NodeConnection):
__canonical_name__ = "HTTPConnection"
__version__ = SYFT_OBJECT_VERSION_2

Expand All @@ -138,6 +139,18 @@ class HTTPConnection(NodeConnection):
routes: type[Routes] = Routes
session_cache: Session | None = None


@serializable(attrs=["proxy_target_uid", "url"])
class HTTPConnection(NodeConnection):
__canonical_name__ = "HTTPConnection"
__version__ = SYFT_OBJECT_VERSION_3

url: GridURL
proxy_target_uid: UID | None = None
routes: type[Routes] = Routes
session_cache: Session | None = None
headers: dict[str, str] | None = None

@field_validator("url", mode="before")
@classmethod
def make_url(cls, v: Any) -> Any:
Expand All @@ -147,6 +160,9 @@ def make_url(cls, v: Any) -> Any:
else v
)

def set_headers(self, headers: dict[str, str]) -> None:
self.headers = headers

def with_proxy(self, proxy_target_uid: UID) -> Self:
return HTTPConnection(url=self.url, proxy_target_uid=proxy_target_uid)

Expand Down Expand Up @@ -184,7 +200,11 @@ def session(self) -> Session:
def _make_get(self, path: str, params: dict | None = None) -> bytes:
url = self.url.with_path(path)
response = self.session.get(
str(url), verify=verify_tls(), proxies={}, params=params
str(url),
headers=self.headers,
verify=verify_tls(),
proxies={},
params=params,
)
if response.status_code != 200:
raise requests.ConnectionError(
Expand All @@ -204,7 +224,12 @@ def _make_post(
) -> bytes:
url = self.url.with_path(path)
response = self.session.post(
str(url), verify=verify_tls(), json=json, proxies={}, data=data
str(url),
headers=self.headers,
verify=verify_tls(),
json=json,
proxies={},
data=data,
)
if response.status_code != 200:
raise requests.ConnectionError(
Expand All @@ -219,7 +244,7 @@ def _make_post(
def stream_data(self, credentials: SyftSigningKey) -> Response:
url = self.url.with_path(self.routes.STREAM.value)
response = self.session.get(
str(url), verify=verify_tls(), proxies={}, stream=True
str(url), verify=verify_tls(), proxies={}, stream=True, headers=self.headers
)
return response

Expand Down Expand Up @@ -309,6 +334,7 @@ def make_call(self, signed_call: SignedSyftAPICall) -> Any | SyftError:
response = requests.post( # nosec
url=str(self.api_url),
data=msg_bytes,
headers=self.headers,
)

if response.status_code != 200:
Expand Down Expand Up @@ -530,6 +556,15 @@ def post_init(self) -> None:
self.metadata.supported_protocols
)

def set_headers(self, headers: dict[str, str]) -> None | SyftError:
if isinstance(self.connection, HTTPConnection):
self.connection.set_headers(headers)
return None
return SyftError( # type: ignore
message="Incompatible connection type."
+ f"Expected HTTPConnection, got {type(self.connection)}"
)

def _get_communication_protocol(
self, protocols_supported_by_server: list
) -> int | str:
Expand Down
7 changes: 7 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@
"action": "remove"
}
},
"HTTPConnection": {
"3": {
"version": 3,
"hash": "54b452bb4ab76691ac1e704b62e7bcec740850fea00805145259b37973ecd0f4",
"action": "add"
}
},
"UserCode": {
"4": {
"version": 4,
Expand Down

0 comments on commit d5476da

Please sign in to comment.