From 64a6161b9060e9f219f8d02e0d97d4914428491a Mon Sep 17 00:00:00 2001 From: Wessie Date: Mon, 20 Jan 2014 20:53:25 +0100 Subject: [PATCH] Rewrote multiprocessing parts to only require one database connection List of changes in detail: - fingerprint_directory: Now uses mp.Pool.imap and updates database in callers process. - fingerprint_file: Now uses _fingerprint_worker internally. - _fingerprint_worker: Some slight changes to argument handling due to `imap` usage. Changed to return (song_name, hashes) where hashes is a set of hashes from all channels. - path_to_songname: Changed to use `os.path.splitext`, this caused files with a period in their name to return the wrong name. --- dejavu/__init__.py | 91 +++++++++++++++++++++++++--------------------- dejavu/decoder.py | 4 +- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/dejavu/__init__.py b/dejavu/__init__.py index d18551a..bbd5b0b 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -16,8 +16,8 @@ class Dejavu(object): self.db = db_cls(**config.get("database", {})) self.db.setup() - - # if we should limit seconds fingerprinted, + + # if we should limit seconds fingerprinted, # None|-1 means use entire track self.limit = self.config.get("fingerprint_limit", None) if self.limit == -1: # for JSON compatibility @@ -43,7 +43,7 @@ class Dejavu(object): pool = multiprocessing.Pool(nprocesses) - results = [] + filenames_to_fingerprint = [] for filename, _ in decoder.find_files(path, extensions): # don't refingerprint already fingerprinted files @@ -51,39 +51,46 @@ class Dejavu(object): print "%s already fingerprinted, continuing..." % filename continue - result = pool.apply_async(_fingerprint_worker, - (filename, self.db, self.limit)) - results.append(result) + filenames_to_fingerprint.append(filename) - while len(results): - for result in results[:]: - # TODO: Handle errors gracefully and return them to the callee - # in some way. - try: - result.get(timeout=2) - except multiprocessing.TimeoutError: - continue - except: - import traceback, sys - traceback.print_exc(file=sys.stdout) - results.remove(result) - else: - results.remove(result) + # Prepare _fingerprint_worker input + worker_input = zip(filenames_to_fingerprint, + [self.limit] * len(filenames_to_fingerprint)) + + # Send off our tasks + iterator = pool.imap_unordered(_fingerprint_worker, + worker_input) + + # Loop till we have all of them + while True: + try: + song_name, hashes = iterator.next() + except multiprocessing.TimeoutError: + continue + except StopIteration: + break + except: + print("Failed fingerprinting") + + # Print traceback because we can't reraise it here + import traceback, sys + traceback.print_exc(file=sys.stdout) + else: + sid = self.db.insert_song(song_name) + + self.db.insert_hashes(sid, hashes) pool.close() pool.join() def fingerprint_file(self, filepath, song_name=None): - channels, Fs = decoder.read(filepath) + song_name, hashes = _fingerprint_worker(filepath, + self.limit, + song_name=song_name) - if not song_name: - print "Song name: %s" % song_name - song_name = decoder.path_to_songname(filepath) - song_id = self.db.insert_song(song_name) + sid = self.db.insert_song(song_name) - for data in channels: - hashes = fingerprint.fingerprint(data, Fs=Fs) - self.db.insert_hashes(song_id, hashes) + self.db.insert_hashes(sid, hashes) def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): hashes = fingerprint.fingerprint(samples, Fs=Fs) @@ -130,7 +137,7 @@ class Dejavu(object): "song_id": song_id, "song_name": songname, "confidence": largest_count, - "offset" : largest + "offset": largest } return song @@ -140,13 +147,21 @@ class Dejavu(object): return r.recognize(*options, **kwoptions) -def _fingerprint_worker(filename, db, limit): - song_name, extension = os.path.splitext(os.path.basename(filename)) +def _fingerprint_worker(filename, limit=None, song_name=None): + # Pool.imap sends arguments as tuples so we have to unpack + # them ourself. + try: + filename, limit = filename + except ValueError: + pass + + songname, extension = os.path.splitext(os.path.basename(filename)) + + song_name = song_name or songname channels, Fs = decoder.read(filename, limit) - # insert song into database - sid = db.insert_song(song_name) + result = set() channel_amount = len(channels) for channeln, channel in enumerate(channels): @@ -158,15 +173,9 @@ def _fingerprint_worker(filename, db, limit): print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount, filename)) - print("Inserting fingerprints for channel %d/%d for %s" % - (channeln + 1, channel_amount, filename)) - db.insert_hashes(sid, hashes) - print("Finished inserting for channel %d/%d for %s" % - (channeln + 1, channel_amount, filename)) + result |= set(hashes) - print("Marking %s finished" % (filename,)) - db.set_song_fingerprinted(sid) - print("%s finished" % (filename,)) + return song_name, result def chunkify(lst, n): diff --git a/dejavu/decoder.py b/dejavu/decoder.py index d01a8ef..e2e2d33 100644 --- a/dejavu/decoder.py +++ b/dejavu/decoder.py @@ -43,6 +43,6 @@ def read(filename, limit=None): def path_to_songname(path): """ Extracts song name from a filepath. Used to identify which songs - have already been fingerprinted on disk. + have already been fingerprinted on disk. """ - return os.path.basename(path).split(".")[0] + return os.path.splitext(os.path.basename(path))[0]