Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call PDL from Python API #7

Merged
merged 10 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/sdk/hello.pdl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
document:
- Hello,
- model: watsonx/ibm/granite-20b-code-instruct
parameters:
stop:
- '!'
include_stop_sequence: true
- "\n"
21 changes: 21 additions & 0 deletions examples/sdk/hello_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pdl.pdl import exec_dict

hello = {
"document": [
"Hello,",
{
"model": "watsonx/ibm/granite-20b-code-instruct",
"parameters": {"stop": ["!"], "include_stop_sequence": True},
},
"\n",
]
}


def main():
result = exec_dict(hello)
print(result)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions examples/sdk/hello_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pdl.pdl import exec_file


def main():
result = exec_file("./hello.pdl")
print(result)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions examples/sdk/hello_prog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pdl.pdl import exec_program
from pdl.pdl_ast import DocumentBlock, LitellmModelBlock, LitellmParameters, Program

hello = Program(
DocumentBlock(
document=[
"Hello,",
LitellmModelBlock(
model="watsonx/ibm/granite-20b-code-instruct",
parameters=LitellmParameters(
stop=["!"], include_stop_sequence=True # pyright: ignore
),
),
"\n",
]
)
)


def main():
result = exec_program(hello)
print(result)


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions examples/sdk/hello_str.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pdl.pdl import exec_str

HELLO = """
document:
- Hello,
- model: watsonx/ibm/granite-20b-code-instruct
parameters:
stop:
- '!'
include_stop_sequence: true
- "\n"
"""


def main():
result = exec_str(HELLO)
print(result)


if __name__ == "__main__":
main()
116 changes: 115 additions & 1 deletion pdl/pdl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,125 @@
import argparse
import json
from typing import Any, Optional, TypedDict

import yaml
from pydantic.json_schema import models_json_schema

from . import pdl_interpreter
from .pdl_ast import PdlBlock, PdlBlocks, Program
from .pdl_ast import (
LocationType,
PdlBlock,
PdlBlocks,
Program,
RoleType,
ScopeType,
empty_block_location,
)
from .pdl_interpreter import InterpreterState, process_prog
from .pdl_parser import parse_file, parse_str


class InterpreterConfig(TypedDict, total=False):
"""Configuration parameters of the PDL interpreter."""

yield_output: bool
"""Print the program messages during the execution.
"""
batch: int
"""Execution type:
- 0: streaming
- 1: non-streaming
"""
role: RoleType
"""Default role.
"""


def exec_program(
prog: Program,
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
loc: Optional[LocationType] = None,
):
"""Execute a PDL program given as a value of type `pdl.pdl_ast.Program`.

Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.
loc: Source code location mapping. Defaults to None.

Returns:
Return the final result.
"""
config = config or {}
state = InterpreterState(**config)
scope = scope or {}
loc = loc or empty_block_location
result = process_prog(state, scope, prog, loc)
return result


def exec_dict(
prog: dict[str, Any],
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
loc: Optional[LocationType] = None,
):
"""Execute a PDL program given as a dictionary.

Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.
loc: Source code location mapping. Defaults to None.

Returns:
Return the final result.
"""
program = Program.model_validate(prog)
result = exec_program(program, config, scope, loc)
return result


def exec_str(
prog: str,
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
):
"""Execute a PDL program given as YAML string.

Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.

Returns:
Return the final result.
"""
program, loc = parse_str(prog)
result = exec_program(program, config, scope, loc)
return result


def exec_file(
prog: str,
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
):
"""Execute a PDL program given as YAML file.

Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.

Returns:
Return the final result.
"""
program, loc = parse_file(prog)
result = exec_program(program, config, scope, loc)
return result


def main():
Expand Down
129 changes: 69 additions & 60 deletions pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .pdl_dumper import block_to_dict, dump_yaml
from .pdl_llms import BamModel, LitellmModel, WatsonxModel
from .pdl_location_utils import append, get_loc_string
from .pdl_parser import PDLParseError, parse_program
from .pdl_parser import PDLParseError, parse_file
from .pdl_scheduler import ModelCallMessage, OutputMessage, YieldMessage, schedule
from .pdl_schema_validator import type_check_args, type_check_spec

Expand Down Expand Up @@ -105,9 +105,8 @@ def generate(
if log_file is None:
log_file = "log.txt"
try:
prog, line_table = parse_program(pdl_file)
prog, loc = parse_file(pdl_file)
state = InterpreterState(yield_output=True)
loc = LocationType(path=[], file=pdl_file, table=line_table)
_, _, _, trace = process_prog(state, initial_scope, prog, loc)
with open(log_file, "w", encoding="utf-8") as log_fp:
for line in state.log:
Expand Down Expand Up @@ -515,61 +514,10 @@ def step_block_body(
result = closure
background = []
trace = closure.model_copy(update={})
case CallBlock(call=f):
result = None
background = []
args, errors = process_expr(scope, block.args, append(loc, "args"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "args"), None, errors, block.model_copy()
)
closure_expr, errors = process_expr(scope, block.call, append(loc, "call"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "call"), None, errors, block.model_copy()
)
closure = get_var(closure_expr, scope)
if closure is None:
trace = handle_error(
block,
append(loc, "call"),
f"Function is undefined: {f}",
[],
block.model_copy(),
)
else:
argsloc = append(loc, "args")
type_errors = type_check_args(args, closure.function, argsloc)
if len(type_errors) > 0:
trace = handle_error(
block,
argsloc,
f"Type errors during function call to {f}",
type_errors,
block.model_copy(),
)
else:
f_body = closure.returns
f_scope = closure.scope | {"context": scope["context"]} | args
funloc = LocationType(
file=closure.location.file,
path=closure.location.path + ["return"],
table=loc.table,
)
result, background, _, f_trace = yield from step_blocks(
IterationType.SEQUENCE, state, f_scope, f_body, funloc
)
trace = block.model_copy(update={"trace": f_trace})
if closure.spec is not None:
errors = type_check_spec(result, closure.spec, funloc)
if len(errors) > 0:
trace = handle_error(
block,
loc,
f"Type errors in result of function call to {f}",
errors,
trace,
)
case CallBlock():
result, background, scope, trace = yield from step_call(
state, scope, block, loc
)
case EmptyBlock():
result = ""
background = []
Expand Down Expand Up @@ -1074,6 +1022,68 @@ def call_python(code: str, scope: dict) -> Any:
return result


def step_call(
state: InterpreterState, scope: ScopeType, block: CallBlock, loc: LocationType
) -> Generator[
YieldMessage, Any, tuple[Any, Messages, ScopeType, CallBlock | ErrorBlock]
]:
result = None
background: Messages = []
args, errors = process_expr(scope, block.args, append(loc, "args"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "args"), None, errors, block.model_copy()
)
closure_expr, errors = process_expr(scope, block.call, append(loc, "call"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "call"), None, errors, block.model_copy()
)
closure = get_var(closure_expr, scope)
if closure is None:
trace = handle_error(
block,
append(loc, "call"),
f"Function is undefined: {block.call}",
[],
block.model_copy(),
)
else:
argsloc = append(loc, "args")
type_errors = type_check_args(args, closure.function, argsloc)
if len(type_errors) > 0:
trace = handle_error(
block,
argsloc,
f"Type errors during function call to {closure_expr}",
type_errors,
block.model_copy(),
)
else:
f_body = closure.returns
f_scope = closure.scope | {"context": scope["context"]} | args
funloc = LocationType(
file=closure.location.file,
path=closure.location.path + ["return"],
table=loc.table,
)
result, background, _, f_trace = yield from step_blocks(
IterationType.SEQUENCE, state, f_scope, f_body, funloc
)
trace = block.model_copy(update={"trace": f_trace})
if closure.spec is not None:
errors = type_check_spec(result, closure.spec, funloc)
if len(errors) > 0:
trace = handle_error(
block,
loc,
f"Type errors in result of function call to {closure_expr}",
errors,
trace,
)
return result, background, scope, trace


def process_input(
state: InterpreterState, scope: ScopeType, block: ReadBlock, loc: LocationType
) -> tuple[str, Messages, ScopeType, ReadBlock | ErrorBlock]:
Expand Down Expand Up @@ -1121,8 +1131,7 @@ def step_include(
YieldMessage, Any, tuple[Any, Messages, ScopeType, IncludeBlock | ErrorBlock]
]:
try:
prog, line_table = parse_program(block.include)
newloc = LocationType(file=block.include, path=[], table=line_table)
prog, newloc = parse_file(block.include)
result, background, scope, trace = yield from step_block(
state, scope, prog.root, newloc
)
Expand Down
4 changes: 2 additions & 2 deletions pdl/pdl_location_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def get_paths(
return ret


def get_line_map(file) -> dict[str, int]:
def get_line_map(prog: str) -> dict[str, int]:
indentation = []
fields = []
is_array_item = []
for line in file.readlines(): # line numbers are off by one
for line in prog.split("\n"): # line numbers are off by one
fields.append(
line.strip().split(":")[0].replace("-", "").strip()
if line.find(":") != -1
Expand Down
Loading