mirror of
https://github.com/correl/dejavu.git
synced 2024-11-27 11:09:51 +00:00
A fairly big batch of changes, blame me to forget committing
Changes in no particular order: - Replaced all use cases of wavfile/wave and extract_channels with the new decoder.read function - Added a 'recognize' method to the Dejavu class. This is a shortcut for recognizing songs. - Renamed 'do_fingerprint' into 'fingerprint_directory' - Removed parameters not required anymore from fingerprint_directory - Cleaned up fingerprint.py - Made fingerprint.generate_hashes a generator - WaveFileRecognizer is now FileRecognizer and can take any formats supported by pydub - Fixed MicrophoneRecognizer to actually run, previous version had many small mistakes - Renamed 'fingerprint_worker' to '_fingerprint_worker' to signify it is not to be used publicly - Moved 'chunkify' outside the Dejavu class - Cleaned up pep8 styling mistakes in all edited files.
This commit is contained in:
parent
7895bae23e
commit
29029238fc
4 changed files with 120 additions and 205 deletions
|
@ -1,22 +1,20 @@
|
||||||
from dejavu.database import SQLDatabase
|
from dejavu.database import SQLDatabase
|
||||||
import dejavu.decode as decoder
|
import dejavu.decoder as decoder
|
||||||
import fingerprint
|
import fingerprint
|
||||||
from scipy.io import wavfile
|
from multiprocessing import Process, cpu_count
|
||||||
from multiprocessing import Process
|
import os
|
||||||
import wave, os
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
DEBUG = False
|
|
||||||
|
|
||||||
class Dejavu(object):
|
class Dejavu(object):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
super(Dejavu, self).__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# initialize db
|
# initialize db
|
||||||
self.db = SQLDatabase(**config.get("database", {}))
|
self.db = SQLDatabase(**config.get("database", {}))
|
||||||
|
|
||||||
#self.fingerprinter = Fingerprinter(self.config)
|
|
||||||
self.db.setup()
|
self.db.setup()
|
||||||
|
|
||||||
# get songs previously indexed
|
# get songs previously indexed
|
||||||
|
@ -29,27 +27,27 @@ class Dejavu(object):
|
||||||
self.songnames_set.add(song_name)
|
self.songnames_set.add(song_name)
|
||||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||||
|
|
||||||
def chunkify(self, lst, n):
|
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
||||||
"""
|
# Try to use the maximum amount of processes if not given.
|
||||||
Splits a list into roughly n equal parts.
|
if nprocesses is None:
|
||||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
try:
|
||||||
"""
|
nprocesses = cpu_count()
|
||||||
return [lst[i::n] for i in xrange(n)]
|
except NotImplementedError:
|
||||||
|
nprocesses = 1
|
||||||
def do_fingerprint(self, path, output, extensions, nprocesses):
|
|
||||||
|
|
||||||
# convert files, shuffle order
|
# convert files, shuffle order
|
||||||
files = decoder.find_files(path, extensions)
|
files = list(decoder.find_files(path, extensions))
|
||||||
random.shuffle(files)
|
random.shuffle(files)
|
||||||
files_split = self.chunkify(files, nprocesses)
|
|
||||||
|
files_split = chunkify(files, nprocesses)
|
||||||
|
|
||||||
# split into processes here
|
# split into processes here
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(nprocesses):
|
for i in range(nprocesses):
|
||||||
|
|
||||||
# create process and start it
|
# create process and start it
|
||||||
p = Process(target=self.fingerprint_worker,
|
p = Process(target=self._fingerprint_worker,
|
||||||
args=(files_split[i], self.db, output))
|
args=(files_split[i], self.db))
|
||||||
p.start()
|
p.start()
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
|
|
||||||
|
@ -57,55 +55,37 @@ class Dejavu(object):
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
# delete orphans
|
def _fingerprint_worker(self, files, db):
|
||||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
|
||||||
# TODO: need a more performant query in database.py for the
|
|
||||||
#self.fingerprinter.db.delete_orphans()
|
|
||||||
|
|
||||||
def fingerprint_worker(self, files, sql_connection, output):
|
|
||||||
|
|
||||||
for filename, extension in files:
|
for filename, extension in files:
|
||||||
|
|
||||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
# if there are already fingerprints in database,
|
||||||
|
# don't re-fingerprint
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
song_name = os.path.basename(filename).split(".")[0]
|
||||||
if DEBUG and song_name in self.songnames_set:
|
if song_name in self.songnames_set:
|
||||||
print("-> Already fingerprinted, continuing...")
|
print("-> Already fingerprinted, continuing...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
channels, Fs = decoder.read(filename)
|
channels, Fs = decoder.read(filename)
|
||||||
|
|
||||||
# insert song name into database
|
# insert song name into database
|
||||||
song_id = sql_connection.insert_song(song_name)
|
song_id = db.insert_song(song_name)
|
||||||
|
|
||||||
for c in range(len(channels)):
|
for c in range(len(channels)):
|
||||||
channel = channels[c]
|
channel = channels[c]
|
||||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||||
|
|
||||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||||
sql_connection.insert_hashes(song_id, hashes)
|
|
||||||
|
db.insert_hashes(song_id, hashes)
|
||||||
|
|
||||||
# only after done fingerprinting do confirm
|
# only after done fingerprinting do confirm
|
||||||
sql_connection.set_song_fingerprinted(song_id)
|
db.set_song_fingerprinted(song_id)
|
||||||
|
|
||||||
def extract_channels(self, path):
|
|
||||||
"""
|
|
||||||
Reads channels from disk.
|
|
||||||
Returns a tuple with (channels, sample_rate)
|
|
||||||
"""
|
|
||||||
channels = []
|
|
||||||
Fs, frames = wavfile.read(path)
|
|
||||||
wave_object = wave.open(path)
|
|
||||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
|
||||||
#assert Fs == self.fingerprinter.Fs
|
|
||||||
|
|
||||||
for channel in range(nchannels):
|
|
||||||
channels.append(frames[:, channel])
|
|
||||||
return (channels, Fs)
|
|
||||||
|
|
||||||
def fingerprint_file(self, filepath, song_name=None):
|
def fingerprint_file(self, filepath, song_name=None):
|
||||||
# TODO: replace with something that handles all audio formats
|
channels, Fs = decoder.read(filepath)
|
||||||
channels, Fs = self.extract_channels(path)
|
|
||||||
if not song_name:
|
if not song_name:
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
song_name = os.path.basename(filepath).split(".")[0]
|
||||||
song_id = self.db.insert_song(song_name)
|
song_id = self.db.insert_song(song_name)
|
||||||
|
|
||||||
for data in channels:
|
for data in channels:
|
||||||
|
@ -141,7 +121,6 @@ class Dejavu(object):
|
||||||
largest_count = diff_counter[diff][sid]
|
largest_count = diff_counter[diff][sid]
|
||||||
song_id = sid
|
song_id = sid
|
||||||
|
|
||||||
if DEBUG:
|
|
||||||
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||||
|
|
||||||
# extract idenfication
|
# extract idenfication
|
||||||
|
@ -151,9 +130,6 @@ class Dejavu(object):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if DEBUG:
|
|
||||||
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
|
||||||
|
|
||||||
# return match info
|
# return match info
|
||||||
song = {
|
song = {
|
||||||
"song_id": song_id,
|
"song_id": song_id,
|
||||||
|
@ -162,3 +138,15 @@ class Dejavu(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
return song
|
return song
|
||||||
|
|
||||||
|
def recognize(self, recognizer, *options, **kwoptions):
|
||||||
|
r = recognizer(self)
|
||||||
|
return r.recognize(*options, **kwoptions)
|
||||||
|
|
||||||
|
|
||||||
|
def chunkify(lst, n):
|
||||||
|
"""
|
||||||
|
Splits a list into roughly n equal parts.
|
||||||
|
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||||
|
"""
|
||||||
|
return [lst[i::n] for i in xrange(n)]
|
||||||
|
|
|
@ -1,17 +1,11 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.mlab as mlab
|
import matplotlib.mlab as mlab
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.image as mpimg
|
|
||||||
from scipy.io import wavfile
|
|
||||||
from scipy.ndimage.filters import maximum_filter
|
from scipy.ndimage.filters import maximum_filter
|
||||||
from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion
|
from scipy.ndimage.morphology import (generate_binary_structure,
|
||||||
from dejavu.database import SQLDatabase
|
iterate_structure, binary_erosion)
|
||||||
import os
|
|
||||||
import wave
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import pickle
|
|
||||||
|
|
||||||
IDX_FREQ_I = 0
|
IDX_FREQ_I = 0
|
||||||
IDX_TIME_J = 1
|
IDX_TIME_J = 1
|
||||||
|
@ -25,8 +19,8 @@ DEFAULT_AMP_MIN = 10
|
||||||
PEAK_NEIGHBORHOOD_SIZE = 20
|
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||||
MIN_HASH_TIME_DELTA = 0
|
MIN_HASH_TIME_DELTA = 0
|
||||||
|
|
||||||
def fingerprint(channel_samples,
|
|
||||||
Fs=DEFAULT_FS,
|
def fingerprint(channel_samples, Fs=DEFAULT_FS,
|
||||||
wsize=DEFAULT_WINDOW_SIZE,
|
wsize=DEFAULT_WINDOW_SIZE,
|
||||||
wratio=DEFAULT_OVERLAP_RATIO,
|
wratio=DEFAULT_OVERLAP_RATIO,
|
||||||
fan_value=DEFAULT_FAN_VALUE,
|
fan_value=DEFAULT_FAN_VALUE,
|
||||||
|
@ -53,8 +47,8 @@ def fingerprint(channel_samples,
|
||||||
# return hashes
|
# return hashes
|
||||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||||
|
|
||||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
|
||||||
|
|
||||||
|
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
||||||
struct = generate_binary_structure(2, 1)
|
struct = generate_binary_structure(2, 1)
|
||||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||||
|
@ -62,8 +56,11 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||||
# find local maxima using our fliter shape
|
# find local maxima using our fliter shape
|
||||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||||
background = (arr2D == 0)
|
background = (arr2D == 0)
|
||||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
eroded_background = binary_erosion(background, structure=neighborhood,
|
||||||
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
|
border_value=1)
|
||||||
|
|
||||||
|
# Boolean mask of arr2D with True at peaks
|
||||||
|
detected_peaks = local_max - eroded_background
|
||||||
|
|
||||||
# extract peaks
|
# extract peaks
|
||||||
amps = arr2D[detected_peaks]
|
amps = arr2D[detected_peaks]
|
||||||
|
@ -85,97 +82,37 @@ def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||||
ax.scatter(time_idx, frequency_idx)
|
ax.scatter(time_idx, frequency_idx)
|
||||||
ax.set_xlabel('Time')
|
ax.set_xlabel('Time')
|
||||||
ax.set_ylabel('Frequency')
|
ax.set_ylabel('Frequency')
|
||||||
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
|
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke")
|
||||||
plt.gca().invert_yaxis()
|
plt.gca().invert_yaxis()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
return zip(frequency_idx, time_idx)
|
return zip(frequency_idx, time_idx)
|
||||||
|
|
||||||
|
|
||||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||||
"""
|
"""
|
||||||
Hash list structure:
|
Hash list structure:
|
||||||
sha1-hash[0:20] time_offset
|
sha1_hash[0:20] time_offset
|
||||||
[(e05b341a9b77a51fd26, 32), ... ]
|
[(e05b341a9b77a51fd26, 32), ... ]
|
||||||
"""
|
"""
|
||||||
fingerprinted = set() # to avoid rehashing same pairs
|
fingerprinted = set() # to avoid rehashing same pairs
|
||||||
hashes = []
|
|
||||||
|
|
||||||
for i in range(len(peaks)):
|
for i in range(len(peaks)):
|
||||||
for j in range(fan_value):
|
for j in range(fan_value):
|
||||||
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
|
if (i + j) < len(peaks) and not (i, i + j) in fingerprinted:
|
||||||
|
|
||||||
freq1 = peaks[i][IDX_FREQ_I]
|
freq1 = peaks[i][IDX_FREQ_I]
|
||||||
freq2 = peaks[i + j][IDX_FREQ_I]
|
freq2 = peaks[i + j][IDX_FREQ_I]
|
||||||
|
|
||||||
t1 = peaks[i][IDX_TIME_J]
|
t1 = peaks[i][IDX_TIME_J]
|
||||||
t2 = peaks[i + j][IDX_TIME_J]
|
t2 = peaks[i + j][IDX_TIME_J]
|
||||||
|
|
||||||
t_delta = t2 - t1
|
t_delta = t2 - t1
|
||||||
|
|
||||||
if t_delta >= MIN_HASH_TIME_DELTA:
|
if t_delta >= MIN_HASH_TIME_DELTA:
|
||||||
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
h = hashlib.sha1(
|
||||||
hashes.append((h.hexdigest()[0:20], t1))
|
"%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))
|
||||||
|
)
|
||||||
|
yield (h.hexdigest()[0:20], t1)
|
||||||
|
|
||||||
# ensure we don't repeat hashing
|
# ensure we don't repeat hashing
|
||||||
fingerprinted.add((i, i + j))
|
fingerprinted.add((i, i + j))
|
||||||
return hashes
|
|
||||||
|
|
||||||
# TODO: move all of the below to a class with DB access
|
|
||||||
|
|
||||||
|
|
||||||
class Fingerprinter():
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, config,
|
|
||||||
Fs=DEFAULT_FS,
|
|
||||||
wsize=DEFAULT_WINDOW_SIZE,
|
|
||||||
wratio=DEFAULT_OVERLAP_RATIO,
|
|
||||||
fan_value=DEFAULT_FAN_VALUE,
|
|
||||||
amp_min=DEFAULT_AMP_MIN):
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
|
|
||||||
self.Fs = Fs
|
|
||||||
self.dt = 1.0 / self.Fs
|
|
||||||
self.window_size = wsize
|
|
||||||
self.window_overlap_ratio = wratio
|
|
||||||
self.fan_value = fan_value
|
|
||||||
self.noverlap = int(self.window_size * self.window_overlap_ratio)
|
|
||||||
self.amp_min = amp_min
|
|
||||||
|
|
||||||
def fingerprint(self, samples, path, sid, cid):
|
|
||||||
"""Used for learning known songs"""
|
|
||||||
hashes = self.process_channel(samples, song_id=sid)
|
|
||||||
print "Generated %d hashes" % len(hashes)
|
|
||||||
self.db.insert_hashes(hashes)
|
|
||||||
|
|
||||||
# TODO: put this in another module
|
|
||||||
def match(self, samples):
|
|
||||||
"""Used for matching unknown songs"""
|
|
||||||
hashes = self.process_channel(samples)
|
|
||||||
matches = self.db.return_matches(hashes)
|
|
||||||
return matches
|
|
||||||
|
|
||||||
# TODO: this function has nothing to do with fingerprinting. is it needed?
|
|
||||||
def print_stats(self):
|
|
||||||
|
|
||||||
iterable = self.db.get_iterable_kv_pairs()
|
|
||||||
|
|
||||||
counter = {}
|
|
||||||
for t in iterable:
|
|
||||||
sid, toff = t
|
|
||||||
if not sid in counter:
|
|
||||||
counter[sid] = 1
|
|
||||||
else:
|
|
||||||
counter[sid] += 1
|
|
||||||
|
|
||||||
for song_id, count in counter.iteritems():
|
|
||||||
song_name = self.song_names[song_id]
|
|
||||||
print "%s has %d spectrogram peaks" % (song_name, count)
|
|
||||||
|
|
||||||
# this does... what? this seems to only be used for the above function
|
|
||||||
def set_song_names(self, wpaths):
|
|
||||||
self.song_names = wpaths
|
|
||||||
|
|
||||||
# TODO: put this in another module
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,8 @@
|
||||||
from multiprocessing import Queue, Process
|
|
||||||
from dejavu.database import SQLDatabase
|
|
||||||
import dejavu.fingerprint as fingerprint
|
import dejavu.fingerprint as fingerprint
|
||||||
from dejavu import Dejavu
|
import dejavu.decoder as decoder
|
||||||
from scipy.io import wavfile
|
|
||||||
import wave
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyaudio
|
import pyaudio
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import array
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRecognizer(object):
|
class BaseRecognizer(object):
|
||||||
|
@ -26,25 +20,17 @@ class BaseRecognizer(object):
|
||||||
def recognize(self):
|
def recognize(self):
|
||||||
pass # base class does nothing
|
pass # base class does nothing
|
||||||
|
|
||||||
class WaveFileRecognizer(BaseRecognizer):
|
|
||||||
|
|
||||||
def __init__(self, dejavu, filename=None):
|
class FileRecognizer(BaseRecognizer):
|
||||||
super(WaveFileRecognizer, self).__init__(dejavu)
|
def __init__(self, dejavu):
|
||||||
self.filename = filename
|
super(FileRecognizer, self).__init__(dejavu)
|
||||||
|
|
||||||
def recognize_file(self, filename):
|
def recognize_file(self, filename):
|
||||||
Fs, frames = wavfile.read(filename)
|
Fs, frames = decoder.read(filename)
|
||||||
self.Fs = Fs
|
self.Fs = Fs
|
||||||
|
|
||||||
wave_object = wave.open(filename)
|
|
||||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
|
||||||
|
|
||||||
channels = []
|
|
||||||
for channel in range(nchannels):
|
|
||||||
channels.append(frames[:, channel])
|
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
match = self._recognize(*channels)
|
match = self._recognize(*frames)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
|
@ -52,50 +38,53 @@ class WaveFileRecognizer(BaseRecognizer):
|
||||||
|
|
||||||
return match
|
return match
|
||||||
|
|
||||||
def recognize(self):
|
def recognize(self, filename):
|
||||||
return self.recognize_file(self.filename)
|
return self.recognize_file(filename)
|
||||||
|
|
||||||
|
|
||||||
class MicrophoneRecognizer(BaseRecognizer):
|
class MicrophoneRecognizer(BaseRecognizer):
|
||||||
|
default_chunksize = 8192
|
||||||
|
default_format = pyaudio.paInt16
|
||||||
|
default_channels = 2
|
||||||
|
default_samplerate = 44100
|
||||||
|
|
||||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
def __init__(self, dejavu):
|
||||||
FORMAT = pyaudio.paInt16
|
|
||||||
CHANNELS = 2
|
|
||||||
RATE = 44100
|
|
||||||
|
|
||||||
def __init__(self, dejavu, seconds=None):
|
|
||||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||||
self.audio = pyaudio.PyAudio()
|
self.audio = pyaudio.PyAudio()
|
||||||
self.stream = None
|
self.stream = None
|
||||||
self.data = []
|
self.data = []
|
||||||
self.channels = CHANNELS
|
self.channels = self.default_channels
|
||||||
self.chunk_size = CHUNK
|
self.chunksize = self.default_chunk
|
||||||
self.rate = RATE
|
self.samplerate = self.default_samplerate
|
||||||
self.recorded = False
|
self.recorded = False
|
||||||
|
|
||||||
def start_recording(self, channels=CHANNELS, rate=RATE, chunk=CHUNK):
|
def start_recording(self, channels=default_channels,
|
||||||
self.chunk_size = chunk
|
samplerate=default_samplerate,
|
||||||
|
chunksize=default_chunksize):
|
||||||
|
self.chunksize = chunksize
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.recorded = False
|
self.recorded = False
|
||||||
self.rate = rate
|
self.samplerate = samplerate
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
self.stream.stop_stream()
|
self.stream.stop_stream()
|
||||||
self.stream.close()
|
self.stream.close()
|
||||||
|
|
||||||
self.stream = self.audio.open(format=FORMAT,
|
self.stream = self.audio.open(
|
||||||
|
format=self.default_format,
|
||||||
channels=channels,
|
channels=channels,
|
||||||
rate=rate,
|
rate=samplerate,
|
||||||
input=True,
|
input=True,
|
||||||
frames_per_buffer=chunk)
|
frames_per_buffer=chunksize,
|
||||||
|
)
|
||||||
|
|
||||||
self.data = [[] for i in range(channels)]
|
self.data = [[] for i in range(channels)]
|
||||||
|
|
||||||
def process_recording(self):
|
def process_recording(self):
|
||||||
data = self.stream.read(self.chunk_size)
|
data = self.stream.read(self.chunksize)
|
||||||
nums = np.fromstring(data, np.int16)
|
nums = np.fromstring(data, np.int16)
|
||||||
for c in range(self.channels):
|
for c in range(self.channels):
|
||||||
self.data[c].extend(nums[c::c+1])
|
self.data[c].extend(nums[c::len(self.channels)])
|
||||||
|
|
||||||
def stop_recording(self):
|
def stop_recording(self):
|
||||||
self.stream.stop_stream()
|
self.stream.stop_stream()
|
||||||
|
@ -111,13 +100,14 @@ class MicrophoneRecognizer(BaseRecognizer):
|
||||||
def get_recorded_time(self):
|
def get_recorded_time(self):
|
||||||
return len(self.data[0]) / self.rate
|
return len(self.data[0]) / self.rate
|
||||||
|
|
||||||
def recognize(self):
|
def recognize(self, seconds=None):
|
||||||
self.start_recording()
|
self.start_recording()
|
||||||
for i in range(0, int(self.rate / self.chunk * self.seconds)):
|
for i in range(0, int(self.samplerate / self.chunksize
|
||||||
|
* seconds)):
|
||||||
self.process_recording()
|
self.process_recording()
|
||||||
self.stop_recording()
|
self.stop_recording()
|
||||||
return self.recognize_recording()
|
return self.recognize_recording()
|
||||||
|
|
||||||
|
|
||||||
class NoRecordingError(Exception):
|
class NoRecordingError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue