Rewrote multiprocessing parts to only require one database connection

List of changes in detail:
- fingerprint_directory:
	Now uses mp.Pool.imap and updates database in callers
	process.
- fingerprint_file:
	Now uses _fingerprint_worker internally.
- _fingerprint_worker:
	Some slight changes to argument handling due to `imap`
	usage.

	Changed to return (song_name, hashes) where hashes is
	a set of hashes from all channels.
- path_to_songname:
	Changed to use `os.path.splitext`, this caused files with
	a period in their name to return the wrong name.
This commit is contained in:
Wessie 2014-01-20 20:53:25 +01:00
parent cd7adc1485
commit 64a6161b90
2 changed files with 52 additions and 43 deletions

View file

@ -43,7 +43,7 @@ class Dejavu(object):
pool = multiprocessing.Pool(nprocesses)
results = []
filenames_to_fingerprint = []
for filename, _ in decoder.find_files(path, extensions):
# don't refingerprint already fingerprinted files
@ -51,39 +51,46 @@ class Dejavu(object):
print "%s already fingerprinted, continuing..." % filename
continue
result = pool.apply_async(_fingerprint_worker,
(filename, self.db, self.limit))
results.append(result)
filenames_to_fingerprint.append(filename)
while len(results):
for result in results[:]:
# TODO: Handle errors gracefully and return them to the callee
# in some way.
# Prepare _fingerprint_worker input
worker_input = zip(filenames_to_fingerprint,
[self.limit] * len(filenames_to_fingerprint))
# Send off our tasks
iterator = pool.imap_unordered(_fingerprint_worker,
worker_input)
# Loop till we have all of them
while True:
try:
result.get(timeout=2)
song_name, hashes = iterator.next()
except multiprocessing.TimeoutError:
continue
except StopIteration:
break
except:
print("Failed fingerprinting")
# Print traceback because we can't reraise it here
import traceback, sys
traceback.print_exc(file=sys.stdout)
results.remove(result)
else:
results.remove(result)
sid = self.db.insert_song(song_name)
self.db.insert_hashes(sid, hashes)
pool.close()
pool.join()
def fingerprint_file(self, filepath, song_name=None):
channels, Fs = decoder.read(filepath)
song_name, hashes = _fingerprint_worker(filepath,
self.limit,
song_name=song_name)
if not song_name:
print "Song name: %s" % song_name
song_name = decoder.path_to_songname(filepath)
song_id = self.db.insert_song(song_name)
sid = self.db.insert_song(song_name)
for data in channels:
hashes = fingerprint.fingerprint(data, Fs=Fs)
self.db.insert_hashes(song_id, hashes)
self.db.insert_hashes(sid, hashes)
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
hashes = fingerprint.fingerprint(samples, Fs=Fs)
@ -130,7 +137,7 @@ class Dejavu(object):
"song_id": song_id,
"song_name": songname,
"confidence": largest_count,
"offset" : largest
"offset": largest
}
return song
@ -140,13 +147,21 @@ class Dejavu(object):
return r.recognize(*options, **kwoptions)
def _fingerprint_worker(filename, db, limit):
song_name, extension = os.path.splitext(os.path.basename(filename))
def _fingerprint_worker(filename, limit=None, song_name=None):
# Pool.imap sends arguments as tuples so we have to unpack
# them ourself.
try:
filename, limit = filename
except ValueError:
pass
songname, extension = os.path.splitext(os.path.basename(filename))
song_name = song_name or songname
channels, Fs = decoder.read(filename, limit)
# insert song into database
sid = db.insert_song(song_name)
result = set()
channel_amount = len(channels)
for channeln, channel in enumerate(channels):
@ -158,15 +173,9 @@ def _fingerprint_worker(filename, db, limit):
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
filename))
print("Inserting fingerprints for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
db.insert_hashes(sid, hashes)
print("Finished inserting for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
result |= set(hashes)
print("Marking %s finished" % (filename,))
db.set_song_fingerprinted(sid)
print("%s finished" % (filename,))
return song_name, result
def chunkify(lst, n):

View file

@ -45,4 +45,4 @@ def path_to_songname(path):
Extracts song name from a filepath. Used to identify which songs
have already been fingerprinted on disk.
"""
return os.path.basename(path).split(".")[0]
return os.path.splitext(os.path.basename(path))[0]