Fixed various small things that weren't caught before.

- Fixes SQL queries for table creations
- Table creation is now down in reverse order to accompany the foreign key
- Fixed a typo in the BaseRecognizer that caused it to not work
- Changed configuration passed to Dejavu into a (nested) dictionary
This commit is contained in:
Wessie 2013-12-18 00:31:57 +01:00
parent 1c9eddc3a2
commit 3b72768f94
3 changed files with 59 additions and 74 deletions

View file

@ -8,20 +8,14 @@ import random
DEBUG = False
class Dejavu():
class Dejavu(object):
def __init__(self, config):
self.config = config
# initialize db
database = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
self.db = database
self.db = SQLDatabase(**config.get("database", {}))
# create components
self.converter = Converter()
#self.fingerprinter = Fingerprinter(self.config)
@ -30,16 +24,16 @@ class Dejavu():
# get songs previously indexed
self.songs = self.db.get_songs()
self.songnames_set = set() # to know which ones we've computed before
if self.songs:
for song in self.songs:
song_id = song[SQLDatabase.FIELD_SONG_ID]
song_name = song[SQLDatabase.FIELD_SONGNAME]
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
for song in self.songs:
song_name = song[self.db.FIELD_SONGNAME]
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n):
"""
Splits a list into roughly n equal parts.
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]
@ -55,25 +49,19 @@ class Dejavu():
processes = []
for i in range(nprocesses):
# need database instance since mysql connections shouldn't be shared across processes
sql_connection = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
# create process and start it
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
p = Process(target=self.fingerprint_worker,
args=(files_split[i], self.db, output))
p.start()
processes.append(p)
# wait for all processes to complete
for p in processes:
p.join()
# delete orphans
# print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the
# TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output):
@ -82,7 +70,7 @@ class Dejavu():
# if there are already fingerprints in database, don't re-fingerprint or convert
song_name = os.path.basename(filename).split(".")[0]
if DEBUG and song_name in self.songnames_set:
if DEBUG and song_name in self.songnames_set:
print("-> Already fingerprinted, continuing...")
continue
@ -117,27 +105,27 @@ class Dejavu():
for channel in range(nchannels):
channels.append(frames[:, channel])
return (channels, Fs)
def fingerprint_file(self, filepath, song_name=None):
# TODO: replace with something that handles all audio formats
channels, Fs = self.extract_channels(path)
if not song_name:
song_name = os.path.basename(filename).split(".")[0]
song_id = self.db.insert_song(song_name)
for data in channels:
hashes = fingerprint.fingerprint(data, Fs=Fs)
self.db.insert_hashes(song_id, hashes)
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
hashes = fingerprint.fingerprint(samples, Fs=Fs)
return self.db.return_matches(hashes)
def align_matches(self, matches):
"""
Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio.
Returns a dictionary with match information.
"""
# align by diffs
@ -158,24 +146,24 @@ class Dejavu():
largest_count = diff_counter[diff][sid]
song_id = sid
if DEBUG:
if DEBUG:
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
# extract idenfication
# extract idenfication
song = self.db.get_song_by_id(song_id)
if song:
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
else:
return None
if DEBUG:
if DEBUG:
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
# return match info
song = {
"song_id" : song_id,
"song_name" : songname,
"confidence" : largest_count
}
return song
return song

View file

@ -70,7 +70,7 @@ class SQLDatabase(Database):
`%s` binary(10) not null,
`%s` mediumint unsigned not null,
`%s` int unsigned not null,
INDEX(%s),
PRIMARY KEY(%s),
UNIQUE(%s, %s, %s),
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
) ENGINE=INNODB;""" % (
@ -157,8 +157,8 @@ class SQLDatabase(Database):
fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.CREATE_SONGS_TABLE)
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.DELETE_UNFINGERPRINTED)
def empty(self):

View file

@ -1,6 +1,6 @@
from multiprocessing import Queue, Process
from dejavu.database import SQLDatabase
import dejavu.fingerprint
import dejavu.fingerprint as fingerprint
from dejavu import Dejavu
from scipy.io import wavfile
import wave
@ -12,60 +12,57 @@ import array
class BaseRecognizer(object):
def __init__(self, dejavu):
self.dejavu = dejavu
self.Fs = dejavu.fingerprint.DEFAULT_FS
self.Fs = fingerprint.DEFAULT_FS
def _recognize(self, *data):
matches = []
for d in data:
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
return self.dejavu.align_matches(matches)
def recognize(self):
pass # base class does nothing
def recognize(self):
pass # base class does nothing
class WaveFileRecognizer(BaseRecognizer):
def __init__(self, dejavu, filename=None):
super(WaveFileRecognizer, self).__init__(dejavu)
self.filename = filename
def recognize_file(self, filename):
Fs, frames = wavfile.read(filename)
self.Fs = Fs
wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
channels = []
for channel in range(nchannels):
channels.append(frames[:, channel])
t = time.time()
match = self._recognize(*channels)
t = time.time() - t
if match:
match['match_time'] = t
return match
def recognize(self):
return self.recognize_file(self.filename)
class MicrophoneRecognizer(BaseRecognizer):
CHUNK = 8192 # 44100 is a multiple of 1225
FORMAT = pyaudio.paInt16
CHANNELS = 2
RATE = 44100
def __init__(self, dejavu, seconds=None):
super(MicrophoneRecognizer, self).__init__(dejavu)
self.audio = pyaudio.PyAudio()
@ -75,52 +72,52 @@ class MicrophoneRecognizer(BaseRecognizer):
self.chunk_size = CHUNK
self.rate = RATE
self.recorded = False
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
self.chunk_size = chunk
self.channels = channels
self.recorded = False
self.rate = rate
if self.stream:
self.stream.stop_stream()
self.stream.close()
self.stream = self.audio.open(format=FORMAT,
channels=channels,
rate=rate,
input=True,
frames_per_buffer=chunk)
self.data = [[] for i in range(channels)]
def process_recording(self):
data = self.stream.read(self.chunk_size)
nums = np.fromstring(data, np.int16)
for c in range(self.channels):
self.data[c].extend(nums[c::c+1])
def stop_recording(self):
self.stream.stop_stream()
self.stream.close()
self.stream = None
self.recorded = True
def recognize_recording(self):
if not self.recorded:
raise NoRecordingError("Recording was not complete/begun")
return self._recognize(*self.data)
def get_recorded_time(self):
return len(self.data[0]) / self.rate
def recognize(self):
self.start_recording()
for i in range(0, int(self.rate / self.chunk * self.seconds)):
self.process_recording()
self.stop_recording()
return self.recognize_recording()
class NoRecordingError(Exception):
pass