Skip to content

Commit

Permalink
Merge pull request #31 from TommasU/parth_testing
Browse files Browse the repository at this point in the history
Basic testing
  • Loading branch information
TommasU authored Nov 3, 2021
2 parents f92b060 + 356c6ae commit 7a9ac6b
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 9 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
on: push
name: on push
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install Python 3
uses: actions/setup-python@v1
with:
python-version: 3.6
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run tests with pytest
run: python -m pytest
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ sumy
monkeylearn
punctuator==0.9.6
wget

pytest==6.2.5
24 changes: 24 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from setuptools import setup

setup(
name='Scrivener',
version='1.4',
author='scrivener',
author_email='[email protected]',
license='MIT',
url='https://github.com/TommasU/scrivener/',

install_requires=open('requirements.txt').readlines(),

description='Video transcript summarizer',
long_description="""\
Scrivener is a video transcript summarizer for Youtube videos.
""",
keywords=['python', 'video summarizer', 'youtube', 'transcript'],
classifiers=[
'License :: OSI Approved :: MIT License',
"Programming Language :: Python",
],

packages=["source"],
)
3 changes: 3 additions & 0 deletions source/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
@author: Scrivener
"""
6 changes: 5 additions & 1 deletion source/main/punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
except:
print("punctuator.py not found in: " + os.getcwd())
from punctuator import Punctuator

from pathlib import Path

class Punctuation:
"""
Expand Down Expand Up @@ -39,6 +39,10 @@ def add_punctuation_transcript(transcription):
Using the generated transcript, add punctuation
"""
# initialize the punctuator ML model
# cwd = Path.cwd()
# template = "./source/punct_model_full.pcl"
# file_path = (cwd / template).resolve()
# print(cwd, flush=True)
punct_model = Punctuator(os.path.abspath("source/punct_model_full.pcl"))

# Add punctuation to text
Expand Down
8 changes: 4 additions & 4 deletions source/main/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
"""

# Import Libraries
from main.summarize import Summary
from source.main.summarize import Summary
import speech_recognition as sr
import moviepy.editor as mp
from helper.split_audio import splitwavaudio
from source.helper.split_audio import splitwavaudio
import os
from helper.cleanup import Cleanup
from source.helper.cleanup import Cleanup

from main.punctuation import Punctuation
from source.main.punctuation import Punctuation


class TranscribeVideo:
Expand Down
6 changes: 3 additions & 3 deletions source/main/transcribe_yt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
"""

# Import Libraries
from main.summarize import Summary
from source.main.summarize import Summary
import pytube
import os
from youtube_transcript_api import YouTubeTranscriptApi
from main.transcribe import TranscribeVideo
from main.punctuation import Punctuation
from source.main.transcribe import TranscribeVideo
from source.main.punctuation import Punctuation


class TranscribeYtVideo:
Expand Down
Empty file added test/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import shutil
import wget
import nltk
import os

def pytest_configure(config):
def check_for_full_model():
# Check if ML model files have been combined, if not combine them
# This needs to be done because the full file is greater than 100mb
# and GitHub does not allow files larger than 100mb to be pushed
first_file = os.path.abspath('source/punct_model_part1.pcl')
second_file = os.path.abspath('source/punct_model_part2.pcl')
third_file = os.path.abspath('source/punct_model_part3.pcl')
new_file = os.path.abspath('source/punct_model_full.pcl')

if not os.path.exists(new_file):
print("Creating punct_model_full.pcl file for ML model...")

# Storing these models in github causes an issue with the Heroku deployment and exceeds 500 MB (it is 618 MB)
# slug/payload limit. Therefore, using this alternative to get it from Github during runtime.
if not os.path.exists(first_file):
print("Downloading punct_model_part1.pcl file for ML model...")
url1 = 'https://github.com/SN-18/scrivener/raw/developer/source/punct_model_part1.pcl'
filename = wget.download(url1, out='source/punct_model_part1.pcl')
print("\nDownloaded file: " + filename)

if not os.path.exists(second_file):
print("Downloading punct_model_part2.pcl file for ML model...")
url2 = 'https://github.com/SN-18/scrivener/raw/developer/source/punct_model_part2.pcl'
filename = wget.download(url2, out='source/punct_model_part2.pcl')
print("\nDownloaded file: " + filename)

if not os.path.exists(third_file):
print("Downloading punct_model_part3.pcl file for ML model...")
url3 = 'https://github.com/SN-18/scrivener/raw/developer/source/punct_model_part3.pcl'
filename = wget.download(url3, out='source/punct_model_part3.pcl')
print("\nDownloaded file: " + filename)

with open(new_file, "wb") as wfd:
for f in [first_file, second_file, third_file]:
with open(f, "rb") as fd:
shutil.copyfileobj(fd, wfd, 1024 * 1024 * 10)

check_for_full_model()

nltk.download('punkt')
37 changes: 37 additions & 0 deletions test/test_punctuation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from source.main.punctuation import Punctuation
import string
import re

# Khan Academy's video on agriculture
# contains transcripts
# contains closed captions
video_one_url = "https://www.youtube.com/watch?v=JvBHwVpBCwM"

# Lecture 38 — Bloom Filters | Mining of Massive Datasets | Stanford University
# contains no transcripts
video_two_url = "https://www.youtube.com/watch?v=qBTdukbzc78"

test_sentence = """
A computer is a machine that can be programmed to carry
out sequences of arithmetic or logical operations automatically.
Modern computers can perform generic sets of operations known
as programs. These programs enable computers to perform a wide
range of tasks. A computer system is a "complete" computer that
includes the hardware, operating system (main software), and
peripheral equipment needed and used for "full" operation. This
term may also refer to a group of computers that are linked and
function together, such as a computer network or computer cluster.
"""

@pytest.mark.parametrize('test_sentence', [
(test_sentence),
])
def test_add_punctuation_transcript(test_sentence):
# Removing punctuations to create a test example
translation = {punct: "" for punct in string.punctuation}
sentence_no_punct = test_sentence.translate(str.maketrans(translation))
sentence_no_punct = sentence_no_punct.lower().strip()

punct_text = Punctuation.add_punctuation_transcript(sentence_no_punct)
assert len(re.findall("["+string.punctuation+"]", punct_text)) > 0
94 changes: 94 additions & 0 deletions test/test_transcribe_yt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pytest

from youtube_transcript_api._transcripts import Transcript
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api._errors import TranscriptsDisabled
from urllib.parse import urlparse
from urllib.parse import parse_qs
import nltk

from source.main.transcribe_yt import TranscribeYtVideo
from source.main.punctuation import Punctuation

# Khan Academy's video on agriculture
# contains transcripts
# contains closed captions
video_one_url = "https://www.youtube.com/watch?v=JvBHwVpBCwM"

# Lecture 38 — Bloom Filters | Mining of Massive Datasets | Stanford University
# contains no transcripts
video_two_url = "https://www.youtube.com/watch?v=qBTdukbzc78"

@pytest.mark.parametrize('youtube_url, test_index', [
(video_one_url, 1),
(video_two_url, 2)
])
def test_check_yt_cc(youtube_url, test_index):
transc_obj = TranscribeYtVideo(youtube_url)
transcript = transc_obj.check_yt_cc()
if test_index == 1:
# Check for type
assert isinstance(transcript, Transcript)
# Check for language correctness
assert transcript.language == "English - Default"

# Check if video ID is the same
parsed = urlparse(youtube_url)
assert transcript.video_id == parse_qs(parsed.query)["v"][0]
elif test_index == 2:
assert transcript == None

def fetch_transcript(yt_id):
full_text = YouTubeTranscriptApi.get_transcript(yt_id)
transcript_text = str()
for rec in full_text:
transcript_text += " " + rec["text"]
punctuated_transcription = Punctuation.add_punctuation_transcript(
transcript_text
)
return punctuated_transcription

@pytest.mark.parametrize('youtube_url, test_index', [
(video_one_url, 1),
(video_two_url, 2)
])
def test_transcribe_yt_video_w_cc(youtube_url, test_index):
transc_obj = TranscribeYtVideo(youtube_url)
assert transc_obj.summary == ""
if test_index == 1:
N = 10
transc_obj.transcribe_yt_video_w_cc()
# checks if the summary is still a string type
assert isinstance(transc_obj.summary, str)
# checks if summary has atleast N characters
assert len(transc_obj.summary) > N

# calculates the BLEU score
full_text = fetch_transcript(transc_obj.yt_id)
full_text_tokens = nltk.tokenize.word_tokenize(full_text)
summary_tokens = nltk.tokenize.word_tokenize(transc_obj.summary)
score = nltk.translate.bleu_score.sentence_bleu(
[full_text_tokens],
summary_tokens,
weights=(1,)
)
# print(f"BLEU score for {youtube_url} is {score}", flush=True)
assert score > 0
elif test_index == 2:
try:
transc_obj.transcribe_yt_video_w_cc()
except Exception as e:
assert isinstance(e, TranscriptsDisabled)

@pytest.mark.parametrize('youtube_url, test_index', [
(video_one_url, 1),
(video_two_url, 2)
])
def test_transcribe_yt_video(youtube_url, test_index):
transc_obj = TranscribeYtVideo(youtube_url)
assert transc_obj.summary == str()

transc_obj.transcribe_yt_video()
N = 10
assert transc_obj.summary != str()
assert len(transc_obj.summary) > N

0 comments on commit 7a9ac6b

Please sign in to comment.