mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 11:09:52 +00:00
Began moving recognizer functionality into separate classes
This commit is contained in:
parent
f02ab94192
commit
371742a314
1 changed files with 48 additions and 18 deletions
66
dejavu/recognize.py
Normal file → Executable file
66
dejavu/recognize.py
Normal file → Executable file
|
@ -1,5 +1,7 @@
|
|||
from multiprocessing import Queue, Process
|
||||
from dejavu.database import SQLDatabase
|
||||
import dejavu.fingerprint
|
||||
from dejavu import Dejavu
|
||||
from scipy.io import wavfile
|
||||
import wave
|
||||
import numpy as np
|
||||
|
@ -8,6 +10,52 @@ import sys
|
|||
import time
|
||||
import array
|
||||
|
||||
|
||||
class BaseRecognizer(object):
|
||||
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = dejavu.fingerprint.DEFAULT_FS
|
||||
|
||||
def recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(data, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
|
||||
class WaveFileRecognizer(BaseRecognizer):
|
||||
|
||||
def __init__(self, dejavu):
|
||||
super(BaseRecognizer, self).__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filepath):
|
||||
Fs, frames = wavfile.read(filename)
|
||||
self.Fs = Fs
|
||||
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
|
||||
channels = []
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
|
||||
t = time.time()
|
||||
match = self.recognize(*channels)
|
||||
t = time.time() - t
|
||||
|
||||
if match:
|
||||
match['match_time'] = t
|
||||
|
||||
return match
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
class Recognizer(object):
|
||||
|
||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||
|
@ -20,24 +68,6 @@ class Recognizer(object):
|
|||
self.fingerprinter = fingerprinter
|
||||
self.config = config
|
||||
self.audio = pyaudio.PyAudio()
|
||||
|
||||
def read(self, filename, verbose=False):
|
||||
|
||||
# read file into channels
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(filename)
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
|
||||
# get matches
|
||||
starttime = time.time()
|
||||
matches = []
|
||||
for channel in channels:
|
||||
matches.extend(self.fingerprinter.match(channel))
|
||||
|
||||
return self.fingerprinter.align_matches(matches, starttime, verbose=verbose)
|
||||
|
||||
def listen(self, seconds=10, verbose=False):
|
||||
|
||||
|
|
Loading…
Reference in a new issue