-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
21921b7
commit 95ae175
Showing
10 changed files
with
1,387 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
Steps to Build the Chatbot: | ||
|
||
Setup: | ||
|
||
1. Install necessary dependencies: | ||
|
||
```bash | ||
pip install langgraph bytewax websockets | ||
``` | ||
|
||
2. Integrate the Dataflow: Use the Coinbase order book dataflow to generate real-time data. | ||
3. Build the Chatbot: Define a chatbot in LangGraph. | ||
4. Incorporate the real-time data into the chatbot's response. | ||
5. Run the Bot: Combine the chatbot and dataflow into a unified system. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Annotated | ||
from typing_extensions import TypedDict | ||
from langgraph.graph import StateGraph, START, END | ||
from langgraph.graph.message import add_messages | ||
from langchain_anthropic import ChatAnthropic | ||
|
||
# Chatbot state | ||
class ChatState(TypedDict): | ||
messages: Annotated[list, add_messages] | ||
|
||
# LangGraph chatbot | ||
graph_builder = StateGraph(ChatState) | ||
|
||
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620") | ||
|
||
|
||
def chatbot(state: ChatState): | ||
user_message = state["messages"][-1][1] | ||
if "BTC-USD" in user_message: | ||
# Fetch real-time data (mocked for this example) | ||
best_bid = 30000 | ||
best_ask = 30010 | ||
response = f"The best bid is ${best_bid} and the best ask is ${best_ask}." | ||
else: | ||
response = llm.invoke(state["messages"]) | ||
return {"messages": [("assistant", response)]} | ||
|
||
|
||
graph_builder.add_node("chatbot", chatbot) | ||
graph_builder.add_edge(START, "chatbot") | ||
graph_builder.add_edge("chatbot", END) | ||
graph = graph_builder.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,171 @@ | ||
from bytewax.dataflow import Dataflow | ||
from bytewax.inputs import FixedPartitionedSource | ||
from bytewax.inputs import (FixedPartitionedSource,\ | ||
StatefulSourcePartition,\ | ||
batch_async) | ||
from bytewax import operators as op | ||
from bytewax.connectors.stdio import StdOutSink | ||
from typing import List | ||
from bytewax.connectors.files import FileSink | ||
from bytewax.run import cli_main | ||
import websockets | ||
import json | ||
from dataclasses import dataclass, field | ||
from datetime import timedelta, datetime | ||
from typing import Dict, List, Optional | ||
import ssl | ||
|
||
import ssl | ||
|
||
async def _ws_agen(product_id): | ||
"""Connect to websocket and yield messages as they arrive.""" | ||
url = "wss://ws-feed.exchange.coinbase.com" | ||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS) | ||
|
||
async with websockets.connect(url, ssl=ssl_context, max_size=10**7) as websocket: | ||
msg = json.dumps( | ||
{ | ||
"type": "subscribe", | ||
"product_ids": [product_id], | ||
"channels": ["level2_batch"], | ||
} | ||
) | ||
await websocket.send(msg) | ||
await websocket.recv() | ||
|
||
while True: | ||
msg = await websocket.recv() | ||
try: | ||
# Parse the incoming message as JSON | ||
parsed_msg = json.loads(msg) | ||
yield (product_id, parsed_msg) # Ensure we yield (key, dictionary) | ||
except json.JSONDecodeError as e: | ||
print(f"Error decoding message: {msg}, error: {e}") | ||
|
||
|
||
|
||
class CoinbasePartition(StatefulSourcePartition): | ||
def __init__(self, product_id): | ||
agen = _ws_agen(product_id) | ||
self._batcher = batch_async(agen, timedelta(seconds=0.5), 100) | ||
|
||
def next_batch(self): | ||
return next(self._batcher) | ||
|
||
|
||
def snapshot(self): | ||
return None | ||
|
||
# Coinbase data source | ||
class CoinbaseSource(FixedPartitionedSource): | ||
def list_parts(self): | ||
return ["BTC-USD"] | ||
|
||
def build_part(self, step_id, for_key, _resume_state): | ||
return CoinbasePartition(for_key) | ||
|
||
@dataclass(frozen=True) | ||
class OrderBookSummary: | ||
"""Represents a summary of the order book state.""" | ||
|
||
class CoinbasePartition: | ||
def __init__(self, product_id): | ||
self._product_id = product_id | ||
self._data = iter([ | ||
{"product_id": "BTC-USD", "bids": [["30000", "1.5"]], "asks": [["30010", "2.0"]]}, | ||
{"product_id": "BTC-USD", "changes": [["buy", "30000", "1.0"]]}, | ||
]) | ||
bid_price: float | ||
bid_size: float | ||
ask_price: float | ||
ask_size: float | ||
spread: float | ||
timestamp: datetime | ||
|
||
def next_batch(self): | ||
try: | ||
return [next(self._data)] | ||
except StopIteration: | ||
return [] | ||
@dataclass | ||
class OrderBookState: | ||
"""Maintains the state of the order book.""" | ||
|
||
def snapshot(self): | ||
return None | ||
bids: Dict[float, float] = field(default_factory=dict) | ||
asks: Dict[float, float] = field(default_factory=dict) | ||
bid_price: Optional[float] = None | ||
ask_price: Optional[float] = None | ||
|
||
def update(self, data): | ||
"""Update the order book state with the given data. | ||
Args: | ||
data: The data to update the order book state with. | ||
""" | ||
# Initialize bids and asks if they're empty | ||
if not self.bids: | ||
self.bids = {float(price): float(size) for price, size in data["bids"]} | ||
self.bid_price = max(self.bids.keys(), default=None) | ||
if not self.asks: | ||
self.asks = {float(price): float(size) for price, size in data["asks"]} | ||
self.ask_price = min(self.asks.keys(), default=None) | ||
|
||
# Process updates from the "changes" field in the data | ||
for change in data.get("changes", []): | ||
side, price_str, size_str = change | ||
price, size = float(price_str), float(size_str) | ||
|
||
target_dict = self.asks if side == "sell" else self.bids | ||
|
||
# If size is zero, remove the price level; otherwise, | ||
# update/add the price level | ||
if size == 0.0: | ||
target_dict.pop(price, None) | ||
else: | ||
target_dict[price] = size | ||
|
||
# After update, recalculate the best bid and ask prices | ||
if side == "sell": | ||
self.ask_price = min(self.asks.keys(), default=None) | ||
else: | ||
self.bid_price = max(self.bids.keys(), default=None) | ||
|
||
def spread(self) -> float: | ||
"""Calculate the spread between the best bid and ask prices. | ||
Returns: | ||
float: The spread between the best bid and ask prices. | ||
""" | ||
return self.ask_price - self.bid_price # type: ignore | ||
|
||
|
||
def summarize(self): | ||
"""Summarize the order book state. | ||
Returns: | ||
OrderBookSummary: A summary of the order book state. | ||
""" | ||
return OrderBookSummary( | ||
bid_price=self.bid_price, | ||
bid_size=self.bids[self.bid_price], | ||
ask_price=self.ask_price, | ||
ask_size=self.asks[self.ask_price], | ||
spread=self.spread(), | ||
timestamp=datetime.now(), | ||
) | ||
|
||
|
||
def create_dataflow(init_name, percentage): | ||
flow = Dataflow(init_name) | ||
source = CoinbaseSource() | ||
inp = op.input("coinbase", flow, source) | ||
|
||
def mapper(state, value): | ||
"""Update the state with the given value and return the state and a summary.""" | ||
if state is None: | ||
state = OrderBookState() | ||
|
||
# Define Dataflow | ||
flow = Dataflow() | ||
state.update(value) | ||
return (state, state.summarize()) | ||
|
||
source = CoinbaseSource() | ||
inp = op.input("coinbase", flow, source) | ||
|
||
def process_orderbook(state, value): | ||
if state is None: | ||
state = {"bids": {}, "asks": {}} | ||
# Process bids and asks | ||
for bid, size in value.get("bids", []): | ||
state["bids"][float(bid)] = float(size) | ||
for ask, size in value.get("asks", []): | ||
state["asks"][float(ask)] = float(size) | ||
return state, { | ||
"best_bid": max(state["bids"].keys(), default=None), | ||
"best_ask": min(state["asks"].keys(), default=None) | ||
} | ||
stats = op.stateful_map("orderbook", inp, mapper) | ||
|
||
def just_large_spread(prod_summary): | ||
"""Filter out products with a spread less than a given percentage.""" | ||
product, summary = prod_summary | ||
return summary.spread / summary.ask_price > percentage | ||
|
||
filter = op.filter("big_spread", stats, just_large_spread) | ||
|
||
orderbook_stats = op.stateful_map("orderbook", inp, process_orderbook) | ||
op.output("out", filter, StdOutSink()) | ||
return flow | ||
|
||
op.output("out", orderbook_stats, StdOutSink()) | ||
percentage = 0.0001 | ||
flow = create_dataflow("coinbase", percentage) | ||
cli_main(flow) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import threading | ||
from bytewax.run import cli_main | ||
from chatbot import graph | ||
from dataflow import create_dataflow | ||
|
||
def run_dataflow(): | ||
flow = create_dataflow() | ||
cli_main(flow) | ||
|
||
def run_chatbot(): | ||
print("Chatbot with real-time data. Type 'quit' to exit.") | ||
while True: | ||
user_input = input("You: ") | ||
if user_input.lower() in ["quit", "exit"]: | ||
print("Goodbye!") | ||
break | ||
state = {"messages": [("user", user_input)]} | ||
events = graph.stream(state) | ||
for event in events: | ||
for key, value in event.items(): | ||
print(f"Assistant: {value['messages'][-1][1]}") | ||
|
||
if __name__ == "__main__": | ||
# Run dataflow in a separate thread | ||
dataflow_thread = threading.Thread(target=run_dataflow) | ||
dataflow_thread.daemon = True | ||
dataflow_thread.start() | ||
|
||
# Run chatbot | ||
run_chatbot() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.