mirror of
https://github.com/correl/dejavu.git
synced 2024-12-29 11:09:26 +00:00
Fixed various small things that weren't caught before.
- Fixes SQL queries for table creations - Table creation is now down in reverse order to accompany the foreign key - Fixed a typo in the BaseRecognizer that caused it to not work - Changed configuration passed to Dejavu into a (nested) dictionary
This commit is contained in:
parent
1c9eddc3a2
commit
3b72768f94
3 changed files with 59 additions and 74 deletions
|
@ -8,20 +8,14 @@ import random
|
|||
|
||||
DEBUG = False
|
||||
|
||||
class Dejavu():
|
||||
|
||||
class Dejavu(object):
|
||||
def __init__(self, config):
|
||||
|
||||
|
||||
self.config = config
|
||||
|
||||
|
||||
# initialize db
|
||||
database = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
self.db = database
|
||||
|
||||
self.db = SQLDatabase(**config.get("database", {}))
|
||||
|
||||
# create components
|
||||
self.converter = Converter()
|
||||
#self.fingerprinter = Fingerprinter(self.config)
|
||||
|
@ -30,16 +24,16 @@ class Dejavu():
|
|||
# get songs previously indexed
|
||||
self.songs = self.db.get_songs()
|
||||
self.songnames_set = set() # to know which ones we've computed before
|
||||
if self.songs:
|
||||
for song in self.songs:
|
||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
||||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
for song in self.songs:
|
||||
song_name = song[self.db.FIELD_SONGNAME]
|
||||
|
||||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
def chunkify(self, lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
|
@ -55,25 +49,19 @@ class Dejavu():
|
|||
processes = []
|
||||
for i in range(nprocesses):
|
||||
|
||||
# need database instance since mysql connections shouldn't be shared across processes
|
||||
sql_connection = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
|
||||
# create process and start it
|
||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
||||
p = Process(target=self.fingerprint_worker,
|
||||
args=(files_split[i], self.db, output))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
# delete orphans
|
||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||
# TODO: need a more performant query in database.py for the
|
||||
# TODO: need a more performant query in database.py for the
|
||||
#self.fingerprinter.db.delete_orphans()
|
||||
|
||||
def fingerprint_worker(self, files, sql_connection, output):
|
||||
|
@ -82,7 +70,7 @@ class Dejavu():
|
|||
|
||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
if DEBUG and song_name in self.songnames_set:
|
||||
if DEBUG and song_name in self.songnames_set:
|
||||
print("-> Already fingerprinted, continuing...")
|
||||
continue
|
||||
|
||||
|
@ -117,27 +105,27 @@ class Dejavu():
|
|||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
return (channels, Fs)
|
||||
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
# TODO: replace with something that handles all audio formats
|
||||
channels, Fs = self.extract_channels(path)
|
||||
if not song_name:
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
song_id = self.db.insert_song(song_name)
|
||||
|
||||
|
||||
for data in channels:
|
||||
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
||||
self.db.insert_hashes(song_id, hashes)
|
||||
|
||||
|
||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||
return self.db.return_matches(hashes)
|
||||
|
||||
|
||||
def align_matches(self, matches):
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
|
||||
|
||||
Returns a dictionary with match information.
|
||||
"""
|
||||
# align by diffs
|
||||
|
@ -158,24 +146,24 @@ class Dejavu():
|
|||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
if DEBUG:
|
||||
if DEBUG:
|
||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||
|
||||
# extract idenfication
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
if DEBUG:
|
||||
|
||||
if DEBUG:
|
||||
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
||||
|
||||
|
||||
# return match info
|
||||
song = {
|
||||
"song_id" : song_id,
|
||||
"song_name" : songname,
|
||||
"confidence" : largest_count
|
||||
}
|
||||
|
||||
return song
|
||||
|
||||
return song
|
||||
|
|
|
@ -70,7 +70,7 @@ class SQLDatabase(Database):
|
|||
`%s` binary(10) not null,
|
||||
`%s` mediumint unsigned not null,
|
||||
`%s` int unsigned not null,
|
||||
INDEX(%s),
|
||||
PRIMARY KEY(%s),
|
||||
UNIQUE(%s, %s, %s),
|
||||
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
|
||||
) ENGINE=INNODB;""" % (
|
||||
|
@ -157,8 +157,8 @@ class SQLDatabase(Database):
|
|||
fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(self.CREATE_SONGS_TABLE)
|
||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from multiprocessing import Queue, Process
|
||||
from dejavu.database import SQLDatabase
|
||||
import dejavu.fingerprint
|
||||
import dejavu.fingerprint as fingerprint
|
||||
from dejavu import Dejavu
|
||||
from scipy.io import wavfile
|
||||
import wave
|
||||
|
@ -12,60 +12,57 @@ import array
|
|||
|
||||
|
||||
class BaseRecognizer(object):
|
||||
|
||||
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = dejavu.fingerprint.DEFAULT_FS
|
||||
|
||||
self.Fs = fingerprint.DEFAULT_FS
|
||||
|
||||
def _recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
|
||||
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
|
||||
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
class WaveFileRecognizer(BaseRecognizer):
|
||||
|
||||
|
||||
def __init__(self, dejavu, filename=None):
|
||||
super(WaveFileRecognizer, self).__init__(dejavu)
|
||||
self.filename = filename
|
||||
|
||||
|
||||
def recognize_file(self, filename):
|
||||
Fs, frames = wavfile.read(filename)
|
||||
self.Fs = Fs
|
||||
|
||||
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
|
||||
|
||||
channels = []
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
|
||||
|
||||
t = time.time()
|
||||
match = self._recognize(*channels)
|
||||
t = time.time() - t
|
||||
|
||||
|
||||
if match:
|
||||
match['match_time'] = t
|
||||
|
||||
|
||||
return match
|
||||
|
||||
|
||||
def recognize(self):
|
||||
return self.recognize_file(self.filename)
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
||||
|
||||
|
||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 2
|
||||
RATE = 44100
|
||||
|
||||
|
||||
def __init__(self, dejavu, seconds=None):
|
||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||
self.audio = pyaudio.PyAudio()
|
||||
|
@ -75,52 +72,52 @@ class MicrophoneRecognizer(BaseRecognizer):
|
|||
self.chunk_size = CHUNK
|
||||
self.rate = RATE
|
||||
self.recorded = False
|
||||
|
||||
|
||||
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
||||
self.chunk_size = chunk
|
||||
self.channels = channels
|
||||
self.recorded = False
|
||||
self.rate = rate
|
||||
|
||||
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
|
||||
|
||||
self.stream = self.audio.open(format=FORMAT,
|
||||
channels=channels,
|
||||
rate=rate,
|
||||
input=True,
|
||||
frames_per_buffer=chunk)
|
||||
|
||||
|
||||
self.data = [[] for i in range(channels)]
|
||||
|
||||
|
||||
def process_recording(self):
|
||||
data = self.stream.read(self.chunk_size)
|
||||
nums = np.fromstring(data, np.int16)
|
||||
for c in range(self.channels):
|
||||
self.data[c].extend(nums[c::c+1])
|
||||
|
||||
|
||||
def stop_recording(self):
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
self.recorded = True
|
||||
|
||||
|
||||
def recognize_recording(self):
|
||||
if not self.recorded:
|
||||
raise NoRecordingError("Recording was not complete/begun")
|
||||
return self._recognize(*self.data)
|
||||
|
||||
|
||||
def get_recorded_time(self):
|
||||
return len(self.data[0]) / self.rate
|
||||
|
||||
|
||||
def recognize(self):
|
||||
self.start_recording()
|
||||
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
||||
self.process_recording()
|
||||
self.stop_recording()
|
||||
return self.recognize_recording()
|
||||
|
||||
|
||||
class NoRecordingError(Exception):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in a new issue