Moved main Dejavu class from dejavu.control to dejavu

This commit is contained in:
Vin 2013-12-16 23:38:58 +00:00
parent 6a6ae94e3d
commit a4ed612658
4 changed files with 178 additions and 165 deletions

176
dejavu/__init__.py Normal file → Executable file
View file

@ -0,0 +1,176 @@
from dejavu.database import SQLDatabase
from dejavu.convert import Converter
import dejavu.fingerprint as fingerprint
from scipy.io import wavfile
from multiprocessing import Process
import wave, os
import random
DEBUG = False
class Dejavu():
def __init__(self, config):
self.config = config
# initialize db
database = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
self.db = database
# create components
self.converter = Converter()
#self.fingerprinter = Fingerprinter(self.config)
self.db.setup()
# get songs previously indexed
self.songs = self.db.get_songs()
self.songnames_set = set() # to know which ones we've computed before
if self.songs:
for song in self.songs:
song_id = song[SQLDatabase.FIELD_SONG_ID]
song_name = song[SQLDatabase.FIELD_SONGNAME]
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]
def fingerprint(self, path, output, extensions, nprocesses):
# convert files, shuffle order
files = self.converter.find_files(path, extensions)
random.shuffle(files)
files_split = self.chunkify(files, nprocesses)
# split into processes here
processes = []
for i in range(nprocesses):
# need database instance since mysql connections shouldn't be shared across processes
sql_connection = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
# create process and start it
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
p.start()
processes.append(p)
# wait for all processes to complete
for p in processes:
p.join()
# delete orphans
# print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output):
for filename, extension in files:
# if there are already fingerprints in database, don't re-fingerprint or convert
song_name = os.path.basename(filename).split(".")[0]
if DEBUG and song_name in self.songnames_set:
print("-> Already fingerprinted, continuing...")
continue
# convert to WAV
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
# insert song name into database
song_id = sql_connection.insert_song(song_name)
# for each channel perform FFT analysis and fingerprinting
channels, Fs = self.extract_channels(wavout_path)
for c in range(len(channels)):
channel = channels[c]
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
hashes = fingerprint.fingerprint(channel, Fs=Fs)
sql_connection.insert_hashes(song_id, hashes)
# only after done fingerprinting do confirm
sql_connection.set_song_fingerprinted(song_id)
def extract_channels(self, path):
"""
Reads channels from disk.
Returns a tuple with (channels, sample_rate)
"""
channels = []
Fs, frames = wavfile.read(path)
wave_object = wave.open(path)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
#assert Fs == self.fingerprinter.Fs
for channel in range(nchannels):
channels.append(frames[:, channel])
return (channels, Fs)
def match(self, samples, Fs=fingerprint.DEFAULT_FS):
hashes = fingerprint.fingerprint(samples, Fs=Fs)
return self.db.return_matches(hashes)
def align_matches(self, matches, starttime, record_seconds=None):
"""
Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio.
Returns a dictionary with match information.
"""
# align by diffs
diff_counter = {}
largest = 0
largest_count = 0
song_id = -1
for tup in matches:
sid, diff = tup
if not diff in diff_counter:
diff_counter[diff] = {}
if not sid in diff_counter[diff]:
diff_counter[diff][sid] = 0
diff_counter[diff][sid] += 1
if diff_counter[diff][sid] > largest_count:
largest = diff
largest_count = diff_counter[diff][sid]
song_id = sid
if DEBUG:
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
# extract idenfication
song = self.db.get_song_by_id(song_id)
if song:
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
else:
return None
songname = songname.replace("_", " ")
elapsed = time.time() - starttime
if DEBUG:
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
# return match info
song = {
"song_id" : song_id,
"song_name" : songname,
"match_time" : elapsed,
"confidence" : largest_count
}
if record_seconds:
song['record_time'] = record_seconds
return song

View file

@ -1,107 +0,0 @@
from dejavu.database import SQLDatabase
from dejavu.convert import Converter
from dejavu.fingerprint import Fingerprinter
from scipy.io import wavfile
from multiprocessing import Process
import wave, os
import random
class Dejavu():
def __init__(self, config):
self.config = config
# create components
self.converter = Converter()
self.fingerprinter = Fingerprinter(self.config)
self.fingerprinter.db.setup()
# get songs previously indexed
self.songs = self.fingerprinter.db.get_songs()
self.songnames_set = set() # to know which ones we've computed before
if self.songs:
for song in self.songs:
song_id = song[SQLDatabase.FIELD_SONG_ID]
song_name = song[SQLDatabase.FIELD_SONGNAME]
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]
def fingerprint(self, path, output, extensions, nprocesses):
# convert files, shuffle order
files = self.converter.find_files(path, extensions)
random.shuffle(files)
files_split = self.chunkify(files, nprocesses)
# split into processes here
processes = []
for i in range(nprocesses):
# need database instance since mysql connections shouldn't be shared across processes
sql_connection = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
# create process and start it
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
p.start()
processes.append(p)
# wait for all processes to complete
for p in processes:
p.join()
# delete orphans
# print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output):
for filename, extension in files:
# if there are already fingerprints in database, don't re-fingerprint or convert
song_name = os.path.basename(filename).split(".")[0]
if song_name in self.songnames_set:
print "-> Already fingerprinted, continuing..."
continue
# convert to WAV
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
# insert song name into database
song_id = sql_connection.insert_song(song_name)
# for each channel perform FFT analysis and fingerprinting
channels = self.extract_channels(wavout_path)
for c in range(len(channels)):
channel = channels[c]
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1)
# only after done fingerprinting do confirm
sql_connection.set_song_fingerprinted(song_id)
def extract_channels(self, path):
"""
Reads channels from disk.
"""
channels = []
Fs, frames = wavfile.read(path)
wave_object = wave.open(path)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
assert Fs == self.fingerprinter.Fs
for channel in range(nchannels):
channels.append(frames[:, channel])
return channels

View file

@ -133,12 +133,7 @@ class Fingerprinter():
amp_min=DEFAULT_AMP_MIN): amp_min=DEFAULT_AMP_MIN):
self.config = config self.config = config
database = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
self.db = database
self.Fs = Fs self.Fs = Fs
self.dt = 1.0 / self.Fs self.dt = 1.0 / self.Fs
@ -183,55 +178,4 @@ class Fingerprinter():
self.song_names = wpaths self.song_names = wpaths
# TODO: put this in another module # TODO: put this in another module
def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
"""
Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio.
Returns a dictionary with match information.
"""
# align by diffs
diff_counter = {}
largest = 0
largest_count = 0
song_id = -1
for tup in matches:
sid, diff = tup
if not diff in diff_counter:
diff_counter[diff] = {}
if not sid in diff_counter[diff]:
diff_counter[diff][sid] = 0
diff_counter[diff][sid] += 1
if diff_counter[diff][sid] > largest_count:
largest = diff
largest_count = diff_counter[diff][sid]
song_id = sid
if verbose:
print "Diff is %d with %d offset-aligned matches" % (largest, largest_count)
# extract idenfication
song = self.db.get_song_by_id(song_id)
if song:
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
else:
return None
songname = songname.replace("_", " ")
elapsed = time.time() - starttime
if verbose:
print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)
# return match info
song = {
"song_id" : song_id,
"song_name" : songname,
"match_time" : elapsed,
"confidence" : largest_count
}
if record_seconds:
song['record_time'] = record_seconds
return song

2
go.py Normal file → Executable file
View file

@ -1,4 +1,4 @@
from dejavu.control import Dejavu from dejavu import Dejavu
from ConfigParser import ConfigParser from ConfigParser import ConfigParser
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")