commit 90a93bc47b1ac10b70a1ff34a8eb58266b51bf26 Author: worldveil Date: Mon Nov 18 21:51:27 2013 -0500 moved to github diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8f8bc7e --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +*.pyc +wav +mp3 +*.wav +*.mp3 +.DS_Store +*.cnf diff --git a/README.md b/README.md new file mode 100644 index 0000000..3ee9ad0 Binary files /dev/null and b/README.md differ diff --git a/dejavu/__init__.py b/dejavu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dejavu/control.py b/dejavu/control.py new file mode 100644 index 0000000..5d5b71f --- /dev/null +++ b/dejavu/control.py @@ -0,0 +1,106 @@ +from dejavu.database import SQLDatabase +from dejavu.converter import Converter +from dejavu.fingerprint import Fingerprinter +from scipy.io import wavfile +from multiprocessing import Process +import wave, os +import random + +class Dejavu(): + + def __init__(self, config): + + self.config = config + + # create components + self.converter = Converter() + self.fingerprinter = Fingerprinter(self.config) + self.fingerprinter.db.setup() + + # get songs previously indexed + self.songs = self.fingerprinter.db.get_songs() + self.songnames_set = set() # to know which ones we've computed before + if self.songs: + for song in self.songs: + song_id = song[SQLDatabase.FIELD_SONG_ID] + song_name = song[SQLDatabase.FIELD_SONGNAME] + self.songnames_set.add(song_name) + print "Added: %s to the set of fingerprinted songs..." % song_name + + def chunkify(self, lst, n): + """ + Splits a list into roughly n equal parts. + http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts + """ + return [lst[i::n] for i in xrange(n)] + + def fingerprint(self, path, output, extensions, nprocesses): + + # convert files, shuffle order + files = self.converter.find_files(path, extensions) + random.shuffle(files) + files_split = self.chunkify(files, nprocesses) + + # split into processes here + processes = [] + for i in range(nprocesses): + + # need database instance since mysql connections shouldn't be shared across processes + sql_connection = SQLDatabase( + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME), + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD), + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE)) + + # create process and start it + p = Process(target=self.fingerprint_worker, args=(files_split[i], sql_connection, output)) + p.start() + processes.append(p) + + # wait for all processes to complete + for p in processes: + p.join() + + # delete orphans + print "Done fingerprinting. Deleting orphaned fingerprints..." + self.fingerprinter.db.delete_orphans() + + def fingerprint_worker(self, files, sql_connection, output): + + for filename, extension in files: + + # if there are already fingerprints in database, don't re-fingerprint or convert + song_name = os.path.basename(filename).split(".")[0] + if song_name in self.songnames_set: + print "-> Already fingerprinted, continuing..." + continue + + # convert to WAV + wavout_path = self.converter.convert(filename, extension, Converter.WAV, output, song_name) + + # insert song name into database + song_id = sql_connection.insert_song(song_name) + + # for each channel perform FFT analysis and fingerprinting + channels = self.extract_channels(wavout_path) + for c in range(len(channels)): + channel = channels[c] + print "-> Fingerprinting channel %d of song %s..." % (c+1, song_name) + self.fingerprinter.fingerprint(channel, wavout_path, song_id, c+1) + + # only after done fingerprinting do confirm + sql_connection.set_song_fingerprinted(song_id) + + def extract_channels(self, path): + """ + Reads channels from disk. + """ + channels = [] + Fs, frames = wavfile.read(path) + wave_object = wave.open(path) + nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() + assert Fs == self.fingerprinter.Fs + + for channel in range(nchannels): + channels.append(frames[:, channel]) + return channels \ No newline at end of file diff --git a/dejavu/convert.py b/dejavu/convert.py new file mode 100644 index 0000000..77afcef --- /dev/null +++ b/dejavu/convert.py @@ -0,0 +1,54 @@ +import os, fnmatch +from pydub import AudioSegment + +class Converter(): + + WAV = "wav" + MP3 = "mp3" + FORMATS = [ + WAV, + MP3] + + def __init__(self): + pass + + def ensure_folder(self, extension): + if not os.path.exists(extension): + os.makedirs(extension) + + def find_files(self, path, extensions): + filepaths = [] + extensions = [e.replace(".", "") for e in extensions if e.replace(".", "") in Converter.FORMATS] + print "Supported formats: %s" % extensions + for dirpath, dirnames, files in os.walk(path) : + for extension in extensions: + for f in fnmatch.filter(files, "*.%s" % extension): + p = os.path.join(dirpath, f) + renamed = p.replace(" ", "_") + os.rename(p, renamed) + #print "Found file: %s with extension %s" % (renamed, extension) + filepaths.append((renamed, extension)) + return filepaths + + def convert(self, orig_path, from_format, to_format, output_folder, song_name): + + # start conversion + self.ensure_folder(output_folder) + print "-> Now converting: %s from %s format to %s format..." % (song_name, from_format, to_format) + + # MP3 --> WAV + if from_format == Converter.MP3 and to_format == Converter.WAV: + + newpath = os.path.join(output_folder, "%s.%s" % (song_name, Converter.WAV)) + if os.path.isfile(newpath): + print "-> Already converted, skipping..." + else: + mp3file = AudioSegment.from_mp3(orig_path) + mp3file.export(newpath, format=Converter.WAV) + + # unsupported + else: + print "CONVERSION ERROR:\nThe conversion from %s to %s is not supported!" % (from_format, to_format) + + print "-> Conversion complete." + return newpath diff --git a/dejavu/database.py b/dejavu/database.py new file mode 100644 index 0000000..9614a5f --- /dev/null +++ b/dejavu/database.py @@ -0,0 +1,320 @@ +import MySQLdb as mysql +import MySQLdb.cursors as cursors +import os + +class SQLDatabase(): + """ + Queries: + + 1) Find duplicates (shouldn't be any, though): + + select `hash`, `song_id`, `offset`, count(*) cnt + from fingerprints + group by `hash`, `song_id`, `offset` + having cnt > 1 + order by cnt asc; + + 2) Get number of hashes by song: + + select song_id, song_name, count(song_id) as num + from fingerprints + natural join songs + group by song_id + order by count(song_id) desc; + + 3) get hashes with highest number of collisions + + select + hash, + count(distinct song_id) as n + from fingerprints + group by `hash` + order by n DESC; + + => 26 different songs with same fingerprint (392 times): + + select songs.song_name, fingerprints.offset + from fingerprints natural join songs + where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73"; + """ + + # config keys + CONNECTION = "connection" + KEY_USERNAME = "username" + KEY_DATABASE = "database" + KEY_PASSWORD = "password" + KEY_HOSTNAME = "hostname" + + # tables + FINGERPRINTS_TABLENAME = "fingerprints" + SONGS_TABLENAME = "songs" + + # fields + FIELD_HASH = "hash" + FIELD_SONG_ID = "song_id" + FIELD_OFFSET = "offset" + FIELD_SONGNAME = "song_name" + FIELD_FINGERPRINTED = "fingerprinted" + + # creates + CREATE_FINGERPRINTS_TABLE = """ + CREATE TABLE IF NOT EXISTS `%s` ( + `%s` binary(10) not null, + `%s` mediumint unsigned not null, + `%s` int unsigned not null, + INDEX(%s), + UNIQUE(%s, %s, %s) + );""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, + FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH, + FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH) + + CREATE_SONGS_TABLE = """ + CREATE TABLE IF NOT EXISTS `%s` ( + `%s` mediumint unsigned not null auto_increment, + `%s` varchar(250) not null, + `%s` tinyint default 0, + PRIMARY KEY (`%s`), + UNIQUE KEY `%s` (`%s`) + );""" % (SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED, + FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID) + + # inserts + INSERT_FINGERPRINT = "INSERT IGNORE INTO %s (%s, %s, %s) VALUES (UNHEX(%%s), %%s, %%s)" % ( + FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) # ignore duplicates and don't insert them + INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % ( + SONGS_TABLENAME, FIELD_SONGNAME) + + # selects + SELECT = "SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH) + SELECT_ALL = "SELECT %s, %s FROM %s;" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME) + SELECT_SONG = "SELECT %s FROM %s WHERE %s = %%s" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID) + SELECT_NUM_FINGERPRINTS = "SELECT COUNT(*) as n FROM %s" % (FINGERPRINTS_TABLENAME) + + SELECT_UNIQUE_SONG_IDS = "SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) + SELECT_SONGS = "SELECT %s, %s FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED) + + # drops + DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME + DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME + + # update + UPDATE_SONG_FINGERPRINTED = "UPDATE %s SET %s = 1 WHERE %s = %%s" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID) + + # delete + DELETE_UNFINGERPRINTED = "DELETE FROM %s WHERE %s = 0;" % (SONGS_TABLENAME, FIELD_FINGERPRINTED) + DELETE_ORPHANS = "" + + def __init__(self, hostname, username, password, database): + # connect + self.database = database + try: + # http://www.halfcooked.com/mt/archives/000969.html + self.connection = mysql.connect( + hostname, username, password, + database, cursorclass=cursors.DictCursor) + + self.connection.autocommit(False) # for fast bulk inserts + self.cursor = self.connection.cursor() + + except mysql.Error, e: + print "Connection error %d: %s" % (e.args[0], e.args[1]) + + def setup(self): + try: + # create fingerprints table + self.cursor.execute("USE %s;" % self.database) + self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE) + self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE) + self.delete_unfingerprinted_songs() + self.connection.commit() + except mysql.Error, e: + print "Connection error %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + + def empty(self): + """ + Drops all tables and re-adds them. Be carfeul with this! + """ + try: + self.cursor.execute("USE %s;" % self.database) + + # drop tables + self.cursor.execute(SQLDatabase.DROP_FINGERPRINTS) + self.cursor.execute(SQLDatabase.DROP_SONGS) + + # recreate + self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE) + self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE) + self.connection.commit() + + except mysql.Error, e: + print "Error in empty(), %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + + def delete_orphans(self): + try: + self.cursor = self.connection.cursor() + self.cursor.execute(SQLDatabase.DELETE_ORPHANS) + self.connection.commit() + except mysql.Error, e: + print "Error in delete_orphans(), %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + + + def delete_unfingerprinted_songs(self): + try: + self.cursor = self.connection.cursor() + self.cursor.execute(SQLDatabase.DELETE_UNFINGERPRINTED) + self.connection.commit() + except mysql.Error, e: + print "Error in delete_unfingerprinted_songs(), %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + + def get_num_songs(self): + """ + Returns number of songs the database has fingerprinted. + """ + try: + self.cursor = self.connection.cursor() + self.cursor.execute(SQLDatabase.SELECT_UNIQUE_SONG_IDS) + record = self.cursor.fetchone() + return int(record['n']) + except mysql.Error, e: + print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1]) + + def get_num_fingerprints(self): + """ + Returns number of fingerprints the database has fingerprinted. + """ + try: + self.cursor = self.connection.cursor() + self.cursor.execute(SQLDatabase.SELECT_NUM_FINGERPRINTS) + record = self.cursor.fetchone() + return int(record['n']) + except mysql.Error, e: + print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1]) + + + def set_song_fingerprinted(self, song_id): + """ + Set the fingerprinted flag to TRUE (1) once a song has been completely + fingerprinted in the database. + """ + try: + self.cursor = self.connection.cursor() + self.cursor.execute(SQLDatabase.UPDATE_SONG_FINGERPRINTED, song_id) + self.connection.commit() + except mysql.Error, e: + print "Error in set_song_fingerprinted(), %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + + def get_songs(self): + """ + Return songs that have the fingerprinted flag set TRUE (1). + """ + try: + self.cursor.execute(SQLDatabase.SELECT_SONGS) + return self.cursor.fetchall() + except mysql.Error, e: + print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1]) + return None + + def get_song_by_id(self, sid): + """ + Returns song by its ID. + """ + try: + self.cursor.execute(SQLDatabase.SELECT_SONG, (sid,)) + return self.cursor.fetchone() + except mysql.Error, e: + print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1]) + return None + + + def insert(self, key, value): + """ + Insert a (sha1, song_id, offset) row into database. + + key is a sha1 hash, value = (song_id, offset) + """ + try: + args = (key, value[0], value[1]) + self.cursor.execute(SQLDatabase.INSERT_FINGERPRINT, args) + except mysql.Error, e: + print "Error in insert(), %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + + def insert_song(self, songname): + """ + Inserts song in the database and returns the ID of the inserted record. + """ + try: + self.cursor.execute(SQLDatabase.INSERT_SONG, (songname,)) + self.connection.commit() + return int(self.cursor.lastrowid) + except mysql.Error, e: + print "Error in insert_song(), %d: %s" % (e.args[0], e.args[1]) + self.connection.rollback() + return None + + def query(self, key): + """ + Return all tuples associated with hash. + + If hash is None, returns all entries in the + database (be careful with that one!). + """ + # select all if no key + if key is not None: + sql = SQLDatabase.SELECT + else: + sql = SQLDatabase.SELECT_ALL + + matches = [] + try: + self.cursor.execute(sql, (key,)) + + # collect all matches + records = self.cursor.fetchall() + for record in records: + matches.append((record[SQLDatabase.FIELD_SONG_ID], record[SQLDatabase.FIELD_OFFSET])) + + except mysql.Error, e: + print "Error in query(), %d: %s" % (e.args[0], e.args[1]) + + return matches + + def get_iterable_kv_pairs(self): + """ + Returns all tuples in database. + """ + return self.query(None) + + def insert_hashes(self, hashes): + """ + Insert series of hash => song_id, offset + values into the database. + """ + for h in hashes: + sha1, val = h + self.insert(sha1, val) + self.connection.commit() + + def return_matches(self, hashes): + """ + Return the (song_id, offset_diff) tuples associated with + a list of + + sha1 => (None, sample_offset) + + values. + """ + matches = [] + for h in hashes: + sha1, val = h + list_of_tups = self.query(sha1) + if list_of_tups: + for t in list_of_tups: + # (song_id, db_offset, song_sampled_offset) + matches.append((t[0], t[1] - val[1])) + return matches diff --git a/dejavu/fingerprint.py b/dejavu/fingerprint.py new file mode 100644 index 0000000..9d7b2ee --- /dev/null +++ b/dejavu/fingerprint.py @@ -0,0 +1,224 @@ +import numpy as np +import matplotlib.mlab as mlab +import matplotlib.pyplot as plt +import matplotlib.image as mpimg +from scipy.io import wavfile +from scipy.ndimage.filters import maximum_filter +from scipy.ndimage.morphology import generate_binary_structure, iterate_structure, binary_erosion +from dejavu.database import SQLDatabase +import os +import wave +import sys +import time +import hashlib +import pickle + +class Fingerprinter(): + + IDX_FREQ_I = 0 + IDX_TIME_J = 1 + + DEFAULT_FS = 44100 + DEFAULT_WINDOW_SIZE = 4096 + DEFAULT_OVERLAP_RATIO = 0.5 + DEFAULT_FAN_VALUE = 15 + + DEFAULT_AMP_MIN = 10 + PEAK_NEIGHBORHOOD_SIZE = 20 + MIN_HASH_TIME_DELTA = 0 + + def __init__(self, config, + Fs=DEFAULT_FS, + wsize=DEFAULT_WINDOW_SIZE, + wratio=DEFAULT_OVERLAP_RATIO, + fan_value=DEFAULT_FAN_VALUE, + amp_min=DEFAULT_AMP_MIN): + + self.config = config + database = SQLDatabase( + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_HOSTNAME), + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_USERNAME), + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_PASSWORD), + self.config.get(SQLDatabase.CONNECTION, SQLDatabase.KEY_DATABASE)) + self.db = database + + self.Fs = Fs + self.dt = 1.0 / self.Fs + self.window_size = wsize + self.window_overlap_ratio = wratio + self.fan_value = fan_value + self.noverlap = int(self.window_size * self.window_overlap_ratio) + self.amp_min = amp_min + + def fingerprint(self, samples, path, sid, cid): + """Used for learning known songs""" + hashes = self.process_channel(samples, song_id=sid) + print "Generated %d hashes" % len(hashes) + self.db.insert_hashes(hashes) + + def match(self, samples): + """Used for matching unknown songs""" + hashes = self.process_channel(samples) + matches = self.db.return_matches(hashes) + return matches + + def process_channel(self, channel_samples, song_id=None): + """ + FFT the channel, log transform output, find local maxima, then return + locally sensitive hashes. + """ + # FFT the signal and extract frequency components + arr2D = mlab.specgram( + channel_samples, + NFFT=self.window_size, + Fs=self.Fs, + window=mlab.window_hanning, + noverlap=self.noverlap)[0] + + # apply log transform since specgram() returns linear array + arr2D = 10 * np.log10(arr2D) + arr2D[arr2D == -np.inf] = 0 # replace infs with zeros + + # find local maxima + local_maxima = self.get_2D_peaks(arr2D, plot=False) + + # return hashes + return self.generate_hashes(local_maxima, song_id=song_id) + + def get_2D_peaks(self, arr2D, plot=False): + + # http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.morphology.iterate_structure.html#scipy.ndimage.morphology.iterate_structure + struct = generate_binary_structure(2, 1) + neighborhood = iterate_structure(struct, Fingerprinter.PEAK_NEIGHBORHOOD_SIZE) + + # find local maxima using our fliter shape + local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D + background = (arr2D == 0) + eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) + detected_peaks = local_max - eroded_background # this is a boolean mask of arr2D with True at peaks + + # extract peaks + amps = arr2D[detected_peaks] + j, i = np.where(detected_peaks) + + # filter peaks + amps = amps.flatten() + peaks = zip(i, j, amps) + peaks_filtered = [x for x in peaks if x[2] > self.amp_min] # freq, time, amp + + # get indices for frequency and time + frequency_idx = [x[1] for x in peaks_filtered] + time_idx = [x[0] for x in peaks_filtered] + + if plot: + # scatter of the peaks + fig, ax = plt.subplots() + ax.imshow(arr2D) + ax.scatter(time_idx, frequency_idx) + ax.set_xlabel('Time') + ax.set_ylabel('Frequency') + ax.set_title("Spectrogram of \"Blurred Lines\" by Robin Thicke"); + plt.gca().invert_yaxis() + plt.show() + + return zip(frequency_idx, time_idx) + + def generate_hashes(self, peaks, song_id=None): + """ + Hash list structure: + sha1-hash[0:20] song_id, time_offset + [(e05b341a9b77a51fd26, (3, 32)), ... ] + """ + fingerprinted = set() # to avoid rehashing same pairs + hashes = [] + + for i in range(len(peaks)): + for j in range(self.fan_value): + if i+j < len(peaks) and not (i, i+j) in fingerprinted: + + freq1 = peaks[i][Fingerprinter.IDX_FREQ_I] + freq2 = peaks[i+j][Fingerprinter.IDX_FREQ_I] + t1 = peaks[i][Fingerprinter.IDX_TIME_J] + t2 = peaks[i+j][Fingerprinter.IDX_TIME_J] + t_delta = t2 - t1 + + if t_delta >= Fingerprinter.MIN_HASH_TIME_DELTA: + h = hashlib.sha1("%s|%s|%s" % (str(freq1), str(freq2), str(t_delta))) + hashes.append((h.hexdigest()[0:20], (song_id, t1))) + + # ensure we don't repeat hashing + fingerprinted.add((i, i+j)) + return hashes + + def insert_into_db(self, key, value): + self.db.insert(key, value) + + def print_stats(self): + + iterable = self.db.get_iterable_kv_pairs() + + counter = {} + for t in iterable: + sid, toff = t + if not sid in counter: + counter[sid] = 1 + else: + counter[sid] += 1 + + for song_id, count in counter.iteritems(): + song_name = self.song_names[song_id] + print "%s has %d spectrogram peaks" % (song_name, count) + + def set_song_names(self, wpaths): + self.song_names = wpaths + + def align_matches(self, matches, starttime, record_seconds=0, verbose=False): + """ + Finds hash matches that align in time with other matches and finds + consensus about which hashes are "true" signal from the audio. + + Returns a dictionary with match information. + """ + # align by diffs + diff_counter = {} + largest = 0 + largest_count = 0 + song_id = -1 + for tup in matches: + sid, diff = tup + if not diff in diff_counter: + diff_counter[diff] = {} + if not sid in diff_counter[diff]: + diff_counter[diff][sid] = 0 + diff_counter[diff][sid] += 1 + + if diff_counter[diff][sid] > largest_count: + largest = diff + largest_count = diff_counter[diff][sid] + song_id = sid + + if verbose: print "Diff is %d with %d offset-aligned matches" % (largest, largest_count) + + #from collections import OrderedDict + #print OrderedDict(diff_counter) + + # extract idenfication + songname = self.db.get_song_by_id(song_id)[SQLDatabase.FIELD_SONGNAME] + songname = songname.replace("_", " ") + elapsed = time.time() - starttime + + if verbose: + print "Song is %s (song ID = %d) identification took %f seconds" % (songname, song_id, elapsed) + + # return match info + song = { + "song_id" : song_id, + "song_name" : songname, + "match_time" : elapsed, + "confidence" : largest_count + } + + if record_seconds: + song['record_time'] = record_seconds + + return song diff --git a/dejavu/recognize.py b/dejavu/recognize.py new file mode 100644 index 0000000..bcb5d3c --- /dev/null +++ b/dejavu/recognize.py @@ -0,0 +1,72 @@ +from multiprocessing import Queue, Process +from dejavu.database import SQLDatabase +from scipy.io import wavfile +import wave +import numpy as np +import pyaudio +import sys +import time +import array + +class Recognizer(object): + + CHUNK = 8192 # 44100 is a multiple of 1225 + FORMAT = pyaudio.paInt16 + CHANNELS = 2 + RATE = 44100 + + def __init__(self, fingerprinter, config): + + self.fingerprinter = fingerprinter + self.config = config + self.audio = pyaudio.PyAudio() + + def read(self, filename, verbose=False): + + # read file into channels + channels = [] + Fs, frames = wavfile.read(filename) + wave_object = wave.open(filename) + nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() + for channel in range(nchannels): + channels.append(frames[:, channel]) + + # get matches + starttime = time.time() + matches = [] + for channel in channels: + matches.extend(self.fingerprinter.match(channel)) + + return self.fingerprinter.align_matches(matches, starttime, verbose=verbose) + + def listen(self, seconds=10, verbose=False): + + # open stream + stream = self.audio.open(format=Recognizer.FORMAT, + channels=Recognizer.CHANNELS, + rate=Recognizer.RATE, + input=True, + frames_per_buffer=Recognizer.CHUNK) + + # record + if verbose: print("* recording") + left, right = [], [] + for i in range(0, int(Recognizer.RATE / Recognizer.CHUNK * seconds)): + data = stream.read(Recognizer.CHUNK) + nums = np.fromstring(data, np.int16) + left.extend(nums[1::2]) + right.extend(nums[0::2]) + if verbose: print("* done recording") + + # close and stop the stream + stream.stop_stream() + stream.close() + + # match both channels + starttime = time.time() + matches = [] + matches.extend(self.fingerprinter.match(left)) + matches.extend(self.fingerprinter.match(right)) + + # align and return + return self.fingerprinter.align_matches(matches, starttime, record_seconds=seconds, verbose=verbose) \ No newline at end of file diff --git a/go.py b/go.py new file mode 100644 index 0000000..7b2d4cb --- /dev/null +++ b/go.py @@ -0,0 +1,20 @@ +from dejavu.control import Dejavu +from ConfigParser import ConfigParser +import warnings +warnings.filterwarnings("ignore") + +# load config +config = ConfigParser() +config.read("dejavu.cnf") + +# create Dejavu object +dejavu = Dejavu(config) +dejavu.fingerprint("va_us_top_40/mp3", "va_us_top_40/wav", [".mp3"], 5) + +# recognize microphone audio +from dejavu.recognize import Recognizer +recognizer = Recognizer(dejavu.fingerprinter, config) + +# recognize song playing over microphone for 10 seconds +song = recognizer.listen(seconds=1, verbose=True) +print song \ No newline at end of file diff --git a/performance.py b/performance.py new file mode 100644 index 0000000..2bee5d0 --- /dev/null +++ b/performance.py @@ -0,0 +1,159 @@ +from dejavu.control import Dejavu +from dejavu.recognize import Recognizer +from dejavu.convert import Converter +from dejavu.database import SQLDatabase +from ConfigParser import ConfigParser +from scipy.io import wavfile +import matplotlib.pyplot as plt +import warnings +import pyaudio +import os, wave, sys +import random +import numpy as np +warnings.filterwarnings("ignore") + +config = ConfigParser() +config.read("dejavu.cnf") +dejavu = Dejavu(config) +recognizer = Recognizer(dejavu.fingerprinter, config) + +def test_recording_lengths(recognizer): + + # settings for run + RATE = 44100 + FORMAT = pyaudio.paInt16 + padding_seconds = 10 + SONG_PADDING = RATE * padding_seconds + OUTPUT_FILE = "output.wav" + p = pyaudio.PyAudio() + c = Converter() + files = c.find_files("va_us_top_40/wav/", [".wav"])[-25:] + total = len(files) + recording_lengths = [4] + correct = 0 + count = 0 + score = {} + + for r in recording_lengths: + + RECORD_LENGTH = RATE * r + + for tup in files: + f, ext = tup + + # read the file + #print "reading: %s" % f + Fs, frames = wavfile.read(f) + wave_object = wave.open(f) + nchannels, sampwidth, framerate, num_frames, comptype, compname = wave_object.getparams() + + # chose at random a segment of audio to play + possible_end = num_frames - SONG_PADDING - RECORD_LENGTH + possible_start = SONG_PADDING + if possible_end - possible_start < RECORD_LENGTH: + print "ERROR! Song is too short to sample based on padding and recording seconds preferences." + sys.exit() + start = random.randint(possible_start, possible_end) + end = start + RECORD_LENGTH + 1 + + # get that segment of samples + channels = [] + frames = frames[start:end, :] + wav_string = frames.tostring() + + # write to disk + wf = wave.open(OUTPUT_FILE, 'wb') + wf.setnchannels(nchannels) + wf.setsampwidth(p.get_sample_size(FORMAT)) + wf.setframerate(RATE) + wf.writeframes(b''.join(wav_string)) + wf.close() + + # play and test + correctname = os.path.basename(f).replace(".wav", "").replace("_", " ") + inp = raw_input("Click ENTER when playing %s ..." % OUTPUT_FILE) + song = recognizer.listen(seconds=r+1, verbose=False) + print "PREDICTED: %s" % song['song_name'] + print "ACTUAL: %s" % correctname + if song['song_name'] == correctname: + correct += 1 + count += 1 + + print "Currently %d correct out of %d in total of %d" % (correct, count, total) + + score[r] = (correct, total) + print "UPDATE AFTER %d TRIAL: %s" % (r, score) + + return score + +def plot_match_time_trials(): + + # I did this manually + t = np.array([1, 2, 3, 4, 5, 6, 7, 8, 10, 15, 25, 30, 45, 60]) + m = np.array([.47, .79, 1.1, 1.5, 1.8, 2.18, 2.62, 2.8, 3.65, 5.29, 8.92, 10.63, 16.09, 22.29]) + mplust = t + m + + # linear regression + A = np.matrix([t, np.ones(len(t))]) + print A + w = np.linalg.lstsq(A.T, mplust)[0] + line = w[0] * t + w[1] + print "Equation for line is %f * record_time + %f = time_to_match" % (w[0], w[1]) + + # and plot + plt.title("Recording vs Matching time for \"Get Lucky\" by Daft Punk") + plt.xlabel("Time recorded (s)") + plt.ylabel("Time recorded + time to match (s)") + #plt.scatter(t, mplust) + plt.plot(t, line, 'r-', t, mplust, 'o') + plt.show() + +def plot_accuracy(): + # also did this manually + secs = np.array([1, 2, 3, 4, 5, 6]) + correct = np.array([27.0, 43.0, 44.0, 44.0, 45.0, 45.0]) + total = 45.0 + correct = correct / total + + plt.title("Dejavu Recognition Accuracy as a Function of Time") + plt.xlabel("Time recorded (s)") + plt.ylabel("Accuracy") + plt.plot(secs, correct) + plt.ylim([0.0, 1.05]) + plt.show() + +def plot_hashes_per_song(): + squery = """select song_name, count(song_id) as num + from fingerprints + natural join songs + group by song_name + order by count(song_id) asc;""" + sql = SQLDatabase(username="root", password="root", database="dejavu", hostname="localhost") + cursor = sql.connection.cursor() + cursor.execute(squery) + counts = cursor.fetchall() + + songs = [] + count = [] + for item in counts: + songs.append(item['song_name'].replace("_", " ")[4:]) + count.append(item['num']) + + pos = np.arange(len(songs)) + 0.5 + + fig = plt.figure() + ax = fig.add_subplot(111) + ax.barh(pos, count, align='center') + ax.set_yticks(pos, tuple(songs)) + + ax.axvline(0, color='k', lw=3) + + ax.set_xlabel('Number of Fingerprints') + ax.set_title('Number of Fingerprints by Song') + ax.grid(True) + plt.show() + +#plot_accuracy() +#score = test_recording_lengths(recognizer) +#plot_match_time_trials() +#plot_hashes_per_song() diff --git a/plots/accuracy.png b/plots/accuracy.png new file mode 100644 index 0000000..5daac25 Binary files /dev/null and b/plots/accuracy.png differ diff --git a/plots/blurred_lines_spectrogram.png b/plots/blurred_lines_spectrogram.png new file mode 100644 index 0000000..bfcd2e5 Binary files /dev/null and b/plots/blurred_lines_spectrogram.png differ diff --git a/plots/blurred_lines_vertical.png b/plots/blurred_lines_vertical.png new file mode 100644 index 0000000..9bad0ec Binary files /dev/null and b/plots/blurred_lines_vertical.png differ diff --git a/plots/blurred_lines_zoomed.png b/plots/blurred_lines_zoomed.png new file mode 100644 index 0000000..0ba8c64 Binary files /dev/null and b/plots/blurred_lines_zoomed.png differ diff --git a/plots/matching_time.png b/plots/matching_time.png new file mode 100644 index 0000000..b104120 Binary files /dev/null and b/plots/matching_time.png differ