A fairly big batch of changes, blame me to forget committing

Changes in no particular order:
- Replaced all use cases of wavfile/wave and extract_channels with the new decoder.read function
- Added a 'recognize' method to the Dejavu class. This is a shortcut for recognizing songs.
- Renamed 'do_fingerprint' into 'fingerprint_directory'
- Removed parameters not required anymore from fingerprint_directory
- Cleaned up fingerprint.py
- Made fingerprint.generate_hashes a generator
- WaveFileRecognizer is now FileRecognizer and can take any formats supported by pydub
- Fixed MicrophoneRecognizer to actually run, previous version had many small mistakes
- Renamed 'fingerprint_worker' to '_fingerprint_worker' to signify it is not to be used publicly
- Moved 'chunkify' outside the Dejavu class
- Cleaned up pep8 styling mistakes in all edited files.
This commit is contained in:
Wessie 2013-12-19 00:54:17 +01:00
parent 7895bae23e
commit 29029238fc
4 changed files with 120 additions and 205 deletions

View file

@ -1,22 +1,20 @@
from dejavu.database import SQLDatabase from dejavu.database import SQLDatabase
import dejavu.decode as decoder import dejavu.decoder as decoder
import fingerprint import fingerprint
from scipy.io import wavfile from multiprocessing import Process, cpu_count
from multiprocessing import Process import os
import wave, os
import random import random
DEBUG = False
class Dejavu(object): class Dejavu(object):
def __init__(self, config): def __init__(self, config):
super(Dejavu, self).__init__()
self.config = config self.config = config
# initialize db # initialize db
self.db = SQLDatabase(**config.get("database", {})) self.db = SQLDatabase(**config.get("database", {}))
#self.fingerprinter = Fingerprinter(self.config)
self.db.setup() self.db.setup()
# get songs previously indexed # get songs previously indexed
@ -29,27 +27,27 @@ class Dejavu(object):
self.songnames_set.add(song_name) self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n): def fingerprint_directory(self, path, extensions, nprocesses=None):
""" # Try to use the maximum amount of processes if not given.
Splits a list into roughly n equal parts. if nprocesses is None:
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts try:
""" nprocesses = cpu_count()
return [lst[i::n] for i in xrange(n)] except NotImplementedError:
nprocesses = 1
def do_fingerprint(self, path, output, extensions, nprocesses):
# convert files, shuffle order # convert files, shuffle order
files = decoder.find_files(path, extensions) files = list(decoder.find_files(path, extensions))
random.shuffle(files) random.shuffle(files)
files_split = self.chunkify(files, nprocesses)
files_split = chunkify(files, nprocesses)
# split into processes here # split into processes here
processes = [] processes = []
for i in range(nprocesses): for i in range(nprocesses):
# create process and start it # create process and start it
p = Process(target=self.fingerprint_worker, p = Process(target=self._fingerprint_worker,
args=(files_split[i], self.db, output)) args=(files_split[i], self.db))
p.start() p.start()
processes.append(p) processes.append(p)
@ -57,55 +55,37 @@ class Dejavu(object):
for p in processes: for p in processes:
p.join() p.join()
# delete orphans def _fingerprint_worker(self, files, db):
# print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output):
for filename, extension in files: for filename, extension in files:
# if there are already fingerprints in database, don't re-fingerprint or convert # if there are already fingerprints in database,
# don't re-fingerprint
song_name = os.path.basename(filename).split(".")[0] song_name = os.path.basename(filename).split(".")[0]
if DEBUG and song_name in self.songnames_set: if song_name in self.songnames_set:
print("-> Already fingerprinted, continuing...") print("-> Already fingerprinted, continuing...")
continue continue
channels, Fs = decoder.read(filename) channels, Fs = decoder.read(filename)
# insert song name into database # insert song name into database
song_id = sql_connection.insert_song(song_name) song_id = db.insert_song(song_name)
for c in range(len(channels)): for c in range(len(channels)):
channel = channels[c] channel = channels[c]
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name) print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
hashes = fingerprint.fingerprint(channel, Fs=Fs) hashes = fingerprint.fingerprint(channel, Fs=Fs)
sql_connection.insert_hashes(song_id, hashes)
db.insert_hashes(song_id, hashes)
# only after done fingerprinting do confirm # only after done fingerprinting do confirm
sql_connection.set_song_fingerprinted(song_id) db.set_song_fingerprinted(song_id)
def extract_channels(self, path):
"""
Reads channels from disk.
Returns a tuple with (channels, sample_rate)
"""
channels = []
Fs, frames = wavfile.read(path)
wave_object = wave.open(path)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
#assert Fs == self.fingerprinter.Fs
for channel in range(nchannels):
channels.append(frames[:, channel])
return (channels, Fs)
def fingerprint_file(self, filepath, song_name=None): def fingerprint_file(self, filepath, song_name=None):
# TODO: replace with something that handles all audio formats channels, Fs = decoder.read(filepath)
channels, Fs = self.extract_channels(path)
if not song_name: if not song_name:
song_name = os.path.basename(filename).split(".")[0] song_name = os.path.basename(filepath).split(".")[0]
song_id = self.db.insert_song(song_name) song_id = self.db.insert_song(song_name)
for data in channels: for data in channels:
@ -141,7 +121,6 @@ class Dejavu(object):
largest_count = diff_counter[diff][sid] largest_count = diff_counter[diff][sid]
song_id = sid song_id = sid
if DEBUG:
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count)) print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
# extract idenfication # extract idenfication
@ -151,9 +130,6 @@ class Dejavu(object):
else: else:
return None return None
if DEBUG:
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
# return match info # return match info
song = { song = {
"song_id": song_id, "song_id": song_id,
@ -162,3 +138,15 @@ class Dejavu(object):
} }
return song return song
def recognize(self, recognizer, *options, **kwoptions):
r = recognizer(self)
return r.recognize(*options, **kwoptions)
def chunkify(lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]

View file

@ -1,17 +1,11 @@
import numpy as np import numpy as np
import matplotlib.mlab as mlab import matplotlib.mlab as mlab
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy.io import wavfile
from scipy.ndimage.filters import maximum_filter from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion from scipy.ndimage.morphology import (generate_binary_structure,
from dejavu.database import SQLDatabase iterate_structure, binary_erosion)
import os
import wave
import sys
import time
import hashlib import hashlib
import pickle
IDX_FREQ_I = 0 IDX_FREQ_I = 0
IDX_TIME_J = 1 IDX_TIME_J = 1
@ -25,8 +19,8 @@ DEFAULT_AMP_MIN = 10
PEAK_NEIGHBORHOOD_SIZE = 20 PEAK_NEIGHBORHOOD_SIZE = 20
MIN_HASH_TIME_DELTA = 0 MIN_HASH_TIME_DELTA = 0
def fingerprint(channel_samples,
Fs=DEFAULT_FS, def fingerprint(channel_samples, Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE, wsize=DEFAULT_WINDOW_SIZE,
wratio=DEFAULT_OVERLAP_RATIO, wratio=DEFAULT_OVERLAP_RATIO,
fan_value=DEFAULT_FAN_VALUE, fan_value=DEFAULT_FAN_VALUE,
@ -53,8 +47,8 @@ def fingerprint(channel_samples,
# return hashes # return hashes
return generate_hashes(local_maxima, fan_value=fan_value) return generate_hashes(local_maxima, fan_value=fan_value)
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
struct = generate_binary_structure(2, 1) struct = generate_binary_structure(2, 1)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
@ -62,8 +56,11 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# find local maxima using our fliter shape # find local maxima using our fliter shape
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
background = (arr2D == 0) background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) eroded_background = binary_erosion(background, structure=neighborhood,
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks border_value=1)
# Boolean mask of arr2D with True at peaks
detected_peaks = local_max - eroded_background
# extract peaks # extract peaks
amps = arr2D[detected_peaks] amps = arr2D[detected_peaks]
@ -85,97 +82,37 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
ax.scatter(time_idx, frequency_idx) ax.scatter(time_idx, frequency_idx)
ax.set_xlabel('Time') ax.set_xlabel('Time')
ax.set_ylabel('Frequency') ax.set_ylabel('Frequency')
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke"); ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke")
plt.gca().invert_yaxis() plt.gca().invert_yaxis()
plt.show() plt.show()
return zip(frequency_idx, time_idx) return zip(frequency_idx, time_idx)
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
""" """
Hash list structure: Hash list structure:
sha1-hash[0:20] time_offset sha1_hash[0:20] time_offset
[(e05b341a9b77a51fd26, 32), ... ] [(e05b341a9b77a51fd26, 32), ... ]
""" """
fingerprinted = set() # to avoid rehashing same pairs fingerprinted = set() # to avoid rehashing same pairs
hashes = []
for i in range(len(peaks)): for i in range(len(peaks)):
for j in range(fan_value): for j in range(fan_value):
if i+j < len(peaks) and not (i, i+j) in fingerprinted: if (i + j) < len(peaks) and not (i, i + j) in fingerprinted:
freq1 = peaks[i][IDX_FREQ_I] freq1 = peaks[i][IDX_FREQ_I]
freq2 = peaks[i + j][IDX_FREQ_I] freq2 = peaks[i + j][IDX_FREQ_I]
t1 = peaks[i][IDX_TIME_J] t1 = peaks[i][IDX_TIME_J]
t2 = peaks[i + j][IDX_TIME_J] t2 = peaks[i + j][IDX_TIME_J]
t_delta = t2 - t1 t_delta = t2 - t1
if t_delta >= MIN_HASH_TIME_DELTA: if t_delta >= MIN_HASH_TIME_DELTA:
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))) h = hashlib.sha1(
hashes.append((h.hexdigest()[0:20], t1)) "%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))
)
yield (h.hexdigest()[0:20], t1)
# ensure we don't repeat hashing # ensure we don't repeat hashing
fingerprinted.add((i, i + j)) fingerprinted.add((i, i + j))
return hashes
# TODO: move all of the below to a class with DB access
class Fingerprinter():
def __init__(self, config,
Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE,
wratio=DEFAULT_OVERLAP_RATIO,
fan_value=DEFAULT_FAN_VALUE,
amp_min=DEFAULT_AMP_MIN):
self.config = config
self.Fs = Fs
self.dt = 1.0 / self.Fs
self.window_size = wsize
self.window_overlap_ratio = wratio
self.fan_value = fan_value
self.noverlap = int(self.window_size * self.window_overlap_ratio)
self.amp_min = amp_min
def fingerprint(self, samples, path, sid, cid):
"""Used for learning known songs"""
hashes = self.process_channel(samples, song_id=sid)
print "Generated %d hashes" % len(hashes)
self.db.insert_hashes(hashes)
# TODO: put this in another module
def match(self, samples):
"""Used for matching unknown songs"""
hashes = self.process_channel(samples)
matches = self.db.return_matches(hashes)
return matches
# TODO: this function has nothing to do with fingerprinting. is it needed?
def print_stats(self):
iterable = self.db.get_iterable_kv_pairs()
counter = {}
for t in iterable:
sid, toff = t
if not sid in counter:
counter[sid] = 1
else:
counter[sid] += 1
for song_id, count in counter.iteritems():
song_name = self.song_names[song_id]
print "%s has %d spectrogram peaks" % (song_name, count)
# this does... what? this seems to only be used for the above function
def set_song_names(self, wpaths):
self.song_names = wpaths
# TODO: put this in another module

View file

@ -1,14 +1,8 @@
from multiprocessing import Queue, Process
from dejavu.database import SQLDatabase
import dejavu.fingerprint as fingerprint import dejavu.fingerprint as fingerprint
from dejavu import Dejavu import dejavu.decoder as decoder
from scipy.io import wavfile
import wave
import numpy as np import numpy as np
import pyaudio import pyaudio
import sys
import time import time
import array
class BaseRecognizer(object): class BaseRecognizer(object):
@ -26,25 +20,17 @@ class BaseRecognizer(object):
def recognize(self): def recognize(self):
pass # base class does nothing pass # base class does nothing
class WaveFileRecognizer(BaseRecognizer):
def __init__(self, dejavu, filename=None): class FileRecognizer(BaseRecognizer):
super(WaveFileRecognizer, self).__init__(dejavu) def __init__(self, dejavu):
self.filename = filename super(FileRecognizer, self).__init__(dejavu)
def recognize_file(self, filename): def recognize_file(self, filename):
Fs, frames = wavfile.read(filename) Fs, frames = decoder.read(filename)
self.Fs = Fs self.Fs = Fs
wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
channels = []
for channel in range(nchannels):
channels.append(frames[:, channel])
t = time.time() t = time.time()
match = self._recognize(*channels) match = self._recognize(*frames)
t = time.time() - t t = time.time() - t
if match: if match:
@ -52,50 +38,53 @@ class WaveFileRecognizer(BaseRecognizer):
return match return match
def recognize(self): def recognize(self, filename):
return self.recognize_file(self.filename) return self.recognize_file(filename)
class MicrophoneRecognizer(BaseRecognizer): class MicrophoneRecognizer(BaseRecognizer):
default_chunksize = 8192
default_format = pyaudio.paInt16
default_channels = 2
default_samplerate = 44100
CHUNK = 8192 # 44100 is a multiple of 1225 def __init__(self, dejavu):
FORMAT = pyaudio.paInt16
CHANNELS = 2
RATE = 44100
def __init__(self, dejavu, seconds=None):
super(MicrophoneRecognizer, self).__init__(dejavu) super(MicrophoneRecognizer, self).__init__(dejavu)
self.audio = pyaudio.PyAudio() self.audio = pyaudio.PyAudio()
self.stream = None self.stream = None
self.data = [] self.data = []
self.channels = CHANNELS self.channels = self.default_channels
self.chunk_size = CHUNK self.chunksize = self.default_chunk
self.rate = RATE self.samplerate = self.default_samplerate
self.recorded = False self.recorded = False
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK): def start_recording(self, channels=default_channels,
self.chunk_size = chunk samplerate=default_samplerate,
chunksize=default_chunksize):
self.chunksize = chunksize
self.channels = channels self.channels = channels
self.recorded = False self.recorded = False
self.rate = rate self.samplerate = samplerate
if self.stream: if self.stream:
self.stream.stop_stream() self.stream.stop_stream()
self.stream.close() self.stream.close()
self.stream = self.audio.open(format=FORMAT, self.stream = self.audio.open(
format=self.default_format,
channels=channels, channels=channels,
rate=rate, rate=samplerate,
input=True, input=True,
frames_per_buffer=chunk) frames_per_buffer=chunksize,
)
self.data = [[] for i in range(channels)] self.data = [[] for i in range(channels)]
def process_recording(self): def process_recording(self):
data = self.stream.read(self.chunk_size) data = self.stream.read(self.chunksize)
nums = np.fromstring(data, np.int16) nums = np.fromstring(data, np.int16)
for c in range(self.channels): for c in range(self.channels):
self.data[c].extend(nums[c::c+1]) self.data[c].extend(nums[c::len(self.channels)])
def stop_recording(self): def stop_recording(self):
self.stream.stop_stream() self.stream.stop_stream()
@ -111,13 +100,14 @@ class MicrophoneRecognizer(BaseRecognizer):
def get_recorded_time(self): def get_recorded_time(self):
return len(self.data[0]) / self.rate return len(self.data[0]) / self.rate
def recognize(self): def recognize(self, seconds=None):
self.start_recording() self.start_recording()
for i in range(0, int(self.rate / self.chunk * self.seconds)): for i in range(0, int(self.samplerate / self.chunksize
* seconds)):
self.process_recording() self.process_recording()
self.stop_recording() self.stop_recording()
return self.recognize_recording() return self.recognize_recording()
class NoRecordingError(Exception): class NoRecordingError(Exception):
pass pass