From 859c06e439d0ba34b54f0fe7fd8489dbb51b0924 Mon Sep 17 00:00:00 2001 From: Alexey Potapov Date: Mon, 6 Nov 2023 13:54:12 +0300 Subject: [PATCH 1/2] support for spaces in llm-gate --- python/sandbox/neurospace/llm_gate.py | 140 +++++++++++------- .../sandbox/neurospace/test_guide_func.metta | 20 +++ .../neurospace/test_guide_prompt.metta | 5 + 3 files changed, 115 insertions(+), 50 deletions(-) create mode 100644 python/sandbox/neurospace/test_guide_func.metta create mode 100644 python/sandbox/neurospace/test_guide_prompt.metta diff --git a/python/sandbox/neurospace/llm_gate.py b/python/sandbox/neurospace/llm_gate.py index 6377f1dd5..4dce509cc 100644 --- a/python/sandbox/neurospace/llm_gate.py +++ b/python/sandbox/neurospace/llm_gate.py @@ -27,60 +27,100 @@ def atom2msg(atom): return v.replace("\\n", "\n") return repr(atom) -def get_message_list(msg_atoms): - ''' - Convert atoms to ChatGPT messages and flatten a possibly nested message list - ''' - messages = [] - for msg in msg_atoms: - if isinstance(msg, ExpressionAtom): - ch = msg.get_children() - if len(ch) == 0: - continue - if ch[0].get_name() == 'Messages': - messages += get_message_list(ch[1:]) - else: - messages += [{"role": ch[0].get_name(), "content": atom2msg(ch[1])}] - else: - raise TypeError("Messages should be tagged by the role") - return messages - -def llm(metta: MeTTa, *args): +def get_llm_args(metta: MeTTa, prompt_space: SpaceRef, *args): messages = [] functions = [] - msgs = None + msg_atoms = [] + def __msg_update(m, f, a): + nonlocal messages, functions, msg_atoms + messages += m + functions += f + msg_atoms += [a] for arg in args: - if isinstance(arg, ExpressionAtom): + if isinstance(arg, GroundedAtom) and \ + isinstance(arg.get_object(), SpaceRef): + # FIXME? This will overwrites the current prompt_space if it is set. + # It is convenient to have it here to successfully execute + # (llm &prompt (Functions fn)), when fn is defined in &prompt. + # But (function fn) can also be put in &prompt directly. + # Depending on what is more convenient, this overriding can be changed. + prompt_space = arg.get_object() + __msg_update(*get_llm_args(metta, prompt_space, *prompt_space.get_atoms())) + elif isinstance(arg, ExpressionAtom): ch = arg.get_children() - if len(ch) > 1 and ch[0].get_name() == 'Messages': - msgs = arg - messages += get_message_list(ch[1:]) - if len(ch) > 1 and ch[0].get_name() == 'Functions': - for fn in ch[1:]: - doc = metta.run(f"! (doc {fn})") - if len(doc) == 0: - # TODO: error / warning - continue - # TODO: format is not checked - doc = doc[0][0].get_children() - properties = {} - for par in doc[2].get_children()[1:]: - p = par.get_children() - properties.update({ - p[0].get_name(): { - "type": "string", - "description": p[1].get_object().value, - "enum": list(map(lambda x: x.get_object().value, p[2].get_children())) + if len(ch) > 1: + name = ch[0].get_name() + if name == 'Messages': + __msg_update(*get_llm_args(metta, prompt_space, *ch[1:])) + elif name in ['system', 'user', 'assistant']: + # We have to interpret the message in the main space context, + # if the prompt template is in a separate file and contains + # some external symbols like (user-query) + msg = interpret(metta.space(), ch[1])[0] + messages += [{'role': name, 'content': atom2msg(msg)}] + msg_atoms += [arg] + elif name in ['Functions', 'function']: + for fn in ch[1:]: + doc = None + if prompt_space is not None: + # TODO: Querying for a function description in prompt_space works well, + # but it is useless, because this function cannot be called + # from the main script, so the functional call is not reduced. + # Fixing this requires in general better library management in MeTTa, + # although it can be managed here by interpreting the functional call expression. + # Another approach would be to have load-template, which will import all functions to &self + # (or just to declare function in separate files and load to self, since we may want them + # to be reusable between templates) + r = prompt_space.query(E(S('='), E(S('doc'), fn), V('r'))) + if not r.is_empty(): + doc = r[0]['r'] + if doc is None: + # We use `match` here instead of direct `doc` evaluation + # to evoid non-reduced `doc` + doc = metta.run(f"! (match &self (= (doc {fn}) $r) $r)") + if len(doc) == 0 or len(doc[0]) == 0: + raise RuntimeError(f"No {fn} function description") + doc = doc[0][0] + # TODO: format is not checked + doc = doc.get_children() + properties = {} + for par in doc[2].get_children()[1:]: + p = par.get_children() + properties.update({ + p[0].get_name(): { + "type": "string", + "description": p[1].get_object().value, + "enum": list(map(lambda x: x.get_object().value, p[2].get_children())) + } + }) + functions += [{ + "name": fn.get_name(), + "description": doc[1].get_children()[1].get_object().value, + "parameters": { + "type": "object", + "properties": properties } - }) - functions += [{ - "name": fn.get_name(), - "description": doc[1].get_children()[1].get_object().value, - "parameters": { - "type": "object", - "properties": properties - } - }] + }] + elif name == '=': + # We ignore equalities here: if a space is used to store messages, + # it can contain equalities as well (another approach would be to + # ignore everythins except valid roles) + continue + else: + raise RuntimeError("Unrecognized argument: " + repr(arg)) + else: + # Ignore an empty expression () for convenience, but we need + # to put it back into msg_atoms to keep the structure + msg_atoms += [arg] + else: + raise RuntimeError("Unrecognized argument: " + repr(arg)) + # Do not wrap a single message into Message (necessary to avoid double + # wrapping of single Message argument) + return messages, functions, \ + msg_atoms[0] if len(msg_atoms) == 1 else E(S('Messages'), *msg_atoms) + +def llm(metta: MeTTa, *args): + messages, functions, msgs_atom = get_llm_args(metta, None, *args) #print(messages) #return [] if functions==[]: @@ -104,7 +144,7 @@ def llm(metta: MeTTa, *args): fs = S(response_message["function_call"]["name"]) args = response_message["function_call"]["arguments"] args = json.loads(args) - return [E(fs, to_nested_expr(list(args.values())), msgs)] + return [E(fs, to_nested_expr(list(args.values())), msgs_atom)] return [ValueAtom(response_message['content'])] @register_atoms(pass_metta=True) diff --git a/python/sandbox/neurospace/test_guide_func.metta b/python/sandbox/neurospace/test_guide_func.metta new file mode 100644 index 000000000..224d28861 --- /dev/null +++ b/python/sandbox/neurospace/test_guide_func.metta @@ -0,0 +1,20 @@ +!(extend-py! llm_gate) + +!(import! &msgs test_guide_prompt.metta) + +; Function for calls still cannot be put into the prompt space, because +; they will not be evaluated, when the functional call is evoked here +(= (doc calc_math) + (Doc + (description "You should call this function with a mathematical expression in Scheme") + (parameters + (expression "Mathematical expression in Scheme" ()) + )) +) +; This is another limitation: LLM output should somehow be transormed to MeTTa expressions +(= (calc_math $expr $msgs) + ($expr is not evaluated, because it is a string atm)) + +(= (user-query) "What is the result of 111102 + 18333?") + +! (llm &msgs) diff --git a/python/sandbox/neurospace/test_guide_prompt.metta b/python/sandbox/neurospace/test_guide_prompt.metta new file mode 100644 index 000000000..c30944571 --- /dev/null +++ b/python/sandbox/neurospace/test_guide_prompt.metta @@ -0,0 +1,5 @@ +(system "Answer the user question. Try to reason carefully.") + +(user (user-query)) + +(function calc_math) From ef91eda87aedcea917ade8eca514a191d778a398 Mon Sep 17 00:00:00 2001 From: Alexey Potapov Date: Mon, 6 Nov 2023 15:59:35 +0300 Subject: [PATCH 2/2] simple but useful postprocessing added --- python/sandbox/neurospace/llm_gate.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sandbox/neurospace/llm_gate.py b/python/sandbox/neurospace/llm_gate.py index 4dce509cc..23044d968 100644 --- a/python/sandbox/neurospace/llm_gate.py +++ b/python/sandbox/neurospace/llm_gate.py @@ -159,3 +159,13 @@ def llmgate_atoms(metta): r"atom2msg": msgAtom } + +def str_find_all(str, values): + return list(filter(lambda v: v in str, values)) + +@register_atoms +def postproc_atoms(): + strfindAtom = OperationAtom('str-find-all', str_find_all) + return { + r"str-find-all": strfindAtom, + }