Skip to content

Commit

Permalink
feat(service/prompt.py): optimize intention and generate with citatio…
Browse files Browse the repository at this point in the history
…n format (#393)

* feat(service/prompt.py): optimize intention parsing and citation generate
  • Loading branch information
tpoisonooo authored Oct 14, 2024
1 parent 1651950 commit 9c98f42
Show file tree
Hide file tree
Showing 16 changed files with 286 additions and 129 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ After running, test with `python3 -m huixiangdou.main --standalone`. At this tim
python3 -m huixiangdou.main --standalone

+---------------------------+---------+----------------------------+-----------------+
| Query | State | Part of Reply | References |
| Query | State | Reply | References |
+===========================+=========+============================+=================+
| How to install mmpose? | success | To install mmpose, plea.. | installation.md |
--------------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ python3 -m huixiangdou.service.feature_store
python3 -m huixiangdou.main --standalone

+-----------------------+---------+--------------------------------+-----------------+
| Query | State | Part of Reply | References |
| Query | State | Reply | References |
+=======================+=========+================================+=================+
| 请问如何安装 mmpose ? | success | 要安装 mmpose,请按照以下步骤操作..| installation.md |
--------------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions huixiangdou/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def ymd():

def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser(description='SerialPipeline.')
parser = argparse.ArgumentParser(description='Gradio UI for parallel/serial pipeline.')
parser.add_argument('--work_dir',
type=str,
default='workdir',
Expand All @@ -34,7 +34,7 @@ def parse_args():
'--config_path',
default='config.ini',
type=str,
help='SerialPipeline configuration path. Default value is config.ini')
help='Pipeline configuration path. Default value is config.ini')
parser.add_argument('--standalone',
action='store_true',
default=True,
Expand Down
4 changes: 2 additions & 2 deletions huixiangdou/service/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,11 @@ def test_query(retriever: Retriever, sample: str = None):

table = Texttable()
table.set_cols_valign(['t', 't', 't', 't'])
table.header(['Query', 'State', 'Part of Chunks', 'References'])
table.header(['Query', 'State', 'Chunks', 'References'])

for example in real_questions:
example = example[0:400]
chunks, context, refs = retriever.query(example)
chunks, context, refs, context_texts = retriever.query(example)
if chunks:
table.add_row(
[example, 'Accepted', chunks[0:100] + '..', ','.join(refs)])
Expand Down
2 changes: 1 addition & 1 deletion huixiangdou/service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def extract_json_from_str(raw: str):
def build_reply_text(code, query: str, reply: str, refs: list, max_len:int=20):
table = Texttable()
table.set_cols_valign(['t', 't', 't', 't'])
table.header(['Query', 'State', 'Part of Reply', 'References'])
table.header(['Query', 'State', 'Reply', 'References'])
table.add_row([query, str(code), reply[0:max_len] + '..', ','.join(refs)])
return table.draw()

Expand Down
9 changes: 5 additions & 4 deletions huixiangdou/service/llm_server_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import pdb
import pytoml
import requests
from loguru import logger
from openai import OpenAI
from huixiangdou.primitive import RPM, TPM
Expand Down Expand Up @@ -38,7 +37,6 @@ def check_gpu_max_memory_gb():
logger.error(str(e))
return -1


def build_messages(prompt, history, system: str = None):
messages = []
if system is not None and len(system) > 0:
Expand Down Expand Up @@ -340,7 +338,6 @@ async def call_openai(self,
system=system)

logger.debug('remote api sending: {}'.format(messages))

stream = client.chat.completions.create(
model=model,
messages=messages,
Expand Down Expand Up @@ -415,7 +412,11 @@ async def chat_stream(self, prompt, history=[], backend='local'):
target_fn = map_fn[backend]

# build args for `target_fn`
args = {'prompt': prompt, 'history': history, 'model':self.backend2model[backend]}
default_model = self.backend2model[backend]
model = self.server_config['remote_llm_model']
if model is None or len(model) < 1:
model = default_model
args = {'prompt': prompt, 'history': history, 'model': model}
if backend in map_base_url:
args['base_url'] = map_base_url[backend]

Expand Down
86 changes: 54 additions & 32 deletions huixiangdou/service/parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,23 @@
"""Pipeline."""
import argparse
import asyncio
import datetime
import json
import os
import re
import time
import pdb
import copy
from abc import ABC, abstractmethod
from typing import List, Tuple, Union, Generator
from typing import List, Tuple, Union, Generator, AsyncGenerator

import pytoml
from loguru import logger

from huixiangdou.primitive import Query, Chunk

from .helper import ErrorCode, is_truth
from .helper import ErrorCode
from .llm_client import ChatClient
from .retriever import CacheRetriever, Retriever
from .sg_search import SourceGraphProxy
from .session import Session
from .web_search import WebSearch
from .prompt import (SCORING_QUESTION_TEMPLTE_CN, CR_NEED_CN, CR_CN, TOPIC_TEMPLATE_CN, SCORING_RELAVANCE_TEMPLATE_CN, GENERATE_TEMPLATE_CN, KEYWORDS_TEMPLATE_CN, PERPLESITY_TEMPLATE_CN, SECURITY_TEMAPLTE_CN)
from .prompt import (SCORING_QUESTION_TEMPLTE_EN, CR_NEED_EN, CR_EN, TOPIC_TEMPLATE_EN, SCORING_RELAVANCE_TEMPLATE_EN, GENERATE_TEMPLATE_EN, KEYWORDS_TEMPLATE_EN, PERPLESITY_TEMPLATE_EN, SECURITY_TEMAPLTE_EN)
from .prompt import (INTENTION_TEMPLATE_CN, CR_CN, SCORING_RELAVANCE_TEMPLATE_CN, KEYWORDS_TEMPLATE_CN)
from .prompt import (INTENTION_TEMPLATE_EN, CR_EN, SCORING_RELAVANCE_TEMPLATE_EN, KEYWORDS_TEMPLATE_EN)
from .prompt import CitationGeneratePrompt

class PreprocNode:
"""PreprocNode is for coreference resolution and scoring based on group
Expand All @@ -38,10 +32,10 @@ def __init__(self, config: dict, llm: ChatClient, language: str):
self.enable_cr = config['worker']['enable_cr']

if language == 'zh':
self.SCORING_QUESTION_TEMPLTE = SCORING_QUESTION_TEMPLTE_CN
self.INTENTION_TEMPLATE = INTENTION_TEMPLATE_CN
self.CR = CR_CN
else:
self.SCORING_QUESTION_TEMPLTE = SCORING_QUESTION_TEMPLTE_EN
self.INTENTION_TEMPLATE = INTENTION_TEMPLATE_EN
self.CR = CR_EN

def process(self, sess: Session) -> Generator[Session, None, None]:
Expand All @@ -51,16 +45,42 @@ def process(self, sess: Session) -> Generator[Session, None, None]:
yield sess
return

prompt = self.SCORING_QUESTION_TEMPLTE.format(sess.query.text)
truth, logs = is_truth(llm=self.llm,
prompt=prompt,
throttle=6,
default=3)
sess.debug['PreprocNode_is_question'] = logs
if not truth:
sess.code = ErrorCode.NOT_A_QUESTION
yield sess
return
prompt = self.INTENTION_TEMPLATE.format(sess.query.text)
json_str = self.llm.generate_response(prompt=prompt, backend='remote')
sess.debug['PreprocNode_intention_response'] = json_str
logger.info('intention response {}'.format(json_str))
try:
if json_str.startswith('```json'):
json_str = json_str[len('```json'):]

if json_str.endswith('```'):
json_str = json_str[0:-3]

json_obj = json.loads(json_str)
intention = json_obj['intention']
if intention is not None:
intention = intention.lower()
else:
intention = 'undefine'
topic = json_obj['topic']
if topic is not None:
topic = topic.lower()
else:
topic = 'undefine'

for block_intention in ['问候', 'greeting', 'undefine']:
if block_intention in intention:
sess.code = ErrorCode.NOT_A_QUESTION
yield sess
return

for block_topic in ['身份', 'identity', 'undefine']:
if block_topic in topic:
sess.code = ErrorCode.NOT_A_QUESTION
yield sess
return
except Exception as e:
logger.error(str(e))

if not self.enable_cr:
yield sess
Expand Down Expand Up @@ -161,7 +181,7 @@ def __init__(self, config: dict, config_path: str, llm: ChatClient,
self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_EN
self.KEYWORDS_TEMPLATE = KEYWORDS_TEMPLATE_EN

async def process(self, sess: Session) -> Generator[Session, None, None]:
async def process(self, sess: Session) -> AsyncGenerator[Session, None]:
"""Try web search."""

if not self.enable:
Expand Down Expand Up @@ -190,7 +210,7 @@ async def process(self, sess: Session) -> Generator[Session, None, None]:
yield sess
return

for article_id, article in enumerate(articles):
for _, article in enumerate(articles):
article.cut(0, self.context_max_length)
c = Chunk(content_or_path=article.content, metadata={'source': article.source})
sess.parallel_chunks.append(c)
Expand All @@ -207,16 +227,13 @@ class ReduceGenerate:
def __init__(self, config: dict, llm: ChatClient, retriever: CacheRetriever, language: str):
self.llm = llm
self.retriever = retriever
if language == 'zh':
self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_CN
else:
self.GENERATE_TEMPLATE = GENERATE_TEMPLATE_EN
llm_config = config['llm']
self.context_max_length = llm_config['server']['local_llm_max_text_length']
if llm_config['enable_remote']:
self.context_max_length = llm_config['server']['remote_llm_max_text_length']
self.language = language

async def process(self, sess: Session) -> Generator[Session, None, None]:
async def process(self, sess: Session) -> AsyncGenerator[Session, None]:
question = sess.query.text
history = sess.history

Expand All @@ -226,9 +243,14 @@ async def process(self, sess: Session) -> Generator[Session, None, None]:
sess.delta = part
yield sess
else:
_, context_str, references = self.retriever.rerank_fuse(query=sess.query, chunks=sess.parallel_chunks, context_max_length=self.context_max_length)
_, _, references, context_texts = self.retriever.rerank_fuse(query=sess.query, chunks=sess.parallel_chunks, context_max_length=self.context_max_length)
sess.references = references
prompt = self.GENERATE_TEMPLATE.format(context_str, sess.query.text)

citation = CitationGeneratePrompt(self.language)
prompt = citation.build(texts=context_texts, question=question)
with open('citation_generate_prompt.txt', 'w') as f:
f.write(prompt)
f.flush()
async for part in self.llm.chat_stream(prompt=prompt, history=history):
sess.delta = part
yield sess
Expand Down
Loading

0 comments on commit 9c98f42

Please sign in to comment.