A fairly big batch of changes, blame me to forget committing

Changes in no particular order:
- Replaced all use cases of wavfile/wave and extract_channels with the new decoder.read function
- Added a 'recognize' method to the Dejavu class. This is a shortcut for recognizing songs.
- Renamed 'do_fingerprint' into 'fingerprint_directory'
- Removed parameters not required anymore from fingerprint_directory
- Cleaned up fingerprint.py
- Made fingerprint.generate_hashes a generator
- WaveFileRecognizer is now FileRecognizer and can take any formats supported by pydub
- Fixed MicrophoneRecognizer to actually run, previous version had many small mistakes
- Renamed 'fingerprint_worker' to '_fingerprint_worker' to signify it is not to be used publicly
- Moved 'chunkify' outside the Dejavu class
- Cleaned up pep8 styling mistakes in all edited files.
This commit is contained in:
Wessie 2013-12-19 00:54:17 +01:00
parent 7895bae23e
commit 29029238fc
4 changed files with 120 additions and 205 deletions

View file

@ -1,22 +1,20 @@
from dejavu.database import SQLDatabase
import dejavu.decode as decoder
import dejavu.decoder as decoder
import fingerprint
from scipy.io import wavfile
from multiprocessing import Process
import wave, os
from multiprocessing import Process, cpu_count
import os
import random
DEBUG = False
class Dejavu(object):
def __init__(self, config):
super(Dejavu, self).__init__()
self.config = config
# initialize db
self.db = SQLDatabase(**config.get("database", {}))
#self.fingerprinter = Fingerprinter(self.config)
self.db.setup()
# get songs previously indexed
@ -29,27 +27,27 @@ class Dejavu(object):
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]
def do_fingerprint(self, path, output, extensions, nprocesses):
def fingerprint_directory(self, path, extensions, nprocesses=None):
# Try to use the maximum amount of processes if not given.
if nprocesses is None:
try:
nprocesses = cpu_count()
except NotImplementedError:
nprocesses = 1
# convert files, shuffle order
files = decoder.find_files(path, extensions)
files = list(decoder.find_files(path, extensions))
random.shuffle(files)
files_split = self.chunkify(files, nprocesses)
files_split = chunkify(files, nprocesses)
# split into processes here
processes = []
for i in range(nprocesses):
# create process and start it
p = Process(target=self.fingerprint_worker,
args=(files_split[i], self.db, output))
p = Process(target=self._fingerprint_worker,
args=(files_split[i], self.db))
p.start()
processes.append(p)
@ -57,55 +55,37 @@ class Dejavu(object):
for p in processes:
p.join()
# delete orphans
# print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output):
def _fingerprint_worker(self, files, db):
for filename, extension in files:
# if there are already fingerprints in database, don't re-fingerprint or convert
# if there are already fingerprints in database,
# don't re-fingerprint
song_name = os.path.basename(filename).split(".")[0]
if DEBUG and song_name in self.songnames_set:
if song_name in self.songnames_set:
print("-> Already fingerprinted, continuing...")
continue
channels, Fs = decoder.read(filename)
# insert song name into database
song_id = sql_connection.insert_song(song_name)
song_id = db.insert_song(song_name)
for c in range(len(channels)):
channel = channels[c]
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
hashes = fingerprint.fingerprint(channel, Fs=Fs)
sql_connection.insert_hashes(song_id, hashes)
db.insert_hashes(song_id, hashes)
# only after done fingerprinting do confirm
sql_connection.set_song_fingerprinted(song_id)
def extract_channels(self, path):
"""
Reads channels from disk.
Returns a tuple with (channels, sample_rate)
"""
channels = []
Fs, frames = wavfile.read(path)
wave_object = wave.open(path)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
#assert Fs == self.fingerprinter.Fs
for channel in range(nchannels):
channels.append(frames[:, channel])
return (channels, Fs)
db.set_song_fingerprinted(song_id)
def fingerprint_file(self, filepath, song_name=None):
# TODO: replace with something that handles all audio formats
channels, Fs = self.extract_channels(path)
channels, Fs = decoder.read(filepath)
if not song_name:
song_name = os.path.basename(filename).split(".")[0]
song_name = os.path.basename(filepath).split(".")[0]
song_id = self.db.insert_song(song_name)
for data in channels:
@ -141,7 +121,6 @@ class Dejavu(object):
largest_count = diff_counter[diff][sid]
song_id = sid
if DEBUG:
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
# extract idenfication
@ -151,9 +130,6 @@ class Dejavu(object):
else:
return None
if DEBUG:
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
# return match info
song = {
"song_id": song_id,
@ -162,3 +138,15 @@ class Dejavu(object):
}
return song
def recognize(self, recognizer, *options, **kwoptions):
r = recognizer(self)
return r.recognize(*options, **kwoptions)
def chunkify(lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]

View file

@ -1,17 +1,11 @@
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy.io import wavfile
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion
from dejavu.database import SQLDatabase
import os
import wave
import sys
import time
from scipy.ndimage.morphology import (generate_binary_structure,
iterate_structure, binary_erosion)
import hashlib
import pickle
IDX_FREQ_I = 0
IDX_TIME_J = 1
@ -25,8 +19,8 @@ DEFAULT_AMP_MIN = 10
PEAK_NEIGHBORHOOD_SIZE = 20
MIN_HASH_TIME_DELTA = 0
def fingerprint(channel_samples,
Fs=DEFAULT_FS,
def fingerprint(channel_samples, Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE,
wratio=DEFAULT_OVERLAP_RATIO,
fan_value=DEFAULT_FAN_VALUE,
@ -53,8 +47,8 @@ def fingerprint(channel_samples,
# return hashes
return generate_hashes(local_maxima, fan_value=fan_value)
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
struct = generate_binary_structure(2, 1)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
@ -62,8 +56,11 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# find local maxima using our fliter shape
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
eroded_background = binary_erosion(background, structure=neighborhood,
border_value=1)
# Boolean mask of arr2D with True at peaks
detected_peaks = local_max - eroded_background
# extract peaks
amps = arr2D[detected_peaks]
@ -85,97 +82,37 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
ax.scatter(time_idx, frequency_idx)
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke")
plt.gca().invert_yaxis()
plt.show()
return zip(frequency_idx, time_idx)
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
"""
Hash list structure:
sha1-hash[0:20] time_offset
sha1_hash[0:20] time_offset
[(e05b341a9b77a51fd26, 32), ... ]
"""
fingerprinted = set() # to avoid rehashing same pairs
hashes = []
for i in range(len(peaks)):
for j in range(fan_value):
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
if (i + j) < len(peaks) and not (i, i + j) in fingerprinted:
freq1 = peaks[i][IDX_FREQ_I]
freq2 = peaks[i + j][IDX_FREQ_I]
t1 = peaks[i][IDX_TIME_J]
t2 = peaks[i + j][IDX_TIME_J]
t_delta = t2 - t1
if t_delta >= MIN_HASH_TIME_DELTA:
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
hashes.append((h.hexdigest()[0:20], t1))
h = hashlib.sha1(
"%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))
)
yield (h.hexdigest()[0:20], t1)
# ensure we don't repeat hashing
fingerprinted.add((i, i + j))
return hashes
# TODO: move all of the below to a class with DB access
class Fingerprinter():
def __init__(self, config,
Fs=DEFAULT_FS,
wsize=DEFAULT_WINDOW_SIZE,
wratio=DEFAULT_OVERLAP_RATIO,
fan_value=DEFAULT_FAN_VALUE,
amp_min=DEFAULT_AMP_MIN):
self.config = config
self.Fs = Fs
self.dt = 1.0 / self.Fs
self.window_size = wsize
self.window_overlap_ratio = wratio
self.fan_value = fan_value
self.noverlap = int(self.window_size * self.window_overlap_ratio)
self.amp_min = amp_min
def fingerprint(self, samples, path, sid, cid):
"""Used for learning known songs"""
hashes = self.process_channel(samples, song_id=sid)
print "Generated %d hashes" % len(hashes)
self.db.insert_hashes(hashes)
# TODO: put this in another module
def match(self, samples):
"""Used for matching unknown songs"""
hashes = self.process_channel(samples)
matches = self.db.return_matches(hashes)
return matches
# TODO: this function has nothing to do with fingerprinting. is it needed?
def print_stats(self):
iterable = self.db.get_iterable_kv_pairs()
counter = {}
for t in iterable:
sid, toff = t
if not sid in counter:
counter[sid] = 1
else:
counter[sid] += 1
for song_id, count in counter.iteritems():
song_name = self.song_names[song_id]
print "%s has %d spectrogram peaks" % (song_name, count)
# this does... what? this seems to only be used for the above function
def set_song_names(self, wpaths):
self.song_names = wpaths
# TODO: put this in another module

View file

@ -1,14 +1,8 @@
from multiprocessing import Queue, Process
from dejavu.database import SQLDatabase
import dejavu.fingerprint as fingerprint
from dejavu import Dejavu
from scipy.io import wavfile
import wave
import dejavu.decoder as decoder
import numpy as np
import pyaudio
import sys
import time
import array
class BaseRecognizer(object):
@ -26,25 +20,17 @@ class BaseRecognizer(object):
def recognize(self):
pass # base class does nothing
class WaveFileRecognizer(BaseRecognizer):
def __init__(self, dejavu, filename=None):
super(WaveFileRecognizer, self).__init__(dejavu)
self.filename = filename
class FileRecognizer(BaseRecognizer):
def __init__(self, dejavu):
super(FileRecognizer, self).__init__(dejavu)
def recognize_file(self, filename):
Fs, frames = wavfile.read(filename)
Fs, frames = decoder.read(filename)
self.Fs = Fs
wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
channels = []
for channel in range(nchannels):
channels.append(frames[:, channel])
t = time.time()
match = self._recognize(*channels)
match = self._recognize(*frames)
t = time.time() - t
if match:
@ -52,50 +38,53 @@ class WaveFileRecognizer(BaseRecognizer):
return match
def recognize(self):
return self.recognize_file(self.filename)
def recognize(self, filename):
return self.recognize_file(filename)
class MicrophoneRecognizer(BaseRecognizer):
default_chunksize = 8192
default_format = pyaudio.paInt16
default_channels = 2
default_samplerate = 44100
CHUNK = 8192 # 44100 is a multiple of 1225
FORMAT = pyaudio.paInt16
CHANNELS = 2
RATE = 44100
def __init__(self, dejavu, seconds=None):
def __init__(self, dejavu):
super(MicrophoneRecognizer, self).__init__(dejavu)
self.audio = pyaudio.PyAudio()
self.stream = None
self.data = []
self.channels = CHANNELS
self.chunk_size = CHUNK
self.rate = RATE
self.channels = self.default_channels
self.chunksize = self.default_chunk
self.samplerate = self.default_samplerate
self.recorded = False
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
self.chunk_size = chunk
def start_recording(self, channels=default_channels,
samplerate=default_samplerate,
chunksize=default_chunksize):
self.chunksize = chunksize
self.channels = channels
self.recorded = False
self.rate = rate
self.samplerate = samplerate
if self.stream:
self.stream.stop_stream()
self.stream.close()
self.stream = self.audio.open(format=FORMAT,
self.stream = self.audio.open(
format=self.default_format,
channels=channels,
rate=rate,
rate=samplerate,
input=True,
frames_per_buffer=chunk)
frames_per_buffer=chunksize,
)
self.data = [[] for i in range(channels)]
def process_recording(self):
data = self.stream.read(self.chunk_size)
data = self.stream.read(self.chunksize)
nums = np.fromstring(data, np.int16)
for c in range(self.channels):
self.data[c].extend(nums[c::c+1])
self.data[c].extend(nums[c::len(self.channels)])
def stop_recording(self):
self.stream.stop_stream()
@ -111,13 +100,14 @@ class MicrophoneRecognizer(BaseRecognizer):
def get_recorded_time(self):
return len(self.data[0]) / self.rate
def recognize(self):
def recognize(self, seconds=None):
self.start_recording()
for i in range(0, int(self.rate / self.chunk * self.seconds)):
for i in range(0, int(self.samplerate / self.chunksize
* seconds)):
self.process_recording()
self.stop_recording()
return self.recognize_recording()
class NoRecordingError(Exception):
pass