Skip to content

Expose run_cli() method on Agent #1642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ab9c522
Add agent CLI functionality
AndrewHannigan May 4, 2025
79ab6fc
add tests
AndrewHannigan May 4, 2025
217def2
update tests
AndrewHannigan May 4, 2025
6b586d9
simplify changes
AndrewHannigan May 4, 2025
b019fe4
add cli dependency group to base pyproject.toml
AndrewHannigan May 4, 2025
5fdf352
require extras
AndrewHannigan May 4, 2025
8c083fc
simplify
AndrewHannigan May 4, 2025
ca77f06
add test
AndrewHannigan May 4, 2025
d72b214
use cli_agent in testing
AndrewHannigan May 4, 2025
bb3fb3f
fix test
AndrewHannigan May 4, 2025
1d2864e
fix test name
AndrewHannigan May 4, 2025
977d9fb
add --agent flag
AndrewHannigan May 5, 2025
17174a6
update docs
AndrewHannigan May 5, 2025
c5a58c1
Add test for --agent flag
AndrewHannigan May 5, 2025
b5a2f6b
add non agent check
AndrewHannigan May 5, 2025
e8f8740
add bad module variable test
AndrewHannigan May 5, 2025
abb734e
Update docs
AndrewHannigan May 5, 2025
6c40d84
Revert "Update docs"
AndrewHannigan May 5, 2025
1eab0d4
minor fix
AndrewHannigan May 5, 2025
d58c1f8
remove case
AndrewHannigan May 5, 2025
7729a71
fix docs
AndrewHannigan May 5, 2025
e07d301
single quotes
AndrewHannigan May 5, 2025
69089eb
Merge branch 'main' into add-agent-cli
AndrewHannigan May 6, 2025
ab6508f
fix failing test in unrelated module
AndrewHannigan May 6, 2025
9e9f67a
Revert "fix failing test in unrelated module"
AndrewHannigan May 6, 2025
551011e
raise helpful warning if missing a tty
AndrewHannigan May 6, 2025
3b799b7
fixes
AndrewHannigan May 6, 2025
f48e891
fixes
AndrewHannigan May 6, 2025
7b8b061
modify docs
AndrewHannigan May 6, 2025
651879c
cleanup unused imprt
AndrewHannigan May 6, 2025
36074b3
fix trailing whitespace
AndrewHannigan May 6, 2025
71d26ea
minor edit
AndrewHannigan May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,18 @@ If you have [uv](https://docs.astral.sh/uv/) installed, the quickest way to run
```bash
uvx --from pydantic-ai pai
```

### Custom Agents

You can specify a custom agent using the `--agent` flag with a module path and variable name:

```bash
pai --agent mymodule.submodule:my_agent "What's the weather today?"
```

The format must be `module:variable` where:
- `module` is the importable Python module path
- `variable` is the name of the Agent instance in that module


Additionally, you can directly launch CLI mode from an `Agent` instance using `Agent.run_cli()`.
28 changes: 24 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import asyncio
import importlib
import sys
from asyncio import CancelledError
from collections.abc import Sequence
Expand Down Expand Up @@ -83,7 +84,7 @@ def cli_system_prompt() -> str:
The user is running {sys.platform}."""


def cli(args_list: Sequence[str] | None = None) -> int:
def cli(args_list: Sequence[str] | None = None, agent: Agent[None, str] = cli_agent) -> int:
parser = argparse.ArgumentParser(
prog='pai',
description=f"""\
Expand All @@ -108,6 +109,11 @@ def cli(args_list: Sequence[str] | None = None) -> int:
# e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
qualified_model_names = [n for n in get_literal_values(KnownModelName.__value__) if ':' in n]
arg.completer = argcomplete.ChoicesCompleter(qualified_model_names) # type: ignore[reportPrivateUsage]
parser.add_argument(
'-a',
'--agent',
help='Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent"',
)
parser.add_argument(
'-l',
'--list-models',
Expand Down Expand Up @@ -139,8 +145,22 @@ def cli(args_list: Sequence[str] | None = None) -> int:
console.print(f' {model}', highlight=False)
return 0

# Load custom agent if specified
if args.agent:
try:
module_path, variable_name = args.agent.split(':')
module = importlib.import_module(module_path)
agent = getattr(module, variable_name)
if not isinstance(agent, Agent):
console.print(f'[red]Error: {args.agent} is not an Agent instance[/red]')
return 1
console.print(f'[green]Using custom agent:[/green] [magenta]{args.agent}[/magenta]', highlight=False)
except ValueError:
console.print('[red]Error: Agent must be specified in "module:variable" format[/red]')
return 1

try:
cli_agent.model = infer_model(args.model)
agent.model = infer_model(args.model)
except UserError as e:
console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]')
return 1
Expand All @@ -155,7 +175,7 @@ def cli(args_list: Sequence[str] | None = None) -> int:

if prompt := cast(str, args.prompt):
try:
asyncio.run(ask_agent(cli_agent, prompt, stream, console, code_theme))
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
except KeyboardInterrupt:
pass
return 0
Expand All @@ -164,7 +184,7 @@ def cli(args_list: Sequence[str] | None = None) -> int:
# doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
session: PromptSession[Any] = PromptSession(history=FileHistory(str(history)))
try:
return asyncio.run(run_chat(session, stream, cli_agent, console, code_theme))
return asyncio.run(run_chat(session, stream, agent, console, code_theme))
except KeyboardInterrupt: # pragma: no cover
return 0

Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,12 @@ async def run_mcp_servers(self) -> AsyncIterator[None]:
finally:
await exit_stack.aclose()

def run_cli(self: Agent[None, str]) -> None:
"""Run the agent in a CLI loop."""
from pydantic_ai._cli import cli

cli(agent=self)


@dataclasses.dataclass(repr=False)
class AgentRun(Generic[AgentDepsT, OutputDataT]):
Expand Down
79 changes: 78 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_cli_help(capfd: CaptureFixture[str]):

assert capfd.readouterr().out.splitlines() == snapshot(
[
'usage: pai [-h] [-m [MODEL]] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]',
'usage: pai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]',
'',
IsStr(),
'',
Expand All @@ -56,6 +56,8 @@ def test_cli_help(capfd: CaptureFixture[str]):
' -h, --help show this help message and exit',
' -m [MODEL], --model [MODEL]',
' Model to use, in format "<provider>:<model>" e.g. "openai:gpt-4o". Defaults to "openai:gpt-4o".',
' -a AGENT, --agent AGENT',
' Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent"',
' -l, --list-models List all available models and exit',
' -t [CODE_THEME], --code-theme [CODE_THEME]',
' Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "monokai".',
Expand All @@ -72,6 +74,72 @@ def test_invalid_model(capfd: CaptureFixture[str]):
)


def test_agent_flag(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')

# Create a dynamic module using types.ModuleType
import types

test_module = types.ModuleType('test_module')

# Create and add agent to the module
test_agent = Agent()
test_agent.model = TestModel(custom_output_text='Hello from custom agent')
setattr(test_module, 'custom_agent', test_agent)

# Register the module in sys.modules
sys.modules['test_module'] = test_module

try:
# Mock ask_agent to avoid actual execution but capture the agent
mock_ask = mocker.patch('pydantic_ai._cli.ask_agent')

# Test CLI with custom agent
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 0

# Verify the output contains the custom agent message
assert 'Using custom agent: test_module:custom_agent' in capfd.readouterr().out

# Verify ask_agent was called with our custom agent
mock_ask.assert_called_once()
assert mock_ask.call_args[0][0] is test_agent

finally:
# Clean up by removing the module from sys.modules
if 'test_module' in sys.modules:
del sys.modules['test_module']


def test_agent_flag_non_agent(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')

# Create a dynamic module using types.ModuleType
import types

test_module = types.ModuleType('test_module')

# Create and add agent to the module
test_agent = 'Not an Agent object'
setattr(test_module, 'custom_agent', test_agent)

# Register the module in sys.modules
sys.modules['test_module'] = test_module

try:
assert cli(['--agent', 'test_module:custom_agent', 'hello']) == 1
assert 'is not an Agent' in capfd.readouterr().out

finally:
# Clean up by removing the module from sys.modules
if 'test_module' in sys.modules:
del sys.modules['test_module']


def test_agent_flag_bad_module_variable_path(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
assert cli(['--agent', 'bad_path', 'hello']) == 1
assert 'Agent must be specified in "module:variable" format' in capfd.readouterr().out


def test_list_models(capfd: CaptureFixture[str]):
assert cli(['--list-models']) == 0
output = capfd.readouterr().out.splitlines()
Expand Down Expand Up @@ -189,3 +257,12 @@ def test_code_theme_dark(mocker: MockerFixture, env: TestEnv):
mock_run_chat.assert_awaited_once_with(
IsInstance(PromptSession), True, IsInstance(Agent), IsInstance(Console), 'monokai'
)


def test_agent_run_cli(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli_agent.run_cli()
mock_run_chat.assert_awaited_once_with(
IsInstance(PromptSession), True, IsInstance(Agent), IsInstance(Console), 'monokai'
)