From 3b72768f942b8cbd4659649e85bd50f6bc346ac3 Mon Sep 17 00:00:00 2001 From: Wessie Date: Wed, 18 Dec 2013 00:31:57 +0100 Subject: [PATCH] Fixed various small things that weren't caught before. - Fixes SQL queries for table creations - Table creation is now down in reverse order to accompany the foreign key - Fixed a typo in the BaseRecognizer that caused it to not work - Changed configuration passed to Dejavu into a (nested) dictionary --- dejavu/__init__.py | 72 +++++++++++++++++++-------------------------- dejavu/database.py | 4 +-- dejavu/recognize.py | 57 +++++++++++++++++------------------ 3 files changed, 59 insertions(+), 74 deletions(-) diff --git a/dejavu/__init__.py b/dejavu/__init__.py index 340db94..3222395 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -8,20 +8,14 @@ import random DEBUG = False -class Dejavu(): - +class Dejavu(object): def __init__(self, config): - + self.config = config - + # initialize db - database = SQLDatabase( - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE)) - self.db = database - + self.db = SQLDatabase(**config.get("database", {})) + # create components self.converter = Converter() #self.fingerprinter = Fingerprinter(self.config) @@ -30,16 +24,16 @@ class Dejavu(): # get songs previously indexed self.songs = self.db.get_songs() self.songnames_set = set() # to know which ones we've computed before - if self.songs: - for song in self.songs: - song_id = song[SQLDatabase.FIELD_SONG_ID] - song_name = song[SQLDatabase.FIELD_SONGNAME] - self.songnames_set.add(song_name) - print "Added: %s to the set of fingerprinted songs..." % song_name + + for song in self.songs: + song_name = song[self.db.FIELD_SONGNAME] + + self.songnames_set.add(song_name) + print "Added: %s to the set of fingerprinted songs..." % song_name def chunkify(self, lst, n): """ - Splits a list into roughly n equal parts. + Splits a list into roughly n equal parts. http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts """ return [lst[i::n] for i in xrange(n)] @@ -55,25 +49,19 @@ class Dejavu(): processes = [] for i in range(nprocesses): - # need database instance since mysql connections shouldn't be shared across processes - sql_connection = SQLDatabase( - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE)) - # create process and start it - p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output)) + p = Process(target=self.fingerprint_worker, + args=(files_split[i], self.db, output)) p.start() processes.append(p) # wait for all processes to complete for p in processes: p.join() - + # delete orphans # print "Done fingerprinting. Deleting orphaned fingerprints..." - # TODO: need a more performant query in database.py for the + # TODO: need a more performant query in database.py for the #self.fingerprinter.db.delete_orphans() def fingerprint_worker(self, files, sql_connection, output): @@ -82,7 +70,7 @@ class Dejavu(): # if there are already fingerprints in database, don't re-fingerprint or convert song_name = os.path.basename(filename).split(".")[0] - if DEBUG and song_name in self.songnames_set: + if DEBUG and song_name in self.songnames_set: print("-> Already fingerprinted, continuing...") continue @@ -117,27 +105,27 @@ class Dejavu(): for channel in range(nchannels): channels.append(frames[:, channel]) return (channels, Fs) - + def fingerprint_file(self, filepath, song_name=None): # TODO: replace with something that handles all audio formats channels, Fs = self.extract_channels(path) if not song_name: song_name = os.path.basename(filename).split(".")[0] song_id = self.db.insert_song(song_name) - + for data in channels: hashes = fingerprint.fingerprint(data, Fs=Fs) self.db.insert_hashes(song_id, hashes) - + def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): hashes = fingerprint.fingerprint(samples, Fs=Fs) return self.db.return_matches(hashes) - + def align_matches(self, matches): """ Finds hash matches that align in time with other matches and finds consensus about which hashes are "true" signal from the audio. - + Returns a dictionary with match information. """ # align by diffs @@ -158,24 +146,24 @@ class Dejavu(): largest_count = diff_counter[diff][sid] song_id = sid - if DEBUG: + if DEBUG: print("Diff is %d with %d offset-aligned matches" % (largest, largest_count)) - - # extract idenfication + + # extract idenfication song = self.db.get_song_by_id(song_id) if song: songname = song.get(SQLDatabase.FIELD_SONGNAME, None) else: return None - - if DEBUG: + + if DEBUG: print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)) - + # return match info song = { "song_id" : song_id, "song_name" : songname, "confidence" : largest_count } - - return song \ No newline at end of file + + return song diff --git a/dejavu/database.py b/dejavu/database.py index 929c4eb..75af1bf 100755 --- a/dejavu/database.py +++ b/dejavu/database.py @@ -70,7 +70,7 @@ class SQLDatabase(Database): `%s` binary(10) not null, `%s` mediumint unsigned not null, `%s` int unsigned not null, - INDEX(%s), + PRIMARY KEY(%s), UNIQUE(%s, %s, %s), FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE ) ENGINE=INNODB;""" % ( @@ -157,8 +157,8 @@ class SQLDatabase(Database): fingerprints associated with them. """ with self.cursor() as cur: - cur.execute(self.CREATE_FINGERPRINTS_TABLE) cur.execute(self.CREATE_SONGS_TABLE) + cur.execute(self.CREATE_FINGERPRINTS_TABLE) cur.execute(self.DELETE_UNFINGERPRINTED) def empty(self): diff --git a/dejavu/recognize.py b/dejavu/recognize.py index 68700fd..9283c6a 100755 --- a/dejavu/recognize.py +++ b/dejavu/recognize.py @@ -1,6 +1,6 @@ from multiprocessing import Queue, Process from dejavu.database import SQLDatabase -import dejavu.fingerprint +import dejavu.fingerprint as fingerprint from dejavu import Dejavu from scipy.io import wavfile import wave @@ -12,60 +12,57 @@ import array class BaseRecognizer(object): - + def __init__(self, dejavu): self.dejavu = dejavu - self.Fs = dejavu.fingerprint.DEFAULT_FS - + self.Fs = fingerprint.DEFAULT_FS + def _recognize(self, *data): matches = [] for d in data: - matches.extend(self.dejavu.find_matches(data, Fs=self.Fs)) + matches.extend(self.dejavu.find_matches(d, Fs=self.Fs)) return self.dejavu.align_matches(matches) - - def recognize(self): - pass # base class does nothing - - + def recognize(self): + pass # base class does nothing class WaveFileRecognizer(BaseRecognizer): - + def __init__(self, dejavu, filename=None): super(WaveFileRecognizer, self).__init__(dejavu) self.filename = filename - + def recognize_file(self, filename): Fs, frames = wavfile.read(filename) self.Fs = Fs - + wave_object = wave.open(filename) nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() - + channels = [] for channel in range(nchannels): channels.append(frames[:, channel]) - + t = time.time() match = self._recognize(*channels) t = time.time() - t - + if match: match['match_time'] = t - + return match - + def recognize(self): return self.recognize_file(self.filename) class MicrophoneRecognizer(BaseRecognizer): - + CHUNK = 8192 # 44100 is a multiple of 1225 FORMAT = pyaudio.paInt16 CHANNELS = 2 RATE = 44100 - + def __init__(self, dejavu, seconds=None): super(MicrophoneRecognizer, self).__init__(dejavu) self.audio = pyaudio.PyAudio() @@ -75,52 +72,52 @@ class MicrophoneRecognizer(BaseRecognizer): self.chunk_size = CHUNK self.rate = RATE self.recorded = False - + def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK): self.chunk_size = chunk self.channels = channels self.recorded = False self.rate = rate - + if self.stream: self.stream.stop_stream() self.stream.close() - + self.stream = self.audio.open(format=FORMAT, channels=channels, rate=rate, input=True, frames_per_buffer=chunk) - + self.data = [[] for i in range(channels)] - + def process_recording(self): data = self.stream.read(self.chunk_size) nums = np.fromstring(data, np.int16) for c in range(self.channels): self.data[c].extend(nums[c::c+1]) - + def stop_recording(self): self.stream.stop_stream() self.stream.close() self.stream = None self.recorded = True - + def recognize_recording(self): if not self.recorded: raise NoRecordingError("Recording was not complete/begun") return self._recognize(*self.data) - + def get_recorded_time(self): return len(self.data[0]) / self.rate - + def recognize(self): self.start_recording() for i in range(0, int(self.rate / self.chunk * self.seconds)): self.process_recording() self.stop_recording() return self.recognize_recording() - + class NoRecordingError(Exception): pass