From 4e6e260dbd5ebdab38be975e991dd37c5e567409 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Sun, 19 Nov 2023 02:03:48 +0800 Subject: [PATCH] add more options for askchat --- askchat/__init__.py | 2 +- askchat/askchat.py | 55 ++++++++++++++++++++++++++++++++++++++++----- setup.py | 2 +- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/askchat/__init__.py b/askchat/__init__.py index dd8db72..0d3e0a8 100644 --- a/askchat/__init__.py +++ b/askchat/__init__.py @@ -2,6 +2,6 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '0.0.4' +__version__ = '0.1.0' from .askchat import ask \ No newline at end of file diff --git a/askchat/askchat.py b/askchat/askchat.py index be6c4e3..76ecf09 100644 --- a/askchat/askchat.py +++ b/askchat/askchat.py @@ -1,10 +1,11 @@ """Main module.""" from chattool import Chat, debug_log -import asyncio +import asyncio, os, uuid from argparse import ArgumentParser import askchat from pprint import pprint VERSION = askchat.__version__ +CONFIG_FILE = os.path.expanduser("~/.askrc") # print the response in a typewriter way async def show_resp(chat): @@ -36,6 +37,9 @@ def main(): parser.add_argument('--valid-models', action='store_true', help='Print valid models that contain "gpt" in their names') parser.add_argument('--all-valid-models', action='store_true', help='Print all valid models') parser.add_argument('-m', '--model', default=None, help='Model name') + parser.add_argument('--base-url', default=None, help='base url of the api(without suffix `/v1`)') + parser.add_argument("--api-key", default=None, help="API key") + parser.add_argument('--generate-config', action="store_true", help="Generate a configuration file by environment table.") args = parser.parse_args() # show debug log if args.debug: @@ -50,11 +54,52 @@ def main(): print('All valid models:') pprint(Chat().get_valid_models(gpt_only=False)) return - # get message and model - model, msg = args.model, args.message + if args.generate_config is not None: + api_key = os.environ.get("OPENAI_API_KEY", "") + base_url = os.environ.get("OPENAI_API_BASE_URL", "") + model = os.environ.get("OPENAI_API_MODEL", "") + if os.path.exists(CONFIG_FILE): + # create a temporary file + os.makedirs("/tmp", exist_ok=True) + tmp_file = os.path.join("/tmp", str(uuid.uuid4())[:8] + ".askrc") + # move the old config file to a temporary file + os.rename(CONFIG_FILE, tmp_file) + print(f"Moved old config file to {tmp_file}") + with open(CONFIG_FILE, "w") as f: + # description for the config file + f.write("#!/bin/bash\n" +\ + "# Description: This is a configuration file for askchat.\n" +\ + "# Author: Rex Wang\n" +\ + "# Current version: " + VERSION + "\n\n") + # write the environment table + f.write("# Your API key\n") + f.write(f"OPENAI_API_KEY={api_key}\n") + f.write("# The base url of the API (without suffix /v1)\n") + f.write(f"OPENAI_API_BASE_URL={base_url}\n") + f.write("# The model name. You can use `askchat --all-valid-models` to see the valid models.\n") + f.write(f"OPENAI_API_MODEL={model}\n") + print("Created config file at", CONFIG_FILE) + return + + # get message, model, and base url + msg = args.message if isinstance(msg, list): msg = ' '.join(msg) assert len(msg.strip()), 'Please specify message' - # call - chat = Chat(msg, model=model) + # read para from config or args + if os.path.exists(CONFIG_FILE): + with open(CONFIG_FILE) as f: + lines = f.readlines() + for line in lines: + if line.startswith("OPENAI_API_KEY="): + api_key = line.split("=")[-1].strip() + elif line.startswith("OPENAI_API_BASE_URL="): + base_url = line.split("=")[-1].strip() + elif line.startswith("OPENAI_API_MODEL="): + model = line.split("=")[-1].strip() + api_key = args.api_key if hasattr(args, "api_key") else api_key + base_url = args.base_url if hasattr(args, "base_url") else base_url + model = args.model if hasattr(args, "model") else model + # call the function + chat = Chat(msg, model=model, base_url=base_url, api_key=api_key) asyncio.run(show_resp(chat)) \ No newline at end of file diff --git a/setup.py b/setup.py index 1c85d3e..805319d 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ from setuptools import setup, find_packages -VERSION = '0.0.5' +VERSION = '0.1.0' with open('README.md') as readme_file: readme = readme_file.read()