From 29029238fcfc7048d86b8774ac455eac12deea5d Mon Sep 17 00:00:00 2001 From: Wessie Date: Thu, 19 Dec 2013 00:54:17 +0100 Subject: [PATCH] A fairly big batch of changes, blame me to forget committing Changes in no particular order: - Replaced all use cases of wavfile/wave and extract_channels with the new decoder.read function - Added a 'recognize' method to the Dejavu class. This is a shortcut for recognizing songs. - Renamed 'do_fingerprint' into 'fingerprint_directory' - Removed parameters not required anymore from fingerprint_directory - Cleaned up fingerprint.py - Made fingerprint.generate_hashes a generator - WaveFileRecognizer is now FileRecognizer and can take any formats supported by pydub - Fixed MicrophoneRecognizer to actually run, previous version had many small mistakes - Renamed 'fingerprint_worker' to '_fingerprint_worker' to signify it is not to be used publicly - Moved 'chunkify' outside the Dejavu class - Cleaned up pep8 styling mistakes in all edited files. --- dejavu/__init__.py | 102 ++++++++++------------ dejavu/{decode.py => decoder.py} | 0 dejavu/fingerprint.py | 145 +++++++++---------------------- dejavu/recognize.py | 78 ++++++++--------- 4 files changed, 120 insertions(+), 205 deletions(-) rename dejavu/{decode.py => decoder.py} (100%) diff --git a/dejavu/__init__.py b/dejavu/__init__.py index f9be8bd..c10e54d 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -1,27 +1,25 @@ from dejavu.database import SQLDatabase -import dejavu.decode as decoder +import dejavu.decoder as decoder import fingerprint -from scipy.io import wavfile -from multiprocessing import Process -import wave, os +from multiprocessing import Process, cpu_count +import os import random -DEBUG = False class Dejavu(object): def __init__(self, config): + super(Dejavu, self).__init__() self.config = config # initialize db self.db = SQLDatabase(**config.get("database", {})) - #self.fingerprinter = Fingerprinter(self.config) self.db.setup() # get songs previously indexed self.songs = self.db.get_songs() - self.songnames_set = set() # to know which ones we've computed before + self.songnames_set = set() # to know which ones we've computed before for song in self.songs: song_name = song[self.db.FIELD_SONGNAME] @@ -29,27 +27,27 @@ class Dejavu(object): self.songnames_set.add(song_name) print "Added: %s to the set of fingerprinted songs..." % song_name - def chunkify(self, lst, n): - """ - Splits a list into roughly n equal parts. - http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts - """ - return [lst[i::n] for i in xrange(n)] - - def do_fingerprint(self, path, output, extensions, nprocesses): + def fingerprint_directory(self, path, extensions, nprocesses=None): + # Try to use the maximum amount of processes if not given. + if nprocesses is None: + try: + nprocesses = cpu_count() + except NotImplementedError: + nprocesses = 1 # convert files, shuffle order - files = decoder.find_files(path, extensions) + files = list(decoder.find_files(path, extensions)) random.shuffle(files) - files_split = self.chunkify(files, nprocesses) + + files_split = chunkify(files, nprocesses) # split into processes here processes = [] for i in range(nprocesses): # create process and start it - p = Process(target=self.fingerprint_worker, - args=(files_split[i], self.db, output)) + p = Process(target=self._fingerprint_worker, + args=(files_split[i], self.db)) p.start() processes.append(p) @@ -57,55 +55,37 @@ class Dejavu(object): for p in processes: p.join() - # delete orphans - # print "Done fingerprinting. Deleting orphaned fingerprints..." - # TODO: need a more performant query in database.py for the - #self.fingerprinter.db.delete_orphans() - - def fingerprint_worker(self, files, sql_connection, output): - + def _fingerprint_worker(self, files, db): for filename, extension in files: - # if there are already fingerprints in database, don't re-fingerprint or convert + # if there are already fingerprints in database, + # don't re-fingerprint song_name = os.path.basename(filename).split(".")[0] - if DEBUG and song_name in self.songnames_set: + if song_name in self.songnames_set: print("-> Already fingerprinted, continuing...") continue channels, Fs = decoder.read(filename) # insert song name into database - song_id = sql_connection.insert_song(song_name) + song_id = db.insert_song(song_name) for c in range(len(channels)): channel = channels[c] print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name) + hashes = fingerprint.fingerprint(channel, Fs=Fs) - sql_connection.insert_hashes(song_id, hashes) + + db.insert_hashes(song_id, hashes) # only after done fingerprinting do confirm - sql_connection.set_song_fingerprinted(song_id) - - def extract_channels(self, path): - """ - Reads channels from disk. - Returns a tuple with (channels, sample_rate) - """ - channels = [] - Fs, frames = wavfile.read(path) - wave_object = wave.open(path) - nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() - #assert Fs == self.fingerprinter.Fs - - for channel in range(nchannels): - channels.append(frames[:, channel]) - return (channels, Fs) + db.set_song_fingerprinted(song_id) def fingerprint_file(self, filepath, song_name=None): - # TODO: replace with something that handles all audio formats - channels, Fs = self.extract_channels(path) + channels, Fs = decoder.read(filepath) + if not song_name: - song_name = os.path.basename(filename).split(".")[0] + song_name = os.path.basename(filepath).split(".")[0] song_id = self.db.insert_song(song_name) for data in channels: @@ -141,8 +121,7 @@ class Dejavu(object): largest_count = diff_counter[diff][sid] song_id = sid - if DEBUG: - print("Diff is %d with %d offset-aligned matches" % (largest, largest_count)) + print("Diff is %d with %d offset-aligned matches" % (largest, largest_count)) # extract idenfication song = self.db.get_song_by_id(song_id) @@ -151,14 +130,23 @@ class Dejavu(object): else: return None - if DEBUG: - print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)) - # return match info song = { - "song_id" : song_id, - "song_name" : songname, - "confidence" : largest_count + "song_id": song_id, + "song_name": songname, + "confidence": largest_count } return song + + def recognize(self, recognizer, *options, **kwoptions): + r = recognizer(self) + return r.recognize(*options, **kwoptions) + + +def chunkify(lst, n): + """ + Splits a list into roughly n equal parts. + http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts + """ + return [lst[i::n] for i in xrange(n)] diff --git a/dejavu/decode.py b/dejavu/decoder.py similarity index 100% rename from dejavu/decode.py rename to dejavu/decoder.py diff --git a/dejavu/fingerprint.py b/dejavu/fingerprint.py index 6108799..caf252e 100755 --- a/dejavu/fingerprint.py +++ b/dejavu/fingerprint.py @@ -1,17 +1,11 @@ import numpy as np import matplotlib.mlab as mlab import matplotlib.pyplot as plt -import matplotlib.image as mpimg -from scipy.io import wavfile from scipy.ndimage.filters import maximum_filter -from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion -from dejavu.database import SQLDatabase -import os -import wave -import sys -import time +from scipy.ndimage.morphology import (generate_binary_structure, + iterate_structure, binary_erosion) import hashlib -import pickle + IDX_FREQ_I = 0 IDX_TIME_J = 1 @@ -25,55 +19,58 @@ DEFAULT_AMP_MIN = 10 PEAK_NEIGHBORHOOD_SIZE = 20 MIN_HASH_TIME_DELTA = 0 -def fingerprint(channel_samples, - Fs=DEFAULT_FS, - wsize=DEFAULT_WINDOW_SIZE, - wratio=DEFAULT_OVERLAP_RATIO, - fan_value=DEFAULT_FAN_VALUE, - amp_min=DEFAULT_AMP_MIN): + +def fingerprint(channel_samples, Fs=DEFAULT_FS, + wsize=DEFAULT_WINDOW_SIZE, + wratio=DEFAULT_OVERLAP_RATIO, + fan_value=DEFAULT_FAN_VALUE, + amp_min=DEFAULT_AMP_MIN): """ - FFT the channel, log transform output, find local maxima, then return - locally sensitive hashes. + FFT the channel, log transform output, find local maxima, then return + locally sensitive hashes. """ # FFT the signal and extract frequency components arr2D = mlab.specgram( - channel_samples, - NFFT=wsize, + channel_samples, + NFFT=wsize, Fs=Fs, window=mlab.window_hanning, noverlap=int(wsize * wratio))[0] # apply log transform since specgram() returns linear array arr2D = 10 * np.log10(arr2D) - arr2D[arr2D == -np.inf] = 0 # replace infs with zeros - + arr2D[arr2D == -np.inf] = 0 # replace infs with zeros + # find local maxima local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) # return hashes return generate_hashes(local_maxima, fan_value=fan_value) -def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): +def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure struct = generate_binary_structure(2, 1) neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) # find local maxima using our fliter shape - local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D + local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D background = (arr2D == 0) - eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) - detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks + eroded_background = binary_erosion(background, structure=neighborhood, + border_value=1) + + # Boolean mask of arr2D with True at peaks + detected_peaks = local_max - eroded_background # extract peaks amps = arr2D[detected_peaks] - j, i = np.where(detected_peaks) + j, i = np.where(detected_peaks) # filter peaks amps = amps.flatten() peaks = zip(i, j, amps) - peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp - + peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp + # get indices for frequency and time frequency_idx = [x[1] for x in peaks_filtered] time_idx = [x[0] for x in peaks_filtered] @@ -85,97 +82,37 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): ax.scatter(time_idx, frequency_idx) ax.set_xlabel('Time') ax.set_ylabel('Frequency') - ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke"); + ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke") plt.gca().invert_yaxis() plt.show() return zip(frequency_idx, time_idx) + def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): """ Hash list structure: - sha1-hash[0:20] time_offset + sha1_hash[0:20] time_offset [(e05b341a9b77a51fd26, 32), ... ] """ - fingerprinted = set() # to avoid rehashing same pairs - hashes = [] + fingerprinted = set() # to avoid rehashing same pairs for i in range(len(peaks)): for j in range(fan_value): - if i+j < len(peaks) and not (i, i+j) in fingerprinted: - + if (i + j) < len(peaks) and not (i, i + j) in fingerprinted: freq1 = peaks[i][IDX_FREQ_I] - freq2 = peaks[i+j][IDX_FREQ_I] + freq2 = peaks[i + j][IDX_FREQ_I] + t1 = peaks[i][IDX_TIME_J] - t2 = peaks[i+j][IDX_TIME_J] + t2 = peaks[i + j][IDX_TIME_J] + t_delta = t2 - t1 - + if t_delta >= MIN_HASH_TIME_DELTA: - h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))) - hashes.append((h.hexdigest()[0:20], t1)) - + h = hashlib.sha1( + "%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)) + ) + yield (h.hexdigest()[0:20], t1) + # ensure we don't repeat hashing - fingerprinted.add((i, i+j)) - return hashes - -# TODO: move all of the below to a class with DB access - - -class Fingerprinter(): - - - - def __init__(self, config, - Fs=DEFAULT_FS, - wsize=DEFAULT_WINDOW_SIZE, - wratio=DEFAULT_OVERLAP_RATIO, - fan_value=DEFAULT_FAN_VALUE, - amp_min=DEFAULT_AMP_MIN): - - self.config = config - - - self.Fs = Fs - self.dt = 1.0 / self.Fs - self.window_size = wsize - self.window_overlap_ratio = wratio - self.fan_value = fan_value - self.noverlap = int(self.window_size * self.window_overlap_ratio) - self.amp_min = amp_min - - def fingerprint(self, samples, path, sid, cid): - """Used for learning known songs""" - hashes = self.process_channel(samples, song_id=sid) - print "Generated %d hashes" % len(hashes) - self.db.insert_hashes(hashes) - - # TODO: put this in another module - def match(self, samples): - """Used for matching unknown songs""" - hashes = self.process_channel(samples) - matches = self.db.return_matches(hashes) - return matches - - # TODO: this function has nothing to do with fingerprinting. is it needed? - def print_stats(self): - - iterable = self.db.get_iterable_kv_pairs() - - counter = {} - for t in iterable: - sid, toff = t - if not sid in counter: - counter[sid] = 1 - else: - counter[sid] += 1 - - for song_id, count in counter.iteritems(): - song_name = self.song_names[song_id] - print "%s has %d spectrogram peaks" % (song_name, count) - - # this does... what? this seems to only be used for the above function - def set_song_names(self, wpaths): - self.song_names = wpaths - - # TODO: put this in another module - + fingerprinted.add((i, i + j)) diff --git a/dejavu/recognize.py b/dejavu/recognize.py index 9283c6a..a723197 100755 --- a/dejavu/recognize.py +++ b/dejavu/recognize.py @@ -1,14 +1,8 @@ -from multiprocessing import Queue, Process -from dejavu.database import SQLDatabase import dejavu.fingerprint as fingerprint -from dejavu import Dejavu -from scipy.io import wavfile -import wave +import dejavu.decoder as decoder import numpy as np import pyaudio -import sys import time -import array class BaseRecognizer(object): @@ -26,25 +20,17 @@ class BaseRecognizer(object): def recognize(self): pass # base class does nothing -class WaveFileRecognizer(BaseRecognizer): - def __init__(self, dejavu, filename=None): - super(WaveFileRecognizer, self).__init__(dejavu) - self.filename = filename +class FileRecognizer(BaseRecognizer): + def __init__(self, dejavu): + super(FileRecognizer, self).__init__(dejavu) def recognize_file(self, filename): - Fs, frames = wavfile.read(filename) + Fs, frames = decoder.read(filename) self.Fs = Fs - wave_object = wave.open(filename) - nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() - - channels = [] - for channel in range(nchannels): - channels.append(frames[:, channel]) - t = time.time() - match = self._recognize(*channels) + match = self._recognize(*frames) t = time.time() - t if match: @@ -52,50 +38,53 @@ class WaveFileRecognizer(BaseRecognizer): return match - def recognize(self): - return self.recognize_file(self.filename) + def recognize(self, filename): + return self.recognize_file(filename) class MicrophoneRecognizer(BaseRecognizer): + default_chunksize = 8192 + default_format = pyaudio.paInt16 + default_channels = 2 + default_samplerate = 44100 - CHUNK = 8192 # 44100 is a multiple of 1225 - FORMAT = pyaudio.paInt16 - CHANNELS = 2 - RATE = 44100 - - def __init__(self, dejavu, seconds=None): + def __init__(self, dejavu): super(MicrophoneRecognizer, self).__init__(dejavu) self.audio = pyaudio.PyAudio() self.stream = None self.data = [] - self.channels = CHANNELS - self.chunk_size = CHUNK - self.rate = RATE + self.channels = self.default_channels + self.chunksize = self.default_chunk + self.samplerate = self.default_samplerate self.recorded = False - def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK): - self.chunk_size = chunk + def start_recording(self, channels=default_channels, + samplerate=default_samplerate, + chunksize=default_chunksize): + self.chunksize = chunksize self.channels = channels self.recorded = False - self.rate = rate + self.samplerate = samplerate if self.stream: self.stream.stop_stream() self.stream.close() - self.stream = self.audio.open(format=FORMAT, - channels=channels, - rate=rate, - input=True, - frames_per_buffer=chunk) + self.stream = self.audio.open( + format=self.default_format, + channels=channels, + rate=samplerate, + input=True, + frames_per_buffer=chunksize, + ) self.data = [[] for i in range(channels)] def process_recording(self): - data = self.stream.read(self.chunk_size) + data = self.stream.read(self.chunksize) nums = np.fromstring(data, np.int16) for c in range(self.channels): - self.data[c].extend(nums[c::c+1]) + self.data[c].extend(nums[c::len(self.channels)]) def stop_recording(self): self.stream.stop_stream() @@ -111,13 +100,14 @@ class MicrophoneRecognizer(BaseRecognizer): def get_recorded_time(self): return len(self.data[0]) / self.rate - def recognize(self): + def recognize(self, seconds=None): self.start_recording() - for i in range(0, int(self.rate / self.chunk * self.seconds)): + for i in range(0, int(self.samplerate / self.chunksize + * seconds)): self.process_recording() self.stop_recording() return self.recognize_recording() + class NoRecordingError(Exception): pass -