Skip to content

Commit

Permalink
improve selenium page retriever, move bshr running example to example…
Browse files Browse the repository at this point in the history
…s, fix other examples, update readme, refactor, remove asserts
  • Loading branch information
doodledood committed Nov 13, 2023
1 parent b894d23 commit 894368d
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 177 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -620,4 +620,4 @@ MigrationBackup/
# dotenv
.env

*.db
examples/output
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ At the heart of ChatFlock is the Conductor, a novel entity that determines the s
- [Multi-Participant Chat with a Composition Generator with a deep (3+ level) Hierarchical Composition](examples/automatic_hierarchical_chat_composition.py)

#### End-to-End Examples
- [BSHR (Brainstorm-Search-Hypothesize-Refine) Loop](chatflock/use_cases/bshr.py) - Based on [David Shapiro's](https://github.com/daveshap/BSHR_Loop) idea.
- [BSHR (Brainstorm-Search-Hypothesize-Refine) Loop](examples/bshr_loop.py) - Based on [David Shapiro's](https://github.com/daveshap/BSHR_Loop) idea.

## 🚀 Features

Expand Down
11 changes: 6 additions & 5 deletions chatflock/ai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ def execute_chat_model_messages(
) -> str:
chat_model_args = chat_model_args or {}

assert "functions" not in chat_model_args, (
"The `functions` argument is reserved for the "
"`execute_chat_model_messages` function. If you want to add more "
"functions use the `functions` argument to this method."
)
if "functions" in chat_model_args:
raise ValueError(
"The `functions` argument is reserved for the "
"`execute_chat_model_messages` function. If you want to add more "
"functions use the `functions` argument to this method."
)

if tools is not None and len(tools) > 0:
chat_model_args["functions"] = [format_tool_to_openai_function(tool) for tool in tools]
Expand Down
5 changes: 2 additions & 3 deletions chatflock/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,8 @@ def __init__(
max_total_messages: Optional[int] = None,
hide_messages: bool = False,
):
assert max_total_messages is None or max_total_messages > 0, (
"Max total messages must be None or greater than " "0."
)
if max_total_messages is not None and max_total_messages <= 0:
raise ValueError("Max total messages must be None or greater than 0.")

self.backing_store = backing_store
self.renderer = renderer
Expand Down
16 changes: 10 additions & 6 deletions chatflock/structured_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,22 @@ def to_text(self, level: int = 0) -> str:
class StructuredString:
sections: List[Section]

def __getitem__(self, item):
assert isinstance(item, str)
def __getitem__(self, item) -> Section:
if not isinstance(item, str):
raise TypeError(f"Item must be of type str, not {type(item)}.")

relevant_sections = [section for section in self.sections if section.name == item]
if len(relevant_sections) == 0:
raise KeyError(f"No section with name {item} exists.")

return relevant_sections[0]

def __setitem__(self, key, value):
assert isinstance(key, str)
assert isinstance(value, Section)
def __setitem__(self, key, value) -> None:
if not isinstance(key, str):
raise TypeError(f"Key must be of type str, not {type(key)}.")

if not isinstance(value, Section):
raise TypeError(f"Value must be of type Section, not {type(value)}.")

try:
section = self[key]
Expand All @@ -63,5 +67,5 @@ def __str__(self) -> str:

return result

def __repr__(self):
def __repr__(self) -> str:
return self.__str__()
62 changes: 2 additions & 60 deletions chatflock/use_cases/bshr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,14 @@

import datetime
import json
import os
from functools import partial
from pathlib import Path

import questionary
from dotenv import load_dotenv
from halo import Halo
from langchain.cache import SQLiteCache
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.globals import set_llm_cache
from langchain.llms.openai import OpenAI
from langchain.memory import ConversationSummaryBufferMemory
from langchain.text_splitter import TokenTextSplitter
from langchain.tools import BaseTool
from pydantic import BaseModel, Field

Expand All @@ -33,9 +26,6 @@
from chatflock.structured_string import Section, StructuredString
from chatflock.use_cases.request_response import get_response
from chatflock.web_research import WebSearch
from chatflock.web_research.page_analyzer import OpenAIChatPageQueryAnalyzer
from chatflock.web_research.page_retrievers.selenium_retriever import SeleniumPageRetriever
from chatflock.web_research.search import GoogleSerperSearchResultsProvider
from chatflock.web_research.web_research import WebResearchTool


Expand Down Expand Up @@ -63,7 +53,7 @@ def load_state(state_file: Optional[str]) -> Optional[BHSRState]:
return None

try:
with open(state_file, "r") as f:
with open(state_file) as f:
data = json.load(f)
return BHSRState.model_validate(data)
except FileNotFoundError:
Expand Down Expand Up @@ -445,7 +435,7 @@ def run_brainstorm_search_hypothesize_refine_loop(
break

has_feedback = questionary.confirm(
"The information need seems to have have been satisficed. Do you have " "any feedback?"
"The information need seems to have have been satisficed. Do you have any feedback?"
).ask()

if not has_feedback:
Expand Down Expand Up @@ -495,51 +485,3 @@ def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = No
)

return hypothesis


if __name__ == "__main__":
load_dotenv()

output_dir = Path(os.getenv("OUTPUT_DIR", "../../output"))
output_dir.mkdir(exist_ok=True, parents=True)

n_search_results = 2

state_file = str(output_dir / "bshr_state.json")
llm_cache = SQLiteCache(database_path=str(output_dir / "llm_cache.db"))
set_llm_cache(llm_cache)

chat_model = ChatOpenAI(temperature=0.0, model="gpt-4-1106-preview")
chat_model_for_analysis = ChatOpenAI(
temperature=0.0,
model="gpt-3.5-turbo-1106",
)

try:
max_context_size = OpenAI.modelname_to_contextsize(chat_model_for_analysis.model_name)
except ValueError:
max_context_size = 12000

web_search = WebSearch(
chat_model=chat_model,
search_results_provider=GoogleSerperSearchResultsProvider(),
page_query_analyzer=OpenAIChatPageQueryAnalyzer(
chat_model=chat_model_for_analysis,
page_retriever=SeleniumPageRetriever(),
text_splitter=TokenTextSplitter(chunk_size=max_context_size, chunk_overlap=max_context_size // 5),
use_first_split_only=True,
),
)

spinner = Halo(spinner="dots")

hypothesis = run_brainstorm_search_hypothesize_refine_loop(
confirm_satisficed=True,
web_search=web_search,
chat_model=chat_model,
n_search_results=n_search_results,
state_file=state_file,
spinner=spinner,
)

print(f"Final Answer:\n----------------\n{hypothesis}\n----------------")
8 changes: 7 additions & 1 deletion chatflock/web_research/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,10 @@
from .page_analyzer import OpenAIChatPageQueryAnalyzer, PageQueryAnalyzer
from .web_research import WebSearch

__all__ = ["WebSearch", "PageQueryAnalyzer", "TransientHTTPError", "NonTransientHTTPError"]
__all__ = [
"WebSearch",
"PageQueryAnalyzer",
"OpenAIChatPageQueryAnalyzer",
"TransientHTTPError",
"NonTransientHTTPError",
]
102 changes: 47 additions & 55 deletions chatflock/web_research/page_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@
from langchain.text_splitter import TextSplitter
from pydantic import BaseModel

from chatflock.backing_stores import InMemoryChatDataBackingStore
from chatflock.base import Chat
from chatflock.conductors import RoundRobinChatConductor
from chatflock.parsing_utils import string_output_to_pydantic
from chatflock.renderers import NoChatRenderer
from chatflock.structured_string import Section, StructuredString

from ..participants.langchain import LangChainBasedAIChatParticipant
from ..participants.user import UserChatParticipant
from ..use_cases.request_response import get_response
from .errors import NonTransientHTTPError, TransientHTTPError
from .page_retrievers import PageRetriever

Expand Down Expand Up @@ -84,6 +80,8 @@ def analyze(self, url: str, title: str, query: str, spinner: Optional[Halo] = No
return PageQueryAnalysisResult(
answer=f"The query could not be answered because an error occurred while retrieving the page: {e}"
)
finally:
self.page_retriever.close()

cleaned_html = clean_html(html)

Expand All @@ -92,71 +90,65 @@ def analyze(self, url: str, title: str, query: str, spinner: Optional[Halo] = No
answer = "No answer yet."
for i, doc in enumerate(docs):
text = doc.page_content
chat = Chat(
backing_store=InMemoryChatDataBackingStore(),
renderer=NoChatRenderer(),
initial_participants=[
UserChatParticipant(),
LangChainBasedAIChatParticipant(
name="Web Page Query Answerer",
role="Web Page Query Answerer",
personal_mission="Answer queries based on provided (partial) web page content from the web.",
chat_model=self.chat_model,
other_prompt_sections=[

query_answerer = LangChainBasedAIChatParticipant(
name="Web Page Query Answerer",
role="Web Page Query Answerer",
personal_mission="Answer queries based on provided (partial) web page content from the web.",
chat_model=self.chat_model,
other_prompt_sections=[
Section(
name="Crafting a Query Answer",
sub_sections=[
Section(
name="Process",
list=[
"Analyze the query and the given content",
"If context is provided, use it to answer the query.",
"Summarize the answer in a comprehensive, yet succinct way.",
],
list_item_prefix=None,
),
Section(
name="Crafting a Query Answer",
sub_sections=[
Section(
name="Process",
list=[
"Analyze the query and the given content",
"If context is provided, use it to answer the query.",
"Summarize the answer in a comprehensive, yet succinct way.",
],
list_item_prefix=None,
),
Section(
name="Guidelines",
list=[
"If the answer is not found in the page content, it's insufficent, or not relevant "
"to the query at all, state it clearly.",
"Do not fabricate information. Stick to provided content.",
"Provide context for the next call (e.g., if a paragraph was cut short, include "
"relevant header information, section, etc. for continuity). Assume the content is "
"partial content from the page. Be very detailed in the context.",
"If unable to answer but found important information, include it in the context "
"for the next call.",
"Pay attention to the details of the query and make sure the answer is suitable "
"for the intent of the query.",
"A potential answer might have been provided. This means you thought you found "
"the answer in a previous partial text for the same page. You should double-check "
"that and provide an alternative revised answer if you think it's wrong, "
"or repeat it if you think it's right or cannot be validated using the current "
"text.",
],
),
name="Guidelines",
list=[
"If the answer is not found in the page content, it's insufficent, or not relevant "
"to the query at all, state it clearly.",
"Do not fabricate information. Stick to provided content.",
"Provide context for the next call (e.g., if a paragraph was cut short, include "
"relevant header information, section, etc. for continuity). Assume the content is "
"partial content from the page. Be very detailed in the context.",
"If unable to answer but found important information, include it in the context "
"for the next call.",
"Pay attention to the details of the query and make sure the answer is suitable "
"for the intent of the query.",
"A potential answer might have been provided. This means you thought you found "
"the answer in a previous partial text for the same page. You should double-check "
"that and provide an alternative revised answer if you think it's wrong, "
"or repeat it if you think it's right or cannot be validated using the current "
"text.",
],
)
),
],
),
)
],
max_total_messages=2,
)
chat_conductor = RoundRobinChatConductor()
final_answer = chat_conductor.initiate_chat_with_result(
chat=chat,
initial_message=str(

final_answer, _ = get_response(
query=str(
StructuredString(
sections=[
Section(name="Query", text=query),
Section(name="Url", text=url),
Section(name="Title", text=title),
Section(name="Previous Answer", text=answer),
Section(name="Page Content", text=text),
Section(name="Page Content", text=f"```{text}```"),
]
)
),
answerer=query_answerer,
)

result = string_output_to_pydantic(
output=final_answer, chat_model=self.chat_model, output_schema=PageQueryAnalysisResult
)
Expand Down
3 changes: 3 additions & 0 deletions chatflock/web_research/page_retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
class PageRetriever(abc.ABC):
def retrieve_html(self, url: str, **kwargs: Any) -> str:
raise NotImplementedError()

def close(self) -> None:
pass
3 changes: 2 additions & 1 deletion chatflock/web_research/page_retrievers/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

class RetrieverWithFallback(PageRetriever):
def __init__(self, retrievers: Sequence[PageRetriever]):
assert len(retrievers) > 0, "Must provide at least one retriever."
if len(retrievers) == 0:
raise ValueError("Must provide at least one retriever.")

self.retrievers = retrievers

Expand Down
Loading

0 comments on commit 894368d

Please sign in to comment.