mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 11:09:52 +00:00
- Extracted common logic from database handlers.
- Fixed tests. - Refactored solution architecture once more. - Refactored solution hierarchy. - Adding audios for testing. - Solved flake8 issues and reorganized several imports in the process.
This commit is contained in:
parent
6bc15ab8d4
commit
e046eeee93
32 changed files with 1154 additions and 1058 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,7 +1,3 @@
|
||||||
*.pyc
|
*.pyc
|
||||||
wav
|
|
||||||
mp3
|
|
||||||
*.wav
|
|
||||||
*.mp3
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*.cnf
|
*.cnf
|
||||||
|
|
|
@ -4,5 +4,6 @@
|
||||||
"user": "root",
|
"user": "root",
|
||||||
"password": "rootpass",
|
"password": "rootpass",
|
||||||
"database": "dejavu"
|
"database": "dejavu"
|
||||||
}
|
},
|
||||||
|
"database_type": "mysql"
|
||||||
}
|
}
|
||||||
|
|
24
dejavu.py
24
dejavu.py
|
@ -1,16 +1,12 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
from os.path import isdir
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
from dejavu import Dejavu
|
from dejavu import Dejavu
|
||||||
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer
|
from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||||
|
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE"
|
DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE"
|
||||||
|
|
||||||
|
@ -58,7 +54,6 @@ if __name__ == '__main__':
|
||||||
config_file = args.config
|
config_file = args.config
|
||||||
if config_file is None:
|
if config_file is None:
|
||||||
config_file = DEFAULT_CONFIG_FILE
|
config_file = DEFAULT_CONFIG_FILE
|
||||||
# print ("Using default config file: {config_file}")
|
|
||||||
|
|
||||||
djv = init(config_file)
|
djv = init(config_file)
|
||||||
if args.fingerprint:
|
if args.fingerprint:
|
||||||
|
@ -71,22 +66,19 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
elif len(args.fingerprint) == 1:
|
elif len(args.fingerprint) == 1:
|
||||||
filepath = args.fingerprint[0]
|
filepath = args.fingerprint[0]
|
||||||
if os.path.isdir(filepath):
|
if isdir(filepath):
|
||||||
print("Please specify an extension if you'd like to fingerprint a directory!")
|
print("Please specify an extension if you'd like to fingerprint a directory!")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
djv.fingerprint_file(filepath)
|
djv.fingerprint_file(filepath)
|
||||||
|
|
||||||
elif args.recognize:
|
elif args.recognize:
|
||||||
# Recognize audio source
|
# Recognize audio source
|
||||||
song = None
|
songs = None
|
||||||
source = args.recognize[0]
|
source = args.recognize[0]
|
||||||
opt_arg = args.recognize[1]
|
opt_arg = args.recognize[1]
|
||||||
|
|
||||||
if source in ('mic', 'microphone'):
|
if source in ('mic', 'microphone'):
|
||||||
song = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
|
songs = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
|
||||||
elif source == 'file':
|
elif source == 'file':
|
||||||
song = djv.recognize(FileRecognizer, opt_arg)
|
songs = djv.recognize(FileRecognizer, opt_arg)
|
||||||
decoded_song = repr(song).decode('string_escape')
|
print(songs)
|
||||||
print(decoded_song)
|
|
||||||
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
|
@ -3,13 +3,13 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import dejavu.decoder as decoder
|
import dejavu.logic.decoder as decoder
|
||||||
from dejavu.config.config import (CONFIDENCE, DEFAULT_FS,
|
from dejavu.base_classes.base_database import get_database
|
||||||
|
from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS,
|
||||||
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||||
FIELD_FILE_SHA1, OFFSET, OFFSET_SECS,
|
FIELD_FILE_SHA1, OFFSET, OFFSET_SECS,
|
||||||
SONG_ID, SONG_NAME, TOPN)
|
SONG_ID, SONG_NAME, TOPN)
|
||||||
from dejavu.database import get_database
|
from dejavu.logic.fingerprint import fingerprint
|
||||||
from dejavu.fingerprint import fingerprint
|
|
||||||
|
|
||||||
|
|
||||||
class Dejavu:
|
class Dejavu:
|
||||||
|
@ -71,7 +71,7 @@ class Dejavu:
|
||||||
continue
|
continue
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
except:
|
except Exception:
|
||||||
print("Failed fingerprinting")
|
print("Failed fingerprinting")
|
||||||
# Print traceback because we can't reraise it here
|
# Print traceback because we can't reraise it here
|
||||||
traceback.print_exc(file=sys.stdout)
|
traceback.print_exc(file=sys.stdout)
|
||||||
|
@ -119,6 +119,7 @@ class Dejavu:
|
||||||
diff_counter = {}
|
diff_counter = {}
|
||||||
largest_count = 0
|
largest_count = 0
|
||||||
|
|
||||||
|
# TODO: review logic to get topn results.
|
||||||
for tup in matches:
|
for tup in matches:
|
||||||
sid, diff = tup
|
sid, diff = tup
|
||||||
if diff not in diff_counter:
|
if diff not in diff_counter:
|
||||||
|
|
0
dejavu/base_classes/__init__.py
Normal file
0
dejavu/base_classes/__init__.py
Normal file
|
@ -1,9 +1,11 @@
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
from dejavu.config.config import DATABASES
|
from typing import Dict
|
||||||
|
|
||||||
|
from dejavu.config.settings import DATABASES
|
||||||
|
|
||||||
|
|
||||||
class Database(object, metaclass=abc.ABCMeta):
|
class BaseDatabase(object, metaclass=abc.ABCMeta):
|
||||||
# Name of your Database subclass, this is used in configuration
|
# Name of your Database subclass, this is used in configuration
|
||||||
# to refer to your class
|
# to refer to your class
|
||||||
type = None
|
type = None
|
||||||
|
@ -70,17 +72,16 @@ class Database(object, metaclass=abc.ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_songs(self):
|
def get_songs(self) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Returns all fully fingerprinted songs in the database.
|
Returns all fully fingerprinted songs in the database. Result must be a Dictionary.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_song_by_id(self, sid):
|
def get_song_by_id(self, sid) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Return a song by its identifier
|
Return a song by its identifier. Result must be a Dictionary.
|
||||||
|
|
||||||
sid: Song identifier
|
sid: Song identifier
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -124,7 +125,7 @@ class Database(object, metaclass=abc.ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def insert_hashes(self, sid, hashes):
|
def insert_hashes(self, sid, hashes, batch=1000):
|
||||||
"""
|
"""
|
||||||
Insert a multitude of fingerprints.
|
Insert a multitude of fingerprints.
|
||||||
|
|
||||||
|
@ -133,7 +134,6 @@ class Database(object, metaclass=abc.ABCMeta):
|
||||||
- hash: Part of a sha1 hash, in hexadecimal format
|
- hash: Part of a sha1 hash, in hexadecimal format
|
||||||
- offset: Offset this hash was created from/at.
|
- offset: Offset this hash was created from/at.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def return_matches(self, hashes):
|
def return_matches(self, hashes):
|
19
dejavu/base_classes/base_recognizer.py
Normal file
19
dejavu/base_classes/base_recognizer.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
import abc
|
||||||
|
|
||||||
|
from dejavu.config.settings import DEFAULT_FS
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRecognizer(object, metaclass=abc.ABCMeta):
|
||||||
|
def __init__(self, dejavu):
|
||||||
|
self.dejavu = dejavu
|
||||||
|
self.Fs = DEFAULT_FS
|
||||||
|
|
||||||
|
def _recognize(self, *data):
|
||||||
|
matches = []
|
||||||
|
for d in data:
|
||||||
|
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||||
|
return self.dejavu.align_matches(matches)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def recognize(self):
|
||||||
|
pass # base class does nothing
|
194
dejavu/base_classes/common_database.py
Normal file
194
dejavu/base_classes/common_database.py
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
import abc
|
||||||
|
|
||||||
|
from dejavu.base_classes.base_database import BaseDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
|
||||||
|
# Since several methods across different databases are actually just the same
|
||||||
|
# I've built this class with the idea to reuse that logic instead of copy pasting
|
||||||
|
# over and over the same code.
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def before_fork(self):
|
||||||
|
"""
|
||||||
|
Called before the database instance is given to the new process
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_fork(self):
|
||||||
|
"""
|
||||||
|
Called after the database instance has been given to the new process
|
||||||
|
|
||||||
|
This will be called in the new process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""
|
||||||
|
Called on creation or shortly afterwards.
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.CREATE_SONGS_TABLE)
|
||||||
|
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||||
|
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||||
|
|
||||||
|
def empty(self):
|
||||||
|
"""
|
||||||
|
Called when the database should be cleared of all data.
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.DROP_FINGERPRINTS)
|
||||||
|
cur.execute(self.DROP_SONGS)
|
||||||
|
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
def delete_unfingerprinted_songs(self):
|
||||||
|
"""
|
||||||
|
Called to remove any song entries that do not have any fingerprints
|
||||||
|
associated with them.
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||||
|
|
||||||
|
def get_num_songs(self):
|
||||||
|
"""
|
||||||
|
Returns the amount of songs in the database.
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
|
||||||
|
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_num_fingerprints(self):
|
||||||
|
"""
|
||||||
|
Returns the number of fingerprints in the database.
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.SELECT_NUM_FINGERPRINTS)
|
||||||
|
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def set_song_fingerprinted(self, sid):
|
||||||
|
"""
|
||||||
|
Sets a specific song as having all fingerprints in the database.
|
||||||
|
|
||||||
|
sid: Song identifier
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
|
||||||
|
|
||||||
|
def get_songs(self):
|
||||||
|
"""
|
||||||
|
Returns all fully fingerprinted songs in the database. Result must be a Dictionary.
|
||||||
|
"""
|
||||||
|
with self.cursor(dictionary=True) as cur:
|
||||||
|
cur.execute(self.SELECT_SONGS)
|
||||||
|
for row in cur:
|
||||||
|
yield row
|
||||||
|
|
||||||
|
def get_song_by_id(self, sid):
|
||||||
|
"""
|
||||||
|
Return a song by its identifier. Result must be a Dictionary.
|
||||||
|
sid: Song identifier
|
||||||
|
"""
|
||||||
|
with self.cursor(dictionary=True) as cur:
|
||||||
|
cur.execute(self.SELECT_SONG, (sid,))
|
||||||
|
return cur.fetchone()
|
||||||
|
|
||||||
|
def insert(self, hash, sid, offset):
|
||||||
|
"""
|
||||||
|
Inserts a single fingerprint into the database.
|
||||||
|
|
||||||
|
hash: Part of a sha1 hash, in hexadecimal format
|
||||||
|
sid: Song identifier this fingerprint is off
|
||||||
|
offset: The offset this hash is from
|
||||||
|
"""
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def insert_song(self, song_name):
|
||||||
|
"""
|
||||||
|
Inserts a song name into the database, returns the new
|
||||||
|
identifier of the song.
|
||||||
|
|
||||||
|
song_name: The name of the song.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def query(self, fingerprint):
|
||||||
|
"""
|
||||||
|
Returns all matching fingerprint entries associated with
|
||||||
|
the given hash as parameter.
|
||||||
|
|
||||||
|
hash: Part of a sha1 hash, in hexadecimal format
|
||||||
|
"""
|
||||||
|
if fingerprint:
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.SELECT, (fingerprint,))
|
||||||
|
for sid, offset in cur:
|
||||||
|
yield (sid, offset)
|
||||||
|
else: # select all if no key
|
||||||
|
with self.cursor() as cur:
|
||||||
|
cur.execute(self.SELECT_ALL)
|
||||||
|
for sid, offset in cur:
|
||||||
|
yield (sid, offset)
|
||||||
|
|
||||||
|
def get_iterable_kv_pairs(self):
|
||||||
|
"""
|
||||||
|
Returns all fingerprints in the database.
|
||||||
|
"""
|
||||||
|
return self.query(None)
|
||||||
|
|
||||||
|
def insert_hashes(self, sid, hashes, batch=1000):
|
||||||
|
"""
|
||||||
|
Insert a multitude of fingerprints.
|
||||||
|
|
||||||
|
sid: Song identifier the fingerprints belong to
|
||||||
|
hashes: A sequence of tuples in the format (hash, offset)
|
||||||
|
- hash: Part of a sha1 hash, in hexadecimal format
|
||||||
|
- offset: Offset this hash was created from/at.
|
||||||
|
"""
|
||||||
|
values = [(sid, hsh, int(offset)) for hsh, offset in hashes]
|
||||||
|
|
||||||
|
with self.cursor() as cur:
|
||||||
|
for index in range(0, len(hashes), batch):
|
||||||
|
cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch])
|
||||||
|
|
||||||
|
def return_matches(self, hashes, batch=1000):
|
||||||
|
"""
|
||||||
|
Searches the database for pairs of (hash, offset) values.
|
||||||
|
|
||||||
|
hashes: A sequence of tuples in the format (hash, offset)
|
||||||
|
- hash: Part of a sha1 hash, in hexadecimal format
|
||||||
|
- offset: Offset this hash was created from/at.
|
||||||
|
|
||||||
|
Returns a sequence of (sid, offset_difference) tuples.
|
||||||
|
|
||||||
|
sid: Song identifier
|
||||||
|
offset_difference: (offset - database_offset)
|
||||||
|
"""
|
||||||
|
# Create a dictionary of hash => offset pairs for later lookups
|
||||||
|
mapper = {}
|
||||||
|
for hsh, offset in hashes:
|
||||||
|
mapper[hsh.upper()] = offset
|
||||||
|
|
||||||
|
# Get an iterable of all the hashes we need
|
||||||
|
values = list(mapper.keys())
|
||||||
|
|
||||||
|
with self.cursor() as cur:
|
||||||
|
for index in range(0, len(values), batch):
|
||||||
|
# Create our IN part of the query
|
||||||
|
query = self.SELECT_MULTIPLE
|
||||||
|
query = query % ', '.join([self.IN_MATCH] * len(values[index: index + batch]))
|
||||||
|
|
||||||
|
cur.execute(query, values[index: index + batch])
|
||||||
|
|
||||||
|
for hsh, sid, offset in cur:
|
||||||
|
# (sid, db_offset - song_sampled_offset)
|
||||||
|
yield (sid, offset - mapper[hsh])
|
|
@ -3,13 +3,114 @@ import queue
|
||||||
import mysql.connector
|
import mysql.connector
|
||||||
from mysql.connector.errors import DatabaseError
|
from mysql.connector.errors import DatabaseError
|
||||||
|
|
||||||
import dejavu.database_handler.mysql_queries as queries
|
from dejavu.base_classes.common_database import CommonDatabase
|
||||||
from dejavu.database import Database
|
from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||||
|
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||||
|
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
||||||
|
SONGS_TABLENAME)
|
||||||
|
|
||||||
|
|
||||||
class MySQLDatabase(Database):
|
class MySQLDatabase(CommonDatabase):
|
||||||
type = "mysql"
|
type = "mysql"
|
||||||
|
|
||||||
|
# CREATES
|
||||||
|
CREATE_SONGS_TABLE = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
|
||||||
|
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
|
||||||
|
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
|
||||||
|
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
|
||||||
|
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
|
||||||
|
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||||
|
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
|
||||||
|
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
|
||||||
|
) ENGINE=INNODB;
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_FINGERPRINTS_TABLE = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
|
||||||
|
`{FIELD_HASH}` BINARY(10) NOT NULL
|
||||||
|
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
|
||||||
|
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
|
||||||
|
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||||
|
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
|
||||||
|
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
|
||||||
|
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
|
||||||
|
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
|
||||||
|
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
|
||||||
|
) ENGINE=INNODB;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# INSERTS (IGNORES DUPLICATES)
|
||||||
|
INSERT_FINGERPRINT = f"""
|
||||||
|
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
|
||||||
|
`{FIELD_SONG_ID}`
|
||||||
|
, `{FIELD_HASH}`
|
||||||
|
, `{FIELD_OFFSET}`)
|
||||||
|
VALUES (%s, UNHEX(%s), %s);
|
||||||
|
"""
|
||||||
|
|
||||||
|
INSERT_SONG = f"""
|
||||||
|
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`)
|
||||||
|
VALUES (%s, UNHEX(%s));
|
||||||
|
"""
|
||||||
|
|
||||||
|
# SELECTS
|
||||||
|
SELECT = f"""
|
||||||
|
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||||
|
FROM `{FINGERPRINTS_TABLENAME}`
|
||||||
|
WHERE `{FIELD_HASH}` = UNHEX(%s);
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_MULTIPLE = f"""
|
||||||
|
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||||
|
FROM `{FINGERPRINTS_TABLENAME}`
|
||||||
|
WHERE `{FIELD_HASH}` IN (%s);
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||||
|
|
||||||
|
SELECT_SONG = f"""
|
||||||
|
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||||
|
FROM `{SONGS_TABLENAME}`
|
||||||
|
WHERE `{FIELD_SONG_ID}` = %s;
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||||
|
|
||||||
|
SELECT_UNIQUE_SONG_IDS = f"""
|
||||||
|
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
|
||||||
|
FROM `{SONGS_TABLENAME}`
|
||||||
|
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_SONGS = f"""
|
||||||
|
SELECT
|
||||||
|
`{FIELD_SONG_ID}`
|
||||||
|
, `{FIELD_SONGNAME}`
|
||||||
|
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||||
|
FROM `{SONGS_TABLENAME}`
|
||||||
|
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DROPS
|
||||||
|
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
|
||||||
|
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
|
||||||
|
|
||||||
|
# update
|
||||||
|
UPDATE_SONG_FINGERPRINTED = f"""
|
||||||
|
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DELETE
|
||||||
|
DELETE_UNFINGERPRINTED = f"""
|
||||||
|
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# IN
|
||||||
|
IN_MATCH = f"UNHEX(%s)"
|
||||||
|
|
||||||
def __init__(self, **options):
|
def __init__(self, **options):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cursor = cursor_factory(**options)
|
self.cursor = cursor_factory(**options)
|
||||||
|
@ -20,160 +121,14 @@ class MySQLDatabase(Database):
|
||||||
# the previous process.
|
# the previous process.
|
||||||
Cursor.clear_cache()
|
Cursor.clear_cache()
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
"""
|
|
||||||
Creates any non-existing tables required for dejavu to function.
|
|
||||||
|
|
||||||
This also removes all songs that have been added but have no
|
|
||||||
fingerprints associated with them.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.CREATE_SONGS_TABLE)
|
|
||||||
cur.execute(queries.CREATE_FINGERPRINTS_TABLE)
|
|
||||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
"""
|
|
||||||
Drops tables created by dejavu and then creates them again
|
|
||||||
by calling `SQLDatabase.setup`.
|
|
||||||
|
|
||||||
.. warning:
|
|
||||||
This will result in a loss of data
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.DROP_FINGERPRINTS)
|
|
||||||
cur.execute(queries.DROP_SONGS)
|
|
||||||
|
|
||||||
self.setup()
|
|
||||||
|
|
||||||
def delete_unfingerprinted_songs(self):
|
|
||||||
"""
|
|
||||||
Removes all songs that have no fingerprints associated with them.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
|
||||||
|
|
||||||
def get_num_songs(self):
|
|
||||||
"""
|
|
||||||
Returns number of songs the database has fingerprinted.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT_UNIQUE_SONG_IDS)
|
|
||||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
def get_num_fingerprints(self):
|
|
||||||
"""
|
|
||||||
Returns number of fingerprints the database has fingerprinted.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT_NUM_FINGERPRINTS)
|
|
||||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
|
||||||
cur.close()
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
def set_song_fingerprinted(self, sid):
|
|
||||||
"""
|
|
||||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
|
||||||
fingerprinted in the database.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,))
|
|
||||||
|
|
||||||
def get_songs(self):
|
|
||||||
"""
|
|
||||||
Return songs that have the fingerprinted flag set TRUE (1).
|
|
||||||
"""
|
|
||||||
with self.cursor(dictionary=True) as cur:
|
|
||||||
cur.execute(queries.SELECT_SONGS)
|
|
||||||
for row in cur:
|
|
||||||
yield row
|
|
||||||
|
|
||||||
def get_song_by_id(self, sid):
|
|
||||||
"""
|
|
||||||
Returns song by its ID.
|
|
||||||
"""
|
|
||||||
with self.cursor(dictionary=True) as cur:
|
|
||||||
cur.execute(queries.SELECT_SONG, (sid,))
|
|
||||||
return cur.fetchone()
|
|
||||||
|
|
||||||
def insert(self, hash, sid, offset):
|
|
||||||
"""
|
|
||||||
Insert a (sha1, song_id, offset) row into database.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset))
|
|
||||||
|
|
||||||
def insert_song(self, song_name, file_hash):
|
def insert_song(self, song_name, file_hash):
|
||||||
"""
|
"""
|
||||||
Inserts song in the database and returns the ID of the inserted record.
|
Inserts song in the database and returns the ID of the inserted record.
|
||||||
"""
|
"""
|
||||||
with self.cursor() as cur:
|
with self.cursor() as cur:
|
||||||
cur.execute(queries.INSERT_SONG, (song_name, file_hash))
|
cur.execute(self.INSERT_SONG, (song_name, file_hash))
|
||||||
return cur.lastrowid
|
return cur.lastrowid
|
||||||
|
|
||||||
def query(self, hash):
|
|
||||||
"""
|
|
||||||
Return all tuples associated with hash.
|
|
||||||
|
|
||||||
If hash is None, returns all entries in the
|
|
||||||
database (be careful with that one!).
|
|
||||||
"""
|
|
||||||
if hash:
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT, (hash,))
|
|
||||||
for sid, offset in cur:
|
|
||||||
yield (sid, offset)
|
|
||||||
else: # select all if no key
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT_ALL)
|
|
||||||
for sid, offset in cur:
|
|
||||||
yield (sid, offset)
|
|
||||||
|
|
||||||
def get_iterable_kv_pairs(self):
|
|
||||||
"""
|
|
||||||
Returns all tuples in database.
|
|
||||||
"""
|
|
||||||
return self.query(None)
|
|
||||||
|
|
||||||
def insert_hashes(self, sid, hashes, batch=1000):
|
|
||||||
"""
|
|
||||||
Insert series of hash => song_id, offset
|
|
||||||
values into the database.
|
|
||||||
"""
|
|
||||||
values = [(sid, hash, int(offset)) for hash, offset in hashes]
|
|
||||||
|
|
||||||
with self.cursor() as cur:
|
|
||||||
for index in range(0, len(hashes), batch):
|
|
||||||
cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch])
|
|
||||||
|
|
||||||
def return_matches(self, hashes, batch=1000):
|
|
||||||
"""
|
|
||||||
Return the (song_id, offset_diff) tuples associated with
|
|
||||||
a list of (sha1, sample_offset) values.
|
|
||||||
"""
|
|
||||||
# Create a dictionary of hash => offset pairs for later lookups
|
|
||||||
mapper = {}
|
|
||||||
for hash, offset in hashes:
|
|
||||||
mapper[hash.upper()] = offset
|
|
||||||
|
|
||||||
# Get an iterable of all the hashes we need
|
|
||||||
values = list(mapper.keys())
|
|
||||||
|
|
||||||
with self.cursor() as cur:
|
|
||||||
for index in range(0, len(values), batch):
|
|
||||||
# Create our IN part of the query
|
|
||||||
query = queries.SELECT_MULTIPLE
|
|
||||||
query = query % ', '.join(['UNHEX(%s)'] * len(values[index: index + batch]))
|
|
||||||
|
|
||||||
cur.execute(query, values[index: index + batch])
|
|
||||||
|
|
||||||
for hash, sid, offset in cur:
|
|
||||||
# (sid, db_offset - song_sampled_offset)
|
|
||||||
yield (sid, offset - mapper[hash])
|
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return self._options,
|
return self._options,
|
||||||
|
|
||||||
|
|
|
@ -1,134 +0,0 @@
|
||||||
from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
|
||||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
|
||||||
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
|
||||||
SONGS_TABLENAME)
|
|
||||||
|
|
||||||
"""
|
|
||||||
Queries:
|
|
||||||
|
|
||||||
1) Find duplicates (shouldn't be any, though):
|
|
||||||
|
|
||||||
select `hash`, `song_id`, `offset`, count(*) cnt
|
|
||||||
from fingerprints
|
|
||||||
group by `hash`, `song_id`, `offset`
|
|
||||||
having cnt > 1
|
|
||||||
order by cnt asc;
|
|
||||||
|
|
||||||
2) Get number of hashes by song:
|
|
||||||
|
|
||||||
select song_id, song_name, count(song_id) as num
|
|
||||||
from fingerprints
|
|
||||||
natural join songs
|
|
||||||
group by song_id
|
|
||||||
order by count(song_id) desc;
|
|
||||||
|
|
||||||
3) get hashes with highest number of collisions
|
|
||||||
|
|
||||||
select
|
|
||||||
hash,
|
|
||||||
count(distinct song_id) as n
|
|
||||||
from fingerprints
|
|
||||||
group by `hash`
|
|
||||||
order by n DESC;
|
|
||||||
|
|
||||||
=> 26 different songs with same fingerprint (392 times):
|
|
||||||
|
|
||||||
select songs.song_name, fingerprints.offset
|
|
||||||
from fingerprints natural join songs
|
|
||||||
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
|
|
||||||
"""
|
|
||||||
|
|
||||||
# creates
|
|
||||||
CREATE_SONGS_TABLE = f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
|
|
||||||
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
|
|
||||||
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
|
|
||||||
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
|
|
||||||
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
|
|
||||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
||||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
|
||||||
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
|
|
||||||
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
|
|
||||||
) ENGINE=INNODB;
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_FINGERPRINTS_TABLE = f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
|
|
||||||
`{FIELD_HASH}` BINARY(10) NOT NULL
|
|
||||||
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
|
|
||||||
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
|
|
||||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
||||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
|
||||||
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
|
|
||||||
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
|
|
||||||
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
|
|
||||||
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
|
|
||||||
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
|
|
||||||
) ENGINE=INNODB;
|
|
||||||
"""
|
|
||||||
|
|
||||||
# inserts (ignores duplicates)
|
|
||||||
INSERT_FINGERPRINT = f"""
|
|
||||||
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
|
|
||||||
`{FIELD_SONG_ID}`
|
|
||||||
, `{FIELD_HASH}`
|
|
||||||
, `{FIELD_OFFSET}`)
|
|
||||||
VALUES (%s, UNHEX(%s), %s);
|
|
||||||
"""
|
|
||||||
|
|
||||||
INSERT_SONG = f"""
|
|
||||||
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`)
|
|
||||||
VALUES (%s, UNHEX(%s));
|
|
||||||
"""
|
|
||||||
|
|
||||||
# selects
|
|
||||||
SELECT = f"""
|
|
||||||
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
|
||||||
FROM `{FINGERPRINTS_TABLENAME}`
|
|
||||||
WHERE `{FIELD_HASH}` = UNHEX(%s);
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_MULTIPLE = f"""
|
|
||||||
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
|
||||||
FROM `{FINGERPRINTS_TABLENAME}`
|
|
||||||
WHERE `{FIELD_HASH}` IN (%s);
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
|
|
||||||
|
|
||||||
SELECT_SONG = f"""
|
|
||||||
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
|
||||||
FROM `{SONGS_TABLENAME}`
|
|
||||||
WHERE `{FIELD_SONG_ID}` = %s;
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
|
|
||||||
|
|
||||||
SELECT_UNIQUE_SONG_IDS = f"""
|
|
||||||
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
|
|
||||||
FROM `{SONGS_TABLENAME}`
|
|
||||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_SONGS = f"""
|
|
||||||
SELECT
|
|
||||||
`{FIELD_SONG_ID}`
|
|
||||||
, `{FIELD_SONGNAME}`
|
|
||||||
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
|
||||||
FROM `{SONGS_TABLENAME}`
|
|
||||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
|
||||||
"""
|
|
||||||
|
|
||||||
# drops
|
|
||||||
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
|
|
||||||
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
|
|
||||||
|
|
||||||
# update
|
|
||||||
UPDATE_SONG_FINGERPRINTED = f"""
|
|
||||||
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
|
|
||||||
"""
|
|
||||||
|
|
||||||
# delete
|
|
||||||
DELETE_UNFINGERPRINTED = f"""
|
|
||||||
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
|
|
||||||
"""
|
|
|
@ -3,13 +3,124 @@ import queue
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extras import DictCursor
|
from psycopg2.extras import DictCursor
|
||||||
|
|
||||||
import dejavu.database_handler.postgres_queries as queries
|
from dejavu.base_classes.common_database import CommonDatabase
|
||||||
from dejavu.database import Database
|
from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||||
|
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||||
|
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
||||||
|
SONGS_TABLENAME)
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLDatabase(Database):
|
class PostgreSQLDatabase(CommonDatabase):
|
||||||
type = "postgres"
|
type = "postgres"
|
||||||
|
|
||||||
|
# CREATES
|
||||||
|
CREATE_SONGS_TABLE = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
|
||||||
|
"{FIELD_SONG_ID}" SERIAL
|
||||||
|
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
|
||||||
|
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
|
||||||
|
, "{FIELD_FILE_SHA1}" BYTEA
|
||||||
|
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
|
||||||
|
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_FINGERPRINTS_TABLE = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
|
||||||
|
"{FIELD_HASH}" BYTEA NOT NULL
|
||||||
|
, "{FIELD_SONG_ID}" INT NOT NULL
|
||||||
|
, "{FIELD_OFFSET}" INT NOT NULL
|
||||||
|
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
|
||||||
|
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
|
||||||
|
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
|
||||||
|
USING hash ("{FIELD_HASH}");
|
||||||
|
"""
|
||||||
|
|
||||||
|
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
|
||||||
|
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
|
||||||
|
USING hash ("{FIELD_HASH}");
|
||||||
|
"""
|
||||||
|
|
||||||
|
# INSERTS (IGNORES DUPLICATES)
|
||||||
|
INSERT_FINGERPRINT = f"""
|
||||||
|
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
|
||||||
|
"{FIELD_SONG_ID}"
|
||||||
|
, "{FIELD_HASH}"
|
||||||
|
, "{FIELD_OFFSET}")
|
||||||
|
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
|
||||||
|
"""
|
||||||
|
|
||||||
|
INSERT_SONG = f"""
|
||||||
|
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}")
|
||||||
|
VALUES (%s, decode(%s, 'hex'))
|
||||||
|
RETURNING "{FIELD_SONG_ID}";
|
||||||
|
"""
|
||||||
|
|
||||||
|
# SELECTS
|
||||||
|
SELECT = f"""
|
||||||
|
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||||
|
FROM "{FINGERPRINTS_TABLENAME}"
|
||||||
|
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_MULTIPLE = f"""
|
||||||
|
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||||
|
FROM "{FINGERPRINTS_TABLENAME}"
|
||||||
|
WHERE "{FIELD_HASH}" IN (%s);
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
|
||||||
|
|
||||||
|
SELECT_SONG = f"""
|
||||||
|
SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||||
|
FROM "{SONGS_TABLENAME}"
|
||||||
|
WHERE "{FIELD_SONG_ID}" = %s;
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
|
||||||
|
|
||||||
|
SELECT_UNIQUE_SONG_IDS = f"""
|
||||||
|
SELECT COUNT("{FIELD_SONG_ID}") AS n
|
||||||
|
FROM "{SONGS_TABLENAME}"
|
||||||
|
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
SELECT_SONGS = f"""
|
||||||
|
SELECT
|
||||||
|
"{FIELD_SONG_ID}"
|
||||||
|
, "{FIELD_SONGNAME}"
|
||||||
|
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||||
|
FROM "{SONGS_TABLENAME}"
|
||||||
|
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DROPS
|
||||||
|
DROP_FINGERPRINTS = F'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
|
||||||
|
DROP_SONGS = F'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
|
||||||
|
|
||||||
|
# UPDATE
|
||||||
|
UPDATE_SONG_FINGERPRINTED = f"""
|
||||||
|
UPDATE "{SONGS_TABLENAME}" SET
|
||||||
|
"{FIELD_FINGERPRINTED}" = 1
|
||||||
|
, "date_modified" = now()
|
||||||
|
WHERE "{FIELD_SONG_ID}" = %s;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DELETE
|
||||||
|
DELETE_UNFINGERPRINTED = f"""
|
||||||
|
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
|
||||||
|
"""
|
||||||
|
|
||||||
|
# IN
|
||||||
|
IN_MATCH = f"decode(%s, 'hex')"
|
||||||
|
|
||||||
def __init__(self, **options):
|
def __init__(self, **options):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cursor = cursor_factory(**options)
|
self.cursor = cursor_factory(**options)
|
||||||
|
@ -20,160 +131,14 @@ class PostgreSQLDatabase(Database):
|
||||||
# the previous process.
|
# the previous process.
|
||||||
Cursor.clear_cache()
|
Cursor.clear_cache()
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
"""
|
|
||||||
Creates any non-existing tables required for dejavu to function.
|
|
||||||
|
|
||||||
This also removes all songs that have been added but have no
|
|
||||||
fingerprints associated with them.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.CREATE_SONGS_TABLE)
|
|
||||||
cur.execute(queries.CREATE_FINGERPRINTS_TABLE)
|
|
||||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
"""
|
|
||||||
Drops tables created by dejavu and then creates them again
|
|
||||||
by calling `SQLDatabase.setup`.
|
|
||||||
|
|
||||||
.. warning:
|
|
||||||
This will result in a loss of data
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.DROP_FINGERPRINTS)
|
|
||||||
cur.execute(queries.DROP_SONGS)
|
|
||||||
|
|
||||||
self.setup()
|
|
||||||
|
|
||||||
def delete_unfingerprinted_songs(self):
|
|
||||||
"""
|
|
||||||
Removes all songs that have no fingerprints associated with them.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.DELETE_UNFINGERPRINTED)
|
|
||||||
|
|
||||||
def get_num_songs(self):
|
|
||||||
"""
|
|
||||||
Returns number of songs the database has fingerprinted.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT_UNIQUE_SONG_IDS)
|
|
||||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
def get_num_fingerprints(self):
|
|
||||||
"""
|
|
||||||
Returns number of fingerprints the database has fingerprinted.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT_NUM_FINGERPRINTS)
|
|
||||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
|
||||||
cur.close()
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
def set_song_fingerprinted(self, sid):
|
|
||||||
"""
|
|
||||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
|
||||||
fingerprinted in the database.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.UPDATE_SONG_FINGERPRINTED, (sid,))
|
|
||||||
|
|
||||||
def get_songs(self):
|
|
||||||
"""
|
|
||||||
Return songs that have the fingerprinted flag set TRUE (1).
|
|
||||||
"""
|
|
||||||
with self.cursor(dictionary=True) as cur:
|
|
||||||
cur.execute(queries.SELECT_SONGS)
|
|
||||||
for row in cur:
|
|
||||||
yield row
|
|
||||||
|
|
||||||
def get_song_by_id(self, sid):
|
|
||||||
"""
|
|
||||||
Returns song by its ID.
|
|
||||||
"""
|
|
||||||
with self.cursor(dictionary=True) as cur:
|
|
||||||
cur.execute(queries.SELECT_SONG, (sid,))
|
|
||||||
return cur.fetchone()
|
|
||||||
|
|
||||||
def insert(self, hash, sid, offset):
|
|
||||||
"""
|
|
||||||
Insert a (sha1, song_id, offset) row into database.
|
|
||||||
"""
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.INSERT_FINGERPRINT, (hash, sid, offset))
|
|
||||||
|
|
||||||
def insert_song(self, song_name, file_hash):
|
def insert_song(self, song_name, file_hash):
|
||||||
"""
|
"""
|
||||||
Inserts song in the database and returns the ID of the inserted record.
|
Inserts song in the database and returns the ID of the inserted record.
|
||||||
"""
|
"""
|
||||||
with self.cursor() as cur:
|
with self.cursor() as cur:
|
||||||
cur.execute(queries.INSERT_SONG, (song_name, file_hash))
|
cur.execute(self.INSERT_SONG, (song_name, file_hash))
|
||||||
return cur.fetchone()[0]
|
return cur.fetchone()[0]
|
||||||
|
|
||||||
def query(self, hash):
|
|
||||||
"""
|
|
||||||
Return all tuples associated with hash.
|
|
||||||
|
|
||||||
If hash is None, returns all entries in the
|
|
||||||
database (be careful with that one!).
|
|
||||||
"""
|
|
||||||
if hash:
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT, (hash,))
|
|
||||||
for sid, offset in cur:
|
|
||||||
yield (sid, offset)
|
|
||||||
else: # select all if no key
|
|
||||||
with self.cursor() as cur:
|
|
||||||
cur.execute(queries.SELECT_ALL)
|
|
||||||
for sid, offset in cur:
|
|
||||||
yield (sid, offset)
|
|
||||||
|
|
||||||
def get_iterable_kv_pairs(self):
|
|
||||||
"""
|
|
||||||
Returns all tuples in database.
|
|
||||||
"""
|
|
||||||
return self.query(None)
|
|
||||||
|
|
||||||
def insert_hashes(self, sid, hashes, batch=1000):
|
|
||||||
"""
|
|
||||||
Insert series of hash => song_id, offset
|
|
||||||
values into the database.
|
|
||||||
"""
|
|
||||||
values = [(sid, hash, int(offset)) for hash, offset in hashes]
|
|
||||||
|
|
||||||
with self.cursor() as cur:
|
|
||||||
for index in range(0, len(hashes), batch):
|
|
||||||
cur.executemany(queries.INSERT_FINGERPRINT, values[index: index + batch])
|
|
||||||
|
|
||||||
def return_matches(self, hashes, batch=1000):
|
|
||||||
"""
|
|
||||||
Return the (song_id, offset_diff) tuples associated with
|
|
||||||
a list of (sha1, sample_offset) values.
|
|
||||||
"""
|
|
||||||
# Create a dictionary of hash => offset pairs for later lookups
|
|
||||||
mapper = {}
|
|
||||||
for hash, offset in hashes:
|
|
||||||
mapper[hash.upper()] = offset
|
|
||||||
|
|
||||||
# Get an iterable of all the hashes we need
|
|
||||||
values = list(mapper.keys())
|
|
||||||
|
|
||||||
with self.cursor() as cur:
|
|
||||||
for index in range(0, len(values), batch):
|
|
||||||
# Create our IN part of the query
|
|
||||||
query = queries.SELECT_MULTIPLE
|
|
||||||
query = query % ', '.join(["decode(%s, 'hex')"] * len(values[index: index + batch]))
|
|
||||||
|
|
||||||
cur.execute(query, values[index: index + batch])
|
|
||||||
|
|
||||||
for hash, sid, offset in cur:
|
|
||||||
# (sid, db_offset - song_sampled_offset)
|
|
||||||
yield (sid, offset - mapper[hash.upper()])
|
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return self._options,
|
return self._options,
|
||||||
|
|
||||||
|
|
|
@ -1,104 +0,0 @@
|
||||||
from dejavu.config.config import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
|
||||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
|
||||||
FIELD_SONGNAME, FINGERPRINTS_TABLENAME,
|
|
||||||
SONGS_TABLENAME)
|
|
||||||
|
|
||||||
# creates
|
|
||||||
CREATE_SONGS_TABLE = f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
|
|
||||||
"{FIELD_SONG_ID}" SERIAL
|
|
||||||
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
|
|
||||||
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
|
|
||||||
, "{FIELD_FILE_SHA1}" BYTEA
|
|
||||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
|
||||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
|
||||||
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
|
|
||||||
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_FINGERPRINTS_TABLE = f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
|
|
||||||
"{FIELD_HASH}" BYTEA NOT NULL
|
|
||||||
, "{FIELD_SONG_ID}" INT NOT NULL
|
|
||||||
, "{FIELD_OFFSET}" INT NOT NULL
|
|
||||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
|
||||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
|
||||||
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
|
|
||||||
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
|
|
||||||
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}");
|
|
||||||
"""
|
|
||||||
|
|
||||||
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
|
|
||||||
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}" USING hash ("{FIELD_HASH}");
|
|
||||||
"""
|
|
||||||
|
|
||||||
# inserts (ignores duplicates)
|
|
||||||
INSERT_FINGERPRINT = f"""
|
|
||||||
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
|
|
||||||
"{FIELD_SONG_ID}"
|
|
||||||
, "{FIELD_HASH}"
|
|
||||||
, "{FIELD_OFFSET}")
|
|
||||||
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
|
|
||||||
"""
|
|
||||||
|
|
||||||
INSERT_SONG = f"""
|
|
||||||
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}")
|
|
||||||
VALUES (%s, decode(%s, 'hex'))
|
|
||||||
RETURNING "{FIELD_SONG_ID}";
|
|
||||||
"""
|
|
||||||
|
|
||||||
# selects
|
|
||||||
SELECT = f"""
|
|
||||||
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
|
||||||
FROM "{FINGERPRINTS_TABLENAME}"
|
|
||||||
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_MULTIPLE = f"""
|
|
||||||
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
|
||||||
FROM "{FINGERPRINTS_TABLENAME}"
|
|
||||||
WHERE "{FIELD_HASH}" IN (%s);
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
|
|
||||||
|
|
||||||
SELECT_SONG = f"""
|
|
||||||
SELECT "{FIELD_SONGNAME}", upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
|
||||||
FROM "{SONGS_TABLENAME}"
|
|
||||||
WHERE "{FIELD_SONG_ID}" = %s;
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
|
|
||||||
|
|
||||||
SELECT_UNIQUE_SONG_IDS = f"""
|
|
||||||
SELECT COUNT("{FIELD_SONG_ID}") AS n
|
|
||||||
FROM "{SONGS_TABLENAME}"
|
|
||||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
|
||||||
"""
|
|
||||||
|
|
||||||
SELECT_SONGS = f"""
|
|
||||||
SELECT
|
|
||||||
"{FIELD_SONG_ID}"
|
|
||||||
, "{FIELD_SONGNAME}"
|
|
||||||
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
|
||||||
FROM "{SONGS_TABLENAME}"
|
|
||||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
|
||||||
"""
|
|
||||||
|
|
||||||
# drops
|
|
||||||
DROP_FINGERPRINTS = f'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
|
|
||||||
DROP_SONGS = f'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
|
|
||||||
|
|
||||||
# update
|
|
||||||
UPDATE_SONG_FINGERPRINTED = f"""
|
|
||||||
UPDATE "{SONGS_TABLENAME}" SET "{FIELD_FINGERPRINTED}" = 1, "date_modified" = now() WHERE "{FIELD_SONG_ID}" = %s;
|
|
||||||
"""
|
|
||||||
|
|
||||||
# delete
|
|
||||||
DELETE_UNFINGERPRINTED = f"""
|
|
||||||
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
|
|
||||||
"""
|
|
0
dejavu/logic/__init__.py
Normal file
0
dejavu/logic/__init__.py
Normal file
|
@ -1,10 +1,12 @@
|
||||||
import os
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import os
|
||||||
|
from hashlib import sha1
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from pydub.utils import audioop
|
from pydub.utils import audioop
|
||||||
from . import wavio
|
|
||||||
from hashlib import sha1
|
from dejavu.third_party import wavio
|
||||||
|
|
||||||
|
|
||||||
def unique_hash(filepath, blocksize=2**20):
|
def unique_hash(filepath, blocksize=2**20):
|
|
@ -9,7 +9,7 @@ from scipy.ndimage.morphology import (binary_erosion,
|
||||||
generate_binary_structure,
|
generate_binary_structure,
|
||||||
iterate_structure)
|
iterate_structure)
|
||||||
|
|
||||||
from dejavu.config.config import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
|
from dejavu.config.settings import (DEFAULT_AMP_MIN, DEFAULT_FAN_VALUE,
|
||||||
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
||||||
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
|
DEFAULT_WINDOW_SIZE, FINGERPRINT_REDUCTION,
|
||||||
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
|
MAX_HASH_TIME_DELTA, MIN_HASH_TIME_DELTA,
|
0
dejavu/logic/recognizer/__init__.py
Normal file
0
dejavu/logic/recognizer/__init__.py
Normal file
24
dejavu/logic/recognizer/file_recognizer.py
Normal file
24
dejavu/logic/recognizer/file_recognizer.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import dejavu.logic.decoder as decoder
|
||||||
|
from dejavu.base_classes.base_recognizer import BaseRecognizer
|
||||||
|
|
||||||
|
|
||||||
|
class FileRecognizer(BaseRecognizer):
|
||||||
|
def __init__(self, dejavu):
|
||||||
|
super().__init__(dejavu)
|
||||||
|
|
||||||
|
def recognize_file(self, filename):
|
||||||
|
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
matches = self._recognize(*frames)
|
||||||
|
t = time.time() - t
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
match['match_time'] = t
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def recognize(self, filename):
|
||||||
|
return self.recognize_file(filename)
|
|
@ -1,45 +1,7 @@
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyaudio
|
import pyaudio
|
||||||
|
|
||||||
import dejavu.decoder as decoder
|
from dejavu.base_classes.base_recognizer import BaseRecognizer
|
||||||
from dejavu.config.config import DEFAULT_FS
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRecognizer(object):
|
|
||||||
def __init__(self, dejavu):
|
|
||||||
self.dejavu = dejavu
|
|
||||||
self.Fs = DEFAULT_FS
|
|
||||||
|
|
||||||
def _recognize(self, *data):
|
|
||||||
matches = []
|
|
||||||
for d in data:
|
|
||||||
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
|
||||||
return self.dejavu.align_matches(matches)
|
|
||||||
|
|
||||||
def recognize(self):
|
|
||||||
pass # base class does nothing
|
|
||||||
|
|
||||||
|
|
||||||
class FileRecognizer(BaseRecognizer):
|
|
||||||
def __init__(self, dejavu):
|
|
||||||
super().__init__(dejavu)
|
|
||||||
|
|
||||||
def recognize_file(self, filename):
|
|
||||||
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
matches = self._recognize(*frames)
|
|
||||||
t = time.time() - t
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
match['match_time'] = t
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
def recognize(self, filename):
|
|
||||||
return self.recognize_file(filename)
|
|
||||||
|
|
||||||
|
|
||||||
class MicrophoneRecognizer(BaseRecognizer):
|
class MicrophoneRecognizer(BaseRecognizer):
|
0
dejavu/tests/__init__.py
Normal file
0
dejavu/tests/__init__.py
Normal file
|
@ -1,130 +1,26 @@
|
||||||
|
|
||||||
import ast
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import traceback
|
import traceback
|
||||||
|
from os import listdir, makedirs, walk
|
||||||
|
from os.path import basename, exists, isfile, join, splitext
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
from dejavu import Dejavu
|
from dejavu.config.settings import (CONFIDENCE, DEFAULT_FS,
|
||||||
from dejavu.decoder import path_to_songname
|
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||||
from dejavu.fingerprint import *
|
MATCH_TIME, OFFSET, SONG_NAME)
|
||||||
|
from dejavu.logic.decoder import path_to_songname
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed=None):
|
class DejavuTest:
|
||||||
"""
|
|
||||||
`seed` as None means that the sampling will be random.
|
|
||||||
|
|
||||||
Setting your own seed means that you can produce the
|
|
||||||
same experiment over and over.
|
|
||||||
"""
|
|
||||||
if seed != None:
|
|
||||||
random.seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def get_files_recursive(src, fmt):
|
|
||||||
"""
|
|
||||||
`src` is the source directory.
|
|
||||||
`fmt` is the extension, ie ".mp3" or "mp3", etc.
|
|
||||||
"""
|
|
||||||
for root, dirnames, filenames in os.walk(src):
|
|
||||||
for filename in fnmatch.filter(filenames, '*' + fmt):
|
|
||||||
yield os.path.join(root, filename)
|
|
||||||
|
|
||||||
|
|
||||||
def get_length_audio(audiopath, extension):
|
|
||||||
"""
|
|
||||||
Returns length of audio in seconds.
|
|
||||||
Returns None if format isn't supported or in case of error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
|
|
||||||
except:
|
|
||||||
print(f"Error in get_length_audio(): {traceback.format_exc()}")
|
|
||||||
return None
|
|
||||||
return int(len(audio) / 1000.0)
|
|
||||||
|
|
||||||
|
|
||||||
def get_starttime(length, nseconds, padding):
|
|
||||||
"""
|
|
||||||
`length` is total audio length in seconds
|
|
||||||
`nseconds` is amount of time to sample in seconds
|
|
||||||
`padding` is off-limits seconds at beginning and ending
|
|
||||||
"""
|
|
||||||
maximum = length - padding - nseconds
|
|
||||||
if padding > maximum:
|
|
||||||
return 0
|
|
||||||
return random.randint(padding, maximum)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
|
|
||||||
"""
|
|
||||||
Generates a test file for each file recursively in `src` directory
|
|
||||||
of given format using `nseconds` sampled from the audio file.
|
|
||||||
|
|
||||||
Results are written to `dest` directory.
|
|
||||||
|
|
||||||
`padding` is the number of off-limit seconds and the beginning and
|
|
||||||
end of a track that won't be sampled in testing. Often you want to
|
|
||||||
avoid silence, etc.
|
|
||||||
"""
|
|
||||||
# create directories if necessary
|
|
||||||
for directory in [src, dest]:
|
|
||||||
try:
|
|
||||||
os.stat(directory)
|
|
||||||
except:
|
|
||||||
os.mkdir(directory)
|
|
||||||
|
|
||||||
# find files recursively of a given file format
|
|
||||||
for fmt in fmts:
|
|
||||||
testsources = get_files_recursive(src, fmt)
|
|
||||||
for audiosource in testsources:
|
|
||||||
|
|
||||||
print("audiosource:", audiosource)
|
|
||||||
|
|
||||||
filename, extension = os.path.splitext(os.path.basename(audiosource))
|
|
||||||
length = get_length_audio(audiosource, extension)
|
|
||||||
starttime = get_starttime(length, nseconds, padding)
|
|
||||||
|
|
||||||
test_file_name = f"{os.path.join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
|
|
||||||
|
|
||||||
subprocess.check_output([
|
|
||||||
"ffmpeg", "-y",
|
|
||||||
"-ss", f"{starttime}",
|
|
||||||
'-t', f"{nseconds}",
|
|
||||||
"-i", audiosource,
|
|
||||||
test_file_name])
|
|
||||||
|
|
||||||
|
|
||||||
def log_msg(msg, log=True, silent=False):
|
|
||||||
if log:
|
|
||||||
logging.debug(msg)
|
|
||||||
if not silent:
|
|
||||||
print(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def autolabel(rects, ax):
|
|
||||||
# attach some text labels
|
|
||||||
for rect in rects:
|
|
||||||
height = rect.get_height()
|
|
||||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
|
|
||||||
|
|
||||||
|
|
||||||
def autolabeldoubles(rects, ax):
|
|
||||||
# attach some text labels
|
|
||||||
for rect in rects:
|
|
||||||
height = rect.get_height()
|
|
||||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
|
|
||||||
ha='center', va='bottom')
|
|
||||||
|
|
||||||
|
|
||||||
class DejavuTest(object):
|
|
||||||
def __init__(self, folder, seconds):
|
def __init__(self, folder, seconds):
|
||||||
super(DejavuTest, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.test_folder = folder
|
self.test_folder = folder
|
||||||
self.test_seconds = seconds
|
self.test_seconds = seconds
|
||||||
|
@ -133,9 +29,10 @@ class DejavuTest(object):
|
||||||
print("test_seconds", self.test_seconds)
|
print("test_seconds", self.test_seconds)
|
||||||
|
|
||||||
self.test_files = [
|
self.test_files = [
|
||||||
f for f in os.listdir(self.test_folder)
|
f for f in listdir(self.test_folder)
|
||||||
if os.path.isfile(os.path.join(self.test_folder, f))
|
if isfile(join(self.test_folder, f))
|
||||||
and re.findall("[0-9]*sec", f)[0] in self.test_seconds]
|
and any([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds])
|
||||||
|
]
|
||||||
|
|
||||||
print("test_files", self.test_files)
|
print("test_files", self.test_files)
|
||||||
|
|
||||||
|
@ -154,7 +51,7 @@ class DejavuTest(object):
|
||||||
# variable match precision (if matched in the corrected time)
|
# variable match precision (if matched in the corrected time)
|
||||||
self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||||
|
|
||||||
# variable mahing time (query time)
|
# variable matching time (query time)
|
||||||
self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||||
|
|
||||||
# variable confidence
|
# variable confidence
|
||||||
|
@ -162,12 +59,12 @@ class DejavuTest(object):
|
||||||
|
|
||||||
self.begin()
|
self.begin()
|
||||||
|
|
||||||
def get_column_id (self, secs):
|
def get_column_id(self, secs):
|
||||||
for i, sec in enumerate(self.test_seconds):
|
for i, sec in enumerate(self.test_seconds):
|
||||||
if secs == sec:
|
if secs == sec:
|
||||||
return i
|
return i
|
||||||
|
|
||||||
def get_line_id (self, song):
|
def get_line_id(self, song):
|
||||||
for i, s in enumerate(self.test_songs):
|
for i, s in enumerate(self.test_songs):
|
||||||
if song == s:
|
if song == s:
|
||||||
return i
|
return i
|
||||||
|
@ -176,7 +73,7 @@ class DejavuTest(object):
|
||||||
|
|
||||||
def create_plots(self, name, results, results_folder):
|
def create_plots(self, name, results, results_folder):
|
||||||
for sec in range(0, len(self.test_seconds)):
|
for sec in range(0, len(self.test_seconds)):
|
||||||
ind = np.arange(self.n_lines) #
|
ind = np.arange(self.n_lines)
|
||||||
width = 0.25 # the width of the bars
|
width = 0.25 # the width of the bars
|
||||||
|
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
|
@ -206,7 +103,7 @@ class DejavuTest(object):
|
||||||
|
|
||||||
plt.grid()
|
plt.grid()
|
||||||
|
|
||||||
fig_name = os.path.join(results_folder, f"{name}_{self.test_seconds[sec]}.png")
|
fig_name = join(results_folder, f"{name}_{self.test_seconds[sec]}.png")
|
||||||
fig.savefig(fig_name)
|
fig.savefig(fig_name)
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
|
@ -215,16 +112,18 @@ class DejavuTest(object):
|
||||||
log_msg(f'file: {f}')
|
log_msg(f'file: {f}')
|
||||||
|
|
||||||
# get column
|
# get column
|
||||||
col = self.get_column_id(re.findall("[0-9]*sec", f)[0])
|
col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0])
|
||||||
# format: XXXX_offset_length.mp3
|
|
||||||
song = path_to_songname(f).split("_")[0]
|
# format: XXXX_offset_length.mp3, we also take into account underscores within XXXX
|
||||||
|
splits = path_to_songname(f).split("_")
|
||||||
|
song = "_".join(splits[0:len(path_to_songname(f).split("_")) - 2])
|
||||||
line = self.get_line_id(song)
|
line = self.get_line_id(song)
|
||||||
result = subprocess.check_output([
|
result = subprocess.check_output([
|
||||||
"python",
|
"python",
|
||||||
"dejavu.py",
|
"dejavu.py",
|
||||||
'-r',
|
'-r',
|
||||||
'file',
|
'file',
|
||||||
self.test_folder + "/" + f])
|
join(self.test_folder, f)])
|
||||||
|
|
||||||
if result.strip() == "None":
|
if result.strip() == "None":
|
||||||
log_msg('No match')
|
log_msg('No match')
|
||||||
|
@ -235,14 +134,12 @@ class DejavuTest(object):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
result = result.strip()
|
result = result.strip()
|
||||||
result = result.replace(" \'", ' "')
|
# we parse the output song back to a json
|
||||||
result = result.replace("{\'", '{"')
|
result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"'))
|
||||||
result = result.replace("\':", '":')
|
|
||||||
result = result.replace("\',", '",')
|
|
||||||
|
|
||||||
# which song did we predict?
|
# which song did we predict? We consider only the first match.
|
||||||
result = ast.literal_eval(result)
|
result = result[0]
|
||||||
song_result = result["song_name"]
|
song_result = result[SONG_NAME]
|
||||||
log_msg(f'song: {song}')
|
log_msg(f'song: {song}')
|
||||||
log_msg(f'song_result: {song_result}')
|
log_msg(f'song_result: {song_result}')
|
||||||
|
|
||||||
|
@ -256,21 +153,22 @@ class DejavuTest(object):
|
||||||
log_msg('correct match')
|
log_msg('correct match')
|
||||||
print(self.result_match)
|
print(self.result_match)
|
||||||
self.result_match[line][col] = 'yes'
|
self.result_match[line][col] = 'yes'
|
||||||
self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3)
|
self.result_query_duration[line][col] = round(result[MATCH_TIME], 3)
|
||||||
self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE]
|
self.result_match_confidence[line][col] = result[CONFIDENCE]
|
||||||
|
|
||||||
song_start_time = re.findall("_[^_]+", f)
|
# using replace in f for getting rid of underscores in name
|
||||||
|
song_start_time = re.findall("_[^_]+", f.replace(song, ""))
|
||||||
song_start_time = song_start_time[0].lstrip("_ ")
|
song_start_time = song_start_time[0].lstrip("_ ")
|
||||||
|
|
||||||
result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE *
|
result_start_time = round((result[OFFSET] * DEFAULT_WINDOW_SIZE *
|
||||||
DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0)
|
DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0)
|
||||||
|
|
||||||
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
||||||
if abs(self.result_matching_times[line][col]) == 1:
|
if abs(self.result_matching_times[line][col]) == 1:
|
||||||
self.result_matching_times[line][col] = 0
|
self.result_matching_times[line][col] = 0
|
||||||
|
|
||||||
log_msg(f'query duration: {round(result[Dejavu.MATCH_TIME], 3)}')
|
log_msg(f'query duration: {round(result[MATCH_TIME], 3)}')
|
||||||
log_msg(f'confidence: {result[Dejavu.CONFIDENCE]}')
|
log_msg(f'confidence: {result[CONFIDENCE]}')
|
||||||
log_msg(f'song start_time: {song_start_time}')
|
log_msg(f'song start_time: {song_start_time}')
|
||||||
log_msg(f'result start time: {result_start_time}')
|
log_msg(f'result start time: {result_start_time}')
|
||||||
|
|
||||||
|
@ -279,3 +177,110 @@ class DejavuTest(object):
|
||||||
else:
|
else:
|
||||||
log_msg('inaccurate match')
|
log_msg('inaccurate match')
|
||||||
log_msg('--------------------------------------------------\n')
|
log_msg('--------------------------------------------------\n')
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed=None):
|
||||||
|
"""
|
||||||
|
`seed` as None means that the sampling will be random.
|
||||||
|
|
||||||
|
Setting your own seed means that you can produce the
|
||||||
|
same experiment over and over.
|
||||||
|
"""
|
||||||
|
if seed is not None:
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_files_recursive(src, fmt):
|
||||||
|
"""
|
||||||
|
`src` is the source directory.
|
||||||
|
`fmt` is the extension, ie ".mp3" or "mp3", etc.
|
||||||
|
"""
|
||||||
|
files = []
|
||||||
|
for root, dirnames, filenames in walk(src):
|
||||||
|
for filename in fnmatch.filter(filenames, '*' + fmt):
|
||||||
|
files.append(join(root, filename))
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def get_length_audio(audiopath, extension):
|
||||||
|
"""
|
||||||
|
Returns length of audio in seconds.
|
||||||
|
Returns None if format isn't supported or in case of error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
|
||||||
|
except Exception:
|
||||||
|
print(f"Error in get_length_audio(): {traceback.format_exc()}")
|
||||||
|
return None
|
||||||
|
return int(len(audio) / 1000.0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_starttime(length, nseconds, padding):
|
||||||
|
"""
|
||||||
|
`length` is total audio length in seconds
|
||||||
|
`nseconds` is amount of time to sample in seconds
|
||||||
|
`padding` is off-limits seconds at beginning and ending
|
||||||
|
"""
|
||||||
|
maximum = length - padding - nseconds
|
||||||
|
if padding > maximum:
|
||||||
|
return 0
|
||||||
|
return random.randint(padding, maximum)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
|
||||||
|
"""
|
||||||
|
Generates a test file for each file recursively in `src` directory
|
||||||
|
of given format using `nseconds` sampled from the audio file.
|
||||||
|
|
||||||
|
Results are written to `dest` directory.
|
||||||
|
|
||||||
|
`padding` is the number of off-limit seconds and the beginning and
|
||||||
|
end of a track that won't be sampled in testing. Often you want to
|
||||||
|
avoid silence, etc.
|
||||||
|
"""
|
||||||
|
# create directories if necessary
|
||||||
|
if not exists(dest):
|
||||||
|
makedirs(dest)
|
||||||
|
|
||||||
|
# find files recursively of a given file format
|
||||||
|
for fmt in fmts:
|
||||||
|
testsources = get_files_recursive(src, fmt)
|
||||||
|
for audiosource in testsources:
|
||||||
|
|
||||||
|
print("audiosource:", audiosource)
|
||||||
|
|
||||||
|
filename, extension = splitext(basename(audiosource))
|
||||||
|
length = get_length_audio(audiosource, extension)
|
||||||
|
starttime = get_starttime(length, nseconds, padding)
|
||||||
|
|
||||||
|
test_file_name = f"{join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
|
||||||
|
|
||||||
|
subprocess.check_output([
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-ss", f"{starttime}",
|
||||||
|
'-t', f"{nseconds}",
|
||||||
|
"-i", audiosource,
|
||||||
|
test_file_name])
|
||||||
|
|
||||||
|
|
||||||
|
def log_msg(msg, log=True, silent=False):
|
||||||
|
if log:
|
||||||
|
logging.debug(msg)
|
||||||
|
if not silent:
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def autolabel(rects, ax):
|
||||||
|
# attach some text labels
|
||||||
|
for rect in rects:
|
||||||
|
height = rect.get_height()
|
||||||
|
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
|
||||||
|
|
||||||
|
|
||||||
|
def autolabeldoubles(rects, ax):
|
||||||
|
# attach some text labels
|
||||||
|
for rect in rects:
|
||||||
|
height = rect.get_height()
|
||||||
|
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
|
||||||
|
ha='center', va='bottom')
|
0
dejavu/third_party/__init__.py
vendored
Normal file
0
dejavu/third_party/__init__.py
vendored
Normal file
357
dejavu/third_party/wavio.py
vendored
Normal file
357
dejavu/third_party/wavio.py
vendored
Normal file
|
@ -0,0 +1,357 @@
|
||||||
|
# wavio.py
|
||||||
|
# Author: Warren Weckesser
|
||||||
|
# License: BSD 2-Clause (http://opensource.org/licenses/BSD-2-Clause)
|
||||||
|
# Synopsis: A Python module for reading and writing 24 bit WAV files.
|
||||||
|
# Github: github.com/WarrenWeckesser/wavio
|
||||||
|
|
||||||
|
"""
|
||||||
|
The wavio module defines the functions:
|
||||||
|
read(file)
|
||||||
|
Read a WAV file and return a `wavio.Wav` object, with attributes
|
||||||
|
`data`, `rate` and `sampwidth`.
|
||||||
|
write(filename, data, rate, scale=None, sampwidth=None)
|
||||||
|
Write a numpy array to a WAV file.
|
||||||
|
-----
|
||||||
|
Author: Warren Weckesser
|
||||||
|
License: BSD 2-Clause:
|
||||||
|
Copyright (c) 2015, Warren Weckesser
|
||||||
|
All rights reserved.
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer.
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import wave as _wave
|
||||||
|
|
||||||
|
import numpy as _np
|
||||||
|
|
||||||
|
__version__ = "0.0.5.dev1"
|
||||||
|
|
||||||
|
|
||||||
|
def _wav2array(nchannels, sampwidth, data):
|
||||||
|
"""data must be the string containing the bytes from the wav file."""
|
||||||
|
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
|
||||||
|
if remainder > 0:
|
||||||
|
raise ValueError('The length of data is not a multiple of '
|
||||||
|
'sampwidth * num_channels.')
|
||||||
|
if sampwidth > 4:
|
||||||
|
raise ValueError("sampwidth must not be greater than 4.")
|
||||||
|
|
||||||
|
if sampwidth == 3:
|
||||||
|
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
|
||||||
|
raw_bytes = _np.frombuffer(data, dtype=_np.uint8)
|
||||||
|
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
|
||||||
|
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
|
||||||
|
result = a.view('<i4').reshape(a.shape[:-1])
|
||||||
|
else:
|
||||||
|
# 8 bit samples are stored as unsigned ints; others as signed ints.
|
||||||
|
dt_char = 'u' if sampwidth == 1 else 'i'
|
||||||
|
a = _np.frombuffer(data, dtype=f'<{dt_char}{sampwidth}')
|
||||||
|
result = a.reshape(-1, nchannels)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _array2wav(a, sampwidth):
|
||||||
|
"""
|
||||||
|
Convert the input array `a` to a string of WAV data.
|
||||||
|
a.dtype must be one of uint8, int16 or int32. Allowed sampwidth
|
||||||
|
values are:
|
||||||
|
dtype sampwidth
|
||||||
|
uint8 1
|
||||||
|
int16 2
|
||||||
|
int32 3 or 4
|
||||||
|
When sampwidth is 3, the *low* bytes of `a` are assumed to contain
|
||||||
|
the values to include in the string.
|
||||||
|
"""
|
||||||
|
if sampwidth == 3:
|
||||||
|
# `a` must have dtype int32
|
||||||
|
if a.ndim == 1:
|
||||||
|
# Convert to a 2D array with a single column.
|
||||||
|
a = a.reshape(-1, 1)
|
||||||
|
# By shifting first 0 bits, then 8, then 16, the resulting output
|
||||||
|
# is 24 bit little-endian.
|
||||||
|
a8 = (a.reshape(a.shape + (1,)) >> _np.array([0, 8, 16])) & 255
|
||||||
|
wavdata = a8.astype(_np.uint8).tostring()
|
||||||
|
else:
|
||||||
|
# Make sure the array is little-endian, and then convert using
|
||||||
|
# tostring()
|
||||||
|
a = a.astype('<' + a.dtype.str[1:], copy=False)
|
||||||
|
wavdata = a.tostring()
|
||||||
|
return wavdata
|
||||||
|
|
||||||
|
|
||||||
|
class Wav(object):
|
||||||
|
"""
|
||||||
|
Object returned by `wavio.read`. Attributes are:
|
||||||
|
data : numpy array
|
||||||
|
The array of data read from the WAV file.
|
||||||
|
rate : float
|
||||||
|
The sample rate of the WAV file.
|
||||||
|
sampwidth : int
|
||||||
|
The sample width (i.e. number of bytes per sample) of the WAV file.
|
||||||
|
For example, `sampwidth == 3` is a 24 bit WAV file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, rate, sampwidth):
|
||||||
|
self.data = data
|
||||||
|
self.rate = rate
|
||||||
|
self.sampwidth = sampwidth
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = (f"Wav(data.shape={self.data.shape}, data.dtype={self.data.dtype}, "
|
||||||
|
f"rate={self.rate}, sampwidth={self.sampwidth})")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def read(file):
|
||||||
|
"""
|
||||||
|
Read a WAV file.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file : string or file object
|
||||||
|
Either the name of a file or an open file pointer.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
wav : wavio.Wav() instance
|
||||||
|
The return value is an instance of the class `wavio.Wav`,
|
||||||
|
with the following attributes:
|
||||||
|
data : numpy array
|
||||||
|
The array containing the data. The shape of the array
|
||||||
|
is (num_samples, num_channels). num_channels is the
|
||||||
|
number of audio channels (1 for mono, 2 for stereo).
|
||||||
|
rate : float
|
||||||
|
The sampling frequency (i.e. frame rate)
|
||||||
|
sampwidth : float
|
||||||
|
The sample width, in bytes. E.g. for a 24 bit WAV file,
|
||||||
|
sampwidth is 3.
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function uses the `wave` module of the Python standard libary
|
||||||
|
to read the WAV file, so it has the same limitations as that library.
|
||||||
|
In particular, the function does not read compressed WAV files, and
|
||||||
|
it does not read files with floating point data.
|
||||||
|
The array returned by `wavio.read` is always two-dimensional. If the
|
||||||
|
WAV data is mono, the array will have shape (num_samples, 1).
|
||||||
|
`wavio.read()` does not scale or normalize the data. The data in the
|
||||||
|
array `wav.data` is the data that was in the file. When the file
|
||||||
|
contains 24 bit samples, the resulting numpy array is 32 bit integers,
|
||||||
|
with values that have been sign-extended.
|
||||||
|
"""
|
||||||
|
wav = _wave.open(file)
|
||||||
|
rate = wav.getframerate()
|
||||||
|
nchannels = wav.getnchannels()
|
||||||
|
sampwidth = wav.getsampwidth()
|
||||||
|
nframes = wav.getnframes()
|
||||||
|
data = wav.readframes(nframes)
|
||||||
|
wav.close()
|
||||||
|
array = _wav2array(nchannels, sampwidth, data)
|
||||||
|
w = Wav(data=array, rate=rate, sampwidth=sampwidth)
|
||||||
|
return w
|
||||||
|
|
||||||
|
|
||||||
|
_sampwidth_dtypes = {1: _np.uint8,
|
||||||
|
2: _np.int16,
|
||||||
|
3: _np.int32,
|
||||||
|
4: _np.int32}
|
||||||
|
_sampwidth_ranges = {1: (0, 256),
|
||||||
|
2: (-2**15, 2**15),
|
||||||
|
3: (-2**23, 2**23),
|
||||||
|
4: (-2**31, 2**31)}
|
||||||
|
|
||||||
|
|
||||||
|
def _scale_to_sampwidth(data, sampwidth, vmin, vmax):
|
||||||
|
# Scale and translate the values to fit the range of the data type
|
||||||
|
# associated with the given sampwidth.
|
||||||
|
|
||||||
|
data = data.clip(vmin, vmax)
|
||||||
|
|
||||||
|
dt = _sampwidth_dtypes[sampwidth]
|
||||||
|
if vmax == vmin:
|
||||||
|
data = _np.zeros(data.shape, dtype=dt)
|
||||||
|
else:
|
||||||
|
outmin, outmax = _sampwidth_ranges[sampwidth]
|
||||||
|
if outmin != vmin or outmax != vmax:
|
||||||
|
vmin = float(vmin)
|
||||||
|
vmax = float(vmax)
|
||||||
|
data = (float(outmax - outmin) * (data - vmin) /
|
||||||
|
(vmax - vmin)).astype(_np.int64) + outmin
|
||||||
|
data[data == outmax] = outmax - 1
|
||||||
|
data = data.astype(dt)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def write(file, data, rate, scale=None, sampwidth=None):
|
||||||
|
"""
|
||||||
|
Write the numpy array `data` to a WAV file.
|
||||||
|
The Python standard library "wave" is used to write the data
|
||||||
|
to the file, so this function has the same limitations as that
|
||||||
|
module. In particular, the Python library does not support
|
||||||
|
floating point data. When given a floating point array, this
|
||||||
|
function converts the values to integers. See below for the
|
||||||
|
conversion rules.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file : string, or file object open for writing in binary mode
|
||||||
|
Either the name of a file or an open file pointer.
|
||||||
|
data : numpy array, 1- or 2-dimensional, integer or floating point
|
||||||
|
If it is 2-d, the rows are the frames (i.e. samples) and the
|
||||||
|
columns are the channels.
|
||||||
|
rate : float
|
||||||
|
The sampling frequency (i.e. frame rate) of the data.
|
||||||
|
sampwidth : int, optional
|
||||||
|
The sample width, in bytes, of the output file.
|
||||||
|
If `sampwidth` is not given, it is inferred (if possible) from
|
||||||
|
the data type of `data`, as follows::
|
||||||
|
data.dtype sampwidth
|
||||||
|
---------- ---------
|
||||||
|
uint8, int8 1
|
||||||
|
uint16, int16 2
|
||||||
|
uint32, int32 4
|
||||||
|
For any other data types, or to write a 24 bit file, `sampwidth`
|
||||||
|
must be given.
|
||||||
|
scale : tuple or str, optional
|
||||||
|
By default, the data written to the file is scaled up or down to
|
||||||
|
occupy the full range of the output data type. So, for example,
|
||||||
|
the unsigned 8 bit data [0, 1, 2, 15] would be written to the file
|
||||||
|
as [0, 17, 30, 255]. More generally, the default behavior is
|
||||||
|
(roughly)::
|
||||||
|
vmin = data.min()
|
||||||
|
vmax = data.max()
|
||||||
|
outmin = <minimum integer of the output dtype>
|
||||||
|
outmax = <maximum integer of the output dtype>
|
||||||
|
outdata = (outmax - outmin)*(data - vmin)/(vmax - vmin) + outmin
|
||||||
|
The `scale` argument allows the scaling of the output data to be
|
||||||
|
changed. `scale` can be a tuple of the form `(vmin, vmax)`, in which
|
||||||
|
case the given values override the use of `data.min()` and
|
||||||
|
`data.max()` for `vmin` and `vmax` shown above. (If either value
|
||||||
|
is `None`, the value shown above is used.) Data outside the
|
||||||
|
range (vmin, vmax) is clipped. If `vmin == vmax`, the output is
|
||||||
|
all zeros.
|
||||||
|
If `scale` is the string "none", then `vmin` and `vmax` are set to
|
||||||
|
`outmin` and `outmax`, respectively. This means the data is written
|
||||||
|
to the file with no scaling. (Note: `scale="none" is not the same
|
||||||
|
as `scale=None`. The latter means "use the default behavior",
|
||||||
|
which is to scale by the data minimum and maximum.)
|
||||||
|
If `scale` is the string "dtype-limits", then `vmin` and `vmax`
|
||||||
|
are set to the minimum and maximum integers of `data.dtype`.
|
||||||
|
The string "dtype-limits" is not allowed when the `data` is a
|
||||||
|
floating point array.
|
||||||
|
If using `scale` results in values that exceed the limits of the
|
||||||
|
output sample width, the data is clipped. For example, the
|
||||||
|
following code::
|
||||||
|
>> x = np.array([-100, 0, 100, 200, 300, 325])
|
||||||
|
>> wavio.write('foo.wav', x, 8000, scale='none', sampwidth=1)
|
||||||
|
will write the values [0, 0, 100, 200, 255, 255] to the file.
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
Create a 3 second 440 Hz sine wave, and save it in a 24-bit WAV file.
|
||||||
|
>> import numpy as np
|
||||||
|
>> import wavio
|
||||||
|
>> rate = 22050 # samples per second
|
||||||
|
>> T = 3 # sample duration (seconds)
|
||||||
|
>> f = 440.0 # sound frequency (Hz)
|
||||||
|
>> t = np.linspace(0, T, T*rate, endpoint=False)
|
||||||
|
>> x = np.sin(2*np.pi * f * t)
|
||||||
|
>> wavio.write("sine24.wav", x, rate, sampwidth=3)
|
||||||
|
Create a file that contains the 16 bit integer values -10000 and 10000
|
||||||
|
repeated 100 times. Don't automatically scale the values. Use a sample
|
||||||
|
rate 8000.
|
||||||
|
>> x = np.empty(200, dtype=np.int16)
|
||||||
|
>> x[::2] = -10000
|
||||||
|
>> x[1::2] = 10000
|
||||||
|
>> wavio.write("foo.wav", x, 8000, scale='none')
|
||||||
|
Check that the file contains what we expect.
|
||||||
|
>> w = wavio.read("foo.wav")
|
||||||
|
>> np.all(w.data[:, 0] == x)
|
||||||
|
True
|
||||||
|
In the following, the values -10000 and 10000 (from within the 16 bit
|
||||||
|
range [-2**15, 2**15-1]) are mapped to the corresponding values 88 and
|
||||||
|
168 (in the range [0, 2**8-1]).
|
||||||
|
>> wavio.write("foo.wav", x, 8000, sampwidth=1, scale='dtype-limits')
|
||||||
|
>> w = wavio.read("foo.wav")
|
||||||
|
>> w.data[:4, 0]
|
||||||
|
array([ 88, 168, 88, 168], dtype=uint8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if sampwidth is None:
|
||||||
|
if not _np.issubdtype(data.dtype, _np.integer) or data.itemsize > 4:
|
||||||
|
raise ValueError('when data.dtype is not an 8-, 16-, or 32-bit integer type, sampwidth must be specified.')
|
||||||
|
sampwidth = data.itemsize
|
||||||
|
else:
|
||||||
|
if sampwidth not in [1, 2, 3, 4]:
|
||||||
|
raise ValueError('sampwidth must be 1, 2, 3 or 4.')
|
||||||
|
|
||||||
|
outdtype = _sampwidth_dtypes[sampwidth]
|
||||||
|
outmin, outmax = _sampwidth_ranges[sampwidth]
|
||||||
|
|
||||||
|
if scale == "none":
|
||||||
|
data = data.clip(outmin, outmax-1).astype(outdtype)
|
||||||
|
elif scale == "dtype-limits":
|
||||||
|
if not _np.issubdtype(data.dtype, _np.integer):
|
||||||
|
raise ValueError("scale cannot be 'dtype-limits' with non-integer data.")
|
||||||
|
# Easy transforms that just changed the signedness of the data.
|
||||||
|
if sampwidth == 1 and data.dtype == _np.int8:
|
||||||
|
data = (data.astype(_np.int16) + 128).astype(_np.uint8)
|
||||||
|
elif sampwidth == 2 and data.dtype == _np.uint16:
|
||||||
|
data = (data.astype(_np.int32) - 32768).astype(_np.int16)
|
||||||
|
elif sampwidth == 4 and data.dtype == _np.uint32:
|
||||||
|
data = (data.astype(_np.int64) - 2**31).astype(_np.int32)
|
||||||
|
elif data.itemsize != sampwidth:
|
||||||
|
# Integer input, but rescaling is needed to adjust the
|
||||||
|
# input range to the output sample width.
|
||||||
|
ii = _np.iinfo(data.dtype)
|
||||||
|
vmin = ii.min
|
||||||
|
vmax = ii.max
|
||||||
|
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
|
||||||
|
else:
|
||||||
|
if scale is None:
|
||||||
|
vmin = data.min()
|
||||||
|
vmax = data.max()
|
||||||
|
else:
|
||||||
|
# scale must be a tuple of the form (vmin, vmax)
|
||||||
|
vmin, vmax = scale
|
||||||
|
if vmin is None:
|
||||||
|
vmin = data.min()
|
||||||
|
if vmax is None:
|
||||||
|
vmax = data.max()
|
||||||
|
|
||||||
|
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
|
||||||
|
|
||||||
|
# At this point, `data` has been converted to have one of the following:
|
||||||
|
# sampwidth dtype
|
||||||
|
# --------- -----
|
||||||
|
# 1 uint8
|
||||||
|
# 2 int16
|
||||||
|
# 3 int32
|
||||||
|
# 4 int32
|
||||||
|
# The values in `data` are in the form in which they will be saved;
|
||||||
|
# no more scaling will take place.
|
||||||
|
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(-1, 1)
|
||||||
|
|
||||||
|
wavdata = _array2wav(data, sampwidth)
|
||||||
|
|
||||||
|
w = _wave.open(file, 'wb')
|
||||||
|
w.setnchannels(data.shape[1])
|
||||||
|
w.setsampwidth(sampwidth)
|
||||||
|
w.setframerate(rate)
|
||||||
|
w.writeframes(wavdata)
|
||||||
|
w.close()
|
122
dejavu/wavio.py
122
dejavu/wavio.py
|
@ -1,122 +0,0 @@
|
||||||
# wavio.py
|
|
||||||
# Author: Warren Weckesser
|
|
||||||
# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause)
|
|
||||||
# Synopsis: A Python module for reading and writing 24 bit WAV files.
|
|
||||||
# Github: github.com/WarrenWeckesser/wavio
|
|
||||||
|
|
||||||
import wave as _wave
|
|
||||||
|
|
||||||
import numpy as _np
|
|
||||||
|
|
||||||
|
|
||||||
def _wav2array(nchannels, sampwidth, data):
|
|
||||||
"""data must be the string containing the bytes from the wav file."""
|
|
||||||
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
|
|
||||||
if remainder > 0:
|
|
||||||
raise ValueError('The length of data is not a multiple of '
|
|
||||||
'sampwidth * num_channels.')
|
|
||||||
if sampwidth > 4:
|
|
||||||
raise ValueError("sampwidth must not be greater than 4.")
|
|
||||||
|
|
||||||
if sampwidth == 3:
|
|
||||||
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
|
|
||||||
raw_bytes = _np.fromstring(data, dtype=_np.uint8)
|
|
||||||
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
|
|
||||||
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
|
|
||||||
result = a.view('<i4').reshape(a.shape[:-1])
|
|
||||||
else:
|
|
||||||
# 8 bit samples are stored as unsigned ints; others as signed ints.
|
|
||||||
dt_char = 'u' if sampwidth == 1 else 'i'
|
|
||||||
a = _np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
|
|
||||||
result = a.reshape(-1, nchannels)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def readwav(file):
|
|
||||||
"""
|
|
||||||
Read a WAV file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
file : string or file object
|
|
||||||
Either the name of a file or an open file pointer.
|
|
||||||
|
|
||||||
Return Values
|
|
||||||
-------------
|
|
||||||
rate : float
|
|
||||||
The sampling frequency (i.e. frame rate)
|
|
||||||
sampwidth : float
|
|
||||||
The sample width, in bytes. E.g. for a 24 bit WAV file,
|
|
||||||
sampwidth is 3.
|
|
||||||
data : numpy array
|
|
||||||
The array containing the data. The shape of the array is
|
|
||||||
(num_samples, num_channels). num_channels is the number of
|
|
||||||
audio channels (1 for mono, 2 for stereo).
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function uses the `wave` module of the Python standard libary
|
|
||||||
to read the WAV file, so it has the same limitations as that library.
|
|
||||||
In particular, the function does not read compressed WAV files.
|
|
||||||
|
|
||||||
"""
|
|
||||||
wav = _wave.open(file)
|
|
||||||
rate = wav.getframerate()
|
|
||||||
nchannels = wav.getnchannels()
|
|
||||||
sampwidth = wav.getsampwidth()
|
|
||||||
nframes = wav.getnframes()
|
|
||||||
data = wav.readframes(nframes)
|
|
||||||
wav.close()
|
|
||||||
array = _wav2array(nchannels, sampwidth, data)
|
|
||||||
return rate, sampwidth, array
|
|
||||||
|
|
||||||
|
|
||||||
def writewav24(filename, rate, data):
|
|
||||||
"""
|
|
||||||
Create a 24 bit wav file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
filename : string
|
|
||||||
Name of the file to create.
|
|
||||||
rate : float
|
|
||||||
The sampling frequency (i.e. frame rate) of the data.
|
|
||||||
data : array-like collection of integer or floating point values
|
|
||||||
data must be "array-like", either 1- or 2-dimensional. If it
|
|
||||||
is 2-d, the rows are the frames (i.e. samples) and the columns
|
|
||||||
are the channels.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
The data is assumed to be signed, and the values are assumed to be
|
|
||||||
within the range of a 24 bit integer. Floating point values are
|
|
||||||
converted to integers. The data is not rescaled or normalized before
|
|
||||||
writing it to the file.
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
Create a 3 second 440 Hz sine wave.
|
|
||||||
|
|
||||||
>>> rate = 22050 # samples per second
|
|
||||||
>>> T = 3 # sample duration (seconds)
|
|
||||||
>>> f = 440.0 # sound frequency (Hz)
|
|
||||||
>>> t = np.linspace(0, T, T*rate, endpoint=False)
|
|
||||||
>>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t)
|
|
||||||
>>> writewav24("sine24.wav", rate, x)
|
|
||||||
|
|
||||||
"""
|
|
||||||
a32 = _np.asarray(data, dtype=_np.int32)
|
|
||||||
if a32.ndim == 1:
|
|
||||||
# Convert to a 2D array with a single column.
|
|
||||||
a32.shape = a32.shape + (1,)
|
|
||||||
# By shifting first 0 bits, then 8, then 16, the resulting output
|
|
||||||
# is 24 bit little-endian.
|
|
||||||
a8 = (a32.reshape(a32.shape + (1,)) >> _np.array([0, 8, 16])) & 255
|
|
||||||
wavdata = a8.astype(_np.uint8).tostring()
|
|
||||||
|
|
||||||
w = _wave.open(filename, 'wb')
|
|
||||||
w.setnchannels(a32.shape[1])
|
|
||||||
w.setsampwidth(3)
|
|
||||||
w.setframerate(rate)
|
|
||||||
w.writeframes(wavdata)
|
|
||||||
w.close()
|
|
|
@ -1,11 +1,8 @@
|
||||||
import json
|
import json
|
||||||
import warnings
|
|
||||||
|
|
||||||
from dejavu import Dejavu
|
from dejavu import Dejavu
|
||||||
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer
|
from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||||
|
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
|
|
||||||
# load config from a JSON file (or anything outputting a python dictionary)
|
# load config from a JSON file (or anything outputting a python dictionary)
|
||||||
with open("dejavu.cnf.SAMPLE") as f:
|
with open("dejavu.cnf.SAMPLE") as f:
|
||||||
|
@ -17,10 +14,10 @@ if __name__ == '__main__':
|
||||||
djv = Dejavu(config)
|
djv = Dejavu(config)
|
||||||
|
|
||||||
# Fingerprint all the mp3's in the directory we give it
|
# Fingerprint all the mp3's in the directory we give it
|
||||||
djv.fingerprint_directory("mp3", [".mp3"])
|
djv.fingerprint_directory("test", [".wav"])
|
||||||
|
|
||||||
# Recognize audio from a file
|
# Recognize audio from a file
|
||||||
song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3")
|
song = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||||
print(f"From file we recognized: {song}\n")
|
print(f"From file we recognized: {song}\n")
|
||||||
|
|
||||||
# Or recognize audio from your microphone for `secs` seconds
|
# Or recognize audio from your microphone for `secs` seconds
|
||||||
|
@ -29,7 +26,7 @@ if __name__ == '__main__':
|
||||||
if song is None:
|
if song is None:
|
||||||
print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)")
|
print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)")
|
||||||
else:
|
else:
|
||||||
print(f"From mic with %d seconds we recognized: {(secs, song)}\n")
|
print(f"From mic with {secs} seconds we recognized: {song}\n")
|
||||||
|
|
||||||
# Or use a recognizer without the shortcut, in anyway you would like
|
# Or use a recognizer without the shortcut, in anyway you would like
|
||||||
recognizer = FileRecognizer(djv)
|
recognizer = FileRecognizer(djv)
|
BIN
mp3/azan_test.wav
Normal file
BIN
mp3/azan_test.wav
Normal file
Binary file not shown.
|
@ -5,3 +5,4 @@ scipy==1.3.1
|
||||||
matplotlib==3.1.1
|
matplotlib==3.1.1
|
||||||
mysql-connector-python==8.0.17
|
mysql-connector-python==8.0.17
|
||||||
psycopg2==2.8.3
|
psycopg2==2.8.3
|
||||||
|
|
||||||
|
|
196
run_tests.py
196
run_tests.py
|
@ -1,98 +1,56 @@
|
||||||
from dejavu.testing import *
|
import argparse
|
||||||
from dejavu import Dejavu
|
import logging
|
||||||
from optparse import OptionParser
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import time
|
import time
|
||||||
import shutil
|
from os import makedirs
|
||||||
|
from os.path import exists, join
|
||||||
|
from shutil import rmtree
|
||||||
|
|
||||||
usage = "usage: %prog [options] TESTING_AUDIOFOLDER"
|
import matplotlib.pyplot as plt
|
||||||
parser = OptionParser(usage=usage, version="%prog 1.1")
|
import numpy as np
|
||||||
parser.add_option("--secs",
|
|
||||||
action="store",
|
|
||||||
dest="secs",
|
|
||||||
default=5,
|
|
||||||
type=int,
|
|
||||||
help='Number of seconds starting from zero to test')
|
|
||||||
parser.add_option("--results",
|
|
||||||
action="store",
|
|
||||||
dest="results_folder",
|
|
||||||
default="./dejavu_test_results",
|
|
||||||
help='Sets the path where the results are saved')
|
|
||||||
parser.add_option("--temp",
|
|
||||||
action="store",
|
|
||||||
dest="temp_folder",
|
|
||||||
default="./dejavu_temp_testing_files",
|
|
||||||
help='Sets the path where the temp files are saved')
|
|
||||||
parser.add_option("--log",
|
|
||||||
action="store_true",
|
|
||||||
dest="log",
|
|
||||||
default=True,
|
|
||||||
help='Enables logging')
|
|
||||||
parser.add_option("--silent",
|
|
||||||
action="store_false",
|
|
||||||
dest="silent",
|
|
||||||
default=False,
|
|
||||||
help='Disables printing')
|
|
||||||
parser.add_option("--log-file",
|
|
||||||
dest="log_file",
|
|
||||||
default="results-compare.log",
|
|
||||||
help='Set the path and filename of the log file')
|
|
||||||
parser.add_option("--padding",
|
|
||||||
action="store",
|
|
||||||
dest="padding",
|
|
||||||
default=10,
|
|
||||||
type=int,
|
|
||||||
help='Number of seconds to pad choice of place to test from')
|
|
||||||
parser.add_option("--seed",
|
|
||||||
action="store",
|
|
||||||
dest="seed",
|
|
||||||
default=None,
|
|
||||||
type=int,
|
|
||||||
help='Random seed')
|
|
||||||
options, args = parser.parse_args()
|
|
||||||
test_folder = args[0]
|
|
||||||
|
|
||||||
# set random seed if set by user
|
from dejavu.tests.dejavu_test import (DejavuTest, autolabeldoubles,
|
||||||
set_seed(options.seed)
|
generate_test_files, log_msg, set_seed)
|
||||||
|
|
||||||
# ensure results folder exists
|
|
||||||
try:
|
|
||||||
os.stat(options.results_folder)
|
|
||||||
except:
|
|
||||||
os.mkdir(options.results_folder)
|
|
||||||
|
|
||||||
# set logging
|
def main(seconds: int, results_folder: str, temp_folder: str, log: bool, silent: bool,
|
||||||
if options.log:
|
log_file: str, padding: int, seed: int, src: str):
|
||||||
logging.basicConfig(filename=options.log_file, level=logging.DEBUG)
|
|
||||||
|
|
||||||
# set test seconds
|
# set random seed if set by user
|
||||||
test_seconds = ['%dsec' % i for i in range(1, options.secs + 1, 1)]
|
set_seed(seed)
|
||||||
|
|
||||||
# generate testing files
|
# ensure results folder exists
|
||||||
for i in range(1, options.secs + 1, 1):
|
if not exists(results_folder):
|
||||||
generate_test_files(test_folder, options.temp_folder,
|
makedirs(results_folder)
|
||||||
i, padding=options.padding)
|
|
||||||
|
|
||||||
# scan files
|
# set logging
|
||||||
log_msg("Running Dejavu fingerprinter on files in %s..." % test_folder,
|
if log:
|
||||||
log=options.log, silent=options.silent)
|
logging.basicConfig(filename=log_file, level=logging.DEBUG)
|
||||||
|
|
||||||
tm = time.time()
|
# set test seconds
|
||||||
djv = DejavuTest(options.temp_folder, test_seconds)
|
test_seconds = [f'{i}sec' for i in range(1, seconds + 1, 1)]
|
||||||
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm),
|
|
||||||
log=options.log, silent=options.silent)
|
|
||||||
|
|
||||||
tests = 1 # djv
|
# generate testing files
|
||||||
n_secs = len(test_seconds)
|
for i in range(1, seconds + 1, 1):
|
||||||
|
generate_test_files(src, temp_folder, i, padding=padding)
|
||||||
|
|
||||||
# set result variables -> 4d variables
|
# scan files
|
||||||
all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)]
|
log_msg(f"Running Dejavu fingerprinter on files in {src}...", log=log, silent=silent)
|
||||||
all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
|
|
||||||
all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
|
||||||
all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
|
||||||
|
|
||||||
# group results by seconds
|
tm = time.time()
|
||||||
for line in range(0, djv.n_lines):
|
djv = DejavuTest(temp_folder, test_seconds)
|
||||||
|
log_msg(f"finished obtaining results from dejavu in {(time.time() - tm)}", log=log, silent=silent)
|
||||||
|
|
||||||
|
tests = 1 # djv
|
||||||
|
n_secs = len(test_seconds)
|
||||||
|
|
||||||
|
# set result variables -> 4d variables
|
||||||
|
all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)]
|
||||||
|
all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
|
||||||
|
all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||||
|
all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||||
|
|
||||||
|
# group results by seconds
|
||||||
|
for line in range(0, djv.n_lines):
|
||||||
for col in range(0, djv.n_columns):
|
for col in range(0, djv.n_columns):
|
||||||
# for dejavu
|
# for dejavu
|
||||||
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||||||
|
@ -114,12 +72,12 @@ for line in range(0, djv.n_lines):
|
||||||
elif djv_match_acc != 0:
|
elif djv_match_acc != 0:
|
||||||
all_matching_times_counter[col][1][0] += 1
|
all_matching_times_counter[col][1][0] += 1
|
||||||
|
|
||||||
# create plots
|
# create plots
|
||||||
djv.create_plots('Confidence', all_match_confidence, options.results_folder)
|
djv.create_plots('Confidence', all_match_confidence, results_folder)
|
||||||
djv.create_plots('Query duration', all_query_duration, options.results_folder)
|
djv.create_plots('Query duration', all_query_duration, results_folder)
|
||||||
|
|
||||||
for sec in range(0, n_secs):
|
for sec in range(0, n_secs):
|
||||||
ind = np.arange(3) #
|
ind = np.arange(3)
|
||||||
width = 0.25 # the width of the bars
|
width = 0.25 # the width of the bars
|
||||||
|
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
|
@ -131,31 +89,30 @@ for sec in range(0, n_secs):
|
||||||
|
|
||||||
# add some
|
# add some
|
||||||
ax.set_ylabel('Matching Percentage')
|
ax.set_ylabel('Matching Percentage')
|
||||||
ax.set_title('%s Matching Percentage' % test_seconds[sec])
|
ax.set_title(f'{test_seconds[sec]} Matching Percentage')
|
||||||
ax.set_xticks(ind + width)
|
ax.set_xticks(ind + width)
|
||||||
|
|
||||||
labels = ['yes','no','invalid']
|
labels = ['yes', 'no', 'invalid']
|
||||||
ax.set_xticklabels( labels )
|
ax.set_xticklabels(labels)
|
||||||
|
|
||||||
box = ax.get_position()
|
box = ax.get_position()
|
||||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||||
#ax.legend((rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
autolabeldoubles(rects1, ax)
|
||||||
autolabeldoubles(rects1,ax)
|
|
||||||
plt.grid()
|
plt.grid()
|
||||||
|
|
||||||
fig_name = os.path.join(options.results_folder, "matching_perc_%s.png" % test_seconds[sec])
|
fig_name = join(results_folder, f"matching_perc_{test_seconds[sec]}.png")
|
||||||
fig.savefig(fig_name)
|
fig.savefig(fig_name)
|
||||||
|
|
||||||
for sec in range(0, n_secs):
|
for sec in range(0, n_secs):
|
||||||
ind = np.arange(2) #
|
ind = np.arange(2)
|
||||||
width = 0.25 # the width of the bars
|
width = 0.25 # the width of the bars
|
||||||
|
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
ax.set_xlim([-1*width, 1.75])
|
ax.set_xlim([-1 * width, 1.75])
|
||||||
|
|
||||||
div = all_match_counter[sec][0][0]
|
div = all_match_counter[sec][0][0]
|
||||||
if div == 0 :
|
if div == 0:
|
||||||
div = 1000000
|
div = 1000000
|
||||||
|
|
||||||
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
|
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
|
||||||
|
@ -163,22 +120,47 @@ for sec in range(0, n_secs):
|
||||||
|
|
||||||
# add some
|
# add some
|
||||||
ax.set_ylabel('Matching Accuracy')
|
ax.set_ylabel('Matching Accuracy')
|
||||||
ax.set_title('%s Matching Times Accuracy' % test_seconds[sec])
|
ax.set_title(f'{test_seconds[sec]} Matching Times Accuracy')
|
||||||
ax.set_xticks(ind + width)
|
ax.set_xticks(ind + width)
|
||||||
|
|
||||||
labels = ['yes','no']
|
labels = ['yes', 'no']
|
||||||
ax.set_xticklabels( labels )
|
ax.set_xticklabels(labels)
|
||||||
|
|
||||||
box = ax.get_position()
|
box = ax.get_position()
|
||||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||||
|
autolabeldoubles(rects1, ax)
|
||||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
|
||||||
autolabeldoubles(rects1,ax)
|
|
||||||
|
|
||||||
plt.grid()
|
plt.grid()
|
||||||
|
|
||||||
fig_name = os.path.join(options.results_folder, "matching_acc_%s.png" % test_seconds[sec])
|
fig_name = join(results_folder, f"matching_acc_{test_seconds[sec]}.png")
|
||||||
fig.savefig(fig_name)
|
fig.savefig(fig_name)
|
||||||
|
|
||||||
# remove temporary folder
|
# remove temporary folder
|
||||||
shutil.rmtree(options.temp_folder)
|
rmtree(temp_folder)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description=f'Runs a few tests for dejavu to evaluate '
|
||||||
|
f'its configuration performance. '
|
||||||
|
f'Usage: %(prog).py [options] TESTING_AUDIOFOLDER'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("-sec", "--seconds", action="store", default=5, type=int,
|
||||||
|
help='Number of seconds starting from zero to test.')
|
||||||
|
parser.add_argument("-res", "--results-folder", action="store", default="./dejavu_test_results",
|
||||||
|
help='Sets the path where the results are saved.')
|
||||||
|
parser.add_argument("-temp", "--temp-folder", action="store", default="./dejavu_temp_testing_files",
|
||||||
|
help='Sets the path where the temp files are saved.')
|
||||||
|
parser.add_argument("-l", "--log", action="store_true", default=False, help='Enables logging.')
|
||||||
|
parser.add_argument("-sl", "--silent", action="store_false", default=False, help='Disables printing.')
|
||||||
|
parser.add_argument("-lf", "--log-file", default="results-compare.log",
|
||||||
|
help='Set the path and filename of the log file.')
|
||||||
|
parser.add_argument("-pad", "--padding", action="store", default=10, type=int,
|
||||||
|
help='Number of seconds to pad choice of place to test from.')
|
||||||
|
parser.add_argument("-sd", "--seed", action="store", default=None, type=int, help='Random seed.')
|
||||||
|
parser.add_argument("src", type=str, help='Source folder for audios to use as tests.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.seconds, args.results_folder, args.temp_folder, args.log, args.silent, args.log_file, args.padding,
|
||||||
|
args.seed, args.src)
|
||||||
|
|
3
setup.cfg
Normal file
3
setup.cfg
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
[flake8]
|
||||||
|
max-line-length = 120
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -1,5 +1,4 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
# import os, sys
|
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements(requirements):
|
def parse_requirements(requirements):
|
||||||
|
@ -14,6 +13,7 @@ def parse_requirements(requirements):
|
||||||
reqs = list(filter((lambda x: x), nocomments))
|
reqs = list(filter((lambda x: x), nocomments))
|
||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
PACKAGE_NAME = "PyDejavu"
|
PACKAGE_NAME = "PyDejavu"
|
||||||
PACKAGE_VERSION = "0.1.3"
|
PACKAGE_VERSION = "0.1.3"
|
||||||
SUMMARY = 'Dejavu: Audio Fingerprinting in Python'
|
SUMMARY = 'Dejavu: Audio Fingerprinting in Python'
|
||||||
|
|
BIN
test/sean_secs.wav
Normal file
BIN
test/sean_secs.wav
Normal file
Binary file not shown.
BIN
test/woodward_43s.wav
Normal file
BIN
test/woodward_43s.wav
Normal file
Binary file not shown.
Loading…
Reference in a new issue