Skip to content

Commit

Permalink
added manifest
Browse files Browse the repository at this point in the history
  • Loading branch information
edsu committed Apr 13, 2024
1 parent 5a1fd1b commit 83292dc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 21 deletions.
28 changes: 18 additions & 10 deletions run
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ import argparse
import datetime
import logging
import os
import sys

from transcribe import aws, google, whisper

parser = argparse.ArgumentParser(
prog="run", description="Run transcription generation for sample data"
)

parser.add_argument("--output_dir", help="Path to a directory to write results")
parser.add_argument("--output-dir", help="Path to a directory to write results")
parser.add_argument("--manifest", default="data.csv", help="Path to data manifest CSV")
parser.add_argument(
"--only",
choices=["whisper", "preprocessing", "aws", "google"],
Expand All @@ -27,6 +29,10 @@ if output_dir is None:
if not os.path.isdir(output_dir):
os.makedirs(output_dir)

# ensure manifest CSV exists
if not os.path.isfile(args.manifest):
sys.exit(f"manifest file {args.manifest} doesn't exist")

logging.basicConfig(
filename=os.path.join(output_dir, "transcribe.log"),
filemode="a",
Expand All @@ -35,18 +41,20 @@ logging.basicConfig(
level=logging.INFO,
)


# run one of the transcription types individually or run them all
if args.only == "whisper":
whisper.run(output_dir)
whisper.run(output_dir, args.manifest)
elif args.only == "preprocessing":
whisper.run_preprocessing(output_dir)
whisper.run_preprocessing(output_dir, args.manifest)
elif args.only == "aws":
aws.run(output_dir)
aws.run(output_dir, args.manifest)
elif args.only == "google":
google.run(output_dir)
google.run(output_dir, args.manifest)
else:
whisper.run(output_dir)
whisper.run_preprocessing(output_dir)
aws.run(output_dir)
google.run(output_dir)
whisper.run(output_dir, args.manifest)
print()
whisper.run_preprocessing(output_dir, args.manifest)
print()
aws.run(output_dir, args.manifest)
print()
google.run(output_dir, args.manifest)
4 changes: 2 additions & 2 deletions transcribe/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
dotenv.load_dotenv()


def run(output_dir):
def run(output_dir, manifest):
results = []
for file_metadata in tqdm.tqdm(utils.get_data_files(), desc="aws".ljust(10)):
for file_metadata in tqdm.tqdm(utils.get_data_files(manifest), desc="aws".ljust(10)):
file_metadata["run_count"] = len(results) + 1
file = file_metadata["media_filename"]

Expand Down
4 changes: 2 additions & 2 deletions transcribe/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from . import utils


def run(output_dir):
def run(output_dir, manifest):
results = []
for file_metadata in tqdm.tqdm(utils.get_data_files(), desc="google".ljust(10)):
for file_metadata in tqdm.tqdm(utils.get_data_files(manifest), desc="google".ljust(10)):
file_metadata["run_count"] = len(results) + 1
file = file_metadata["media_filename"]

Expand Down
5 changes: 2 additions & 3 deletions transcribe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
]


def get_data_files():
def get_data_files(manifest):
rows = []
data_csv = Path(__file__).parent.parent / "data.csv"
for row in csv.DictReader(open(data_csv)):
for row in csv.DictReader(open(manifest)):
rows.append(row)
return rows

Expand Down
8 changes: 4 additions & 4 deletions transcribe/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
]


def run(output_dir):
def run(output_dir, manifest):
combinations = list(whisper_option_combinations())
files = utils.get_data_files()
files = utils.get_data_files(manifest)
total = len(combinations) * len(files)
progress = tqdm.tqdm(total=total, desc="whisper".ljust(10))

Expand All @@ -60,9 +60,9 @@ def run(output_dir):
utils.write_report(results, csv_filename, extra_cols=["options"])


def run_preprocessing(output_dir):
def run_preprocessing(output_dir, manifest):
results = []
files = utils.get_data_files()
files = utils.get_data_files(manifest)
total = len(files) * len(preprocessing_combinations)
progress = tqdm.tqdm(total=total, desc="preprocess".ljust(10))

Expand Down

0 comments on commit 83292dc

Please sign in to comment.