Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove use of http2_server in replicator aws proxy #62

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 232 additions & 20 deletions aws-replicator/aws_replicator/client/auth_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@
import subprocess
import sys
from functools import cache
from io import BytesIO
from typing import Dict, Optional, Tuple
from urllib.parse import urlparse, urlunparse

import boto3
import requests
from botocore.awsrequest import AWSPreparedRequest
from botocore.awsrequest import AWSPreparedRequest, AWSResponse
from botocore.httpchecksum import resolve_checksum_context
from botocore.model import OperationModel
from localstack import config
from localstack import config as localstack_config
from localstack.aws.chain import HandlerChain
from localstack.aws.chain import RequestContext as AwsRequestContext
from localstack.aws.gateway import Gateway
from localstack.aws.protocol.parser import create_parser
from localstack.aws.spec import load_service
from localstack.config import external_service_url
from localstack.constants import AWS_REGION_US_EAST_1, DOCKER_IMAGE_NAME_PRO
from localstack.http import Request
from localstack.http import Response as HttpResponse
from localstack.http.hypercorn import GatewayServer
from localstack.utils.aws.aws_responses import requests_response
from localstack.utils.bootstrap import setup_logging
from localstack.utils.collections import select_attributes
Expand All @@ -37,8 +44,6 @@
from aws_replicator.config import HANDLER_PATH_PROXIES
from aws_replicator.shared.models import AddProxyRequest, ProxyConfig

from .http2_server import run_server

LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
if config.DEBUG:
Expand All @@ -57,6 +62,207 @@
DEFAULT_BIND_HOST = "127.0.0.1"


class AwsProxyHandler:
"""
A handler for an AWS Handler chain that attempts to forward the request using a specific boto3 session.
This can be used to proxy incoming requests to real AWS.
"""

def __init__(self, session: boto3.Session = None):
self.session = session or boto3.Session()

def __call__(self, chain: HandlerChain, context: AwsRequestContext, response: HttpResponse):
# prepare the API invocation parameters
LOG.info(
"Received %s.%s = %s",
context.service.service_name,
context.operation.name,
context.service_request,
)

# make the actual API call against upstream AWS (will also calculate a new auth signature)
try:
aws_response = self._make_aws_api_call(context)
except Exception:
LOG.exception(
"Exception while proxying %s.%s to AWS",
context.service.service_name,
context.operation.name,
)
raise

# tell the handler chain to respond
LOG.info(
"AWS Response %s.%s: url=%s status_code=%s, headers=%s, content=%s",
context.service.service_name,
context.operation.name,
aws_response.url,
aws_response.status_code,
aws_response.headers,
aws_response.content,
)
chain.respond(aws_response.status_code, aws_response.content, dict(aws_response.headers))

def _make_aws_api_call(self, context: AwsRequestContext) -> AWSResponse:
# TODO: reconcile with AwsRequestProxy from localstack, and other forwarder tools
# create a real AWS client
client = self.session.client(context.service.service_name, region_name=context.region)
operation_model = context.operation

# prepare API request parameters as expected by boto
api_params = {k: v for k, v in context.service_request.items() if v is not None}

# this is a stripped down version of botocore's client._make_api_call to immediately get the HTTP
# response instead of a parsed response.
request_context = {
"client_region": client.meta.region_name,
"client_config": client.meta.config,
"has_streaming_input": operation_model.has_streaming_input,
"auth_type": operation_model.auth_type,
}

(
endpoint_url,
additional_headers,
properties,
) = client._resolve_endpoint_ruleset(operation_model, api_params, request_context)
if properties:
# Pass arbitrary endpoint info with the Request
# for use during construction.
request_context["endpoint_properties"] = properties

request_dict = client._convert_to_request_dict(
api_params=api_params,
operation_model=operation_model,
endpoint_url=endpoint_url,
context=request_context,
headers=additional_headers,
)
resolve_checksum_context(request_dict, operation_model, api_params)

if operation_model.has_streaming_input:
request_dict["body"] = request_dict["body"].read()

self._adjust_request_dict(context.service.service_name, request_dict)

if operation_model.has_streaming_input:
request_dict["body"] = BytesIO(request_dict["body"])

LOG.info("Making AWS request %s", request_dict)
http, _ = client._endpoint.make_request(operation_model, request_dict)

http: AWSResponse

# for some elusive reasons, these header modifications are needed (were part of http2_server)
http.headers.pop("Date", None)
http.headers.pop("Server", None)
if operation_model.has_streaming_output:
http.headers.pop("Content-Length", None)

return http

def _adjust_request_dict(self, service_name: str, request_dict: Dict):
"""Apply minor fixes to the request dict, which seem to be required in the current setup."""
# TODO: replacing localstack-specific URLs, IDs, etc, should ideally be done in a more generalized
# way.

req_body = request_dict.get("body")

# TODO: fix for switch between path/host addressing
# Note: the behavior seems to be different across botocore versions. Seems to be working
# with 1.29.97 (fix below not required) whereas newer versions like 1.29.151 require the fix.
if service_name == "s3":
body_str = run_safe(lambda: to_str(req_body)) or ""

request_url = request_dict["url"]
url_parsed = list(urlparse(request_url))
path_parts = url_parsed[2].strip("/").split("/")
bucket_subdomain_prefix = f"://{path_parts[0]}.s3."
if bucket_subdomain_prefix in request_url:
prefix = f"/{path_parts[0]}"
url_parsed[2] = url_parsed[2].removeprefix(prefix)
request_dict["url_path"] = request_dict["url_path"].removeprefix(prefix)
# replace empty path with "/" (seems required for signature calculation)
request_dict["url_path"] = request_dict["url_path"] or "/"
url_parsed[2] = url_parsed[2] or "/"
# re-construct final URL
request_dict["url"] = urlunparse(url_parsed)

# TODO: this custom fix should not be required - investigate and remove!
if "<CreateBucketConfiguration" in body_str and "LocationConstraint" not in body_str:
region = request_dict["context"]["client_region"]
if region == AWS_REGION_US_EAST_1:
request_dict["body"] = ""
else:
request_dict["body"] = (
'<CreateBucketConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">'
f"<LocationConstraint>{region}</LocationConstraint></CreateBucketConfiguration>"
)

if service_name == "sqs" and isinstance(req_body, dict):
account_id = self._query_account_id_from_aws()
if "QueueUrl" in req_body:
queue_name = req_body["QueueUrl"].split("/")[-1]
req_body["QueueUrl"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}"
if "QueueOwnerAWSAccountId" in req_body:
req_body["QueueOwnerAWSAccountId"] = account_id
if service_name == "sqs" and request_dict.get("url"):
req_json = run_safe(lambda: json.loads(body_str)) or {}
account_id = self._query_account_id_from_aws()
queue_name = req_json.get("QueueName")
if account_id and queue_name:
request_dict["url"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}"
req_json["QueueOwnerAWSAccountId"] = account_id
request_dict["body"] = to_bytes(json.dumps(req_json))

def _fix_headers(self, request: Request, service_name: str):
if service_name == "s3":
# fix the Host header, to avoid bucket addressing issues
host = request.headers.get("Host") or ""
regex = r"^(https?://)?([0-9.]+|localhost)(:[0-9]+)?"
if re.match(regex, host):
request.headers["Host"] = re.sub(regex, r"\1s3.localhost.localstack.cloud", host)
request.headers.pop("Content-Length", None)
request.headers.pop("x-localstack-request-url", None)
request.headers.pop("X-Forwarded-For", None)
request.headers.pop("X-Localstack-Tgt-Api", None)
request.headers.pop("X-Moto-Account-Id", None)
request.headers.pop("Remote-Addr", None)

@cache
def _query_account_id_from_aws(self) -> str:
sts_client = self.session.client("sts")
result = sts_client.get_caller_identity()
return result["Account"]


class AwsProxyGateway(Gateway):
"""
A handler chain that receives AWS requests, and proxies them transparently to upstream AWS using real
credentials. It de-constructs the incoming request, and creates a new request signed with the AWS
credentials configured in the environment.
"""

def __init__(self) -> None:
from localstack.aws import handlers

super().__init__(
request_handlers=[
handlers.parse_service_name,
handlers.content_decoder,
handlers.add_region_from_header,
handlers.add_account_id,
handlers.parse_service_request,
AwsProxyHandler(),
],
exception_handlers=[
handlers.log_exception,
handlers.handle_internal_failure,
],
context_class=AwsRequestContext,
)


class AuthProxyAWS(Server):
def __init__(self, config: ProxyConfig, port: int = None):
self.config = config
Expand All @@ -65,9 +271,13 @@ def __init__(self, config: ProxyConfig, port: int = None):

def do_run(self):
self.register_in_instance()

bind_host = self.config.get("bind_host") or DEFAULT_BIND_HOST
proxy = run_server(port=self.port, bind_addresses=[bind_host], handler=self.proxy_request)
proxy.join()
srv = GatewayServer(AwsProxyGateway(), localstack_config.HostAndPort(bind_host, self.port))
srv.start()
srv.join()
# proxy = run_server(port=self.port, bind_addresses=[bind_host], handler=self.proxy_request)
# proxy.join()

def proxy_request(self, request: Request, data: bytes) -> Response:
parsed = self._extract_region_and_service(request.headers)
Expand Down Expand Up @@ -214,20 +424,23 @@ def _parse_aws_request(

def _adjust_request_dict(self, service_name: str, request_dict: Dict):
"""Apply minor fixes to the request dict, which seem to be required in the current setup."""

# TODO: replacing localstack-specific URLs, IDs, etc, should ideally be done in a more generalized
# way.
req_body = request_dict.get("body")
body_str = run_safe(lambda: to_str(req_body)) or ""

# TODO: this custom fix should not be required - investigate and remove!
if "<CreateBucketConfiguration" in body_str and "LocationConstraint" not in body_str:
region = request_dict["context"]["client_region"]
if region == AWS_REGION_US_EAST_1:
request_dict["body"] = ""
else:
request_dict["body"] = (
'<CreateBucketConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">'
f"<LocationConstraint>{region}</LocationConstraint></CreateBucketConfiguration>"
)

if service_name == "s3":
body_str = run_safe(lambda: to_str(req_body)) or ""

# TODO: this custom fix should not be required - investigate and remove!
if "<CreateBucketConfiguration" in body_str and "LocationConstraint" not in body_str:
region = request_dict["context"]["client_region"]
if region == AWS_REGION_US_EAST_1:
request_dict["body"] = ""
else:
request_dict["body"] = (
'<CreateBucketConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">'
f"<LocationConstraint>{region}</LocationConstraint></CreateBucketConfiguration>"
)

if service_name == "sqs" and isinstance(req_body, dict):
account_id = self._query_account_id_from_aws()
Expand Down Expand Up @@ -327,8 +540,7 @@ def start_aws_auth_proxy_in_container(
command = [
"bash",
"-c",
# TODO: manually installing quart/h11/hypercorn as a dirty quick fix for now. To be fixed!
f"{venv_activate}; pip install h11 hypercorn quart; pip install --upgrade --no-deps '{CLI_PIP_PACKAGE}'",
f"{venv_activate}; pip install --upgrade --no-deps '{CLI_PIP_PACKAGE}'",
]
DOCKER_CLIENT.exec_in_container(container_name, command=command)

Expand Down
30 changes: 24 additions & 6 deletions aws-replicator/aws_replicator/server/aws_request_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from localstack.utils.net import get_addressable_container_host
from localstack.utils.strings import to_str, truncate
from requests.structures import CaseInsensitiveDict
from rolo.proxy import forward

try:
from localstack.testing.config import TEST_AWS_ACCESS_KEY_ID
Expand All @@ -37,15 +38,16 @@ def __call__(self, chain: HandlerChain, context: RequestContext, response: Respo
return

# forward request to proxy
response = self.forward_request(context, proxy)
response_ = self.forward_request(context, proxy)

if response is None:
if response_ is None:
return

# Remove `Transfer-Encoding` header (which could be set to 'chunked'), to prevent client timeouts
response_.headers.pop("Transfer-Encoding", None)

# set response details, then stop handler chain to return response
chain.response.data = response.raw_content
chain.response.status_code = response.status_code
chain.response.headers.update(dict(response.headers))
response.update_from(response_)
chain.stop()

def select_proxy(self, context: RequestContext) -> Optional[ProxyInstance]:
Expand Down Expand Up @@ -126,6 +128,22 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ
port = proxy["port"]
request = context.request
target_host = get_addressable_container_host(default_local_hostname=LOCALHOST)

try:
LOG.info("Forwarding request: %s", context)
response = forward(request, f"http://{target_host}:{port}")
LOG.info(
"Received response: status=%s headers=%s body=%s",
response.status_code,
response.headers,
response.data,
)
except Exception:
LOG.exception("Exception while forwarding request")
raise

return response

url = f"http://{target_host}:{port}{request.path}?{to_str(request.query_string)}"

# inject Auth header, to ensure we're passing the right region to the proxy (e.g., for Cognito InitiateAuth)
Expand Down Expand Up @@ -158,7 +176,7 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ
)
except requests.exceptions.ConnectionError:
# remove unreachable proxy
LOG.info("Removing unreachable AWS forward proxy due to connection issue: %s", url)
LOG.exception("Removing unreachable AWS forward proxy due to connection issue: %s", url)
self.PROXY_INSTANCES.pop(port, None)
return result

Expand Down
4 changes: 4 additions & 0 deletions aws-replicator/aws_replicator/server/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class AwsReplicatorExtension(Extension):
name = "aws-replicator"

def on_extension_load(self):
logging.getLogger("aws_replicator").setLevel(
logging.DEBUG if config.DEBUG else logging.INFO
)

if config.GATEWAY_SERVER == "twisted":
LOG.warning(
"AWS resource replicator: The aws-replicator extension currently requires hypercorn as "
Expand Down
Loading