Skip to content

Commit

Permalink
Integration tests on formatter (#49)
Browse files Browse the repository at this point in the history
* wip

* Add Yann Lecun IIT Madras lecture

* WIP

* Update

* Add uv.lock change

* Fix for linting

* Fix run pytest
  • Loading branch information
shun-liang authored Nov 2, 2024
1 parent 75587e8 commit f1e5a90
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 327 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Type check and lint
name: Type check, lint, and integration tests

on: push

Expand All @@ -12,7 +12,11 @@ jobs:
run: uv python install 3.12
- name: Install dependencies
run: uv sync
- name: Run mypy
- name: Run mypy on src
run: uv run mypy --strict src
- name: Run mypy on tests
run: uv run mypy --strict tests
- name: Run ruff check
run: uv run ruff check
run: uv run ruff check
- name: Run pytest
run: uv run python -m pytest tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ source = "vcs" # Use version control (e.g., git) for versioning
dev-dependencies = [
"ipython>=8.27.0",
"mypy>=1.11.2",
"pytest>=8.3.3",
"ruff>=0.6.3",
"types-tqdm>=4.66.0.20240417",
]
Expand Down
4 changes: 3 additions & 1 deletion src/yt2doc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def main(
ignore_source_chapters: typing.Annotated[
bool,
typer.Option(
"--ignore-source-chapters", "--ignore-chapters", help="Ignore original chapters from the source"
"--ignore-source-chapters",
"--ignore-chapters",
help="Ignore original chapters from the source",
),
] = False,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/yt2doc/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from yt2doc.extraction.extractor import Extractor
from yt2doc.formatting.formatter import MarkdownFormatter
from yt2doc.formatting.llm_topic_segmenter import LLMTopicSegmenter
from yt2doc.formatting.llm_adapter import LLMAdapter
from yt2doc.yt2doc import Yt2Doc


Expand Down Expand Up @@ -54,8 +55,8 @@ def get_yt2doc(
),
mode=instructor.Mode.JSON,
)

llm_topic_segmenter = LLMTopicSegmenter(llm_client=llm_client, model=llm_model)
llm_adapter = LLMAdapter(llm_client=llm_client, llm_model=llm_model)
llm_topic_segmenter = LLMTopicSegmenter(llm_adapter=llm_adapter)
formatter = MarkdownFormatter(sat=sat, topic_segmenter=llm_topic_segmenter)
else:
formatter = MarkdownFormatter(sat=sat)
Expand Down
10 changes: 10 additions & 0 deletions src/yt2doc/formatting/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ class FormattedPlaylist(BaseModel):
transcripts: typing.Sequence[FormattedTranscript]


class ILLMAdapter(typing.Protocol):
def get_topic_changing_paragraph_indexes(
self, paragraphs: typing.List[typing.List[str]]
) -> typing.List[int]: ...

def generate_title_for_paragraphs(
self, paragraphs: typing.List[typing.List[str]]
) -> str: ...


class ITopicSegmenter(typing.Protocol):
def segment(
self, paragraphs: typing.List[typing.List[str]]
Expand Down
100 changes: 100 additions & 0 deletions src/yt2doc/formatting/llm_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import typing

from instructor import Instructor
from pydantic import BaseModel, AfterValidator


class LLMAdapter:
def __init__(self, llm_client: Instructor, llm_model: str) -> None:
self.llm_client = llm_client
self.llm_model = llm_model

def get_topic_changing_paragraph_indexes(
self, paragraphs: typing.List[typing.List[str]]
) -> typing.List[int]:
def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]:
n = len(paragraphs)
unique_values = set(v)
if len(unique_values) != len(v):
raise ValueError("All elements must be unique")
for i in v:
if i <= 0:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is less than or equal to 0"
)
if i >= n:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is greater or equal to {n}"
)

return v

paragraph_texts = ["\n\n".join(p) for p in paragraphs]

class Result(BaseModel):
paragraph_indexes: typing.Annotated[
typing.List[int], AfterValidator(validate_paragraph_indexes)
]

result = self.llm_client.chat.completions.create(
model=self.llm_model,
response_model=Result,
messages=[
{
"role": "system",
"content": """
You are a smart assistant who reads paragraphs of text from an audio transcript and
find the paragraphs that significantly change topic from the previous paragraph.
Make sure only mark paragraphs that talks about a VERY DIFFERENT topic from the previous one.
The response should be an array of the index number of such paragraphs, such as `[1, 3, 5]`
If there is no paragraph that changes topic, then return an empty list.
""",
},
{
"role": "user",
"content": """
{% for paragraph in paragraphs %}
<paragraph {{ loop.index0 }}>
{{ paragraph }}
</ paragraph {{ loop.index0 }}>
{% endfor %}
""",
},
],
context={
"paragraphs": paragraph_texts,
},
)
return result.paragraph_indexes

def generate_title_for_paragraphs(
self, paragraphs: typing.List[typing.List[str]]
) -> str:
text = "\n\n".join(["".join(p) for p in paragraphs])
title = self.llm_client.chat.completions.create(
model=self.llm_model,
response_model=str,
messages=[
{
"role": "system",
"content": """
Please generate a short title for the following text.
Be VERY SUCCINCT. No more than 6 words.
""",
},
{
"role": "user",
"content": """
{{ text }}
""",
},
],
context={
"text": text,
},
)
return title
99 changes: 11 additions & 88 deletions src/yt2doc/formatting/llm_topic_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import typing
import logging

import instructor

from pydantic import BaseModel, AfterValidator
from tqdm import tqdm

from yt2doc.formatting import interfaces
Expand All @@ -12,37 +9,14 @@


class LLMTopicSegmenter:
def __init__(self, llm_client: instructor.Instructor, model: str) -> None:
self.llm_client = llm_client
self.model = model
def __init__(self, llm_adapter: interfaces.ILLMAdapter) -> None:
self.llm_adapter = llm_adapter

def _get_title_for_chapter(self, paragraphs: typing.List[typing.List[str]]) -> str:
truncated_paragraphs = [p[:10] for p in paragraphs]
truncated_text = "\n\n".join(["".join(p) for p in truncated_paragraphs])
title = self.llm_client.chat.completions.create(
model=self.model,
response_model=str,
messages=[
{
"role": "system",
"content": """
Please generate a short title for the following text.
Be VERY SUCCINCT. No more than 6 words.
""",
},
{
"role": "user",
"content": """
{{ text }}
""",
},
],
context={
"text": truncated_text,
},
return self.llm_adapter.generate_title_for_paragraphs(
paragraphs=truncated_paragraphs
)
return title

def segment(
self, paragraphs: typing.List[typing.List[str]]
Expand All @@ -60,68 +34,17 @@ def segment(
grouped_paragraphs_with_overlap, desc="Finding topic change points"
):
truncate_sentence_index = 6
truncated_grouped_paragraph_texts = [
"".join(paragraph[:truncate_sentence_index])
for paragraph in grouped_paragraphs
truncated_group_paragraphs = [
paragraph[:truncate_sentence_index] for paragraph in grouped_paragraphs
]

def validate_paragraph_indexes(v: typing.List[int]) -> typing.List[int]:
n = len(truncated_grouped_paragraph_texts)
unique_values = set(v)
if len(unique_values) != len(v):
raise ValueError("All elements must be unique")
for i in v:
if i <= 0:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is less than or equal to 0"
)
if i >= n:
raise ValueError(
f"All elements must be greater than 0 and less than {n}. Paragraph index {i} is greater or equal to {n}"
)

return v

class Result(BaseModel):
paragraph_indexes: typing.Annotated[
typing.List[int], AfterValidator(validate_paragraph_indexes)
]

result = self.llm_client.chat.completions.create(
model=self.model,
response_model=Result,
messages=[
{
"role": "system",
"content": """
You are a smart assistant who reads paragraphs of text from an audio transcript and
find the paragraphs that significantly change topic from the previous paragraph.
Make sure only mark paragraphs that talks about a VERY DIFFERENT topic from the previous one.
The response should be an array of the index number of such paragraphs, such as `[1, 3, 5]`
If there is no paragraph that changes topic, then return an empty list.
""",
},
{
"role": "user",
"content": """
{% for paragraph in paragraphs %}
<paragraph {{ loop.index0 }}>
{{ paragraph }}
</ paragraph {{ loop.index0 }}>
{% endfor %}
""",
},
],
context={
"paragraphs": truncated_grouped_paragraph_texts,
},
paragraph_indexes = self.llm_adapter.get_topic_changing_paragraph_indexes(
paragraphs=truncated_group_paragraphs
)
logger.info(f"paragraph indexes from LLM: {result}")

logger.info(f"paragraph indexes from LLM: {paragraph_indexes}")
aligned_indexes = [
start_index + index for index in sorted(result.paragraph_indexes)
start_index + index for index in sorted(paragraph_indexes)
]
topic_changed_indexes += aligned_indexes

Expand Down
Loading

0 comments on commit f1e5a90

Please sign in to comment.