mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Merged with fingerprinter and recognizer cleanup
This commit is contained in:
commit
1c9eddc3a2
6 changed files with 410 additions and 331 deletions
181
dejavu/__init__.py
Normal file → Executable file
181
dejavu/__init__.py
Normal file → Executable file
|
@ -0,0 +1,181 @@
|
|||
from dejavu.database import SQLDatabase
|
||||
from dejavu.convert import Converter
|
||||
import fingerprint
|
||||
from scipy.io import wavfile
|
||||
from multiprocessing import Process
|
||||
import wave, os
|
||||
import random
|
||||
|
||||
DEBUG = False
|
||||
|
||||
class Dejavu():
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
|
||||
# initialize db
|
||||
database = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
self.db = database
|
||||
|
||||
# create components
|
||||
self.converter = Converter()
|
||||
#self.fingerprinter = Fingerprinter(self.config)
|
||||
self.db.setup()
|
||||
|
||||
# get songs previously indexed
|
||||
self.songs = self.db.get_songs()
|
||||
self.songnames_set = set() # to know which ones we've computed before
|
||||
if self.songs:
|
||||
for song in self.songs:
|
||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
||||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
def chunkify(self, lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
|
||||
def do_fingerprint(self, path, output, extensions, nprocesses):
|
||||
|
||||
# convert files, shuffle order
|
||||
files = self.converter.find_files(path, extensions)
|
||||
random.shuffle(files)
|
||||
files_split = self.chunkify(files, nprocesses)
|
||||
|
||||
# split into processes here
|
||||
processes = []
|
||||
for i in range(nprocesses):
|
||||
|
||||
# need database instance since mysql connections shouldn't be shared across processes
|
||||
sql_connection = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
|
||||
# create process and start it
|
||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# delete orphans
|
||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||
# TODO: need a more performant query in database.py for the
|
||||
#self.fingerprinter.db.delete_orphans()
|
||||
|
||||
def fingerprint_worker(self, files, sql_connection, output):
|
||||
|
||||
for filename, extension in files:
|
||||
|
||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
if DEBUG and song_name in self.songnames_set:
|
||||
print("-> Already fingerprinted, continuing...")
|
||||
continue
|
||||
|
||||
# convert to WAV
|
||||
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
|
||||
|
||||
# insert song name into database
|
||||
song_id = sql_connection.insert_song(song_name)
|
||||
|
||||
# for each channel perform FFT analysis and fingerprinting
|
||||
channels, Fs = self.extract_channels(wavout_path)
|
||||
for c in range(len(channels)):
|
||||
channel = channels[c]
|
||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||
sql_connection.insert_hashes(song_id, hashes)
|
||||
|
||||
# only after done fingerprinting do confirm
|
||||
sql_connection.set_song_fingerprinted(song_id)
|
||||
|
||||
def extract_channels(self, path):
|
||||
"""
|
||||
Reads channels from disk.
|
||||
Returns a tuple with (channels, sample_rate)
|
||||
"""
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(path)
|
||||
wave_object = wave.open(path)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
#assert Fs == self.fingerprinter.Fs
|
||||
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
return (channels, Fs)
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
# TODO: replace with something that handles all audio formats
|
||||
channels, Fs = self.extract_channels(path)
|
||||
if not song_name:
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
song_id = self.db.insert_song(song_name)
|
||||
|
||||
for data in channels:
|
||||
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
||||
self.db.insert_hashes(song_id, hashes)
|
||||
|
||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||
return self.db.return_matches(hashes)
|
||||
|
||||
def align_matches(self, matches):
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
|
||||
Returns a dictionary with match information.
|
||||
"""
|
||||
# align by diffs
|
||||
diff_counter = {}
|
||||
largest = 0
|
||||
largest_count = 0
|
||||
song_id = -1
|
||||
for tup in matches:
|
||||
sid, diff = tup
|
||||
if not diff in diff_counter:
|
||||
diff_counter[diff] = {}
|
||||
if not sid in diff_counter[diff]:
|
||||
diff_counter[diff][sid] = 0
|
||||
diff_counter[diff][sid] += 1
|
||||
|
||||
if diff_counter[diff][sid] > largest_count:
|
||||
largest = diff
|
||||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
if DEBUG:
|
||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||
else:
|
||||
return None
|
||||
|
||||
if DEBUG:
|
||||
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
||||
|
||||
# return match info
|
||||
song = {
|
||||
"song_id" : song_id,
|
||||
"song_name" : songname,
|
||||
"confidence" : largest_count
|
||||
}
|
||||
|
||||
return song
|
|
@ -1,107 +0,0 @@
|
|||
from dejavu.database import SQLDatabase
|
||||
from dejavu.convert import Converter
|
||||
from dejavu.fingerprint import Fingerprinter
|
||||
from scipy.io import wavfile
|
||||
from multiprocessing import Process
|
||||
import wave, os
|
||||
import random
|
||||
|
||||
class Dejavu():
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
|
||||
# create components
|
||||
self.converter = Converter()
|
||||
self.fingerprinter = Fingerprinter(self.config)
|
||||
self.fingerprinter.db.setup()
|
||||
|
||||
# get songs previously indexed
|
||||
self.songs = self.fingerprinter.db.get_songs()
|
||||
self.songnames_set = set() # to know which ones we've computed before
|
||||
if self.songs:
|
||||
for song in self.songs:
|
||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
||||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
def chunkify(self, lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
|
||||
def fingerprint(self, path, output, extensions, nprocesses):
|
||||
|
||||
# convert files, shuffle order
|
||||
files = self.converter.find_files(path, extensions)
|
||||
random.shuffle(files)
|
||||
files_split = self.chunkify(files, nprocesses)
|
||||
|
||||
# split into processes here
|
||||
processes = []
|
||||
for i in range(nprocesses):
|
||||
|
||||
# need database instance since mysql connections shouldn't be shared across processes
|
||||
sql_connection = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
|
||||
# create process and start it
|
||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# delete orphans
|
||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||
# TODO: need a more performant query in database.py for the
|
||||
#self.fingerprinter.db.delete_orphans()
|
||||
|
||||
def fingerprint_worker(self, files, sql_connection, output):
|
||||
|
||||
for filename, extension in files:
|
||||
|
||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
if song_name in self.songnames_set:
|
||||
print "-> Already fingerprinted, continuing..."
|
||||
continue
|
||||
|
||||
# convert to WAV
|
||||
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
|
||||
|
||||
# insert song name into database
|
||||
song_id = sql_connection.insert_song(song_name)
|
||||
|
||||
# for each channel perform FFT analysis and fingerprinting
|
||||
channels = self.extract_channels(wavout_path)
|
||||
for c in range(len(channels)):
|
||||
channel = channels[c]
|
||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||
self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1)
|
||||
|
||||
# only after done fingerprinting do confirm
|
||||
sql_connection.set_song_fingerprinted(song_id)
|
||||
|
||||
def extract_channels(self, path):
|
||||
"""
|
||||
Reads channels from disk.
|
||||
"""
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(path)
|
||||
wave_object = wave.open(path)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
assert Fs == self.fingerprinter.Fs
|
||||
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
return channels
|
|
@ -280,15 +280,11 @@ class SQLDatabase(Database):
|
|||
def return_matches(self, hashes):
|
||||
"""
|
||||
Return the (song_id, offset_diff) tuples associated with
|
||||
a list of
|
||||
|
||||
sha1 => (None, sample_offset)
|
||||
|
||||
values.
|
||||
a list of (sha1, sample_offset) values.
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hash, (_, offset) in hashes:
|
||||
for hash, offset in hashes:
|
||||
mapper[hash.upper()] = offset
|
||||
|
||||
# Get an iteratable of all the hashes we need
|
||||
|
|
181
dejavu/fingerprint.py
Normal file → Executable file
181
dejavu/fingerprint.py
Normal file → Executable file
|
@ -13,56 +13,24 @@ import time
|
|||
import hashlib
|
||||
import pickle
|
||||
|
||||
class Fingerprinter():
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
||||
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
||||
DEFAULT_FS = 44100
|
||||
DEFAULT_WINDOW_SIZE = 4096
|
||||
DEFAULT_OVERLAP_RATIO = 0.5
|
||||
DEFAULT_FAN_VALUE = 15
|
||||
|
||||
DEFAULT_FS = 44100
|
||||
DEFAULT_WINDOW_SIZE = 4096
|
||||
DEFAULT_OVERLAP_RATIO = 0.5
|
||||
DEFAULT_FAN_VALUE = 15
|
||||
DEFAULT_AMP_MIN = 10
|
||||
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||
MIN_HASH_TIME_DELTA = 0
|
||||
|
||||
DEFAULT_AMP_MIN = 10
|
||||
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||
MIN_HASH_TIME_DELTA = 0
|
||||
|
||||
def __init__(self, config,
|
||||
def fingerprint(channel_samples,
|
||||
Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
amp_min=DEFAULT_AMP_MIN):
|
||||
|
||||
self.config = config
|
||||
database = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
self.db = database
|
||||
|
||||
self.Fs = Fs
|
||||
self.dt = 1.0 / self.Fs
|
||||
self.window_size = wsize
|
||||
self.window_overlap_ratio = wratio
|
||||
self.fan_value = fan_value
|
||||
self.noverlap = int(self.window_size * self.window_overlap_ratio)
|
||||
self.amp_min = amp_min
|
||||
|
||||
def fingerprint(self, samples, path, sid, cid):
|
||||
"""Used for learning known songs"""
|
||||
hashes = self.process_channel(samples, song_id=sid)
|
||||
print "Generated %d hashes" % len(hashes)
|
||||
self.db.insert_hashes(hashes)
|
||||
|
||||
def match(self, samples):
|
||||
"""Used for matching unknown songs"""
|
||||
hashes = self.process_channel(samples)
|
||||
matches = self.db.return_matches(hashes)
|
||||
return matches
|
||||
|
||||
def process_channel(self, channel_samples, song_id=None):
|
||||
"""
|
||||
FFT the channel, log transform output, find local maxima, then return
|
||||
locally sensitive hashes.
|
||||
|
@ -70,26 +38,26 @@ class Fingerprinter():
|
|||
# FFT the signal and extract frequency components
|
||||
arr2D = mlab.specgram(
|
||||
channel_samples,
|
||||
NFFT=self.window_size,
|
||||
Fs=self.Fs,
|
||||
NFFT=wsize,
|
||||
Fs=Fs,
|
||||
window=mlab.window_hanning,
|
||||
noverlap=self.noverlap)[0]
|
||||
noverlap=int(wsize * wratio))[0]
|
||||
|
||||
# apply log transform since specgram() returns linear array
|
||||
arr2D = 10 * np.log10(arr2D)
|
||||
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
|
||||
|
||||
# find local maxima
|
||||
local_maxima = self.get_2D_peaks(arr2D, plot=False)
|
||||
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||
|
||||
# return hashes
|
||||
return self.generate_hashes(local_maxima, song_id=song_id)
|
||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||
|
||||
def get_2D_peaks(self, arr2D, plot=False):
|
||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
||||
struct = generate_binary_structure(2, 1)
|
||||
neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE)
|
||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||
|
||||
# find local maxima using our fliter shape
|
||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||
|
@ -104,7 +72,7 @@ class Fingerprinter():
|
|||
# filter peaks
|
||||
amps = amps.flatten()
|
||||
peaks = zip(i, j, amps)
|
||||
peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp
|
||||
peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp
|
||||
|
||||
# get indices for frequency and time
|
||||
frequency_idx = [x[1] for x in peaks_filtered]
|
||||
|
@ -123,36 +91,72 @@ class Fingerprinter():
|
|||
|
||||
return zip(frequency_idx, time_idx)
|
||||
|
||||
def generate_hashes(self, peaks, song_id=None):
|
||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||
"""
|
||||
Hash list structure:
|
||||
sha1-hash[0:20] song_id, time_offset
|
||||
[(e05b341a9b77a51fd26, (3, 32)), ... ]
|
||||
sha1-hash[0:20] time_offset
|
||||
[(e05b341a9b77a51fd26, 32), ... ]
|
||||
"""
|
||||
fingerprinted = set() # to avoid rehashing same pairs
|
||||
hashes = []
|
||||
|
||||
for i in range(len(peaks)):
|
||||
for j in range(self.fan_value):
|
||||
for j in range(fan_value):
|
||||
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
|
||||
|
||||
freq1 = peaks[i][Fingerprinter.IDX_FREQ_I]
|
||||
freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I]
|
||||
t1 = peaks[i][Fingerprinter.IDX_TIME_J]
|
||||
t2 = peaks[i+j][Fingerprinter.IDX_TIME_J]
|
||||
freq1 = peaks[i][IDX_FREQ_I]
|
||||
freq2 = peaks[i+j][IDX_FREQ_I]
|
||||
t1 = peaks[i][IDX_TIME_J]
|
||||
t2 = peaks[i+j][IDX_TIME_J]
|
||||
t_delta = t2 - t1
|
||||
|
||||
if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA:
|
||||
if t_delta >= MIN_HASH_TIME_DELTA:
|
||||
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
||||
hashes.append((h.hexdigest()[0:20], (song_id, t1)))
|
||||
hashes.append((h.hexdigest()[0:20], t1))
|
||||
|
||||
# ensure we don't repeat hashing
|
||||
fingerprinted.add((i, i+j))
|
||||
return hashes
|
||||
|
||||
def insert_into_db(self, key, value):
|
||||
self.db.insert(key, value)
|
||||
# TODO: move all of the below to a class with DB access
|
||||
|
||||
|
||||
class Fingerprinter():
|
||||
|
||||
|
||||
|
||||
def __init__(self, config,
|
||||
Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
amp_min=DEFAULT_AMP_MIN):
|
||||
|
||||
self.config = config
|
||||
|
||||
|
||||
self.Fs = Fs
|
||||
self.dt = 1.0 / self.Fs
|
||||
self.window_size = wsize
|
||||
self.window_overlap_ratio = wratio
|
||||
self.fan_value = fan_value
|
||||
self.noverlap = int(self.window_size * self.window_overlap_ratio)
|
||||
self.amp_min = amp_min
|
||||
|
||||
def fingerprint(self, samples, path, sid, cid):
|
||||
"""Used for learning known songs"""
|
||||
hashes = self.process_channel(samples, song_id=sid)
|
||||
print "Generated %d hashes" % len(hashes)
|
||||
self.db.insert_hashes(hashes)
|
||||
|
||||
# TODO: put this in another module
|
||||
def match(self, samples):
|
||||
"""Used for matching unknown songs"""
|
||||
hashes = self.process_channel(samples)
|
||||
matches = self.db.return_matches(hashes)
|
||||
return matches
|
||||
|
||||
# TODO: this function has nothing to do with fingerprinting. is it needed?
|
||||
def print_stats(self):
|
||||
|
||||
iterable = self.db.get_iterable_kv_pairs()
|
||||
|
@ -169,58 +173,9 @@ class Fingerprinter():
|
|||
song_name = self.song_names[song_id]
|
||||
print "%s has %d spectrogram peaks" % (song_name, count)
|
||||
|
||||
# this does... what? this seems to only be used for the above function
|
||||
def set_song_names(self, wpaths):
|
||||
self.song_names = wpaths
|
||||
|
||||
def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
# TODO: put this in another module
|
||||
|
||||
Returns a dictionary with match information.
|
||||
"""
|
||||
# align by diffs
|
||||
diff_counter = {}
|
||||
largest = 0
|
||||
largest_count = 0
|
||||
song_id = -1
|
||||
for tup in matches:
|
||||
sid, diff = tup
|
||||
if not diff in diff_counter:
|
||||
diff_counter[diff] = {}
|
||||
if not sid in diff_counter[diff]:
|
||||
diff_counter[diff][sid] = 0
|
||||
diff_counter[diff][sid] += 1
|
||||
|
||||
if diff_counter[diff][sid] > largest_count:
|
||||
largest = diff
|
||||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
if verbose:
|
||||
print "Diff is %d with %d offset-aligned matches" % (largest, largest_count)
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||
else:
|
||||
return None
|
||||
songname = songname.replace("_", " ")
|
||||
elapsed = time.time() - starttime
|
||||
|
||||
if verbose:
|
||||
print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)
|
||||
|
||||
# return match info
|
||||
song = {
|
||||
"song_id" : song_id,
|
||||
"song_name" : songname,
|
||||
"match_time" : elapsed,
|
||||
"confidence" : largest_count
|
||||
}
|
||||
|
||||
if record_seconds:
|
||||
song['record_time'] = record_seconds
|
||||
|
||||
return song
|
||||
|
|
146
dejavu/recognize.py
Normal file → Executable file
146
dejavu/recognize.py
Normal file → Executable file
|
@ -1,5 +1,7 @@
|
|||
from multiprocessing import Queue, Process
|
||||
from dejavu.database import SQLDatabase
|
||||
import dejavu.fingerprint
|
||||
from dejavu import Dejavu
|
||||
from scipy.io import wavfile
|
||||
import wave
|
||||
import numpy as np
|
||||
|
@ -8,65 +10,117 @@ import sys
|
|||
import time
|
||||
import array
|
||||
|
||||
class Recognizer(object):
|
||||
|
||||
class BaseRecognizer(object):
|
||||
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = dejavu.fingerprint.DEFAULT_FS
|
||||
|
||||
def _recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
|
||||
|
||||
|
||||
class WaveFileRecognizer(BaseRecognizer):
|
||||
|
||||
def __init__(self, dejavu, filename=None):
|
||||
super(WaveFileRecognizer, self).__init__(dejavu)
|
||||
self.filename = filename
|
||||
|
||||
def recognize_file(self, filename):
|
||||
Fs, frames = wavfile.read(filename)
|
||||
self.Fs = Fs
|
||||
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
|
||||
channels = []
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
|
||||
t = time.time()
|
||||
match = self._recognize(*channels)
|
||||
t = time.time() - t
|
||||
|
||||
if match:
|
||||
match['match_time'] = t
|
||||
|
||||
return match
|
||||
|
||||
def recognize(self):
|
||||
return self.recognize_file(self.filename)
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
||||
|
||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 2
|
||||
RATE = 44100
|
||||
|
||||
def __init__(self, fingerprinter, config):
|
||||
|
||||
self.fingerprinter = fingerprinter
|
||||
self.config = config
|
||||
def __init__(self, dejavu, seconds=None):
|
||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
self.data = []
|
||||
self.channels = CHANNELS
|
||||
self.chunk_size = CHUNK
|
||||
self.rate = RATE
|
||||
self.recorded = False
|
||||
|
||||
def read(self, filename, verbose=False):
|
||||
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
||||
self.chunk_size = chunk
|
||||
self.channels = channels
|
||||
self.recorded = False
|
||||
self.rate = rate
|
||||
|
||||
# read file into channels
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(filename)
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
|
||||
# get matches
|
||||
starttime = time.time()
|
||||
matches = []
|
||||
for channel in channels:
|
||||
matches.extend(self.fingerprinter.match(channel))
|
||||
|
||||
return self.fingerprinter.align_matches(matches, starttime, verbose=verbose)
|
||||
|
||||
def listen(self, seconds=10, verbose=False):
|
||||
|
||||
# open stream
|
||||
stream = self.audio.open(format=Recognizer.FORMAT,
|
||||
channels=Recognizer.CHANNELS,
|
||||
rate=Recognizer.RATE,
|
||||
self.stream = self.audio.open(format=FORMAT,
|
||||
channels=channels,
|
||||
rate=rate,
|
||||
input=True,
|
||||
frames_per_buffer=Recognizer.CHUNK)
|
||||
frames_per_buffer=chunk)
|
||||
|
||||
# record
|
||||
if verbose: print("* recording")
|
||||
left, right = [], []
|
||||
for i in range(0, int(Recognizer.RATE / Recognizer.CHUNK * seconds)):
|
||||
data = stream.read(Recognizer.CHUNK)
|
||||
self.data = [[] for i in range(channels)]
|
||||
|
||||
def process_recording(self):
|
||||
data = self.stream.read(self.chunk_size)
|
||||
nums = np.fromstring(data, np.int16)
|
||||
left.extend(nums[1::2])
|
||||
right.extend(nums[0::2])
|
||||
if verbose: print("* done recording")
|
||||
for c in range(self.channels):
|
||||
self.data[c].extend(nums[c::c+1])
|
||||
|
||||
# close and stop the stream
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
def stop_recording(self):
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
self.recorded = True
|
||||
|
||||
# match both channels
|
||||
starttime = time.time()
|
||||
matches = []
|
||||
matches.extend(self.fingerprinter.match(left))
|
||||
matches.extend(self.fingerprinter.match(right))
|
||||
def recognize_recording(self):
|
||||
if not self.recorded:
|
||||
raise NoRecordingError("Recording was not complete/begun")
|
||||
return self._recognize(*self.data)
|
||||
|
||||
def get_recorded_time(self):
|
||||
return len(self.data[0]) / self.rate
|
||||
|
||||
def recognize(self):
|
||||
self.start_recording()
|
||||
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
||||
self.process_recording()
|
||||
self.stop_recording()
|
||||
return self.recognize_recording()
|
||||
|
||||
class NoRecordingError(Exception):
|
||||
pass
|
||||
|
||||
# align and return
|
||||
return self.fingerprinter.align_matches(matches, starttime, record_seconds=seconds, verbose=verbose)
|
2
go.py
Normal file → Executable file
2
go.py
Normal file → Executable file
|
@ -1,4 +1,4 @@
|
|||
from dejavu.control import Dejavu
|
||||
from dejavu import Dejavu
|
||||
from ConfigParser import ConfigParser
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
|
Loading…
Reference in a new issue