Began moving recognizer functionality into separate classes

This commit is contained in:
Vin 2013-12-17 20:55:20 +00:00
parent f02ab94192
commit 371742a314

66
dejavu/recognize.py Normal file → Executable file
View file

@ -1,5 +1,7 @@
from multiprocessing import Queue, Process from multiprocessing import Queue, Process
from dejavu.database import SQLDatabase from dejavu.database import SQLDatabase
import dejavu.fingerprint
from dejavu import Dejavu
from scipy.io import wavfile from scipy.io import wavfile
import wave import wave
import numpy as np import numpy as np
@ -8,6 +10,52 @@ import sys
import time import time
import array import array
class BaseRecognizer(object):
def __init__(self, dejavu):
self.dejavu = dejavu
self.Fs = dejavu.fingerprint.DEFAULT_FS
def recognize(self, *data):
matches = []
for d in data:
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
return self.dejavu.align_matches(matches)
class WaveFileRecognizer(BaseRecognizer):
def __init__(self, dejavu):
super(BaseRecognizer, self).__init__(dejavu)
def recognize_file(self, filepath):
Fs, frames = wavfile.read(filename)
self.Fs = Fs
wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
channels = []
for channel in range(nchannels):
channels.append(frames[:, channel])
t = time.time()
match = self.recognize(*channels)
t = time.time() - t
if match:
match['match_time'] = t
return match
class MicrophoneRecognizer(BaseRecognizer):
pass
class Recognizer(object): class Recognizer(object):
CHUNK = 8192 # 44100 is a multiple of 1225 CHUNK = 8192 # 44100 is a multiple of 1225
@ -20,24 +68,6 @@ class Recognizer(object):
self.fingerprinter = fingerprinter self.fingerprinter = fingerprinter
self.config = config self.config = config
self.audio = pyaudio.PyAudio() self.audio = pyaudio.PyAudio()
def read(self, filename, verbose=False):
# read file into channels
channels = []
Fs, frames = wavfile.read(filename)
wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
for channel in range(nchannels):
channels.append(frames[:, channel])
# get matches
starttime = time.time()
matches = []
for channel in channels:
matches.extend(self.fingerprinter.match(channel))
return self.fingerprinter.align_matches(matches, starttime, verbose=verbose)
def listen(self, seconds=10, verbose=False): def listen(self, seconds=10, verbose=False):