From a375cba3cb11158a305b566b4096da5c15e35d46 Mon Sep 17 00:00:00 2001 From: Marius Wichtner <2mawi2@gmail.com> Date: Sun, 19 Mar 2023 17:12:00 +0100 Subject: [PATCH] Stream the model suggestion directly into the terminal --- copilot/main.py | 67 ++++++++++++++------------------------ copilot/open_ai_adapter.py | 45 +++++++++++++++++++++++++ copilot/parse_args.py | 5 +++ copilot/strip.py | 10 ++++++ copilot/test_copilot.py | 4 +-- 5 files changed, 87 insertions(+), 44 deletions(-) create mode 100644 copilot/open_ai_adapter.py create mode 100644 copilot/strip.py diff --git a/copilot/main.py b/copilot/main.py index 4b0a1ce..af9729a 100644 --- a/copilot/main.py +++ b/copilot/main.py @@ -1,7 +1,5 @@ # Program to serve as a terminal copilot for the user -import enum import sys -import argparse import subprocess import openai import pyperclip @@ -9,10 +7,10 @@ from urllib.parse import quote import platform import json -import re from copilot import history from conversation import Conversation +from open_ai_adapter import request_cmds, stream_cmd_into_terminal from parse_os import parse_operating_system, OperatingSystem from parse_args import parse_terminal_copilot_args from messages_builder import Context, build_conversation @@ -30,7 +28,6 @@ def is_unix_system(): def main(): args = parse_terminal_copilot_args() - if args.verbose: print("Verbose mode enabled") @@ -76,21 +73,30 @@ def main(): print("To set the environment variable, run:") print("export OPENAI_API_KEY=") sys.exit(1) - cmds = request_cmds(conversation, n=int(args.count) if args.json and args.count else 1) if args.json: + cmds = request_cmds(conversation, n=int(args.count) if args.json and args.count else 1) print(json.dumps({ "commands": cmds, "explainshell_links": list(map(get_explainshell_link, cmds)) })) else: - show_command_options(conversation, cmds[0]) + cmds = fetch_and_print_cmd(conversation, args) + show_command_options(conversation, cmds, args) + + +def fetch_and_print_cmd(conversation, args): + if args.no_stream: + cmds = request_cmds(conversation, n=int(args.count) if args.json and args.count else 1)[0] + print(f"\033[94m> {cmds}\033[0m") + else: + cmds = stream_cmd_into_terminal(conversation) + return cmds -def show_command_options(conversation: Conversation, cmd): +def show_command_options(conversation: Conversation, cmd, args): operating_system = platform.system() - print(f"\033[94m> {cmd}\033[0m") options = ["execute", "refine", "copy", "explainshell", "show more options"] if is_unix_system(): @@ -108,9 +114,9 @@ def show_command_options(conversation: Conversation, cmd): menu_entry_index = options.index(answers["menu_entry_index"]) if menu_entry_index == 0: - execute(conversation, cmd) + execute(conversation, cmd, args) elif menu_entry_index == 1: - refine_command(conversation, cmd) + refine_command(conversation, cmd, args) elif menu_entry_index == 2: print("> copied") pyperclip.copy(cmd) @@ -130,7 +136,7 @@ def read_input(): return input("> ") -def refine_command(conversation: Conversation, cmd): +def refine_command(conversation: Conversation, cmd, args): refinement = read_input() conversation.messages.append({"role": "assistant", "content": cmd}) refinement_command = f"""The user requires a command for the following prompt: `{refinement}`. @@ -138,11 +144,11 @@ def refine_command(conversation: Conversation, cmd): Do not add any text in front of it and do not add any text after it. The command the user is looking for is: `""" conversation.messages.append({"role": "user", "content": refinement_command}) - cmd = request_cmds(conversation, n=1)[0] - show_command_options(conversation, cmd) + cmd = fetch_and_print_cmd(conversation, args) + show_command_options(conversation, cmd, args) -def execute(conversation: Conversation, cmd): +def execute(conversation: Conversation, cmd, args): try: result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8') out = result.stdout @@ -150,23 +156,23 @@ def execute(conversation: Conversation, cmd): print(out) print(error) if error != "" and error is not None: - refine_failed_command(conversation, cmd, error) + refine_failed_command(conversation, cmd, error, args) else: history.save(cmd) except Exception as e: print(e) - refine_failed_command(conversation, cmd, str(e)) + refine_failed_command(conversation, cmd, str(e), args) -def refine_failed_command(conversation: Conversation, cmd, error): +def refine_failed_command(conversation: Conversation, cmd, error, args): error = error[:300] conversation.messages.append({"role": "assistant", "content": cmd}) failed_command = f"The last suggested command of the assistant failed with the error: `{error}`." \ f"The corrected command (and only the command) is:`" conversation.messages.append({"role": "user", "content": failed_command}) - cmd = request_cmds(conversation, n=1)[0] print("The last command failed. Here is suggested corrected command:") - show_command_options(conversation, cmd) + cmd = fetch_and_print_cmd(conversation, args) + show_command_options(conversation, cmd, args) def show_more_cmd_options(conversation: Conversation): @@ -194,33 +200,10 @@ def show_more_cmd_options(conversation: Conversation): show_command_options(conversation, cmds[cmd_menu_entry_index]) -def request_cmds(conversation: Conversation, n=1): - response = openai.ChatCompletion.create( - model=conversation.model.value, - messages=conversation.messages, - temperature=0, - max_tokens=1000, - top_p=0.2, - stop=["`"], - frequency_penalty=0, - presence_penalty=0, - n=n, - ) - choices = response.choices - cmds = strip(choices) - if len(cmds) > 1: - cmds = list(dict.fromkeys(cmds)) - return cmds - - def get_explainshell_link(cmd): return "https://explainshell.com/explain?cmd=" + quote(cmd) -def strip(choices): - return [re.sub('`[^`]*(`|$)', r'\1', choice.message.content) for choice in choices] - - def git_info(): git_installed = ( subprocess.run(["which", "git"], capture_output=True).returncode == 0 diff --git a/copilot/open_ai_adapter.py b/copilot/open_ai_adapter.py new file mode 100644 index 0000000..932da12 --- /dev/null +++ b/copilot/open_ai_adapter.py @@ -0,0 +1,45 @@ +import sys + +import openai + +from conversation import Conversation +from strip import strip_cmd, strip_choices + + +def _create_chat_completion(conversation, n, stream=False): + return openai.ChatCompletion.create( + model=conversation.model.value, + messages=conversation.messages, + temperature=0, + max_tokens=1000, + top_p=0.2, + stop=["`"], + frequency_penalty=0, + presence_penalty=0, + n=n, + stream=stream, + ) + + +def request_cmds(conversation: Conversation, n=1): + response = _create_chat_completion(conversation, n) + choices = response.choices + cmds = strip_choices(choices) + if len(cmds) > 1: + cmds = list(dict.fromkeys(cmds)) + return cmds + + +def stream_cmd_into_terminal(conversation: Conversation) -> str: + response = _create_chat_completion(conversation, n=1, stream=True) + print(f"\033[94m> ", end='') + cmd = "" + for chunk in response: + if "content" in chunk["choices"][0]["delta"]: + cmd_delta = chunk["choices"][0]["delta"]["content"] + cmd_delta = strip_cmd(cmd_delta) + print(cmd_delta, end='') + sys.stdout.flush() + cmd += cmd_delta + print("\033[0m") + return strip_cmd(cmd) diff --git a/copilot/parse_args.py b/copilot/parse_args.py index 811cc80..e018e80 100644 --- a/copilot/parse_args.py +++ b/copilot/parse_args.py @@ -36,5 +36,10 @@ def parse_terminal_copilot_args(): "-m", "--model", type=argparse_model_type, default=Model.GPT_35_TURBO, help="The model to use. Defaults to gpt-3.5-turbo." ) + parser.add_argument( + "-ns", "--no-stream", action="store_true", + default=False, + help="Disable streaming the command into the terminal (by default, streaming is enabled)." + ) args = parser.parse_args() return args diff --git a/copilot/strip.py b/copilot/strip.py new file mode 100644 index 0000000..6dbbbc2 --- /dev/null +++ b/copilot/strip.py @@ -0,0 +1,10 @@ +import re + + +def strip_cmd(cmd): + return re.sub('`[^`]*(`|$)', r'\1', cmd) + + +def strip_choices(choices): + return [strip_cmd(choice.message.content) for choice in choices] + diff --git a/copilot/test_copilot.py b/copilot/test_copilot.py index 15491c4..ffc54a7 100644 --- a/copilot/test_copilot.py +++ b/copilot/test_copilot.py @@ -74,7 +74,7 @@ def test_model_should_decline_unrelated_requests(self, mock_terminal_menu, fake_ self.assertIn(expected_command, output) def execute_prompt(self, fake_stdout, prompt): - sys.argv = ["copilot", prompt] + sys.argv = ["copilot", prompt, "--no-stream"] main() output = fake_stdout.getvalue() return output @@ -95,7 +95,7 @@ def test_model_should_refine_command(self, mock_terminal_menu, fake_stdout, mock terminal_menu_mock = MagicMock() terminal_menu_mock.show.side_effect = [REFINE, NO_EXECUTION] mock_terminal_menu.return_value = terminal_menu_mock - sys.argv = ["copilot", prompt] + sys.argv = ["copilot", prompt, "--no-stream"] # act main() output = fake_stdout.getvalue()