bugfixes and tweaks. also now returns offset of matched song

This commit is contained in:
worldveil 2013-12-25 02:56:43 -06:00
commit 9eca3cc05a
11 changed files with 920 additions and 881 deletions

BIN
README.md

Binary file not shown.

171
dejavu/__init__.py Normal file → Executable file
View file

@ -0,0 +1,171 @@
from dejavu.database import get_database
import dejavu.decoder as decoder
import fingerprint
import multiprocessing
import os
class Dejavu(object):
def __init__(self, config):
super(Dejavu, self).__init__()
self.config = config
# initialize db
db_cls = get_database(config.get("database_type", None))
self.db = db_cls(**config.get("database", {}))
self.db.setup()
# get songs previously indexed
# TODO: should probably use a checksum of the file instead of filename
self.songs = self.db.get_songs()
self.songnames_set = set() # to know which ones we've computed before
for song in self.songs:
song_name = song[self.db.FIELD_SONGNAME]
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
def fingerprint_directory(self, path, extensions, nprocesses=None):
# Try to use the maximum amount of processes if not given.
try:
nprocesses = nprocesses or multiprocessing.cpu_count()
except NotImplementedError:
nprocesses = 1
else:
nprocesses = 1 if nprocesses <= 0 else nprocesses
pool = multiprocessing.Pool(nprocesses)
results = []
for filename, _ in decoder.find_files(path, extensions):
# don't refingerprint already fingerprinted files
if decoder.path_to_songname(filename) in self.songnames_set:
print "%s already fingerprinted, continuing..." % filename
continue
result = pool.apply_async(_fingerprint_worker,
(filename, self.db))
results.append(result)
while len(results):
for result in results[:]:
# TODO: Handle errors gracefully and return them to the callee
# in some way.
try:
result.get(timeout=2)
except multiprocessing.TimeoutError:
continue
except:
import traceback, sys
traceback.print_exc(file=sys.stdout)
results.remove(result)
else:
results.remove(result)
pool.close()
pool.join()
def fingerprint_file(self, filepath, song_name=None):
channels, Fs = decoder.read(filepath)
if not song_name:
print "Song name: %s" % song_name
song_name = decoder.path_to_songname(filepath)
song_id = self.db.insert_song(song_name)
for data in channels:
hashes = fingerprint.fingerprint(data, Fs=Fs)
self.db.insert_hashes(song_id, hashes)
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
hashes = fingerprint.fingerprint(samples, Fs=Fs)
return self.db.return_matches(hashes)
def align_matches(self, matches):
"""
Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio.
Returns a dictionary with match information.
"""
# align by diffs
diff_counter = {}
largest = 0
largest_count = 0
song_id = -1
for tup in matches:
sid, diff = tup
if not diff in diff_counter:
diff_counter[diff] = {}
if not sid in diff_counter[diff]:
diff_counter[diff][sid] = 0
diff_counter[diff][sid] += 1
if diff_counter[diff][sid] > largest_count:
largest = diff
largest_count = diff_counter[diff][sid]
song_id = sid
print("Diff is %d with %d offset-aligned matches" % (largest,
largest_count))
# extract idenfication
song = self.db.get_song_by_id(song_id)
if song:
# TODO: Clarifey what `get_song_by_id` should return.
songname = song.get("song_name", None)
else:
return None
# return match info
song = {
"song_id": song_id,
"song_name": songname,
"confidence": largest_count,
"offset" : largest
}
return song
def recognize(self, recognizer, *options, **kwoptions):
r = recognizer(self)
return r.recognize(*options, **kwoptions)
def _fingerprint_worker(filename, db):
song_name, extension = os.path.splitext(os.path.basename(filename))
channels, Fs = decoder.read(filename)
# insert song into database
sid = db.insert_song(song_name)
channel_amount = len(channels)
for channeln, channel in enumerate(channels):
# TODO: Remove prints or change them into optional logging.
print("Fingerprinting channel %d/%d for %s" % (channeln + 1,
channel_amount,
filename))
hashes = fingerprint.fingerprint(channel, Fs=Fs)
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
filename))
print("Inserting fingerprints for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
db.insert_hashes(sid, hashes)
print("Finished inserting for channel %d/%d for %s" %
(channeln + 1, channel_amount, filename))
print("Marking %s finished" % (filename,))
db.set_song_fingerprinted(sid)
print("%s finished" % (filename,))
def chunkify(lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]

View file

@ -1,120 +0,0 @@
from dejavu.database import SQLDatabase
from dejavu.convert import Converter
from dejavu.fingerprint import Fingerprinter
from scipy.io import wavfile
from multiprocessing import Process
import wave, os
import random
class Dejavu():
def __init__(self, config):
self.config = config
# create components
self.converter = Converter(config)
self.fingerprinter = Fingerprinter(self.config)
self.fingerprinter.db.setup()
# get songs previously indexed
self.songs = self.fingerprinter.db.get_songs()
self.songnames_set = set() # to know which ones we've computed before
if self.songs:
for song in self.songs:
song_id = song[SQLDatabase.FIELD_SONG_ID]
song_name = song[SQLDatabase.FIELD_SONGNAME]
self.songnames_set.add(song_name)
print "Added: %s to the set of fingerprinted songs..." % song_name
def chunkify(self, lst, n):
"""
Splits a list into roughly n equal parts.
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
"""
return [lst[i::n] for i in xrange(n)]
def fingerprint(self, path, output, extensions, nprocesses, keep_wav=False):
# convert files, shuffle order
files = self.converter.find_files(path, extensions)
random.shuffle(files)
files_split = self.chunkify(files, nprocesses)
# split into processes here
processes = []
for i in range(nprocesses):
# need database instance since mysql connections shouldn't be shared across processes
sql_connection = SQLDatabase(
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
# create process and start it
p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output, keep_wav))
p.start()
processes.append(p)
# wait for all processes to complete
try:
for p in processes:
p.join()
except KeyboardInterrupt:
print "-> Exiting.."
for worker in processes:
worker.terminate()
worker.join()
# delete orphans
# print "Done fingerprinting. Deleting orphaned fingerprints..."
# TODO: need a more performant query in database.py for the
#self.fingerprinter.db.delete_orphans()
def fingerprint_worker(self, files, sql_connection, output, keep_wav):
for filename, extension in files:
# if there are already fingerprints in database, don't re-fingerprint or convert
if filename in self.songnames_set:
print "-> Already fingerprinted, continuing..."
continue
# convert to WAV
wavout_path = self.converter.convert(filename, extension, Converter.WAV, output)
# for each channel perform FFT analysis and fingerprinting
try:
channels = self.extract_channels(wavout_path)
except AssertionError, e:
print "-> File not supported, skipping."
continue
# insert song name into database
song_id = sql_connection.insert_song(filename)
for c in range(len(channels)):
channel = channels[c]
print "-> Fingerprinting channel %d of song %s..." % (c+1, filename)
self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1)
# remove wav file if not required
if not keep_wav:
os.unlink(wavout_path)
# only after done fingerprinting do confirm
sql_connection.set_song_fingerprinted(song_id)
def extract_channels(self, path):
"""
Reads channels from disk.
"""
channels = []
Fs, frames = wavfile.read(path)
wave_object = wave.open(path)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
assert Fs == self.fingerprinter.Fs
for channel in range(nchannels):
channels.append(frames[:, channel])
return channels

View file

@ -1,58 +0,0 @@
import os, fnmatch
from pydub import AudioSegment
class Converter():
WAV = "wav"
MP3 = "mp3"
FORMATS = [
WAV,
MP3]
def __init__(self, config):
self.config = config
if self.config.has_section("input") and self.config.has_option("input", "length"):
self.max_input_len = self.config.getint("input", "length")
else:
self.max_input_len = None
def ensure_folder(self, extension):
if not os.path.exists(extension):
os.makedirs(extension)
def find_files(self, path, extensions):
filepaths = []
extensions = [e.replace(".", "") for e in extensions if e.replace(".", "") in Converter.FORMATS]
print "Supported formats: %s" % extensions
for dirpath, dirnames, files in os.walk(path):
for extension in extensions:
for f in fnmatch.filter(files, "*.%s" % extension):
p = os.path.join(dirpath, f)
#print "Found file: %s with extension %s" % (renamed, extension)
filepaths.append((p, extension))
return filepaths
def convert(self, orig_path, from_format, to_format, output_folder):
path, song_name = os.path.split(orig_path)
# start conversion
self.ensure_folder(output_folder)
print "-> Now converting: %s from %s format to %s format..." % (song_name, from_format, to_format)
# MP3 --> WAV
if from_format == Converter.MP3 and to_format == Converter.WAV:
newpath = os.path.join(output_folder, "%s.%s" % (song_name, Converter.WAV))
if os.path.isfile(newpath):
print "-> Already converted, skipping..."
else:
mp3file = AudioSegment.from_mp3(orig_path)
if self.max_input_len:
print "-> Reading input seconds: ", self.max_input_len
mp3file = mp3file[:self.max_input_len * 1000]
mp3file.export(newpath, format=Converter.WAV)
# unsupported
else:
print "CONVERSION ERROR:\nThe conversion from %s to %s is not supported!" % (from_format, to_format)
print "-> Conversion complete."
return newpath

387
dejavu/database.py Normal file → Executable file
View file

@ -1,325 +1,170 @@
import MySQLdb as mysql from __future__ import absolute_import
import MySQLdb.cursors as cursors import abc
import os
class SQLDatabase():
"""
Queries:
1) Find duplicates (shouldn't be any, though): class Database(object):
__metaclass__ = abc.ABCMeta
select `hash`, `song_id`, `offset`, count(*) cnt # Name of your Database subclass, this is used in configuration
from fingerprints # to refer to your class
group by `hash`, `song_id`, `offset` type = None
having cnt > 1
order by cnt asc;
2) Get number of hashes by song: def __init__(self):
super(Database, self).__init__()
select song_id, song_name, count(song_id) as num def before_fork(self):
from fingerprints """
natural join songs Called before the database instance is given to the new process
group by song_id """
order by count(song_id) desc; pass
3) get hashes with highest number of collisions def after_fork(self):
"""
Called after the database instance has been given to the new process
select This will be called in the new process.
hash, """
count(distinct song_id) as n pass
from fingerprints
group by `hash`
order by n DESC;
=> 26 different songs with same fingerprint (392 times):
select songs.song_name, fingerprints.offset
from fingerprints natural join songs
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
"""
# config keys
CONNECTION = "connection"
KEY_USERNAME = "username"
KEY_DATABASE = "database"
KEY_PASSWORD = "password"
KEY_HOSTNAME = "hostname"
# tables
FINGERPRINTS_TABLENAME = "fingerprints"
SONGS_TABLENAME = "songs"
# fields
FIELD_HASH = "hash"
FIELD_SONG_ID = "song_id"
FIELD_OFFSET = "offset"
FIELD_SONGNAME = "song_name"
FIELD_FINGERPRINTED = "fingerprinted"
# creates
CREATE_FINGERPRINTS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
`%s` binary(10) not null,
`%s` mediumint unsigned not null,
`%s` int unsigned not null,
INDEX(%s),
UNIQUE(%s, %s, %s)
);""" % (FINGERPRINTS_TABLENAME, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH)
CREATE_SONGS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
`%s` mediumint unsigned not null auto_increment,
`%s` varchar(250) not null,
`%s` tinyint default 0,
PRIMARY KEY (`%s`),
UNIQUE KEY `%s` (`%s`)
);""" % (SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID)
# inserts
INSERT_FINGERPRINT = "INSERT IGNORE INTO %s (%s, %s, %s) VALUES (UNHEX(%%s), %%s, %%s)" % (
FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) # ignore duplicates and don't insert them
INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % (
SONGS_TABLENAME, FIELD_SONGNAME)
# selects
SELECT = "SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_ALL = "SELECT %s, %s FROM %s;" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
SELECT_SONG = "SELECT %s FROM %s WHERE %s = %%s" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
SELECT_NUM_FINGERPRINTS = "SELECT COUNT(*) as n FROM %s" % (FINGERPRINTS_TABLENAME)
SELECT_UNIQUE_SONG_IDS = "SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_SONGS = "SELECT %s, %s FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
# drops
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
# update
UPDATE_SONG_FINGERPRINTED = "UPDATE %s SET %s = 1 WHERE %s = %%s" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
# delete
DELETE_UNFINGERPRINTED = "DELETE FROM %s WHERE %s = 0;" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
DELETE_ORPHANS = """
delete from fingerprints
where not exists (
select * from songs where fingerprints.song_id = songs.song_id
)"""
def __init__(self, hostname, username, password, database):
# connect
self.database = database
try:
# http://www.halfcooked.com/mt/archives/000969.html
self.connection = mysql.connect(
hostname, username, password,
database, cursorclass=cursors.DictCursor)
self.connection.autocommit(False) # for fast bulk inserts
self.cursor = self.connection.cursor()
except mysql.Error, e:
print "Connection error %d: %s" % (e.args[0], e.args[1])
def setup(self): def setup(self):
try: """
# create fingerprints table Called on creation or shortly afterwards.
self.cursor.execute("USE %s;" % self.database) """
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE) pass
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
self.delete_unfingerprinted_songs()
self.connection.commit()
except mysql.Error, e:
print "Connection error %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
@abc.abstractmethod
def empty(self): def empty(self):
""" """
Drops all tables and re-adds them. Be carfeul with this! Called when the database should be cleared of all data.
""" """
try: pass
self.cursor.execute("USE %s;" % self.database)
# drop tables @abc.abstractmethod
self.cursor.execute(SQLDatabase.DROP_FINGERPRINTS)
self.cursor.execute(SQLDatabase.DROP_SONGS)
# recreate
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE)
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
self.connection.commit()
except mysql.Error, e:
print "Error in empty(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def delete_orphans(self):
try:
self.cursor = self.connection.cursor()
### TODO: SQLDatabase.DELETE_ORPHANS is not performant enough, need better query
### to delete fingerprints for which no song is tied to.
#self.cursor.execute(SQLDatabase.DELETE_ORPHANS)
#self.connection.commit()
except mysql.Error, e:
print "Error in delete_orphans(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def delete_unfingerprinted_songs(self): def delete_unfingerprinted_songs(self):
try: """
self.cursor = self.connection.cursor() Called to remove any song entries that do not have any fingerprints
self.cursor.execute(SQLDatabase.DELETE_UNFINGERPRINTED) associated with them.
self.connection.commit() """
except mysql.Error, e: pass
print "Error in delete_unfingerprinted_songs(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
@abc.abstractmethod
def get_num_songs(self): def get_num_songs(self):
""" """
Returns number of songs the database has fingerprinted. Returns the amount of songs in the database.
""" """
try: pass
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.SELECT_UNIQUE_SONG_IDS) @abc.abstractmethod
record = self.cursor.fetchone()
return int(record['n'])
except mysql.Error, e:
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
def get_num_fingerprints(self): def get_num_fingerprints(self):
""" """
Returns number of fingerprints the database has fingerprinted. Returns the number of fingerprints in the database.
""" """
try: pass
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.SELECT_NUM_FINGERPRINTS)
record = self.cursor.fetchone()
return int(record['n'])
except mysql.Error, e:
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
def set_song_fingerprinted(self, song_id): @abc.abstractmethod
def set_song_fingerprinted(self, sid):
""" """
Set the fingerprinted flag to TRUE (1) once a song has been completely Sets a specific song as having all fingerprints in the database.
fingerprinted in the database.
"""
try:
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.UPDATE_SONG_FINGERPRINTED, song_id)
self.connection.commit()
except mysql.Error, e:
print "Error in set_song_fingerprinted(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
sid: Song identifier
"""
pass
@abc.abstractmethod
def get_songs(self): def get_songs(self):
""" """
Return songs that have the fingerprinted flag set TRUE (1). Returns all fully fingerprinted songs in the database.
""" """
try: pass
self.cursor.execute(SQLDatabase.SELECT_SONGS)
return self.cursor.fetchall() @abc.abstractmethod
except mysql.Error, e:
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
return None
def get_song_by_id(self, sid): def get_song_by_id(self, sid):
""" """
Returns song by its ID. Return a song by its identifier
sid: Song identifier
""" """
try: pass
self.cursor.execute(SQLDatabase.SELECT_SONG, (sid,))
return self.cursor.fetchone()
except mysql.Error, e:
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
return None
def insert(self, key, value): @abc.abstractmethod
def insert(self, hash, sid, offset):
""" """
Insert a (sha1, song_id, offset) row into database. Inserts a single fingerprint into the database.
key is a sha1 hash, value = (song_id, offset) hash: Part of a sha1 hash, in hexadecimal format
sid: Song identifier this fingerprint is off
offset: The offset this hash is from
""" """
try: pass
args = (key, value[0], value[1])
self.cursor.execute(SQLDatabase.INSERT_FINGERPRINT, args)
except mysql.Error, e:
print "Error in insert(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def insert_song(self, songname): @abc.abstractmethod
def insert_song(self, song_name):
""" """
Inserts song in the database and returns the ID of the inserted record. Inserts a song name into the database, returns the new
identifier of the song.
song_name: The name of the song.
""" """
try: pass
self.cursor.execute(SQLDatabase.INSERT_SONG, (songname,))
self.connection.commit()
return int(self.cursor.lastrowid)
except mysql.Error, e:
print "Error in insert_song(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
return None
def query(self, key): @abc.abstractmethod
def query(self, hash):
""" """
Return all tuples associated with hash. Returns all matching fingerprint entries associated with
the given hash as parameter.
If hash is None, returns all entries in the hash: Part of a sha1 hash, in hexadecimal format
database (be careful with that one!).
""" """
# select all if no key pass
if key is not None:
sql = SQLDatabase.SELECT
else:
sql = SQLDatabase.SELECT_ALL
matches = []
try:
self.cursor.execute(sql, (key,))
# collect all matches
records = self.cursor.fetchall()
for record in records:
matches.append((record[SQLDatabase.FIELD_SONG_ID], record[SQLDatabase.FIELD_OFFSET]))
except mysql.Error, e:
print "Error in query(), %d: %s" % (e.args[0], e.args[1])
return matches
@abc.abstractmethod
def get_iterable_kv_pairs(self): def get_iterable_kv_pairs(self):
""" """
Returns all tuples in database. Returns all fingerprints in the database.
""" """
return self.query(None) pass
def insert_hashes(self, hashes): @abc.abstractmethod
def insert_hashes(self, sid, hashes):
""" """
Insert series of hash => song_id, offset Insert a multitude of fingerprints.
values into the database.
"""
for h in hashes:
sha1, val = h
self.insert(sha1, val)
self.connection.commit()
sid: Song identifier the fingerprints belong to
hashes: A sequence of tuples in the format (hash, offset)
- hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at.
"""
pass
@abc.abstractmethod
def return_matches(self, hashes): def return_matches(self, hashes):
""" """
Return the (song_id, offset_diff) tuples associated with Searches the database for pairs of (hash, offset) values.
a list of
sha1 => (None, sample_offset) hashes: A sequence of tuples in the format (hash, offset)
- hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at.
values. Returns a sequence of (sid, offset_difference) tuples.
sid: Song identifier
offset_difference: (offset - database_offset)
""" """
matches = [] pass
for h in hashes:
sha1, val = h
list_of_tups = self.query(sha1) def get_database(database_type=None):
if list_of_tups: # Default to using the mysql database
for t in list_of_tups: database_type = database_type or "mysql"
# (song_id, db_offset, song_sampled_offset) # Lower all the input.
matches.append((t[0], t[1] - val[1])) database_type = database_type.lower()
return matches
for db_cls in Database.__subclasses__():
if db_cls.type == database_type:
return db_cls
raise TypeError("Unsupported database type supplied.")
# Import our default database handler
import dejavu.database_sql

374
dejavu/database_sql.py Normal file
View file

@ -0,0 +1,374 @@
from __future__ import absolute_import
from itertools import izip_longest
import Queue
import MySQLdb as mysql
from MySQLdb.cursors import DictCursor
from dejavu.database import Database
class SQLDatabase(Database):
"""
Queries:
1) Find duplicates (shouldn't be any, though):
select `hash`, `song_id`, `offset`, count(*) cnt
from fingerprints
group by `hash`, `song_id`, `offset`
having cnt > 1
order by cnt asc;
2) Get number of hashes by song:
select song_id, song_name, count(song_id) as num
from fingerprints
natural join songs
group by song_id
order by count(song_id) desc;
3) get hashes with highest number of collisions
select
hash,
count(distinct song_id) as n
from fingerprints
group by `hash`
order by n DESC;
=> 26 different songs with same fingerprint (392 times):
select songs.song_name, fingerprints.offset
from fingerprints natural join songs
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
"""
type = "mysql"
# tables
FINGERPRINTS_TABLENAME = "fingerprints"
SONGS_TABLENAME = "songs"
# fields
FIELD_HASH = "hash"
FIELD_SONG_ID = "song_id"
FIELD_OFFSET = "offset"
FIELD_SONGNAME = "song_name"
FIELD_FINGERPRINTED = "fingerprinted"
# creates
CREATE_FINGERPRINTS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
`%s` binary(10) not null,
`%s` mediumint unsigned not null,
`%s` int unsigned not null,
PRIMARY KEY(%s),
UNIQUE(%s, %s, %s),
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
) ENGINE=INNODB;""" % (
FINGERPRINTS_TABLENAME, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
FIELD_SONG_ID, SONGS_TABLENAME, FIELD_SONG_ID
)
CREATE_SONGS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
`%s` mediumint unsigned not null auto_increment,
`%s` varchar(250) not null,
`%s` tinyint default 0,
PRIMARY KEY (`%s`),
UNIQUE KEY `%s` (`%s`)
) ENGINE=INNODB;""" % (
SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID,
)
# inserts (ignores duplicates)
INSERT_FINGERPRINT = """
INSERT IGNORE INTO %s (%s, %s, %s) values
(UNHEX(%%s), %%s, %%s);
""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET)
INSERT_SONG = "INSERT INTO %s (%s) values (%%s);" % (
SONGS_TABLENAME, FIELD_SONGNAME)
# selects
SELECT = """
SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_MULTIPLE = """
SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s);
""" % (FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET,
FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_ALL = """
SELECT %s, %s FROM %s;
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
SELECT_SONG = """
SELECT %s FROM %s WHERE %s = %%s
""" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
SELECT_NUM_FINGERPRINTS = """
SELECT COUNT(*) as n FROM %s
""" % (FINGERPRINTS_TABLENAME)
SELECT_UNIQUE_SONG_IDS = """
SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;
""" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_SONGS = """
SELECT %s, %s FROM %s WHERE %s = 1;
""" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
# drops
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
# update
UPDATE_SONG_FINGERPRINTED = """
UPDATE %s SET %s = 1 WHERE %s = %%s
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
# delete
DELETE_UNFINGERPRINTED = """
DELETE FROM %s WHERE %s = 0;
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
def __init__(self, **options):
super(SQLDatabase, self).__init__()
self.cursor = cursor_factory(**options)
self._options = options
def after_fork(self):
# Clear the cursor cache, we don't want any stale connections from
# the previous process.
Cursor.clear_cache()
def setup(self):
"""
Creates any non-existing tables required for dejavu to function.
This also removes all songs that have been added but have no
fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(self.CREATE_SONGS_TABLE)
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.DELETE_UNFINGERPRINTED)
def empty(self):
"""
Drops tables created by dejavu and then creates them again
by calling `SQLDatabase.setup`.
.. warning:
This will result in a loss of data
"""
with self.cursor() as cur:
cur.execute(self.DROP_FINGERPRINTS)
cur.execute(self.DROP_SONGS)
self.setup()
def delete_unfingerprinted_songs(self):
"""
Removes all songs that have no fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(self.DELETE_UNFINGERPRINTED)
def get_num_songs(self):
"""
Returns number of songs the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
for count, in cur:
return count
return 0
def get_num_fingerprints(self):
"""
Returns number of fingerprints the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(self.SELECT_NUM_FINGERPRINTS)
for count, in cur:
return count
return 0
def set_song_fingerprinted(self, sid):
"""
Set the fingerprinted flag to TRUE (1) once a song has been completely
fingerprinted in the database.
"""
with self.cursor() as cur:
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
def get_songs(self):
"""
Return songs that have the fingerprinted flag set TRUE (1).
"""
with self.cursor(cursor_type=DictCursor) as cur:
cur.execute(self.SELECT_SONGS)
for row in cur:
yield row
def get_song_by_id(self, sid):
"""
Returns song by its ID.
"""
with self.cursor(cursor_type=DictCursor) as cur:
cur.execute(self.SELECT_SONG, (sid,))
return cur.fetchone()
def insert(self, hash, sid, offset):
"""
Insert a (sha1, song_id, offset) row into database.
"""
with self.cursor() as cur:
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
def insert_song(self, songname):
"""
Inserts song in the database and returns the ID of the inserted record.
"""
with self.cursor() as cur:
cur.execute(self.INSERT_SONG, (songname,))
return cur.lastrowid
def query(self, hash):
"""
Return all tuples associated with hash.
If hash is None, returns all entries in the
database (be careful with that one!).
"""
# select all if no key
query = self.SELECT_ALL if hash is None else self.SELECT
with self.cursor() as cur:
cur.execute(query)
for sid, offset in cur:
yield (sid, offset)
def get_iterable_kv_pairs(self):
"""
Returns all tuples in database.
"""
return self.query(None)
def insert_hashes(self, sid, hashes):
"""
Insert series of hash => song_id, offset
values into the database.
"""
values = []
for hash, offset in hashes:
values.append((hash, sid, offset))
with self.cursor() as cur:
for split_values in grouper(values, 1000):
cur.executemany(self.INSERT_FINGERPRINT, split_values)
def return_matches(self, hashes):
"""
Return the (song_id, offset_diff) tuples associated with
a list of (sha1, sample_offset) values.
"""
# Create a dictionary of hash => offset pairs for later lookups
mapper = {}
for hash, offset in hashes:
mapper[hash.upper()] = offset
# Get an iteratable of all the hashes we need
values = mapper.keys()
with self.cursor() as cur:
for split_values in grouper(values, 1000):
# Create our IN part of the query
query = self.SELECT_MULTIPLE
query = query % ', '.join(['UNHEX(%s)'] * len(split_values))
cur.execute(query, split_values)
for hash, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hash])
def __getstate__(self):
return (self._options,)
def __setstate__(self, state):
self._options, = state
self.cursor = cursor_factory(**self._options)
def grouper(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return (filter(None, values) for values
in izip_longest(fillvalue=fillvalue, *args))
def cursor_factory(**factory_options):
def cursor(**options):
options.update(factory_options)
return Cursor(**options)
return cursor
class Cursor(object):
"""
Establishes a connection to the database and returns an open cursor.
```python
# Use as context manager
with Cursor() as cur:
cur.execute(query)
```
"""
_cache = Queue.Queue(maxsize=5)
def __init__(self, cursor_type=mysql.cursors.Cursor, **options):
super(Cursor, self).__init__()
try:
conn = self._cache.get_nowait()
except Queue.Empty:
conn = mysql.connect(**options)
else:
# Ping the connection before using it from the cache.
conn.ping(True)
self.conn = conn
self.conn.autocommit(False)
self.cursor_type = cursor_type
@classmethod
def clear_cache(cls):
cls._cache = Queue.Queue(maxsize=5)
def __enter__(self):
self.cursor = self.conn.cursor(self.cursor_type)
return self.cursor
def __exit__(self, extype, exvalue, traceback):
# if we had a MySQL related error we try to rollback the cursor.
if extype is mysql.MySQLError:
self.cursor.rollback()
self.cursor.close()
self.conn.commit()
# Put it back on the queue
try:
self._cache.put_nowait(self.conn)
except Queue.Full:
self.conn.close()

48
dejavu/decoder.py Normal file
View file

@ -0,0 +1,48 @@
import os
import fnmatch
import numpy as np
from pydub import AudioSegment
def find_files(path, extensions):
# Allow both with ".mp3" and without "mp3" to be used for extensions
extensions = [e.replace(".", "") for e in extensions]
for dirpath, dirnames, files in os.walk(path):
for extension in extensions:
for f in fnmatch.filter(files, "*.%s" % extension):
p = os.path.join(dirpath, f)
yield (p, extension)
def read(filename, limit=None):
"""
Reads any file supported by pydub (ffmpeg) and returns the data contained
within.
Can be optionally limited to a certain amount of seconds from the start
of the file by specifying the `limit` parameter. This is the amount of
seconds from the start of the file.
returns: (channels, samplerate)
"""
audiofile = AudioSegment.from_file(filename)
if limit:
audiofile = audiofile[:limit * 1000]
data = np.fromstring(audiofile._data, np.int16)
channels = []
for chn in xrange(audiofile.channels):
channels.append(data[chn::audiofile.channels])
return channels, audiofile.frame_rate
def path_to_songname(path):
"""
Extracts song name from a filepath. Used to identify which songs
have already been fingerprinted on disk.
"""
return os.path.basename(path).split(".")[0]

284
dejavu/fingerprint.py Normal file → Executable file
View file

@ -1,226 +1,118 @@
import numpy as np import numpy as np
import matplotlib.mlab as mlab import matplotlib.mlab as mlab
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy.io import wavfile
from scipy.ndimage.filters import maximum_filter from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion from scipy.ndimage.morphology import (generate_binary_structure,
from dejavu.database import SQLDatabase iterate_structure, binary_erosion)
import os
import wave
import sys
import time
import hashlib import hashlib
import pickle
class Fingerprinter():
IDX_FREQ_I = 0 IDX_FREQ_I = 0
IDX_TIME_J = 1 IDX_TIME_J = 1
DEFAULT_FS = 44100
DEFAULT_WINDOW_SIZE = 4096
DEFAULT_OVERLAP_RATIO = 0.5
DEFAULT_FAN_VALUE = 15
DEFAULT_AMP_MIN = 10
PEAK_NEIGHBORHOOD_SIZE = 20
MIN_HASH_TIME_DELTA = 0
def __init__(self, config, DEFAULT_FS = 44100
Fs=DEFAULT_FS, DEFAULT_WINDOW_SIZE = 4096
wsize=DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO = 0.5
wratio=DEFAULT_OVERLAP_RATIO, DEFAULT_FAN_VALUE = 15
fan_value=DEFAULT_FAN_VALUE,
amp_min=DEFAULT_AMP_MIN):
self.config = config DEFAULT_AMP_MIN = 10
database = SQLDatabase( PEAK_NEIGHBORHOOD_SIZE = 20
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), MIN_HASH_TIME_DELTA = 0
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD),
self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE))
self.db = database
self.Fs = Fs
self.dt = 1.0 / self.Fs
self.window_size = wsize
self.window_overlap_ratio = wratio
self.fan_value = fan_value
self.noverlap = int(self.window_size * self.window_overlap_ratio)
self.amp_min = amp_min
def fingerprint(self, samples, path, sid, cid): def fingerprint(channel_samples, Fs=DEFAULT_FS,
"""Used for learning known songs""" wsize=DEFAULT_WINDOW_SIZE,
hashes = self.process_channel(samples, song_id=sid) wratio=DEFAULT_OVERLAP_RATIO,
print "Generated %d hashes" % len(hashes) fan_value=DEFAULT_FAN_VALUE,
self.db.insert_hashes(hashes) amp_min=DEFAULT_AMP_MIN):
"""
FFT the channel, log transform output, find local maxima, then return
locally sensitive hashes.
"""
# FFT the signal and extract frequency components
arr2D = mlab.specgram(
channel_samples,
NFFT=wsize,
Fs=Fs,
window=mlab.window_hanning,
noverlap=int(wsize * wratio))[0]
def match(self, samples): # apply log transform since specgram() returns linear array
"""Used for matching unknown songs""" arr2D = 10 * np.log10(arr2D)
hashes = self.process_channel(samples) arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
matches = self.db.return_matches(hashes)
return matches
def process_channel(self, channel_samples, song_id=None): # find local maxima
""" local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
FFT the channel, log transform output, find local maxima, then return
locally sensitive hashes.
"""
# FFT the signal and extract frequency components
arr2D = mlab.specgram(
channel_samples,
NFFT=self.window_size,
Fs=self.Fs,
window=mlab.window_hanning,
noverlap=self.noverlap)[0]
# apply log transform since specgram() returns linear array # return hashes
arr2D = 10 * np.log10(arr2D) return generate_hashes(local_maxima, fan_value=fan_value)
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
# find local maxima
local_maxima = self.get_2D_peaks(arr2D, plot=False)
# return hashes
return self.generate_hashes(local_maxima, song_id=song_id)
def get_2D_peaks(self, arr2D, plot=False): def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure
struct = generate_binary_structure(2, 1)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure # find local maxima using our fliter shape
struct = generate_binary_structure(2, 1) local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE) background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood,
border_value=1)
# find local maxima using our fliter shape # Boolean mask of arr2D with True at peaks
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D detected_peaks = local_max - eroded_background
background = (arr2D == 0)
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks
# extract peaks # extract peaks
amps = arr2D[detected_peaks] amps = arr2D[detected_peaks]
j, i = np.where(detected_peaks) j, i = np.where(detected_peaks)
# filter peaks # filter peaks
amps = amps.flatten() amps = amps.flatten()
peaks = zip(i, j, amps) peaks = zip(i, j, amps)
peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp
# get indices for frequency and time
frequency_idx = [x[1] for x in peaks_filtered]
time_idx = [x[0] for x in peaks_filtered]
if plot: # get indices for frequency and time
# scatter of the peaks frequency_idx = [x[1] for x in peaks_filtered]
fig, ax = plt.subplots() time_idx = [x[0] for x in peaks_filtered]
ax.imshow(arr2D)
ax.scatter(time_idx, frequency_idx)
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke");
plt.gca().invert_yaxis()
plt.show()
return zip(frequency_idx, time_idx) if plot:
# scatter of the peaks
fig, ax = plt.subplots()
ax.imshow(arr2D)
ax.scatter(time_idx, frequency_idx)
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title("Spectrogram")
plt.gca().invert_yaxis()
plt.show()
def generate_hashes(self, peaks, song_id=None): return zip(frequency_idx, time_idx)
"""
Hash list structure:
sha1-hash[0:20] song_id, time_offset
[(e05b341a9b77a51fd26, (3, 32)), ... ]
"""
fingerprinted = set() # to avoid rehashing same pairs
hashes = []
for i in range(len(peaks)):
for j in range(self.fan_value):
if i+j < len(peaks) and not (i, i+j) in fingerprinted:
freq1 = peaks[i][Fingerprinter.IDX_FREQ_I] def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I] """
t1 = peaks[i][Fingerprinter.IDX_TIME_J] Hash list structure:
t2 = peaks[i+j][Fingerprinter.IDX_TIME_J] sha1_hash[0:20] time_offset
t_delta = t2 - t1 [(e05b341a9b77a51fd26, 32), ... ]
"""
if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA: fingerprinted = set() # to avoid rehashing same pairs
h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
hashes.append((h.hexdigest()[0:20], (song_id, t1)))
# ensure we don't repeat hashing
fingerprinted.add((i, i+j))
return hashes
def insert_into_db(self, key, value): for i in range(len(peaks)):
self.db.insert(key, value) for j in range(fan_value):
if (i + j) < len(peaks) and not (i, i + j) in fingerprinted:
freq1 = peaks[i][IDX_FREQ_I]
freq2 = peaks[i + j][IDX_FREQ_I]
def print_stats(self): t1 = peaks[i][IDX_TIME_J]
t2 = peaks[i + j][IDX_TIME_J]
iterable = self.db.get_iterable_kv_pairs() t_delta = t2 - t1
counter = {} if t_delta >= MIN_HASH_TIME_DELTA:
for t in iterable: h = hashlib.sha1(
sid, toff = t "%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))
if not sid in counter: )
counter[sid] = 1 yield (h.hexdigest()[0:20], t1)
else:
counter[sid] += 1
for song_id, count in counter.iteritems(): # ensure we don't repeat hashing
song_name = self.song_names[song_id] fingerprinted.add((i, i + j))
print "%s has %d spectrogram peaks" % (song_name, count)
def set_song_names(self, wpaths):
self.song_names = wpaths
def align_matches(self, matches, starttime, record_seconds=0, verbose=False):
"""
Finds hash matches that align in time with other matches and finds
consensus about which hashes are "true" signal from the audio.
Returns a dictionary with match information.
"""
# align by diffs
diff_counter = {}
largest = 0
largest_count = 0
song_id = -1
for tup in matches:
sid, diff = tup
if not diff in diff_counter:
diff_counter[diff] = {}
if not sid in diff_counter[diff]:
diff_counter[diff][sid] = 0
diff_counter[diff][sid] += 1
if diff_counter[diff][sid] > largest_count:
largest = diff
largest_count = diff_counter[diff][sid]
song_id = sid
if verbose:
print "Diff is %d with %d offset-aligned matches" % (largest, largest_count)
# extract idenfication
song = self.db.get_song_by_id(song_id)
if song:
songname = song.get(SQLDatabase.FIELD_SONGNAME, None)
else:
return None
songname = songname.replace("_", " ")
elapsed = time.time() - starttime
if verbose:
print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed)
# return match info
song = {
"song_id" : song_id,
"song_name" : songname,
"match_time" : elapsed,
"confidence" : largest_count
}
if record_seconds:
song['record_time'] = record_seconds
return song

166
dejavu/recognize.py Normal file → Executable file
View file

@ -1,72 +1,112 @@
from multiprocessing import Queue, Process import dejavu.fingerprint as fingerprint
from dejavu.database import SQLDatabase import dejavu.decoder as decoder
from scipy.io import wavfile
import wave
import numpy as np import numpy as np
import pyaudio import pyaudio
import sys
import time import time
import array
class Recognizer(object):
CHUNK = 8192 # 44100 is a multiple of 1225 class BaseRecognizer(object):
FORMAT = pyaudio.paInt16
CHANNELS = 2
RATE = 44100
def __init__(self, fingerprinter, config): def __init__(self, dejavu):
self.dejavu = dejavu
self.Fs = fingerprint.DEFAULT_FS
self.fingerprinter = fingerprinter def _recognize(self, *data):
self.config = config matches = []
for d in data:
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
return self.dejavu.align_matches(matches)
def recognize(self):
pass # base class does nothing
class FileRecognizer(BaseRecognizer):
def __init__(self, dejavu):
super(FileRecognizer, self).__init__(dejavu)
def recognize_file(self, filename):
frames, self.Fs = decoder.read(filename)
t = time.time()
match = self._recognize(*frames)
t = time.time() - t
if match:
match['match_time'] = t
return match
def recognize(self, filename):
return self.recognize_file(filename)
class MicrophoneRecognizer(BaseRecognizer):
default_chunksize = 8192
default_format = pyaudio.paInt16
default_channels = 2
default_samplerate = 44100
def __init__(self, dejavu):
super(MicrophoneRecognizer, self).__init__(dejavu)
self.audio = pyaudio.PyAudio() self.audio = pyaudio.PyAudio()
self.stream = None
def read(self, filename, verbose=False): self.data = []
self.channels = MicrophoneRecognizer.default_channels
# read file into channels self.chunksize = MicrophoneRecognizer.default_chunksize
channels = [] self.samplerate = MicrophoneRecognizer.default_samplerate
Fs, frames = wavfile.read(filename) self.recorded = False
wave_object = wave.open(filename)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
for channel in range(nchannels):
channels.append(frames[:, channel])
# get matches
starttime = time.time()
matches = []
for channel in channels:
matches.extend(self.fingerprinter.match(channel))
return self.fingerprinter.align_matches(matches, starttime, verbose=verbose)
def listen(self, seconds=10, verbose=False):
# open stream def start_recording(self, channels=default_channels,
stream = self.audio.open(format=Recognizer.FORMAT, samplerate=default_samplerate,
channels=Recognizer.CHANNELS, chunksize=default_chunksize):
rate=Recognizer.RATE, self.chunksize = chunksize
input=True, self.channels = channels
frames_per_buffer=Recognizer.CHUNK) self.recorded = False
self.samplerate = samplerate
# record
if verbose: print("* recording") if self.stream:
left, right = [], [] self.stream.stop_stream()
for i in range(0, int(Recognizer.RATE / Recognizer.CHUNK * seconds)): self.stream.close()
data = stream.read(Recognizer.CHUNK)
nums = np.fromstring(data, np.int16) self.stream = self.audio.open(
left.extend(nums[1::2]) format=self.default_format,
right.extend(nums[0::2]) channels=channels,
if verbose: print("* done recording") rate=samplerate,
input=True,
# close and stop the stream frames_per_buffer=chunksize,
stream.stop_stream() )
stream.close()
self.data = [[] for i in range(channels)]
# match both channels
starttime = time.time() def process_recording(self):
matches = [] data = self.stream.read(self.chunksize)
matches.extend(self.fingerprinter.match(left)) nums = np.fromstring(data, np.int16)
matches.extend(self.fingerprinter.match(right)) for c in range(self.channels):
self.data[c].extend(nums[c::self.channels])
# align and return
return self.fingerprinter.align_matches(matches, starttime, record_seconds=seconds, verbose=verbose) def stop_recording(self):
self.stream.stop_stream()
self.stream.close()
self.stream = None
self.recorded = True
def recognize_recording(self):
if not self.recorded:
raise NoRecordingError("Recording was not complete/begun")
return self._recognize(*self.data)
def get_recorded_time(self):
return len(self.data[0]) / self.rate
def recognize(self, seconds=10):
self.start_recording()
for i in range(0, int(self.samplerate / self.chunksize
* seconds)):
self.process_recording()
self.stop_recording()
return self.recognize_recording()
class NoRecordingError(Exception):
pass

34
go.py Normal file → Executable file
View file

@ -1,20 +1,26 @@
from dejavu.control import Dejavu from dejavu import Dejavu
from ConfigParser import ConfigParser
import warnings import warnings
import json
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
# load config # load config from a JSON file (or anything outputting a python dictionary)
config = ConfigParser() with open("dejavu.cnf") as f:
config.read("dejavu.cnf") config = json.load(f)
# create Dejavu object # create a Dejavu instance
dejavu = Dejavu(config) djv = Dejavu(config)
dejavu.fingerprint("va_us_top_40/mp3", "va_us_top_40/wav", [".mp3"], 5) # Fingerprint all the mp3's in the directory we give it
djv.fingerprint_directory("va_us_top_40/mp3", [".mp3"])
# recognize microphone audio # Recognize audio from a file
from dejavu.recognize import Recognizer from dejavu.recognize import FileRecognizer
recognizer = Recognizer(dejavu.fingerprinter, config) song = djv.recognize(FileRecognizer, "mp3/beware.mp3")
# recognize song playing over microphone for 10 seconds # Or recognize audio from your microphone for 10 seconds
song = recognizer.listen(seconds=5, verbose=True) from dejavu.recognize import MicrophoneRecognizer
print song song = djv.recognize(MicrophoneRecognizer, seconds=2)
# Or use a recognizer without the shortcut, in anyway you would like
from dejavu.recognize import FileRecognizer
recognizer = FileRecognizer(djv)
song = recognizer.recognize_file("va_us_top_40/wav/17_-_#Beautiful_-_Mariah_Carey_ft.wav")

View file

@ -1,159 +0,0 @@
from dejavu.control import Dejavu
from dejavu.recognize import Recognizer
from dejavu.convert import Converter
from dejavu.database import SQLDatabase
from ConfigParser import ConfigParser
from scipy.io import wavfile
import matplotlib.pyplot as plt
import warnings
import pyaudio
import os, wave, sys
import random
import numpy as np
warnings.filterwarnings("ignore")
config = ConfigParser()
config.read("dejavu.cnf")
dejavu = Dejavu(config)
recognizer = Recognizer(dejavu.fingerprinter, config)
def test_recording_lengths(recognizer):
# settings for run
RATE = 44100
FORMAT = pyaudio.paInt16
padding_seconds = 10
SONG_PADDING = RATE * padding_seconds
OUTPUT_FILE = "output.wav"
p = pyaudio.PyAudio()
c = Converter()
files = c.find_files("va_us_top_40/wav/", [".wav"])[-25:]
total = len(files)
recording_lengths = [4]
correct = 0
count = 0
score = {}
for r in recording_lengths:
RECORD_LENGTH = RATE * r
for tup in files:
f, ext = tup
# read the file
#print "reading: %s" % f
Fs, frames = wavfile.read(f)
wave_object = wave.open(f)
nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams()
# chose at random a segment of audio to play
possible_end = num_frames - SONG_PADDING - RECORD_LENGTH
possible_start = SONG_PADDING
if possible_end - possible_start < RECORD_LENGTH:
print "ERROR! Song is too short to sample based on padding and recording seconds preferences."
sys.exit()
start = random.randint(possible_start, possible_end)
end = start + RECORD_LENGTH + 1
# get that segment of samples
channels = []
frames = frames[start:end, :]
wav_string = frames.tostring()
# write to disk
wf = wave.open(OUTPUT_FILE, 'wb')
wf.setnchannels(nchannels)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(wav_string))
wf.close()
# play and test
correctname = os.path.basename(f).replace(".wav", "").replace("_", " ")
inp = raw_input("Click ENTER when playing %s ..." % OUTPUT_FILE)
song = recognizer.listen(seconds=r+1, verbose=False)
print "PREDICTED: %s" % song['song_name']
print "ACTUAL: %s" % correctname
if song['song_name'] == correctname:
correct += 1
count += 1
print "Currently %d correct out of %d in total of %d" % (correct, count, total)
score[r] = (correct, total)
print "UPDATE AFTER %d TRIAL: %s" % (r, score)
return score
def plot_match_time_trials():
# I did this manually
t = np.array([1, 2, 3, 4, 5, 6, 7, 8, 10, 15, 25, 30, 45, 60])
m = np.array([.47, .79, 1.1, 1.5, 1.8, 2.18, 2.62, 2.8, 3.65, 5.29, 8.92, 10.63, 16.09, 22.29])
mplust = t + m
# linear regression
A = np.matrix([t, np.ones(len(t))])
print A
w = np.linalg.lstsq(A.T, mplust)[0]
line = w[0] * t + w[1]
print "Equation for line is %f * record_time + %f = time_to_match" % (w[0], w[1])
# and plot
plt.title("Recording vs Matching time for \"Get Lucky\" by Daft Punk")
plt.xlabel("Time recorded (s)")
plt.ylabel("Time recorded + time to match (s)")
#plt.scatter(t, mplust)
plt.plot(t, line, 'r-', t, mplust, 'o')
plt.show()
def plot_accuracy():
# also did this manually
secs = np.array([1, 2, 3, 4, 5, 6])
correct = np.array([27.0, 43.0, 44.0, 44.0, 45.0, 45.0])
total = 45.0
correct = correct / total
plt.title("Dejavu Recognition Accuracy as a Function of Time")
plt.xlabel("Time recorded (s)")
plt.ylabel("Accuracy")
plt.plot(secs, correct)
plt.ylim([0.0, 1.05])
plt.show()
def plot_hashes_per_song():
squery = """select song_name, count(song_id) as num
from fingerprints
natural join songs
group by song_name
order by count(song_id) asc;"""
sql = SQLDatabase(username="root", password="root", database="dejavu", hostname="localhost")
cursor = sql.connection.cursor()
cursor.execute(squery)
counts = cursor.fetchall()
songs = []
count = []
for item in counts:
songs.append(item['song_name'].replace("_", " ")[4:])
count.append(item['num'])
pos = np.arange(len(songs)) + 0.5
fig = plt.figure()
ax = fig.add_subplot(111)
ax.barh(pos, count, align='center')
ax.set_yticks(pos, tuple(songs))
ax.axvline(0, color='k', lw=3)
ax.set_xlabel('Number of Fingerprints')
ax.set_title('Number of Fingerprints by Song')
ax.grid(True)
plt.show()
#plot_accuracy()
#score = test_recording_lengths(recognizer)
#plot_match_time_trials()
#plot_hashes_per_song()