Fixed various small things that weren't caught before.

- Fixes SQL queries for table creations
- Table creation is now down in reverse order to accompany the foreign key
- Fixed a typo in the BaseRecognizer that caused it to not work
- Changed configuration passed to Dejavu into a (nested) dictionary
This commit is contained in:
Wessie 2013-12-18 00:31:57 +01:00
parent 1c9eddc3a2
commit 3b72768f94
3 changed files with 59 additions and 74 deletions

View file

@ -8,20 +8,14 @@ import random
DEBUG = False DEBUG = False
class Dejavu(): class Dejavu(object):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
# initialize db # initialize db
database = SQLDatabase( self.db = SQLDatabase(**config.get("database", {}))
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
self.db = database
# create components # create components
self.converter = Converter() self.converter = Converter()
#self.fingerprinter = Fingerprinter(self.config) #self.fingerprinter = Fingerprinter(self.config)
@ -30,16 +24,16 @@ class Dejavu():
# get songs previously indexed # get songs previously indexed
self.songs = self.db.get_songs() self.songs = self.db.get_songs()
self.songnames_set = set() # to know which ones we've computed before self.songnames_set = set() # to know which ones we've computed before
if self.songs:
for song in self.songs: for song in self.songs:
song_id = song[SQLDatabase.FIELD_SONG_ID] song_name = song[self.db.FIELD_SONGNAME]
song_name = song[SQLDatabase.FIELD_SONGNAME]
self.songnames_set.add(song_name) self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n): def chunkify(self, lst, n):
""" """
Splits a list into roughly n equal parts. Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
""" """
return [lst[i::n] for i in xrange(n)] return [lst[i::n] for i in xrange(n)]
@ -55,25 +49,19 @@ class Dejavu():
processes = [] processes = []
for i in range(nprocesses): for i in range(nprocesses):
# need database instance since mysql connections shouldn't be shared across processes
sql_connection = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
# create process and start it # create process and start it
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output)) p = Process(target=self.fingerprint_worker,
args=(files_split[i], self.db, output))
p.start() p.start()
processes.append(p) processes.append(p)
# wait for all processes to complete # wait for all processes to complete
for p in processes: for p in processes:
p.join() p.join()
# delete orphans # delete orphans
# print "Done fingerprinting. Deleting orphaned fingerprints..." # print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the # TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans() #self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output): def fingerprint_worker(self, files, sql_connection, output):
@ -82,7 +70,7 @@ class Dejavu():
# if there are already fingerprints in database, don't re-fingerprint or convert # if there are already fingerprints in database, don't re-fingerprint or convert
song_name = os.path.basename(filename).split(".")[0] song_name = os.path.basename(filename).split(".")[0]
if DEBUG and song_name in self.songnames_set: if DEBUG and song_name in self.songnames_set:
print("-> Already fingerprinted, continuing...") print("-> Already fingerprinted, continuing...")
continue continue
@ -117,27 +105,27 @@ class Dejavu():
for channel in range(nchannels): for channel in range(nchannels):
channels.append(frames[:, channel]) channels.append(frames[:, channel])
return (channels, Fs) return (channels, Fs)
def fingerprint_file(self, filepath, song_name=None): def fingerprint_file(self, filepath, song_name=None):
# TODO: replace with something that handles all audio formats # TODO: replace with something that handles all audio formats
channels, Fs = self.extract_channels(path) channels, Fs = self.extract_channels(path)
if not song_name: if not song_name:
song_name = os.path.basename(filename).split(".")[0] song_name = os.path.basename(filename).split(".")[0]
song_id = self.db.insert_song(song_name) song_id = self.db.insert_song(song_name)
for data in channels: for data in channels:
hashes = fingerprint.fingerprint(data, Fs=Fs) hashes = fingerprint.fingerprint(data, Fs=Fs)
self.db.insert_hashes(song_id, hashes) self.db.insert_hashes(song_id, hashes)
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
hashes = fingerprint.fingerprint(samples, Fs=Fs) hashes = fingerprint.fingerprint(samples, Fs=Fs)
return self.db.return_matches(hashes) return self.db.return_matches(hashes)
def align_matches(self, matches): def align_matches(self, matches):
""" """
Finds hash matches that align in time with other matches and finds Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio. consensus about which hashes are "true" signal from the audio.
Returns a dictionary with match information. Returns a dictionary with match information.
""" """
# align by diffs # align by diffs
@ -158,24 +146,24 @@ class Dejavu():
largest_count = diff_counter[diff][sid] largest_count = diff_counter[diff][sid]
song_id = sid song_id = sid
if DEBUG: if DEBUG:
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count)) print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
# extract idenfication # extract idenfication
song = self.db.get_song_by_id(song_id) song = self.db.get_song_by_id(song_id)
if song: if song:
songname = song.get(SQLDatabase.FIELD_SONGNAME, None) songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
else: else:
return None return None
if DEBUG: if DEBUG:
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)) print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
# return match info # return match info
song = { song = {
"song_id" : song_id, "song_id" : song_id,
"song_name" : songname, "song_name" : songname,
"confidence" : largest_count "confidence" : largest_count
} }
return song return song

View file

@ -70,7 +70,7 @@ class SQLDatabase(Database):
`%s` binary(10) not null, `%s` binary(10) not null,
`%s` mediumint unsigned not null, `%s` mediumint unsigned not null,
`%s` int unsigned not null, `%s` int unsigned not null,
INDEX(%s), PRIMARY KEY(%s),
UNIQUE(%s, %s, %s), UNIQUE(%s, %s, %s),
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
) ENGINE=INNODB;""" % ( ) ENGINE=INNODB;""" % (
@ -157,8 +157,8 @@ class SQLDatabase(Database):
fingerprints associated with them. fingerprints associated with them.
""" """
with self.cursor() as cur: with self.cursor() as cur:
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.CREATE_SONGS_TABLE) cur.execute(self.CREATE_SONGS_TABLE)
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.DELETE_UNFINGERPRINTED) cur.execute(self.DELETE_UNFINGERPRINTED)
def empty(self): def empty(self):

View file

@ -1,6 +1,6 @@
from multiprocessing import Queue, Process from multiprocessing import Queue, Process
from dejavu.database import SQLDatabase from dejavu.database import SQLDatabase
import dejavu.fingerprint import dejavu.fingerprint as fingerprint
from dejavu import Dejavu from dejavu import Dejavu
from scipy.io import wavfile from scipy.io import wavfile
import wave import wave
@ -12,60 +12,57 @@ import array
class BaseRecognizer(object): class BaseRecognizer(object):
def __init__(self, dejavu): def __init__(self, dejavu):
self.dejavu = dejavu self.dejavu = dejavu
self.Fs = dejavu.fingerprint.DEFAULT_FS self.Fs = fingerprint.DEFAULT_FS
def _recognize(self, *data): def _recognize(self, *data):
matches = [] matches = []
for d in data: for d in data:
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs)) matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
return self.dejavu.align_matches(matches) return self.dejavu.align_matches(matches)
def recognize(self):
pass # base class does nothing
def recognize(self):
pass # base class does nothing
class WaveFileRecognizer(BaseRecognizer): class WaveFileRecognizer(BaseRecognizer):
def __init__(self, dejavu, filename=None): def __init__(self, dejavu, filename=None):
super(WaveFileRecognizer, self).__init__(dejavu) super(WaveFileRecognizer, self).__init__(dejavu)
self.filename = filename self.filename = filename
def recognize_file(self, filename): def recognize_file(self, filename):
Fs, frames = wavfile.read(filename) Fs, frames = wavfile.read(filename)
self.Fs = Fs self.Fs = Fs
wave_object = wave.open(filename) wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
channels = [] channels = []
for channel in range(nchannels): for channel in range(nchannels):
channels.append(frames[:, channel]) channels.append(frames[:, channel])
t = time.time() t = time.time()
match = self._recognize(*channels) match = self._recognize(*channels)
t = time.time() - t t = time.time() - t
if match: if match:
match['match_time'] = t match['match_time'] = t
return match return match
def recognize(self): def recognize(self):
return self.recognize_file(self.filename) return self.recognize_file(self.filename)
class MicrophoneRecognizer(BaseRecognizer): class MicrophoneRecognizer(BaseRecognizer):
CHUNK = 8192 # 44100 is a multiple of 1225 CHUNK = 8192 # 44100 is a multiple of 1225
FORMAT = pyaudio.paInt16 FORMAT = pyaudio.paInt16
CHANNELS = 2 CHANNELS = 2
RATE = 44100 RATE = 44100
def __init__(self, dejavu, seconds=None): def __init__(self, dejavu, seconds=None):
super(MicrophoneRecognizer, self).__init__(dejavu) super(MicrophoneRecognizer, self).__init__(dejavu)
self.audio = pyaudio.PyAudio() self.audio = pyaudio.PyAudio()
@ -75,52 +72,52 @@ class MicrophoneRecognizer(BaseRecognizer):
self.chunk_size = CHUNK self.chunk_size = CHUNK
self.rate = RATE self.rate = RATE
self.recorded = False self.recorded = False
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK): def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
self.chunk_size = chunk self.chunk_size = chunk
self.channels = channels self.channels = channels
self.recorded = False self.recorded = False
self.rate = rate self.rate = rate
if self.stream: if self.stream:
self.stream.stop_stream() self.stream.stop_stream()
self.stream.close() self.stream.close()
self.stream = self.audio.open(format=FORMAT, self.stream = self.audio.open(format=FORMAT,
channels=channels, channels=channels,
rate=rate, rate=rate,
input=True, input=True,
frames_per_buffer=chunk) frames_per_buffer=chunk)
self.data = [[] for i in range(channels)] self.data = [[] for i in range(channels)]
def process_recording(self): def process_recording(self):
data = self.stream.read(self.chunk_size) data = self.stream.read(self.chunk_size)
nums = np.fromstring(data, np.int16) nums = np.fromstring(data, np.int16)
for c in range(self.channels): for c in range(self.channels):
self.data[c].extend(nums[c::c+1]) self.data[c].extend(nums[c::c+1])
def stop_recording(self): def stop_recording(self):
self.stream.stop_stream() self.stream.stop_stream()
self.stream.close() self.stream.close()
self.stream = None self.stream = None
self.recorded = True self.recorded = True
def recognize_recording(self): def recognize_recording(self):
if not self.recorded: if not self.recorded:
raise NoRecordingError("Recording was not complete/begun") raise NoRecordingError("Recording was not complete/begun")
return self._recognize(*self.data) return self._recognize(*self.data)
def get_recorded_time(self): def get_recorded_time(self):
return len(self.data[0]) / self.rate return len(self.data[0]) / self.rate
def recognize(self): def recognize(self):
self.start_recording() self.start_recording()
for i in range(0, int(self.rate / self.chunk * self.seconds)): for i in range(0, int(self.rate / self.chunk * self.seconds)):
self.process_recording() self.process_recording()
self.stop_recording() self.stop_recording()
return self.recognize_recording() return self.recognize_recording()
class NoRecordingError(Exception): class NoRecordingError(Exception):
pass pass