Refactored the fingerprint module

This commit is contained in:
Vin 2013-12-16 23:12:50 +00:00
parent 9aa80ff87c
commit 6a6ae94e3d

197
dejavu/fingerprint.py Normal file → Executable file
View file

@ -13,8 +13,6 @@ import time
import hashlib import hashlib
import pickle import pickle
class Fingerprinter():
IDX_FREQ_I = 0 IDX_FREQ_I = 0
IDX_TIME_J = 1 IDX_TIME_J = 1
@ -27,6 +25,106 @@ class Fingerprinter():
PEAK_NEIGHBORHOOD_SIZE = 20 PEAK_NEIGHBORHOOD_SIZE = 20
MIN_HASH_TIME_DELTA = 0 MIN_HASH_TIME_DELTA = 0
def fingerprint(channel_samples,
Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE,
wratio=DEFAULT_OVERLAP_RATIO,
fan_value=DEFAULT_FAN_VALUE,
amp_min=DEFAULT_AMP_MIN):
"""
FFT the channel, log transform output, find local maxima, then return
locally sensitive hashes.
"""
# FFT the signal and extract frequency components
arr2D = mlab.specgram(
channel_samples,
NFFT=wsize,
Fs=Fs,
window=mlab.window_hanning,
noverlap=int(wsize * wratio))[0]
# apply log transform since specgram() returns linear array
arr2D = 10 * np.log10(arr2D)
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
# find local maxima
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
# return hashes
return generate_hashes(local_maxima, fan_value=fan_value)
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
struct = generate_binary_structure(2, 1)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
# find local maxima using our fliter shape
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
# extract peaks
amps = arr2D[detected_peaks]
j, i = np.where(detected_peaks)
# filter peaks
amps = amps.flatten()
peaks = zip(i, j, amps)
peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp
# get indices for frequency and time
frequency_idx = [x[1] for x in peaks_filtered]
time_idx = [x[0] for x in peaks_filtered]
if plot:
# scatter of the peaks
fig, ax = plt.subplots()
ax.imshow(arr2D)
ax.scatter(time_idx, frequency_idx)
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
plt.gca().invert_yaxis()
plt.show()
return zip(frequency_idx, time_idx)
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
"""
Hash list structure:
sha1-hash[0:20] time_offset
[(e05b341a9b77a51fd26, 32), ... ]
"""
fingerprinted = set() # to avoid rehashing same pairs
hashes = []
for i in range(len(peaks)):
for j in range(fan_value):
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
freq1 = peaks[i][IDX_FREQ_I]
freq2 = peaks[i+j][IDX_FREQ_I]
t1 = peaks[i][IDX_TIME_J]
t2 = peaks[i+j][IDX_TIME_J]
t_delta = t2 - t1
if t_delta >= MIN_HASH_TIME_DELTA:
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
hashes.append((h.hexdigest()[0:20], t1))
# ensure we don't repeat hashing
fingerprinted.add((i, i+j))
return hashes
# TODO: move all of the below to a class with DB access
class Fingerprinter():
def __init__(self, config, def __init__(self, config,
Fs=DEFAULT_FS, Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE, wsize=DEFAULT_WINDOW_SIZE,
@ -56,103 +154,14 @@ class Fingerprinter():
print "Generated %d hashes" % len(hashes) print "Generated %d hashes" % len(hashes)
self.db.insert_hashes(hashes) self.db.insert_hashes(hashes)
# TODO: put this in another module
def match(self, samples): def match(self, samples):
"""Used for matching unknown songs""" """Used for matching unknown songs"""
hashes = self.process_channel(samples) hashes = self.process_channel(samples)
matches = self.db.return_matches(hashes) matches = self.db.return_matches(hashes)
return matches return matches
def process_channel(self, channel_samples, song_id=None): # TODO: this function has nothing to do with fingerprinting. is it needed?
"""
FFT the channel, log transform output, find local maxima, then return
locally sensitive hashes.
"""
# FFT the signal and extract frequency components
arr2D = mlab.specgram(
channel_samples,
NFFT=self.window_size,
Fs=self.Fs,
window=mlab.window_hanning,
noverlap=self.noverlap)[0]
# apply log transform since specgram() returns linear array
arr2D = 10 * np.log10(arr2D)
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
# find local maxima
local_maxima = self.get_2D_peaks(arr2D, plot=False)
# return hashes
return self.generate_hashes(local_maxima, song_id=song_id)
def get_2D_peaks(self, arr2D, plot=False):
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
struct = generate_binary_structure(2, 1)
neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE)
# find local maxima using our fliter shape
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
# extract peaks
amps = arr2D[detected_peaks]
j, i = np.where(detected_peaks)
# filter peaks
amps = amps.flatten()
peaks = zip(i, j, amps)
peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp
# get indices for frequency and time
frequency_idx = [x[1] for x in peaks_filtered]
time_idx = [x[0] for x in peaks_filtered]
if plot:
# scatter of the peaks
fig, ax = plt.subplots()
ax.imshow(arr2D)
ax.scatter(time_idx, frequency_idx)
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
plt.gca().invert_yaxis()
plt.show()
return zip(frequency_idx, time_idx)
def generate_hashes(self, peaks, song_id=None):
"""
Hash list structure:
sha1-hash[0:20] song_id, time_offset
[(e05b341a9b77a51fd26, (3, 32)), ... ]
"""
fingerprinted = set() # to avoid rehashing same pairs
hashes = []
for i in range(len(peaks)):
for j in range(self.fan_value):
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
freq1 = peaks[i][Fingerprinter.IDX_FREQ_I]
freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I]
t1 = peaks[i][Fingerprinter.IDX_TIME_J]
t2 = peaks[i+j][Fingerprinter.IDX_TIME_J]
t_delta = t2 - t1
if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA:
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
hashes.append((h.hexdigest()[0:20], (song_id, t1)))
# ensure we don't repeat hashing
fingerprinted.add((i, i+j))
return hashes
def insert_into_db(self, key, value):
self.db.insert(key, value)
def print_stats(self): def print_stats(self):
iterable = self.db.get_iterable_kv_pairs() iterable = self.db.get_iterable_kv_pairs()
@ -169,9 +178,11 @@ class Fingerprinter():
song_name = self.song_names[song_id] song_name = self.song_names[song_id]
print "%s has %d spectrogram peaks" % (song_name, count) print "%s has %d spectrogram peaks" % (song_name, count)
# this does... what? this seems to only be used for the above function
def set_song_names(self, wpaths): def set_song_names(self, wpaths):
self.song_names = wpaths self.song_names = wpaths
# TODO: put this in another module
def align_matches(self, matches, starttime, record_seconds=0, verbose=False): def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
""" """
Finds hash matches that align in time with other matches and finds Finds hash matches that align in time with other matches and finds