mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Switched fingerprint_directory to using multiprocessing.Pool
Fixed an issue of 'grouper' items being generators due to ifilter usage. Temporary fix applied for the need of referencing SQLDatabase.FIELD_SONGNAME in __init__ Cleaned up some pep8 style issues
This commit is contained in:
parent
e071804ea5
commit
7d14e0734a
2 changed files with 66 additions and 53 deletions
|
@ -1,9 +1,8 @@
|
||||||
from dejavu.database import get_database
|
from dejavu.database import get_database
|
||||||
import dejavu.decoder as decoder
|
import dejavu.decoder as decoder
|
||||||
import fingerprint
|
import fingerprint
|
||||||
from multiprocessing import Process, cpu_count
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
|
|
||||||
|
|
||||||
class Dejavu(object):
|
class Dejavu(object):
|
||||||
|
@ -30,57 +29,39 @@ class Dejavu(object):
|
||||||
|
|
||||||
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
||||||
# Try to use the maximum amount of processes if not given.
|
# Try to use the maximum amount of processes if not given.
|
||||||
if nprocesses is None:
|
try:
|
||||||
try:
|
nprocesses = nprocesses or multiprocessing.cpu_count()
|
||||||
nprocesses = cpu_count()
|
except NotImplementedError:
|
||||||
except NotImplementedError:
|
nprocesses = 1
|
||||||
nprocesses = 1
|
else:
|
||||||
|
nprocesses = 1 if nprocesses <= 0 else nprocesses
|
||||||
|
|
||||||
# convert files, shuffle order
|
pool = multiprocessing.Pool(nprocesses)
|
||||||
files = list(decoder.find_files(path, extensions))
|
|
||||||
random.shuffle(files)
|
|
||||||
|
|
||||||
files_split = chunkify(files, nprocesses)
|
results = []
|
||||||
|
for filename, _ in decoder.find_files(path, extensions):
|
||||||
|
# TODO: Don't queue up files that have already been fingerprinted.
|
||||||
|
result = pool.apply_async(_fingerprint_worker,
|
||||||
|
(filename, self.db))
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
# split into processes here
|
while len(results):
|
||||||
processes = []
|
for result in results[:]:
|
||||||
for i in range(nprocesses):
|
# TODO: Handle errors gracefully and return them to the callee
|
||||||
|
# in some way.
|
||||||
|
try:
|
||||||
|
result.get(timeout=2)
|
||||||
|
except multiprocessing.TimeoutError:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
import traceback, sys
|
||||||
|
traceback.print_exc(file=sys.stdout)
|
||||||
|
results.remove(result)
|
||||||
|
else:
|
||||||
|
results.remove(result)
|
||||||
|
|
||||||
# create process and start it
|
pool.close()
|
||||||
p = Process(target=self._fingerprint_worker,
|
pool.join()
|
||||||
args=(files_split[i], self.db))
|
|
||||||
p.start()
|
|
||||||
processes.append(p)
|
|
||||||
|
|
||||||
# wait for all processes to complete
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
def _fingerprint_worker(self, files, db):
|
|
||||||
for filename, extension in files:
|
|
||||||
|
|
||||||
# if there are already fingerprints in database,
|
|
||||||
# don't re-fingerprint
|
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
|
||||||
if song_name in self.songnames_set:
|
|
||||||
print("-> Already fingerprinted, continuing...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
channels, Fs = decoder.read(filename)
|
|
||||||
|
|
||||||
# insert song name into database
|
|
||||||
song_id = db.insert_song(song_name)
|
|
||||||
|
|
||||||
for c in range(len(channels)):
|
|
||||||
channel = channels[c]
|
|
||||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
|
||||||
|
|
||||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
|
||||||
|
|
||||||
db.insert_hashes(song_id, hashes)
|
|
||||||
|
|
||||||
# only after done fingerprinting do confirm
|
|
||||||
db.set_song_fingerprinted(song_id)
|
|
||||||
|
|
||||||
def fingerprint_file(self, filepath, song_name=None):
|
def fingerprint_file(self, filepath, song_name=None):
|
||||||
channels, Fs = decoder.read(filepath)
|
channels, Fs = decoder.read(filepath)
|
||||||
|
@ -122,12 +103,14 @@ class Dejavu(object):
|
||||||
largest_count = diff_counter[diff][sid]
|
largest_count = diff_counter[diff][sid]
|
||||||
song_id = sid
|
song_id = sid
|
||||||
|
|
||||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
print("Diff is %d with %d offset-aligned matches" % (largest,
|
||||||
|
largest_count))
|
||||||
|
|
||||||
# extract idenfication
|
# extract idenfication
|
||||||
song = self.db.get_song_by_id(song_id)
|
song = self.db.get_song_by_id(song_id)
|
||||||
if song:
|
if song:
|
||||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
# TODO: Clarifey what `get_song_by_id` should return.
|
||||||
|
songname = song.get("song_name", None)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -145,6 +128,35 @@ class Dejavu(object):
|
||||||
return r.recognize(*options, **kwoptions)
|
return r.recognize(*options, **kwoptions)
|
||||||
|
|
||||||
|
|
||||||
|
def _fingerprint_worker(filename, db):
|
||||||
|
song_name, extension = os.path.splitext(os.path.basename(filename))
|
||||||
|
|
||||||
|
channels, Fs = decoder.read(filename)
|
||||||
|
|
||||||
|
# insert song into database
|
||||||
|
sid = db.insert_song(song_name)
|
||||||
|
|
||||||
|
channel_amount = len(channels)
|
||||||
|
for channeln, channel in enumerate(channels):
|
||||||
|
# TODO: Remove prints or change them into optional logging.
|
||||||
|
print("Fingerprinting channel %d/%d for %s" % (channeln + 1,
|
||||||
|
channel_amount,
|
||||||
|
filename))
|
||||||
|
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||||
|
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
||||||
|
filename))
|
||||||
|
|
||||||
|
print("Inserting fingerprints for channel %d/%d for %s" %
|
||||||
|
(channeln + 1, channel_amount, filename))
|
||||||
|
db.insert_hashes(sid, hashes)
|
||||||
|
print("Finished inserting for channel %d/%d for %s" %
|
||||||
|
(channeln + 1, channel_amount, filename))
|
||||||
|
|
||||||
|
print("Marking %s finished" % (filename,))
|
||||||
|
db.set_song_fingerprinted(sid)
|
||||||
|
print("%s finished" % (filename,))
|
||||||
|
|
||||||
|
|
||||||
def chunkify(lst, n):
|
def chunkify(lst, n):
|
||||||
"""
|
"""
|
||||||
Splits a list into roughly n equal parts.
|
Splits a list into roughly n equal parts.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from itertools import izip_longest, ifilter
|
from itertools import izip_longest
|
||||||
import Queue
|
import Queue
|
||||||
|
|
||||||
import MySQLdb as mysql
|
import MySQLdb as mysql
|
||||||
|
@ -312,7 +312,8 @@ class SQLDatabase(Database):
|
||||||
|
|
||||||
def grouper(iterable, n, fillvalue=None):
|
def grouper(iterable, n, fillvalue=None):
|
||||||
args = [iter(iterable)] * n
|
args = [iter(iterable)] * n
|
||||||
return (ifilter(None, values) for values in izip_longest(fillvalue=fillvalue, *args))
|
return (filter(None, values) for values
|
||||||
|
in izip_longest(fillvalue=fillvalue, *args))
|
||||||
|
|
||||||
|
|
||||||
def cursor_factory(**factory_options):
|
def cursor_factory(**factory_options):
|
||||||
|
|
Loading…
Reference in a new issue