Switched fingerprint_directory to using multiprocessing.Pool

Fixed an issue of 'grouper' items being generators due to ifilter usage.
Temporary fix applied for the need of referencing SQLDatabase.FIELD_SONGNAME in __init__
Cleaned up some pep8 style issues
This commit is contained in:
Wessie 2013-12-23 14:59:08 +01:00
parent e071804ea5
commit 7d14e0734a
2 changed files with 66 additions and 53 deletions

View file

@ -1,9 +1,8 @@
from dejavu.database import get_database
import dejavu.decoder as decoder
import fingerprint
from multiprocessing import Process, cpu_count
import multiprocessing
import os
import random
class Dejavu(object):
@ -30,57 +29,39 @@ class Dejavu(object):
def fingerprint_directory(self, path, extensions, nprocesses=None):
# Try to use the maximum amount of processes if not given.
if nprocesses is None:
try:
nprocesses = cpu_count()
nprocesses = nprocesses or multiprocessing.cpu_count()
except NotImplementedError:
nprocesses = 1
else:
nprocesses = 1 if nprocesses <= 0 else nprocesses
# convert files, shuffle order
files = list(decoder.find_files(path, extensions))
random.shuffle(files)
pool = multiprocessing.Pool(nprocesses)
files_split = chunkify(files, nprocesses)
results = []
for filename, _ in decoder.find_files(path, extensions):
# TODO: Don't queue up files that have already been fingerprinted.
result = pool.apply_async(_fingerprint_worker,
(filename, self.db))
results.append(result)
# split into processes here
processes = []
for i in range(nprocesses):
# create process and start it
p = Process(target=self._fingerprint_worker,
args=(files_split[i], self.db))
p.start()
processes.append(p)
# wait for all processes to complete
for p in processes:
p.join()
def _fingerprint_worker(self, files, db):
for filename, extension in files:
# if there are already fingerprints in database,
# don't re-fingerprint
song_name = os.path.basename(filename).split(".")[0]
if song_name in self.songnames_set:
print("-> Already fingerprinted, continuing...")
while len(results):
for result in results[:]:
# TODO: Handle errors gracefully and return them to the callee
# in some way.
try:
result.get(timeout=2)
except multiprocessing.TimeoutError:
continue
except:
import traceback, sys
traceback.print_exc(file=sys.stdout)
results.remove(result)
else:
results.remove(result)
channels, Fs = decoder.read(filename)
# insert song name into database
song_id = db.insert_song(song_name)
for c in range(len(channels)):
channel = channels[c]
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
hashes = fingerprint.fingerprint(channel, Fs=Fs)
db.insert_hashes(song_id, hashes)
# only after done fingerprinting do confirm
db.set_song_fingerprinted(song_id)
pool.close()
pool.join()
def fingerprint_file(self, filepath, song_name=None):
channels, Fs = decoder.read(filepath)
@ -122,12 +103,14 @@ class Dejavu(object):
largest_count = diff_counter[diff][sid]
song_id = sid
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
print("Diff is %d with %d offset-aligned matches" % (largest,
largest_count))
# extract idenfication
song = self.db.get_song_by_id(song_id)
if song:
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
# TODO: Clarifey what `get_song_by_id` should return.
songname = song.get("song_name", None)
else:
return None
@ -145,6 +128,35 @@ class Dejavu(object):
return r.recognize(*options, **kwoptions)
def _fingerprint_worker(filename, db):
song_name, extension = os.path.splitext(os.path.basename(filename))
channels, Fs = decoder.read(filename)
# insert song into database
sid = db.insert_song(song_name)
channel_amount = len(channels)
for channeln, channel in enumerate(channels):
# TODO: Remove prints or change them into optional logging.
print("Fingerprinting channel %d/%d for %s" % (channeln + 1,
channel_amount,
filename))
hashes = fingerprint.fingerprint(channel, Fs=Fs)
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
filename))
print("Inserting fingerprints for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
db.insert_hashes(sid, hashes)
print("Finished inserting for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
print("Marking %s finished" % (filename,))
db.set_song_fingerprinted(sid)
print("%s finished" % (filename,))
def chunkify(lst, n):
"""
Splits a list into roughly n equal parts.

View file

@ -1,5 +1,5 @@
from __future__ import absolute_import
from itertools import izip_longest, ifilter
from itertools import izip_longest
import Queue
import MySQLdb as mysql
@ -312,7 +312,8 @@ class SQLDatabase(Database):
def grouper(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return (ifilter(None, values) for values in izip_longest(fillvalue=fillvalue, *args))
return (filter(None, values) for values
in izip_longest(fillvalue=fillvalue, *args))
def cursor_factory(**factory_options):