diff --git a/dejavu/__init__.py b/dejavu/__init__.py index e93edac..bb52f1e 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -1,9 +1,8 @@ from dejavu.database import get_database import dejavu.decoder as decoder import fingerprint -from multiprocessing import Process, cpu_count +import multiprocessing import os -import random class Dejavu(object): @@ -30,57 +29,39 @@ class Dejavu(object): def fingerprint_directory(self, path, extensions, nprocesses=None): # Try to use the maximum amount of processes if not given. - if nprocesses is None: - try: - nprocesses = cpu_count() - except NotImplementedError: - nprocesses = 1 + try: + nprocesses = nprocesses or multiprocessing.cpu_count() + except NotImplementedError: + nprocesses = 1 + else: + nprocesses = 1 if nprocesses <= 0 else nprocesses - # convert files, shuffle order - files = list(decoder.find_files(path, extensions)) - random.shuffle(files) + pool = multiprocessing.Pool(nprocesses) - files_split = chunkify(files, nprocesses) + results = [] + for filename, _ in decoder.find_files(path, extensions): + # TODO: Don't queue up files that have already been fingerprinted. + result = pool.apply_async(_fingerprint_worker, + (filename, self.db)) + results.append(result) - # split into processes here - processes = [] - for i in range(nprocesses): + while len(results): + for result in results[:]: + # TODO: Handle errors gracefully and return them to the callee + # in some way. + try: + result.get(timeout=2) + except multiprocessing.TimeoutError: + continue + except: + import traceback, sys + traceback.print_exc(file=sys.stdout) + results.remove(result) + else: + results.remove(result) - # create process and start it - p = Process(target=self._fingerprint_worker, - args=(files_split[i], self.db)) - p.start() - processes.append(p) - - # wait for all processes to complete - for p in processes: - p.join() - - def _fingerprint_worker(self, files, db): - for filename, extension in files: - - # if there are already fingerprints in database, - # don't re-fingerprint - song_name = os.path.basename(filename).split(".")[0] - if song_name in self.songnames_set: - print("-> Already fingerprinted, continuing...") - continue - - channels, Fs = decoder.read(filename) - - # insert song name into database - song_id = db.insert_song(song_name) - - for c in range(len(channels)): - channel = channels[c] - print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name) - - hashes = fingerprint.fingerprint(channel, Fs=Fs) - - db.insert_hashes(song_id, hashes) - - # only after done fingerprinting do confirm - db.set_song_fingerprinted(song_id) + pool.close() + pool.join() def fingerprint_file(self, filepath, song_name=None): channels, Fs = decoder.read(filepath) @@ -122,12 +103,14 @@ class Dejavu(object): largest_count = diff_counter[diff][sid] song_id = sid - print("Diff is %d with %d offset-aligned matches" % (largest, largest_count)) + print("Diff is %d with %d offset-aligned matches" % (largest, + largest_count)) # extract idenfication song = self.db.get_song_by_id(song_id) if song: - songname = song.get(SQLDatabase.FIELD_SONGNAME, None) + # TODO: Clarifey what `get_song_by_id` should return. + songname = song.get("song_name", None) else: return None @@ -145,6 +128,35 @@ class Dejavu(object): return r.recognize(*options, **kwoptions) +def _fingerprint_worker(filename, db): + song_name, extension = os.path.splitext(os.path.basename(filename)) + + channels, Fs = decoder.read(filename) + + # insert song into database + sid = db.insert_song(song_name) + + channel_amount = len(channels) + for channeln, channel in enumerate(channels): + # TODO: Remove prints or change them into optional logging. + print("Fingerprinting channel %d/%d for %s" % (channeln + 1, + channel_amount, + filename)) + hashes = fingerprint.fingerprint(channel, Fs=Fs) + print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount, + filename)) + + print("Inserting fingerprints for channel %d/%d for %s" % + (channeln + 1, channel_amount, filename)) + db.insert_hashes(sid, hashes) + print("Finished inserting for channel %d/%d for %s" % + (channeln + 1, channel_amount, filename)) + + print("Marking %s finished" % (filename,)) + db.set_song_fingerprinted(sid) + print("%s finished" % (filename,)) + + def chunkify(lst, n): """ Splits a list into roughly n equal parts. diff --git a/dejavu/database_sql.py b/dejavu/database_sql.py index a93b1f9..565d83f 100644 --- a/dejavu/database_sql.py +++ b/dejavu/database_sql.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -from itertools import izip_longest, ifilter +from itertools import izip_longest import Queue import MySQLdb as mysql @@ -312,7 +312,8 @@ class SQLDatabase(Database): def grouper(iterable, n, fillvalue=None): args = [iter(iterable)] * n - return (ifilter(None, values) for values in izip_longest(fillvalue=fillvalue, *args)) + return (filter(None, values) for values + in izip_longest(fillvalue=fillvalue, *args)) def cursor_factory(**factory_options):