mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 11:09:52 +00:00
- Extracted common logic from database handlers.
- Fixed tests. - Refactored solution architecture once more. - Refactored solution hierarchy. - Adding audios for testing. - Solved flake8 issues and reorganized several imports in the process.
This commit is contained in:
parent
6bc15ab8d4
commit
e046eeee93
32 changed files with 1154 additions and 1058 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,7 +1,3 @@
|
|||
*.pyc
|
||||
wav
|
||||
mp3
|
||||
*.wav
|
||||
*.mp3
|
||||
.DS_Store
|
||||
*.cnf
|
||||
|
|
|
@ -4,5 +4,6 @@
|
|||
"user": "root",
|
||||
"password": "rootpass",
|
||||
"database": "dejavu"
|
||||
}
|
||||
},
|
||||
"database_type": "mysql"
|
||||
}
|
||||
|
|
24
dejavu.py
24
dejavu.py
|
@ -1,16 +1,12 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from os.path import isdir
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from dejavu import Dejavu
|
||||
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||
|
||||
DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE"
|
||||
|
||||
|
@ -58,7 +54,6 @@ if __name__ == '__main__':
|
|||
config_file = args.config
|
||||
if config_file is None:
|
||||
config_file = DEFAULT_CONFIG_FILE
|
||||
# print ("Using default config file: {config_file}")
|
||||
|
||||
djv = init(config_file)
|
||||
if args.fingerprint:
|
||||
|
@ -71,22 +66,19 @@ if __name__ == '__main__':
|
|||
|
||||
elif len(args.fingerprint) == 1:
|
||||
filepath = args.fingerprint[0]
|
||||
if os.path.isdir(filepath):
|
||||
if isdir(filepath):
|
||||
print("Please specify an extension if you'd like to fingerprint a directory!")
|
||||
sys.exit(1)
|
||||
djv.fingerprint_file(filepath)
|
||||
|
||||
elif args.recognize:
|
||||
# Recognize audio source
|
||||
song = None
|
||||
songs = None
|
||||
source = args.recognize[0]
|
||||
opt_arg = args.recognize[1]
|
||||
|
||||
if source in ('mic', 'microphone'):
|
||||
song = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
|
||||
songs = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
|
||||
elif source == 'file':
|
||||
song = djv.recognize(FileRecognizer, opt_arg)
|
||||
decoded_song = repr(song).decode('string_escape')
|
||||
print(decoded_song)
|
||||
|
||||
sys.exit(0)
|
||||
songs = djv.recognize(FileRecognizer, opt_arg)
|
||||
print(songs)
|
||||
|
|
|
@ -3,13 +3,13 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import dejavu.decoder as decoder
|
||||
from dejavu.config.config import (CONFIDENCE, DEFAULT_FS,
|
||||
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||
FIELD_FILE_SHA1, OFFSET, OFFSET_SECS,
|
||||
SONG_ID, SONG_NAME, TOPN)
|
||||
from dejavu.database import get_database
|
||||
from dejavu.fingerprint import fingerprint
|
||||
import dejavu.logic.decoder as decoder
|
||||
from dejavu.base_classes.base_database import get_database
|
||||
from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS,
|
||||
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||
FIELD_FILE_SHA1, OFFSET, OFFSET_SECS,
|
||||
SONG_ID, SONG_NAME, TOPN)
|
||||
from dejavu.logic.fingerprint import fingerprint
|
||||
|
||||
|
||||
class Dejavu:
|
||||
|
@ -71,7 +71,7 @@ class Dejavu:
|
|||
continue
|
||||
except StopIteration:
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
print("Failed fingerprinting")
|
||||
# Print traceback because we can't reraise it here
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
@ -119,6 +119,7 @@ class Dejavu:
|
|||
diff_counter = {}
|
||||
largest_count = 0
|
||||
|
||||
# TODO: review logic to get topn results.
|
||||
for tup in matches:
|
||||
sid, diff = tup
|
||||
if diff not in diff_counter:
|
||||
|
|
0
dejavu/base_classes/__init__.py
Normal file
0
dejavu/base_classes/__init__.py
Normal file
|
@ -1,9 +1,11 @@
|
|||
import abc
|
||||
import importlib
|
||||
from dejavu.config.config import DATABASES
|
||||
from typing import Dict
|
||||
|
||||
from dejavu.config.settings import DATABASES
|
||||
|
||||
|
||||
class Database(object, metaclass=abc.ABCMeta):
|
||||
class BaseDatabase(object, metaclass=abc.ABCMeta):
|
||||
# Name of your Database subclass, this is used in configuration
|
||||
# to refer to your class
|
||||
type = None
|
||||
|
@ -70,17 +72,16 @@ class Database(object, metaclass=abc.ABCMeta):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_songs(self):
|
||||
def get_songs(self) -> Dict[str, str]:
|
||||
"""
|
||||
Returns all fully fingerprinted songs in the database.
|
||||
Returns all fully fingerprinted songs in the database. Result must be a Dictionary.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_song_by_id(self, sid):
|
||||
def get_song_by_id(self, sid) -> Dict[str, str]:
|
||||
"""
|
||||
Return a song by its identifier
|
||||
|
||||
Return a song by its identifier. Result must be a Dictionary.
|
||||
sid: Song identifier
|
||||
"""
|
||||
pass
|
||||
|
@ -124,7 +125,7 @@ class Database(object, metaclass=abc.ABCMeta):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_hashes(self, sid, hashes):
|
||||
def insert_hashes(self, sid, hashes, batch=1000):
|
||||
"""
|
||||
Insert a multitude of fingerprints.
|
||||
|
||||
|
@ -133,7 +134,6 @@ class Database(object, metaclass=abc.ABCMeta):
|
|||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def return_matches(self, hashes):
|
19
dejavu/base_classes/base_recognizer.py
Normal file
19
dejavu/base_classes/base_recognizer.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import abc
|
||||
|
||||
from dejavu.config.settings import DEFAULT_FS
|
||||
|
||||
|
||||
class BaseRecognizer(object, metaclass=abc.ABCMeta):
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = DEFAULT_FS
|
||||
|
||||
def _recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
194
dejavu/base_classes/common_database.py
Normal file
194
dejavu/base_classes/common_database.py
Normal file
|
@ -0,0 +1,194 @@
|
|||
import abc
|
||||
|
||||
from dejavu.base_classes.base_database import BaseDatabase
|
||||
|
||||
|
||||
class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
|
||||
# Since several methods across different databases are actually just the same
|
||||
# I've built this class with the idea to reuse that logic instead of copy pasting
|
||||
# over and over the same code.
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def before_fork(self):
|
||||
"""
|
||||
Called before the database instance is given to the new process
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_fork(self):
|
||||
"""
|
||||
Called after the database instance has been given to the new process
|
||||
|
||||
This will be called in the new process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Called on creation or shortly afterwards.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.CREATE_SONGS_TABLE)
|
||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self):
|
||||
"""
|
||||
Called when the database should be cleared of all data.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.DROP_FINGERPRINTS)
|
||||
cur.execute(self.DROP_SONGS)
|
||||
|
||||
self.setup()
|
||||
|
||||
def delete_unfingerprinted_songs(self):
|
||||
"""
|
||||
Called to remove any song entries that do not have any fingerprints
|
||||
associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns the amount of songs in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
|
||||
return count
|
||||
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns the number of fingerprints in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.SELECT_NUM_FINGERPRINTS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
cur.close()
|
||||
|
||||
return count
|
||||
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Sets a specific song as having all fingerprints in the database.
|
||||
|
||||
sid: Song identifier
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
|
||||
|
||||
def get_songs(self):
|
||||
"""
|
||||
Returns all fully fingerprinted songs in the database. Result must be a Dictionary.
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(self.SELECT_SONGS)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Return a song by its identifier. Result must be a Dictionary.
|
||||
sid: Song identifier
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(self.SELECT_SONG, (sid,))
|
||||
return cur.fetchone()
|
||||
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Inserts a single fingerprint into the database.
|
||||
|
||||
hash: Part of a sha1 hash, in hexadecimal format
|
||||
sid: Song identifier this fingerprint is off
|
||||
offset: The offset this hash is from
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_song(self, song_name):
|
||||
"""
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
song_name: The name of the song.
|
||||
"""
|
||||
pass
|
||||
|
||||
def query(self, fingerprint):
|
||||
"""
|
||||
Returns all matching fingerprint entries associated with
|
||||
the given hash as parameter.
|
||||
|
||||
hash: Part of a sha1 hash, in hexadecimal format
|
||||
"""
|
||||
if fingerprint:
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.SELECT, (fingerprint,))
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
else: # select all if no key
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.SELECT_ALL)
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all fingerprints in the database.
|
||||
"""
|
||||
return self.query(None)
|
||||
|
||||
def insert_hashes(self, sid, hashes, batch=1000):
|
||||
"""
|
||||
Insert a multitude of fingerprints.
|
||||
|
||||
sid: Song identifier the fingerprints belong to
|
||||
hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
"""
|
||||
values = [(sid, hsh, int(offset)) for hsh, offset in hashes]
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(hashes), batch):
|
||||
cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch])
|
||||
|
||||
def return_matches(self, hashes, batch=1000):
|
||||
"""
|
||||
Searches the database for pairs of (hash, offset) values.
|
||||
|
||||
hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
|
||||
Returns a sequence of (sid, offset_difference) tuples.
|
||||
|
||||
sid: Song identifier
|
||||
offset_difference: (offset - database_offset)
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hsh, offset in hashes:
|
||||
mapper[hsh.upper()] = offset
|
||||
|
||||
# Get an iterable of all the hashes we need
|
||||
values = list(mapper.keys())
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(values), batch):
|
||||
# Create our IN part of the query
|
||||
query = self.SELECT_MULTIPLE
|
||||
query = query % ', '.join([self.IN_MATCH] * len(values[index: index + batch]))
|
||||
|
||||
cur.execute(query, values[index: index + batch])
|
||||
|
||||
for hsh, sid, offset in cur:
|
||||
# (sid, db_offset - song_sampled_offset)
|
||||
yield (sid, offset - mapper[hsh])
|
|
@ -3,13 +3,114 @@ import queue
|
|||
import mysql.connector
|
||||
from mysql.connector.errors import DatabaseError
|
||||
|
||||
import dejavu.database_handler.mysql_queries as queries
|
||||
from dejavu.database import Database
|
||||
from dejavu.base_classes.common_database import CommonDatabase
|
||||
from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
||||
SONGS_TABLENAME)
|
||||
|
||||
|
||||
class MySQLDatabase(Database):
|
||||
class MySQLDatabase(CommonDatabase):
|
||||
type = "mysql"
|
||||
|
||||
# CREATES
|
||||
CREATE_SONGS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
|
||||
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
|
||||
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
|
||||
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
|
||||
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
|
||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
|
||||
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
|
||||
) ENGINE=INNODB;
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
|
||||
`{FIELD_HASH}` BINARY(10) NOT NULL
|
||||
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
|
||||
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
|
||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
|
||||
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
|
||||
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
|
||||
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
|
||||
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
|
||||
) ENGINE=INNODB;
|
||||
"""
|
||||
|
||||
# INSERTS (IGNORES DUPLICATES)
|
||||
INSERT_FINGERPRINT = f"""
|
||||
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
|
||||
`{FIELD_SONG_ID}`
|
||||
, `{FIELD_HASH}`
|
||||
, `{FIELD_OFFSET}`)
|
||||
VALUES (%s, UNHEX(%s), %s);
|
||||
"""
|
||||
|
||||
INSERT_SONG = f"""
|
||||
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`)
|
||||
VALUES (%s, UNHEX(%s));
|
||||
"""
|
||||
|
||||
# SELECTS
|
||||
SELECT = f"""
|
||||
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||
FROM `{FINGERPRINTS_TABLENAME}`
|
||||
WHERE `{FIELD_HASH}` = UNHEX(%s);
|
||||
"""
|
||||
|
||||
SELECT_MULTIPLE = f"""
|
||||
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||
FROM `{FINGERPRINTS_TABLENAME}`
|
||||
WHERE `{FIELD_HASH}` IN (%s);
|
||||
"""
|
||||
|
||||
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||
|
||||
SELECT_SONG = f"""
|
||||
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_SONG_ID}` = %s;
|
||||
"""
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = f"""
|
||||
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||
"""
|
||||
|
||||
SELECT_SONGS = f"""
|
||||
SELECT
|
||||
`{FIELD_SONG_ID}`
|
||||
, `{FIELD_SONGNAME}`
|
||||
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||
"""
|
||||
|
||||
# DROPS
|
||||
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
|
||||
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
|
||||
|
||||
# update
|
||||
UPDATE_SONG_FINGERPRINTED = f"""
|
||||
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
|
||||
"""
|
||||
|
||||
# DELETE
|
||||
DELETE_UNFINGERPRINTED = f"""
|
||||
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
|
||||
"""
|
||||
|
||||
# IN
|
||||
IN_MATCH = f"UNHEX(%s)"
|
||||
|
||||
def __init__(self, **options):
|
||||
super().__init__()
|
||||
self.cursor = cursor_factory(**options)
|
||||
|
@ -20,160 +121,14 @@ class MySQLDatabase(Database):
|
|||
# the previous process.
|
||||
Cursor.clear_cache()
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Creates any non-existing tables required for dejavu to function.
|
||||
|
||||
This also removes all songs that have been added but have no
|
||||
fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.CREATE_SONGS_TABLE)
|
||||
cur.execute(queries.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self):
|
||||
"""
|
||||
Drops tables created by dejavu and then creates them again
|
||||
by calling `SQLDatabase.setup`.
|
||||
|
||||
.. warning:
|
||||
This will result in a loss of data
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.DROP_FINGERPRINTS)
|
||||
cur.execute(queries.DROP_SONGS)
|
||||
|
||||
self.setup()
|
||||
|
||||
def delete_unfingerprinted_songs(self):
|
||||
"""
|
||||
Removes all songs that have no fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns number of songs the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT_UNIQUE_SONG_IDS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
|
||||
return count
|
||||
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns number of fingerprints the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT_NUM_FINGERPRINTS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
cur.close()
|
||||
|
||||
return count
|
||||
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
||||
fingerprinted in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,))
|
||||
|
||||
def get_songs(self):
|
||||
"""
|
||||
Return songs that have the fingerprinted flag set TRUE (1).
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(queries.SELECT_SONGS)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Returns song by its ID.
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(queries.SELECT_SONG, (sid,))
|
||||
return cur.fetchone()
|
||||
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Insert a (sha1, song_id, offset) row into database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset))
|
||||
|
||||
def insert_song(self, song_name, file_hash):
|
||||
"""
|
||||
Inserts song in the database and returns the ID of the inserted record.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.INSERT_SONG, (song_name, file_hash))
|
||||
cur.execute(self.INSERT_SONG, (song_name, file_hash))
|
||||
return cur.lastrowid
|
||||
|
||||
def query(self, hash):
|
||||
"""
|
||||
Return all tuples associated with hash.
|
||||
|
||||
If hash is None, returns all entries in the
|
||||
database (be careful with that one!).
|
||||
"""
|
||||
if hash:
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT, (hash,))
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
else: # select all if no key
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT_ALL)
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all tuples in database.
|
||||
"""
|
||||
return self.query(None)
|
||||
|
||||
def insert_hashes(self, sid, hashes, batch=1000):
|
||||
"""
|
||||
Insert series of hash => song_id, offset
|
||||
values into the database.
|
||||
"""
|
||||
values = [(sid, hash, int(offset)) for hash, offset in hashes]
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(hashes), batch):
|
||||
cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch])
|
||||
|
||||
def return_matches(self, hashes, batch=1000):
|
||||
"""
|
||||
Return the (song_id, offset_diff) tuples associated with
|
||||
a list of (sha1, sample_offset) values.
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hash, offset in hashes:
|
||||
mapper[hash.upper()] = offset
|
||||
|
||||
# Get an iterable of all the hashes we need
|
||||
values = list(mapper.keys())
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(values), batch):
|
||||
# Create our IN part of the query
|
||||
query = queries.SELECT_MULTIPLE
|
||||
query = query % ', '.join(['UNHEX(%s)'] * len(values[index: index + batch]))
|
||||
|
||||
cur.execute(query, values[index: index + batch])
|
||||
|
||||
for hash, sid, offset in cur:
|
||||
# (sid, db_offset - song_sampled_offset)
|
||||
yield (sid, offset - mapper[hash])
|
||||
|
||||
def __getstate__(self):
|
||||
return self._options,
|
||||
|
||||
|
|
|
@ -1,134 +0,0 @@
|
|||
from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
||||
SONGS_TABLENAME)
|
||||
|
||||
"""
|
||||
Queries:
|
||||
|
||||
1) Find duplicates (shouldn't be any, though):
|
||||
|
||||
select `hash`, `song_id`, `offset`, count(*) cnt
|
||||
from fingerprints
|
||||
group by `hash`, `song_id`, `offset`
|
||||
having cnt > 1
|
||||
order by cnt asc;
|
||||
|
||||
2) Get number of hashes by song:
|
||||
|
||||
select song_id, song_name, count(song_id) as num
|
||||
from fingerprints
|
||||
natural join songs
|
||||
group by song_id
|
||||
order by count(song_id) desc;
|
||||
|
||||
3) get hashes with highest number of collisions
|
||||
|
||||
select
|
||||
hash,
|
||||
count(distinct song_id) as n
|
||||
from fingerprints
|
||||
group by `hash`
|
||||
order by n DESC;
|
||||
|
||||
=> 26 different songs with same fingerprint (392 times):
|
||||
|
||||
select songs.song_name, fingerprints.offset
|
||||
from fingerprints natural join songs
|
||||
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
|
||||
"""
|
||||
|
||||
# creates
|
||||
CREATE_SONGS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
|
||||
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
|
||||
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
|
||||
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
|
||||
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
|
||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
|
||||
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
|
||||
) ENGINE=INNODB;
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
|
||||
`{FIELD_HASH}` BINARY(10) NOT NULL
|
||||
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
|
||||
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
|
||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
|
||||
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
|
||||
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
|
||||
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
|
||||
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
|
||||
) ENGINE=INNODB;
|
||||
"""
|
||||
|
||||
# inserts (ignores duplicates)
|
||||
INSERT_FINGERPRINT = f"""
|
||||
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
|
||||
`{FIELD_SONG_ID}`
|
||||
, `{FIELD_HASH}`
|
||||
, `{FIELD_OFFSET}`)
|
||||
VALUES (%s, UNHEX(%s), %s);
|
||||
"""
|
||||
|
||||
INSERT_SONG = f"""
|
||||
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`)
|
||||
VALUES (%s, UNHEX(%s));
|
||||
"""
|
||||
|
||||
# selects
|
||||
SELECT = f"""
|
||||
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||
FROM `{FINGERPRINTS_TABLENAME}`
|
||||
WHERE `{FIELD_HASH}` = UNHEX(%s);
|
||||
"""
|
||||
|
||||
SELECT_MULTIPLE = f"""
|
||||
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||
FROM `{FINGERPRINTS_TABLENAME}`
|
||||
WHERE `{FIELD_HASH}` IN (%s);
|
||||
"""
|
||||
|
||||
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||
|
||||
SELECT_SONG = f"""
|
||||
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_SONG_ID}` = %s;
|
||||
"""
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = f"""
|
||||
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||
"""
|
||||
|
||||
SELECT_SONGS = f"""
|
||||
SELECT
|
||||
`{FIELD_SONG_ID}`
|
||||
, `{FIELD_SONGNAME}`
|
||||
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||
"""
|
||||
|
||||
# drops
|
||||
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
|
||||
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
|
||||
|
||||
# update
|
||||
UPDATE_SONG_FINGERPRINTED = f"""
|
||||
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
|
||||
"""
|
||||
|
||||
# delete
|
||||
DELETE_UNFINGERPRINTED = f"""
|
||||
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
|
||||
"""
|
|
@ -3,13 +3,124 @@ import queue
|
|||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
|
||||
import dejavu.database_handler.postgres_queries as queries
|
||||
from dejavu.database import Database
|
||||
from dejavu.base_classes.common_database import CommonDatabase
|
||||
from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
||||
SONGS_TABLENAME)
|
||||
|
||||
|
||||
class PostgreSQLDatabase(Database):
|
||||
class PostgreSQLDatabase(CommonDatabase):
|
||||
type = "postgres"
|
||||
|
||||
# CREATES
|
||||
CREATE_SONGS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
|
||||
"{FIELD_SONG_ID}" SERIAL
|
||||
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
|
||||
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
|
||||
, "{FIELD_FILE_SHA1}" BYTEA
|
||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
|
||||
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
|
||||
);
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
|
||||
"{FIELD_HASH}" BYTEA NOT NULL
|
||||
, "{FIELD_SONG_ID}" INT NOT NULL
|
||||
, "{FIELD_OFFSET}" INT NOT NULL
|
||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
|
||||
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
|
||||
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
|
||||
USING hash ("{FIELD_HASH}");
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
|
||||
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
|
||||
USING hash ("{FIELD_HASH}");
|
||||
"""
|
||||
|
||||
# INSERTS (IGNORES DUPLICATES)
|
||||
INSERT_FINGERPRINT = f"""
|
||||
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
|
||||
"{FIELD_SONG_ID}"
|
||||
, "{FIELD_HASH}"
|
||||
, "{FIELD_OFFSET}")
|
||||
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
|
||||
INSERT_SONG = f"""
|
||||
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}")
|
||||
VALUES (%s, decode(%s, 'hex'))
|
||||
RETURNING "{FIELD_SONG_ID}";
|
||||
"""
|
||||
|
||||
# SELECTS
|
||||
SELECT = f"""
|
||||
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||
FROM "{FINGERPRINTS_TABLENAME}"
|
||||
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
|
||||
"""
|
||||
|
||||
SELECT_MULTIPLE = f"""
|
||||
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||
FROM "{FINGERPRINTS_TABLENAME}"
|
||||
WHERE "{FIELD_HASH}" IN (%s);
|
||||
"""
|
||||
|
||||
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
|
||||
|
||||
SELECT_SONG = f"""
|
||||
SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_SONG_ID}" = %s;
|
||||
"""
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = f"""
|
||||
SELECT COUNT("{FIELD_SONG_ID}") AS n
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||
"""
|
||||
|
||||
SELECT_SONGS = f"""
|
||||
SELECT
|
||||
"{FIELD_SONG_ID}"
|
||||
, "{FIELD_SONGNAME}"
|
||||
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||
"""
|
||||
|
||||
# DROPS
|
||||
DROP_FINGERPRINTS = F'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
|
||||
DROP_SONGS = F'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
|
||||
|
||||
# UPDATE
|
||||
UPDATE_SONG_FINGERPRINTED = f"""
|
||||
UPDATE "{SONGS_TABLENAME}" SET
|
||||
"{FIELD_FINGERPRINTED}" = 1
|
||||
, "date_modified" = now()
|
||||
WHERE "{FIELD_SONG_ID}" = %s;
|
||||
"""
|
||||
|
||||
# DELETE
|
||||
DELETE_UNFINGERPRINTED = f"""
|
||||
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
|
||||
"""
|
||||
|
||||
# IN
|
||||
IN_MATCH = f"decode(%s, 'hex')"
|
||||
|
||||
def __init__(self, **options):
|
||||
super().__init__()
|
||||
self.cursor = cursor_factory(**options)
|
||||
|
@ -20,160 +131,14 @@ class PostgreSQLDatabase(Database):
|
|||
# the previous process.
|
||||
Cursor.clear_cache()
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Creates any non-existing tables required for dejavu to function.
|
||||
|
||||
This also removes all songs that have been added but have no
|
||||
fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.CREATE_SONGS_TABLE)
|
||||
cur.execute(queries.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self):
|
||||
"""
|
||||
Drops tables created by dejavu and then creates them again
|
||||
by calling `SQLDatabase.setup`.
|
||||
|
||||
.. warning:
|
||||
This will result in a loss of data
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.DROP_FINGERPRINTS)
|
||||
cur.execute(queries.DROP_SONGS)
|
||||
|
||||
self.setup()
|
||||
|
||||
def delete_unfingerprinted_songs(self):
|
||||
"""
|
||||
Removes all songs that have no fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns number of songs the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT_UNIQUE_SONG_IDS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
|
||||
return count
|
||||
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns number of fingerprints the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT_NUM_FINGERPRINTS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
cur.close()
|
||||
|
||||
return count
|
||||
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
||||
fingerprinted in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,))
|
||||
|
||||
def get_songs(self):
|
||||
"""
|
||||
Return songs that have the fingerprinted flag set TRUE (1).
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(queries.SELECT_SONGS)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Returns song by its ID.
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(queries.SELECT_SONG, (sid,))
|
||||
return cur.fetchone()
|
||||
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Insert a (sha1, song_id, offset) row into database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset))
|
||||
|
||||
def insert_song(self, song_name, file_hash):
|
||||
"""
|
||||
Inserts song in the database and returns the ID of the inserted record.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.INSERT_SONG, (song_name, file_hash))
|
||||
cur.execute(self.INSERT_SONG, (song_name, file_hash))
|
||||
return cur.fetchone()[0]
|
||||
|
||||
def query(self, hash):
|
||||
"""
|
||||
Return all tuples associated with hash.
|
||||
|
||||
If hash is None, returns all entries in the
|
||||
database (be careful with that one!).
|
||||
"""
|
||||
if hash:
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT, (hash,))
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
else: # select all if no key
|
||||
with self.cursor() as cur:
|
||||
cur.execute(queries.SELECT_ALL)
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all tuples in database.
|
||||
"""
|
||||
return self.query(None)
|
||||
|
||||
def insert_hashes(self, sid, hashes, batch=1000):
|
||||
"""
|
||||
Insert series of hash => song_id, offset
|
||||
values into the database.
|
||||
"""
|
||||
values = [(sid, hash, int(offset)) for hash, offset in hashes]
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(hashes), batch):
|
||||
cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch])
|
||||
|
||||
def return_matches(self, hashes, batch=1000):
|
||||
"""
|
||||
Return the (song_id, offset_diff) tuples associated with
|
||||
a list of (sha1, sample_offset) values.
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hash, offset in hashes:
|
||||
mapper[hash.upper()] = offset
|
||||
|
||||
# Get an iterable of all the hashes we need
|
||||
values = list(mapper.keys())
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(values), batch):
|
||||
# Create our IN part of the query
|
||||
query = queries.SELECT_MULTIPLE
|
||||
query = query % ', '.join(["decode(%s, 'hex')"] * len(values[index: index + batch]))
|
||||
|
||||
cur.execute(query, values[index: index + batch])
|
||||
|
||||
for hash, sid, offset in cur:
|
||||
# (sid, db_offset - song_sampled_offset)
|
||||
yield (sid, offset - mapper[hash.upper()])
|
||||
|
||||
def __getstate__(self):
|
||||
return self._options,
|
||||
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
||||
SONGS_TABLENAME)
|
||||
|
||||
# creates
|
||||
CREATE_SONGS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
|
||||
"{FIELD_SONG_ID}" SERIAL
|
||||
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
|
||||
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
|
||||
, "{FIELD_FILE_SHA1}" BYTEA
|
||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
|
||||
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
|
||||
);
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
|
||||
"{FIELD_HASH}" BYTEA NOT NULL
|
||||
, "{FIELD_SONG_ID}" INT NOT NULL
|
||||
, "{FIELD_OFFSET}" INT NOT NULL
|
||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
|
||||
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
|
||||
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}");
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
|
||||
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}");
|
||||
"""
|
||||
|
||||
# inserts (ignores duplicates)
|
||||
INSERT_FINGERPRINT = f"""
|
||||
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
|
||||
"{FIELD_SONG_ID}"
|
||||
, "{FIELD_HASH}"
|
||||
, "{FIELD_OFFSET}")
|
||||
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
|
||||
INSERT_SONG = f"""
|
||||
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}")
|
||||
VALUES (%s, decode(%s, 'hex'))
|
||||
RETURNING "{FIELD_SONG_ID}";
|
||||
"""
|
||||
|
||||
# selects
|
||||
SELECT = f"""
|
||||
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||
FROM "{FINGERPRINTS_TABLENAME}"
|
||||
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
|
||||
"""
|
||||
|
||||
SELECT_MULTIPLE = f"""
|
||||
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||
FROM "{FINGERPRINTS_TABLENAME}"
|
||||
WHERE "{FIELD_HASH}" IN (%s);
|
||||
"""
|
||||
|
||||
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
|
||||
|
||||
SELECT_SONG = f"""
|
||||
SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_SONG_ID}" = %s;
|
||||
"""
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = f"""
|
||||
SELECT COUNT("{FIELD_SONG_ID}") AS n
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||
"""
|
||||
|
||||
SELECT_SONGS = f"""
|
||||
SELECT
|
||||
"{FIELD_SONG_ID}"
|
||||
, "{FIELD_SONGNAME}"
|
||||
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||
"""
|
||||
|
||||
# drops
|
||||
DROP_FINGERPRINTS = f'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
|
||||
DROP_SONGS = f'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
|
||||
|
||||
# update
|
||||
UPDATE_SONG_FINGERPRINTED = f"""
|
||||
UPDATE "{SONGS_TABLENAME}" SET "{FIELD_FINGERPRINTED}" = 1, "date_modified" = now() WHERE "{FIELD_SONG_ID}" = %s;
|
||||
"""
|
||||
|
||||
# delete
|
||||
DELETE_UNFINGERPRINTED = f"""
|
||||
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
|
||||
"""
|
0
dejavu/logic/__init__.py
Normal file
0
dejavu/logic/__init__.py
Normal file
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
import fnmatch
|
||||
import os
|
||||
from hashlib import sha1
|
||||
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
from pydub.utils import audioop
|
||||
from . import wavio
|
||||
from hashlib import sha1
|
||||
|
||||
from dejavu.third_party import wavio
|
||||
|
||||
|
||||
def unique_hash(filepath, blocksize=2**20):
|
|
@ -9,11 +9,11 @@ from scipy.ndimage.morphology import (binary_erosion,
|
|||
generate_binary_structure,
|
||||
iterate_structure)
|
||||
|
||||
from dejavu.config.config import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
|
||||
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
||||
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
|
||||
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
|
||||
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
|
||||
from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
|
||||
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
||||
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
|
||||
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
|
||||
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
|
||||
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
0
dejavu/logic/recognizer/__init__.py
Normal file
0
dejavu/logic/recognizer/__init__.py
Normal file
24
dejavu/logic/recognizer/file_recognizer.py
Normal file
24
dejavu/logic/recognizer/file_recognizer.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import time
|
||||
|
||||
import dejavu.logic.decoder as decoder
|
||||
from dejavu.base_classes.base_recognizer import BaseRecognizer
|
||||
|
||||
|
||||
class FileRecognizer(BaseRecognizer):
|
||||
def __init__(self, dejavu):
|
||||
super().__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filename):
|
||||
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
|
||||
|
||||
t = time.time()
|
||||
matches = self._recognize(*frames)
|
||||
t = time.time() - t
|
||||
|
||||
for match in matches:
|
||||
match['match_time'] = t
|
||||
|
||||
return matches
|
||||
|
||||
def recognize(self, filename):
|
||||
return self.recognize_file(filename)
|
|
@ -1,45 +1,7 @@
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
|
||||
import dejavu.decoder as decoder
|
||||
from dejavu.config.config import DEFAULT_FS
|
||||
|
||||
|
||||
class BaseRecognizer(object):
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = DEFAULT_FS
|
||||
|
||||
def _recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
|
||||
class FileRecognizer(BaseRecognizer):
|
||||
def __init__(self, dejavu):
|
||||
super().__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filename):
|
||||
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
|
||||
|
||||
t = time.time()
|
||||
matches = self._recognize(*frames)
|
||||
t = time.time() - t
|
||||
|
||||
for match in matches:
|
||||
match['match_time'] = t
|
||||
|
||||
return matches
|
||||
|
||||
def recognize(self, filename):
|
||||
return self.recognize_file(filename)
|
||||
from dejavu.base_classes.base_recognizer import BaseRecognizer
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
0
dejavu/tests/__init__.py
Normal file
0
dejavu/tests/__init__.py
Normal file
|
@ -1,130 +1,26 @@
|
|||
|
||||
import ast
|
||||
import fnmatch
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import traceback
|
||||
from os import listdir, makedirs, walk
|
||||
from os.path import basename, exists, isfile, join, splitext
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
|
||||
from dejavu import Dejavu
|
||||
from dejavu.decoder import path_to_songname
|
||||
from dejavu.fingerprint import *
|
||||
from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS,
|
||||
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||
MATCH_TIME, OFFSET, SONG_NAME)
|
||||
from dejavu.logic.decoder import path_to_songname
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
"""
|
||||
`seed` as None means that the sampling will be random.
|
||||
|
||||
Setting your own seed means that you can produce the
|
||||
same experiment over and over.
|
||||
"""
|
||||
if seed != None:
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def get_files_recursive(src, fmt):
|
||||
"""
|
||||
`src` is the source directory.
|
||||
`fmt` is the extension, ie ".mp3" or "mp3", etc.
|
||||
"""
|
||||
for root, dirnames, filenames in os.walk(src):
|
||||
for filename in fnmatch.filter(filenames, '*' + fmt):
|
||||
yield os.path.join(root, filename)
|
||||
|
||||
|
||||
def get_length_audio(audiopath, extension):
|
||||
"""
|
||||
Returns length of audio in seconds.
|
||||
Returns None if format isn't supported or in case of error.
|
||||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
|
||||
except:
|
||||
print(f"Error in get_length_audio(): {traceback.format_exc()}")
|
||||
return None
|
||||
return int(len(audio) / 1000.0)
|
||||
|
||||
|
||||
def get_starttime(length, nseconds, padding):
|
||||
"""
|
||||
`length` is total audio length in seconds
|
||||
`nseconds` is amount of time to sample in seconds
|
||||
`padding` is off-limits seconds at beginning and ending
|
||||
"""
|
||||
maximum = length - padding - nseconds
|
||||
if padding > maximum:
|
||||
return 0
|
||||
return random.randint(padding, maximum)
|
||||
|
||||
|
||||
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
|
||||
"""
|
||||
Generates a test file for each file recursively in `src` directory
|
||||
of given format using `nseconds` sampled from the audio file.
|
||||
|
||||
Results are written to `dest` directory.
|
||||
|
||||
`padding` is the number of off-limit seconds and the beginning and
|
||||
end of a track that won't be sampled in testing. Often you want to
|
||||
avoid silence, etc.
|
||||
"""
|
||||
# create directories if necessary
|
||||
for directory in [src, dest]:
|
||||
try:
|
||||
os.stat(directory)
|
||||
except:
|
||||
os.mkdir(directory)
|
||||
|
||||
# find files recursively of a given file format
|
||||
for fmt in fmts:
|
||||
testsources = get_files_recursive(src, fmt)
|
||||
for audiosource in testsources:
|
||||
|
||||
print("audiosource:", audiosource)
|
||||
|
||||
filename, extension = os.path.splitext(os.path.basename(audiosource))
|
||||
length = get_length_audio(audiosource, extension)
|
||||
starttime = get_starttime(length, nseconds, padding)
|
||||
|
||||
test_file_name = f"{os.path.join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
|
||||
|
||||
subprocess.check_output([
|
||||
"ffmpeg", "-y",
|
||||
"-ss", f"{starttime}",
|
||||
'-t', f"{nseconds}",
|
||||
"-i", audiosource,
|
||||
test_file_name])
|
||||
|
||||
|
||||
def log_msg(msg, log=True, silent=False):
|
||||
if log:
|
||||
logging.debug(msg)
|
||||
if not silent:
|
||||
print(msg)
|
||||
|
||||
|
||||
def autolabel(rects, ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
|
||||
|
||||
|
||||
def autolabeldoubles(rects, ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
|
||||
ha='center', va='bottom')
|
||||
|
||||
|
||||
class DejavuTest(object):
|
||||
class DejavuTest:
|
||||
def __init__(self, folder, seconds):
|
||||
super(DejavuTest, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.test_folder = folder
|
||||
self.test_seconds = seconds
|
||||
|
@ -133,9 +29,10 @@ class DejavuTest(object):
|
|||
print("test_seconds", self.test_seconds)
|
||||
|
||||
self.test_files = [
|
||||
f for f in os.listdir(self.test_folder)
|
||||
if os.path.isfile(os.path.join(self.test_folder, f))
|
||||
and re.findall("[0-9]*sec", f)[0] in self.test_seconds]
|
||||
f for f in listdir(self.test_folder)
|
||||
if isfile(join(self.test_folder, f))
|
||||
and any([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds])
|
||||
]
|
||||
|
||||
print("test_files", self.test_files)
|
||||
|
||||
|
@ -154,7 +51,7 @@ class DejavuTest(object):
|
|||
# variable match precision (if matched in the corrected time)
|
||||
self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||
|
||||
# variable mahing time (query time)
|
||||
# variable matching time (query time)
|
||||
self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||
|
||||
# variable confidence
|
||||
|
@ -162,12 +59,12 @@ class DejavuTest(object):
|
|||
|
||||
self.begin()
|
||||
|
||||
def get_column_id (self, secs):
|
||||
def get_column_id(self, secs):
|
||||
for i, sec in enumerate(self.test_seconds):
|
||||
if secs == sec:
|
||||
return i
|
||||
|
||||
def get_line_id (self, song):
|
||||
def get_line_id(self, song):
|
||||
for i, s in enumerate(self.test_songs):
|
||||
if song == s:
|
||||
return i
|
||||
|
@ -176,7 +73,7 @@ class DejavuTest(object):
|
|||
|
||||
def create_plots(self, name, results, results_folder):
|
||||
for sec in range(0, len(self.test_seconds)):
|
||||
ind = np.arange(self.n_lines) #
|
||||
ind = np.arange(self.n_lines)
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
fig = plt.figure()
|
||||
|
@ -206,7 +103,7 @@ class DejavuTest(object):
|
|||
|
||||
plt.grid()
|
||||
|
||||
fig_name = os.path.join(results_folder, f"{name}_{self.test_seconds[sec]}.png")
|
||||
fig_name = join(results_folder, f"{name}_{self.test_seconds[sec]}.png")
|
||||
fig.savefig(fig_name)
|
||||
|
||||
def begin(self):
|
||||
|
@ -215,16 +112,18 @@ class DejavuTest(object):
|
|||
log_msg(f'file: {f}')
|
||||
|
||||
# get column
|
||||
col = self.get_column_id(re.findall("[0-9]*sec", f)[0])
|
||||
# format: XXXX_offset_length.mp3
|
||||
song = path_to_songname(f).split("_")[0]
|
||||
col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0])
|
||||
|
||||
# format: XXXX_offset_length.mp3, we also take into account underscores within XXXX
|
||||
splits = path_to_songname(f).split("_")
|
||||
song = "_".join(splits[0:len(path_to_songname(f).split("_")) - 2])
|
||||
line = self.get_line_id(song)
|
||||
result = subprocess.check_output([
|
||||
"python",
|
||||
"dejavu.py",
|
||||
'-r',
|
||||
'file',
|
||||
self.test_folder + "/" + f])
|
||||
join(self.test_folder, f)])
|
||||
|
||||
if result.strip() == "None":
|
||||
log_msg('No match')
|
||||
|
@ -235,14 +134,12 @@ class DejavuTest(object):
|
|||
|
||||
else:
|
||||
result = result.strip()
|
||||
result = result.replace(" \'", ' "')
|
||||
result = result.replace("{\'", '{"')
|
||||
result = result.replace("\':", '":')
|
||||
result = result.replace("\',", '",')
|
||||
# we parse the output song back to a json
|
||||
result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"'))
|
||||
|
||||
# which song did we predict?
|
||||
result = ast.literal_eval(result)
|
||||
song_result = result["song_name"]
|
||||
# which song did we predict? We consider only the first match.
|
||||
result = result[0]
|
||||
song_result = result[SONG_NAME]
|
||||
log_msg(f'song: {song}')
|
||||
log_msg(f'song_result: {song_result}')
|
||||
|
||||
|
@ -256,21 +153,22 @@ class DejavuTest(object):
|
|||
log_msg('correct match')
|
||||
print(self.result_match)
|
||||
self.result_match[line][col] = 'yes'
|
||||
self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3)
|
||||
self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE]
|
||||
self.result_query_duration[line][col] = round(result[MATCH_TIME], 3)
|
||||
self.result_match_confidence[line][col] = result[CONFIDENCE]
|
||||
|
||||
song_start_time = re.findall("_[^_]+", f)
|
||||
# using replace in f for getting rid of underscores in name
|
||||
song_start_time = re.findall("_[^_]+", f.replace(song, ""))
|
||||
song_start_time = song_start_time[0].lstrip("_ ")
|
||||
|
||||
result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE *
|
||||
result_start_time = round((result[OFFSET] * DEFAULT_WINDOW_SIZE *
|
||||
DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0)
|
||||
|
||||
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
||||
if abs(self.result_matching_times[line][col]) == 1:
|
||||
self.result_matching_times[line][col] = 0
|
||||
|
||||
log_msg(f'query duration: {round(result[Dejavu.MATCH_TIME], 3)}')
|
||||
log_msg(f'confidence: {result[Dejavu.CONFIDENCE]}')
|
||||
log_msg(f'query duration: {round(result[MATCH_TIME], 3)}')
|
||||
log_msg(f'confidence: {result[CONFIDENCE]}')
|
||||
log_msg(f'song start_time: {song_start_time}')
|
||||
log_msg(f'result start time: {result_start_time}')
|
||||
|
||||
|
@ -279,3 +177,110 @@ class DejavuTest(object):
|
|||
else:
|
||||
log_msg('inaccurate match')
|
||||
log_msg('--------------------------------------------------\n')
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
"""
|
||||
`seed` as None means that the sampling will be random.
|
||||
|
||||
Setting your own seed means that you can produce the
|
||||
same experiment over and over.
|
||||
"""
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def get_files_recursive(src, fmt):
|
||||
"""
|
||||
`src` is the source directory.
|
||||
`fmt` is the extension, ie ".mp3" or "mp3", etc.
|
||||
"""
|
||||
files = []
|
||||
for root, dirnames, filenames in walk(src):
|
||||
for filename in fnmatch.filter(filenames, '*' + fmt):
|
||||
files.append(join(root, filename))
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def get_length_audio(audiopath, extension):
|
||||
"""
|
||||
Returns length of audio in seconds.
|
||||
Returns None if format isn't supported or in case of error.
|
||||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
|
||||
except Exception:
|
||||
print(f"Error in get_length_audio(): {traceback.format_exc()}")
|
||||
return None
|
||||
return int(len(audio) / 1000.0)
|
||||
|
||||
|
||||
def get_starttime(length, nseconds, padding):
|
||||
"""
|
||||
`length` is total audio length in seconds
|
||||
`nseconds` is amount of time to sample in seconds
|
||||
`padding` is off-limits seconds at beginning and ending
|
||||
"""
|
||||
maximum = length - padding - nseconds
|
||||
if padding > maximum:
|
||||
return 0
|
||||
return random.randint(padding, maximum)
|
||||
|
||||
|
||||
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
|
||||
"""
|
||||
Generates a test file for each file recursively in `src` directory
|
||||
of given format using `nseconds` sampled from the audio file.
|
||||
|
||||
Results are written to `dest` directory.
|
||||
|
||||
`padding` is the number of off-limit seconds and the beginning and
|
||||
end of a track that won't be sampled in testing. Often you want to
|
||||
avoid silence, etc.
|
||||
"""
|
||||
# create directories if necessary
|
||||
if not exists(dest):
|
||||
makedirs(dest)
|
||||
|
||||
# find files recursively of a given file format
|
||||
for fmt in fmts:
|
||||
testsources = get_files_recursive(src, fmt)
|
||||
for audiosource in testsources:
|
||||
|
||||
print("audiosource:", audiosource)
|
||||
|
||||
filename, extension = splitext(basename(audiosource))
|
||||
length = get_length_audio(audiosource, extension)
|
||||
starttime = get_starttime(length, nseconds, padding)
|
||||
|
||||
test_file_name = f"{join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
|
||||
|
||||
subprocess.check_output([
|
||||
"ffmpeg", "-y",
|
||||
"-ss", f"{starttime}",
|
||||
'-t', f"{nseconds}",
|
||||
"-i", audiosource,
|
||||
test_file_name])
|
||||
|
||||
|
||||
def log_msg(msg, log=True, silent=False):
|
||||
if log:
|
||||
logging.debug(msg)
|
||||
if not silent:
|
||||
print(msg)
|
||||
|
||||
|
||||
def autolabel(rects, ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
|
||||
|
||||
|
||||
def autolabeldoubles(rects, ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
|
||||
ha='center', va='bottom')
|
0
dejavu/third_party/__init__.py
vendored
Normal file
0
dejavu/third_party/__init__.py
vendored
Normal file
357
dejavu/third_party/wavio.py
vendored
Normal file
357
dejavu/third_party/wavio.py
vendored
Normal file
|
@ -0,0 +1,357 @@
|
|||
# wavio.py
|
||||
# Author: Warren Weckesser
|
||||
# License: BSD 2-Clause (http://opensource.org/licenses/BSD-2-Clause)
|
||||
# Synopsis: A Python module for reading and writing 24 bit WAV files.
|
||||
# Github: github.com/WarrenWeckesser/wavio
|
||||
|
||||
"""
|
||||
The wavio module defines the functions:
|
||||
read(file)
|
||||
Read a WAV file and return a `wavio.Wav` object, with attributes
|
||||
`data`, `rate` and `sampwidth`.
|
||||
write(filename, data, rate, scale=None, sampwidth=None)
|
||||
Write a numpy array to a WAV file.
|
||||
-----
|
||||
Author: Warren Weckesser
|
||||
License: BSD 2-Clause:
|
||||
Copyright (c) 2015, Warren Weckesser
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
1. Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
|
||||
import wave as _wave
|
||||
|
||||
import numpy as _np
|
||||
|
||||
__version__ = "0.0.5.dev1"
|
||||
|
||||
|
||||
def _wav2array(nchannels, sampwidth, data):
|
||||
"""data must be the string containing the bytes from the wav file."""
|
||||
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
|
||||
if remainder > 0:
|
||||
raise ValueError('The length of data is not a multiple of '
|
||||
'sampwidth * num_channels.')
|
||||
if sampwidth > 4:
|
||||
raise ValueError("sampwidth must not be greater than 4.")
|
||||
|
||||
if sampwidth == 3:
|
||||
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
|
||||
raw_bytes = _np.frombuffer(data, dtype=_np.uint8)
|
||||
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
|
||||
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
|
||||
result = a.view('<i4').reshape(a.shape[:-1])
|
||||
else:
|
||||
# 8 bit samples are stored as unsigned ints; others as signed ints.
|
||||
dt_char = 'u' if sampwidth == 1 else 'i'
|
||||
a = _np.frombuffer(data, dtype=f'<{dt_char}{sampwidth}')
|
||||
result = a.reshape(-1, nchannels)
|
||||
return result
|
||||
|
||||
|
||||
def _array2wav(a, sampwidth):
|
||||
"""
|
||||
Convert the input array `a` to a string of WAV data.
|
||||
a.dtype must be one of uint8, int16 or int32. Allowed sampwidth
|
||||
values are:
|
||||
dtype sampwidth
|
||||
uint8 1
|
||||
int16 2
|
||||
int32 3 or 4
|
||||
When sampwidth is 3, the *low* bytes of `a` are assumed to contain
|
||||
the values to include in the string.
|
||||
"""
|
||||
if sampwidth == 3:
|
||||
# `a` must have dtype int32
|
||||
if a.ndim == 1:
|
||||
# Convert to a 2D array with a single column.
|
||||
a = a.reshape(-1, 1)
|
||||
# By shifting first 0 bits, then 8, then 16, the resulting output
|
||||
# is 24 bit little-endian.
|
||||
a8 = (a.reshape(a.shape + (1,)) >> _np.array([0, 8, 16])) & 255
|
||||
wavdata = a8.astype(_np.uint8).tostring()
|
||||
else:
|
||||
# Make sure the array is little-endian, and then convert using
|
||||
# tostring()
|
||||
a = a.astype('<' + a.dtype.str[1:], copy=False)
|
||||
wavdata = a.tostring()
|
||||
return wavdata
|
||||
|
||||
|
||||
class Wav(object):
|
||||
"""
|
||||
Object returned by `wavio.read`. Attributes are:
|
||||
data : numpy array
|
||||
The array of data read from the WAV file.
|
||||
rate : float
|
||||
The sample rate of the WAV file.
|
||||
sampwidth : int
|
||||
The sample width (i.e. number of bytes per sample) of the WAV file.
|
||||
For example, `sampwidth == 3` is a 24 bit WAV file.
|
||||
"""
|
||||
|
||||
def __init__(self, data, rate, sampwidth):
|
||||
self.data = data
|
||||
self.rate = rate
|
||||
self.sampwidth = sampwidth
|
||||
|
||||
def __repr__(self):
|
||||
s = (f"Wav(data.shape={self.data.shape}, data.dtype={self.data.dtype}, "
|
||||
f"rate={self.rate}, sampwidth={self.sampwidth})")
|
||||
return s
|
||||
|
||||
|
||||
def read(file):
|
||||
"""
|
||||
Read a WAV file.
|
||||
Parameters
|
||||
----------
|
||||
file : string or file object
|
||||
Either the name of a file or an open file pointer.
|
||||
Returns
|
||||
-------
|
||||
wav : wavio.Wav() instance
|
||||
The return value is an instance of the class `wavio.Wav`,
|
||||
with the following attributes:
|
||||
data : numpy array
|
||||
The array containing the data. The shape of the array
|
||||
is (num_samples, num_channels). num_channels is the
|
||||
number of audio channels (1 for mono, 2 for stereo).
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate)
|
||||
sampwidth : float
|
||||
The sample width, in bytes. E.g. for a 24 bit WAV file,
|
||||
sampwidth is 3.
|
||||
Notes
|
||||
-----
|
||||
This function uses the `wave` module of the Python standard libary
|
||||
to read the WAV file, so it has the same limitations as that library.
|
||||
In particular, the function does not read compressed WAV files, and
|
||||
it does not read files with floating point data.
|
||||
The array returned by `wavio.read` is always two-dimensional. If the
|
||||
WAV data is mono, the array will have shape (num_samples, 1).
|
||||
`wavio.read()` does not scale or normalize the data. The data in the
|
||||
array `wav.data` is the data that was in the file. When the file
|
||||
contains 24 bit samples, the resulting numpy array is 32 bit integers,
|
||||
with values that have been sign-extended.
|
||||
"""
|
||||
wav = _wave.open(file)
|
||||
rate = wav.getframerate()
|
||||
nchannels = wav.getnchannels()
|
||||
sampwidth = wav.getsampwidth()
|
||||
nframes = wav.getnframes()
|
||||
data = wav.readframes(nframes)
|
||||
wav.close()
|
||||
array = _wav2array(nchannels, sampwidth, data)
|
||||
w = Wav(data=array, rate=rate, sampwidth=sampwidth)
|
||||
return w
|
||||
|
||||
|
||||
_sampwidth_dtypes = {1: _np.uint8,
|
||||
2: _np.int16,
|
||||
3: _np.int32,
|
||||
4: _np.int32}
|
||||
_sampwidth_ranges = {1: (0, 256),
|
||||
2: (-2**15, 2**15),
|
||||
3: (-2**23, 2**23),
|
||||
4: (-2**31, 2**31)}
|
||||
|
||||
|
||||
def _scale_to_sampwidth(data, sampwidth, vmin, vmax):
|
||||
# Scale and translate the values to fit the range of the data type
|
||||
# associated with the given sampwidth.
|
||||
|
||||
data = data.clip(vmin, vmax)
|
||||
|
||||
dt = _sampwidth_dtypes[sampwidth]
|
||||
if vmax == vmin:
|
||||
data = _np.zeros(data.shape, dtype=dt)
|
||||
else:
|
||||
outmin, outmax = _sampwidth_ranges[sampwidth]
|
||||
if outmin != vmin or outmax != vmax:
|
||||
vmin = float(vmin)
|
||||
vmax = float(vmax)
|
||||
data = (float(outmax - outmin) * (data - vmin) /
|
||||
(vmax - vmin)).astype(_np.int64) + outmin
|
||||
data[data == outmax] = outmax - 1
|
||||
data = data.astype(dt)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def write(file, data, rate, scale=None, sampwidth=None):
|
||||
"""
|
||||
Write the numpy array `data` to a WAV file.
|
||||
The Python standard library "wave" is used to write the data
|
||||
to the file, so this function has the same limitations as that
|
||||
module. In particular, the Python library does not support
|
||||
floating point data. When given a floating point array, this
|
||||
function converts the values to integers. See below for the
|
||||
conversion rules.
|
||||
Parameters
|
||||
----------
|
||||
file : string, or file object open for writing in binary mode
|
||||
Either the name of a file or an open file pointer.
|
||||
data : numpy array, 1- or 2-dimensional, integer or floating point
|
||||
If it is 2-d, the rows are the frames (i.e. samples) and the
|
||||
columns are the channels.
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate) of the data.
|
||||
sampwidth : int, optional
|
||||
The sample width, in bytes, of the output file.
|
||||
If `sampwidth` is not given, it is inferred (if possible) from
|
||||
the data type of `data`, as follows::
|
||||
data.dtype sampwidth
|
||||
---------- ---------
|
||||
uint8, int8 1
|
||||
uint16, int16 2
|
||||
uint32, int32 4
|
||||
For any other data types, or to write a 24 bit file, `sampwidth`
|
||||
must be given.
|
||||
scale : tuple or str, optional
|
||||
By default, the data written to the file is scaled up or down to
|
||||
occupy the full range of the output data type. So, for example,
|
||||
the unsigned 8 bit data [0, 1, 2, 15] would be written to the file
|
||||
as [0, 17, 30, 255]. More generally, the default behavior is
|
||||
(roughly)::
|
||||
vmin = data.min()
|
||||
vmax = data.max()
|
||||
outmin = <minimum integer of the output dtype>
|
||||
outmax = <maximum integer of the output dtype>
|
||||
outdata = (outmax - outmin)*(data - vmin)/(vmax - vmin) + outmin
|
||||
The `scale` argument allows the scaling of the output data to be
|
||||
changed. `scale` can be a tuple of the form `(vmin, vmax)`, in which
|
||||
case the given values override the use of `data.min()` and
|
||||
`data.max()` for `vmin` and `vmax` shown above. (If either value
|
||||
is `None`, the value shown above is used.) Data outside the
|
||||
range (vmin, vmax) is clipped. If `vmin == vmax`, the output is
|
||||
all zeros.
|
||||
If `scale` is the string "none", then `vmin` and `vmax` are set to
|
||||
`outmin` and `outmax`, respectively. This means the data is written
|
||||
to the file with no scaling. (Note: `scale="none" is not the same
|
||||
as `scale=None`. The latter means "use the default behavior",
|
||||
which is to scale by the data minimum and maximum.)
|
||||
If `scale` is the string "dtype-limits", then `vmin` and `vmax`
|
||||
are set to the minimum and maximum integers of `data.dtype`.
|
||||
The string "dtype-limits" is not allowed when the `data` is a
|
||||
floating point array.
|
||||
If using `scale` results in values that exceed the limits of the
|
||||
output sample width, the data is clipped. For example, the
|
||||
following code::
|
||||
>> x = np.array([-100, 0, 100, 200, 300, 325])
|
||||
>> wavio.write('foo.wav', x, 8000, scale='none', sampwidth=1)
|
||||
will write the values [0, 0, 100, 200, 255, 255] to the file.
|
||||
Example
|
||||
-------
|
||||
Create a 3 second 440 Hz sine wave, and save it in a 24-bit WAV file.
|
||||
>> import numpy as np
|
||||
>> import wavio
|
||||
>> rate = 22050 # samples per second
|
||||
>> T = 3 # sample duration (seconds)
|
||||
>> f = 440.0 # sound frequency (Hz)
|
||||
>> t = np.linspace(0, T, T*rate, endpoint=False)
|
||||
>> x = np.sin(2*np.pi * f * t)
|
||||
>> wavio.write("sine24.wav", x, rate, sampwidth=3)
|
||||
Create a file that contains the 16 bit integer values -10000 and 10000
|
||||
repeated 100 times. Don't automatically scale the values. Use a sample
|
||||
rate 8000.
|
||||
>> x = np.empty(200, dtype=np.int16)
|
||||
>> x[::2] = -10000
|
||||
>> x[1::2] = 10000
|
||||
>> wavio.write("foo.wav", x, 8000, scale='none')
|
||||
Check that the file contains what we expect.
|
||||
>> w = wavio.read("foo.wav")
|
||||
>> np.all(w.data[:, 0] == x)
|
||||
True
|
||||
In the following, the values -10000 and 10000 (from within the 16 bit
|
||||
range [-2**15, 2**15-1]) are mapped to the corresponding values 88 and
|
||||
168 (in the range [0, 2**8-1]).
|
||||
>> wavio.write("foo.wav", x, 8000, sampwidth=1, scale='dtype-limits')
|
||||
>> w = wavio.read("foo.wav")
|
||||
>> w.data[:4, 0]
|
||||
array([ 88, 168, 88, 168], dtype=uint8)
|
||||
"""
|
||||
|
||||
if sampwidth is None:
|
||||
if not _np.issubdtype(data.dtype, _np.integer) or data.itemsize > 4:
|
||||
raise ValueError('when data.dtype is not an 8-, 16-, or 32-bit integer type, sampwidth must be specified.')
|
||||
sampwidth = data.itemsize
|
||||
else:
|
||||
if sampwidth not in [1, 2, 3, 4]:
|
||||
raise ValueError('sampwidth must be 1, 2, 3 or 4.')
|
||||
|
||||
outdtype = _sampwidth_dtypes[sampwidth]
|
||||
outmin, outmax = _sampwidth_ranges[sampwidth]
|
||||
|
||||
if scale == "none":
|
||||
data = data.clip(outmin, outmax-1).astype(outdtype)
|
||||
elif scale == "dtype-limits":
|
||||
if not _np.issubdtype(data.dtype, _np.integer):
|
||||
raise ValueError("scale cannot be 'dtype-limits' with non-integer data.")
|
||||
# Easy transforms that just changed the signedness of the data.
|
||||
if sampwidth == 1 and data.dtype == _np.int8:
|
||||
data = (data.astype(_np.int16) + 128).astype(_np.uint8)
|
||||
elif sampwidth == 2 and data.dtype == _np.uint16:
|
||||
data = (data.astype(_np.int32) - 32768).astype(_np.int16)
|
||||
elif sampwidth == 4 and data.dtype == _np.uint32:
|
||||
data = (data.astype(_np.int64) - 2**31).astype(_np.int32)
|
||||
elif data.itemsize != sampwidth:
|
||||
# Integer input, but rescaling is needed to adjust the
|
||||
# input range to the output sample width.
|
||||
ii = _np.iinfo(data.dtype)
|
||||
vmin = ii.min
|
||||
vmax = ii.max
|
||||
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
|
||||
else:
|
||||
if scale is None:
|
||||
vmin = data.min()
|
||||
vmax = data.max()
|
||||
else:
|
||||
# scale must be a tuple of the form (vmin, vmax)
|
||||
vmin, vmax = scale
|
||||
if vmin is None:
|
||||
vmin = data.min()
|
||||
if vmax is None:
|
||||
vmax = data.max()
|
||||
|
||||
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
|
||||
|
||||
# At this point, `data` has been converted to have one of the following:
|
||||
# sampwidth dtype
|
||||
# --------- -----
|
||||
# 1 uint8
|
||||
# 2 int16
|
||||
# 3 int32
|
||||
# 4 int32
|
||||
# The values in `data` are in the form in which they will be saved;
|
||||
# no more scaling will take place.
|
||||
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(-1, 1)
|
||||
|
||||
wavdata = _array2wav(data, sampwidth)
|
||||
|
||||
w = _wave.open(file, 'wb')
|
||||
w.setnchannels(data.shape[1])
|
||||
w.setsampwidth(sampwidth)
|
||||
w.setframerate(rate)
|
||||
w.writeframes(wavdata)
|
||||
w.close()
|
122
dejavu/wavio.py
122
dejavu/wavio.py
|
@ -1,122 +0,0 @@
|
|||
# wavio.py
|
||||
# Author: Warren Weckesser
|
||||
# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause)
|
||||
# Synopsis: A Python module for reading and writing 24 bit WAV files.
|
||||
# Github: github.com/WarrenWeckesser/wavio
|
||||
|
||||
import wave as _wave
|
||||
|
||||
import numpy as _np
|
||||
|
||||
|
||||
def _wav2array(nchannels, sampwidth, data):
|
||||
"""data must be the string containing the bytes from the wav file."""
|
||||
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
|
||||
if remainder > 0:
|
||||
raise ValueError('The length of data is not a multiple of '
|
||||
'sampwidth * num_channels.')
|
||||
if sampwidth > 4:
|
||||
raise ValueError("sampwidth must not be greater than 4.")
|
||||
|
||||
if sampwidth == 3:
|
||||
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
|
||||
raw_bytes = _np.fromstring(data, dtype=_np.uint8)
|
||||
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
|
||||
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
|
||||
result = a.view('<i4').reshape(a.shape[:-1])
|
||||
else:
|
||||
# 8 bit samples are stored as unsigned ints; others as signed ints.
|
||||
dt_char = 'u' if sampwidth == 1 else 'i'
|
||||
a = _np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
|
||||
result = a.reshape(-1, nchannels)
|
||||
return result
|
||||
|
||||
|
||||
def readwav(file):
|
||||
"""
|
||||
Read a WAV file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : string or file object
|
||||
Either the name of a file or an open file pointer.
|
||||
|
||||
Return Values
|
||||
-------------
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate)
|
||||
sampwidth : float
|
||||
The sample width, in bytes. E.g. for a 24 bit WAV file,
|
||||
sampwidth is 3.
|
||||
data : numpy array
|
||||
The array containing the data. The shape of the array is
|
||||
(num_samples, num_channels). num_channels is the number of
|
||||
audio channels (1 for mono, 2 for stereo).
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function uses the `wave` module of the Python standard libary
|
||||
to read the WAV file, so it has the same limitations as that library.
|
||||
In particular, the function does not read compressed WAV files.
|
||||
|
||||
"""
|
||||
wav = _wave.open(file)
|
||||
rate = wav.getframerate()
|
||||
nchannels = wav.getnchannels()
|
||||
sampwidth = wav.getsampwidth()
|
||||
nframes = wav.getnframes()
|
||||
data = wav.readframes(nframes)
|
||||
wav.close()
|
||||
array = _wav2array(nchannels, sampwidth, data)
|
||||
return rate, sampwidth, array
|
||||
|
||||
|
||||
def writewav24(filename, rate, data):
|
||||
"""
|
||||
Create a 24 bit wav file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : string
|
||||
Name of the file to create.
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate) of the data.
|
||||
data : array-like collection of integer or floating point values
|
||||
data must be "array-like", either 1- or 2-dimensional. If it
|
||||
is 2-d, the rows are the frames (i.e. samples) and the columns
|
||||
are the channels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The data is assumed to be signed, and the values are assumed to be
|
||||
within the range of a 24 bit integer. Floating point values are
|
||||
converted to integers. The data is not rescaled or normalized before
|
||||
writing it to the file.
|
||||
|
||||
Example
|
||||
-------
|
||||
Create a 3 second 440 Hz sine wave.
|
||||
|
||||
>>> rate = 22050 # samples per second
|
||||
>>> T = 3 # sample duration (seconds)
|
||||
>>> f = 440.0 # sound frequency (Hz)
|
||||
>>> t = np.linspace(0, T, T*rate, endpoint=False)
|
||||
>>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t)
|
||||
>>> writewav24("sine24.wav", rate, x)
|
||||
|
||||
"""
|
||||
a32 = _np.asarray(data, dtype=_np.int32)
|
||||
if a32.ndim == 1:
|
||||
# Convert to a 2D array with a single column.
|
||||
a32.shape = a32.shape + (1,)
|
||||
# By shifting first 0 bits, then 8, then 16, the resulting output
|
||||
# is 24 bit little-endian.
|
||||
a8 = (a32.reshape(a32.shape + (1,)) >> _np.array([0, 8, 16])) & 255
|
||||
wavdata = a8.astype(_np.uint8).tostring()
|
||||
|
||||
w = _wave.open(filename, 'wb')
|
||||
w.setnchannels(a32.shape[1])
|
||||
w.setsampwidth(3)
|
||||
w.setframerate(rate)
|
||||
w.writeframes(wavdata)
|
||||
w.close()
|
|
@ -1,11 +1,8 @@
|
|||
import json
|
||||
import warnings
|
||||
|
||||
from dejavu import Dejavu
|
||||
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||
|
||||
# load config from a JSON file (or anything outputting a python dictionary)
|
||||
with open("dejavu.cnf.SAMPLE") as f:
|
||||
|
@ -17,10 +14,10 @@ if __name__ == '__main__':
|
|||
djv = Dejavu(config)
|
||||
|
||||
# Fingerprint all the mp3's in the directory we give it
|
||||
djv.fingerprint_directory("mp3", [".mp3"])
|
||||
djv.fingerprint_directory("test", [".wav"])
|
||||
|
||||
# Recognize audio from a file
|
||||
song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3")
|
||||
song = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||
print(f"From file we recognized: {song}\n")
|
||||
|
||||
# Or recognize audio from your microphone for `secs` seconds
|
||||
|
@ -29,7 +26,7 @@ if __name__ == '__main__':
|
|||
if song is None:
|
||||
print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)")
|
||||
else:
|
||||
print(f"From mic with %d seconds we recognized: {(secs, song)}\n")
|
||||
print(f"From mic with {secs} seconds we recognized: {song}\n")
|
||||
|
||||
# Or use a recognizer without the shortcut, in anyway you would like
|
||||
recognizer = FileRecognizer(djv)
|
BIN
mp3/azan_test.wav
Normal file
BIN
mp3/azan_test.wav
Normal file
Binary file not shown.
|
@ -5,3 +5,4 @@ scipy==1.3.1
|
|||
matplotlib==3.1.1
|
||||
mysql-connector-python==8.0.17
|
||||
psycopg2==2.8.3
|
||||
|
||||
|
|
280
run_tests.py
280
run_tests.py
|
@ -1,184 +1,166 @@
|
|||
from dejavu.testing import *
|
||||
from dejavu import Dejavu
|
||||
from optparse import OptionParser
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import shutil
|
||||
from os import makedirs
|
||||
from os.path import exists, join
|
||||
from shutil import rmtree
|
||||
|
||||
usage = "usage: %prog [options] TESTING_AUDIOFOLDER"
|
||||
parser = OptionParser(usage=usage, version="%prog 1.1")
|
||||
parser.add_option("--secs",
|
||||
action="store",
|
||||
dest="secs",
|
||||
default=5,
|
||||
type=int,
|
||||
help='Number of seconds starting from zero to test')
|
||||
parser.add_option("--results",
|
||||
action="store",
|
||||
dest="results_folder",
|
||||
default="./dejavu_test_results",
|
||||
help='Sets the path where the results are saved')
|
||||
parser.add_option("--temp",
|
||||
action="store",
|
||||
dest="temp_folder",
|
||||
default="./dejavu_temp_testing_files",
|
||||
help='Sets the path where the temp files are saved')
|
||||
parser.add_option("--log",
|
||||
action="store_true",
|
||||
dest="log",
|
||||
default=True,
|
||||
help='Enables logging')
|
||||
parser.add_option("--silent",
|
||||
action="store_false",
|
||||
dest="silent",
|
||||
default=False,
|
||||
help='Disables printing')
|
||||
parser.add_option("--log-file",
|
||||
dest="log_file",
|
||||
default="results-compare.log",
|
||||
help='Set the path and filename of the log file')
|
||||
parser.add_option("--padding",
|
||||
action="store",
|
||||
dest="padding",
|
||||
default=10,
|
||||
type=int,
|
||||
help='Number of seconds to pad choice of place to test from')
|
||||
parser.add_option("--seed",
|
||||
action="store",
|
||||
dest="seed",
|
||||
default=None,
|
||||
type=int,
|
||||
help='Random seed')
|
||||
options, args = parser.parse_args()
|
||||
test_folder = args[0]
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# set random seed if set by user
|
||||
set_seed(options.seed)
|
||||
from dejavu.tests.dejavu_test import (DejavuTest, autolabeldoubles,
|
||||
generate_test_files, log_msg, set_seed)
|
||||
|
||||
# ensure results folder exists
|
||||
try:
|
||||
os.stat(options.results_folder)
|
||||
except:
|
||||
os.mkdir(options.results_folder)
|
||||
|
||||
# set logging
|
||||
if options.log:
|
||||
logging.basicConfig(filename=options.log_file, level=logging.DEBUG)
|
||||
def main(seconds: int, results_folder: str, temp_folder: str, log: bool, silent: bool,
|
||||
log_file: str, padding: int, seed: int, src: str):
|
||||
|
||||
# set test seconds
|
||||
test_seconds = ['%dsec' % i for i in range(1, options.secs + 1, 1)]
|
||||
# set random seed if set by user
|
||||
set_seed(seed)
|
||||
|
||||
# generate testing files
|
||||
for i in range(1, options.secs + 1, 1):
|
||||
generate_test_files(test_folder, options.temp_folder,
|
||||
i, padding=options.padding)
|
||||
# ensure results folder exists
|
||||
if not exists(results_folder):
|
||||
makedirs(results_folder)
|
||||
|
||||
# scan files
|
||||
log_msg("Running Dejavu fingerprinter on files in %s..." % test_folder,
|
||||
log=options.log, silent=options.silent)
|
||||
# set logging
|
||||
if log:
|
||||
logging.basicConfig(filename=log_file, level=logging.DEBUG)
|
||||
|
||||
tm = time.time()
|
||||
djv = DejavuTest(options.temp_folder, test_seconds)
|
||||
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm),
|
||||
log=options.log, silent=options.silent)
|
||||
# set test seconds
|
||||
test_seconds = [f'{i}sec' for i in range(1, seconds + 1, 1)]
|
||||
|
||||
tests = 1 # djv
|
||||
n_secs = len(test_seconds)
|
||||
# generate testing files
|
||||
for i in range(1, seconds + 1, 1):
|
||||
generate_test_files(src, temp_folder, i, padding=padding)
|
||||
|
||||
# set result variables -> 4d variables
|
||||
all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)]
|
||||
all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
|
||||
all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||
all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||
# scan files
|
||||
log_msg(f"Running Dejavu fingerprinter on files in {src}...", log=log, silent=silent)
|
||||
|
||||
# group results by seconds
|
||||
for line in range(0, djv.n_lines):
|
||||
for col in range(0, djv.n_columns):
|
||||
# for dejavu
|
||||
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||||
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
|
||||
tm = time.time()
|
||||
djv = DejavuTest(temp_folder, test_seconds)
|
||||
log_msg(f"finished obtaining results from dejavu in {(time.time() - tm)}", log=log, silent=silent)
|
||||
|
||||
djv_match_result = djv.result_match[line][col]
|
||||
tests = 1 # djv
|
||||
n_secs = len(test_seconds)
|
||||
|
||||
if djv_match_result == 'yes':
|
||||
all_match_counter[col][0][0] += 1
|
||||
elif djv_match_result == 'no':
|
||||
all_match_counter[col][1][0] += 1
|
||||
else:
|
||||
all_match_counter[col][2][0] += 1
|
||||
# set result variables -> 4d variables
|
||||
all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)]
|
||||
all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
|
||||
all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||
all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||
|
||||
djv_match_acc = djv.result_matching_times[line][col]
|
||||
# group results by seconds
|
||||
for line in range(0, djv.n_lines):
|
||||
for col in range(0, djv.n_columns):
|
||||
# for dejavu
|
||||
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||||
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
|
||||
|
||||
if djv_match_acc == 0 and djv_match_result == 'yes':
|
||||
all_matching_times_counter[col][0][0] += 1
|
||||
elif djv_match_acc != 0:
|
||||
all_matching_times_counter[col][1][0] += 1
|
||||
djv_match_result = djv.result_match[line][col]
|
||||
|
||||
# create plots
|
||||
djv.create_plots('Confidence', all_match_confidence, options.results_folder)
|
||||
djv.create_plots('Query duration', all_query_duration, options.results_folder)
|
||||
if djv_match_result == 'yes':
|
||||
all_match_counter[col][0][0] += 1
|
||||
elif djv_match_result == 'no':
|
||||
all_match_counter[col][1][0] += 1
|
||||
else:
|
||||
all_match_counter[col][2][0] += 1
|
||||
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(3) #
|
||||
width = 0.25 # the width of the bars
|
||||
djv_match_acc = djv.result_matching_times[line][col]
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 2.75])
|
||||
if djv_match_acc == 0 and djv_match_result == 'yes':
|
||||
all_matching_times_counter[col][0][0] += 1
|
||||
elif djv_match_acc != 0:
|
||||
all_matching_times_counter[col][1][0] += 1
|
||||
|
||||
means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
# create plots
|
||||
djv.create_plots('Confidence', all_match_confidence, results_folder)
|
||||
djv.create_plots('Query duration', all_query_duration, results_folder)
|
||||
|
||||
# add some
|
||||
ax.set_ylabel('Matching Percentage')
|
||||
ax.set_title('%s Matching Percentage' % test_seconds[sec])
|
||||
ax.set_xticks(ind + width)
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(3)
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
labels = ['yes','no','invalid']
|
||||
ax.set_xticklabels( labels )
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 2.75])
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
#ax.legend((rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
autolabeldoubles(rects1,ax)
|
||||
plt.grid()
|
||||
means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
fig_name = os.path.join(options.results_folder, "matching_perc_%s.png" % test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
||||
# add some
|
||||
ax.set_ylabel('Matching Percentage')
|
||||
ax.set_title(f'{test_seconds[sec]} Matching Percentage')
|
||||
ax.set_xticks(ind + width)
|
||||
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(2) #
|
||||
width = 0.25 # the width of the bars
|
||||
labels = ['yes', 'no', 'invalid']
|
||||
ax.set_xticklabels(labels)
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1*width, 1.75])
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
autolabeldoubles(rects1, ax)
|
||||
plt.grid()
|
||||
|
||||
div = all_match_counter[sec][0][0]
|
||||
if div == 0 :
|
||||
div = 1000000
|
||||
fig_name = join(results_folder, f"matching_perc_{test_seconds[sec]}.png")
|
||||
fig.savefig(fig_name)
|
||||
|
||||
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(2)
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
# add some
|
||||
ax.set_ylabel('Matching Accuracy')
|
||||
ax.set_title('%s Matching Times Accuracy' % test_seconds[sec])
|
||||
ax.set_xticks(ind + width)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 1.75])
|
||||
|
||||
labels = ['yes','no']
|
||||
ax.set_xticklabels( labels )
|
||||
div = all_match_counter[sec][0][0]
|
||||
if div == 0:
|
||||
div = 1000000
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
autolabeldoubles(rects1,ax)
|
||||
# add some
|
||||
ax.set_ylabel('Matching Accuracy')
|
||||
ax.set_title(f'{test_seconds[sec]} Matching Times Accuracy')
|
||||
ax.set_xticks(ind + width)
|
||||
|
||||
plt.grid()
|
||||
labels = ['yes', 'no']
|
||||
ax.set_xticklabels(labels)
|
||||
|
||||
fig_name = os.path.join(options.results_folder, "matching_acc_%s.png" % test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
autolabeldoubles(rects1, ax)
|
||||
|
||||
# remove temporary folder
|
||||
shutil.rmtree(options.temp_folder)
|
||||
plt.grid()
|
||||
|
||||
fig_name = join(results_folder, f"matching_acc_{test_seconds[sec]}.png")
|
||||
fig.savefig(fig_name)
|
||||
|
||||
# remove temporary folder
|
||||
rmtree(temp_folder)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description=f'Runs a few tests for dejavu to evaluate '
|
||||
f'its configuration performance. '
|
||||
f'Usage: %(prog).py [options] TESTING_AUDIOFOLDER'
|
||||
)
|
||||
|
||||
parser.add_argument("-sec", "--seconds", action="store", default=5, type=int,
|
||||
help='Number of seconds starting from zero to test.')
|
||||
parser.add_argument("-res", "--results-folder", action="store", default="./dejavu_test_results",
|
||||
help='Sets the path where the results are saved.')
|
||||
parser.add_argument("-temp", "--temp-folder", action="store", default="./dejavu_temp_testing_files",
|
||||
help='Sets the path where the temp files are saved.')
|
||||
parser.add_argument("-l", "--log", action="store_true", default=False, help='Enables logging.')
|
||||
parser.add_argument("-sl", "--silent", action="store_false", default=False, help='Disables printing.')
|
||||
parser.add_argument("-lf", "--log-file", default="results-compare.log",
|
||||
help='Set the path and filename of the log file.')
|
||||
parser.add_argument("-pad", "--padding", action="store", default=10, type=int,
|
||||
help='Number of seconds to pad choice of place to test from.')
|
||||
parser.add_argument("-sd", "--seed", action="store", default=None, type=int, help='Random seed.')
|
||||
parser.add_argument("src", type=str, help='Source folder for audios to use as tests.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.seconds, args.results_folder, args.temp_folder, args.log, args.silent, args.log_file, args.padding,
|
||||
args.seed, args.src)
|
||||
|
|
3
setup.cfg
Normal file
3
setup.cfg
Normal file
|
@ -0,0 +1,3 @@
|
|||
[flake8]
|
||||
max-line-length = 120
|
||||
|
2
setup.py
2
setup.py
|
@ -1,5 +1,4 @@
|
|||
from setuptools import setup, find_packages
|
||||
# import os, sys
|
||||
|
||||
|
||||
def parse_requirements(requirements):
|
||||
|
@ -14,6 +13,7 @@ def parse_requirements(requirements):
|
|||
reqs = list(filter((lambda x: x), nocomments))
|
||||
return reqs
|
||||
|
||||
|
||||
PACKAGE_NAME = "PyDejavu"
|
||||
PACKAGE_VERSION = "0.1.3"
|
||||
SUMMARY = 'Dejavu: Audio Fingerprinting in Python'
|
||||
|
|
BIN
test/sean_secs.wav
Normal file
BIN
test/sean_secs.wav
Normal file
Binary file not shown.
BIN
test/woodward_43s.wav
Normal file
BIN
test/woodward_43s.wav
Normal file
Binary file not shown.
Loading…
Reference in a new issue