Skip to content

Commit

Permalink
Merge pull request #8 from hmasdev/add-printer-selection
Browse files Browse the repository at this point in the history
Update printer functions and game configuration
  • Loading branch information
hmasdev authored Sep 2, 2024
2 parents d7e00ab + 43da93c commit 81a47f8
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
18 changes: 18 additions & 0 deletions tests/utils/test_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from werewolf.utils.printer import (
_print_dict,
create_print_func,
)


@pytest.mark.parametrize(
'key,expected',
list(_print_dict.items()),
)
def test_create_print_func(key, expected):
assert create_print_func(key) is expected


def test_create_print_func_invalid_key():
with pytest.raises(ValueError):
create_print_func('invalid_key')
12 changes: 7 additions & 5 deletions werewolf/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Iterable

import autogen
import click
from langchain_openai import ChatOpenAI

from .const import DEFAULT_MODEL, EGameMaster
from .config import GameConfig
from .game_master.base import BaseGameMaster
from .utils.openai import create_chat_openai_model
from .utils.printer import create_print_func


def game(
Expand All @@ -23,10 +23,12 @@ def game(
open_game: bool = False,
config_list=[{'model': DEFAULT_MODEL}],
llm: ChatOpenAI | str | None = None,
printer: str = 'click.echo',
log_file: str = 'werewolf.log',
logger: logging.Logger = logging.getLogger(__name__),
):
# preparation
print_func = create_print_func(printer)
llm = create_chat_openai_model(llm)
master = BaseGameMaster.instantiate(
EGameMaster.Default, # TODO
Expand Down Expand Up @@ -62,7 +64,7 @@ def game(
days = range(len(master.alive_players))
for _ in days:
# announce day
click.echo(f'=============================== Day {master.day} (Daytime) ================================') # noqa
print_func(f'=============================== Day {master.day} (Daytime) ================================') # noqa
master.announce(
'\n'.join([
f'Day {master.day}: Daytime.',
Expand All @@ -77,7 +79,7 @@ def game(
votes = master.daytime_discussion()
# exclude from the game
excluded_result = master.exclude_players_following_votes(votes)
click.echo('\n'.join([
print_func('\n'.join([
'============================== Excluded result ==============================', # noqa
str(excluded_result),
'=============================================================================', # noqa
Expand All @@ -89,7 +91,7 @@ def game(
break
# announce day
# announce night
click.echo(f'================================ Day {master.day} (Nighttime) ================================') # noqa
print_func(f'================================ Day {master.day} (Nighttime) ================================') # noqa
master.announce(
'\n'.join([
f'Day {master.day}: Nighttime.',
Expand All @@ -104,7 +106,7 @@ def game(
votes = master.nighttime_action()
# exclude from the game
excluded_result = master.exclude_players_following_votes(votes, announce_votes=False) # noqa
click.echo('\n'.join([
print_func('\n'.join([
'============================== Excluded result ==============================', # noqa
str(excluded_result),
'=============================================================================', # noqa
Expand Down
13 changes: 11 additions & 2 deletions werewolf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .const import DEFAULT_MODEL, ESpeakerSelectionMethod
from .game import game
from .utils.printer import KEYS_FOR_PRINTER, create_print_func


@click.command()
Expand All @@ -20,6 +21,7 @@
@click.option('-o', '--open-game', is_flag=True, help='Whether to open game or not.') # noqa
@click.option('-l', '--log', default=None, help='The log file name. Default is werewolf%Y%m%d%H%M%S.log') # noqa
@click.option('-m', '--model', default=DEFAULT_MODEL, help=f'The model name. Default is {DEFAULT_MODEL}.') # noqa
@click.option('-p', '--printer', default='click.echo', help=f'The printer name. The valid values is in {KEYS_FOR_PRINTER}. Default is click.echo.') # noqa
@click.option('--sub-model', default=DEFAULT_MODEL, help=f'The sub-model name. Default is {DEFAULT_MODEL}.') # noqa
@click.option('--log-level', default='WARNING', help='The log level, DEBUG, INFO, WARNING, ERROR or CRITICAL. Default is WARNING.') # noqa
@click.option('--debug', is_flag=True, help='Whether to show debug logs or not.') # noqa
Expand All @@ -34,16 +36,21 @@ def main(
open_game: bool,
log: str,
model: str,
printer: str,
sub_model: str,
log_level: str,
debug: bool,
):
load_dotenv(),
if os.environ.get('OPENAI_API_KEY') is None:
raise ValueError('You must set OPENAI_API_KEY in your environment variables or .env file.') # noqa

printer_func = create_print_func(printer)

log = f'werewolf{dt.now().strftime("%Y%m%d%H%M%S")}.log' if log is None else log # noqa
logging.basicConfig(level=logging.DEBUG if debug else getattr(logging, log_level.upper(), 'WARNING')) # type: ignore # noqa
config_list = [{'model': model}]

result = game(
n_players=n_players,
n_werewolves=n_werewolves,
Expand All @@ -53,12 +60,14 @@ def main(
speaker_selection_method=speaker_selection_method.value,
include_human=include_human,
open_game=open_game,
printer=printer,
log_file=log,
config_list=config_list,
llm=sub_model,
)
click.echo('================================ Game Result ================================') # noqa
click.echo(result)

printer_func('================================ Game Result ================================') # noqa
printer_func(result)


if __name__ == "__main__":
Expand Down
31 changes: 31 additions & 0 deletions werewolf/utils/printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from functools import partial
import logging
from typing import Callable
import click


_print_dict: dict[str, Callable[..., None]] = {
'print': print,
'click.echo': click.echo,
'logging.info': logging.info,
}

KEYS_FOR_PRINTER: tuple[str, ...] = tuple(_print_dict.keys())


def create_print_func(key: str, **kwargs) -> Callable[..., None]:
"""Create a print function.
Args:
key (str): key of the print function
**kwargs: keyword arguments to pass to the print function
Returns:
Callable[..., None]: print function
"""
try:
if kwargs:
return partial(_print_dict[key], **kwargs)
return _print_dict[key]
except KeyError:
raise ValueError(f'Invalid key: {key}. Valid keys are {list(KEYS_FOR_PRINTER)}') # noqa

0 comments on commit 81a47f8

Please sign in to comment.