mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Switched fingerprint_directory to using multiprocessing.Pool
Fixed an issue of 'grouper' items being generators due to ifilter usage. Temporary fix applied for the need of referencing SQLDatabase.FIELD_SONGNAME in __init__ Cleaned up some pep8 style issues
This commit is contained in:
parent
e071804ea5
commit
7d14e0734a
2 changed files with 66 additions and 53 deletions
|
@ -1,9 +1,8 @@
|
|||
from dejavu.database import get_database
|
||||
import dejavu.decoder as decoder
|
||||
import fingerprint
|
||||
from multiprocessing import Process, cpu_count
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
|
||||
|
||||
class Dejavu(object):
|
||||
|
@ -30,57 +29,39 @@ class Dejavu(object):
|
|||
|
||||
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
||||
# Try to use the maximum amount of processes if not given.
|
||||
if nprocesses is None:
|
||||
try:
|
||||
nprocesses = cpu_count()
|
||||
nprocesses = nprocesses or multiprocessing.cpu_count()
|
||||
except NotImplementedError:
|
||||
nprocesses = 1
|
||||
else:
|
||||
nprocesses = 1 if nprocesses <= 0 else nprocesses
|
||||
|
||||
# convert files, shuffle order
|
||||
files = list(decoder.find_files(path, extensions))
|
||||
random.shuffle(files)
|
||||
pool = multiprocessing.Pool(nprocesses)
|
||||
|
||||
files_split = chunkify(files, nprocesses)
|
||||
results = []
|
||||
for filename, _ in decoder.find_files(path, extensions):
|
||||
# TODO: Don't queue up files that have already been fingerprinted.
|
||||
result = pool.apply_async(_fingerprint_worker,
|
||||
(filename, self.db))
|
||||
results.append(result)
|
||||
|
||||
# split into processes here
|
||||
processes = []
|
||||
for i in range(nprocesses):
|
||||
|
||||
# create process and start it
|
||||
p = Process(target=self._fingerprint_worker,
|
||||
args=(files_split[i], self.db))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def _fingerprint_worker(self, files, db):
|
||||
for filename, extension in files:
|
||||
|
||||
# if there are already fingerprints in database,
|
||||
# don't re-fingerprint
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
if song_name in self.songnames_set:
|
||||
print("-> Already fingerprinted, continuing...")
|
||||
while len(results):
|
||||
for result in results[:]:
|
||||
# TODO: Handle errors gracefully and return them to the callee
|
||||
# in some way.
|
||||
try:
|
||||
result.get(timeout=2)
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
except:
|
||||
import traceback, sys
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
results.remove(result)
|
||||
else:
|
||||
results.remove(result)
|
||||
|
||||
channels, Fs = decoder.read(filename)
|
||||
|
||||
# insert song name into database
|
||||
song_id = db.insert_song(song_name)
|
||||
|
||||
for c in range(len(channels)):
|
||||
channel = channels[c]
|
||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||
|
||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||
|
||||
db.insert_hashes(song_id, hashes)
|
||||
|
||||
# only after done fingerprinting do confirm
|
||||
db.set_song_fingerprinted(song_id)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
channels, Fs = decoder.read(filepath)
|
||||
|
@ -122,12 +103,14 @@ class Dejavu(object):
|
|||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||
print("Diff is %d with %d offset-aligned matches" % (largest,
|
||||
largest_count))
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||
# TODO: Clarifey what `get_song_by_id` should return.
|
||||
songname = song.get("song_name", None)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -145,6 +128,35 @@ class Dejavu(object):
|
|||
return r.recognize(*options, **kwoptions)
|
||||
|
||||
|
||||
def _fingerprint_worker(filename, db):
|
||||
song_name, extension = os.path.splitext(os.path.basename(filename))
|
||||
|
||||
channels, Fs = decoder.read(filename)
|
||||
|
||||
# insert song into database
|
||||
sid = db.insert_song(song_name)
|
||||
|
||||
channel_amount = len(channels)
|
||||
for channeln, channel in enumerate(channels):
|
||||
# TODO: Remove prints or change them into optional logging.
|
||||
print("Fingerprinting channel %d/%d for %s" % (channeln + 1,
|
||||
channel_amount,
|
||||
filename))
|
||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
||||
filename))
|
||||
|
||||
print("Inserting fingerprints for channel %d/%d for %s" %
|
||||
(channeln + 1, channel_amount, filename))
|
||||
db.insert_hashes(sid, hashes)
|
||||
print("Finished inserting for channel %d/%d for %s" %
|
||||
(channeln + 1, channel_amount, filename))
|
||||
|
||||
print("Marking %s finished" % (filename,))
|
||||
db.set_song_fingerprinted(sid)
|
||||
print("%s finished" % (filename,))
|
||||
|
||||
|
||||
def chunkify(lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import absolute_import
|
||||
from itertools import izip_longest, ifilter
|
||||
from itertools import izip_longest
|
||||
import Queue
|
||||
|
||||
import MySQLdb as mysql
|
||||
|
@ -312,7 +312,8 @@ class SQLDatabase(Database):
|
|||
|
||||
def grouper(iterable, n, fillvalue=None):
|
||||
args = [iter(iterable)] * n
|
||||
return (ifilter(None, values) for values in izip_longest(fillvalue=fillvalue, *args))
|
||||
return (filter(None, values) for values
|
||||
in izip_longest(fillvalue=fillvalue, *args))
|
||||
|
||||
|
||||
def cursor_factory(**factory_options):
|
||||
|
|
Loading…
Reference in a new issue