mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Merged with fingerprinter and recognizer cleanup
This commit is contained in:
commit
1c9eddc3a2
6 changed files with 410 additions and 331 deletions
181
dejavu/__init__.py
Normal file → Executable file
181
dejavu/__init__.py
Normal file → Executable file
|
@ -0,0 +1,181 @@
|
||||||
|
from dejavu.database import SQLDatabase
|
||||||
|
from dejavu.convert import Converter
|
||||||
|
import fingerprint
|
||||||
|
from scipy.io import wavfile
|
||||||
|
from multiprocessing import Process
|
||||||
|
import wave, os
|
||||||
|
import random
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
class Dejavu():
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# initialize db
|
||||||
|
database = SQLDatabase(
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||||
|
self.db = database
|
||||||
|
|
||||||
|
# create components
|
||||||
|
self.converter = Converter()
|
||||||
|
#self.fingerprinter = Fingerprinter(self.config)
|
||||||
|
self.db.setup()
|
||||||
|
|
||||||
|
# get songs previously indexed
|
||||||
|
self.songs = self.db.get_songs()
|
||||||
|
self.songnames_set = set() # to know which ones we've computed before
|
||||||
|
if self.songs:
|
||||||
|
for song in self.songs:
|
||||||
|
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
||||||
|
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
||||||
|
self.songnames_set.add(song_name)
|
||||||
|
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||||
|
|
||||||
|
def chunkify(self, lst, n):
|
||||||
|
"""
|
||||||
|
Splits a list into roughly n equal parts.
|
||||||
|
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||||
|
"""
|
||||||
|
return [lst[i::n] for i in xrange(n)]
|
||||||
|
|
||||||
|
def do_fingerprint(self, path, output, extensions, nprocesses):
|
||||||
|
|
||||||
|
# convert files, shuffle order
|
||||||
|
files = self.converter.find_files(path, extensions)
|
||||||
|
random.shuffle(files)
|
||||||
|
files_split = self.chunkify(files, nprocesses)
|
||||||
|
|
||||||
|
# split into processes here
|
||||||
|
processes = []
|
||||||
|
for i in range(nprocesses):
|
||||||
|
|
||||||
|
# need database instance since mysql connections shouldn't be shared across processes
|
||||||
|
sql_connection = SQLDatabase(
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||||
|
|
||||||
|
# create process and start it
|
||||||
|
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
# wait for all processes to complete
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
# delete orphans
|
||||||
|
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||||
|
# TODO: need a more performant query in database.py for the
|
||||||
|
#self.fingerprinter.db.delete_orphans()
|
||||||
|
|
||||||
|
def fingerprint_worker(self, files, sql_connection, output):
|
||||||
|
|
||||||
|
for filename, extension in files:
|
||||||
|
|
||||||
|
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||||
|
song_name = os.path.basename(filename).split(".")[0]
|
||||||
|
if DEBUG and song_name in self.songnames_set:
|
||||||
|
print("-> Already fingerprinted, continuing...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# convert to WAV
|
||||||
|
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
|
||||||
|
|
||||||
|
# insert song name into database
|
||||||
|
song_id = sql_connection.insert_song(song_name)
|
||||||
|
|
||||||
|
# for each channel perform FFT analysis and fingerprinting
|
||||||
|
channels, Fs = self.extract_channels(wavout_path)
|
||||||
|
for c in range(len(channels)):
|
||||||
|
channel = channels[c]
|
||||||
|
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||||
|
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||||
|
sql_connection.insert_hashes(song_id, hashes)
|
||||||
|
|
||||||
|
# only after done fingerprinting do confirm
|
||||||
|
sql_connection.set_song_fingerprinted(song_id)
|
||||||
|
|
||||||
|
def extract_channels(self, path):
|
||||||
|
"""
|
||||||
|
Reads channels from disk.
|
||||||
|
Returns a tuple with (channels, sample_rate)
|
||||||
|
"""
|
||||||
|
channels = []
|
||||||
|
Fs, frames = wavfile.read(path)
|
||||||
|
wave_object = wave.open(path)
|
||||||
|
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||||
|
#assert Fs == self.fingerprinter.Fs
|
||||||
|
|
||||||
|
for channel in range(nchannels):
|
||||||
|
channels.append(frames[:, channel])
|
||||||
|
return (channels, Fs)
|
||||||
|
|
||||||
|
def fingerprint_file(self, filepath, song_name=None):
|
||||||
|
# TODO: replace with something that handles all audio formats
|
||||||
|
channels, Fs = self.extract_channels(path)
|
||||||
|
if not song_name:
|
||||||
|
song_name = os.path.basename(filename).split(".")[0]
|
||||||
|
song_id = self.db.insert_song(song_name)
|
||||||
|
|
||||||
|
for data in channels:
|
||||||
|
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
||||||
|
self.db.insert_hashes(song_id, hashes)
|
||||||
|
|
||||||
|
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||||
|
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||||
|
return self.db.return_matches(hashes)
|
||||||
|
|
||||||
|
def align_matches(self, matches):
|
||||||
|
"""
|
||||||
|
Finds hash matches that align in time with other matches and finds
|
||||||
|
consensus about which hashes are "true" signal from the audio.
|
||||||
|
|
||||||
|
Returns a dictionary with match information.
|
||||||
|
"""
|
||||||
|
# align by diffs
|
||||||
|
diff_counter = {}
|
||||||
|
largest = 0
|
||||||
|
largest_count = 0
|
||||||
|
song_id = -1
|
||||||
|
for tup in matches:
|
||||||
|
sid, diff = tup
|
||||||
|
if not diff in diff_counter:
|
||||||
|
diff_counter[diff] = {}
|
||||||
|
if not sid in diff_counter[diff]:
|
||||||
|
diff_counter[diff][sid] = 0
|
||||||
|
diff_counter[diff][sid] += 1
|
||||||
|
|
||||||
|
if diff_counter[diff][sid] > largest_count:
|
||||||
|
largest = diff
|
||||||
|
largest_count = diff_counter[diff][sid]
|
||||||
|
song_id = sid
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||||
|
|
||||||
|
# extract idenfication
|
||||||
|
song = self.db.get_song_by_id(song_id)
|
||||||
|
if song:
|
||||||
|
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
||||||
|
|
||||||
|
# return match info
|
||||||
|
song = {
|
||||||
|
"song_id" : song_id,
|
||||||
|
"song_name" : songname,
|
||||||
|
"confidence" : largest_count
|
||||||
|
}
|
||||||
|
|
||||||
|
return song
|
|
@ -1,107 +0,0 @@
|
||||||
from dejavu.database import SQLDatabase
|
|
||||||
from dejavu.convert import Converter
|
|
||||||
from dejavu.fingerprint import Fingerprinter
|
|
||||||
from scipy.io import wavfile
|
|
||||||
from multiprocessing import Process
|
|
||||||
import wave, os
|
|
||||||
import random
|
|
||||||
|
|
||||||
class Dejavu():
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# create components
|
|
||||||
self.converter = Converter()
|
|
||||||
self.fingerprinter = Fingerprinter(self.config)
|
|
||||||
self.fingerprinter.db.setup()
|
|
||||||
|
|
||||||
# get songs previously indexed
|
|
||||||
self.songs = self.fingerprinter.db.get_songs()
|
|
||||||
self.songnames_set = set() # to know which ones we've computed before
|
|
||||||
if self.songs:
|
|
||||||
for song in self.songs:
|
|
||||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
|
||||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
|
||||||
self.songnames_set.add(song_name)
|
|
||||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
|
||||||
|
|
||||||
def chunkify(self, lst, n):
|
|
||||||
"""
|
|
||||||
Splits a list into roughly n equal parts.
|
|
||||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
|
||||||
"""
|
|
||||||
return [lst[i::n] for i in xrange(n)]
|
|
||||||
|
|
||||||
def fingerprint(self, path, output, extensions, nprocesses):
|
|
||||||
|
|
||||||
# convert files, shuffle order
|
|
||||||
files = self.converter.find_files(path, extensions)
|
|
||||||
random.shuffle(files)
|
|
||||||
files_split = self.chunkify(files, nprocesses)
|
|
||||||
|
|
||||||
# split into processes here
|
|
||||||
processes = []
|
|
||||||
for i in range(nprocesses):
|
|
||||||
|
|
||||||
# need database instance since mysql connections shouldn't be shared across processes
|
|
||||||
sql_connection = SQLDatabase(
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
|
||||||
|
|
||||||
# create process and start it
|
|
||||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
|
||||||
p.start()
|
|
||||||
processes.append(p)
|
|
||||||
|
|
||||||
# wait for all processes to complete
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
# delete orphans
|
|
||||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
|
||||||
# TODO: need a more performant query in database.py for the
|
|
||||||
#self.fingerprinter.db.delete_orphans()
|
|
||||||
|
|
||||||
def fingerprint_worker(self, files, sql_connection, output):
|
|
||||||
|
|
||||||
for filename, extension in files:
|
|
||||||
|
|
||||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
|
||||||
if song_name in self.songnames_set:
|
|
||||||
print "-> Already fingerprinted, continuing..."
|
|
||||||
continue
|
|
||||||
|
|
||||||
# convert to WAV
|
|
||||||
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
|
|
||||||
|
|
||||||
# insert song name into database
|
|
||||||
song_id = sql_connection.insert_song(song_name)
|
|
||||||
|
|
||||||
# for each channel perform FFT analysis and fingerprinting
|
|
||||||
channels = self.extract_channels(wavout_path)
|
|
||||||
for c in range(len(channels)):
|
|
||||||
channel = channels[c]
|
|
||||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
|
||||||
self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1)
|
|
||||||
|
|
||||||
# only after done fingerprinting do confirm
|
|
||||||
sql_connection.set_song_fingerprinted(song_id)
|
|
||||||
|
|
||||||
def extract_channels(self, path):
|
|
||||||
"""
|
|
||||||
Reads channels from disk.
|
|
||||||
"""
|
|
||||||
channels = []
|
|
||||||
Fs, frames = wavfile.read(path)
|
|
||||||
wave_object = wave.open(path)
|
|
||||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
|
||||||
assert Fs == self.fingerprinter.Fs
|
|
||||||
|
|
||||||
for channel in range(nchannels):
|
|
||||||
channels.append(frames[:, channel])
|
|
||||||
return channels
|
|
|
@ -280,15 +280,11 @@ class SQLDatabase(Database):
|
||||||
def return_matches(self, hashes):
|
def return_matches(self, hashes):
|
||||||
"""
|
"""
|
||||||
Return the (song_id, offset_diff) tuples associated with
|
Return the (song_id, offset_diff) tuples associated with
|
||||||
a list of
|
a list of (sha1, sample_offset) values.
|
||||||
|
|
||||||
sha1 => (None, sample_offset)
|
|
||||||
|
|
||||||
values.
|
|
||||||
"""
|
"""
|
||||||
# Create a dictionary of hash => offset pairs for later lookups
|
# Create a dictionary of hash => offset pairs for later lookups
|
||||||
mapper = {}
|
mapper = {}
|
||||||
for hash, (_, offset) in hashes:
|
for hash, offset in hashes:
|
||||||
mapper[hash.upper()] = offset
|
mapper[hash.upper()] = offset
|
||||||
|
|
||||||
# Get an iteratable of all the hashes we need
|
# Get an iteratable of all the hashes we need
|
||||||
|
|
271
dejavu/fingerprint.py
Normal file → Executable file
271
dejavu/fingerprint.py
Normal file → Executable file
|
@ -13,19 +13,117 @@ import time
|
||||||
import hashlib
|
import hashlib
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
IDX_FREQ_I = 0
|
||||||
|
IDX_TIME_J = 1
|
||||||
|
|
||||||
|
DEFAULT_FS = 44100
|
||||||
|
DEFAULT_WINDOW_SIZE = 4096
|
||||||
|
DEFAULT_OVERLAP_RATIO = 0.5
|
||||||
|
DEFAULT_FAN_VALUE = 15
|
||||||
|
|
||||||
|
DEFAULT_AMP_MIN = 10
|
||||||
|
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||||
|
MIN_HASH_TIME_DELTA = 0
|
||||||
|
|
||||||
|
def fingerprint(channel_samples,
|
||||||
|
Fs=DEFAULT_FS,
|
||||||
|
wsize=DEFAULT_WINDOW_SIZE,
|
||||||
|
wratio=DEFAULT_OVERLAP_RATIO,
|
||||||
|
fan_value=DEFAULT_FAN_VALUE,
|
||||||
|
amp_min=DEFAULT_AMP_MIN):
|
||||||
|
"""
|
||||||
|
FFT the channel, log transform output, find local maxima, then return
|
||||||
|
locally sensitive hashes.
|
||||||
|
"""
|
||||||
|
# FFT the signal and extract frequency components
|
||||||
|
arr2D = mlab.specgram(
|
||||||
|
channel_samples,
|
||||||
|
NFFT=wsize,
|
||||||
|
Fs=Fs,
|
||||||
|
window=mlab.window_hanning,
|
||||||
|
noverlap=int(wsize * wratio))[0]
|
||||||
|
|
||||||
|
# apply log transform since specgram() returns linear array
|
||||||
|
arr2D = 10 * np.log10(arr2D)
|
||||||
|
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
|
||||||
|
|
||||||
|
# find local maxima
|
||||||
|
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||||
|
|
||||||
|
# return hashes
|
||||||
|
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||||
|
|
||||||
|
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||||
|
|
||||||
|
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
||||||
|
struct = generate_binary_structure(2, 1)
|
||||||
|
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||||
|
|
||||||
|
# find local maxima using our fliter shape
|
||||||
|
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||||
|
background = (arr2D == 0)
|
||||||
|
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
||||||
|
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
|
||||||
|
|
||||||
|
# extract peaks
|
||||||
|
amps = arr2D[detected_peaks]
|
||||||
|
j, i = np.where(detected_peaks)
|
||||||
|
|
||||||
|
# filter peaks
|
||||||
|
amps = amps.flatten()
|
||||||
|
peaks = zip(i, j, amps)
|
||||||
|
peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp
|
||||||
|
|
||||||
|
# get indices for frequency and time
|
||||||
|
frequency_idx = [x[1] for x in peaks_filtered]
|
||||||
|
time_idx = [x[0] for x in peaks_filtered]
|
||||||
|
|
||||||
|
if plot:
|
||||||
|
# scatter of the peaks
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(arr2D)
|
||||||
|
ax.scatter(time_idx, frequency_idx)
|
||||||
|
ax.set_xlabel('Time')
|
||||||
|
ax.set_ylabel('Frequency')
|
||||||
|
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
|
||||||
|
plt.gca().invert_yaxis()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return zip(frequency_idx, time_idx)
|
||||||
|
|
||||||
|
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||||
|
"""
|
||||||
|
Hash list structure:
|
||||||
|
sha1-hash[0:20] time_offset
|
||||||
|
[(e05b341a9b77a51fd26, 32), ... ]
|
||||||
|
"""
|
||||||
|
fingerprinted = set() # to avoid rehashing same pairs
|
||||||
|
hashes = []
|
||||||
|
|
||||||
|
for i in range(len(peaks)):
|
||||||
|
for j in range(fan_value):
|
||||||
|
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
|
||||||
|
|
||||||
|
freq1 = peaks[i][IDX_FREQ_I]
|
||||||
|
freq2 = peaks[i+j][IDX_FREQ_I]
|
||||||
|
t1 = peaks[i][IDX_TIME_J]
|
||||||
|
t2 = peaks[i+j][IDX_TIME_J]
|
||||||
|
t_delta = t2 - t1
|
||||||
|
|
||||||
|
if t_delta >= MIN_HASH_TIME_DELTA:
|
||||||
|
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
||||||
|
hashes.append((h.hexdigest()[0:20], t1))
|
||||||
|
|
||||||
|
# ensure we don't repeat hashing
|
||||||
|
fingerprinted.add((i, i+j))
|
||||||
|
return hashes
|
||||||
|
|
||||||
|
# TODO: move all of the below to a class with DB access
|
||||||
|
|
||||||
|
|
||||||
class Fingerprinter():
|
class Fingerprinter():
|
||||||
|
|
||||||
IDX_FREQ_I = 0
|
|
||||||
IDX_TIME_J = 1
|
|
||||||
|
|
||||||
DEFAULT_FS = 44100
|
|
||||||
DEFAULT_WINDOW_SIZE = 4096
|
|
||||||
DEFAULT_OVERLAP_RATIO = 0.5
|
|
||||||
DEFAULT_FAN_VALUE = 15
|
|
||||||
|
|
||||||
DEFAULT_AMP_MIN = 10
|
|
||||||
PEAK_NEIGHBORHOOD_SIZE = 20
|
|
||||||
MIN_HASH_TIME_DELTA = 0
|
|
||||||
|
|
||||||
def __init__(self, config,
|
def __init__(self, config,
|
||||||
Fs=DEFAULT_FS,
|
Fs=DEFAULT_FS,
|
||||||
|
@ -35,12 +133,7 @@ class Fingerprinter():
|
||||||
amp_min=DEFAULT_AMP_MIN):
|
amp_min=DEFAULT_AMP_MIN):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
database = SQLDatabase(
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
|
||||||
self.db = database
|
|
||||||
|
|
||||||
self.Fs = Fs
|
self.Fs = Fs
|
||||||
self.dt = 1.0 / self.Fs
|
self.dt = 1.0 / self.Fs
|
||||||
|
@ -56,103 +149,14 @@ class Fingerprinter():
|
||||||
print "Generated %d hashes" % len(hashes)
|
print "Generated %d hashes" % len(hashes)
|
||||||
self.db.insert_hashes(hashes)
|
self.db.insert_hashes(hashes)
|
||||||
|
|
||||||
|
# TODO: put this in another module
|
||||||
def match(self, samples):
|
def match(self, samples):
|
||||||
"""Used for matching unknown songs"""
|
"""Used for matching unknown songs"""
|
||||||
hashes = self.process_channel(samples)
|
hashes = self.process_channel(samples)
|
||||||
matches = self.db.return_matches(hashes)
|
matches = self.db.return_matches(hashes)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def process_channel(self, channel_samples, song_id=None):
|
# TODO: this function has nothing to do with fingerprinting. is it needed?
|
||||||
"""
|
|
||||||
FFT the channel, log transform output, find local maxima, then return
|
|
||||||
locally sensitive hashes.
|
|
||||||
"""
|
|
||||||
# FFT the signal and extract frequency components
|
|
||||||
arr2D = mlab.specgram(
|
|
||||||
channel_samples,
|
|
||||||
NFFT=self.window_size,
|
|
||||||
Fs=self.Fs,
|
|
||||||
window=mlab.window_hanning,
|
|
||||||
noverlap=self.noverlap)[0]
|
|
||||||
|
|
||||||
# apply log transform since specgram() returns linear array
|
|
||||||
arr2D = 10 * np.log10(arr2D)
|
|
||||||
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
|
|
||||||
|
|
||||||
# find local maxima
|
|
||||||
local_maxima = self.get_2D_peaks(arr2D, plot=False)
|
|
||||||
|
|
||||||
# return hashes
|
|
||||||
return self.generate_hashes(local_maxima, song_id=song_id)
|
|
||||||
|
|
||||||
def get_2D_peaks(self, arr2D, plot=False):
|
|
||||||
|
|
||||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
|
||||||
struct = generate_binary_structure(2, 1)
|
|
||||||
neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE)
|
|
||||||
|
|
||||||
# find local maxima using our fliter shape
|
|
||||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
|
||||||
background = (arr2D == 0)
|
|
||||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
|
||||||
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
|
|
||||||
|
|
||||||
# extract peaks
|
|
||||||
amps = arr2D[detected_peaks]
|
|
||||||
j, i = np.where(detected_peaks)
|
|
||||||
|
|
||||||
# filter peaks
|
|
||||||
amps = amps.flatten()
|
|
||||||
peaks = zip(i, j, amps)
|
|
||||||
peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp
|
|
||||||
|
|
||||||
# get indices for frequency and time
|
|
||||||
frequency_idx = [x[1] for x in peaks_filtered]
|
|
||||||
time_idx = [x[0] for x in peaks_filtered]
|
|
||||||
|
|
||||||
if plot:
|
|
||||||
# scatter of the peaks
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.imshow(arr2D)
|
|
||||||
ax.scatter(time_idx, frequency_idx)
|
|
||||||
ax.set_xlabel('Time')
|
|
||||||
ax.set_ylabel('Frequency')
|
|
||||||
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
|
|
||||||
plt.gca().invert_yaxis()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
return zip(frequency_idx, time_idx)
|
|
||||||
|
|
||||||
def generate_hashes(self, peaks, song_id=None):
|
|
||||||
"""
|
|
||||||
Hash list structure:
|
|
||||||
sha1-hash[0:20] song_id, time_offset
|
|
||||||
[(e05b341a9b77a51fd26, (3, 32)), ... ]
|
|
||||||
"""
|
|
||||||
fingerprinted = set() # to avoid rehashing same pairs
|
|
||||||
hashes = []
|
|
||||||
|
|
||||||
for i in range(len(peaks)):
|
|
||||||
for j in range(self.fan_value):
|
|
||||||
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
|
|
||||||
|
|
||||||
freq1 = peaks[i][Fingerprinter.IDX_FREQ_I]
|
|
||||||
freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I]
|
|
||||||
t1 = peaks[i][Fingerprinter.IDX_TIME_J]
|
|
||||||
t2 = peaks[i+j][Fingerprinter.IDX_TIME_J]
|
|
||||||
t_delta = t2 - t1
|
|
||||||
|
|
||||||
if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA:
|
|
||||||
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
|
||||||
hashes.append((h.hexdigest()[0:20], (song_id, t1)))
|
|
||||||
|
|
||||||
# ensure we don't repeat hashing
|
|
||||||
fingerprinted.add((i, i+j))
|
|
||||||
return hashes
|
|
||||||
|
|
||||||
def insert_into_db(self, key, value):
|
|
||||||
self.db.insert(key, value)
|
|
||||||
|
|
||||||
def print_stats(self):
|
def print_stats(self):
|
||||||
|
|
||||||
iterable = self.db.get_iterable_kv_pairs()
|
iterable = self.db.get_iterable_kv_pairs()
|
||||||
|
@ -169,58 +173,9 @@ class Fingerprinter():
|
||||||
song_name = self.song_names[song_id]
|
song_name = self.song_names[song_id]
|
||||||
print "%s has %d spectrogram peaks" % (song_name, count)
|
print "%s has %d spectrogram peaks" % (song_name, count)
|
||||||
|
|
||||||
|
# this does... what? this seems to only be used for the above function
|
||||||
def set_song_names(self, wpaths):
|
def set_song_names(self, wpaths):
|
||||||
self.song_names = wpaths
|
self.song_names = wpaths
|
||||||
|
|
||||||
def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
|
# TODO: put this in another module
|
||||||
"""
|
|
||||||
Finds hash matches that align in time with other matches and finds
|
|
||||||
consensus about which hashes are "true" signal from the audio.
|
|
||||||
|
|
||||||
Returns a dictionary with match information.
|
|
||||||
"""
|
|
||||||
# align by diffs
|
|
||||||
diff_counter = {}
|
|
||||||
largest = 0
|
|
||||||
largest_count = 0
|
|
||||||
song_id = -1
|
|
||||||
for tup in matches:
|
|
||||||
sid, diff = tup
|
|
||||||
if not diff in diff_counter:
|
|
||||||
diff_counter[diff] = {}
|
|
||||||
if not sid in diff_counter[diff]:
|
|
||||||
diff_counter[diff][sid] = 0
|
|
||||||
diff_counter[diff][sid] += 1
|
|
||||||
|
|
||||||
if diff_counter[diff][sid] > largest_count:
|
|
||||||
largest = diff
|
|
||||||
largest_count = diff_counter[diff][sid]
|
|
||||||
song_id = sid
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print "Diff is %d with %d offset-aligned matches" % (largest, largest_count)
|
|
||||||
|
|
||||||
# extract idenfication
|
|
||||||
song = self.db.get_song_by_id(song_id)
|
|
||||||
if song:
|
|
||||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
songname = songname.replace("_", " ")
|
|
||||||
elapsed = time.time() - starttime
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)
|
|
||||||
|
|
||||||
# return match info
|
|
||||||
song = {
|
|
||||||
"song_id" : song_id,
|
|
||||||
"song_name" : songname,
|
|
||||||
"match_time" : elapsed,
|
|
||||||
"confidence" : largest_count
|
|
||||||
}
|
|
||||||
|
|
||||||
if record_seconds:
|
|
||||||
song['record_time'] = record_seconds
|
|
||||||
|
|
||||||
return song
|
|
||||||
|
|
144
dejavu/recognize.py
Normal file → Executable file
144
dejavu/recognize.py
Normal file → Executable file
|
@ -1,5 +1,7 @@
|
||||||
from multiprocessing import Queue, Process
|
from multiprocessing import Queue, Process
|
||||||
from dejavu.database import SQLDatabase
|
from dejavu.database import SQLDatabase
|
||||||
|
import dejavu.fingerprint
|
||||||
|
from dejavu import Dejavu
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import wave
|
import wave
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -8,65 +10,117 @@ import sys
|
||||||
import time
|
import time
|
||||||
import array
|
import array
|
||||||
|
|
||||||
class Recognizer(object):
|
|
||||||
|
class BaseRecognizer(object):
|
||||||
|
|
||||||
|
def __init__(self, dejavu):
|
||||||
|
self.dejavu = dejavu
|
||||||
|
self.Fs = dejavu.fingerprint.DEFAULT_FS
|
||||||
|
|
||||||
|
def _recognize(self, *data):
|
||||||
|
matches = []
|
||||||
|
for d in data:
|
||||||
|
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
|
||||||
|
return self.dejavu.align_matches(matches)
|
||||||
|
|
||||||
|
def recognize(self):
|
||||||
|
pass # base class does nothing
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WaveFileRecognizer(BaseRecognizer):
|
||||||
|
|
||||||
|
def __init__(self, dejavu, filename=None):
|
||||||
|
super(WaveFileRecognizer, self).__init__(dejavu)
|
||||||
|
self.filename = filename
|
||||||
|
|
||||||
|
def recognize_file(self, filename):
|
||||||
|
Fs, frames = wavfile.read(filename)
|
||||||
|
self.Fs = Fs
|
||||||
|
|
||||||
|
wave_object = wave.open(filename)
|
||||||
|
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||||
|
|
||||||
|
channels = []
|
||||||
|
for channel in range(nchannels):
|
||||||
|
channels.append(frames[:, channel])
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
match = self._recognize(*channels)
|
||||||
|
t = time.time() - t
|
||||||
|
|
||||||
|
if match:
|
||||||
|
match['match_time'] = t
|
||||||
|
|
||||||
|
return match
|
||||||
|
|
||||||
|
def recognize(self):
|
||||||
|
return self.recognize_file(self.filename)
|
||||||
|
|
||||||
|
|
||||||
|
class MicrophoneRecognizer(BaseRecognizer):
|
||||||
|
|
||||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||||
FORMAT = pyaudio.paInt16
|
FORMAT = pyaudio.paInt16
|
||||||
CHANNELS = 2
|
CHANNELS = 2
|
||||||
RATE = 44100
|
RATE = 44100
|
||||||
|
|
||||||
def __init__(self, fingerprinter, config):
|
def __init__(self, dejavu, seconds=None):
|
||||||
|
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||||
self.fingerprinter = fingerprinter
|
|
||||||
self.config = config
|
|
||||||
self.audio = pyaudio.PyAudio()
|
self.audio = pyaudio.PyAudio()
|
||||||
|
self.stream = None
|
||||||
|
self.data = []
|
||||||
|
self.channels = CHANNELS
|
||||||
|
self.chunk_size = CHUNK
|
||||||
|
self.rate = RATE
|
||||||
|
self.recorded = False
|
||||||
|
|
||||||
def read(self, filename, verbose=False):
|
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
||||||
|
self.chunk_size = chunk
|
||||||
|
self.channels = channels
|
||||||
|
self.recorded = False
|
||||||
|
self.rate = rate
|
||||||
|
|
||||||
# read file into channels
|
if self.stream:
|
||||||
channels = []
|
self.stream.stop_stream()
|
||||||
Fs, frames = wavfile.read(filename)
|
self.stream.close()
|
||||||
wave_object = wave.open(filename)
|
|
||||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
|
||||||
for channel in range(nchannels):
|
|
||||||
channels.append(frames[:, channel])
|
|
||||||
|
|
||||||
# get matches
|
self.stream = self.audio.open(format=FORMAT,
|
||||||
starttime = time.time()
|
channels=channels,
|
||||||
matches = []
|
rate=rate,
|
||||||
for channel in channels:
|
input=True,
|
||||||
matches.extend(self.fingerprinter.match(channel))
|
frames_per_buffer=chunk)
|
||||||
|
|
||||||
return self.fingerprinter.align_matches(matches, starttime, verbose=verbose)
|
self.data = [[] for i in range(channels)]
|
||||||
|
|
||||||
def listen(self, seconds=10, verbose=False):
|
def process_recording(self):
|
||||||
|
data = self.stream.read(self.chunk_size)
|
||||||
|
nums = np.fromstring(data, np.int16)
|
||||||
|
for c in range(self.channels):
|
||||||
|
self.data[c].extend(nums[c::c+1])
|
||||||
|
|
||||||
# open stream
|
def stop_recording(self):
|
||||||
stream = self.audio.open(format=Recognizer.FORMAT,
|
self.stream.stop_stream()
|
||||||
channels=Recognizer.CHANNELS,
|
self.stream.close()
|
||||||
rate=Recognizer.RATE,
|
self.stream = None
|
||||||
input=True,
|
self.recorded = True
|
||||||
frames_per_buffer=Recognizer.CHUNK)
|
|
||||||
|
|
||||||
# record
|
def recognize_recording(self):
|
||||||
if verbose: print("* recording")
|
if not self.recorded:
|
||||||
left, right = [], []
|
raise NoRecordingError("Recording was not complete/begun")
|
||||||
for i in range(0, int(Recognizer.RATE / Recognizer.CHUNK * seconds)):
|
return self._recognize(*self.data)
|
||||||
data = stream.read(Recognizer.CHUNK)
|
|
||||||
nums = np.fromstring(data, np.int16)
|
|
||||||
left.extend(nums[1::2])
|
|
||||||
right.extend(nums[0::2])
|
|
||||||
if verbose: print("* done recording")
|
|
||||||
|
|
||||||
# close and stop the stream
|
def get_recorded_time(self):
|
||||||
stream.stop_stream()
|
return len(self.data[0]) / self.rate
|
||||||
stream.close()
|
|
||||||
|
|
||||||
# match both channels
|
def recognize(self):
|
||||||
starttime = time.time()
|
self.start_recording()
|
||||||
matches = []
|
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
||||||
matches.extend(self.fingerprinter.match(left))
|
self.process_recording()
|
||||||
matches.extend(self.fingerprinter.match(right))
|
self.stop_recording()
|
||||||
|
return self.recognize_recording()
|
||||||
|
|
||||||
|
class NoRecordingError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
# align and return
|
|
||||||
return self.fingerprinter.align_matches(matches, starttime, record_seconds=seconds, verbose=verbose)
|
|
2
go.py
Normal file → Executable file
2
go.py
Normal file → Executable file
|
@ -1,4 +1,4 @@
|
||||||
from dejavu.control import Dejavu
|
from dejavu import Dejavu
|
||||||
from ConfigParser import ConfigParser
|
from ConfigParser import ConfigParser
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
Loading…
Reference in a new issue