Skip to content

Commit

Permalink
Merge pull request #175 from NotJoeMartinez/add_summarize_command
Browse files Browse the repository at this point in the history
Add summarize command
  • Loading branch information
NotJoeMartinez authored Sep 6, 2024
2 parents 0b715f8 + 2d0fbf5 commit 7f9f792
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 11 deletions.
8 changes: 5 additions & 3 deletions yt_fts/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def add_video(channel_id, video_id, video_title, video_url, video_date):
(video_id,)).fetchone()

if existing_video is None:
cur.execute(
"INSERT INTO Videos (video_id, video_title, video_url, video_date, channel_id) VALUES (?, ?, ?, ?, ?)",
(video_id, video_title, video_url, video_date, channel_id))
cur.execute("""
INSERT INTO Videos (video_id, video_title, video_url, video_date, channel_id)
VALUES (?, ?, ?, ?, ?)
""",(video_id, video_title, video_url, video_date, channel_id))
conn.commit()

else:
Expand Down Expand Up @@ -358,6 +359,7 @@ def delete_channel_from_chroma(channel_id):
where={"channel_id": channel_id}
)


def get_channel_id_from_rowid(rowid):
db = Database(get_db_path())

Expand Down
26 changes: 22 additions & 4 deletions yt_fts/search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys

from rich.console import Console
import textwrap
from .config import get_chroma_client
from .utils import time_to_secs, bold_query_matches
from .get_embeddings import EmbeddingsHandler
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(self,
self.query = ''
self.response = []
self.openai_client = openai_client
self.max_width = 80

def full_text_search(self, query):

Expand Down Expand Up @@ -159,6 +161,7 @@ def print_fts_res(self):
metadata = quote["metadata"]
video_name = metadata["video_title"]
video_date = metadata["video_date"]
video_id = quote["video_id"]
quote_data = {
"quote": quote["subs"],
"time_stamp": quote["time_stamp"],
Expand All @@ -167,8 +170,8 @@ def print_fts_res(self):
if channel_name not in fts_dict:
fts_dict[channel_name] = {}
if (video_name, video_date) not in fts_dict[channel_name]:
fts_dict[channel_name][(video_name, video_date)] = []
fts_dict[channel_name][(video_name, video_date)].append(quote_data)
fts_dict[channel_name][(video_name, video_date, video_id)] = []
fts_dict[channel_name][(video_name, video_date, video_id)].append(quote_data)

# Sort the list by the total number of quotes in each channel
channel_list = list(fts_dict.items())
Expand All @@ -182,8 +185,8 @@ def print_fts_res(self):
video_list = list(videos.items())
video_list.sort(key=lambda x: len(x[1]))

for (video_name, video_date), quotes in video_list:
console.print(f" [bold][blue]{video_name}[/blue][/bold] ({video_date})")
for (video_name, video_date, video_id), quotes in video_list:
console.print(f"{video_id} ({video_date}) \"[bold][blue]{video_name}[/blue][/bold]\"")
console.print("")

# Sort the quotes by timestamp
Expand Down Expand Up @@ -249,4 +252,19 @@ def print_vector_search_results(self):

console.print(summary_str)

def wrap_text(self, text: str) -> str:
lines = text.split('\n')
wrapped_lines = []

for line in lines:
# If the line is a code block, don't wrap it
if line.strip().startswith('```') or line.strip().startswith('`'):
wrapped_lines.append(line)
else:
# Wrap the line
wrapped = textwrap.wrap(line, width=self.max_width, break_long_words=False, replace_whitespace=False)
wrapped_lines.extend(wrapped)

# Join the wrapped lines back together
return " \n".join(wrapped_lines)

269 changes: 269 additions & 0 deletions yt_fts/summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import os
import json
import sys
import sqlite3
import tempfile
import textwrap

import yt_dlp
from rich.console import Console
from rich.markdown import Markdown
from urllib.parse import urlparse, parse_qs

from .config import get_db_path
from .utils import parse_vtt
from .db_utils import get_title_from_db, get_channel_name_from_video_id

class SummarizeHandler:
def __init__(self, openai_client, model, input_video):

self.console = Console()
self.model = model
self.openai_client = openai_client
self.input_video = input_video
self.max_width = 80

self.video_title = ''
self.channel_name = ''

if "https" in input_video:
self.video_id = self.get_video_id_from_url(input_video)
else:
self.video_id = input_video

if not self.video_in_database(self.video_id):
self.transcript_text = self.download_transcript()
else:
self.video_title = get_title_from_db(self.video_id)
self.channel_name = get_channel_name_from_video_id(self.video_id)
self.transcript_text = self.get_transcript_from_database(self.video_id)

def summarize_video(self):
console = self.console
video_id = self.video_id


system_prompt = f"""
Summarize the transcript of the YouTube video given below.
- Provide valid youtube timestamped urls for key points in the video
using the format: [timestamp](https://youtu.be/{video_id}?t=[seconds])
Video Title: {self.video_title}
Channel Name: {self.channel_name}
Transcript:
"""

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": self.transcript_text},
]


with console.status("[green]Generating summary..."):
summary_text = self.get_completion(messages)
md = Markdown(summary_text)
console.print("")
console.print(md)


def get_completion(self, messages: list) -> str:
console = self.console
try:
response = self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.5,
max_tokens=2000,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=None,
)

response_text = response.choices[0].message.content
wrapped_text = self.wrap_text(response_text)
return wrapped_text

except Exception as e:
console.print(f"[red]Error:[/red] {e}")
sys.exit(1)

def download_transcript(self):
console = self.console
video_id = self.video_id
video_url = f"https://www.youtube.com/watch?v={video_id}"

try:
console.print(f"Downloading subtitles for: {video_url}")
with tempfile.TemporaryDirectory() as tmp_dir:
ydl_opts = {
'outtmpl': f'{tmp_dir}/%(id)s',
'writeinfojson': True,
'writeautomaticsub': True,
'subtitlesformat': 'vtt',
'skip_download': True,
'subtitleslangs': ['en', '-live_chat'],
'quiet': True,
'no_warnings': True,
'progress_hook': [self.quiet_progress_hook],
}

# if self.cookies_from_browser is not None:
# ydl_opts['cookiesfrombrowser'] = (self.cookies_from_browser,)

with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([video_url])


items = os.listdir(tmp_dir)
vtt_files = [os.path.join(tmp_dir, item) for item in items if item.endswith('.vtt')]
json_files = [os.path.join(tmp_dir, item) for item in items if item.endswith('.info.json')]

if len(vtt_files) == 0:
console.print("[red]Error:[/red] "
"Failed to download subtitles.")
sys.exit(1)

try:
with open(json_files[0], 'r') as f:
data = json.load(f)
title = data['title']
channel = data['uploader']
self.video_title = title
self.channel_name = channel
except Exception as e:
console.print(f"[yellow]Warning:[/yellow] {e}")
pass


vtt_file_path = vtt_files[0]
vtt_json = parse_vtt(vtt_file_path)
transcript = ""
for subtitle in vtt_json:
start_time = subtitle['start_time'][:-4]
text = subtitle['text'].strip()
if len(text) == 0:
continue
transcript += f"{start_time}: {text}\n"

return transcript

except Exception as e:
console.print(f"Failed to get: {video_id}\n{e}")
sys.exit(1)


def get_transcript_from_database(self, video_id) -> str:

console = self.console
try:
conn = sqlite3.connect(get_db_path())
curr = conn.cursor()
curr.execute(
"""
SELECT
start_time, text
FROM
Subtitles
WHERE
video_id = ?
""", (video_id,)
)
res = curr.fetchall()
transcript = ""
for row in res:
start_time, text = row
text = text.strip()
if len(text) == 0:
continue
transcript += f"{start_time[:-4]}: {text}\n"
conn.close()
return transcript
except Exception as e:
console.print(f"[red]Error:[/red] {e}")
sys.exit(1)
finally:
conn.close()

def video_in_database(self, video_id) -> bool:
console = self.console
try:
conn = sqlite3.connect(get_db_path())
curr = conn.cursor()
curr.execute(
"""
SELECT
count(*)
FROM
Videos
WHERE
video_id = ?
""", (video_id,)
)
count = curr.fetchone()[0]
conn.close()
if count > 0:
return True
return False
except Exception as e:
console.print(f"[red]Error:[/red] {e}")
sys.exit(1)
finally:
conn.close()


def get_video_id_from_url(self, video_url):
console = self.console
video_url = video_url.strip('/')
parsed = urlparse(video_url)
domain = parsed.netloc
path = parsed.path.split('/')
query = parse_qs(parsed.query)

valid_domains = ["youtube.com", "youtu.be", "www.youtube.com"]

if domain not in valid_domains:
console.print("[red]Error:[/red] "
f"Invalid URL, domain \"{domain}\" not supported.")
sys.exit(1)


if domain in ["youtube.com", "www.youtube.com"] and "watch" in path:
video_id = query.get('v', [None])[0]
elif domain == "youtu.be":
video_id = path[-1]
else:
console.print("[red]Error:[/red] "
"Invalid URL, please provide a valid YouTube video URL.")
sys.exit(1)

if video_id:
return video_id

console.print("[red]Error:[/red] "
"Invalid URL, please provide a valid YouTube video URL.")
sys.exit(1)


def quiet_progress_hook(self, d):
console = self.console
if d['status'] == 'finished':
console.print(f" -> \"{d['filename']}\"")

def wrap_text(self, text: str) -> str:
lines = text.split('\n')
wrapped_lines = []

for line in lines:
# If the line is a code block, don't wrap it
if line.strip().startswith('```') or line.strip().startswith('`'):
wrapped_lines.append(line)
else:
# Wrap the line
wrapped = textwrap.wrap(line, width=self.max_width, break_long_words=False, replace_whitespace=False)
wrapped_lines.extend(wrapped)

# Join the wrapped lines back together
return " \n".join(wrapped_lines)

Loading

0 comments on commit 7f9f792

Please sign in to comment.