mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Rewrote multiprocessing parts to only require one database connection
List of changes in detail: - fingerprint_directory: Now uses mp.Pool.imap and updates database in callers process. - fingerprint_file: Now uses _fingerprint_worker internally. - _fingerprint_worker: Some slight changes to argument handling due to `imap` usage. Changed to return (song_name, hashes) where hashes is a set of hashes from all channels. - path_to_songname: Changed to use `os.path.splitext`, this caused files with a period in their name to return the wrong name.
This commit is contained in:
parent
cd7adc1485
commit
64a6161b90
2 changed files with 52 additions and 43 deletions
|
@ -16,8 +16,8 @@ class Dejavu(object):
|
||||||
|
|
||||||
self.db = db_cls(**config.get("database", {}))
|
self.db = db_cls(**config.get("database", {}))
|
||||||
self.db.setup()
|
self.db.setup()
|
||||||
|
|
||||||
# if we should limit seconds fingerprinted,
|
# if we should limit seconds fingerprinted,
|
||||||
# None|-1 means use entire track
|
# None|-1 means use entire track
|
||||||
self.limit = self.config.get("fingerprint_limit", None)
|
self.limit = self.config.get("fingerprint_limit", None)
|
||||||
if self.limit == -1: # for JSON compatibility
|
if self.limit == -1: # for JSON compatibility
|
||||||
|
@ -43,7 +43,7 @@ class Dejavu(object):
|
||||||
|
|
||||||
pool = multiprocessing.Pool(nprocesses)
|
pool = multiprocessing.Pool(nprocesses)
|
||||||
|
|
||||||
results = []
|
filenames_to_fingerprint = []
|
||||||
for filename, _ in decoder.find_files(path, extensions):
|
for filename, _ in decoder.find_files(path, extensions):
|
||||||
|
|
||||||
# don't refingerprint already fingerprinted files
|
# don't refingerprint already fingerprinted files
|
||||||
|
@ -51,39 +51,46 @@ class Dejavu(object):
|
||||||
print "%s already fingerprinted, continuing..." % filename
|
print "%s already fingerprinted, continuing..." % filename
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result = pool.apply_async(_fingerprint_worker,
|
filenames_to_fingerprint.append(filename)
|
||||||
(filename, self.db, self.limit))
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
while len(results):
|
# Prepare _fingerprint_worker input
|
||||||
for result in results[:]:
|
worker_input = zip(filenames_to_fingerprint,
|
||||||
# TODO: Handle errors gracefully and return them to the callee
|
[self.limit] * len(filenames_to_fingerprint))
|
||||||
# in some way.
|
|
||||||
try:
|
# Send off our tasks
|
||||||
result.get(timeout=2)
|
iterator = pool.imap_unordered(_fingerprint_worker,
|
||||||
except multiprocessing.TimeoutError:
|
worker_input)
|
||||||
continue
|
|
||||||
except:
|
# Loop till we have all of them
|
||||||
import traceback, sys
|
while True:
|
||||||
traceback.print_exc(file=sys.stdout)
|
try:
|
||||||
results.remove(result)
|
song_name, hashes = iterator.next()
|
||||||
else:
|
except multiprocessing.TimeoutError:
|
||||||
results.remove(result)
|
continue
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
print("Failed fingerprinting")
|
||||||
|
|
||||||
|
# Print traceback because we can't reraise it here
|
||||||
|
import traceback, sys
|
||||||
|
traceback.print_exc(file=sys.stdout)
|
||||||
|
else:
|
||||||
|
sid = self.db.insert_song(song_name)
|
||||||
|
|
||||||
|
self.db.insert_hashes(sid, hashes)
|
||||||
|
|
||||||
pool.close()
|
pool.close()
|
||||||
pool.join()
|
pool.join()
|
||||||
|
|
||||||
def fingerprint_file(self, filepath, song_name=None):
|
def fingerprint_file(self, filepath, song_name=None):
|
||||||
channels, Fs = decoder.read(filepath)
|
song_name, hashes = _fingerprint_worker(filepath,
|
||||||
|
self.limit,
|
||||||
|
song_name=song_name)
|
||||||
|
|
||||||
if not song_name:
|
sid = self.db.insert_song(song_name)
|
||||||
print "Song name: %s" % song_name
|
|
||||||
song_name = decoder.path_to_songname(filepath)
|
|
||||||
song_id = self.db.insert_song(song_name)
|
|
||||||
|
|
||||||
for data in channels:
|
self.db.insert_hashes(sid, hashes)
|
||||||
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
|
||||||
self.db.insert_hashes(song_id, hashes)
|
|
||||||
|
|
||||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||||
|
@ -130,7 +137,7 @@ class Dejavu(object):
|
||||||
"song_id": song_id,
|
"song_id": song_id,
|
||||||
"song_name": songname,
|
"song_name": songname,
|
||||||
"confidence": largest_count,
|
"confidence": largest_count,
|
||||||
"offset" : largest
|
"offset": largest
|
||||||
}
|
}
|
||||||
|
|
||||||
return song
|
return song
|
||||||
|
@ -140,13 +147,21 @@ class Dejavu(object):
|
||||||
return r.recognize(*options, **kwoptions)
|
return r.recognize(*options, **kwoptions)
|
||||||
|
|
||||||
|
|
||||||
def _fingerprint_worker(filename, db, limit):
|
def _fingerprint_worker(filename, limit=None, song_name=None):
|
||||||
song_name, extension = os.path.splitext(os.path.basename(filename))
|
# Pool.imap sends arguments as tuples so we have to unpack
|
||||||
|
# them ourself.
|
||||||
|
try:
|
||||||
|
filename, limit = filename
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
songname, extension = os.path.splitext(os.path.basename(filename))
|
||||||
|
|
||||||
|
song_name = song_name or songname
|
||||||
|
|
||||||
channels, Fs = decoder.read(filename, limit)
|
channels, Fs = decoder.read(filename, limit)
|
||||||
|
|
||||||
# insert song into database
|
result = set()
|
||||||
sid = db.insert_song(song_name)
|
|
||||||
|
|
||||||
channel_amount = len(channels)
|
channel_amount = len(channels)
|
||||||
for channeln, channel in enumerate(channels):
|
for channeln, channel in enumerate(channels):
|
||||||
|
@ -158,15 +173,9 @@ def _fingerprint_worker(filename, db, limit):
|
||||||
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
||||||
filename))
|
filename))
|
||||||
|
|
||||||
print("Inserting fingerprints for channel %d/%d for %s" %
|
result |= set(hashes)
|
||||||
(channeln + 1, channel_amount, filename))
|
|
||||||
db.insert_hashes(sid, hashes)
|
|
||||||
print("Finished inserting for channel %d/%d for %s" %
|
|
||||||
(channeln + 1, channel_amount, filename))
|
|
||||||
|
|
||||||
print("Marking %s finished" % (filename,))
|
return song_name, result
|
||||||
db.set_song_fingerprinted(sid)
|
|
||||||
print("%s finished" % (filename,))
|
|
||||||
|
|
||||||
|
|
||||||
def chunkify(lst, n):
|
def chunkify(lst, n):
|
||||||
|
|
|
@ -43,6 +43,6 @@ def read(filename, limit=None):
|
||||||
def path_to_songname(path):
|
def path_to_songname(path):
|
||||||
"""
|
"""
|
||||||
Extracts song name from a filepath. Used to identify which songs
|
Extracts song name from a filepath. Used to identify which songs
|
||||||
have already been fingerprinted on disk.
|
have already been fingerprinted on disk.
|
||||||
"""
|
"""
|
||||||
return os.path.basename(path).split(".")[0]
|
return os.path.splitext(os.path.basename(path))[0]
|
||||||
|
|
Loading…
Reference in a new issue