Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Pass arguments to http client #554

Merged
merged 15 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/code/targets/7_http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@

# For AOAI the response content is located in the path choices[0].message.content - for other responses this should be in the documentation or you can manually test the output to find the right path
parsing_function = get_http_target_json_response_callback_function(key="choices[0].message.content")
http_prompt_target = HTTPTarget(http_request=raw_http_request, callback_function=parsing_function)
# httpx AsyncClient parameters can be passed as kwargs to HTTPTarget, for example the timeout below
http_prompt_target = HTTPTarget(http_request=raw_http_request, callback_function=parsing_function, timeout=20.0)

# Note, a converter is used to format the prompt to be json safe without new lines/carriage returns, etc
with PromptSendingOrchestrator(
Expand Down
7 changes: 5 additions & 2 deletions pyrit/prompt_target/http_target/http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import logging
import re
from typing import Callable
from typing import Callable, Optional, Any

from pyrit.models import construct_response_from_request, PromptRequestPiece, PromptRequestResponse
from pyrit.prompt_target import PromptTarget
Expand All @@ -26,6 +26,7 @@ class HTTPTarget(PromptTarget):
use_tls: (bool): whether to use TLS or not. Default is True
callback_function (function): function to parse HTTP response.
These are the customizable functions which determine how to parse the output
client_kwargs: (dict): additional keyword arguments to pass to the HTTP client
"""

def __init__(
Expand All @@ -34,12 +35,14 @@ def __init__(
prompt_regex_string: str = "{PROMPT}",
use_tls: bool = True,
callback_function: Callable = None,
**client_kwargs: Optional[Any],
) -> None:

self.http_request = http_request
self.callback_function = callback_function
self.prompt_regex_string = prompt_regex_string
self.use_tls = use_tls
self.client_kwargs = client_kwargs or {}

async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Expand All @@ -66,7 +69,7 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
if http_version and "HTTP/2" in http_version:
http2_version = True

async with httpx.AsyncClient(http2=http2_version) as client:
async with httpx.AsyncClient(http2=http2_version, **self.client_kwargs) as client:
response = await client.request(
method=http_method,
url=url,
Expand Down
20 changes: 20 additions & 0 deletions tests/target/test_http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ async def test_send_prompt_async(mock_request, mock_http_target, mock_http_respo
)


@pytest.mark.asyncio
@patch("httpx.AsyncClient")
async def test_send_prompt_async_client_kwargs(mock_async_client):
# Create client_kwargs to test
client_kwargs = {"timeout": 10, "verify": False}
sample_request = "GET /test HTTP/1.1\nHost: example.com\n\n"
# Create instance of HTTPTarget with client_kwargs
# Use **client_kwargs to pass them as keyword arguments
http_target = HTTPTarget(http_request=sample_request, **client_kwargs)
prompt_request = MagicMock()
prompt_request.request_pieces = [MagicMock(converted_value="")]
mock_response = MagicMock()
mock_response.content = b"Response content"
instance = mock_async_client.return_value.__aenter__.return_value
instance.request.return_value = mock_response
await http_target.send_prompt_async(prompt_request=prompt_request)

mock_async_client.assert_called_with(http2=False, timeout=10, verify=False)


@pytest.mark.asyncio
async def test_send_prompt_async_validation(mock_http_target):
# Create an invalid prompt request (missing request_pieces)
Expand Down
Loading