mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 19:19:53 +00:00
bugfixes and tweaks. also now returns offset of matched song
This commit is contained in:
commit
9eca3cc05a
11 changed files with 920 additions and 881 deletions
BIN
README.md
BIN
README.md
Binary file not shown.
171
dejavu/__init__.py
Normal file → Executable file
171
dejavu/__init__.py
Normal file → Executable file
|
@ -0,0 +1,171 @@
|
|||
from dejavu.database import get_database
|
||||
import dejavu.decoder as decoder
|
||||
import fingerprint
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
|
||||
class Dejavu(object):
|
||||
def __init__(self, config):
|
||||
super(Dejavu, self).__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# initialize db
|
||||
db_cls = get_database(config.get("database_type", None))
|
||||
|
||||
self.db = db_cls(**config.get("database", {}))
|
||||
self.db.setup()
|
||||
|
||||
# get songs previously indexed
|
||||
# TODO: should probably use a checksum of the file instead of filename
|
||||
self.songs = self.db.get_songs()
|
||||
self.songnames_set = set() # to know which ones we've computed before
|
||||
for song in self.songs:
|
||||
song_name = song[self.db.FIELD_SONGNAME]
|
||||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
||||
# Try to use the maximum amount of processes if not given.
|
||||
try:
|
||||
nprocesses = nprocesses or multiprocessing.cpu_count()
|
||||
except NotImplementedError:
|
||||
nprocesses = 1
|
||||
else:
|
||||
nprocesses = 1 if nprocesses <= 0 else nprocesses
|
||||
|
||||
pool = multiprocessing.Pool(nprocesses)
|
||||
|
||||
results = []
|
||||
for filename, _ in decoder.find_files(path, extensions):
|
||||
|
||||
# don't refingerprint already fingerprinted files
|
||||
if decoder.path_to_songname(filename) in self.songnames_set:
|
||||
print "%s already fingerprinted, continuing..." % filename
|
||||
continue
|
||||
|
||||
result = pool.apply_async(_fingerprint_worker,
|
||||
(filename, self.db))
|
||||
results.append(result)
|
||||
|
||||
while len(results):
|
||||
for result in results[:]:
|
||||
# TODO: Handle errors gracefully and return them to the callee
|
||||
# in some way.
|
||||
try:
|
||||
result.get(timeout=2)
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
except:
|
||||
import traceback, sys
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
results.remove(result)
|
||||
else:
|
||||
results.remove(result)
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
channels, Fs = decoder.read(filepath)
|
||||
|
||||
if not song_name:
|
||||
print "Song name: %s" % song_name
|
||||
song_name = decoder.path_to_songname(filepath)
|
||||
song_id = self.db.insert_song(song_name)
|
||||
|
||||
for data in channels:
|
||||
hashes = fingerprint.fingerprint(data, Fs=Fs)
|
||||
self.db.insert_hashes(song_id, hashes)
|
||||
|
||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||
return self.db.return_matches(hashes)
|
||||
|
||||
def align_matches(self, matches):
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
|
||||
Returns a dictionary with match information.
|
||||
"""
|
||||
# align by diffs
|
||||
diff_counter = {}
|
||||
largest = 0
|
||||
largest_count = 0
|
||||
song_id = -1
|
||||
for tup in matches:
|
||||
sid, diff = tup
|
||||
if not diff in diff_counter:
|
||||
diff_counter[diff] = {}
|
||||
if not sid in diff_counter[diff]:
|
||||
diff_counter[diff][sid] = 0
|
||||
diff_counter[diff][sid] += 1
|
||||
|
||||
if diff_counter[diff][sid] > largest_count:
|
||||
largest = diff
|
||||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
print("Diff is %d with %d offset-aligned matches" % (largest,
|
||||
largest_count))
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
# TODO: Clarifey what `get_song_by_id` should return.
|
||||
songname = song.get("song_name", None)
|
||||
else:
|
||||
return None
|
||||
|
||||
# return match info
|
||||
song = {
|
||||
"song_id": song_id,
|
||||
"song_name": songname,
|
||||
"confidence": largest_count,
|
||||
"offset" : largest
|
||||
}
|
||||
|
||||
return song
|
||||
|
||||
def recognize(self, recognizer, *options, **kwoptions):
|
||||
r = recognizer(self)
|
||||
return r.recognize(*options, **kwoptions)
|
||||
|
||||
|
||||
def _fingerprint_worker(filename, db):
|
||||
song_name, extension = os.path.splitext(os.path.basename(filename))
|
||||
|
||||
channels, Fs = decoder.read(filename)
|
||||
|
||||
# insert song into database
|
||||
sid = db.insert_song(song_name)
|
||||
|
||||
channel_amount = len(channels)
|
||||
for channeln, channel in enumerate(channels):
|
||||
# TODO: Remove prints or change them into optional logging.
|
||||
print("Fingerprinting channel %d/%d for %s" % (channeln + 1,
|
||||
channel_amount,
|
||||
filename))
|
||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
||||
filename))
|
||||
|
||||
print("Inserting fingerprints for channel %d/%d for %s" %
|
||||
(channeln + 1, channel_amount, filename))
|
||||
db.insert_hashes(sid, hashes)
|
||||
print("Finished inserting for channel %d/%d for %s" %
|
||||
(channeln + 1, channel_amount, filename))
|
||||
|
||||
print("Marking %s finished" % (filename,))
|
||||
db.set_song_fingerprinted(sid)
|
||||
print("%s finished" % (filename,))
|
||||
|
||||
|
||||
def chunkify(lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
|
@ -1,120 +0,0 @@
|
|||
from dejavu.database import SQLDatabase
|
||||
from dejavu.convert import Converter
|
||||
from dejavu.fingerprint import Fingerprinter
|
||||
from scipy.io import wavfile
|
||||
from multiprocessing import Process
|
||||
import wave, os
|
||||
import random
|
||||
|
||||
class Dejavu():
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
|
||||
# create components
|
||||
self.converter = Converter(config)
|
||||
self.fingerprinter = Fingerprinter(self.config)
|
||||
self.fingerprinter.db.setup()
|
||||
|
||||
# get songs previously indexed
|
||||
self.songs = self.fingerprinter.db.get_songs()
|
||||
self.songnames_set = set() # to know which ones we've computed before
|
||||
if self.songs:
|
||||
for song in self.songs:
|
||||
song_id = song[SQLDatabase.FIELD_SONG_ID]
|
||||
song_name = song[SQLDatabase.FIELD_SONGNAME]
|
||||
self.songnames_set.add(song_name)
|
||||
print "Added: %s to the set of fingerprinted songs..." % song_name
|
||||
|
||||
def chunkify(self, lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
|
||||
def fingerprint(self, path, output, extensions, nprocesses, keep_wav=False):
|
||||
|
||||
# convert files, shuffle order
|
||||
files = self.converter.find_files(path, extensions)
|
||||
random.shuffle(files)
|
||||
files_split = self.chunkify(files, nprocesses)
|
||||
|
||||
# split into processes here
|
||||
processes = []
|
||||
for i in range(nprocesses):
|
||||
|
||||
# need database instance since mysql connections shouldn't be shared across processes
|
||||
sql_connection = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
|
||||
# create process and start it
|
||||
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output, keep_wav))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# wait for all processes to complete
|
||||
try:
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
print "-> Exiting.."
|
||||
for worker in processes:
|
||||
worker.terminate()
|
||||
worker.join()
|
||||
|
||||
# delete orphans
|
||||
# print "Done fingerprinting. Deleting orphaned fingerprints..."
|
||||
# TODO: need a more performant query in database.py for the
|
||||
#self.fingerprinter.db.delete_orphans()
|
||||
|
||||
def fingerprint_worker(self, files, sql_connection, output, keep_wav):
|
||||
|
||||
for filename, extension in files:
|
||||
# if there are already fingerprints in database, don't re-fingerprint or convert
|
||||
if filename in self.songnames_set:
|
||||
print "-> Already fingerprinted, continuing..."
|
||||
continue
|
||||
|
||||
# convert to WAV
|
||||
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output)
|
||||
|
||||
# for each channel perform FFT analysis and fingerprinting
|
||||
try:
|
||||
channels = self.extract_channels(wavout_path)
|
||||
except AssertionError, e:
|
||||
print "-> File not supported, skipping."
|
||||
continue
|
||||
|
||||
# insert song name into database
|
||||
song_id = sql_connection.insert_song(filename)
|
||||
|
||||
for c in range(len(channels)):
|
||||
channel = channels[c]
|
||||
print "-> Fingerprinting channel %d of song %s..." % (c+1, filename)
|
||||
self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1)
|
||||
|
||||
# remove wav file if not required
|
||||
if not keep_wav:
|
||||
os.unlink(wavout_path)
|
||||
|
||||
# only after done fingerprinting do confirm
|
||||
sql_connection.set_song_fingerprinted(song_id)
|
||||
|
||||
def extract_channels(self, path):
|
||||
"""
|
||||
Reads channels from disk.
|
||||
"""
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(path)
|
||||
wave_object = wave.open(path)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
assert Fs == self.fingerprinter.Fs
|
||||
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
return channels
|
|
@ -1,58 +0,0 @@
|
|||
import os, fnmatch
|
||||
from pydub import AudioSegment
|
||||
|
||||
class Converter():
|
||||
|
||||
WAV = "wav"
|
||||
MP3 = "mp3"
|
||||
FORMATS = [
|
||||
WAV,
|
||||
MP3]
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
if self.config.has_section("input") and self.config.has_option("input", "length"):
|
||||
self.max_input_len = self.config.getint("input", "length")
|
||||
else:
|
||||
self.max_input_len = None
|
||||
|
||||
def ensure_folder(self, extension):
|
||||
if not os.path.exists(extension):
|
||||
os.makedirs(extension)
|
||||
|
||||
def find_files(self, path, extensions):
|
||||
filepaths = []
|
||||
extensions = [e.replace(".", "") for e in extensions if e.replace(".", "") in Converter.FORMATS]
|
||||
print "Supported formats: %s" % extensions
|
||||
for dirpath, dirnames, files in os.walk(path):
|
||||
for extension in extensions:
|
||||
for f in fnmatch.filter(files, "*.%s" % extension):
|
||||
p = os.path.join(dirpath, f)
|
||||
#print "Found file: %s with extension %s" % (renamed, extension)
|
||||
filepaths.append((p, extension))
|
||||
return filepaths
|
||||
|
||||
def convert(self, orig_path, from_format, to_format, output_folder):
|
||||
path, song_name = os.path.split(orig_path)
|
||||
# start conversion
|
||||
self.ensure_folder(output_folder)
|
||||
print "-> Now converting: %s from %s format to %s format..." % (song_name, from_format, to_format)
|
||||
# MP3 --> WAV
|
||||
if from_format == Converter.MP3 and to_format == Converter.WAV:
|
||||
|
||||
newpath = os.path.join(output_folder, "%s.%s" % (song_name, Converter.WAV))
|
||||
if os.path.isfile(newpath):
|
||||
print "-> Already converted, skipping..."
|
||||
else:
|
||||
mp3file = AudioSegment.from_mp3(orig_path)
|
||||
if self.max_input_len:
|
||||
print "-> Reading input seconds: ", self.max_input_len
|
||||
mp3file = mp3file[:self.max_input_len * 1000]
|
||||
mp3file.export(newpath, format=Converter.WAV)
|
||||
|
||||
# unsupported
|
||||
else:
|
||||
print "CONVERSION ERROR:\nThe conversion from %s to %s is not supported!" % (from_format, to_format)
|
||||
|
||||
print "-> Conversion complete."
|
||||
return newpath
|
387
dejavu/database.py
Normal file → Executable file
387
dejavu/database.py
Normal file → Executable file
|
@ -1,325 +1,170 @@
|
|||
import MySQLdb as mysql
|
||||
import MySQLdb.cursors as cursors
|
||||
import os
|
||||
from __future__ import absolute_import
|
||||
import abc
|
||||
|
||||
class SQLDatabase():
|
||||
|
||||
class Database(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
# Name of your Database subclass, this is used in configuration
|
||||
# to refer to your class
|
||||
type = None
|
||||
|
||||
def __init__(self):
|
||||
super(Database, self).__init__()
|
||||
|
||||
def before_fork(self):
|
||||
"""
|
||||
Queries:
|
||||
|
||||
1) Find duplicates (shouldn't be any, though):
|
||||
|
||||
select `hash`, `song_id`, `offset`, count(*) cnt
|
||||
from fingerprints
|
||||
group by `hash`, `song_id`, `offset`
|
||||
having cnt > 1
|
||||
order by cnt asc;
|
||||
|
||||
2) Get number of hashes by song:
|
||||
|
||||
select song_id, song_name, count(song_id) as num
|
||||
from fingerprints
|
||||
natural join songs
|
||||
group by song_id
|
||||
order by count(song_id) desc;
|
||||
|
||||
3) get hashes with highest number of collisions
|
||||
|
||||
select
|
||||
hash,
|
||||
count(distinct song_id) as n
|
||||
from fingerprints
|
||||
group by `hash`
|
||||
order by n DESC;
|
||||
|
||||
=> 26 different songs with same fingerprint (392 times):
|
||||
|
||||
select songs.song_name, fingerprints.offset
|
||||
from fingerprints natural join songs
|
||||
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
|
||||
Called before the database instance is given to the new process
|
||||
"""
|
||||
pass
|
||||
|
||||
# config keys
|
||||
CONNECTION = "connection"
|
||||
KEY_USERNAME = "username"
|
||||
KEY_DATABASE = "database"
|
||||
KEY_PASSWORD = "password"
|
||||
KEY_HOSTNAME = "hostname"
|
||||
def after_fork(self):
|
||||
"""
|
||||
Called after the database instance has been given to the new process
|
||||
|
||||
# tables
|
||||
FINGERPRINTS_TABLENAME = "fingerprints"
|
||||
SONGS_TABLENAME = "songs"
|
||||
|
||||
# fields
|
||||
FIELD_HASH = "hash"
|
||||
FIELD_SONG_ID = "song_id"
|
||||
FIELD_OFFSET = "offset"
|
||||
FIELD_SONGNAME = "song_name"
|
||||
FIELD_FINGERPRINTED = "fingerprinted"
|
||||
|
||||
# creates
|
||||
CREATE_FINGERPRINTS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS `%s` (
|
||||
`%s` binary(10) not null,
|
||||
`%s` mediumint unsigned not null,
|
||||
`%s` int unsigned not null,
|
||||
INDEX(%s),
|
||||
UNIQUE(%s, %s, %s)
|
||||
);""" % (FINGERPRINTS_TABLENAME, FIELD_HASH,
|
||||
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
|
||||
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH)
|
||||
|
||||
CREATE_SONGS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS `%s` (
|
||||
`%s` mediumint unsigned not null auto_increment,
|
||||
`%s` varchar(250) not null,
|
||||
`%s` tinyint default 0,
|
||||
PRIMARY KEY (`%s`),
|
||||
UNIQUE KEY `%s` (`%s`)
|
||||
);""" % (SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
|
||||
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID)
|
||||
|
||||
# inserts
|
||||
INSERT_FINGERPRINT = "INSERT IGNORE INTO %s (%s, %s, %s) VALUES (UNHEX(%%s), %%s, %%s)" % (
|
||||
FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) # ignore duplicates and don't insert them
|
||||
INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % (
|
||||
SONGS_TABLENAME, FIELD_SONGNAME)
|
||||
|
||||
# selects
|
||||
SELECT = "SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
|
||||
SELECT_ALL = "SELECT %s, %s FROM %s;" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
|
||||
SELECT_SONG = "SELECT %s FROM %s WHERE %s = %%s" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
|
||||
SELECT_NUM_FINGERPRINTS = "SELECT COUNT(*) as n FROM %s" % (FINGERPRINTS_TABLENAME)
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = "SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
SELECT_SONGS = "SELECT %s, %s FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
# drops
|
||||
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
|
||||
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
|
||||
|
||||
# update
|
||||
UPDATE_SONG_FINGERPRINTED = "UPDATE %s SET %s = 1 WHERE %s = %%s" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
|
||||
|
||||
# delete
|
||||
DELETE_UNFINGERPRINTED = "DELETE FROM %s WHERE %s = 0;" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
DELETE_ORPHANS = """
|
||||
delete from fingerprints
|
||||
where not exists (
|
||||
select * from songs where fingerprints.song_id = songs.song_id
|
||||
)"""
|
||||
|
||||
def __init__(self, hostname, username, password, database):
|
||||
# connect
|
||||
self.database = database
|
||||
try:
|
||||
# http://www.halfcooked.com/mt/archives/000969.html
|
||||
self.connection = mysql.connect(
|
||||
hostname, username, password,
|
||||
database, cursorclass=cursors.DictCursor)
|
||||
|
||||
self.connection.autocommit(False) # for fast bulk inserts
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
except mysql.Error, e:
|
||||
print "Connection error %d: %s" % (e.args[0], e.args[1])
|
||||
This will be called in the new process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup(self):
|
||||
try:
|
||||
# create fingerprints table
|
||||
self.cursor.execute("USE %s;" % self.database)
|
||||
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE)
|
||||
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
|
||||
self.delete_unfingerprinted_songs()
|
||||
self.connection.commit()
|
||||
except mysql.Error, e:
|
||||
print "Connection error %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
"""
|
||||
Called on creation or shortly afterwards.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def empty(self):
|
||||
"""
|
||||
Drops all tables and re-adds them. Be carfeul with this!
|
||||
Called when the database should be cleared of all data.
|
||||
"""
|
||||
try:
|
||||
self.cursor.execute("USE %s;" % self.database)
|
||||
|
||||
# drop tables
|
||||
self.cursor.execute(SQLDatabase.DROP_FINGERPRINTS)
|
||||
self.cursor.execute(SQLDatabase.DROP_SONGS)
|
||||
|
||||
# recreate
|
||||
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE)
|
||||
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
|
||||
self.connection.commit()
|
||||
|
||||
except mysql.Error, e:
|
||||
print "Error in empty(), %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
|
||||
def delete_orphans(self):
|
||||
try:
|
||||
self.cursor = self.connection.cursor()
|
||||
### TODO: SQLDatabase.DELETE_ORPHANS is not performant enough, need better query
|
||||
### to delete fingerprints for which no song is tied to.
|
||||
#self.cursor.execute(SQLDatabase.DELETE_ORPHANS)
|
||||
#self.connection.commit()
|
||||
except mysql.Error, e:
|
||||
print "Error in delete_orphans(), %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_unfingerprinted_songs(self):
|
||||
try:
|
||||
self.cursor = self.connection.cursor()
|
||||
self.cursor.execute(SQLDatabase.DELETE_UNFINGERPRINTED)
|
||||
self.connection.commit()
|
||||
except mysql.Error, e:
|
||||
print "Error in delete_unfingerprinted_songs(), %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
"""
|
||||
Called to remove any song entries that do not have any fingerprints
|
||||
associated with them.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns number of songs the database has fingerprinted.
|
||||
Returns the amount of songs in the database.
|
||||
"""
|
||||
try:
|
||||
self.cursor = self.connection.cursor()
|
||||
self.cursor.execute(SQLDatabase.SELECT_UNIQUE_SONG_IDS)
|
||||
record = self.cursor.fetchone()
|
||||
return int(record['n'])
|
||||
except mysql.Error, e:
|
||||
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns number of fingerprints the database has fingerprinted.
|
||||
Returns the number of fingerprints in the database.
|
||||
"""
|
||||
try:
|
||||
self.cursor = self.connection.cursor()
|
||||
self.cursor.execute(SQLDatabase.SELECT_NUM_FINGERPRINTS)
|
||||
record = self.cursor.fetchone()
|
||||
return int(record['n'])
|
||||
except mysql.Error, e:
|
||||
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
|
||||
pass
|
||||
|
||||
|
||||
def set_song_fingerprinted(self, song_id):
|
||||
@abc.abstractmethod
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
||||
fingerprinted in the database.
|
||||
"""
|
||||
try:
|
||||
self.cursor = self.connection.cursor()
|
||||
self.cursor.execute(SQLDatabase.UPDATE_SONG_FINGERPRINTED, song_id)
|
||||
self.connection.commit()
|
||||
except mysql.Error, e:
|
||||
print "Error in set_song_fingerprinted(), %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
Sets a specific song as having all fingerprints in the database.
|
||||
|
||||
sid: Song identifier
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_songs(self):
|
||||
"""
|
||||
Return songs that have the fingerprinted flag set TRUE (1).
|
||||
Returns all fully fingerprinted songs in the database.
|
||||
"""
|
||||
try:
|
||||
self.cursor.execute(SQLDatabase.SELECT_SONGS)
|
||||
return self.cursor.fetchall()
|
||||
except mysql.Error, e:
|
||||
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
|
||||
return None
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Returns song by its ID.
|
||||
Return a song by its identifier
|
||||
|
||||
sid: Song identifier
|
||||
"""
|
||||
try:
|
||||
self.cursor.execute(SQLDatabase.SELECT_SONG, (sid,))
|
||||
return self.cursor.fetchone()
|
||||
except mysql.Error, e:
|
||||
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
|
||||
return None
|
||||
pass
|
||||
|
||||
|
||||
def insert(self, key, value):
|
||||
@abc.abstractmethod
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Insert a (sha1, song_id, offset) row into database.
|
||||
Inserts a single fingerprint into the database.
|
||||
|
||||
key is a sha1 hash, value = (song_id, offset)
|
||||
hash: Part of a sha1 hash, in hexadecimal format
|
||||
sid: Song identifier this fingerprint is off
|
||||
offset: The offset this hash is from
|
||||
"""
|
||||
try:
|
||||
args = (key, value[0], value[1])
|
||||
self.cursor.execute(SQLDatabase.INSERT_FINGERPRINT, args)
|
||||
except mysql.Error, e:
|
||||
print "Error in insert(), %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
pass
|
||||
|
||||
def insert_song(self, songname):
|
||||
@abc.abstractmethod
|
||||
def insert_song(self, song_name):
|
||||
"""
|
||||
Inserts song in the database and returns the ID of the inserted record.
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
song_name: The name of the song.
|
||||
"""
|
||||
try:
|
||||
self.cursor.execute(SQLDatabase.INSERT_SONG, (songname,))
|
||||
self.connection.commit()
|
||||
return int(self.cursor.lastrowid)
|
||||
except mysql.Error, e:
|
||||
print "Error in insert_song(), %d: %s" % (e.args[0], e.args[1])
|
||||
self.connection.rollback()
|
||||
return None
|
||||
pass
|
||||
|
||||
def query(self, key):
|
||||
@abc.abstractmethod
|
||||
def query(self, hash):
|
||||
"""
|
||||
Return all tuples associated with hash.
|
||||
Returns all matching fingerprint entries associated with
|
||||
the given hash as parameter.
|
||||
|
||||
If hash is None, returns all entries in the
|
||||
database (be careful with that one!).
|
||||
hash: Part of a sha1 hash, in hexadecimal format
|
||||
"""
|
||||
# select all if no key
|
||||
if key is not None:
|
||||
sql = SQLDatabase.SELECT
|
||||
else:
|
||||
sql = SQLDatabase.SELECT_ALL
|
||||
|
||||
matches = []
|
||||
try:
|
||||
self.cursor.execute(sql, (key,))
|
||||
|
||||
# collect all matches
|
||||
records = self.cursor.fetchall()
|
||||
for record in records:
|
||||
matches.append((record[SQLDatabase.FIELD_SONG_ID], record[SQLDatabase.FIELD_OFFSET]))
|
||||
|
||||
except mysql.Error, e:
|
||||
print "Error in query(), %d: %s" % (e.args[0], e.args[1])
|
||||
|
||||
return matches
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all tuples in database.
|
||||
Returns all fingerprints in the database.
|
||||
"""
|
||||
return self.query(None)
|
||||
pass
|
||||
|
||||
def insert_hashes(self, hashes):
|
||||
@abc.abstractmethod
|
||||
def insert_hashes(self, sid, hashes):
|
||||
"""
|
||||
Insert series of hash => song_id, offset
|
||||
values into the database.
|
||||
"""
|
||||
for h in hashes:
|
||||
sha1, val = h
|
||||
self.insert(sha1, val)
|
||||
self.connection.commit()
|
||||
Insert a multitude of fingerprints.
|
||||
|
||||
sid: Song identifier the fingerprints belong to
|
||||
hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def return_matches(self, hashes):
|
||||
"""
|
||||
Return the (song_id, offset_diff) tuples associated with
|
||||
a list of
|
||||
Searches the database for pairs of (hash, offset) values.
|
||||
|
||||
sha1 => (None, sample_offset)
|
||||
hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
|
||||
values.
|
||||
Returns a sequence of (sid, offset_difference) tuples.
|
||||
|
||||
sid: Song identifier
|
||||
offset_difference: (offset - database_offset)
|
||||
"""
|
||||
matches = []
|
||||
for h in hashes:
|
||||
sha1, val = h
|
||||
list_of_tups = self.query(sha1)
|
||||
if list_of_tups:
|
||||
for t in list_of_tups:
|
||||
# (song_id, db_offset, song_sampled_offset)
|
||||
matches.append((t[0], t[1] - val[1]))
|
||||
return matches
|
||||
pass
|
||||
|
||||
|
||||
def get_database(database_type=None):
|
||||
# Default to using the mysql database
|
||||
database_type = database_type or "mysql"
|
||||
# Lower all the input.
|
||||
database_type = database_type.lower()
|
||||
|
||||
for db_cls in Database.__subclasses__():
|
||||
if db_cls.type == database_type:
|
||||
return db_cls
|
||||
|
||||
raise TypeError("Unsupported database type supplied.")
|
||||
|
||||
|
||||
# Import our default database handler
|
||||
import dejavu.database_sql
|
||||
|
|
374
dejavu/database_sql.py
Normal file
374
dejavu/database_sql.py
Normal file
|
@ -0,0 +1,374 @@
|
|||
from __future__ import absolute_import
|
||||
from itertools import izip_longest
|
||||
import Queue
|
||||
|
||||
import MySQLdb as mysql
|
||||
from MySQLdb.cursors import DictCursor
|
||||
|
||||
from dejavu.database import Database
|
||||
|
||||
|
||||
class SQLDatabase(Database):
|
||||
"""
|
||||
Queries:
|
||||
|
||||
1) Find duplicates (shouldn't be any, though):
|
||||
|
||||
select `hash`, `song_id`, `offset`, count(*) cnt
|
||||
from fingerprints
|
||||
group by `hash`, `song_id`, `offset`
|
||||
having cnt > 1
|
||||
order by cnt asc;
|
||||
|
||||
2) Get number of hashes by song:
|
||||
|
||||
select song_id, song_name, count(song_id) as num
|
||||
from fingerprints
|
||||
natural join songs
|
||||
group by song_id
|
||||
order by count(song_id) desc;
|
||||
|
||||
3) get hashes with highest number of collisions
|
||||
|
||||
select
|
||||
hash,
|
||||
count(distinct song_id) as n
|
||||
from fingerprints
|
||||
group by `hash`
|
||||
order by n DESC;
|
||||
|
||||
=> 26 different songs with same fingerprint (392 times):
|
||||
|
||||
select songs.song_name, fingerprints.offset
|
||||
from fingerprints natural join songs
|
||||
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
|
||||
"""
|
||||
|
||||
type = "mysql"
|
||||
|
||||
# tables
|
||||
FINGERPRINTS_TABLENAME = "fingerprints"
|
||||
SONGS_TABLENAME = "songs"
|
||||
|
||||
# fields
|
||||
FIELD_HASH = "hash"
|
||||
FIELD_SONG_ID = "song_id"
|
||||
FIELD_OFFSET = "offset"
|
||||
FIELD_SONGNAME = "song_name"
|
||||
FIELD_FINGERPRINTED = "fingerprinted"
|
||||
|
||||
# creates
|
||||
CREATE_FINGERPRINTS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS `%s` (
|
||||
`%s` binary(10) not null,
|
||||
`%s` mediumint unsigned not null,
|
||||
`%s` int unsigned not null,
|
||||
PRIMARY KEY(%s),
|
||||
UNIQUE(%s, %s, %s),
|
||||
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
|
||||
) ENGINE=INNODB;""" % (
|
||||
FINGERPRINTS_TABLENAME, FIELD_HASH,
|
||||
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
|
||||
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
|
||||
FIELD_SONG_ID, SONGS_TABLENAME, FIELD_SONG_ID
|
||||
)
|
||||
|
||||
CREATE_SONGS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS `%s` (
|
||||
`%s` mediumint unsigned not null auto_increment,
|
||||
`%s` varchar(250) not null,
|
||||
`%s` tinyint default 0,
|
||||
PRIMARY KEY (`%s`),
|
||||
UNIQUE KEY `%s` (`%s`)
|
||||
) ENGINE=INNODB;""" % (
|
||||
SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
|
||||
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID,
|
||||
)
|
||||
|
||||
# inserts (ignores duplicates)
|
||||
INSERT_FINGERPRINT = """
|
||||
INSERT IGNORE INTO %s (%s, %s, %s) values
|
||||
(UNHEX(%%s), %%s, %%s);
|
||||
""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET)
|
||||
|
||||
INSERT_SONG = "INSERT INTO %s (%s) values (%%s);" % (
|
||||
SONGS_TABLENAME, FIELD_SONGNAME)
|
||||
|
||||
# selects
|
||||
SELECT = """
|
||||
SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);
|
||||
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
|
||||
|
||||
SELECT_MULTIPLE = """
|
||||
SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s);
|
||||
""" % (FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET,
|
||||
FINGERPRINTS_TABLENAME, FIELD_HASH)
|
||||
|
||||
SELECT_ALL = """
|
||||
SELECT %s, %s FROM %s;
|
||||
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
|
||||
|
||||
SELECT_SONG = """
|
||||
SELECT %s FROM %s WHERE %s = %%s
|
||||
""" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = """
|
||||
SELECT COUNT(*) as n FROM %s
|
||||
""" % (FINGERPRINTS_TABLENAME)
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = """
|
||||
SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;
|
||||
""" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
SELECT_SONGS = """
|
||||
SELECT %s, %s FROM %s WHERE %s = 1;
|
||||
""" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
# drops
|
||||
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
|
||||
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
|
||||
|
||||
# update
|
||||
UPDATE_SONG_FINGERPRINTED = """
|
||||
UPDATE %s SET %s = 1 WHERE %s = %%s
|
||||
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
|
||||
|
||||
# delete
|
||||
DELETE_UNFINGERPRINTED = """
|
||||
DELETE FROM %s WHERE %s = 0;
|
||||
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
def __init__(self, **options):
|
||||
super(SQLDatabase, self).__init__()
|
||||
self.cursor = cursor_factory(**options)
|
||||
self._options = options
|
||||
|
||||
def after_fork(self):
|
||||
# Clear the cursor cache, we don't want any stale connections from
|
||||
# the previous process.
|
||||
Cursor.clear_cache()
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Creates any non-existing tables required for dejavu to function.
|
||||
|
||||
This also removes all songs that have been added but have no
|
||||
fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.CREATE_SONGS_TABLE)
|
||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self):
|
||||
"""
|
||||
Drops tables created by dejavu and then creates them again
|
||||
by calling `SQLDatabase.setup`.
|
||||
|
||||
.. warning:
|
||||
This will result in a loss of data
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.DROP_FINGERPRINTS)
|
||||
cur.execute(self.DROP_SONGS)
|
||||
|
||||
self.setup()
|
||||
|
||||
def delete_unfingerprinted_songs(self):
|
||||
"""
|
||||
Removes all songs that have no fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns number of songs the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
|
||||
|
||||
for count, in cur:
|
||||
return count
|
||||
return 0
|
||||
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns number of fingerprints the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.SELECT_NUM_FINGERPRINTS)
|
||||
|
||||
for count, in cur:
|
||||
return count
|
||||
return 0
|
||||
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
||||
fingerprinted in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
|
||||
|
||||
def get_songs(self):
|
||||
"""
|
||||
Return songs that have the fingerprinted flag set TRUE (1).
|
||||
"""
|
||||
with self.cursor(cursor_type=DictCursor) as cur:
|
||||
cur.execute(self.SELECT_SONGS)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Returns song by its ID.
|
||||
"""
|
||||
with self.cursor(cursor_type=DictCursor) as cur:
|
||||
cur.execute(self.SELECT_SONG, (sid,))
|
||||
return cur.fetchone()
|
||||
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Insert a (sha1, song_id, offset) row into database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
|
||||
|
||||
def insert_song(self, songname):
|
||||
"""
|
||||
Inserts song in the database and returns the ID of the inserted record.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.INSERT_SONG, (songname,))
|
||||
return cur.lastrowid
|
||||
|
||||
def query(self, hash):
|
||||
"""
|
||||
Return all tuples associated with hash.
|
||||
|
||||
If hash is None, returns all entries in the
|
||||
database (be careful with that one!).
|
||||
"""
|
||||
# select all if no key
|
||||
query = self.SELECT_ALL if hash is None else self.SELECT
|
||||
|
||||
with self.cursor() as cur:
|
||||
cur.execute(query)
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all tuples in database.
|
||||
"""
|
||||
return self.query(None)
|
||||
|
||||
def insert_hashes(self, sid, hashes):
|
||||
"""
|
||||
Insert series of hash => song_id, offset
|
||||
values into the database.
|
||||
"""
|
||||
values = []
|
||||
for hash, offset in hashes:
|
||||
values.append((hash, sid, offset))
|
||||
|
||||
with self.cursor() as cur:
|
||||
for split_values in grouper(values, 1000):
|
||||
cur.executemany(self.INSERT_FINGERPRINT, split_values)
|
||||
|
||||
def return_matches(self, hashes):
|
||||
"""
|
||||
Return the (song_id, offset_diff) tuples associated with
|
||||
a list of (sha1, sample_offset) values.
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hash, offset in hashes:
|
||||
mapper[hash.upper()] = offset
|
||||
|
||||
# Get an iteratable of all the hashes we need
|
||||
values = mapper.keys()
|
||||
|
||||
with self.cursor() as cur:
|
||||
for split_values in grouper(values, 1000):
|
||||
# Create our IN part of the query
|
||||
query = self.SELECT_MULTIPLE
|
||||
query = query % ', '.join(['UNHEX(%s)'] * len(split_values))
|
||||
|
||||
cur.execute(query, split_values)
|
||||
|
||||
for hash, sid, offset in cur:
|
||||
# (sid, db_offset - song_sampled_offset)
|
||||
yield (sid, offset - mapper[hash])
|
||||
|
||||
def __getstate__(self):
|
||||
return (self._options,)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._options, = state
|
||||
self.cursor = cursor_factory(**self._options)
|
||||
|
||||
|
||||
def grouper(iterable, n, fillvalue=None):
|
||||
args = [iter(iterable)] * n
|
||||
return (filter(None, values) for values
|
||||
in izip_longest(fillvalue=fillvalue, *args))
|
||||
|
||||
|
||||
def cursor_factory(**factory_options):
|
||||
def cursor(**options):
|
||||
options.update(factory_options)
|
||||
return Cursor(**options)
|
||||
return cursor
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
"""
|
||||
Establishes a connection to the database and returns an open cursor.
|
||||
|
||||
|
||||
```python
|
||||
# Use as context manager
|
||||
with Cursor() as cur:
|
||||
cur.execute(query)
|
||||
```
|
||||
"""
|
||||
_cache = Queue.Queue(maxsize=5)
|
||||
|
||||
def __init__(self, cursor_type=mysql.cursors.Cursor, **options):
|
||||
super(Cursor, self).__init__()
|
||||
|
||||
try:
|
||||
conn = self._cache.get_nowait()
|
||||
except Queue.Empty:
|
||||
conn = mysql.connect(**options)
|
||||
else:
|
||||
# Ping the connection before using it from the cache.
|
||||
conn.ping(True)
|
||||
|
||||
self.conn = conn
|
||||
self.conn.autocommit(False)
|
||||
self.cursor_type = cursor_type
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache = Queue.Queue(maxsize=5)
|
||||
|
||||
def __enter__(self):
|
||||
self.cursor = self.conn.cursor(self.cursor_type)
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, extype, exvalue, traceback):
|
||||
# if we had a MySQL related error we try to rollback the cursor.
|
||||
if extype is mysql.MySQLError:
|
||||
self.cursor.rollback()
|
||||
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
|
||||
# Put it back on the queue
|
||||
try:
|
||||
self._cache.put_nowait(self.conn)
|
||||
except Queue.Full:
|
||||
self.conn.close()
|
48
dejavu/decoder.py
Normal file
48
dejavu/decoder.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
import fnmatch
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
def find_files(path, extensions):
|
||||
# Allow both with ".mp3" and without "mp3" to be used for extensions
|
||||
extensions = [e.replace(".", "") for e in extensions]
|
||||
|
||||
for dirpath, dirnames, files in os.walk(path):
|
||||
for extension in extensions:
|
||||
for f in fnmatch.filter(files, "*.%s" % extension):
|
||||
p = os.path.join(dirpath, f)
|
||||
yield (p, extension)
|
||||
|
||||
|
||||
def read(filename, limit=None):
|
||||
"""
|
||||
Reads any file supported by pydub (ffmpeg) and returns the data contained
|
||||
within.
|
||||
|
||||
Can be optionally limited to a certain amount of seconds from the start
|
||||
of the file by specifying the `limit` parameter. This is the amount of
|
||||
seconds from the start of the file.
|
||||
|
||||
returns: (channels, samplerate)
|
||||
"""
|
||||
audiofile = AudioSegment.from_file(filename)
|
||||
|
||||
if limit:
|
||||
audiofile = audiofile[:limit * 1000]
|
||||
|
||||
data = np.fromstring(audiofile._data, np.int16)
|
||||
|
||||
channels = []
|
||||
for chn in xrange(audiofile.channels):
|
||||
channels.append(data[chn::audiofile.channels])
|
||||
|
||||
return channels, audiofile.frame_rate
|
||||
|
||||
|
||||
def path_to_songname(path):
|
||||
"""
|
||||
Extracts song name from a filepath. Used to identify which songs
|
||||
have already been fingerprinted on disk.
|
||||
"""
|
||||
return os.path.basename(path).split(".")[0]
|
176
dejavu/fingerprint.py
Normal file → Executable file
176
dejavu/fingerprint.py
Normal file → Executable file
|
@ -1,19 +1,11 @@
|
|||
import numpy as np
|
||||
import matplotlib.mlab as mlab
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.image as mpimg
|
||||
from scipy.io import wavfile
|
||||
from scipy.ndimage.filters import maximum_filter
|
||||
from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion
|
||||
from dejavu.database import SQLDatabase
|
||||
import os
|
||||
import wave
|
||||
import sys
|
||||
import time
|
||||
from scipy.ndimage.morphology import (generate_binary_structure,
|
||||
iterate_structure, binary_erosion)
|
||||
import hashlib
|
||||
import pickle
|
||||
|
||||
class Fingerprinter():
|
||||
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
||||
|
@ -27,42 +19,12 @@ class Fingerprinter():
|
|||
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||
MIN_HASH_TIME_DELTA = 0
|
||||
|
||||
def __init__(self, config,
|
||||
Fs=DEFAULT_FS,
|
||||
|
||||
def fingerprint(channel_samples, Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
amp_min=DEFAULT_AMP_MIN):
|
||||
|
||||
self.config = config
|
||||
database = SQLDatabase(
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
|
||||
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
|
||||
self.db = database
|
||||
|
||||
self.Fs = Fs
|
||||
self.dt = 1.0 / self.Fs
|
||||
self.window_size = wsize
|
||||
self.window_overlap_ratio = wratio
|
||||
self.fan_value = fan_value
|
||||
self.noverlap = int(self.window_size * self.window_overlap_ratio)
|
||||
self.amp_min = amp_min
|
||||
|
||||
def fingerprint(self, samples, path, sid, cid):
|
||||
"""Used for learning known songs"""
|
||||
hashes = self.process_channel(samples, song_id=sid)
|
||||
print "Generated %d hashes" % len(hashes)
|
||||
self.db.insert_hashes(hashes)
|
||||
|
||||
def match(self, samples):
|
||||
"""Used for matching unknown songs"""
|
||||
hashes = self.process_channel(samples)
|
||||
matches = self.db.return_matches(hashes)
|
||||
return matches
|
||||
|
||||
def process_channel(self, channel_samples, song_id=None):
|
||||
"""
|
||||
FFT the channel, log transform output, find local maxima, then return
|
||||
locally sensitive hashes.
|
||||
|
@ -70,32 +32,35 @@ class Fingerprinter():
|
|||
# FFT the signal and extract frequency components
|
||||
arr2D = mlab.specgram(
|
||||
channel_samples,
|
||||
NFFT=self.window_size,
|
||||
Fs=self.Fs,
|
||||
NFFT=wsize,
|
||||
Fs=Fs,
|
||||
window=mlab.window_hanning,
|
||||
noverlap=self.noverlap)[0]
|
||||
noverlap=int(wsize * wratio))[0]
|
||||
|
||||
# apply log transform since specgram() returns linear array
|
||||
arr2D = 10 * np.log10(arr2D)
|
||||
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
|
||||
|
||||
# find local maxima
|
||||
local_maxima = self.get_2D_peaks(arr2D, plot=False)
|
||||
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||
|
||||
# return hashes
|
||||
return self.generate_hashes(local_maxima, song_id=song_id)
|
||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||
|
||||
def get_2D_peaks(self, arr2D, plot=False):
|
||||
|
||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
|
||||
struct = generate_binary_structure(2, 1)
|
||||
neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE)
|
||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||
|
||||
# find local maxima using our fliter shape
|
||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||
background = (arr2D == 0)
|
||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
||||
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
|
||||
eroded_background = binary_erosion(background, structure=neighborhood,
|
||||
border_value=1)
|
||||
|
||||
# Boolean mask of arr2D with True at peaks
|
||||
detected_peaks = local_max - eroded_background
|
||||
|
||||
# extract peaks
|
||||
amps = arr2D[detected_peaks]
|
||||
|
@ -104,7 +69,7 @@ class Fingerprinter():
|
|||
# filter peaks
|
||||
amps = amps.flatten()
|
||||
peaks = zip(i, j, amps)
|
||||
peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp
|
||||
peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp
|
||||
|
||||
# get indices for frequency and time
|
||||
frequency_idx = [x[1] for x in peaks_filtered]
|
||||
|
@ -117,110 +82,37 @@ class Fingerprinter():
|
|||
ax.scatter(time_idx, frequency_idx)
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Frequency')
|
||||
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
|
||||
ax.set_title("Spectrogram")
|
||||
plt.gca().invert_yaxis()
|
||||
plt.show()
|
||||
|
||||
return zip(frequency_idx, time_idx)
|
||||
|
||||
def generate_hashes(self, peaks, song_id=None):
|
||||
|
||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||
"""
|
||||
Hash list structure:
|
||||
sha1-hash[0:20] song_id, time_offset
|
||||
[(e05b341a9b77a51fd26, (3, 32)), ... ]
|
||||
sha1_hash[0:20] time_offset
|
||||
[(e05b341a9b77a51fd26, 32), ... ]
|
||||
"""
|
||||
fingerprinted = set() # to avoid rehashing same pairs
|
||||
hashes = []
|
||||
|
||||
for i in range(len(peaks)):
|
||||
for j in range(self.fan_value):
|
||||
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
|
||||
for j in range(fan_value):
|
||||
if (i + j) < len(peaks) and not (i, i + j) in fingerprinted:
|
||||
freq1 = peaks[i][IDX_FREQ_I]
|
||||
freq2 = peaks[i + j][IDX_FREQ_I]
|
||||
|
||||
t1 = peaks[i][IDX_TIME_J]
|
||||
t2 = peaks[i + j][IDX_TIME_J]
|
||||
|
||||
freq1 = peaks[i][Fingerprinter.IDX_FREQ_I]
|
||||
freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I]
|
||||
t1 = peaks[i][Fingerprinter.IDX_TIME_J]
|
||||
t2 = peaks[i+j][Fingerprinter.IDX_TIME_J]
|
||||
t_delta = t2 - t1
|
||||
|
||||
if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA:
|
||||
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
||||
hashes.append((h.hexdigest()[0:20], (song_id, t1)))
|
||||
if t_delta >= MIN_HASH_TIME_DELTA:
|
||||
h = hashlib.sha1(
|
||||
"%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))
|
||||
)
|
||||
yield (h.hexdigest()[0:20], t1)
|
||||
|
||||
# ensure we don't repeat hashing
|
||||
fingerprinted.add((i, i + j))
|
||||
return hashes
|
||||
|
||||
def insert_into_db(self, key, value):
|
||||
self.db.insert(key, value)
|
||||
|
||||
def print_stats(self):
|
||||
|
||||
iterable = self.db.get_iterable_kv_pairs()
|
||||
|
||||
counter = {}
|
||||
for t in iterable:
|
||||
sid, toff = t
|
||||
if not sid in counter:
|
||||
counter[sid] = 1
|
||||
else:
|
||||
counter[sid] += 1
|
||||
|
||||
for song_id, count in counter.iteritems():
|
||||
song_name = self.song_names[song_id]
|
||||
print "%s has %d spectrogram peaks" % (song_name, count)
|
||||
|
||||
def set_song_names(self, wpaths):
|
||||
self.song_names = wpaths
|
||||
|
||||
def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
|
||||
Returns a dictionary with match information.
|
||||
"""
|
||||
# align by diffs
|
||||
diff_counter = {}
|
||||
largest = 0
|
||||
largest_count = 0
|
||||
song_id = -1
|
||||
for tup in matches:
|
||||
sid, diff = tup
|
||||
if not diff in diff_counter:
|
||||
diff_counter[diff] = {}
|
||||
if not sid in diff_counter[diff]:
|
||||
diff_counter[diff][sid] = 0
|
||||
diff_counter[diff][sid] += 1
|
||||
|
||||
if diff_counter[diff][sid] > largest_count:
|
||||
largest = diff
|
||||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
|
||||
if verbose:
|
||||
print "Diff is %d with %d offset-aligned matches" % (largest, largest_count)
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
|
||||
else:
|
||||
return None
|
||||
songname = songname.replace("_", " ")
|
||||
elapsed = time.time() - starttime
|
||||
|
||||
if verbose:
|
||||
print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)
|
||||
|
||||
# return match info
|
||||
song = {
|
||||
"song_id" : song_id,
|
||||
"song_name" : songname,
|
||||
"match_time" : elapsed,
|
||||
"confidence" : largest_count
|
||||
}
|
||||
|
||||
if record_seconds:
|
||||
song['record_time'] = record_seconds
|
||||
|
||||
return song
|
||||
|
|
150
dejavu/recognize.py
Normal file → Executable file
150
dejavu/recognize.py
Normal file → Executable file
|
@ -1,72 +1,112 @@
|
|||
from multiprocessing import Queue, Process
|
||||
from dejavu.database import SQLDatabase
|
||||
from scipy.io import wavfile
|
||||
import wave
|
||||
import dejavu.fingerprint as fingerprint
|
||||
import dejavu.decoder as decoder
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import sys
|
||||
import time
|
||||
import array
|
||||
|
||||
class Recognizer(object):
|
||||
|
||||
CHUNK = 8192 # 44100 is a multiple of 1225
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 2
|
||||
RATE = 44100
|
||||
class BaseRecognizer(object):
|
||||
|
||||
def __init__(self, fingerprinter, config):
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = fingerprint.DEFAULT_FS
|
||||
|
||||
self.fingerprinter = fingerprinter
|
||||
self.config = config
|
||||
def _recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
|
||||
class FileRecognizer(BaseRecognizer):
|
||||
def __init__(self, dejavu):
|
||||
super(FileRecognizer, self).__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filename):
|
||||
frames, self.Fs = decoder.read(filename)
|
||||
|
||||
t = time.time()
|
||||
match = self._recognize(*frames)
|
||||
t = time.time() - t
|
||||
|
||||
if match:
|
||||
match['match_time'] = t
|
||||
|
||||
return match
|
||||
|
||||
def recognize(self, filename):
|
||||
return self.recognize_file(filename)
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
||||
default_chunksize = 8192
|
||||
default_format = pyaudio.paInt16
|
||||
default_channels = 2
|
||||
default_samplerate = 44100
|
||||
|
||||
def __init__(self, dejavu):
|
||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
self.data = []
|
||||
self.channels = MicrophoneRecognizer.default_channels
|
||||
self.chunksize = MicrophoneRecognizer.default_chunksize
|
||||
self.samplerate = MicrophoneRecognizer.default_samplerate
|
||||
self.recorded = False
|
||||
|
||||
def read(self, filename, verbose=False):
|
||||
def start_recording(self, channels=default_channels,
|
||||
samplerate=default_samplerate,
|
||||
chunksize=default_chunksize):
|
||||
self.chunksize = chunksize
|
||||
self.channels = channels
|
||||
self.recorded = False
|
||||
self.samplerate = samplerate
|
||||
|
||||
# read file into channels
|
||||
channels = []
|
||||
Fs, frames = wavfile.read(filename)
|
||||
wave_object = wave.open(filename)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
for channel in range(nchannels):
|
||||
channels.append(frames[:, channel])
|
||||
if self.stream:
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
|
||||
# get matches
|
||||
starttime = time.time()
|
||||
matches = []
|
||||
for channel in channels:
|
||||
matches.extend(self.fingerprinter.match(channel))
|
||||
|
||||
return self.fingerprinter.align_matches(matches, starttime, verbose=verbose)
|
||||
|
||||
def listen(self, seconds=10, verbose=False):
|
||||
|
||||
# open stream
|
||||
stream = self.audio.open(format=Recognizer.FORMAT,
|
||||
channels=Recognizer.CHANNELS,
|
||||
rate=Recognizer.RATE,
|
||||
self.stream = self.audio.open(
|
||||
format=self.default_format,
|
||||
channels=channels,
|
||||
rate=samplerate,
|
||||
input=True,
|
||||
frames_per_buffer=Recognizer.CHUNK)
|
||||
frames_per_buffer=chunksize,
|
||||
)
|
||||
|
||||
# record
|
||||
if verbose: print("* recording")
|
||||
left, right = [], []
|
||||
for i in range(0, int(Recognizer.RATE / Recognizer.CHUNK * seconds)):
|
||||
data = stream.read(Recognizer.CHUNK)
|
||||
self.data = [[] for i in range(channels)]
|
||||
|
||||
def process_recording(self):
|
||||
data = self.stream.read(self.chunksize)
|
||||
nums = np.fromstring(data, np.int16)
|
||||
left.extend(nums[1::2])
|
||||
right.extend(nums[0::2])
|
||||
if verbose: print("* done recording")
|
||||
for c in range(self.channels):
|
||||
self.data[c].extend(nums[c::self.channels])
|
||||
|
||||
# close and stop the stream
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
def stop_recording(self):
|
||||
self.stream.stop_stream()
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
self.recorded = True
|
||||
|
||||
# match both channels
|
||||
starttime = time.time()
|
||||
matches = []
|
||||
matches.extend(self.fingerprinter.match(left))
|
||||
matches.extend(self.fingerprinter.match(right))
|
||||
def recognize_recording(self):
|
||||
if not self.recorded:
|
||||
raise NoRecordingError("Recording was not complete/begun")
|
||||
return self._recognize(*self.data)
|
||||
|
||||
# align and return
|
||||
return self.fingerprinter.align_matches(matches, starttime, record_seconds=seconds, verbose=verbose)
|
||||
def get_recorded_time(self):
|
||||
return len(self.data[0]) / self.rate
|
||||
|
||||
def recognize(self, seconds=10):
|
||||
self.start_recording()
|
||||
for i in range(0, int(self.samplerate / self.chunksize
|
||||
* seconds)):
|
||||
self.process_recording()
|
||||
self.stop_recording()
|
||||
return self.recognize_recording()
|
||||
|
||||
|
||||
class NoRecordingError(Exception):
|
||||
pass
|
||||
|
|
34
go.py
Normal file → Executable file
34
go.py
Normal file → Executable file
|
@ -1,20 +1,26 @@
|
|||
from dejavu.control import Dejavu
|
||||
from ConfigParser import ConfigParser
|
||||
from dejavu import Dejavu
|
||||
import warnings
|
||||
import json
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# load config
|
||||
config = ConfigParser()
|
||||
config.read("dejavu.cnf")
|
||||
# load config from a JSON file (or anything outputting a python dictionary)
|
||||
with open("dejavu.cnf") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# create Dejavu object
|
||||
dejavu = Dejavu(config)
|
||||
dejavu.fingerprint("va_us_top_40/mp3", "va_us_top_40/wav", [".mp3"], 5)
|
||||
# create a Dejavu instance
|
||||
djv = Dejavu(config)
|
||||
# Fingerprint all the mp3's in the directory we give it
|
||||
djv.fingerprint_directory("va_us_top_40/mp3", [".mp3"])
|
||||
|
||||
# recognize microphone audio
|
||||
from dejavu.recognize import Recognizer
|
||||
recognizer = Recognizer(dejavu.fingerprinter, config)
|
||||
# Recognize audio from a file
|
||||
from dejavu.recognize import FileRecognizer
|
||||
song = djv.recognize(FileRecognizer, "mp3/beware.mp3")
|
||||
|
||||
# recognize song playing over microphone for 10 seconds
|
||||
song = recognizer.listen(seconds=5, verbose=True)
|
||||
print song
|
||||
# Or recognize audio from your microphone for 10 seconds
|
||||
from dejavu.recognize import MicrophoneRecognizer
|
||||
song = djv.recognize(MicrophoneRecognizer, seconds=2)
|
||||
|
||||
# Or use a recognizer without the shortcut, in anyway you would like
|
||||
from dejavu.recognize import FileRecognizer
|
||||
recognizer = FileRecognizer(djv)
|
||||
song = recognizer.recognize_file("va_us_top_40/wav/17_-_#Beautiful_-_Mariah_Carey_ft.wav")
|
||||
|
|
159
performance.py
159
performance.py
|
@ -1,159 +0,0 @@
|
|||
from dejavu.control import Dejavu
|
||||
from dejavu.recognize import Recognizer
|
||||
from dejavu.convert import Converter
|
||||
from dejavu.database import SQLDatabase
|
||||
from ConfigParser import ConfigParser
|
||||
from scipy.io import wavfile
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
import pyaudio
|
||||
import os, wave, sys
|
||||
import random
|
||||
import numpy as np
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
config = ConfigParser()
|
||||
config.read("dejavu.cnf")
|
||||
dejavu = Dejavu(config)
|
||||
recognizer = Recognizer(dejavu.fingerprinter, config)
|
||||
|
||||
def test_recording_lengths(recognizer):
|
||||
|
||||
# settings for run
|
||||
RATE = 44100
|
||||
FORMAT = pyaudio.paInt16
|
||||
padding_seconds = 10
|
||||
SONG_PADDING = RATE * padding_seconds
|
||||
OUTPUT_FILE = "output.wav"
|
||||
p = pyaudio.PyAudio()
|
||||
c = Converter()
|
||||
files = c.find_files("va_us_top_40/wav/", [".wav"])[-25:]
|
||||
total = len(files)
|
||||
recording_lengths = [4]
|
||||
correct = 0
|
||||
count = 0
|
||||
score = {}
|
||||
|
||||
for r in recording_lengths:
|
||||
|
||||
RECORD_LENGTH = RATE * r
|
||||
|
||||
for tup in files:
|
||||
f, ext = tup
|
||||
|
||||
# read the file
|
||||
#print "reading: %s" % f
|
||||
Fs, frames = wavfile.read(f)
|
||||
wave_object = wave.open(f)
|
||||
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
|
||||
|
||||
# chose at random a segment of audio to play
|
||||
possible_end = num_frames - SONG_PADDING - RECORD_LENGTH
|
||||
possible_start = SONG_PADDING
|
||||
if possible_end - possible_start < RECORD_LENGTH:
|
||||
print "ERROR! Song is too short to sample based on padding and recording seconds preferences."
|
||||
sys.exit()
|
||||
start = random.randint(possible_start, possible_end)
|
||||
end = start + RECORD_LENGTH + 1
|
||||
|
||||
# get that segment of samples
|
||||
channels = []
|
||||
frames = frames[start:end, :]
|
||||
wav_string = frames.tostring()
|
||||
|
||||
# write to disk
|
||||
wf = wave.open(OUTPUT_FILE, 'wb')
|
||||
wf.setnchannels(nchannels)
|
||||
wf.setsampwidth(p.get_sample_size(FORMAT))
|
||||
wf.setframerate(RATE)
|
||||
wf.writeframes(b''.join(wav_string))
|
||||
wf.close()
|
||||
|
||||
# play and test
|
||||
correctname = os.path.basename(f).replace(".wav", "").replace("_", " ")
|
||||
inp = raw_input("Click ENTER when playing %s ..." % OUTPUT_FILE)
|
||||
song = recognizer.listen(seconds=r+1, verbose=False)
|
||||
print "PREDICTED: %s" % song['song_name']
|
||||
print "ACTUAL: %s" % correctname
|
||||
if song['song_name'] == correctname:
|
||||
correct += 1
|
||||
count += 1
|
||||
|
||||
print "Currently %d correct out of %d in total of %d" % (correct, count, total)
|
||||
|
||||
score[r] = (correct, total)
|
||||
print "UPDATE AFTER %d TRIAL: %s" % (r, score)
|
||||
|
||||
return score
|
||||
|
||||
def plot_match_time_trials():
|
||||
|
||||
# I did this manually
|
||||
t = np.array([1, 2, 3, 4, 5, 6, 7, 8, 10, 15, 25, 30, 45, 60])
|
||||
m = np.array([.47, .79, 1.1, 1.5, 1.8, 2.18, 2.62, 2.8, 3.65, 5.29, 8.92, 10.63, 16.09, 22.29])
|
||||
mplust = t + m
|
||||
|
||||
# linear regression
|
||||
A = np.matrix([t, np.ones(len(t))])
|
||||
print A
|
||||
w = np.linalg.lstsq(A.T, mplust)[0]
|
||||
line = w[0] * t + w[1]
|
||||
print "Equation for line is %f * record_time + %f = time_to_match" % (w[0], w[1])
|
||||
|
||||
# and plot
|
||||
plt.title("Recording vs Matching time for \"Get Lucky\" by Daft Punk")
|
||||
plt.xlabel("Time recorded (s)")
|
||||
plt.ylabel("Time recorded + time to match (s)")
|
||||
#plt.scatter(t, mplust)
|
||||
plt.plot(t, line, 'r-', t, mplust, 'o')
|
||||
plt.show()
|
||||
|
||||
def plot_accuracy():
|
||||
# also did this manually
|
||||
secs = np.array([1, 2, 3, 4, 5, 6])
|
||||
correct = np.array([27.0, 43.0, 44.0, 44.0, 45.0, 45.0])
|
||||
total = 45.0
|
||||
correct = correct / total
|
||||
|
||||
plt.title("Dejavu Recognition Accuracy as a Function of Time")
|
||||
plt.xlabel("Time recorded (s)")
|
||||
plt.ylabel("Accuracy")
|
||||
plt.plot(secs, correct)
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.show()
|
||||
|
||||
def plot_hashes_per_song():
|
||||
squery = """select song_name, count(song_id) as num
|
||||
from fingerprints
|
||||
natural join songs
|
||||
group by song_name
|
||||
order by count(song_id) asc;"""
|
||||
sql = SQLDatabase(username="root", password="root", database="dejavu", hostname="localhost")
|
||||
cursor = sql.connection.cursor()
|
||||
cursor.execute(squery)
|
||||
counts = cursor.fetchall()
|
||||
|
||||
songs = []
|
||||
count = []
|
||||
for item in counts:
|
||||
songs.append(item['song_name'].replace("_", " ")[4:])
|
||||
count.append(item['num'])
|
||||
|
||||
pos = np.arange(len(songs)) + 0.5
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.barh(pos, count, align='center')
|
||||
ax.set_yticks(pos, tuple(songs))
|
||||
|
||||
ax.axvline(0, color='k', lw=3)
|
||||
|
||||
ax.set_xlabel('Number of Fingerprints')
|
||||
ax.set_title('Number of Fingerprints by Song')
|
||||
ax.grid(True)
|
||||
plt.show()
|
||||
|
||||
#plot_accuracy()
|
||||
#score = test_recording_lengths(recognizer)
|
||||
#plot_match_time_trials()
|
||||
#plot_hashes_per_song()
|
Loading…
Reference in a new issue