mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
A fairly big batch of changes, blame me to forget committing
Changes in no particular order: - Replaced all use cases of wavfile/wave and extract_channels with the new decoder.read function - Added a 'recognize' method to the Dejavu class. This is a shortcut for recognizing songs. - Renamed 'do_fingerprint' into 'fingerprint_directory' - Removed parameters not required anymore from fingerprint_directory - Cleaned up fingerprint.py - Made fingerprint.generate_hashes a generator - WaveFileRecognizer is now FileRecognizer and can take any formats supported by pydub - Fixed MicrophoneRecognizer to actually run, previous version had many small mistakes - Renamed 'fingerprint_worker' to '_fingerprint_worker' to signify it is not to be used publicly - Moved 'chunkify' outside the Dejavu class - Cleaned up pep8 styling mistakes in all edited files.
This commit is contained in:
parent
7895bae23e
commit
29029238fc
4 changed files with 120 additions and 205 deletions
|
@ -1,22 +1,20 @@
|
|||
from dejavu.database import SQLDatabase
|
||||
import dejavu.decode as decoder
|
||||
import dejavu.decoder as decoder
|
||||
import fingerprint
|
||||
from scipy.io import wavfile
|
||||
from multiprocessing import Process
|
||||
import wave, os
|
||||
from multiprocessing import Process, cpu_count
|
||||
import os
|
||||
import random
|
||||
|
||||
DEBUG = False
|
||||
|
||||
class Dejavu(object):
|
||||
def __init__(self, config):
|
||||
super(Dejavu, self).__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# initialize db
|
||||
self.db = SQLDatabase(**config.get("database", {}))
|
||||
|
||||
#self.fingerprinter = Fingerprinter(self.config)
|
||||
self.db.setup()
|
||||
|
||||
# get songs previously indexed
|
||||
|
@ -29,27 +27,27 @@ class Dejavu(object):
|
|||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
def chunkify(self, lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
|
||||
def do_fingerprint(self, path, output, extensions, nprocesses):
|
||||
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
||||
# Try to use the maximum amount of processes if not given.
|
||||
if nprocesses is None:
|
||||
try:
|
||||
nprocesses = cpu_count()
|
||||
except NotImplementedError:
|
||||
nprocesses = 1
|
||||
|
||||
# convert files, shuffle order
|
||||
files = decoder.find_files(path, extensions)
|
||||
files = list(decoder.find_files(path, extensions))
|
||||
random.shuffle(files)
|
||||
files_split = self.chunkify(files, nprocesses)
|
||||
|
||||
files_split = chunkify(files, nprocesses)
|
||||
|
||||
# split into processes here
|
||||
processes = []
|
||||
for i in range(nprocesses):
|
||||
|
||||
# create process and start it
|
||||
p = Process(target=self.fingerprint_worker,
|
||||
args=(files_split[i], self.db, output))
|
||||
p = Process(target=self._fingerprint_worker,
|
||||
args=(files_split[i], self.db))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
|
@ -57,55 +55,37 @@ class Dejavu(object):
|
|||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# delete orphans
|
||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||
# TODO: need a more performant query in database.py for the
|
||||
#self.fingerprinter.db.delete_orphans()
|
||||
|
||||
def fingerprint_worker(self, files, sql_connection, output):
|
||||
|
||||
def _fingerprint_worker(self, files, db):
|
||||
for filename, extension in files:
|
||||
|
||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||
# if there are already fingerprints in database,
|
||||
# don't re-fingerprint
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
if DEBUG and song_name in self.songnames_set:
|
||||
if song_name in self.songnames_set:
|
||||
print("-> Already fingerprinted, continuing...")
|
||||
continue
|
||||
|
||||
channels, Fs = decoder.read(filename)
|
||||
|
||||
# insert song name into database
|
||||
song_id = sql_connection.insert_song(song_name)
|
||||
song_id = db.insert_song(song_name)
|
||||
|
||||
for c in range(len(channels)):
|
||||
channel = channels[c]
|
||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||
|
||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||
sql_connection.insert_hashes(song_id, hashes)
|
||||
|
||||
db.insert_hashes(song_id, hashes)
|
||||
|
||||
# only after done fingerprinting do confirm
|
||||
sql_connection.set_song_fingerprinted(song_id)
|
||||
|
||||
def extract_channels(self, path):
|
||||
"""
|
||||
Reads channels from disk.
|
||||
Returns a tuple with (channels, sample_rate)
|
||||
"""
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(path)
|
||||
wave_object = wave.open(path)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
#assert Fs == self.fingerprinter.Fs
|
||||
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
return (channels, Fs)
|
||||
db.set_song_fingerprinted(song_id)
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
# TODO: replace with something that handles all audio formats
|
||||
channels, Fs = self.extract_channels(path)
|
||||
channels, Fs = decoder.read(filepath)
|
||||
|
||||
if not song_name:
|
||||
song_name = os.path.basename(filename).split(".")[0]
|
||||
song_name = os.path.basename(filepath).split(".")[0]
|
||||
song_id = self.db.insert_song(song_name)
|
||||
|
||||
for data in channels:
|
||||
|
@ -141,7 +121,6 @@ class Dejavu(object):
|
|||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
if DEBUG:
|
||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||
|
||||
# extract idenfication
|
||||
|
@ -151,9 +130,6 @@ class Dejavu(object):
|
|||
else:
|
||||
return None
|
||||
|
||||
if DEBUG:
|
||||
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
||||
|
||||
# return match info
|
||||
song = {
|
||||
"song_id": song_id,
|
||||
|
@ -162,3 +138,15 @@ class Dejavu(object):
|
|||
}
|
||||
|
||||
return song
|
||||
|
||||
def recognize(self, recognizer, *options, **kwoptions):
|
||||
r = recognizer(self)
|
||||
return r.recognize(*options, **kwoptions)
|
||||
|
||||
|
||||
def chunkify(lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
import numpy as np
|
||||
import matplotlib.mlab as mlab
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.image as mpimg
|
||||
from scipy.io import wavfile
|
||||
from scipy.ndimage.filters import maximum_filter
|
||||
from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion
|
||||
from dejavu.database import SQLDatabase
|
||||
import os
|
||||
import wave
|
||||
import sys
|
||||
import time
|
||||
from scipy.ndimage.morphology import (generate_binary_structure,
|
||||
iterate_structure, binary_erosion)
|
||||
import hashlib
|
||||
import pickle
|
||||
|
||||
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
||||
|
@ -25,8 +19,8 @@ DEFAULT_AMP_MIN = 10
|
|||
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||
MIN_HASH_TIME_DELTA = 0
|
||||
|
||||
def fingerprint(channel_samples,
|
||||
Fs=DEFAULT_FS,
|
||||
|
||||
def fingerprint(channel_samples, Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
|
@ -53,8 +47,8 @@ def fingerprint(channel_samples,
|
|||
# return hashes
|
||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||
|
||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||
|
||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
||||
struct = generate_binary_structure(2, 1)
|
||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||
|
@ -62,8 +56,11 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
|||
# find local maxima using our fliter shape
|
||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||
background = (arr2D == 0)
|
||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
||||
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
|
||||
eroded_background = binary_erosion(background, structure=neighborhood,
|
||||
border_value=1)
|
||||
|
||||
# Boolean mask of arr2D with True at peaks
|
||||
detected_peaks = local_max - eroded_background
|
||||
|
||||
# extract peaks
|
||||
amps = arr2D[detected_peaks]
|
||||
|
@ -85,97 +82,37 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
|||
ax.scatter(time_idx, frequency_idx)
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Frequency')
|
||||
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
|
||||
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke")
|
||||
plt.gca().invert_yaxis()
|
||||
plt.show()
|
||||
|
||||
return zip(frequency_idx, time_idx)
|
||||
|
||||
|
||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||
"""
|
||||
Hash list structure:
|
||||
sha1-hash[0:20] time_offset
|
||||
sha1_hash[0:20] time_offset
|
||||
[(e05b341a9b77a51fd26, 32), ... ]
|
||||
"""
|
||||
fingerprinted = set() # to avoid rehashing same pairs
|
||||
hashes = []
|
||||
|
||||
for i in range(len(peaks)):
|
||||
for j in range(fan_value):
|
||||
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
|
||||
|
||||
if (i + j) < len(peaks) and not (i, i + j) in fingerprinted:
|
||||
freq1 = peaks[i][IDX_FREQ_I]
|
||||
freq2 = peaks[i + j][IDX_FREQ_I]
|
||||
|
||||
t1 = peaks[i][IDX_TIME_J]
|
||||
t2 = peaks[i + j][IDX_TIME_J]
|
||||
|
||||
t_delta = t2 - t1
|
||||
|
||||
if t_delta >= MIN_HASH_TIME_DELTA:
|
||||
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
||||
hashes.append((h.hexdigest()[0:20], t1))
|
||||
h = hashlib.sha1(
|
||||
"%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))
|
||||
)
|
||||
yield (h.hexdigest()[0:20], t1)
|
||||
|
||||
# ensure we don't repeat hashing
|
||||
fingerprinted.add((i, i + j))
|
||||
return hashes
|
||||
|
||||
# TODO: move all of the below to a class with DB access
|
||||
|
||||
|
||||
class Fingerprinter():
|
||||
|
||||
|
||||
|
||||
def __init__(self, config,
|
||||
Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
amp_min=DEFAULT_AMP_MIN):
|
||||
|
||||
self.config = config
|
||||
|
||||
|
||||
self.Fs = Fs
|
||||
self.dt = 1.0 / self.Fs
|
||||
self.window_size = wsize
|
||||
self.window_overlap_ratio = wratio
|
||||
self.fan_value = fan_value
|
||||
self.noverlap = int(self.window_size * self.window_overlap_ratio)
|
||||
self.amp_min = amp_min
|
||||
|
||||
def fingerprint(self, samples, path, sid, cid):
|
||||
"""Used for learning known songs"""
|
||||
hashes = self.process_channel(samples, song_id=sid)
|
||||
print "Generated %d hashes" % len(hashes)
|
||||
self.db.insert_hashes(hashes)
|
||||
|
||||
# TODO: put this in another module
|
||||
def match(self, samples):
|
||||
"""Used for matching unknown songs"""
|
||||
hashes = self.process_channel(samples)
|
||||
matches = self.db.return_matches(hashes)
|
||||
return matches
|
||||
|
||||
# TODO: this function has nothing to do with fingerprinting. is it needed?
|
||||
def print_stats(self):
|
||||
|
||||
iterable = self.db.get_iterable_kv_pairs()
|
||||
|
||||
counter = {}
|
||||
for t in iterable:
|
||||
sid, toff = t
|
||||
if not sid in counter:
|
||||
counter[sid] = 1
|
||||
else:
|
||||
counter[sid] += 1
|
||||
|
||||
for song_id, count in counter.iteritems():
|
||||
song_name = self.song_names[song_id]
|
||||
print "%s has %d spectrogram peaks" % (song_name, count)
|
||||
|
||||
# this does... what? this seems to only be used for the above function
|
||||
def set_song_names(self, wpaths):
|
||||
self.song_names = wpaths
|
||||
|
||||
# TODO: put this in another module
|
||||
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
from multiprocessing import Queue, Process
|
||||
from dejavu.database import SQLDatabase
|
||||
import dejavu.fingerprint as fingerprint
|
||||
from dejavu import Dejavu
|
||||
from scipy.io import wavfile
|
||||
import wave
|
||||
import dejavu.decoder as decoder
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import sys
|
||||
import time
|
||||
import array
|
||||
|
||||
|
||||
class BaseRecognizer(object):
|
||||
|
@ -26,25 +20,17 @@ class BaseRecognizer(object):
|
|||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
class WaveFileRecognizer(BaseRecognizer):
|
||||
|
||||
def __init__(self, dejavu, filename=None):
|
||||
super(WaveFileRecognizer, self).__init__(dejavu)
|
||||
self.filename = filename
|
||||
class FileRecognizer(BaseRecognizer):
|
||||
def __init__(self, dejavu):
|
||||
super(FileRecognizer, self).__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filename):
|
||||
Fs, frames = wavfile.read(filename)
|
||||
Fs, frames = decoder.read(filename)
|
||||
self.Fs = Fs
|
||||
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
|
||||
channels = []
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
|
||||
t = time.time()
|
||||
match = self._recognize(*channels)
|
||||
match = self._recognize(*frames)
|
||||
t = time.time() - t
|
||||
|
||||
if match:
|
||||
|
@ -52,50 +38,53 @@ class WaveFileRecognizer(BaseRecognizer):
|
|||
|
||||
return match
|
||||
|
||||
def recognize(self):
|
||||
return self.recognize_file(self.filename)
|
||||
def recognize(self, filename):
|
||||
return self.recognize_file(filename)
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
||||
default_chunksize = 8192
|
||||
default_format = pyaudio.paInt16
|
||||
default_channels = 2
|
||||
default_samplerate = 44100
|
||||
|
||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 2
|
||||
RATE = 44100
|
||||
|
||||
def __init__(self, dejavu, seconds=None):
|
||||
def __init__(self, dejavu):
|
||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
self.data = []
|
||||
self.channels = CHANNELS
|
||||
self.chunk_size = CHUNK
|
||||
self.rate = RATE
|
||||
self.channels = self.default_channels
|
||||
self.chunksize = self.default_chunk
|
||||
self.samplerate = self.default_samplerate
|
||||
self.recorded = False
|
||||
|
||||
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
||||
self.chunk_size = chunk
|
||||
def start_recording(self, channels=default_channels,
|
||||
samplerate=default_samplerate,
|
||||
chunksize=default_chunksize):
|
||||
self.chunksize = chunksize
|
||||
self.channels = channels
|
||||
self.recorded = False
|
||||
self.rate = rate
|
||||
self.samplerate = samplerate
|
||||
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
|
||||
self.stream = self.audio.open(format=FORMAT,
|
||||
self.stream = self.audio.open(
|
||||
format=self.default_format,
|
||||
channels=channels,
|
||||
rate=rate,
|
||||
rate=samplerate,
|
||||
input=True,
|
||||
frames_per_buffer=chunk)
|
||||
frames_per_buffer=chunksize,
|
||||
)
|
||||
|
||||
self.data = [[] for i in range(channels)]
|
||||
|
||||
def process_recording(self):
|
||||
data = self.stream.read(self.chunk_size)
|
||||
data = self.stream.read(self.chunksize)
|
||||
nums = np.fromstring(data, np.int16)
|
||||
for c in range(self.channels):
|
||||
self.data[c].extend(nums[c::c+1])
|
||||
self.data[c].extend(nums[c::len(self.channels)])
|
||||
|
||||
def stop_recording(self):
|
||||
self.stream.stop_stream()
|
||||
|
@ -111,13 +100,14 @@ class MicrophoneRecognizer(BaseRecognizer):
|
|||
def get_recorded_time(self):
|
||||
return len(self.data[0]) / self.rate
|
||||
|
||||
def recognize(self):
|
||||
def recognize(self, seconds=None):
|
||||
self.start_recording()
|
||||
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
||||
for i in range(0, int(self.samplerate / self.chunksize
|
||||
* seconds)):
|
||||
self.process_recording()
|
||||
self.stop_recording()
|
||||
return self.recognize_recording()
|
||||
|
||||
|
||||
class NoRecordingError(Exception):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in a new issue