diff --git a/dejavu/__init__.py b/dejavu/__init__.py index 6194cec..9d4c8a1 100755 --- a/dejavu/__init__.py +++ b/dejavu/__init__.py @@ -36,12 +36,11 @@ class Dejavu(object): def get_fingerprinted_songs(self): # get songs previously indexed - # TODO: should probably use a checksum of the file instead of filename self.songs = self.db.get_songs() - self.songnames_set = set() # to know which ones we've computed before + self.songhashes_set = set() # to know which ones we've computed before for song in self.songs: - song_name = song[self.db.FIELD_SONGNAME] - self.songnames_set.add(song_name) + song_hash = song[self.db.FIELD_SHA1] + self.songhashes_set.add(song_hash) def fingerprint_directory(self, path, extensions, nprocesses=None): # Try to use the maximum amount of processes if not given. @@ -58,7 +57,7 @@ class Dejavu(object): for filename, _ in decoder.find_files(path, extensions): # don't refingerprint already fingerprinted files - if decoder.path_to_songname(filename) in self.songnames_set: + if decoder.unique_hash(filename) in self.songhashes_set: print "%s already fingerprinted, continuing..." % filename continue @@ -75,7 +74,7 @@ class Dejavu(object): # Loop till we have all of them while True: try: - song_name, hashes = iterator.next() + song_name, hashes, file_hash = iterator.next() except multiprocessing.TimeoutError: continue except StopIteration: @@ -85,7 +84,7 @@ class Dejavu(object): # Print traceback because we can't reraise it here traceback.print_exc(file=sys.stdout) else: - sid = self.db.insert_song(song_name) + sid = self.db.insert_song(song_name, file_hash) self.db.insert_hashes(sid, hashes) self.db.set_song_fingerprinted(sid) @@ -96,16 +95,18 @@ class Dejavu(object): def fingerprint_file(self, filepath, song_name=None): songname = decoder.path_to_songname(filepath) + song_hash = decoder.unique_hash(filepath) song_name = song_name or songname # don't refingerprint already fingerprinted files - if song_name in self.songnames_set: + if song_hash in self.songhashes_set: print "%s already fingerprinted, continuing..." % song_name else: - song_name, hashes = _fingerprint_worker(filepath, - self.limit, - song_name=song_name) - - sid = self.db.insert_song(song_name) + song_name, hashes, file_hash = _fingerprint_worker( + filepath, + self.limit, + song_name=song_name + ) + sid = self.db.insert_song(song_name, file_hash) self.db.insert_hashes(sid, hashes) self.db.set_song_fingerprinted(sid) @@ -177,7 +178,7 @@ def _fingerprint_worker(filename, limit=None, song_name=None): songname, extension = os.path.splitext(os.path.basename(filename)) song_name = song_name or songname - channels, Fs = decoder.read(filename, limit) + channels, Fs, file_hash = decoder.read(filename, limit) result = set() channel_amount = len(channels) @@ -191,7 +192,7 @@ def _fingerprint_worker(filename, limit=None, song_name=None): filename)) result |= set(hashes) - return song_name, result + return song_name, result, file_hash def chunkify(lst, n): diff --git a/dejavu/database_sql.py b/dejavu/database_sql.py index 07f5853..d7c0dc9 100755 --- a/dejavu/database_sql.py +++ b/dejavu/database_sql.py @@ -54,6 +54,7 @@ class SQLDatabase(Database): FIELD_HASH = "hash" FIELD_SONG_ID = "song_id" FIELD_OFFSET = "offset" + FIELD_SHA1 = 'file_sha1' FIELD_SONGNAME = "song_name" FIELD_FINGERPRINTED = "fingerprinted" @@ -78,10 +79,12 @@ class SQLDatabase(Database): `%s` mediumint unsigned not null auto_increment, `%s` varchar(250) not null, `%s` tinyint default 0, + `%s` binary(10) not null, PRIMARY KEY (`%s`), UNIQUE KEY `%s` (`%s`) ) ENGINE=INNODB;""" % ( SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED, + FIELD_SHA1, FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID, ) @@ -91,8 +94,8 @@ class SQLDatabase(Database): (UNHEX(%%s), %%s, %%s); """ % (FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) - INSERT_SONG = "INSERT INTO %s (%s) values (%%s);" % ( - SONGS_TABLENAME, FIELD_SONGNAME) + INSERT_SONG = "INSERT INTO %s (%s, %s) values (%%s, UNHEX(%%s));" % ( + SONGS_TABLENAME, FIELD_SONGNAME, FIELD_SHA1) # selects SELECT = """ @@ -109,8 +112,8 @@ class SQLDatabase(Database): """ % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME) SELECT_SONG = """ - SELECT %s FROM %s WHERE %s = %%s - """ % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID) + SELECT %s, HEX(%s) FROM %s WHERE %s = %%s + """ % (FIELD_SONGNAME, FIELD_SHA1, SONGS_TABLENAME, FIELD_SONG_ID) SELECT_NUM_FINGERPRINTS = """ SELECT COUNT(*) as n FROM %s @@ -121,8 +124,9 @@ class SQLDatabase(Database): """ % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) SELECT_SONGS = """ - SELECT %s, %s FROM %s WHERE %s = 1; - """ % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED) + SELECT %s, %s, HEX(%s) FROM %s WHERE %s = 1; + """ % (FIELD_SONG_ID, FIELD_SONGNAME, FIELD_SHA1, + SONGS_TABLENAME, FIELD_FINGERPRINTED) # drops DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME @@ -235,12 +239,12 @@ class SQLDatabase(Database): with self.cursor() as cur: cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset)) - def insert_song(self, songname): + def insert_song(self, songname, file_hash): """ Inserts song in the database and returns the ID of the inserted record. """ with self.cursor() as cur: - cur.execute(self.INSERT_SONG, (songname,)) + cur.execute(self.INSERT_SONG, (songname, file_hash)) return cur.lastrowid def query(self, hash): diff --git a/dejavu/decoder.py b/dejavu/decoder.py index 830b8f7..dde6b6b 100755 --- a/dejavu/decoder.py +++ b/dejavu/decoder.py @@ -4,6 +4,20 @@ import numpy as np from pydub import AudioSegment from pydub.utils import audioop import wavio +from hashlib import sha1 + +def unique_hash(filepath): + """ Small function to generate a hash to uniquely generate + a file. Taken / inspired from git's way via stackoverflow: + http://stackoverflow.com/questions/552659 + """ + filesize_bytes = os.path.getsize(filepath) + s = sha1() + s.update(("blob %u\0" % filesize_bytes).encode('ascii')) + with open(filepath, 'rb') as f: + s.update(f.read()) + return s.hexdigest() + def find_files(path, extensions): # Allow both with ".mp3" and without "mp3" to be used for extensions @@ -55,7 +69,7 @@ def read(filename, limit=None): for chn in audiofile: channels.append(chn) - return channels, fs + return channels, audiofile.frame_rate, unique_hash(filename) def path_to_songname(path): diff --git a/dejavu/recognize.py b/dejavu/recognize.py index a9d0f59..b43a879 100755 --- a/dejavu/recognize.py +++ b/dejavu/recognize.py @@ -26,7 +26,7 @@ class FileRecognizer(BaseRecognizer): super(FileRecognizer, self).__init__(dejavu) def recognize_file(self, filename): - frames, self.Fs = decoder.read(filename, self.dejavu.limit) + frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit) t = time.time() match = self._recognize(*frames)