Skip to content

Commit

Permalink
initialize agent
Browse files Browse the repository at this point in the history
  • Loading branch information
lfunderburk committed Dec 9, 2024
1 parent 21921b7 commit 95ae175
Show file tree
Hide file tree
Showing 10 changed files with 1,387 additions and 42 deletions.
14 changes: 14 additions & 0 deletions ch8/chatbot-real-time/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Steps to Build the Chatbot:

Setup:

1. Install necessary dependencies:

```bash
pip install langgraph bytewax websockets
```

2. Integrate the Dataflow: Use the Coinbase order book dataflow to generate real-time data.
3. Build the Chatbot: Define a chatbot in LangGraph.
4. Incorporate the real-time data into the chatbot's response.
5. Run the Bot: Combine the chatbot and dataflow into a unified system.
32 changes: 32 additions & 0 deletions ch8/chatbot-real-time/chatbot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_anthropic import ChatAnthropic

# Chatbot state
class ChatState(TypedDict):
messages: Annotated[list, add_messages]

# LangGraph chatbot
graph_builder = StateGraph(ChatState)

llm = ChatAnthropic(model="claude-3-5-sonnet-20240620")


def chatbot(state: ChatState):
user_message = state["messages"][-1][1]
if "BTC-USD" in user_message:
# Fetch real-time data (mocked for this example)
best_bid = 30000
best_ask = 30010
response = f"The best bid is ${best_bid} and the best ask is ${best_ask}."
else:
response = llm.invoke(state["messages"])
return {"messages": [("assistant", response)]}


graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
186 changes: 151 additions & 35 deletions ch8/chatbot-real-time/dataflow.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,171 @@
from bytewax.dataflow import Dataflow
from bytewax.inputs import FixedPartitionedSource
from bytewax.inputs import (FixedPartitionedSource,\
StatefulSourcePartition,\
batch_async)
from bytewax import operators as op
from bytewax.connectors.stdio import StdOutSink
from typing import List
from bytewax.connectors.files import FileSink
from bytewax.run import cli_main
import websockets
import json
from dataclasses import dataclass, field
from datetime import timedelta, datetime
from typing import Dict, List, Optional
import ssl

import ssl

async def _ws_agen(product_id):
"""Connect to websocket and yield messages as they arrive."""
url = "wss://ws-feed.exchange.coinbase.com"
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS)

async with websockets.connect(url, ssl=ssl_context, max_size=10**7) as websocket:
msg = json.dumps(
{
"type": "subscribe",
"product_ids": [product_id],
"channels": ["level2_batch"],
}
)
await websocket.send(msg)
await websocket.recv()

while True:
msg = await websocket.recv()
try:
# Parse the incoming message as JSON
parsed_msg = json.loads(msg)
yield (product_id, parsed_msg) # Ensure we yield (key, dictionary)
except json.JSONDecodeError as e:
print(f"Error decoding message: {msg}, error: {e}")



class CoinbasePartition(StatefulSourcePartition):
def __init__(self, product_id):
agen = _ws_agen(product_id)
self._batcher = batch_async(agen, timedelta(seconds=0.5), 100)

def next_batch(self):
return next(self._batcher)


def snapshot(self):
return None

# Coinbase data source
class CoinbaseSource(FixedPartitionedSource):
def list_parts(self):
return ["BTC-USD"]

def build_part(self, step_id, for_key, _resume_state):
return CoinbasePartition(for_key)

@dataclass(frozen=True)
class OrderBookSummary:
"""Represents a summary of the order book state."""

class CoinbasePartition:
def __init__(self, product_id):
self._product_id = product_id
self._data = iter([
{"product_id": "BTC-USD", "bids": [["30000", "1.5"]], "asks": [["30010", "2.0"]]},
{"product_id": "BTC-USD", "changes": [["buy", "30000", "1.0"]]},
])
bid_price: float
bid_size: float
ask_price: float
ask_size: float
spread: float
timestamp: datetime

def next_batch(self):
try:
return [next(self._data)]
except StopIteration:
return []
@dataclass
class OrderBookState:
"""Maintains the state of the order book."""

def snapshot(self):
return None
bids: Dict[float, float] = field(default_factory=dict)
asks: Dict[float, float] = field(default_factory=dict)
bid_price: Optional[float] = None
ask_price: Optional[float] = None

def update(self, data):
"""Update the order book state with the given data.
Args:
data: The data to update the order book state with.
"""
# Initialize bids and asks if they're empty
if not self.bids:
self.bids = {float(price): float(size) for price, size in data["bids"]}
self.bid_price = max(self.bids.keys(), default=None)
if not self.asks:
self.asks = {float(price): float(size) for price, size in data["asks"]}
self.ask_price = min(self.asks.keys(), default=None)

# Process updates from the "changes" field in the data
for change in data.get("changes", []):
side, price_str, size_str = change
price, size = float(price_str), float(size_str)

target_dict = self.asks if side == "sell" else self.bids

# If size is zero, remove the price level; otherwise,
# update/add the price level
if size == 0.0:
target_dict.pop(price, None)
else:
target_dict[price] = size

# After update, recalculate the best bid and ask prices
if side == "sell":
self.ask_price = min(self.asks.keys(), default=None)
else:
self.bid_price = max(self.bids.keys(), default=None)

def spread(self) -> float:
"""Calculate the spread between the best bid and ask prices.
Returns:
float: The spread between the best bid and ask prices.
"""
return self.ask_price - self.bid_price # type: ignore


def summarize(self):
"""Summarize the order book state.
Returns:
OrderBookSummary: A summary of the order book state.
"""
return OrderBookSummary(
bid_price=self.bid_price,
bid_size=self.bids[self.bid_price],
ask_price=self.ask_price,
ask_size=self.asks[self.ask_price],
spread=self.spread(),
timestamp=datetime.now(),
)


def create_dataflow(init_name, percentage):
flow = Dataflow(init_name)
source = CoinbaseSource()
inp = op.input("coinbase", flow, source)

def mapper(state, value):
"""Update the state with the given value and return the state and a summary."""
if state is None:
state = OrderBookState()

# Define Dataflow
flow = Dataflow()
state.update(value)
return (state, state.summarize())

source = CoinbaseSource()
inp = op.input("coinbase", flow, source)

def process_orderbook(state, value):
if state is None:
state = {"bids": {}, "asks": {}}
# Process bids and asks
for bid, size in value.get("bids", []):
state["bids"][float(bid)] = float(size)
for ask, size in value.get("asks", []):
state["asks"][float(ask)] = float(size)
return state, {
"best_bid": max(state["bids"].keys(), default=None),
"best_ask": min(state["asks"].keys(), default=None)
}
stats = op.stateful_map("orderbook", inp, mapper)

def just_large_spread(prod_summary):
"""Filter out products with a spread less than a given percentage."""
product, summary = prod_summary
return summary.spread / summary.ask_price > percentage

filter = op.filter("big_spread", stats, just_large_spread)

orderbook_stats = op.stateful_map("orderbook", inp, process_orderbook)
op.output("out", filter, StdOutSink())
return flow

op.output("out", orderbook_stats, StdOutSink())
percentage = 0.0001
flow = create_dataflow("coinbase", percentage)
cli_main(flow)
30 changes: 30 additions & 0 deletions ch8/chatbot-real-time/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import threading
from bytewax.run import cli_main
from chatbot import graph
from dataflow import create_dataflow

def run_dataflow():
flow = create_dataflow()
cli_main(flow)

def run_chatbot():
print("Chatbot with real-time data. Type 'quit' to exit.")
while True:
user_input = input("You: ")
if user_input.lower() in ["quit", "exit"]:
print("Goodbye!")
break
state = {"messages": [("user", user_input)]}
events = graph.stream(state)
for event in events:
for key, value in event.items():
print(f"Assistant: {value['messages'][-1][1]}")

if __name__ == "__main__":
# Run dataflow in a separate thread
dataflow_thread = threading.Thread(target=run_dataflow)
dataflow_thread.daemon = True
dataflow_thread.start()

# Run chatbot
run_chatbot()
Empty file.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ dependencies = [
"gunicorn",
"spacy",
"langgraph",
"websockets"
"websockets",
"certifi",
"matplotlib",
"langchain"
]

[build-system]
Expand Down
Loading

0 comments on commit 95ae175

Please sign in to comment.