Skip to content

Commit

Permalink
feat: allow passing env vars
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Pokorný <[email protected]>
  • Loading branch information
JanPokorny committed Jan 3, 2025
1 parent 99dd207 commit b9822a3
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 378 deletions.
10 changes: 6 additions & 4 deletions executor/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use actix_web::{middleware::Logger, web, App, Error, HttpResponse, HttpServer};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::collections::{HashSet, HashMap};
use std::env;
use std::path::Path;
use std::time::{Duration, SystemTime};
Expand All @@ -30,6 +30,7 @@ use std::time::UNIX_EPOCH;
struct ExecuteRequest {
source_code: String,
timeout: Option<u64>,
env: Option<HashMap<String, String>>, // New field for environment variables
}

#[derive(Serialize)]
Expand Down Expand Up @@ -148,11 +149,12 @@ async fn execute(payload: web::Json<ExecuteRequest>) -> Result<HttpResponse, Err
tokio::fs::rename(source_dir.path().join("script.py"), source_dir.path().join("script.xsh")).await?;

let timeout = Duration::from_secs(payload.timeout.unwrap_or(60));
let mut cmd = Command::new("xonsh");
cmd.arg(source_dir.path().join("script.xsh"));
if let Some(env) = &payload.env { cmd.envs(env); }
let (stdout, stderr, exit_code) = tokio::time::timeout(
timeout,
Command::new("xonsh") // TODO: manually switch between python and shell for ~80ms perf gain
.arg(source_dir.path().join("script.xsh"))
.output(),
cmd.output(),
)
.await
.map(|r| {
Expand Down
680 changes: 314 additions & 366 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ python = "^3.12"

aiorun = "^2024.8.1"
anyio = "^4.6.2.post1"
frozendict = "^2.4.6"
grpcio = "^1.66.2"
grpcio-reflection = "^1.66.2"
protobuf = "5.27.2" # NOTE: needs to be in sync with generated code
protobuf = "5.28.1" # NOTE: needs to be in sync with generated code
protovalidate = "^0.4.0" # NOTE: breaking change in 0.5.0 w.r.t. generated code
pydantic = "^2.9.2"
pydantic-settings = "^2.5.2"
Expand Down
7 changes: 4 additions & 3 deletions src/code_interpreter/services/custom_tool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ async def execute(
self,
tool_source_code: str,
tool_input_json: str,
env: typing.Mapping[str, str] = {},
) -> typing.Any:
"""
Execute the given custom tool with the given input.
Expand All @@ -171,8 +172,7 @@ async def execute(
*imports, function_def = ast.parse(clean_tool_source_code).body

result = await self.code_executor.execute(
source_code=f"""
# Import all tool dependencies here -- to aid the dependency detection
source_code=f"""# Import all tool dependencies here -- to aid the dependency detection
{"\n".join(ast.unparse(node) for node in imports if isinstance(node, (ast.Import, ast.ImportFrom)))}
import pydantic
Expand All @@ -185,7 +185,8 @@ async def execute(
result = pydantic.TypeAdapter(inner_globals[{repr(function_def.name)}]).validate_json({repr(tool_input_json)})
print(json.dumps(result))
""",
""",
env=env,
)

if result.exit_code != 0:
Expand Down
6 changes: 5 additions & 1 deletion src/code_interpreter/services/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

class ExecuteRequest(BaseModel):
source_code: str
files: Dict[AbsolutePath, Hash]
files: Dict[AbsolutePath, Hash] = {}
env: Dict[str, str] = {}


class ExecuteResponse(BaseModel):
Expand All @@ -62,6 +63,7 @@ class ParseCustomToolErrorResponse(BaseModel):
class ExecuteCustomToolRequest(BaseModel):
tool_source_code: str
tool_input_json: str
env: Dict[str, str] = {}


class ExecuteCustomToolResponse(BaseModel):
Expand Down Expand Up @@ -95,6 +97,7 @@ async def execute(
result = await code_executor.execute(
source_code=request.source_code,
files=request.files,
env=request.env,
)
except Exception as e:
logger.exception("Error executing code")
Expand Down Expand Up @@ -143,6 +146,7 @@ async def execute_custom_tool(
result = await custom_tool_executor.execute(
tool_input_json=request.tool_input_json,
tool_source_code=request.tool_source_code,
env=request.env,
)
logger.info("Executed custom tool with result %s", result)
return ExecuteCustomToolResponse(tool_output_json=json.dumps(result))
Expand Down
5 changes: 3 additions & 2 deletions src/code_interpreter/services/kubernetes_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import AsyncGenerator, Mapping
from frozendict import frozendict
from pydantic import validate_call
from tenacity import (
retry,
Expand Down Expand Up @@ -82,7 +81,8 @@ def __init__(
async def execute(
self,
source_code: str,
files: Mapping[AbsolutePath, Hash] = frozendict(),
files: Mapping[AbsolutePath, Hash] = {},
env: Mapping[str, str] = {},
) -> Result:
"""
Executes the given Python source code in a Kubernetes pod.
Expand Down Expand Up @@ -118,6 +118,7 @@ async def upload_file(file_path, file_hash):
f"http://{executor_pod_ip}:8000/execute",
json={
"source_code": source_code,
"env": env,
},
)
).json()
Expand Down
32 changes: 32 additions & 0 deletions test/e2e/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ def test_create_file_in_interpreter(http_client: httpx.Client, config: Config):
assert not response_json["files"]


def test_execute_with_env(http_client: httpx.Client):
request_data = {
"source_code": "import os\nprint('Hello ' + os.environ['MY_NAME'])",
"files": {},
"env": {
"MY_NAME": "John Doe"
}
}
response = http_client.post("/v1/execute", json=request_data)
assert response.status_code == 200
response_json = response.json()
assert response_json["stdout"].strip() == "Hello John Doe"



def test_parse_custom_tool_success(http_client: httpx.Client):
response = http_client.post(
"/v1/parse-custom-tool",
Expand Down Expand Up @@ -268,3 +283,20 @@ def test_execute_custom_tool_error(http_client: httpx.Client):
assert response.status_code == 400
response_json = response.json()
assert "division by zero" in response_json["stderr"]


def test_execute_custom_tool_with_env(http_client: httpx.Client):
response = http_client.post(
"/v1/execute-custom-tool",
json={
"tool_source_code": "import os\ndef greet() -> str:\n return 'Hello ' + os.environ['MY_NAME']",
"tool_input_json": '{}',
"env": {
"MY_NAME": "John Doe"
}
},
)

assert response.status_code == 200
response_json = response.json()
assert json.loads(response_json["tool_output_json"]) == "Hello John Doe"

0 comments on commit b9822a3

Please sign in to comment.