diff --git a/dejavu/__init__.py b/dejavu/__init__.py index fac72bc5..1be47f0f 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -1,7 +1,7 @@ -import multiprocessing import os import sys import traceback +from concurrent.futures import ProcessPoolExecutor, as_completed from itertools import groupby from time import time from typing import Dict, List, Tuple @@ -33,6 +33,8 @@ def __init__(self, config): self.limit = self.config.get("fingerprint_limit", None) if self.limit == -1: # for JSON compatibility self.limit = None + self.songs = None + self.songhashes_set = set() # to know which ones we've computed before self.__load_fingerprinted_audio_hashes() def __load_fingerprinted_audio_hashes(self) -> None: @@ -44,8 +46,7 @@ def __load_fingerprinted_audio_hashes(self) -> None: self.songs = self.db.get_songs() self.songhashes_set = set() # to know which ones we've computed before for song in self.songs: - song_hash = song[FIELD_FILE_SHA1] - self.songhashes_set.add(song_hash) + self.songhashes_set.add(song[FIELD_FILE_SHA1]) def get_fingerprinted_songs(self) -> List[Dict[str, any]]: """ @@ -71,52 +72,37 @@ def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = No :param extensions: list of file extensions to consider. :param nprocesses: amount of processes to fingerprint the files within the directory. """ - # Try to use the maximum amount of processes if not given. - try: - nprocesses = nprocesses or multiprocessing.cpu_count() - except NotImplementedError: - nprocesses = 1 - else: - nprocesses = 1 if nprocesses <= 0 else nprocesses - - pool = multiprocessing.Pool(nprocesses) - - filenames_to_fingerprint = [] - for filename, _ in decoder.find_files(path, extensions): - # don't refingerprint already fingerprinted files - if decoder.unique_hash(filename) in self.songhashes_set: - print(f"{filename} already fingerprinted, continuing...") - continue - - filenames_to_fingerprint.append(filename) - - # Prepare _fingerprint_worker input - worker_input = list(zip(filenames_to_fingerprint, [self.limit] * len(filenames_to_fingerprint))) - - # Send off our tasks - iterator = pool.imap_unordered(Dejavu._fingerprint_worker, worker_input) - - # Loop till we have all of them - while True: - try: - song_name, hashes, file_hash = next(iterator) - except multiprocessing.TimeoutError: - continue - except StopIteration: - break - except Exception: - print("Failed fingerprinting") - # Print traceback because we can't reraise it here - traceback.print_exc(file=sys.stdout) - else: - sid = self.db.insert_song(song_name, file_hash, len(hashes)) - - self.db.insert_hashes(sid, hashes) - self.db.set_song_fingerprinted(sid) - self.__load_fingerprinted_audio_hashes() - - pool.close() - pool.join() + nprocesses = int(nprocesses) if nprocesses is not None else None + + with ProcessPoolExecutor(max_workers=nprocesses) as executor: + futures = [] + for filename, _ in decoder.find_files_g(path, extensions): + # don't refingerprint already fingerprinted files + if decoder.unique_hash(filename) in self.songhashes_set: + print(f"{filename} already fingerprinted, continuing...") + else: + futures.append( + executor.submit( + self._fingerprint_worker, + filename, + self.limit, + ) + ) + for future in as_completed(futures): + try: + song_name, hashes, file_hash = future.result() + except StopIteration: + break + except Exception: + print("Failed fingerprinting") + # Print traceback because we can't reraise it here + traceback.print_exc(file=sys.stdout) + else: + sid = self.db.insert_song(song_name, file_hash, len(hashes)) + self.db.insert_hashes(sid, hashes) + self.db.set_song_fingerprinted(sid) + # Wait until all songs are processed to reload hashes + self.__load_fingerprinted_audio_hashes() def fingerprint_file(self, file_path: str, song_name: str = None) -> None: """ @@ -187,9 +173,9 @@ def align_matches(self, matches: List[Tuple[int, int]], dedup_hashes: Dict[str, """ # count offset occurrences per song and keep only the maximum ones. sorted_matches = sorted(matches, key=lambda m: (m[0], m[1])) - counts = [(*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1]))] + counts = ((*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1]))) songs_matches = sorted( - [max(list(group), key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])], + (max(group, key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])), key=lambda count: count[2], reverse=True ) @@ -226,18 +212,13 @@ def recognize(self, recognizer, *options, **kwoptions) -> Dict[str, any]: return r.recognize(*options, **kwoptions) @staticmethod - def _fingerprint_worker(arguments): - # Pool.imap sends arguments as tuples so we have to unpack - # them ourself. - try: - file_name, limit = arguments - except ValueError: - pass - - song_name, extension = os.path.splitext(os.path.basename(file_name)) - - fingerprints, file_hash = Dejavu.get_file_fingerprints(file_name, limit, print_output=True) - + def _fingerprint_worker(file_name, limit): + song_name = os.path.splitext(os.path.basename(file_name))[0] + # Suppressing print_output because MP will step all over itself + # while printing to stdout + fingerprints, file_hash = Dejavu.get_file_fingerprints( + file_name, limit, print_output=False + ) return song_name, fingerprints, file_hash @staticmethod diff --git a/dejavu/logic/decoder.py b/dejavu/logic/decoder.py index ccafa262..cab09bb6 100755 --- a/dejavu/logic/decoder.py +++ b/dejavu/logic/decoder.py @@ -1,7 +1,7 @@ import fnmatch import os from hashlib import sha1 -from typing import List, Tuple +from typing import Generator, List, Tuple import numpy as np from pydub import AudioSegment @@ -23,10 +23,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str: """ s = sha1() with open(file_path, "rb") as f: - while True: - buf = f.read(block_size) - if not buf: - break + while buf := f.read(block_size): s.update(buf) return s.hexdigest().upper() @@ -51,6 +48,30 @@ def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]: return results +def find_files_g(path: str, extensions: List[str]) -> Generator[Tuple[str, str], None, None]: + """ + Get all files that meet the specified extensions. + + :param path: path to a directory with audio files. + :param extensions: file extensions to look for. + :yields: a tuple with file name and its extension. + """ + # Allow both with ".mp3" and without "mp3" to be used for extensions + norm_extensions = set() + for extension in extensions: + extension = extension.lower() + norm_extensions.add(extension) + if extension.startswith('.'): + norm_extensions.add(extension.lstrip('.')) + else: + norm_extensions.add(f'.{extension}') + for root, dirs, files in os.walk(path): + for f in files: + ext = os.path.splitext(f)[1].lower() + if ext in norm_extensions: + yield os.path.join(root, f), ext + + def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]: """ Reads any file supported by pydub (ffmpeg) and returns the data contained @@ -74,9 +95,9 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]: data = np.fromstring(audiofile.raw_data, np.int16) - channels = [] - for chn in range(audiofile.channels): - channels.append(data[chn::audiofile.channels]) + channels = [ + data[chn::audiofile.channels] for chn in range(audiofile.channels) + ] audiofile.frame_rate except audioop.error: @@ -88,9 +109,7 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]: audiofile = audiofile.T audiofile = audiofile.astype(np.int16) - channels = [] - for chn in audiofile: - channels.append(chn) + channels = [chn for chn in audiofile] return channels, audiofile.frame_rate, unique_hash(file_name)