mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
Moved main Dejavu class from dejavu.control to dejavu
This commit is contained in:
parent
6a6ae94e3d
commit
a4ed612658
4 changed files with 178 additions and 165 deletions
176
dejavu/__init__.py
Normal file → Executable file
176
dejavu/__init__.py
Normal file → Executable file
|
@ -0,0 +1,176 @@
|
||||||
|
from dejavu.database import SQLDatabase
|
||||||
|
from dejavu.convert import Converter
|
||||||
|
import dejavu.fingerprint as fingerprint
|
||||||
|
from scipy.io import wavfile
|
||||||
|
from multiprocessing import Process
|
||||||
|
import wave, os
|
||||||
|
import random
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
class Dejavu():
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# initialize db
|
||||||
|
database = SQLDatabase(
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||||
|
self.db = database
|
||||||
|
|
||||||
|
# create components
|
||||||
|
self.converter = Converter()
|
||||||
|
#self.fingerprinter = Fingerprinter(self.config)
|
||||||
|
self.db.setup()
|
||||||
|
|
||||||
|
# get songs previously indexed
|
||||||
|
self.songs = self.db.get_songs()
|
||||||
|
self.songnames_set = set() # to know which ones we've computed before
|
||||||
|
if self.songs:
|
||||||
|
for song in self.songs:
|
||||||
|
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
||||||
|
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
||||||
|
self.songnames_set.add(song_name)
|
||||||
|
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||||
|
|
||||||
|
def chunkify(self, lst, n):
|
||||||
|
"""
|
||||||
|
Splits a list into roughly n equal parts.
|
||||||
|
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||||
|
"""
|
||||||
|
return [lst[i::n] for i in xrange(n)]
|
||||||
|
|
||||||
|
def fingerprint(self, path, output, extensions, nprocesses):
|
||||||
|
|
||||||
|
# convert files, shuffle order
|
||||||
|
files = self.converter.find_files(path, extensions)
|
||||||
|
random.shuffle(files)
|
||||||
|
files_split = self.chunkify(files, nprocesses)
|
||||||
|
|
||||||
|
# split into processes here
|
||||||
|
processes = []
|
||||||
|
for i in range(nprocesses):
|
||||||
|
|
||||||
|
# need database instance since mysql connections shouldn't be shared across processes
|
||||||
|
sql_connection = SQLDatabase(
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||||
|
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||||
|
|
||||||
|
# create process and start it
|
||||||
|
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
# wait for all processes to complete
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
# delete orphans
|
||||||
|
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||||
|
# TODO: need a more performant query in database.py for the
|
||||||
|
#self.fingerprinter.db.delete_orphans()
|
||||||
|
|
||||||
|
def fingerprint_worker(self, files, sql_connection, output):
|
||||||
|
|
||||||
|
for filename, extension in files:
|
||||||
|
|
||||||
|
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||||
|
song_name = os.path.basename(filename).split(".")[0]
|
||||||
|
if DEBUG and song_name in self.songnames_set:
|
||||||
|
print("-> Already fingerprinted, continuing...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# convert to WAV
|
||||||
|
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
|
||||||
|
|
||||||
|
# insert song name into database
|
||||||
|
song_id = sql_connection.insert_song(song_name)
|
||||||
|
|
||||||
|
# for each channel perform FFT analysis and fingerprinting
|
||||||
|
channels, Fs = self.extract_channels(wavout_path)
|
||||||
|
for c in range(len(channels)):
|
||||||
|
channel = channels[c]
|
||||||
|
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
||||||
|
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||||
|
sql_connection.insert_hashes(song_id, hashes)
|
||||||
|
|
||||||
|
# only after done fingerprinting do confirm
|
||||||
|
sql_connection.set_song_fingerprinted(song_id)
|
||||||
|
|
||||||
|
def extract_channels(self, path):
|
||||||
|
"""
|
||||||
|
Reads channels from disk.
|
||||||
|
Returns a tuple with (channels, sample_rate)
|
||||||
|
"""
|
||||||
|
channels = []
|
||||||
|
Fs, frames = wavfile.read(path)
|
||||||
|
wave_object = wave.open(path)
|
||||||
|
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||||
|
#assert Fs == self.fingerprinter.Fs
|
||||||
|
|
||||||
|
for channel in range(nchannels):
|
||||||
|
channels.append(frames[:, channel])
|
||||||
|
return (channels, Fs)
|
||||||
|
|
||||||
|
def match(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||||
|
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||||
|
return self.db.return_matches(hashes)
|
||||||
|
|
||||||
|
def align_matches(self, matches, starttime, record_seconds=None):
|
||||||
|
"""
|
||||||
|
Finds hash matches that align in time with other matches and finds
|
||||||
|
consensus about which hashes are "true" signal from the audio.
|
||||||
|
|
||||||
|
Returns a dictionary with match information.
|
||||||
|
"""
|
||||||
|
# align by diffs
|
||||||
|
diff_counter = {}
|
||||||
|
largest = 0
|
||||||
|
largest_count = 0
|
||||||
|
song_id = -1
|
||||||
|
for tup in matches:
|
||||||
|
sid, diff = tup
|
||||||
|
if not diff in diff_counter:
|
||||||
|
diff_counter[diff] = {}
|
||||||
|
if not sid in diff_counter[diff]:
|
||||||
|
diff_counter[diff][sid] = 0
|
||||||
|
diff_counter[diff][sid] += 1
|
||||||
|
|
||||||
|
if diff_counter[diff][sid] > largest_count:
|
||||||
|
largest = diff
|
||||||
|
largest_count = diff_counter[diff][sid]
|
||||||
|
song_id = sid
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print("Diff is %d with %d offset-aligned matches" % (largest, largest_count))
|
||||||
|
|
||||||
|
# extract idenfication
|
||||||
|
song = self.db.get_song_by_id(song_id)
|
||||||
|
if song:
|
||||||
|
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
songname = songname.replace("_", " ")
|
||||||
|
elapsed = time.time() - starttime
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
print("Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed))
|
||||||
|
|
||||||
|
# return match info
|
||||||
|
song = {
|
||||||
|
"song_id" : song_id,
|
||||||
|
"song_name" : songname,
|
||||||
|
"match_time" : elapsed,
|
||||||
|
"confidence" : largest_count
|
||||||
|
}
|
||||||
|
|
||||||
|
if record_seconds:
|
||||||
|
song['record_time'] = record_seconds
|
||||||
|
|
||||||
|
return song
|
|
@ -1,107 +0,0 @@
|
||||||
from dejavu.database import SQLDatabase
|
|
||||||
from dejavu.convert import Converter
|
|
||||||
from dejavu.fingerprint import Fingerprinter
|
|
||||||
from scipy.io import wavfile
|
|
||||||
from multiprocessing import Process
|
|
||||||
import wave, os
|
|
||||||
import random
|
|
||||||
|
|
||||||
class Dejavu():
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# create components
|
|
||||||
self.converter = Converter()
|
|
||||||
self.fingerprinter = Fingerprinter(self.config)
|
|
||||||
self.fingerprinter.db.setup()
|
|
||||||
|
|
||||||
# get songs previously indexed
|
|
||||||
self.songs = self.fingerprinter.db.get_songs()
|
|
||||||
self.songnames_set = set() # to know which ones we've computed before
|
|
||||||
if self.songs:
|
|
||||||
for song in self.songs:
|
|
||||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
|
||||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
|
||||||
self.songnames_set.add(song_name)
|
|
||||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
|
||||||
|
|
||||||
def chunkify(self, lst, n):
|
|
||||||
"""
|
|
||||||
Splits a list into roughly n equal parts.
|
|
||||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
|
||||||
"""
|
|
||||||
return [lst[i::n] for i in xrange(n)]
|
|
||||||
|
|
||||||
def fingerprint(self, path, output, extensions, nprocesses):
|
|
||||||
|
|
||||||
# convert files, shuffle order
|
|
||||||
files = self.converter.find_files(path, extensions)
|
|
||||||
random.shuffle(files)
|
|
||||||
files_split = self.chunkify(files, nprocesses)
|
|
||||||
|
|
||||||
# split into processes here
|
|
||||||
processes = []
|
|
||||||
for i in range(nprocesses):
|
|
||||||
|
|
||||||
# need database instance since mysql connections shouldn't be shared across processes
|
|
||||||
sql_connection = SQLDatabase(
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
|
||||||
|
|
||||||
# create process and start it
|
|
||||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output))
|
|
||||||
p.start()
|
|
||||||
processes.append(p)
|
|
||||||
|
|
||||||
# wait for all processes to complete
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
# delete orphans
|
|
||||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
|
||||||
# TODO: need a more performant query in database.py for the
|
|
||||||
#self.fingerprinter.db.delete_orphans()
|
|
||||||
|
|
||||||
def fingerprint_worker(self, files, sql_connection, output):
|
|
||||||
|
|
||||||
for filename, extension in files:
|
|
||||||
|
|
||||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
|
||||||
song_name = os.path.basename(filename).split(".")[0]
|
|
||||||
if song_name in self.songnames_set:
|
|
||||||
print "-> Already fingerprinted, continuing..."
|
|
||||||
continue
|
|
||||||
|
|
||||||
# convert to WAV
|
|
||||||
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name)
|
|
||||||
|
|
||||||
# insert song name into database
|
|
||||||
song_id = sql_connection.insert_song(song_name)
|
|
||||||
|
|
||||||
# for each channel perform FFT analysis and fingerprinting
|
|
||||||
channels = self.extract_channels(wavout_path)
|
|
||||||
for c in range(len(channels)):
|
|
||||||
channel = channels[c]
|
|
||||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name)
|
|
||||||
self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1)
|
|
||||||
|
|
||||||
# only after done fingerprinting do confirm
|
|
||||||
sql_connection.set_song_fingerprinted(song_id)
|
|
||||||
|
|
||||||
def extract_channels(self, path):
|
|
||||||
"""
|
|
||||||
Reads channels from disk.
|
|
||||||
"""
|
|
||||||
channels = []
|
|
||||||
Fs, frames = wavfile.read(path)
|
|
||||||
wave_object = wave.open(path)
|
|
||||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
|
||||||
assert Fs == self.fingerprinter.Fs
|
|
||||||
|
|
||||||
for channel in range(nchannels):
|
|
||||||
channels.append(frames[:, channel])
|
|
||||||
return channels
|
|
|
@ -133,12 +133,7 @@ class Fingerprinter():
|
||||||
amp_min=DEFAULT_AMP_MIN):
|
amp_min=DEFAULT_AMP_MIN):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
database = SQLDatabase(
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
|
||||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
|
||||||
self.db = database
|
|
||||||
|
|
||||||
self.Fs = Fs
|
self.Fs = Fs
|
||||||
self.dt = 1.0 / self.Fs
|
self.dt = 1.0 / self.Fs
|
||||||
|
@ -183,55 +178,4 @@ class Fingerprinter():
|
||||||
self.song_names = wpaths
|
self.song_names = wpaths
|
||||||
|
|
||||||
# TODO: put this in another module
|
# TODO: put this in another module
|
||||||
def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
|
|
||||||
"""
|
|
||||||
Finds hash matches that align in time with other matches and finds
|
|
||||||
consensus about which hashes are "true" signal from the audio.
|
|
||||||
|
|
||||||
Returns a dictionary with match information.
|
|
||||||
"""
|
|
||||||
# align by diffs
|
|
||||||
diff_counter = {}
|
|
||||||
largest = 0
|
|
||||||
largest_count = 0
|
|
||||||
song_id = -1
|
|
||||||
for tup in matches:
|
|
||||||
sid, diff = tup
|
|
||||||
if not diff in diff_counter:
|
|
||||||
diff_counter[diff] = {}
|
|
||||||
if not sid in diff_counter[diff]:
|
|
||||||
diff_counter[diff][sid] = 0
|
|
||||||
diff_counter[diff][sid] += 1
|
|
||||||
|
|
||||||
if diff_counter[diff][sid] > largest_count:
|
|
||||||
largest = diff
|
|
||||||
largest_count = diff_counter[diff][sid]
|
|
||||||
song_id = sid
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print "Diff is %d with %d offset-aligned matches" % (largest, largest_count)
|
|
||||||
|
|
||||||
# extract idenfication
|
|
||||||
song = self.db.get_song_by_id(song_id)
|
|
||||||
if song:
|
|
||||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
songname = songname.replace("_", " ")
|
|
||||||
elapsed = time.time() - starttime
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)
|
|
||||||
|
|
||||||
# return match info
|
|
||||||
song = {
|
|
||||||
"song_id" : song_id,
|
|
||||||
"song_name" : songname,
|
|
||||||
"match_time" : elapsed,
|
|
||||||
"confidence" : largest_count
|
|
||||||
}
|
|
||||||
|
|
||||||
if record_seconds:
|
|
||||||
song['record_time'] = record_seconds
|
|
||||||
|
|
||||||
return song
|
|
||||||
|
|
2
go.py
Normal file → Executable file
2
go.py
Normal file → Executable file
|
@ -1,4 +1,4 @@
|
||||||
from dejavu.control import Dejavu
|
from dejavu import Dejavu
|
||||||
from ConfigParser import ConfigParser
|
from ConfigParser import ConfigParser
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
Loading…
Reference in a new issue