Skip to content

Commit

Permalink
fixed formatting on summaries
Browse files Browse the repository at this point in the history
  • Loading branch information
NotJoeMartinez committed Sep 6, 2024
1 parent bb04927 commit f0eedc1
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions yt_fts/summarize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json
import sys
import sqlite3
import tempfile
Expand All @@ -11,6 +12,7 @@

from .config import get_db_path
from .utils import parse_vtt
from .db_utils import get_title_from_db, get_channel_name_from_video_id

class SummarizeHandler:
def __init__(self, openai_client, model, input_video):
Expand All @@ -21,37 +23,47 @@ def __init__(self, openai_client, model, input_video):
self.input_video = input_video
self.max_width = 80

self.video_title = ''
self.channel_name = ''

if "https" in input_video:
self.video_id = self.get_video_id_from_url(input_video)
else:
self.video_id = input_video

if not self.video_in_database(self.video_id):
self.transcript_text = self.download_transcript()
else:
self.video_title = get_title_from_db(self.video_id)
self.channel_name = get_channel_name_from_video_id(self.video_id)
self.transcript_text = self.get_transcript_from_database(self.video_id)

def summarize_video(self):
console = self.console
video_id = self.video_id

if self.video_in_database(video_id):
transcript_text = self.get_transcript_from_database(video_id)
else:
transcript_text = self.download_transcript()

system_prompt = f"""
Summarize the transcript of the YouTube video given below.
Provide valid youtube timestamped urls for key points
in the video using the format: https://youtu.be/{video_id}?t=[seconds]
- Provide valid youtube timestamped urls for key points in the video
using the format: [timestamp](https://youtu.be/{video_id}?t=[seconds])
Video Title: {self.video_title}
Channel Name: {self.channel_name}
Transcript:
"""

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": transcript_text},
{"role": "user", "content": self.transcript_text},
]


with console.status("[green]Generating summary..."):
summary_text = self.get_completion(messages)
md = Markdown(summary_text)
console.print("")
console.print(md)


Expand Down Expand Up @@ -93,7 +105,8 @@ def download_transcript(self):
'skip_download': True,
'subtitleslangs': ['en', '-live_chat'],
'quiet': True,
'progress_hooks': [self.quiet_progress_hook],
'no_warnings': True,
'progress_hook': [self.quiet_progress_hook],
}

# if self.cookies_from_browser is not None:
Expand All @@ -104,15 +117,28 @@ def download_transcript(self):


items = os.listdir(tmp_dir)
file_paths = [os.path.join(tmp_dir, item) for item in items if item.endswith('.vtt')]
vtt_files = [os.path.join(tmp_dir, item) for item in items if item.endswith('.vtt')]
json_files = [os.path.join(tmp_dir, item) for item in items if item.endswith('.info.json')]

if len(file_paths) == 0:
if len(vtt_files) == 0:
console.print("[red]Error:[/red] "
"Failed to download subtitles.")
sys.exit(1)

try:
with open(json_files[0], 'r') as f:
data = json.load(f)
title = data['title']
channel = data['uploader']
self.video_title = title
self.channel_name = channel
except Exception as e:
console.print(f"[yellow]Warning:[/yellow] {e}")
pass


file_path = file_paths[0]
vtt_json = parse_vtt(file_path)
vtt_file_path = vtt_files[0]
vtt_json = parse_vtt(vtt_file_path)
transcript = ""
for subtitle in vtt_json:
start_time = subtitle['start_time'][:-4]
Expand All @@ -127,6 +153,7 @@ def download_transcript(self):
console.print(f"Failed to get: {video_id}\n{e}")
sys.exit(1)


def get_transcript_from_database(self, video_id) -> str:

console = self.console
Expand Down

0 comments on commit f0eedc1

Please sign in to comment.