Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Tweaks, Iterators, and Lazy Evaluations #308

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
107 changes: 44 additions & 63 deletions dejavu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import multiprocessing
import os
import sys
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import groupby
from time import time
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -33,6 +33,8 @@ def __init__(self, config):
self.limit = self.config.get("fingerprint_limit", None)
if self.limit == -1: # for JSON compatibility
self.limit = None
self.songs = None
self.songhashes_set = set() # to know which ones we've computed before
self.__load_fingerprinted_audio_hashes()

def __load_fingerprinted_audio_hashes(self) -> None:
Expand All @@ -44,8 +46,7 @@ def __load_fingerprinted_audio_hashes(self) -> None:
self.songs = self.db.get_songs()
self.songhashes_set = set() # to know which ones we've computed before
for song in self.songs:
song_hash = song[FIELD_FILE_SHA1]
self.songhashes_set.add(song_hash)
self.songhashes_set.add(song[FIELD_FILE_SHA1])

def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
"""
Expand All @@ -71,52 +72,37 @@ def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = No
:param extensions: list of file extensions to consider.
:param nprocesses: amount of processes to fingerprint the files within the directory.
"""
# Try to use the maximum amount of processes if not given.
try:
nprocesses = nprocesses or multiprocessing.cpu_count()
except NotImplementedError:
nprocesses = 1
else:
nprocesses = 1 if nprocesses <= 0 else nprocesses

pool = multiprocessing.Pool(nprocesses)

filenames_to_fingerprint = []
for filename, _ in decoder.find_files(path, extensions):
# don't refingerprint already fingerprinted files
if decoder.unique_hash(filename) in self.songhashes_set:
print(f"{filename} already fingerprinted, continuing...")
continue

filenames_to_fingerprint.append(filename)

# Prepare _fingerprint_worker input
worker_input = list(zip(filenames_to_fingerprint, [self.limit] * len(filenames_to_fingerprint)))

# Send off our tasks
iterator = pool.imap_unordered(Dejavu._fingerprint_worker, worker_input)

# Loop till we have all of them
while True:
try:
song_name, hashes, file_hash = next(iterator)
except multiprocessing.TimeoutError:
continue
except StopIteration:
break
except Exception:
print("Failed fingerprinting")
# Print traceback because we can't reraise it here
traceback.print_exc(file=sys.stdout)
else:
sid = self.db.insert_song(song_name, file_hash, len(hashes))

self.db.insert_hashes(sid, hashes)
self.db.set_song_fingerprinted(sid)
self.__load_fingerprinted_audio_hashes()

pool.close()
pool.join()
nprocesses = int(nprocesses) if nprocesses is not None else None

with ProcessPoolExecutor(max_workers=nprocesses) as executor:
futures = []
for filename, _ in decoder.find_files_g(path, extensions):
# don't refingerprint already fingerprinted files
if decoder.unique_hash(filename) in self.songhashes_set:
print(f"{filename} already fingerprinted, continuing...")
else:
futures.append(
executor.submit(
self._fingerprint_worker,
filename,
self.limit,
)
)
for future in as_completed(futures):
try:
song_name, hashes, file_hash = future.result()
except StopIteration:
break
except Exception:
print("Failed fingerprinting")
# Print traceback because we can't reraise it here
traceback.print_exc(file=sys.stdout)
else:
sid = self.db.insert_song(song_name, file_hash, len(hashes))
self.db.insert_hashes(sid, hashes)
self.db.set_song_fingerprinted(sid)
# Wait until all songs are processed to reload hashes
self.__load_fingerprinted_audio_hashes()

def fingerprint_file(self, file_path: str, song_name: str = None) -> None:
"""
Expand Down Expand Up @@ -187,9 +173,9 @@ def align_matches(self, matches: List[Tuple[int, int]], dedup_hashes: Dict[str,
"""
# count offset occurrences per song and keep only the maximum ones.
sorted_matches = sorted(matches, key=lambda m: (m[0], m[1]))
counts = [(*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1]))]
counts = ((*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1])))
songs_matches = sorted(
[max(list(group), key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])],
(max(group, key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])),
key=lambda count: count[2], reverse=True
)

Expand Down Expand Up @@ -226,18 +212,13 @@ def recognize(self, recognizer, *options, **kwoptions) -> Dict[str, any]:
return r.recognize(*options, **kwoptions)

@staticmethod
def _fingerprint_worker(arguments):
# Pool.imap sends arguments as tuples so we have to unpack
# them ourself.
try:
file_name, limit = arguments
except ValueError:
pass

song_name, extension = os.path.splitext(os.path.basename(file_name))

fingerprints, file_hash = Dejavu.get_file_fingerprints(file_name, limit, print_output=True)

def _fingerprint_worker(file_name, limit):
song_name = os.path.splitext(os.path.basename(file_name))[0]
# Suppressing print_output because MP will step all over itself
# while printing to stdout
fingerprints, file_hash = Dejavu.get_file_fingerprints(
file_name, limit, print_output=False
)
return song_name, fingerprints, file_hash

@staticmethod
Expand Down
41 changes: 30 additions & 11 deletions dejavu/logic/decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fnmatch
import os
from hashlib import sha1
from typing import List, Tuple
from typing import Generator, List, Tuple

import numpy as np
from pydub import AudioSegment
Expand All @@ -23,10 +23,7 @@ def unique_hash(file_path: str, block_size: int = 2**20) -> str:
"""
s = sha1()
with open(file_path, "rb") as f:
while True:
buf = f.read(block_size)
if not buf:
break
while buf := f.read(block_size):
s.update(buf)
return s.hexdigest().upper()

Expand All @@ -51,6 +48,30 @@ def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
return results


def find_files_g(path: str, extensions: List[str]) -> Generator[Tuple[str, str], None, None]:
"""
Get all files that meet the specified extensions.

:param path: path to a directory with audio files.
:param extensions: file extensions to look for.
:yields: a tuple with file name and its extension.
"""
# Allow both with ".mp3" and without "mp3" to be used for extensions
norm_extensions = set()
for extension in extensions:
extension = extension.lower()
norm_extensions.add(extension)
if extension.startswith('.'):
norm_extensions.add(extension.lstrip('.'))
else:
norm_extensions.add(f'.{extension}')
for root, dirs, files in os.walk(path):
for f in files:
ext = os.path.splitext(f)[1].lower()
if ext in norm_extensions:
yield os.path.join(root, f), ext


def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
"""
Reads any file supported by pydub (ffmpeg) and returns the data contained
Expand All @@ -74,9 +95,9 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:

data = np.fromstring(audiofile.raw_data, np.int16)

channels = []
for chn in range(audiofile.channels):
channels.append(data[chn::audiofile.channels])
channels = [
data[chn::audiofile.channels] for chn in range(audiofile.channels)
]

audiofile.frame_rate
except audioop.error:
Expand All @@ -88,9 +109,7 @@ def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
audiofile = audiofile.T
audiofile = audiofile.astype(np.int16)

channels = []
for chn in audiofile:
channels.append(chn)
channels = [chn for chn in audiofile]

return channels, audiofile.frame_rate, unique_hash(file_name)

Expand Down