Skip to content

Commit

Permalink
Merge pull request #166 from pipecat-ai/fix-llm-response-wake-check
Browse files Browse the repository at this point in the history
fix llm response wake check
  • Loading branch information
aconchillo authored May 23, 2024
2 parents 91c706a + e130aad commit 5d9a962
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed an issue in `LLMUserResponseAggregator` and `UserResponseAggregator`
that would cause frames after a brief pause to not be pushed to the LLM.

- Clear the audio output buffer if we are interrupted.

- Re-add exponential smoothing after volume calculation. This makes sure the
Expand Down
5 changes: 0 additions & 5 deletions examples/foundational/14-wake-phrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
await tts.say("Hi! If you want to talk to me, just say 'Hey Robot'.")

# Kick off the conversation.
# messages.append(
# {"role": "system", "content": "Please introduce yourself to the user."})
# await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)
Expand Down
6 changes: 3 additions & 3 deletions src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class TextFrame(DataFrame):
text: str

def __str__(self):
return f"{self.name}(text: {self.text})"
return f"{self.name}(text: [{self.text}])"


@dataclass
Expand All @@ -132,7 +132,7 @@ class TranscriptionFrame(TextFrame):
timestamp: str

def __str__(self):
return f"{self.name}(user_id: {self.user_id}, text: {self.text}, timestamp: {self.timestamp})"
return f"{self.name}(user_id: {self.user_id}, text: [{self.text}], timestamp: {self.timestamp})"


@dataclass
Expand All @@ -143,7 +143,7 @@ class InterimTranscriptionFrame(TextFrame):
timestamp: str

def __str__(self):
return f"{self.name}(user: {self.user_id}, text: {self.text}, timestamp: {self.timestamp})"
return f"{self.name}(user: {self.user_id}, text: [{self.text}], timestamp: {self.timestamp})"


@dataclass
Expand Down
13 changes: 8 additions & 5 deletions src/pipecat/processors/aggregators/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
send_aggregation = False

if isinstance(frame, self._start_frame):
self._seen_start_frame = True
self._aggregation = ""
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
self._seen_start_frame = False

# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
Expand Down Expand Up @@ -118,10 +122,9 @@ async def _push_aggregation(self):
if len(self._aggregation) > 0:
self._messages.append({"role": self._role, "content": self._aggregation})

# Reset our accumulator state. Reset it before pushing it down,
# otherwise if the tasks gets cancelled we won't be able to clear
# things up.
self._reset()
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""

frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)
Expand Down
12 changes: 7 additions & 5 deletions src/pipecat/processors/aggregators/user_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
send_aggregation = False

if isinstance(frame, self._start_frame):
self._seen_start_frame = True
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
self._seen_interim_results = False
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
self._seen_start_frame = False

# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
Expand Down Expand Up @@ -120,10 +123,9 @@ async def _push_aggregation(self):
if len(self._aggregation) > 0:
frame = TextFrame(self._aggregation.strip())

# Reset our accumulator state. Reset it before pushing it down,
# otherwise if the tasks gets cancelled we won't be able to clear
# things up.
self._reset()
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""

await self.push_frame(frame)

Expand Down
4 changes: 2 additions & 2 deletions src/pipecat/processors/filters/wake_check_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, participant_id: str):
self.wake_timer = 0.0
self.accumulator = ""

def __init__(self, wake_phrases: list[str], keepalive_timeout: float = 2):
def __init__(self, wake_phrases: list[str], keepalive_timeout: float = 3):
super().__init__()
self._participant_states = {}
self._keepalive_timeout = keepalive_timeout
Expand All @@ -55,7 +55,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
if p.state == WakeCheckFilter.WakeState.AWAKE:
if time.time() - p.wake_timer < self._keepalive_timeout:
logger.debug(
"Wake phrase keepalive timeout has not expired. Passing frame through.")
f"Wake phrase keepalive timeout has not expired. Pushing {frame}")
p.wake_timer = time.time()
await self.push_frame(frame)
return
Expand Down
2 changes: 1 addition & 1 deletion src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def _process_text_frame(self, frame: TextFrame):
else:
self._current_sentence += frame.text
if self._current_sentence.strip().endswith((".", "?", "!")):
text = self._current_sentence
text = self._current_sentence.strip()
self._current_sentence = ""

if text:
Expand Down
2 changes: 1 addition & 1 deletion src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self._model = model

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Transcribing text: {text}")
logger.debug(f"Transcribing text: [{text}]")

url = f"https://api.elevenlabs.io/v1/text-to-speech/{self._voice_id}/stream"

Expand Down
2 changes: 2 additions & 0 deletions src/pipecat/transports/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ async def _handle_interruptions(self, frame: Frame):
if self._allow_interruptions:
# Make sure we notify about interruptions quickly out-of-band
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
self._push_frame_task.cancel()
self._create_push_task()
await self.push_frame(StartInterruptionFrame())
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("User stopped speaking")
await self.push_frame(StopInterruptionFrame())
await self._internal_push_frame(frame)

Expand Down

0 comments on commit 5d9a962

Please sign in to comment.