mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Rewrote multiprocessing parts to only require one database connection
List of changes in detail: - fingerprint_directory: Now uses mp.Pool.imap and updates database in callers process. - fingerprint_file: Now uses _fingerprint_worker internally. - _fingerprint_worker: Some slight changes to argument handling due to `imap` usage. Changed to return (song_name, hashes) where hashes is a set of hashes from all channels. - path_to_songname: Changed to use `os.path.splitext`, this caused files with a period in their name to return the wrong name.
This commit is contained in:
parent
cd7adc1485
commit
64a6161b90
2 changed files with 52 additions and 43 deletions
|
@ -43,7 +43,7 @@ class Dejavu(object):
|
|||
|
||||
pool = multiprocessing.Pool(nprocesses)
|
||||
|
||||
results = []
|
||||
filenames_to_fingerprint = []
|
||||
for filename, _ in decoder.find_files(path, extensions):
|
||||
|
||||
# don't refingerprint already fingerprinted files
|
||||
|
@ -51,39 +51,46 @@ class Dejavu(object):
|
|||
print "%s already fingerprinted, continuing..." % filename
|
||||
continue
|
||||
|
||||
result = pool.apply_async(_fingerprint_worker,
|
||||
(filename, self.db, self.limit))
|
||||
results.append(result)
|
||||
filenames_to_fingerprint.append(filename)
|
||||
|
||||
while len(results):
|
||||
for result in results[:]:
|
||||
# TODO: Handle errors gracefully and return them to the callee
|
||||
# in some way.
|
||||
# Prepare _fingerprint_worker input
|
||||
worker_input = zip(filenames_to_fingerprint,
|
||||
[self.limit] * len(filenames_to_fingerprint))
|
||||
|
||||
# Send off our tasks
|
||||
iterator = pool.imap_unordered(_fingerprint_worker,
|
||||
worker_input)
|
||||
|
||||
# Loop till we have all of them
|
||||
while True:
|
||||
try:
|
||||
result.get(timeout=2)
|
||||
song_name, hashes = iterator.next()
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
except StopIteration:
|
||||
break
|
||||
except:
|
||||
print("Failed fingerprinting")
|
||||
|
||||
# Print traceback because we can't reraise it here
|
||||
import traceback, sys
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
results.remove(result)
|
||||
else:
|
||||
results.remove(result)
|
||||
sid = self.db.insert_song(song_name)
|
||||
|
||||
self.db.insert_hashes(sid, hashes)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
channels, Fs = decoder.read(filepath)
|
||||
song_name, hashes = _fingerprint_worker(filepath,
|
||||
self.limit,
|
||||
song_name=song_name)
|
||||
|
||||
if not song_name:
|
||||
print "Song name: %s" % song_name
|
||||
song_name = decoder.path_to_songname(filepath)
|
||||
song_id = self.db.insert_song(song_name)
|
||||
sid = self.db.insert_song(song_name)
|
||||
|
||||
for data in channels:
|
||||
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
||||
self.db.insert_hashes(song_id, hashes)
|
||||
self.db.insert_hashes(sid, hashes)
|
||||
|
||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||
|
@ -140,13 +147,21 @@ class Dejavu(object):
|
|||
return r.recognize(*options, **kwoptions)
|
||||
|
||||
|
||||
def _fingerprint_worker(filename, db, limit):
|
||||
song_name, extension = os.path.splitext(os.path.basename(filename))
|
||||
def _fingerprint_worker(filename, limit=None, song_name=None):
|
||||
# Pool.imap sends arguments as tuples so we have to unpack
|
||||
# them ourself.
|
||||
try:
|
||||
filename, limit = filename
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
songname, extension = os.path.splitext(os.path.basename(filename))
|
||||
|
||||
song_name = song_name or songname
|
||||
|
||||
channels, Fs = decoder.read(filename, limit)
|
||||
|
||||
# insert song into database
|
||||
sid = db.insert_song(song_name)
|
||||
result = set()
|
||||
|
||||
channel_amount = len(channels)
|
||||
for channeln, channel in enumerate(channels):
|
||||
|
@ -158,15 +173,9 @@ def _fingerprint_worker(filename, db, limit):
|
|||
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
||||
filename))
|
||||
|
||||
print("Inserting fingerprints for channel %d/%d for %s" %
|
||||
(channeln + 1, channel_amount, filename))
|
||||
db.insert_hashes(sid, hashes)
|
||||
print("Finished inserting for channel %d/%d for %s" %
|
||||
(channeln + 1, channel_amount, filename))
|
||||
result |= set(hashes)
|
||||
|
||||
print("Marking %s finished" % (filename,))
|
||||
db.set_song_fingerprinted(sid)
|
||||
print("%s finished" % (filename,))
|
||||
return song_name, result
|
||||
|
||||
|
||||
def chunkify(lst, n):
|
||||
|
|
|
@ -45,4 +45,4 @@ def path_to_songname(path):
|
|||
Extracts song name from a filepath. Used to identify which songs
|
||||
have already been fingerprinted on disk.
|
||||
"""
|
||||
return os.path.basename(path).split(".")[0]
|
||||
return os.path.splitext(os.path.basename(path))[0]
|
||||
|
|
Loading…
Reference in a new issue