Rewrote multiprocessing parts to only require one database connection

List of changes in detail:
- fingerprint_directory:
	Now uses mp.Pool.imap and updates database in callers
	process.
- fingerprint_file:
	Now uses _fingerprint_worker internally.
- _fingerprint_worker:
	Some slight changes to argument handling due to `imap`
	usage.

	Changed to return (song_name, hashes) where hashes is
	a set of hashes from all channels.
- path_to_songname:
	Changed to use `os.path.splitext`, this caused files with
	a period in their name to return the wrong name.
This commit is contained in:
Wessie 2014-01-20 20:53:25 +01:00
parent cd7adc1485
commit 64a6161b90
2 changed files with 52 additions and 43 deletions

View file

@ -43,7 +43,7 @@ class Dejavu(object):
pool = multiprocessing.Pool(nprocesses) pool = multiprocessing.Pool(nprocesses)
results = [] filenames_to_fingerprint = []
for filename, _ in decoder.find_files(path, extensions): for filename, _ in decoder.find_files(path, extensions):
# don't refingerprint already fingerprinted files # don't refingerprint already fingerprinted files
@ -51,39 +51,46 @@ class Dejavu(object):
print "%s already fingerprinted, continuing..." % filename print "%s already fingerprinted, continuing..." % filename
continue continue
result = pool.apply_async(_fingerprint_worker, filenames_to_fingerprint.append(filename)
(filename, self.db, self.limit))
results.append(result)
while len(results): # Prepare _fingerprint_worker input
for result in results[:]: worker_input = zip(filenames_to_fingerprint,
# TODO: Handle errors gracefully and return them to the callee [self.limit] * len(filenames_to_fingerprint))
# in some way.
# Send off our tasks
iterator = pool.imap_unordered(_fingerprint_worker,
worker_input)
# Loop till we have all of them
while True:
try: try:
result.get(timeout=2) song_name, hashes = iterator.next()
except multiprocessing.TimeoutError: except multiprocessing.TimeoutError:
continue continue
except StopIteration:
break
except: except:
print("Failed fingerprinting")
# Print traceback because we can't reraise it here
import traceback, sys import traceback, sys
traceback.print_exc(file=sys.stdout) traceback.print_exc(file=sys.stdout)
results.remove(result)
else: else:
results.remove(result) sid = self.db.insert_song(song_name)
self.db.insert_hashes(sid, hashes)
pool.close() pool.close()
pool.join() pool.join()
def fingerprint_file(self, filepath, song_name=None): def fingerprint_file(self, filepath, song_name=None):
channels, Fs = decoder.read(filepath) song_name, hashes = _fingerprint_worker(filepath,
self.limit,
song_name=song_name)
if not song_name: sid = self.db.insert_song(song_name)
print "Song name: %s" % song_name
song_name = decoder.path_to_songname(filepath)
song_id = self.db.insert_song(song_name)
for data in channels: self.db.insert_hashes(sid, hashes)
hashes = fingerprint.fingerprint(data, Fs=Fs)
self.db.insert_hashes(song_id, hashes)
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
hashes = fingerprint.fingerprint(samples, Fs=Fs) hashes = fingerprint.fingerprint(samples, Fs=Fs)
@ -130,7 +137,7 @@ class Dejavu(object):
"song_id": song_id, "song_id": song_id,
"song_name": songname, "song_name": songname,
"confidence": largest_count, "confidence": largest_count,
"offset" : largest "offset": largest
} }
return song return song
@ -140,13 +147,21 @@ class Dejavu(object):
return r.recognize(*options, **kwoptions) return r.recognize(*options, **kwoptions)
def _fingerprint_worker(filename, db, limit): def _fingerprint_worker(filename, limit=None, song_name=None):
song_name, extension = os.path.splitext(os.path.basename(filename)) # Pool.imap sends arguments as tuples so we have to unpack
# them ourself.
try:
filename, limit = filename
except ValueError:
pass
songname, extension = os.path.splitext(os.path.basename(filename))
song_name = song_name or songname
channels, Fs = decoder.read(filename, limit) channels, Fs = decoder.read(filename, limit)
# insert song into database result = set()
sid = db.insert_song(song_name)
channel_amount = len(channels) channel_amount = len(channels)
for channeln, channel in enumerate(channels): for channeln, channel in enumerate(channels):
@ -158,15 +173,9 @@ def _fingerprint_worker(filename, db, limit):
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount, print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
filename)) filename))
print("Inserting fingerprints for channel %d/%d for %s" % result |= set(hashes)
(channeln + 1, channel_amount, filename))
db.insert_hashes(sid, hashes)
print("Finished inserting for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
print("Marking %s finished" % (filename,)) return song_name, result
db.set_song_fingerprinted(sid)
print("%s finished" % (filename,))
def chunkify(lst, n): def chunkify(lst, n):

View file

@ -45,4 +45,4 @@ def path_to_songname(path):
Extracts song name from a filepath. Used to identify which songs Extracts song name from a filepath. Used to identify which songs
have already been fingerprinted on disk. have already been fingerprinted on disk.
""" """
return os.path.basename(path).split(".")[0] return os.path.splitext(os.path.basename(path))[0]