diff --git a/README.md b/README.md index 76699f3..51e5e0a 100644 Binary files a/README.md and b/README.md differ diff --git a/dejavu/__init__.py b/dejavu/__init__.py old mode 100644 new mode 100755 index e69de29..bf47743 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -0,0 +1,171 @@ +from dejavu.database import get_database +import dejavu.decoder as decoder +import fingerprint +import multiprocessing +import os + + +class Dejavu(object): + def __init__(self, config): + super(Dejavu, self).__init__() + + self.config = config + + # initialize db + db_cls = get_database(config.get("database_type", None)) + + self.db = db_cls(**config.get("database", {})) + self.db.setup() + + # get songs previously indexed + # TODO: should probably use a checksum of the file instead of filename + self.songs = self.db.get_songs() + self.songnames_set = set() # to know which ones we've computed before + for song in self.songs: + song_name = song[self.db.FIELD_SONGNAME] + self.songnames_set.add(song_name) + print "Added: %s to the set of fingerprinted songs..." % song_name + + def fingerprint_directory(self, path, extensions, nprocesses=None): + # Try to use the maximum amount of processes if not given. + try: + nprocesses = nprocesses or multiprocessing.cpu_count() + except NotImplementedError: + nprocesses = 1 + else: + nprocesses = 1 if nprocesses <= 0 else nprocesses + + pool = multiprocessing.Pool(nprocesses) + + results = [] + for filename, _ in decoder.find_files(path, extensions): + + # don't refingerprint already fingerprinted files + if decoder.path_to_songname(filename) in self.songnames_set: + print "%s already fingerprinted, continuing..." % filename + continue + + result = pool.apply_async(_fingerprint_worker, + (filename, self.db)) + results.append(result) + + while len(results): + for result in results[:]: + # TODO: Handle errors gracefully and return them to the callee + # in some way. + try: + result.get(timeout=2) + except multiprocessing.TimeoutError: + continue + except: + import traceback, sys + traceback.print_exc(file=sys.stdout) + results.remove(result) + else: + results.remove(result) + + pool.close() + pool.join() + + def fingerprint_file(self, filepath, song_name=None): + channels, Fs = decoder.read(filepath) + + if not song_name: + print "Song name: %s" % song_name + song_name = decoder.path_to_songname(filepath) + song_id = self.db.insert_song(song_name) + + for data in channels: + hashes = fingerprint.fingerprint(data, Fs=Fs) + self.db.insert_hashes(song_id, hashes) + + def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS): + hashes = fingerprint.fingerprint(samples, Fs=Fs) + return self.db.return_matches(hashes) + + def align_matches(self, matches): + """ + Finds hash matches that align in time with other matches and finds + consensus about which hashes are "true" signal from the audio. + + Returns a dictionary with match information. + """ + # align by diffs + diff_counter = {} + largest = 0 + largest_count = 0 + song_id = -1 + for tup in matches: + sid, diff = tup + if not diff in diff_counter: + diff_counter[diff] = {} + if not sid in diff_counter[diff]: + diff_counter[diff][sid] = 0 + diff_counter[diff][sid] += 1 + + if diff_counter[diff][sid] > largest_count: + largest = diff + largest_count = diff_counter[diff][sid] + song_id = sid + + print("Diff is %d with %d offset-aligned matches" % (largest, + largest_count)) + + # extract idenfication + song = self.db.get_song_by_id(song_id) + if song: + # TODO: Clarifey what `get_song_by_id` should return. + songname = song.get("song_name", None) + else: + return None + + # return match info + song = { + "song_id": song_id, + "song_name": songname, + "confidence": largest_count, + "offset" : largest + } + + return song + + def recognize(self, recognizer, *options, **kwoptions): + r = recognizer(self) + return r.recognize(*options, **kwoptions) + + +def _fingerprint_worker(filename, db): + song_name, extension = os.path.splitext(os.path.basename(filename)) + + channels, Fs = decoder.read(filename) + + # insert song into database + sid = db.insert_song(song_name) + + channel_amount = len(channels) + for channeln, channel in enumerate(channels): + # TODO: Remove prints or change them into optional logging. + print("Fingerprinting channel %d/%d for %s" % (channeln + 1, + channel_amount, + filename)) + hashes = fingerprint.fingerprint(channel, Fs=Fs) + print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount, + filename)) + + print("Inserting fingerprints for channel %d/%d for %s" % + (channeln + 1, channel_amount, filename)) + db.insert_hashes(sid, hashes) + print("Finished inserting for channel %d/%d for %s" % + (channeln + 1, channel_amount, filename)) + + print("Marking %s finished" % (filename,)) + db.set_song_fingerprinted(sid) + print("%s finished" % (filename,)) + + +def chunkify(lst, n): + """ + Splits a list into roughly n equal parts. + http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts + """ + return [lst[i::n] for i in xrange(n)] diff --git a/dejavu/control.py b/dejavu/control.py deleted file mode 100644 index 85144be..0000000 --- a/dejavu/control.py +++ /dev/null @@ -1,120 +0,0 @@ -from dejavu.database import SQLDatabase -from dejavu.convert import Converter -from dejavu.fingerprint import Fingerprinter -from scipy.io import wavfile -from multiprocessing import Process -import wave, os -import random - -class Dejavu(): - - def __init__(self, config): - - self.config = config - - # create components - self.converter = Converter(config) - self.fingerprinter = Fingerprinter(self.config) - self.fingerprinter.db.setup() - - # get songs previously indexed - self.songs = self.fingerprinter.db.get_songs() - self.songnames_set = set() # to know which ones we've computed before - if self.songs: - for song in self.songs: - song_id = song[SQLDatabase.FIELD_SONG_ID] - song_name = song[SQLDatabase.FIELD_SONGNAME] - self.songnames_set.add(song_name) - print "Added: %s to the set of fingerprinted songs..." % song_name - - def chunkify(self, lst, n): - """ - Splits a list into roughly n equal parts. - http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts - """ - return [lst[i::n] for i in xrange(n)] - - def fingerprint(self, path, output, extensions, nprocesses, keep_wav=False): - - # convert files, shuffle order - files = self.converter.find_files(path, extensions) - random.shuffle(files) - files_split = self.chunkify(files, nprocesses) - - # split into processes here - processes = [] - for i in range(nprocesses): - - # need database instance since mysql connections shouldn't be shared across processes - sql_connection = SQLDatabase( - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE)) - - # create process and start it - p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output, keep_wav)) - p.start() - processes.append(p) - - # wait for all processes to complete - try: - for p in processes: - p.join() - except KeyboardInterrupt: - print "-> Exiting.." - for worker in processes: - worker.terminate() - worker.join() - - # delete orphans - # print "Done fingerprinting. Deleting orphaned fingerprints..." - # TODO: need a more performant query in database.py for the - #self.fingerprinter.db.delete_orphans() - - def fingerprint_worker(self, files, sql_connection, output, keep_wav): - - for filename, extension in files: - # if there are already fingerprints in database, don't re-fingerprint or convert - if filename in self.songnames_set: - print "-> Already fingerprinted, continuing..." - continue - - # convert to WAV - wavout_path = self.converter.convert(filename, extension, Converter.WAV, output) - - # for each channel perform FFT analysis and fingerprinting - try: - channels = self.extract_channels(wavout_path) - except AssertionError, e: - print "-> File not supported, skipping." - continue - - # insert song name into database - song_id = sql_connection.insert_song(filename) - - for c in range(len(channels)): - channel = channels[c] - print "-> Fingerprinting channel %d of song %s..." % (c+1, filename) - self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1) - - # remove wav file if not required - if not keep_wav: - os.unlink(wavout_path) - - # only after done fingerprinting do confirm - sql_connection.set_song_fingerprinted(song_id) - - def extract_channels(self, path): - """ - Reads channels from disk. - """ - channels = [] - Fs, frames = wavfile.read(path) - wave_object = wave.open(path) - nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() - assert Fs == self.fingerprinter.Fs - - for channel in range(nchannels): - channels.append(frames[:, channel]) - return channels diff --git a/dejavu/convert.py b/dejavu/convert.py deleted file mode 100644 index f1e7b2c..0000000 --- a/dejavu/convert.py +++ /dev/null @@ -1,58 +0,0 @@ -import os, fnmatch -from pydub import AudioSegment - -class Converter(): - - WAV = "wav" - MP3 = "mp3" - FORMATS = [ - WAV, - MP3] - - def __init__(self, config): - self.config = config - if self.config.has_section("input") and self.config.has_option("input", "length"): - self.max_input_len = self.config.getint("input", "length") - else: - self.max_input_len = None - - def ensure_folder(self, extension): - if not os.path.exists(extension): - os.makedirs(extension) - - def find_files(self, path, extensions): - filepaths = [] - extensions = [e.replace(".", "") for e in extensions if e.replace(".", "") in Converter.FORMATS] - print "Supported formats: %s" % extensions - for dirpath, dirnames, files in os.walk(path): - for extension in extensions: - for f in fnmatch.filter(files, "*.%s" % extension): - p = os.path.join(dirpath, f) - #print "Found file: %s with extension %s" % (renamed, extension) - filepaths.append((p, extension)) - return filepaths - - def convert(self, orig_path, from_format, to_format, output_folder): - path, song_name = os.path.split(orig_path) - # start conversion - self.ensure_folder(output_folder) - print "-> Now converting: %s from %s format to %s format..." % (song_name, from_format, to_format) - # MP3 --> WAV - if from_format == Converter.MP3 and to_format == Converter.WAV: - - newpath = os.path.join(output_folder, "%s.%s" % (song_name, Converter.WAV)) - if os.path.isfile(newpath): - print "-> Already converted, skipping..." - else: - mp3file = AudioSegment.from_mp3(orig_path) - if self.max_input_len: - print "-> Reading input seconds: ", self.max_input_len - mp3file = mp3file[:self.max_input_len * 1000] - mp3file.export(newpath, format=Converter.WAV) - - # unsupported - else: - print "CONVERSION ERROR:\nThe conversion from %s to %s is not supported!" % (from_format, to_format) - - print "-> Conversion complete." - return newpath diff --git a/dejavu/database.py b/dejavu/database.py old mode 100644 new mode 100755 index f03e33e..5903541 --- a/dejavu/database.py +++ b/dejavu/database.py @@ -1,325 +1,170 @@ -import MySQLdb as mysql -import MySQLdb.cursors as cursors -import os +from __future__ import absolute_import +import abc -class SQLDatabase(): - """ - Queries: - 1) Find duplicates (shouldn't be any, though): +class Database(object): + __metaclass__ = abc.ABCMeta - select `hash`, `song_id`, `offset`, count(*) cnt - from fingerprints - group by `hash`, `song_id`, `offset` - having cnt > 1 - order by cnt asc; + # Name of your Database subclass, this is used in configuration + # to refer to your class + type = None - 2) Get number of hashes by song: + def __init__(self): + super(Database, self).__init__() - select song_id, song_name, count(song_id) as num - from fingerprints - natural join songs - group by song_id - order by count(song_id) desc; + def before_fork(self): + """ + Called before the database instance is given to the new process + """ + pass - 3) get hashes with highest number of collisions + def after_fork(self): + """ + Called after the database instance has been given to the new process - select - hash, - count(distinct song_id) as n - from fingerprints - group by `hash` - order by n DESC; - - => 26 different songs with same fingerprint (392 times): - - select songs.song_name, fingerprints.offset - from fingerprints natural join songs - where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; - """ - - # config keys - CONNECTION = "connection" - KEY_USERNAME = "username" - KEY_DATABASE = "database" - KEY_PASSWORD = "password" - KEY_HOSTNAME = "hostname" - - # tables - FINGERPRINTS_TABLENAME = "fingerprints" - SONGS_TABLENAME = "songs" - - # fields - FIELD_HASH = "hash" - FIELD_SONG_ID = "song_id" - FIELD_OFFSET = "offset" - FIELD_SONGNAME = "song_name" - FIELD_FINGERPRINTED = "fingerprinted" - - # creates - CREATE_FINGERPRINTS_TABLE = """ - CREATE TABLE IF NOT EXISTS `%s` ( - `%s` binary(10) not null, - `%s` mediumint unsigned not null, - `%s` int unsigned not null, - INDEX(%s), - UNIQUE(%s, %s, %s) - );""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, - FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH, - FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH) - - CREATE_SONGS_TABLE = """ - CREATE TABLE IF NOT EXISTS `%s` ( - `%s` mediumint unsigned not null auto_increment, - `%s` varchar(250) not null, - `%s` tinyint default 0, - PRIMARY KEY (`%s`), - UNIQUE KEY `%s` (`%s`) - );""" % (SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED, - FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID) - - # inserts - INSERT_FINGERPRINT = "INSERT IGNORE INTO %s (%s, %s, %s) VALUES (UNHEX(%%s), %%s, %%s)" % ( - FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) # ignore duplicates and don't insert them - INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % ( - SONGS_TABLENAME, FIELD_SONGNAME) - - # selects - SELECT = "SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH) - SELECT_ALL = "SELECT %s, %s FROM %s;" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME) - SELECT_SONG = "SELECT %s FROM %s WHERE %s = %%s" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID) - SELECT_NUM_FINGERPRINTS = "SELECT COUNT(*) as n FROM %s" % (FINGERPRINTS_TABLENAME) - - SELECT_UNIQUE_SONG_IDS = "SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) - SELECT_SONGS = "SELECT %s, %s FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED) - - # drops - DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME - DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME - - # update - UPDATE_SONG_FINGERPRINTED = "UPDATE %s SET %s = 1 WHERE %s = %%s" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID) - - # delete - DELETE_UNFINGERPRINTED = "DELETE FROM %s WHERE %s = 0;" % (SONGS_TABLENAME, FIELD_FINGERPRINTED) - DELETE_ORPHANS = """ - delete from fingerprints - where not exists ( - select * from songs where fingerprints.song_id = songs.song_id - )""" - - def __init__(self, hostname, username, password, database): - # connect - self.database = database - try: - # http://www.halfcooked.com/mt/archives/000969.html - self.connection = mysql.connect( - hostname, username, password, - database, cursorclass=cursors.DictCursor) - - self.connection.autocommit(False) # for fast bulk inserts - self.cursor = self.connection.cursor() - - except mysql.Error, e: - print "Connection error %d: %s" % (e.args[0], e.args[1]) + This will be called in the new process. + """ + pass def setup(self): - try: - # create fingerprints table - self.cursor.execute("USE %s;" % self.database) - self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE) - self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE) - self.delete_unfingerprinted_songs() - self.connection.commit() - except mysql.Error, e: - print "Connection error %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() + """ + Called on creation or shortly afterwards. + """ + pass + @abc.abstractmethod def empty(self): """ - Drops all tables and re-adds them. Be carfeul with this! + Called when the database should be cleared of all data. """ - try: - self.cursor.execute("USE %s;" % self.database) + pass - # drop tables - self.cursor.execute(SQLDatabase.DROP_FINGERPRINTS) - self.cursor.execute(SQLDatabase.DROP_SONGS) - - # recreate - self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE) - self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE) - self.connection.commit() - - except mysql.Error, e: - print "Error in empty(), %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() - - def delete_orphans(self): - try: - self.cursor = self.connection.cursor() - ### TODO: SQLDatabase.DELETE_ORPHANS is not performant enough, need better query - ### to delete fingerprints for which no song is tied to. - #self.cursor.execute(SQLDatabase.DELETE_ORPHANS) - #self.connection.commit() - except mysql.Error, e: - print "Error in delete_orphans(), %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() - + @abc.abstractmethod def delete_unfingerprinted_songs(self): - try: - self.cursor = self.connection.cursor() - self.cursor.execute(SQLDatabase.DELETE_UNFINGERPRINTED) - self.connection.commit() - except mysql.Error, e: - print "Error in delete_unfingerprinted_songs(), %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() + """ + Called to remove any song entries that do not have any fingerprints + associated with them. + """ + pass + @abc.abstractmethod def get_num_songs(self): """ - Returns number of songs the database has fingerprinted. + Returns the amount of songs in the database. """ - try: - self.cursor = self.connection.cursor() - self.cursor.execute(SQLDatabase.SELECT_UNIQUE_SONG_IDS) - record = self.cursor.fetchone() - return int(record['n']) - except mysql.Error, e: - print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1]) - + pass + + @abc.abstractmethod def get_num_fingerprints(self): """ - Returns number of fingerprints the database has fingerprinted. + Returns the number of fingerprints in the database. """ - try: - self.cursor = self.connection.cursor() - self.cursor.execute(SQLDatabase.SELECT_NUM_FINGERPRINTS) - record = self.cursor.fetchone() - return int(record['n']) - except mysql.Error, e: - print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1]) - + pass - def set_song_fingerprinted(self, song_id): + @abc.abstractmethod + def set_song_fingerprinted(self, sid): """ - Set the fingerprinted flag to TRUE (1) once a song has been completely - fingerprinted in the database. - """ - try: - self.cursor = self.connection.cursor() - self.cursor.execute(SQLDatabase.UPDATE_SONG_FINGERPRINTED, song_id) - self.connection.commit() - except mysql.Error, e: - print "Error in set_song_fingerprinted(), %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() + Sets a specific song as having all fingerprints in the database. + sid: Song identifier + """ + pass + + @abc.abstractmethod def get_songs(self): """ - Return songs that have the fingerprinted flag set TRUE (1). + Returns all fully fingerprinted songs in the database. """ - try: - self.cursor.execute(SQLDatabase.SELECT_SONGS) - return self.cursor.fetchall() - except mysql.Error, e: - print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1]) - return None - + pass + + @abc.abstractmethod def get_song_by_id(self, sid): """ - Returns song by its ID. + Return a song by its identifier + + sid: Song identifier """ - try: - self.cursor.execute(SQLDatabase.SELECT_SONG, (sid,)) - return self.cursor.fetchone() - except mysql.Error, e: - print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1]) - return None - + pass - def insert(self, key, value): + @abc.abstractmethod + def insert(self, hash, sid, offset): """ - Insert a (sha1, song_id, offset) row into database. + Inserts a single fingerprint into the database. - key is a sha1 hash, value = (song_id, offset) + hash: Part of a sha1 hash, in hexadecimal format + sid: Song identifier this fingerprint is off + offset: The offset this hash is from """ - try: - args = (key, value[0], value[1]) - self.cursor.execute(SQLDatabase.INSERT_FINGERPRINT, args) - except mysql.Error, e: - print "Error in insert(), %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() + pass - def insert_song(self, songname): + @abc.abstractmethod + def insert_song(self, song_name): """ - Inserts song in the database and returns the ID of the inserted record. + Inserts a song name into the database, returns the new + identifier of the song. + + song_name: The name of the song. """ - try: - self.cursor.execute(SQLDatabase.INSERT_SONG, (songname,)) - self.connection.commit() - return int(self.cursor.lastrowid) - except mysql.Error, e: - print "Error in insert_song(), %d: %s" % (e.args[0], e.args[1]) - self.connection.rollback() - return None + pass - def query(self, key): + @abc.abstractmethod + def query(self, hash): """ - Return all tuples associated with hash. + Returns all matching fingerprint entries associated with + the given hash as parameter. - If hash is None, returns all entries in the - database (be careful with that one!). + hash: Part of a sha1 hash, in hexadecimal format """ - # select all if no key - if key is not None: - sql = SQLDatabase.SELECT - else: - sql = SQLDatabase.SELECT_ALL - - matches = [] - try: - self.cursor.execute(sql, (key,)) - - # collect all matches - records = self.cursor.fetchall() - for record in records: - matches.append((record[SQLDatabase.FIELD_SONG_ID], record[SQLDatabase.FIELD_OFFSET])) - - except mysql.Error, e: - print "Error in query(), %d: %s" % (e.args[0], e.args[1]) - - return matches + pass + @abc.abstractmethod def get_iterable_kv_pairs(self): """ - Returns all tuples in database. + Returns all fingerprints in the database. """ - return self.query(None) + pass - def insert_hashes(self, hashes): + @abc.abstractmethod + def insert_hashes(self, sid, hashes): """ - Insert series of hash => song_id, offset - values into the database. - """ - for h in hashes: - sha1, val = h - self.insert(sha1, val) - self.connection.commit() + Insert a multitude of fingerprints. + sid: Song identifier the fingerprints belong to + hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. + """ + pass + + @abc.abstractmethod def return_matches(self, hashes): """ - Return the (song_id, offset_diff) tuples associated with - a list of + Searches the database for pairs of (hash, offset) values. - sha1 => (None, sample_offset) + hashes: A sequence of tuples in the format (hash, offset) + - hash: Part of a sha1 hash, in hexadecimal format + - offset: Offset this hash was created from/at. - values. + Returns a sequence of (sid, offset_difference) tuples. + + sid: Song identifier + offset_difference: (offset - database_offset) """ - matches = [] - for h in hashes: - sha1, val = h - list_of_tups = self.query(sha1) - if list_of_tups: - for t in list_of_tups: - # (song_id, db_offset, song_sampled_offset) - matches.append((t[0], t[1] - val[1])) - return matches + pass + + +def get_database(database_type=None): + # Default to using the mysql database + database_type = database_type or "mysql" + # Lower all the input. + database_type = database_type.lower() + + for db_cls in Database.__subclasses__(): + if db_cls.type == database_type: + return db_cls + + raise TypeError("Unsupported database type supplied.") + + +# Import our default database handler +import dejavu.database_sql diff --git a/dejavu/database_sql.py b/dejavu/database_sql.py new file mode 100644 index 0000000..565d83f --- /dev/null +++ b/dejavu/database_sql.py @@ -0,0 +1,374 @@ +from __future__ import absolute_import +from itertools import izip_longest +import Queue + +import MySQLdb as mysql +from MySQLdb.cursors import DictCursor + +from dejavu.database import Database + + +class SQLDatabase(Database): + """ + Queries: + + 1) Find duplicates (shouldn't be any, though): + + select `hash`, `song_id`, `offset`, count(*) cnt + from fingerprints + group by `hash`, `song_id`, `offset` + having cnt > 1 + order by cnt asc; + + 2) Get number of hashes by song: + + select song_id, song_name, count(song_id) as num + from fingerprints + natural join songs + group by song_id + order by count(song_id) desc; + + 3) get hashes with highest number of collisions + + select + hash, + count(distinct song_id) as n + from fingerprints + group by `hash` + order by n DESC; + + => 26 different songs with same fingerprint (392 times): + + select songs.song_name, fingerprints.offset + from fingerprints natural join songs + where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; + """ + + type = "mysql" + + # tables + FINGERPRINTS_TABLENAME = "fingerprints" + SONGS_TABLENAME = "songs" + + # fields + FIELD_HASH = "hash" + FIELD_SONG_ID = "song_id" + FIELD_OFFSET = "offset" + FIELD_SONGNAME = "song_name" + FIELD_FINGERPRINTED = "fingerprinted" + + # creates + CREATE_FINGERPRINTS_TABLE = """ + CREATE TABLE IF NOT EXISTS `%s` ( + `%s` binary(10) not null, + `%s` mediumint unsigned not null, + `%s` int unsigned not null, + PRIMARY KEY(%s), + UNIQUE(%s, %s, %s), + FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE + ) ENGINE=INNODB;""" % ( + FINGERPRINTS_TABLENAME, FIELD_HASH, + FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH, + FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH, + FIELD_SONG_ID, SONGS_TABLENAME, FIELD_SONG_ID + ) + + CREATE_SONGS_TABLE = """ + CREATE TABLE IF NOT EXISTS `%s` ( + `%s` mediumint unsigned not null auto_increment, + `%s` varchar(250) not null, + `%s` tinyint default 0, + PRIMARY KEY (`%s`), + UNIQUE KEY `%s` (`%s`) + ) ENGINE=INNODB;""" % ( + SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED, + FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID, + ) + + # inserts (ignores duplicates) + INSERT_FINGERPRINT = """ + INSERT IGNORE INTO %s (%s, %s, %s) values + (UNHEX(%%s), %%s, %%s); + """ % (FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) + + INSERT_SONG = "INSERT INTO %s (%s) values (%%s);" % ( + SONGS_TABLENAME, FIELD_SONGNAME) + + # selects + SELECT = """ + SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s); + """ % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH) + + SELECT_MULTIPLE = """ + SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s); + """ % (FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET, + FINGERPRINTS_TABLENAME, FIELD_HASH) + + SELECT_ALL = """ + SELECT %s, %s FROM %s; + """ % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME) + + SELECT_SONG = """ + SELECT %s FROM %s WHERE %s = %%s + """ % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID) + + SELECT_NUM_FINGERPRINTS = """ + SELECT COUNT(*) as n FROM %s + """ % (FINGERPRINTS_TABLENAME) + + SELECT_UNIQUE_SONG_IDS = """ + SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1; + """ % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) + + SELECT_SONGS = """ + SELECT %s, %s FROM %s WHERE %s = 1; + """ % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED) + + # drops + DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME + DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME + + # update + UPDATE_SONG_FINGERPRINTED = """ + UPDATE %s SET %s = 1 WHERE %s = %%s + """ % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID) + + # delete + DELETE_UNFINGERPRINTED = """ + DELETE FROM %s WHERE %s = 0; + """ % (SONGS_TABLENAME, FIELD_FINGERPRINTED) + + def __init__(self, **options): + super(SQLDatabase, self).__init__() + self.cursor = cursor_factory(**options) + self._options = options + + def after_fork(self): + # Clear the cursor cache, we don't want any stale connections from + # the previous process. + Cursor.clear_cache() + + def setup(self): + """ + Creates any non-existing tables required for dejavu to function. + + This also removes all songs that have been added but have no + fingerprints associated with them. + """ + with self.cursor() as cur: + cur.execute(self.CREATE_SONGS_TABLE) + cur.execute(self.CREATE_FINGERPRINTS_TABLE) + cur.execute(self.DELETE_UNFINGERPRINTED) + + def empty(self): + """ + Drops tables created by dejavu and then creates them again + by calling `SQLDatabase.setup`. + + .. warning: + This will result in a loss of data + """ + with self.cursor() as cur: + cur.execute(self.DROP_FINGERPRINTS) + cur.execute(self.DROP_SONGS) + + self.setup() + + def delete_unfingerprinted_songs(self): + """ + Removes all songs that have no fingerprints associated with them. + """ + with self.cursor() as cur: + cur.execute(self.DELETE_UNFINGERPRINTED) + + def get_num_songs(self): + """ + Returns number of songs the database has fingerprinted. + """ + with self.cursor() as cur: + cur.execute(self.SELECT_UNIQUE_SONG_IDS) + + for count, in cur: + return count + return 0 + + def get_num_fingerprints(self): + """ + Returns number of fingerprints the database has fingerprinted. + """ + with self.cursor() as cur: + cur.execute(self.SELECT_NUM_FINGERPRINTS) + + for count, in cur: + return count + return 0 + + def set_song_fingerprinted(self, sid): + """ + Set the fingerprinted flag to TRUE (1) once a song has been completely + fingerprinted in the database. + """ + with self.cursor() as cur: + cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,)) + + def get_songs(self): + """ + Return songs that have the fingerprinted flag set TRUE (1). + """ + with self.cursor(cursor_type=DictCursor) as cur: + cur.execute(self.SELECT_SONGS) + for row in cur: + yield row + + def get_song_by_id(self, sid): + """ + Returns song by its ID. + """ + with self.cursor(cursor_type=DictCursor) as cur: + cur.execute(self.SELECT_SONG, (sid,)) + return cur.fetchone() + + def insert(self, hash, sid, offset): + """ + Insert a (sha1, song_id, offset) row into database. + """ + with self.cursor() as cur: + cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset)) + + def insert_song(self, songname): + """ + Inserts song in the database and returns the ID of the inserted record. + """ + with self.cursor() as cur: + cur.execute(self.INSERT_SONG, (songname,)) + return cur.lastrowid + + def query(self, hash): + """ + Return all tuples associated with hash. + + If hash is None, returns all entries in the + database (be careful with that one!). + """ + # select all if no key + query = self.SELECT_ALL if hash is None else self.SELECT + + with self.cursor() as cur: + cur.execute(query) + for sid, offset in cur: + yield (sid, offset) + + def get_iterable_kv_pairs(self): + """ + Returns all tuples in database. + """ + return self.query(None) + + def insert_hashes(self, sid, hashes): + """ + Insert series of hash => song_id, offset + values into the database. + """ + values = [] + for hash, offset in hashes: + values.append((hash, sid, offset)) + + with self.cursor() as cur: + for split_values in grouper(values, 1000): + cur.executemany(self.INSERT_FINGERPRINT, split_values) + + def return_matches(self, hashes): + """ + Return the (song_id, offset_diff) tuples associated with + a list of (sha1, sample_offset) values. + """ + # Create a dictionary of hash => offset pairs for later lookups + mapper = {} + for hash, offset in hashes: + mapper[hash.upper()] = offset + + # Get an iteratable of all the hashes we need + values = mapper.keys() + + with self.cursor() as cur: + for split_values in grouper(values, 1000): + # Create our IN part of the query + query = self.SELECT_MULTIPLE + query = query % ', '.join(['UNHEX(%s)'] * len(split_values)) + + cur.execute(query, split_values) + + for hash, sid, offset in cur: + # (sid, db_offset - song_sampled_offset) + yield (sid, offset - mapper[hash]) + + def __getstate__(self): + return (self._options,) + + def __setstate__(self, state): + self._options, = state + self.cursor = cursor_factory(**self._options) + + +def grouper(iterable, n, fillvalue=None): + args = [iter(iterable)] * n + return (filter(None, values) for values + in izip_longest(fillvalue=fillvalue, *args)) + + +def cursor_factory(**factory_options): + def cursor(**options): + options.update(factory_options) + return Cursor(**options) + return cursor + + +class Cursor(object): + """ + Establishes a connection to the database and returns an open cursor. + + + ```python + # Use as context manager + with Cursor() as cur: + cur.execute(query) + ``` + """ + _cache = Queue.Queue(maxsize=5) + + def __init__(self, cursor_type=mysql.cursors.Cursor, **options): + super(Cursor, self).__init__() + + try: + conn = self._cache.get_nowait() + except Queue.Empty: + conn = mysql.connect(**options) + else: + # Ping the connection before using it from the cache. + conn.ping(True) + + self.conn = conn + self.conn.autocommit(False) + self.cursor_type = cursor_type + + @classmethod + def clear_cache(cls): + cls._cache = Queue.Queue(maxsize=5) + + def __enter__(self): + self.cursor = self.conn.cursor(self.cursor_type) + return self.cursor + + def __exit__(self, extype, exvalue, traceback): + # if we had a MySQL related error we try to rollback the cursor. + if extype is mysql.MySQLError: + self.cursor.rollback() + + self.cursor.close() + self.conn.commit() + + # Put it back on the queue + try: + self._cache.put_nowait(self.conn) + except Queue.Full: + self.conn.close() diff --git a/dejavu/decoder.py b/dejavu/decoder.py new file mode 100644 index 0000000..d01a8ef --- /dev/null +++ b/dejavu/decoder.py @@ -0,0 +1,48 @@ +import os +import fnmatch +import numpy as np +from pydub import AudioSegment + + +def find_files(path, extensions): + # Allow both with ".mp3" and without "mp3" to be used for extensions + extensions = [e.replace(".", "") for e in extensions] + + for dirpath, dirnames, files in os.walk(path): + for extension in extensions: + for f in fnmatch.filter(files, "*.%s" % extension): + p = os.path.join(dirpath, f) + yield (p, extension) + + +def read(filename, limit=None): + """ + Reads any file supported by pydub (ffmpeg) and returns the data contained + within. + + Can be optionally limited to a certain amount of seconds from the start + of the file by specifying the `limit` parameter. This is the amount of + seconds from the start of the file. + + returns: (channels, samplerate) + """ + audiofile = AudioSegment.from_file(filename) + + if limit: + audiofile = audiofile[:limit * 1000] + + data = np.fromstring(audiofile._data, np.int16) + + channels = [] + for chn in xrange(audiofile.channels): + channels.append(data[chn::audiofile.channels]) + + return channels, audiofile.frame_rate + + +def path_to_songname(path): + """ + Extracts song name from a filepath. Used to identify which songs + have already been fingerprinted on disk. + """ + return os.path.basename(path).split(".")[0] diff --git a/dejavu/fingerprint.py b/dejavu/fingerprint.py old mode 100644 new mode 100755 index d1f78cc..68f6104 --- a/dejavu/fingerprint.py +++ b/dejavu/fingerprint.py @@ -1,226 +1,118 @@ import numpy as np import matplotlib.mlab as mlab import matplotlib.pyplot as plt -import matplotlib.image as mpimg -from scipy.io import wavfile from scipy.ndimage.filters import maximum_filter -from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion -from dejavu.database import SQLDatabase -import os -import wave -import sys -import time +from scipy.ndimage.morphology import (generate_binary_structure, + iterate_structure, binary_erosion) import hashlib -import pickle -class Fingerprinter(): - IDX_FREQ_I = 0 - IDX_TIME_J = 1 - - DEFAULT_FS = 44100 - DEFAULT_WINDOW_SIZE = 4096 - DEFAULT_OVERLAP_RATIO = 0.5 - DEFAULT_FAN_VALUE = 15 - - DEFAULT_AMP_MIN = 10 - PEAK_NEIGHBORHOOD_SIZE = 20 - MIN_HASH_TIME_DELTA = 0 +IDX_FREQ_I = 0 +IDX_TIME_J = 1 - def __init__(self, config, - Fs=DEFAULT_FS, - wsize=DEFAULT_WINDOW_SIZE, - wratio=DEFAULT_OVERLAP_RATIO, - fan_value=DEFAULT_FAN_VALUE, - amp_min=DEFAULT_AMP_MIN): +DEFAULT_FS = 44100 +DEFAULT_WINDOW_SIZE = 4096 +DEFAULT_OVERLAP_RATIO = 0.5 +DEFAULT_FAN_VALUE = 15 - self.config = config - database = SQLDatabase( - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD), - self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE)) - self.db = database +DEFAULT_AMP_MIN = 10 +PEAK_NEIGHBORHOOD_SIZE = 20 +MIN_HASH_TIME_DELTA = 0 - self.Fs = Fs - self.dt = 1.0 / self.Fs - self.window_size = wsize - self.window_overlap_ratio = wratio - self.fan_value = fan_value - self.noverlap = int(self.window_size * self.window_overlap_ratio) - self.amp_min = amp_min - def fingerprint(self, samples, path, sid, cid): - """Used for learning known songs""" - hashes = self.process_channel(samples, song_id=sid) - print "Generated %d hashes" % len(hashes) - self.db.insert_hashes(hashes) +def fingerprint(channel_samples, Fs=DEFAULT_FS, + wsize=DEFAULT_WINDOW_SIZE, + wratio=DEFAULT_OVERLAP_RATIO, + fan_value=DEFAULT_FAN_VALUE, + amp_min=DEFAULT_AMP_MIN): + """ + FFT the channel, log transform output, find local maxima, then return + locally sensitive hashes. + """ + # FFT the signal and extract frequency components + arr2D = mlab.specgram( + channel_samples, + NFFT=wsize, + Fs=Fs, + window=mlab.window_hanning, + noverlap=int(wsize * wratio))[0] - def match(self, samples): - """Used for matching unknown songs""" - hashes = self.process_channel(samples) - matches = self.db.return_matches(hashes) - return matches + # apply log transform since specgram() returns linear array + arr2D = 10 * np.log10(arr2D) + arr2D[arr2D == -np.inf] = 0 # replace infs with zeros - def process_channel(self, channel_samples, song_id=None): - """ - FFT the channel, log transform output, find local maxima, then return - locally sensitive hashes. - """ - # FFT the signal and extract frequency components - arr2D = mlab.specgram( - channel_samples, - NFFT=self.window_size, - Fs=self.Fs, - window=mlab.window_hanning, - noverlap=self.noverlap)[0] + # find local maxima + local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min) - # apply log transform since specgram() returns linear array - arr2D = 10 * np.log10(arr2D) - arr2D[arr2D == -np.inf] = 0 # replace infs with zeros - - # find local maxima - local_maxima = self.get_2D_peaks(arr2D, plot=False) + # return hashes + return generate_hashes(local_maxima, fan_value=fan_value) - # return hashes - return self.generate_hashes(local_maxima, song_id=song_id) - def get_2D_peaks(self, arr2D, plot=False): +def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN): + # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure + struct = generate_binary_structure(2, 1) + neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) - # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure - struct = generate_binary_structure(2, 1) - neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE) + # find local maxima using our fliter shape + local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D + background = (arr2D == 0) + eroded_background = binary_erosion(background, structure=neighborhood, + border_value=1) - # find local maxima using our fliter shape - local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D - background = (arr2D == 0) - eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) - detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks + # Boolean mask of arr2D with True at peaks + detected_peaks = local_max - eroded_background - # extract peaks - amps = arr2D[detected_peaks] - j, i = np.where(detected_peaks) + # extract peaks + amps = arr2D[detected_peaks] + j, i = np.where(detected_peaks) - # filter peaks - amps = amps.flatten() - peaks = zip(i, j, amps) - peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp - - # get indices for frequency and time - frequency_idx = [x[1] for x in peaks_filtered] - time_idx = [x[0] for x in peaks_filtered] + # filter peaks + amps = amps.flatten() + peaks = zip(i, j, amps) + peaks_filtered = [x for x in peaks if x[2] > amp_min] # freq, time, amp - if plot: - # scatter of the peaks - fig, ax = plt.subplots() - ax.imshow(arr2D) - ax.scatter(time_idx, frequency_idx) - ax.set_xlabel('Time') - ax.set_ylabel('Frequency') - ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke"); - plt.gca().invert_yaxis() - plt.show() + # get indices for frequency and time + frequency_idx = [x[1] for x in peaks_filtered] + time_idx = [x[0] for x in peaks_filtered] - return zip(frequency_idx, time_idx) + if plot: + # scatter of the peaks + fig, ax = plt.subplots() + ax.imshow(arr2D) + ax.scatter(time_idx, frequency_idx) + ax.set_xlabel('Time') + ax.set_ylabel('Frequency') + ax.set_title("Spectrogram") + plt.gca().invert_yaxis() + plt.show() - def generate_hashes(self, peaks, song_id=None): - """ - Hash list structure: - sha1-hash[0:20] song_id, time_offset - [(e05b341a9b77a51fd26, (3, 32)), ... ] - """ - fingerprinted = set() # to avoid rehashing same pairs - hashes = [] + return zip(frequency_idx, time_idx) - for i in range(len(peaks)): - for j in range(self.fan_value): - if i+j < len(peaks) and not (i, i+j) in fingerprinted: - freq1 = peaks[i][Fingerprinter.IDX_FREQ_I] - freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I] - t1 = peaks[i][Fingerprinter.IDX_TIME_J] - t2 = peaks[i+j][Fingerprinter.IDX_TIME_J] - t_delta = t2 - t1 - - if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA: - h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))) - hashes.append((h.hexdigest()[0:20], (song_id, t1))) - - # ensure we don't repeat hashing - fingerprinted.add((i, i+j)) - return hashes +def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE): + """ + Hash list structure: + sha1_hash[0:20] time_offset + [(e05b341a9b77a51fd26, 32), ... ] + """ + fingerprinted = set() # to avoid rehashing same pairs - def insert_into_db(self, key, value): - self.db.insert(key, value) + for i in range(len(peaks)): + for j in range(fan_value): + if (i + j) < len(peaks) and not (i, i + j) in fingerprinted: + freq1 = peaks[i][IDX_FREQ_I] + freq2 = peaks[i + j][IDX_FREQ_I] - def print_stats(self): + t1 = peaks[i][IDX_TIME_J] + t2 = peaks[i + j][IDX_TIME_J] - iterable = self.db.get_iterable_kv_pairs() + t_delta = t2 - t1 - counter = {} - for t in iterable: - sid, toff = t - if not sid in counter: - counter[sid] = 1 - else: - counter[sid] += 1 + if t_delta >= MIN_HASH_TIME_DELTA: + h = hashlib.sha1( + "%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)) + ) + yield (h.hexdigest()[0:20], t1) - for song_id, count in counter.iteritems(): - song_name = self.song_names[song_id] - print "%s has %d spectrogram peaks" % (song_name, count) - - def set_song_names(self, wpaths): - self.song_names = wpaths - - def align_matches(self, matches, starttime, record_seconds=0, verbose=False): - """ - Finds hash matches that align in time with other matches and finds - consensus about which hashes are "true" signal from the audio. - - Returns a dictionary with match information. - """ - # align by diffs - diff_counter = {} - largest = 0 - largest_count = 0 - song_id = -1 - for tup in matches: - sid, diff = tup - if not diff in diff_counter: - diff_counter[diff] = {} - if not sid in diff_counter[diff]: - diff_counter[diff][sid] = 0 - diff_counter[diff][sid] += 1 - - if diff_counter[diff][sid] > largest_count: - largest = diff - largest_count = diff_counter[diff][sid] - song_id = sid - - if verbose: - print "Diff is %d with %d offset-aligned matches" % (largest, largest_count) - - # extract idenfication - song = self.db.get_song_by_id(song_id) - if song: - songname = song.get(SQLDatabase.FIELD_SONGNAME, None) - else: - return None - songname = songname.replace("_", " ") - elapsed = time.time() - starttime - - if verbose: - print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed) - - # return match info - song = { - "song_id" : song_id, - "song_name" : songname, - "match_time" : elapsed, - "confidence" : largest_count - } - - if record_seconds: - song['record_time'] = record_seconds - - return song + # ensure we don't repeat hashing + fingerprinted.add((i, i + j)) diff --git a/dejavu/recognize.py b/dejavu/recognize.py old mode 100644 new mode 100755 index bcb5d3c..d87a323 --- a/dejavu/recognize.py +++ b/dejavu/recognize.py @@ -1,72 +1,112 @@ -from multiprocessing import Queue, Process -from dejavu.database import SQLDatabase -from scipy.io import wavfile -import wave +import dejavu.fingerprint as fingerprint +import dejavu.decoder as decoder import numpy as np import pyaudio -import sys import time -import array -class Recognizer(object): - CHUNK = 8192 # 44100 is a multiple of 1225 - FORMAT = pyaudio.paInt16 - CHANNELS = 2 - RATE = 44100 +class BaseRecognizer(object): - def __init__(self, fingerprinter, config): + def __init__(self, dejavu): + self.dejavu = dejavu + self.Fs = fingerprint.DEFAULT_FS - self.fingerprinter = fingerprinter - self.config = config + def _recognize(self, *data): + matches = [] + for d in data: + matches.extend(self.dejavu.find_matches(d, Fs=self.Fs)) + return self.dejavu.align_matches(matches) + + def recognize(self): + pass # base class does nothing + + +class FileRecognizer(BaseRecognizer): + def __init__(self, dejavu): + super(FileRecognizer, self).__init__(dejavu) + + def recognize_file(self, filename): + frames, self.Fs = decoder.read(filename) + + t = time.time() + match = self._recognize(*frames) + t = time.time() - t + + if match: + match['match_time'] = t + + return match + + def recognize(self, filename): + return self.recognize_file(filename) + + +class MicrophoneRecognizer(BaseRecognizer): + default_chunksize = 8192 + default_format = pyaudio.paInt16 + default_channels = 2 + default_samplerate = 44100 + + def __init__(self, dejavu): + super(MicrophoneRecognizer, self).__init__(dejavu) self.audio = pyaudio.PyAudio() - - def read(self, filename, verbose=False): - - # read file into channels - channels = [] - Fs, frames = wavfile.read(filename) - wave_object = wave.open(filename) - nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() - for channel in range(nchannels): - channels.append(frames[:, channel]) - - # get matches - starttime = time.time() - matches = [] - for channel in channels: - matches.extend(self.fingerprinter.match(channel)) - - return self.fingerprinter.align_matches(matches, starttime, verbose=verbose) - - def listen(self, seconds=10, verbose=False): + self.stream = None + self.data = [] + self.channels = MicrophoneRecognizer.default_channels + self.chunksize = MicrophoneRecognizer.default_chunksize + self.samplerate = MicrophoneRecognizer.default_samplerate + self.recorded = False - # open stream - stream = self.audio.open(format=Recognizer.FORMAT, - channels=Recognizer.CHANNELS, - rate=Recognizer.RATE, - input=True, - frames_per_buffer=Recognizer.CHUNK) - - # record - if verbose: print("* recording") - left, right = [], [] - for i in range(0, int(Recognizer.RATE / Recognizer.CHUNK * seconds)): - data = stream.read(Recognizer.CHUNK) - nums = np.fromstring(data, np.int16) - left.extend(nums[1::2]) - right.extend(nums[0::2]) - if verbose: print("* done recording") - - # close and stop the stream - stream.stop_stream() - stream.close() - - # match both channels - starttime = time.time() - matches = [] - matches.extend(self.fingerprinter.match(left)) - matches.extend(self.fingerprinter.match(right)) - - # align and return - return self.fingerprinter.align_matches(matches, starttime, record_seconds=seconds, verbose=verbose) \ No newline at end of file + def start_recording(self, channels=default_channels, + samplerate=default_samplerate, + chunksize=default_chunksize): + self.chunksize = chunksize + self.channels = channels + self.recorded = False + self.samplerate = samplerate + + if self.stream: + self.stream.stop_stream() + self.stream.close() + + self.stream = self.audio.open( + format=self.default_format, + channels=channels, + rate=samplerate, + input=True, + frames_per_buffer=chunksize, + ) + + self.data = [[] for i in range(channels)] + + def process_recording(self): + data = self.stream.read(self.chunksize) + nums = np.fromstring(data, np.int16) + for c in range(self.channels): + self.data[c].extend(nums[c::self.channels]) + + def stop_recording(self): + self.stream.stop_stream() + self.stream.close() + self.stream = None + self.recorded = True + + def recognize_recording(self): + if not self.recorded: + raise NoRecordingError("Recording was not complete/begun") + return self._recognize(*self.data) + + def get_recorded_time(self): + return len(self.data[0]) / self.rate + + def recognize(self, seconds=10): + self.start_recording() + for i in range(0, int(self.samplerate / self.chunksize + * seconds)): + self.process_recording() + self.stop_recording() + return self.recognize_recording() + + +class NoRecordingError(Exception): + pass diff --git a/go.py b/go.py old mode 100644 new mode 100755 index 2b22eac..90947e2 --- a/go.py +++ b/go.py @@ -1,20 +1,26 @@ -from dejavu.control import Dejavu -from ConfigParser import ConfigParser +from dejavu import Dejavu import warnings +import json warnings.filterwarnings("ignore") -# load config -config = ConfigParser() -config.read("dejavu.cnf") +# load config from a JSON file (or anything outputting a python dictionary) +with open("dejavu.cnf") as f: + config = json.load(f) -# create Dejavu object -dejavu = Dejavu(config) -dejavu.fingerprint("va_us_top_40/mp3", "va_us_top_40/wav", [".mp3"], 5) +# create a Dejavu instance +djv = Dejavu(config) +# Fingerprint all the mp3's in the directory we give it +djv.fingerprint_directory("va_us_top_40/mp3", [".mp3"]) -# recognize microphone audio -from dejavu.recognize import Recognizer -recognizer = Recognizer(dejavu.fingerprinter, config) +# Recognize audio from a file +from dejavu.recognize import FileRecognizer +song = djv.recognize(FileRecognizer, "mp3/beware.mp3") -# recognize song playing over microphone for 10 seconds -song = recognizer.listen(seconds=5, verbose=True) -print song \ No newline at end of file +# Or recognize audio from your microphone for 10 seconds +from dejavu.recognize import MicrophoneRecognizer +song = djv.recognize(MicrophoneRecognizer, seconds=2) + +# Or use a recognizer without the shortcut, in anyway you would like +from dejavu.recognize import FileRecognizer +recognizer = FileRecognizer(djv) +song = recognizer.recognize_file("va_us_top_40/wav/17_-_#Beautiful_-_Mariah_Carey_ft.wav") diff --git a/performance.py b/performance.py deleted file mode 100644 index 2bee5d0..0000000 --- a/performance.py +++ /dev/null @@ -1,159 +0,0 @@ -from dejavu.control import Dejavu -from dejavu.recognize import Recognizer -from dejavu.convert import Converter -from dejavu.database import SQLDatabase -from ConfigParser import ConfigParser -from scipy.io import wavfile -import matplotlib.pyplot as plt -import warnings -import pyaudio -import os, wave, sys -import random -import numpy as np -warnings.filterwarnings("ignore") - -config = ConfigParser() -config.read("dejavu.cnf") -dejavu = Dejavu(config) -recognizer = Recognizer(dejavu.fingerprinter, config) - -def test_recording_lengths(recognizer): - - # settings for run - RATE = 44100 - FORMAT = pyaudio.paInt16 - padding_seconds = 10 - SONG_PADDING = RATE * padding_seconds - OUTPUT_FILE = "output.wav" - p = pyaudio.PyAudio() - c = Converter() - files = c.find_files("va_us_top_40/wav/", [".wav"])[-25:] - total = len(files) - recording_lengths = [4] - correct = 0 - count = 0 - score = {} - - for r in recording_lengths: - - RECORD_LENGTH = RATE * r - - for tup in files: - f, ext = tup - - # read the file - #print "reading: %s" % f - Fs, frames = wavfile.read(f) - wave_object = wave.open(f) - nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() - - # chose at random a segment of audio to play - possible_end = num_frames - SONG_PADDING - RECORD_LENGTH - possible_start = SONG_PADDING - if possible_end - possible_start < RECORD_LENGTH: - print "ERROR! Song is too short to sample based on padding and recording seconds preferences." - sys.exit() - start = random.randint(possible_start, possible_end) - end = start + RECORD_LENGTH + 1 - - # get that segment of samples - channels = [] - frames = frames[start:end, :] - wav_string = frames.tostring() - - # write to disk - wf = wave.open(OUTPUT_FILE, 'wb') - wf.setnchannels(nchannels) - wf.setsampwidth(p.get_sample_size(FORMAT)) - wf.setframerate(RATE) - wf.writeframes(b''.join(wav_string)) - wf.close() - - # play and test - correctname = os.path.basename(f).replace(".wav", "").replace("_", " ") - inp = raw_input("Click ENTER when playing %s ..." % OUTPUT_FILE) - song = recognizer.listen(seconds=r+1, verbose=False) - print "PREDICTED: %s" % song['song_name'] - print "ACTUAL: %s" % correctname - if song['song_name'] == correctname: - correct += 1 - count += 1 - - print "Currently %d correct out of %d in total of %d" % (correct, count, total) - - score[r] = (correct, total) - print "UPDATE AFTER %d TRIAL: %s" % (r, score) - - return score - -def plot_match_time_trials(): - - # I did this manually - t = np.array([1, 2, 3, 4, 5, 6, 7, 8, 10, 15, 25, 30, 45, 60]) - m = np.array([.47, .79, 1.1, 1.5, 1.8, 2.18, 2.62, 2.8, 3.65, 5.29, 8.92, 10.63, 16.09, 22.29]) - mplust = t + m - - # linear regression - A = np.matrix([t, np.ones(len(t))]) - print A - w = np.linalg.lstsq(A.T, mplust)[0] - line = w[0] * t + w[1] - print "Equation for line is %f * record_time + %f = time_to_match" % (w[0], w[1]) - - # and plot - plt.title("Recording vs Matching time for \"Get Lucky\" by Daft Punk") - plt.xlabel("Time recorded (s)") - plt.ylabel("Time recorded + time to match (s)") - #plt.scatter(t, mplust) - plt.plot(t, line, 'r-', t, mplust, 'o') - plt.show() - -def plot_accuracy(): - # also did this manually - secs = np.array([1, 2, 3, 4, 5, 6]) - correct = np.array([27.0, 43.0, 44.0, 44.0, 45.0, 45.0]) - total = 45.0 - correct = correct / total - - plt.title("Dejavu Recognition Accuracy as a Function of Time") - plt.xlabel("Time recorded (s)") - plt.ylabel("Accuracy") - plt.plot(secs, correct) - plt.ylim([0.0, 1.05]) - plt.show() - -def plot_hashes_per_song(): - squery = """select song_name, count(song_id) as num - from fingerprints - natural join songs - group by song_name - order by count(song_id) asc;""" - sql = SQLDatabase(username="root", password="root", database="dejavu", hostname="localhost") - cursor = sql.connection.cursor() - cursor.execute(squery) - counts = cursor.fetchall() - - songs = [] - count = [] - for item in counts: - songs.append(item['song_name'].replace("_", " ")[4:]) - count.append(item['num']) - - pos = np.arange(len(songs)) + 0.5 - - fig = plt.figure() - ax = fig.add_subplot(111) - ax.barh(pos, count, align='center') - ax.set_yticks(pos, tuple(songs)) - - ax.axvline(0, color='k', lw=3) - - ax.set_xlabel('Number of Fingerprints') - ax.set_title('Number of Fingerprints by Song') - ax.grid(True) - plt.show() - -#plot_accuracy() -#score = test_recording_lengths(recognizer) -#plot_match_time_trials() -#plot_hashes_per_song()