mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Fixed various small things that weren't caught before.
- Fixes SQL queries for table creations - Table creation is now down in reverse order to accompany the foreign key - Fixed a typo in the BaseRecognizer that caused it to not work - Changed configuration passed to Dejavu into a (nested) dictionary
This commit is contained in:
parent
1c9eddc3a2
commit
3b72768f94
3 changed files with 59 additions and 74 deletions
|
@ -8,20 +8,14 @@ import random
|
||||||
|
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
class Dejavu():
|
class Dejavu(object):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# initialize db
|
# initialize db
|
||||||
database = SQLDatabase(
|
self.db = SQLDatabase(**config.get("database", {}))
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
|
||||||
self.db = database
|
|
||||||
|
|
||||||
# create components
|
# create components
|
||||||
self.converter = Converter()
|
self.converter = Converter()
|
||||||
#self.fingerprinter = Fingerprinter(self.config)
|
#self.fingerprinter = Fingerprinter(self.config)
|
||||||
|
@ -30,16 +24,16 @@ class Dejavu():
|
||||||
# get songs previously indexed
|
# get songs previously indexed
|
||||||
self.songs = self.db.get_songs()
|
self.songs = self.db.get_songs()
|
||||||
self.songnames_set = set() # to know which ones we've computed before
|
self.songnames_set = set() # to know which ones we've computed before
|
||||||
if self.songs:
|
|
||||||
for song in self.songs:
|
for song in self.songs:
|
||||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
song_name = song[self.db.FIELD_SONGNAME]
|
||||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
|
||||||
self.songnames_set.add(song_name)
|
self.songnames_set.add(song_name)
|
||||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||||
|
|
||||||
def chunkify(self, lst, n):
|
def chunkify(self, lst, n):
|
||||||
"""
|
"""
|
||||||
Splits a list into roughly n equal parts.
|
Splits a list into roughly n equal parts.
|
||||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||||
"""
|
"""
|
||||||
return [lst[i::n] for i in xrange(n)]
|
return [lst[i::n] for i in xrange(n)]
|
||||||
|
@ -55,25 +49,19 @@ class Dejavu():
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(nprocesses):
|
for i in range(nprocesses):
|
||||||
|
|
||||||
# need database instance since mysql connections shouldn't be shared across processes
|
|
||||||
sql_connection = SQLDatabase(
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
|
||||||
|
|
||||||
# create process and start it
|
# create process and start it
|
||||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
p = Process(target=self.fingerprint_worker,
|
||||||
|
args=(files_split[i], self.db, output))
|
||||||
p.start()
|
p.start()
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
|
|
||||||
# wait for all processes to complete
|
# wait for all processes to complete
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
# delete orphans
|
# delete orphans
|
||||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||||
# TODO: need a more performant query in database.py for the
|
# TODO: need a more performant query in database.py for the
|
||||||
#self.fingerprinter.db.delete_orphans()
|
#self.fingerprinter.db.delete_orphans()
|
||||||
|
|
||||||
def fingerprint_worker(self, files, sql_connection, output):
|
def fingerprint_worker(self, files, sql_connection, output):
|
||||||
|
@ -82,7 +70,7 @@ class Dejavu():
|
||||||
|
|
||||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
song_name = os.path.basename(filename).split(".")[0]
|
||||||
if DEBUG and song_name in self.songnames_set:
|
if DEBUG and song_name in self.songnames_set:
|
||||||
print("-> Already fingerprinted, continuing...")
|
print("-> Already fingerprinted, continuing...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -117,27 +105,27 @@ class Dejavu():
|
||||||
for channel in range(nchannels):
|
for channel in range(nchannels):
|
||||||
channels.append(frames[:, channel])
|
channels.append(frames[:, channel])
|
||||||
return (channels, Fs)
|
return (channels, Fs)
|
||||||
|
|
||||||
def fingerprint_file(self, filepath, song_name=None):
|
def fingerprint_file(self, filepath, song_name=None):
|
||||||
# TODO: replace with something that handles all audio formats
|
# TODO: replace with something that handles all audio formats
|
||||||
channels, Fs = self.extract_channels(path)
|
channels, Fs = self.extract_channels(path)
|
||||||
if not song_name:
|
if not song_name:
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
song_name = os.path.basename(filename).split(".")[0]
|
||||||
song_id = self.db.insert_song(song_name)
|
song_id = self.db.insert_song(song_name)
|
||||||
|
|
||||||
for data in channels:
|
for data in channels:
|
||||||
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
||||||
self.db.insert_hashes(song_id, hashes)
|
self.db.insert_hashes(song_id, hashes)
|
||||||
|
|
||||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||||
return self.db.return_matches(hashes)
|
return self.db.return_matches(hashes)
|
||||||
|
|
||||||
def align_matches(self, matches):
|
def align_matches(self, matches):
|
||||||
"""
|
"""
|
||||||
Finds hash matches that align in time with other matches and finds
|
Finds hash matches that align in time with other matches and finds
|
||||||
consensus about which hashes are "true" signal from the audio.
|
consensus about which hashes are "true" signal from the audio.
|
||||||
|
|
||||||
Returns a dictionary with match information.
|
Returns a dictionary with match information.
|
||||||
"""
|
"""
|
||||||
# align by diffs
|
# align by diffs
|
||||||
|
@ -158,24 +146,24 @@ class Dejavu():
|
||||||
largest_count = diff_counter[diff][sid]
|
largest_count = diff_counter[diff][sid]
|
||||||
song_id = sid
|
song_id = sid
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||||
|
|
||||||
# extract idenfication
|
# extract idenfication
|
||||||
song = self.db.get_song_by_id(song_id)
|
song = self.db.get_song_by_id(song_id)
|
||||||
if song:
|
if song:
|
||||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
||||||
|
|
||||||
# return match info
|
# return match info
|
||||||
song = {
|
song = {
|
||||||
"song_id" : song_id,
|
"song_id" : song_id,
|
||||||
"song_name" : songname,
|
"song_name" : songname,
|
||||||
"confidence" : largest_count
|
"confidence" : largest_count
|
||||||
}
|
}
|
||||||
|
|
||||||
return song
|
return song
|
||||||
|
|
|
@ -70,7 +70,7 @@ class SQLDatabase(Database):
|
||||||
`%s` binary(10) not null,
|
`%s` binary(10) not null,
|
||||||
`%s` mediumint unsigned not null,
|
`%s` mediumint unsigned not null,
|
||||||
`%s` int unsigned not null,
|
`%s` int unsigned not null,
|
||||||
INDEX(%s),
|
PRIMARY KEY(%s),
|
||||||
UNIQUE(%s, %s, %s),
|
UNIQUE(%s, %s, %s),
|
||||||
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
|
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
|
||||||
) ENGINE=INNODB;""" % (
|
) ENGINE=INNODB;""" % (
|
||||||
|
@ -157,8 +157,8 @@ class SQLDatabase(Database):
|
||||||
fingerprints associated with them.
|
fingerprints associated with them.
|
||||||
"""
|
"""
|
||||||
with self.cursor() as cur:
|
with self.cursor() as cur:
|
||||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
|
||||||
cur.execute(self.CREATE_SONGS_TABLE)
|
cur.execute(self.CREATE_SONGS_TABLE)
|
||||||
|
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||||
|
|
||||||
def empty(self):
|
def empty(self):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from multiprocessing import Queue, Process
|
from multiprocessing import Queue, Process
|
||||||
from dejavu.database import SQLDatabase
|
from dejavu.database import SQLDatabase
|
||||||
import dejavu.fingerprint
|
import dejavu.fingerprint as fingerprint
|
||||||
from dejavu import Dejavu
|
from dejavu import Dejavu
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import wave
|
import wave
|
||||||
|
@ -12,60 +12,57 @@ import array
|
||||||
|
|
||||||
|
|
||||||
class BaseRecognizer(object):
|
class BaseRecognizer(object):
|
||||||
|
|
||||||
def __init__(self, dejavu):
|
def __init__(self, dejavu):
|
||||||
self.dejavu = dejavu
|
self.dejavu = dejavu
|
||||||
self.Fs = dejavu.fingerprint.DEFAULT_FS
|
self.Fs = fingerprint.DEFAULT_FS
|
||||||
|
|
||||||
def _recognize(self, *data):
|
def _recognize(self, *data):
|
||||||
matches = []
|
matches = []
|
||||||
for d in data:
|
for d in data:
|
||||||
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
|
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||||
return self.dejavu.align_matches(matches)
|
return self.dejavu.align_matches(matches)
|
||||||
|
|
||||||
def recognize(self):
|
|
||||||
pass # base class does nothing
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def recognize(self):
|
||||||
|
pass # base class does nothing
|
||||||
|
|
||||||
class WaveFileRecognizer(BaseRecognizer):
|
class WaveFileRecognizer(BaseRecognizer):
|
||||||
|
|
||||||
def __init__(self, dejavu, filename=None):
|
def __init__(self, dejavu, filename=None):
|
||||||
super(WaveFileRecognizer, self).__init__(dejavu)
|
super(WaveFileRecognizer, self).__init__(dejavu)
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
|
||||||
def recognize_file(self, filename):
|
def recognize_file(self, filename):
|
||||||
Fs, frames = wavfile.read(filename)
|
Fs, frames = wavfile.read(filename)
|
||||||
self.Fs = Fs
|
self.Fs = Fs
|
||||||
|
|
||||||
wave_object = wave.open(filename)
|
wave_object = wave.open(filename)
|
||||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||||
|
|
||||||
channels = []
|
channels = []
|
||||||
for channel in range(nchannels):
|
for channel in range(nchannels):
|
||||||
channels.append(frames[:, channel])
|
channels.append(frames[:, channel])
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
match = self._recognize(*channels)
|
match = self._recognize(*channels)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
match['match_time'] = t
|
match['match_time'] = t
|
||||||
|
|
||||||
return match
|
return match
|
||||||
|
|
||||||
def recognize(self):
|
def recognize(self):
|
||||||
return self.recognize_file(self.filename)
|
return self.recognize_file(self.filename)
|
||||||
|
|
||||||
|
|
||||||
class MicrophoneRecognizer(BaseRecognizer):
|
class MicrophoneRecognizer(BaseRecognizer):
|
||||||
|
|
||||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||||
FORMAT = pyaudio.paInt16
|
FORMAT = pyaudio.paInt16
|
||||||
CHANNELS = 2
|
CHANNELS = 2
|
||||||
RATE = 44100
|
RATE = 44100
|
||||||
|
|
||||||
def __init__(self, dejavu, seconds=None):
|
def __init__(self, dejavu, seconds=None):
|
||||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||||
self.audio = pyaudio.PyAudio()
|
self.audio = pyaudio.PyAudio()
|
||||||
|
@ -75,52 +72,52 @@ class MicrophoneRecognizer(BaseRecognizer):
|
||||||
self.chunk_size = CHUNK
|
self.chunk_size = CHUNK
|
||||||
self.rate = RATE
|
self.rate = RATE
|
||||||
self.recorded = False
|
self.recorded = False
|
||||||
|
|
||||||
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
||||||
self.chunk_size = chunk
|
self.chunk_size = chunk
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.recorded = False
|
self.recorded = False
|
||||||
self.rate = rate
|
self.rate = rate
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
self.stream.stop_stream()
|
self.stream.stop_stream()
|
||||||
self.stream.close()
|
self.stream.close()
|
||||||
|
|
||||||
self.stream = self.audio.open(format=FORMAT,
|
self.stream = self.audio.open(format=FORMAT,
|
||||||
channels=channels,
|
channels=channels,
|
||||||
rate=rate,
|
rate=rate,
|
||||||
input=True,
|
input=True,
|
||||||
frames_per_buffer=chunk)
|
frames_per_buffer=chunk)
|
||||||
|
|
||||||
self.data = [[] for i in range(channels)]
|
self.data = [[] for i in range(channels)]
|
||||||
|
|
||||||
def process_recording(self):
|
def process_recording(self):
|
||||||
data = self.stream.read(self.chunk_size)
|
data = self.stream.read(self.chunk_size)
|
||||||
nums = np.fromstring(data, np.int16)
|
nums = np.fromstring(data, np.int16)
|
||||||
for c in range(self.channels):
|
for c in range(self.channels):
|
||||||
self.data[c].extend(nums[c::c+1])
|
self.data[c].extend(nums[c::c+1])
|
||||||
|
|
||||||
def stop_recording(self):
|
def stop_recording(self):
|
||||||
self.stream.stop_stream()
|
self.stream.stop_stream()
|
||||||
self.stream.close()
|
self.stream.close()
|
||||||
self.stream = None
|
self.stream = None
|
||||||
self.recorded = True
|
self.recorded = True
|
||||||
|
|
||||||
def recognize_recording(self):
|
def recognize_recording(self):
|
||||||
if not self.recorded:
|
if not self.recorded:
|
||||||
raise NoRecordingError("Recording was not complete/begun")
|
raise NoRecordingError("Recording was not complete/begun")
|
||||||
return self._recognize(*self.data)
|
return self._recognize(*self.data)
|
||||||
|
|
||||||
def get_recorded_time(self):
|
def get_recorded_time(self):
|
||||||
return len(self.data[0]) / self.rate
|
return len(self.data[0]) / self.rate
|
||||||
|
|
||||||
def recognize(self):
|
def recognize(self):
|
||||||
self.start_recording()
|
self.start_recording()
|
||||||
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
||||||
self.process_recording()
|
self.process_recording()
|
||||||
self.stop_recording()
|
self.stop_recording()
|
||||||
return self.recognize_recording()
|
return self.recognize_recording()
|
||||||
|
|
||||||
class NoRecordingError(Exception):
|
class NoRecordingError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue