From 83292dc8f32bc30a003d0e71362ad12733f66473 Mon Sep 17 00:00:00 2001 From: Ed Summers Date: Sat, 13 Apr 2024 16:03:21 -0400 Subject: [PATCH] added manifest --- run | 28 ++++++++++++++++++---------- transcribe/aws.py | 4 ++-- transcribe/google.py | 4 ++-- transcribe/utils.py | 5 ++--- transcribe/whisper.py | 8 ++++---- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/run b/run index fa43cec..95c1177 100755 --- a/run +++ b/run @@ -4,6 +4,7 @@ import argparse import datetime import logging import os +import sys from transcribe import aws, google, whisper @@ -11,7 +12,8 @@ parser = argparse.ArgumentParser( prog="run", description="Run transcription generation for sample data" ) -parser.add_argument("--output_dir", help="Path to a directory to write results") +parser.add_argument("--output-dir", help="Path to a directory to write results") +parser.add_argument("--manifest", default="data.csv", help="Path to data manifest CSV") parser.add_argument( "--only", choices=["whisper", "preprocessing", "aws", "google"], @@ -27,6 +29,10 @@ if output_dir is None: if not os.path.isdir(output_dir): os.makedirs(output_dir) +# ensure manifest CSV exists +if not os.path.isfile(args.manifest): + sys.exit(f"manifest file {args.manifest} doesn't exist") + logging.basicConfig( filename=os.path.join(output_dir, "transcribe.log"), filemode="a", @@ -35,18 +41,20 @@ logging.basicConfig( level=logging.INFO, ) - # run one of the transcription types individually or run them all if args.only == "whisper": - whisper.run(output_dir) + whisper.run(output_dir, args.manifest) elif args.only == "preprocessing": - whisper.run_preprocessing(output_dir) + whisper.run_preprocessing(output_dir, args.manifest) elif args.only == "aws": - aws.run(output_dir) + aws.run(output_dir, args.manifest) elif args.only == "google": - google.run(output_dir) + google.run(output_dir, args.manifest) else: - whisper.run(output_dir) - whisper.run_preprocessing(output_dir) - aws.run(output_dir) - google.run(output_dir) + whisper.run(output_dir, args.manifest) + print() + whisper.run_preprocessing(output_dir, args.manifest) + print() + aws.run(output_dir, args.manifest) + print() + google.run(output_dir, args.manifest) diff --git a/transcribe/aws.py b/transcribe/aws.py index 215b371..67ec0a9 100644 --- a/transcribe/aws.py +++ b/transcribe/aws.py @@ -18,9 +18,9 @@ dotenv.load_dotenv() -def run(output_dir): +def run(output_dir, manifest): results = [] - for file_metadata in tqdm.tqdm(utils.get_data_files(), desc="aws".ljust(10)): + for file_metadata in tqdm.tqdm(utils.get_data_files(manifest), desc="aws".ljust(10)): file_metadata["run_count"] = len(results) + 1 file = file_metadata["media_filename"] diff --git a/transcribe/google.py b/transcribe/google.py index 7f763a1..6e18505 100644 --- a/transcribe/google.py +++ b/transcribe/google.py @@ -13,9 +13,9 @@ from . import utils -def run(output_dir): +def run(output_dir, manifest): results = [] - for file_metadata in tqdm.tqdm(utils.get_data_files(), desc="google".ljust(10)): + for file_metadata in tqdm.tqdm(utils.get_data_files(manifest), desc="google".ljust(10)): file_metadata["run_count"] = len(results) + 1 file = file_metadata["media_filename"] diff --git a/transcribe/utils.py b/transcribe/utils.py index cd2b364..1d5df79 100644 --- a/transcribe/utils.py +++ b/transcribe/utils.py @@ -28,10 +28,9 @@ ] -def get_data_files(): +def get_data_files(manifest): rows = [] - data_csv = Path(__file__).parent.parent / "data.csv" - for row in csv.DictReader(open(data_csv)): + for row in csv.DictReader(open(manifest)): rows.append(row) return rows diff --git a/transcribe/whisper.py b/transcribe/whisper.py index 1e2d5a5..41d2414 100644 --- a/transcribe/whisper.py +++ b/transcribe/whisper.py @@ -42,9 +42,9 @@ ] -def run(output_dir): +def run(output_dir, manifest): combinations = list(whisper_option_combinations()) - files = utils.get_data_files() + files = utils.get_data_files(manifest) total = len(combinations) * len(files) progress = tqdm.tqdm(total=total, desc="whisper".ljust(10)) @@ -60,9 +60,9 @@ def run(output_dir): utils.write_report(results, csv_filename, extra_cols=["options"]) -def run_preprocessing(output_dir): +def run_preprocessing(output_dir, manifest): results = [] - files = utils.get_data_files() + files = utils.get_data_files(manifest) total = len(files) * len(preprocessing_combinations) progress = tqdm.tqdm(total=total, desc="preprocess".ljust(10))