mirror of
https://github.com/correl/dejavu.git
synced 2024-11-23 11:09:52 +00:00
Merge pull request #205 from mauriciorepetto/dejavu_python_3.6.6
Dejavu on python 3.6.6
This commit is contained in:
commit
f48b09740c
36 changed files with 2180 additions and 1453 deletions
28
README.md
28
README.md
|
@ -24,6 +24,8 @@ Second, you'll need to create a MySQL database where Dejavu can store fingerprin
|
|||
|
||||
Now you're ready to start fingerprinting your audio collection!
|
||||
|
||||
Obs: The same from above goes for postgres database if you want to use it.
|
||||
|
||||
## Quickstart
|
||||
|
||||
```bash
|
||||
|
@ -44,8 +46,8 @@ Start by creating a Dejavu object with your configurations settings (Dejavu take
|
|||
... "database": {
|
||||
... "host": "127.0.0.1",
|
||||
... "user": "root",
|
||||
... "passwd": <password above>,
|
||||
... "db": <name of the database you created above>,
|
||||
... "password": <password above>,
|
||||
... "database": <name of the database you created above>,
|
||||
... }
|
||||
... }
|
||||
>>> djv = Dejavu(config)
|
||||
|
@ -81,7 +83,7 @@ The following keys are mandatory:
|
|||
The following keys are optional:
|
||||
|
||||
* `fingerprint_limit`: allows you to control how many seconds of each audio file to fingerprint. Leaving out this key, or alternatively using `-1` and `None` will cause Dejavu to fingerprint the entire audio file. Default value is `None`.
|
||||
* `database_type`: as of now, only `mysql` (the default value) is supported. If you'd like to subclass `Database` and add another, please fork and send a pull request!
|
||||
* `database_type`: `mysql` (the default value) and `postgres` are supported. If you'd like to add another subclass for `BaseDatabase` and implement a new type of database, please fork and send a pull request!
|
||||
|
||||
An example configuration is as follows:
|
||||
|
||||
|
@ -91,8 +93,8 @@ An example configuration is as follows:
|
|||
... "database": {
|
||||
... "host": "127.0.0.1",
|
||||
... "user": "root",
|
||||
... "passwd": "Password123",
|
||||
... "db": "dejavu_db",
|
||||
... "password": "Password123",
|
||||
... "database": "dejavu_db",
|
||||
... },
|
||||
... "database_type" : "mysql",
|
||||
... "fingerprint_limit" : 10
|
||||
|
@ -102,16 +104,16 @@ An example configuration is as follows:
|
|||
|
||||
## Tuning
|
||||
|
||||
Inside `fingerprint.py`, you may want to adjust following parameters (some values are given below).
|
||||
Inside `config/settings.py`, you may want to adjust following parameters (some values are given below).
|
||||
|
||||
FINGERPRINT_REDUCTION = 30
|
||||
PEAK_SORT = False
|
||||
DEFAULT_OVERLAP_RATIO = 0.4
|
||||
DEFAULT_FAN_VALUE = 10
|
||||
DEFAULT_AMP_MIN = 15
|
||||
PEAK_NEIGHBORHOOD_SIZE = 30
|
||||
DEFAULT_FAN_VALUE = 5
|
||||
DEFAULT_AMP_MIN = 10
|
||||
PEAK_NEIGHBORHOOD_SIZE = 10
|
||||
|
||||
These parameters are described in the `fingerprint.py` in detail. Read that in-order to understand the impact of changing these values.
|
||||
These parameters are described within the file in detail. Read that in-order to understand the impact of changing these values.
|
||||
|
||||
## Recognizing
|
||||
|
||||
|
@ -123,13 +125,13 @@ Through the terminal:
|
|||
|
||||
```bash
|
||||
$ python dejavu.py --recognize file sometrack.wav
|
||||
{'song_id': 1, 'song_name': 'Taylor Swift - Shake It Off', 'confidence': 3948, 'offset_seconds': 30.00018, 'match_time': 0.7159781455993652, 'offset': 646L}
|
||||
{'total_time': 2.863781690597534, 'fingerprint_time': 2.4306554794311523, 'query_time': 0.4067542552947998, 'align_time': 0.007731199264526367, 'results': [{'song_id': 1, 'song_name': 'Taylor Swift - Shake It Off', 'input_total_hashes': 76168, 'fingerprinted_hashes_in_db': 4919, 'hashes_matched_in_input': 794, 'input_confidence': 0.01, 'fingerprinted_confidence': 0.16, 'offset': -924, 'offset_seconds': -30.00018, 'file_sha1': b'3DC269DF7B8DB9B30D2604DA80783155912593E8'}, {...}, ...]}
|
||||
```
|
||||
|
||||
or in scripting, assuming you've already instantiated a Dejavu object:
|
||||
|
||||
```python
|
||||
>>> from dejavu.recognize import FileRecognizer
|
||||
>>> from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||
>>> song = djv.recognize(FileRecognizer, "va_us_top_40/wav/Mirrors - Justin Timberlake.wav")
|
||||
```
|
||||
|
||||
|
@ -138,7 +140,7 @@ or in scripting, assuming you've already instantiated a Dejavu object:
|
|||
With scripting:
|
||||
|
||||
```python
|
||||
>>> from dejavu.recognize import MicrophoneRecognizer
|
||||
>>> from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||
>>> song = djv.recognize(MicrophoneRecognizer, seconds=10) # Defaults to 10 seconds.
|
||||
```
|
||||
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
"database": {
|
||||
"host": "127.0.0.1",
|
||||
"user": "root",
|
||||
"passwd": "12345678",
|
||||
"db": "dejavu"
|
||||
}
|
||||
"password": "rootpass",
|
||||
"database": "dejavu"
|
||||
},
|
||||
"database_type": "mysql"
|
||||
}
|
38
dejavu.py
38
dejavu.py
|
@ -1,17 +1,12 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import warnings
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from argparse import RawTextHelpFormatter
|
||||
from os.path import isdir
|
||||
|
||||
from dejavu import Dejavu
|
||||
from dejavu.recognize import FileRecognizer
|
||||
from dejavu.recognize import MicrophoneRecognizer
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||
from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||
|
||||
DEFAULT_CONFIG_FILE = "dejavu.cnf.SAMPLE"
|
||||
|
||||
|
@ -24,7 +19,7 @@ def init(configpath):
|
|||
with open(configpath) as f:
|
||||
config = json.load(f)
|
||||
except IOError as err:
|
||||
print("Cannot open configuration: %s. Exiting" % (str(err)))
|
||||
print(f"Cannot open configuration: {str(err)}. Exiting")
|
||||
sys.exit(1)
|
||||
|
||||
# create a Dejavu instance
|
||||
|
@ -46,7 +41,7 @@ if __name__ == '__main__':
|
|||
'--fingerprint /path/to/directory')
|
||||
parser.add_argument('-r', '--recognize', nargs=2,
|
||||
help='Recognize what is '
|
||||
'playing through the microphone\n'
|
||||
'playing through the microphone or in a file.\n'
|
||||
'Usage: \n'
|
||||
'--recognize mic number_of_seconds \n'
|
||||
'--recognize file path/to/file \n')
|
||||
|
@ -59,7 +54,6 @@ if __name__ == '__main__':
|
|||
config_file = args.config
|
||||
if config_file is None:
|
||||
config_file = DEFAULT_CONFIG_FILE
|
||||
# print "Using default config file: %s" % (config_file)
|
||||
|
||||
djv = init(config_file)
|
||||
if args.fingerprint:
|
||||
|
@ -67,28 +61,24 @@ if __name__ == '__main__':
|
|||
if len(args.fingerprint) == 2:
|
||||
directory = args.fingerprint[0]
|
||||
extension = args.fingerprint[1]
|
||||
print("Fingerprinting all .%s files in the %s directory"
|
||||
% (extension, directory))
|
||||
print(f"Fingerprinting all .{extension} files in the {directory} directory")
|
||||
djv.fingerprint_directory(directory, ["." + extension], 4)
|
||||
|
||||
elif len(args.fingerprint) == 1:
|
||||
filepath = args.fingerprint[0]
|
||||
if os.path.isdir(filepath):
|
||||
if isdir(filepath):
|
||||
print("Please specify an extension if you'd like to fingerprint a directory!")
|
||||
sys.exit(1)
|
||||
djv.fingerprint_file(filepath)
|
||||
|
||||
elif args.recognize:
|
||||
# Recognize audio source
|
||||
song = None
|
||||
songs = None
|
||||
source = args.recognize[0]
|
||||
opt_arg = args.recognize[1]
|
||||
|
||||
if source in ('mic', 'microphone'):
|
||||
song = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
|
||||
songs = djv.recognize(MicrophoneRecognizer, seconds=opt_arg)
|
||||
elif source == 'file':
|
||||
song = djv.recognize(FileRecognizer, opt_arg)
|
||||
decoded_song = repr(song).decode('string_escape')
|
||||
print(decoded_song)
|
||||
|
||||
sys.exit(0)
|
||||
songs = djv.recognize(FileRecognizer, opt_arg)
|
||||
print(songs)
|
||||
|
|
|
@ -1,28 +1,29 @@
|
|||
from dejavu.database import get_database, Database
|
||||
import dejavu.decoder as decoder
|
||||
import fingerprint
|
||||
import multiprocessing
|
||||
import os
|
||||
import traceback
|
||||
import sys
|
||||
import traceback
|
||||
from itertools import groupby
|
||||
from time import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import dejavu.logic.decoder as decoder
|
||||
from dejavu.base_classes.base_database import get_database
|
||||
from dejavu.config.settings import (DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
||||
DEFAULT_WINDOW_SIZE, FIELD_FILE_SHA1,
|
||||
FIELD_TOTAL_HASHES,
|
||||
FINGERPRINTED_CONFIDENCE,
|
||||
FINGERPRINTED_HASHES, HASHES_MATCHED,
|
||||
INPUT_CONFIDENCE, INPUT_HASHES, OFFSET,
|
||||
OFFSET_SECS, SONG_ID, SONG_NAME, TOPN)
|
||||
from dejavu.logic.fingerprint import fingerprint
|
||||
|
||||
|
||||
class Dejavu(object):
|
||||
|
||||
SONG_ID = "song_id"
|
||||
SONG_NAME = 'song_name'
|
||||
CONFIDENCE = 'confidence'
|
||||
MATCH_TIME = 'match_time'
|
||||
OFFSET = 'offset'
|
||||
OFFSET_SECS = 'offset_seconds'
|
||||
|
||||
class Dejavu:
|
||||
def __init__(self, config):
|
||||
super(Dejavu, self).__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# initialize db
|
||||
db_cls = get_database(config.get("database_type", None))
|
||||
db_cls = get_database(config.get("database_type", "mysql").lower())
|
||||
|
||||
self.db = db_cls(**config.get("database", {}))
|
||||
self.db.setup()
|
||||
|
@ -32,17 +33,44 @@ class Dejavu(object):
|
|||
self.limit = self.config.get("fingerprint_limit", None)
|
||||
if self.limit == -1: # for JSON compatibility
|
||||
self.limit = None
|
||||
self.get_fingerprinted_songs()
|
||||
self.__load_fingerprinted_audio_hashes()
|
||||
|
||||
def get_fingerprinted_songs(self):
|
||||
def __load_fingerprinted_audio_hashes(self) -> None:
|
||||
"""
|
||||
Keeps a dictionary with the hashes of the fingerprinted songs, in that way is possible to check
|
||||
whether or not an audio file was already processed.
|
||||
"""
|
||||
# get songs previously indexed
|
||||
self.songs = self.db.get_songs()
|
||||
self.songhashes_set = set() # to know which ones we've computed before
|
||||
for song in self.songs:
|
||||
song_hash = song[Database.FIELD_FILE_SHA1]
|
||||
song_hash = song[FIELD_FILE_SHA1]
|
||||
self.songhashes_set.add(song_hash)
|
||||
|
||||
def fingerprint_directory(self, path, extensions, nprocesses=None):
|
||||
def get_fingerprinted_songs(self) -> List[Dict[str, any]]:
|
||||
"""
|
||||
To pull all fingerprinted songs from the database.
|
||||
|
||||
:return: a list of fingerprinted audios from the database.
|
||||
"""
|
||||
return self.db.get_songs()
|
||||
|
||||
def delete_songs_by_id(self, song_ids: List[int]) -> None:
|
||||
"""
|
||||
Deletes all audios given their ids.
|
||||
|
||||
:param song_ids: song ids to delete from the database.
|
||||
"""
|
||||
self.db.delete_songs_by_id(song_ids)
|
||||
|
||||
def fingerprint_directory(self, path: str, extensions: str, nprocesses: int = None) -> None:
|
||||
"""
|
||||
Given a directory and a set of extensions it fingerprints all files that match each extension specified.
|
||||
|
||||
:param path: path to the directory.
|
||||
:param extensions: list of file extensions to consider.
|
||||
:param nprocesses: amount of processes to fingerprint the files within the directory.
|
||||
"""
|
||||
# Try to use the maximum amount of processes if not given.
|
||||
try:
|
||||
nprocesses = nprocesses or multiprocessing.cpu_count()
|
||||
|
@ -55,54 +83,58 @@ class Dejavu(object):
|
|||
|
||||
filenames_to_fingerprint = []
|
||||
for filename, _ in decoder.find_files(path, extensions):
|
||||
|
||||
# don't refingerprint already fingerprinted files
|
||||
if decoder.unique_hash(filename) in self.songhashes_set:
|
||||
print "%s already fingerprinted, continuing..." % filename
|
||||
print(f"{filename} already fingerprinted, continuing...")
|
||||
continue
|
||||
|
||||
filenames_to_fingerprint.append(filename)
|
||||
|
||||
# Prepare _fingerprint_worker input
|
||||
worker_input = zip(filenames_to_fingerprint,
|
||||
[self.limit] * len(filenames_to_fingerprint))
|
||||
worker_input = list(zip(filenames_to_fingerprint, [self.limit] * len(filenames_to_fingerprint)))
|
||||
|
||||
# Send off our tasks
|
||||
iterator = pool.imap_unordered(_fingerprint_worker,
|
||||
worker_input)
|
||||
iterator = pool.imap_unordered(Dejavu._fingerprint_worker, worker_input)
|
||||
|
||||
# Loop till we have all of them
|
||||
while True:
|
||||
try:
|
||||
song_name, hashes, file_hash = iterator.next()
|
||||
song_name, hashes, file_hash = next(iterator)
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
except StopIteration:
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
print("Failed fingerprinting")
|
||||
# Print traceback because we can't reraise it here
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
else:
|
||||
sid = self.db.insert_song(song_name, file_hash)
|
||||
sid = self.db.insert_song(song_name, file_hash, len(hashes))
|
||||
|
||||
self.db.insert_hashes(sid, hashes)
|
||||
self.db.set_song_fingerprinted(sid)
|
||||
self.get_fingerprinted_songs()
|
||||
self.__load_fingerprinted_audio_hashes()
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def fingerprint_file(self, filepath, song_name=None):
|
||||
songname = decoder.path_to_songname(filepath)
|
||||
song_hash = decoder.unique_hash(filepath)
|
||||
song_name = song_name or songname
|
||||
def fingerprint_file(self, file_path: str, song_name: str = None) -> None:
|
||||
"""
|
||||
Given a path to a file the method generates hashes for it and stores them in the database
|
||||
for later be queried.
|
||||
|
||||
:param file_path: path to the file.
|
||||
:param song_name: song name associated to the audio file.
|
||||
"""
|
||||
song_name_from_path = decoder.get_audio_name_from_path(file_path)
|
||||
song_hash = decoder.unique_hash(file_path)
|
||||
song_name = song_name or song_name_from_path
|
||||
# don't refingerprint already fingerprinted files
|
||||
if song_hash in self.songhashes_set:
|
||||
print "%s already fingerprinted, continuing..." % song_name
|
||||
print(f"{song_name} already fingerprinted, continuing...")
|
||||
else:
|
||||
song_name, hashes, file_hash = _fingerprint_worker(
|
||||
filepath,
|
||||
song_name, hashes, file_hash = Dejavu._fingerprint_worker(
|
||||
file_path,
|
||||
self.limit,
|
||||
song_name=song_name
|
||||
)
|
||||
|
@ -110,93 +142,118 @@ class Dejavu(object):
|
|||
|
||||
self.db.insert_hashes(sid, hashes)
|
||||
self.db.set_song_fingerprinted(sid)
|
||||
self.get_fingerprinted_songs()
|
||||
self.__load_fingerprinted_audio_hashes()
|
||||
|
||||
def find_matches(self, samples, Fs=fingerprint.DEFAULT_FS):
|
||||
hashes = fingerprint.fingerprint(samples, Fs=Fs)
|
||||
return self.db.return_matches(hashes)
|
||||
def generate_fingerprints(self, samples: List[int], Fs=DEFAULT_FS) -> Tuple[List[Tuple[str, int]], float]:
|
||||
f"""
|
||||
Generate the fingerprints for the given sample data (channel).
|
||||
|
||||
def align_matches(self, matches):
|
||||
:param samples: list of ints which represents the channel info of the given audio file.
|
||||
:param Fs: sampling rate which defaults to {DEFAULT_FS}.
|
||||
:return: a list of tuples for hash and its corresponding offset, together with the generation time.
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
t = time()
|
||||
hashes = fingerprint(samples, Fs=Fs)
|
||||
fingerprint_time = time() - t
|
||||
return hashes, fingerprint_time
|
||||
|
||||
Returns a dictionary with match information.
|
||||
def find_matches(self, hashes: List[Tuple[str, int]]) -> Tuple[List[Tuple[int, int]], Dict[str, int], float]:
|
||||
"""
|
||||
# align by diffs
|
||||
diff_counter = {}
|
||||
largest = 0
|
||||
largest_count = 0
|
||||
song_id = -1
|
||||
for tup in matches:
|
||||
sid, diff = tup
|
||||
if diff not in diff_counter:
|
||||
diff_counter[diff] = {}
|
||||
if sid not in diff_counter[diff]:
|
||||
diff_counter[diff][sid] = 0
|
||||
diff_counter[diff][sid] += 1
|
||||
Finds the corresponding matches on the fingerprinted audios for the given hashes.
|
||||
|
||||
if diff_counter[diff][sid] > largest_count:
|
||||
largest = diff
|
||||
largest_count = diff_counter[diff][sid]
|
||||
song_id = sid
|
||||
:param hashes: list of tuples for hashes and their corresponding offsets
|
||||
:return: a tuple containing the matches found against the db, a dictionary which counts the different
|
||||
hashes matched for each song (with the song id as key), and the time that the query took.
|
||||
|
||||
# extract idenfication
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
if song:
|
||||
# TODO: Clarify what `get_song_by_id` should return.
|
||||
songname = song.get(Dejavu.SONG_NAME, None)
|
||||
else:
|
||||
return None
|
||||
"""
|
||||
t = time()
|
||||
matches, dedup_hashes = self.db.return_matches(hashes)
|
||||
query_time = time() - t
|
||||
|
||||
# return match info
|
||||
nseconds = round(float(largest) / fingerprint.DEFAULT_FS *
|
||||
fingerprint.DEFAULT_WINDOW_SIZE *
|
||||
fingerprint.DEFAULT_OVERLAP_RATIO, 5)
|
||||
song = {
|
||||
Dejavu.SONG_ID : song_id,
|
||||
Dejavu.SONG_NAME : songname.encode("utf8"),
|
||||
Dejavu.CONFIDENCE : largest_count,
|
||||
Dejavu.OFFSET : int(largest),
|
||||
Dejavu.OFFSET_SECS : nseconds,
|
||||
Database.FIELD_FILE_SHA1 : song.get(Database.FIELD_FILE_SHA1, None).encode("utf8"),}
|
||||
return song
|
||||
return matches, dedup_hashes, query_time
|
||||
|
||||
def recognize(self, recognizer, *options, **kwoptions):
|
||||
def align_matches(self, matches: List[Tuple[int, int]], dedup_hashes: Dict[str, int], queried_hashes: int,
|
||||
topn: int = TOPN) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Finds hash matches that align in time with other matches and finds
|
||||
consensus about which hashes are "true" signal from the audio.
|
||||
|
||||
:param matches: matches from the database
|
||||
:param dedup_hashes: dictionary containing the hashes matched without duplicates for each song
|
||||
(key is the song id).
|
||||
:param queried_hashes: amount of hashes sent for matching against the db
|
||||
:param topn: number of results being returned back.
|
||||
:return: a list of dictionaries (based on topn) with match information.
|
||||
"""
|
||||
# count offset occurrences per song and keep only the maximum ones.
|
||||
sorted_matches = sorted(matches, key=lambda m: (m[0], m[1]))
|
||||
counts = [(*key, len(list(group))) for key, group in groupby(sorted_matches, key=lambda m: (m[0], m[1]))]
|
||||
songs_matches = sorted(
|
||||
[max(list(group), key=lambda g: g[2]) for key, group in groupby(counts, key=lambda count: count[0])],
|
||||
key=lambda count: count[2], reverse=True
|
||||
)
|
||||
|
||||
songs_result = []
|
||||
for song_id, offset, _ in songs_matches[0:topn]: # consider topn elements in the result
|
||||
song = self.db.get_song_by_id(song_id)
|
||||
|
||||
song_name = song.get(SONG_NAME, None)
|
||||
song_hashes = song.get(FIELD_TOTAL_HASHES, None)
|
||||
nseconds = round(float(offset) / DEFAULT_FS * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO, 5)
|
||||
hashes_matched = dedup_hashes[song_id]
|
||||
|
||||
song = {
|
||||
SONG_ID: song_id,
|
||||
SONG_NAME: song_name.encode("utf8"),
|
||||
INPUT_HASHES: queried_hashes,
|
||||
FINGERPRINTED_HASHES: song_hashes,
|
||||
HASHES_MATCHED: hashes_matched,
|
||||
# Percentage regarding hashes matched vs hashes from the input.
|
||||
INPUT_CONFIDENCE: round(hashes_matched / queried_hashes, 2),
|
||||
# Percentage regarding hashes matched vs hashes fingerprinted in the db.
|
||||
FINGERPRINTED_CONFIDENCE: round(hashes_matched / song_hashes, 2),
|
||||
OFFSET: offset,
|
||||
OFFSET_SECS: nseconds,
|
||||
FIELD_FILE_SHA1: song.get(FIELD_FILE_SHA1, None).encode("utf8")
|
||||
}
|
||||
|
||||
songs_result.append(song)
|
||||
|
||||
return songs_result
|
||||
|
||||
def recognize(self, recognizer, *options, **kwoptions) -> Dict[str, any]:
|
||||
r = recognizer(self)
|
||||
return r.recognize(*options, **kwoptions)
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_worker(arguments):
|
||||
# Pool.imap sends arguments as tuples so we have to unpack
|
||||
# them ourself.
|
||||
try:
|
||||
file_name, limit = arguments
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _fingerprint_worker(filename, limit=None, song_name=None):
|
||||
# Pool.imap sends arguments as tuples so we have to unpack
|
||||
# them ourself.
|
||||
try:
|
||||
filename, limit = filename
|
||||
except ValueError:
|
||||
pass
|
||||
song_name, extension = os.path.splitext(os.path.basename(file_name))
|
||||
|
||||
songname, extension = os.path.splitext(os.path.basename(filename))
|
||||
song_name = song_name or songname
|
||||
channels, Fs, file_hash = decoder.read(filename, limit)
|
||||
result = set()
|
||||
channel_amount = len(channels)
|
||||
fingerprints, file_hash = Dejavu.get_file_fingerprints(file_name, limit, print_output=True)
|
||||
|
||||
for channeln, channel in enumerate(channels):
|
||||
# TODO: Remove prints or change them into optional logging.
|
||||
print("Fingerprinting channel %d/%d for %s" % (channeln + 1,
|
||||
channel_amount,
|
||||
filename))
|
||||
hashes = fingerprint.fingerprint(channel, Fs=Fs)
|
||||
print("Finished channel %d/%d for %s" % (channeln + 1, channel_amount,
|
||||
filename))
|
||||
result |= set(hashes)
|
||||
return song_name, fingerprints, file_hash
|
||||
|
||||
return song_name, result, file_hash
|
||||
@staticmethod
|
||||
def get_file_fingerprints(file_name: str, limit: int, print_output: bool = False):
|
||||
channels, fs, file_hash = decoder.read(file_name, limit)
|
||||
fingerprints = set()
|
||||
channel_amount = len(channels)
|
||||
for channeln, channel in enumerate(channels, start=1):
|
||||
if print_output:
|
||||
print(f"Fingerprinting channel {channeln}/{channel_amount} for {file_name}")
|
||||
|
||||
hashes = fingerprint(channel, Fs=fs)
|
||||
|
||||
def chunkify(lst, n):
|
||||
"""
|
||||
Splits a list into roughly n equal parts.
|
||||
http://stackoverflow.com/questions/2130016/splitting-a-list-of-arbitrary-size-into-only-roughly-n-equal-parts
|
||||
"""
|
||||
return [lst[i::n] for i in xrange(n)]
|
||||
if print_output:
|
||||
print(f"Finished channel {channeln}/{channel_amount} for {file_name}")
|
||||
|
||||
fingerprints |= set(hashes)
|
||||
|
||||
return fingerprints, file_hash
|
||||
|
|
0
dejavu/base_classes/__init__.py
Normal file
0
dejavu/base_classes/__init__.py
Normal file
195
dejavu/base_classes/base_database.py
Executable file
195
dejavu/base_classes/base_database.py
Executable file
|
@ -0,0 +1,195 @@
|
|||
import abc
|
||||
import importlib
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from dejavu.config.settings import DATABASES
|
||||
|
||||
|
||||
class BaseDatabase(object, metaclass=abc.ABCMeta):
|
||||
# Name of your Database subclass, this is used in configuration
|
||||
# to refer to your class
|
||||
type = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def before_fork(self) -> None:
|
||||
"""
|
||||
Called before the database instance is given to the new process
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_fork(self) -> None:
|
||||
"""
|
||||
Called after the database instance has been given to the new process
|
||||
|
||||
This will be called in the new process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup(self) -> None:
|
||||
"""
|
||||
Called on creation or shortly afterwards.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def empty(self) -> None:
|
||||
"""
|
||||
Called when the database should be cleared of all data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_unfingerprinted_songs(self) -> None:
|
||||
"""
|
||||
Called to remove any song entries that do not have any fingerprints
|
||||
associated with them.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_num_songs(self) -> int:
|
||||
"""
|
||||
Returns the song's count stored.
|
||||
|
||||
:return: the amount of songs in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_num_fingerprints(self) -> int:
|
||||
"""
|
||||
Returns the fingerprints' count stored.
|
||||
|
||||
:return: the number of fingerprints in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_song_fingerprinted(self, song_id: int):
|
||||
"""
|
||||
Sets a specific song as having all fingerprints in the database.
|
||||
|
||||
:param song_id: song identifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_songs(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Returns all fully fingerprinted songs in the database
|
||||
|
||||
:return: a dictionary with the songs info.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_song_by_id(self, song_id: int) -> Dict[str, str]:
|
||||
"""
|
||||
Brings the song info from the database.
|
||||
|
||||
:param song_id: song identifier.
|
||||
:return: a song by its identifier. Result must be a Dictionary.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert(self, fingerprint: str, song_id: int, offset: int):
|
||||
"""
|
||||
Inserts a single fingerprint into the database.
|
||||
|
||||
:param fingerprint: Part of a sha1 hash, in hexadecimal format
|
||||
:param song_id: Song identifier this fingerprint is off
|
||||
:param offset: The offset this fingerprint is from.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int:
|
||||
"""
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
:param song_name: The name of the song.
|
||||
:param file_hash: Hash from the fingerprinted file.
|
||||
:param total_hashes: amount of hashes to be inserted on fingerprint table.
|
||||
:return: the inserted id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def query(self, fingerprint: str = None) -> List[Tuple]:
|
||||
"""
|
||||
Returns all matching fingerprint entries associated with
|
||||
the given hash as parameter, if None is passed it returns all entries.
|
||||
|
||||
:param fingerprint: part of a sha1 hash, in hexadecimal format
|
||||
:return: a list of fingerprint records stored in the db.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_iterable_kv_pairs(self) -> List[Tuple]:
|
||||
"""
|
||||
Returns all fingerprints in the database.
|
||||
|
||||
:return: a list containing all fingerprints stored in the db.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_hashes(self, song_id: int, hashes: List[Tuple[str, int]], batch_size: int = 1000) -> None:
|
||||
"""
|
||||
Insert a multitude of fingerprints.
|
||||
|
||||
:param song_id: Song identifier the fingerprints belong to
|
||||
:param hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
:param batch_size: insert batches.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def return_matches(self, hashes: List[Tuple[str, int]], batch_size: int = 1000) \
|
||||
-> Tuple[List[Tuple[int, int]], Dict[int, int]]:
|
||||
"""
|
||||
Searches the database for pairs of (hash, offset) values.
|
||||
|
||||
:param hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
:param batch_size: number of query's batches.
|
||||
:return: a list of (sid, offset_difference) tuples and a
|
||||
dictionary with the amount of hashes matched (not considering
|
||||
duplicated hashes) in each song.
|
||||
- song id: Song identifier
|
||||
- offset_difference: (database_offset - sampled_offset)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
||||
"""
|
||||
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
||||
|
||||
:param song_ids: song ids to be deleted from the database.
|
||||
:param batch_size: number of query's batches.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_database(database_type: str = "mysql") -> BaseDatabase:
|
||||
"""
|
||||
Given a database type it returns a database instance for that type.
|
||||
|
||||
:param database_type: type of the database.
|
||||
:return: an instance of BaseDatabase depending on given database_type.
|
||||
"""
|
||||
try:
|
||||
path, db_class_name = DATABASES[database_type]
|
||||
db_module = importlib.import_module(path)
|
||||
db_class = getattr(db_module, db_class_name)
|
||||
return db_class
|
||||
except (ImportError, KeyError):
|
||||
raise TypeError("Unsupported database type supplied.")
|
33
dejavu/base_classes/base_recognizer.py
Normal file
33
dejavu/base_classes/base_recognizer.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import abc
|
||||
from time import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dejavu.config.settings import DEFAULT_FS
|
||||
|
||||
|
||||
class BaseRecognizer(object, metaclass=abc.ABCMeta):
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = DEFAULT_FS
|
||||
|
||||
def _recognize(self, *data) -> Tuple[List[Dict[str, any]], int, int, int]:
|
||||
fingerprint_times = []
|
||||
hashes = set() # to remove possible duplicated fingerprints we built a set.
|
||||
for channel in data:
|
||||
fingerprints, fingerprint_time = self.dejavu.generate_fingerprints(channel, Fs=self.Fs)
|
||||
fingerprint_times.append(fingerprint_time)
|
||||
hashes |= set(fingerprints)
|
||||
|
||||
matches, dedup_hashes, query_time = self.dejavu.find_matches(hashes)
|
||||
|
||||
t = time()
|
||||
final_results = self.dejavu.align_matches(matches, dedup_hashes, len(hashes))
|
||||
align_time = time() - t
|
||||
|
||||
return final_results, np.sum(fingerprint_times), query_time, align_time
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize(self) -> Dict[str, any]:
|
||||
pass # base class does nothing
|
232
dejavu/base_classes/common_database.py
Normal file
232
dejavu/base_classes/common_database.py
Normal file
|
@ -0,0 +1,232 @@
|
|||
import abc
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from dejavu.base_classes.base_database import BaseDatabase
|
||||
|
||||
|
||||
class CommonDatabase(BaseDatabase, metaclass=abc.ABCMeta):
|
||||
# Since several methods across different databases are actually just the same
|
||||
# I've built this class with the idea to reuse that logic instead of copy pasting
|
||||
# over and over the same code.
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def before_fork(self) -> None:
|
||||
"""
|
||||
Called before the database instance is given to the new process
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_fork(self) -> None:
|
||||
"""
|
||||
Called after the database instance has been given to the new process
|
||||
|
||||
This will be called in the new process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup(self) -> None:
|
||||
"""
|
||||
Called on creation or shortly afterwards.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.CREATE_SONGS_TABLE)
|
||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self) -> None:
|
||||
"""
|
||||
Called when the database should be cleared of all data.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.DROP_FINGERPRINTS)
|
||||
cur.execute(self.DROP_SONGS)
|
||||
|
||||
self.setup()
|
||||
|
||||
def delete_unfingerprinted_songs(self) -> None:
|
||||
"""
|
||||
Called to remove any song entries that do not have any fingerprints
|
||||
associated with them.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def get_num_songs(self) -> int:
|
||||
"""
|
||||
Returns the song's count stored.
|
||||
|
||||
:return: the amount of songs in the database.
|
||||
"""
|
||||
with self.cursor(buffered=True) as cur:
|
||||
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
|
||||
return count
|
||||
|
||||
def get_num_fingerprints(self) -> int:
|
||||
"""
|
||||
Returns the fingerprints' count stored.
|
||||
|
||||
:return: the number of fingerprints in the database.
|
||||
"""
|
||||
with self.cursor(buffered=True) as cur:
|
||||
cur.execute(self.SELECT_NUM_FINGERPRINTS)
|
||||
count = cur.fetchone()[0] if cur.rowcount != 0 else 0
|
||||
|
||||
return count
|
||||
|
||||
def set_song_fingerprinted(self, song_id):
|
||||
"""
|
||||
Sets a specific song as having all fingerprints in the database.
|
||||
|
||||
:param song_id: song identifier.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (song_id,))
|
||||
|
||||
def get_songs(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Returns all fully fingerprinted songs in the database
|
||||
|
||||
:return: a dictionary with the songs info.
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(self.SELECT_SONGS)
|
||||
return list(cur)
|
||||
|
||||
def get_song_by_id(self, song_id: int) -> Dict[str, str]:
|
||||
"""
|
||||
Brings the song info from the database.
|
||||
|
||||
:param song_id: song identifier.
|
||||
:return: a song by its identifier. Result must be a Dictionary.
|
||||
"""
|
||||
with self.cursor(dictionary=True) as cur:
|
||||
cur.execute(self.SELECT_SONG, (song_id,))
|
||||
return cur.fetchone()
|
||||
|
||||
def insert(self, fingerprint: str, song_id: int, offset: int):
|
||||
"""
|
||||
Inserts a single fingerprint into the database.
|
||||
|
||||
:param fingerprint: Part of a sha1 hash, in hexadecimal format
|
||||
:param song_id: Song identifier this fingerprint is off
|
||||
:param offset: The offset this fingerprint is from.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.INSERT_FINGERPRINT, (fingerprint, song_id, offset))
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int:
|
||||
"""
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
:param song_name: The name of the song.
|
||||
:param file_hash: Hash from the fingerprinted file.
|
||||
:param total_hashes: amount of hashes to be inserted on fingerprint table.
|
||||
:return: the inserted id.
|
||||
"""
|
||||
pass
|
||||
|
||||
def query(self, fingerprint: str = None) -> List[Tuple]:
|
||||
"""
|
||||
Returns all matching fingerprint entries associated with
|
||||
the given hash as parameter, if None is passed it returns all entries.
|
||||
|
||||
:param fingerprint: part of a sha1 hash, in hexadecimal format
|
||||
:return: a list of fingerprint records stored in the db.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
if fingerprint:
|
||||
cur.execute(self.SELECT, (fingerprint,))
|
||||
else: # select all if no key
|
||||
cur.execute(self.SELECT_ALL)
|
||||
return list(cur)
|
||||
|
||||
def get_iterable_kv_pairs(self) -> List[Tuple]:
|
||||
"""
|
||||
Returns all fingerprints in the database.
|
||||
|
||||
:return: a list containing all fingerprints stored in the db.
|
||||
"""
|
||||
return self.query(None)
|
||||
|
||||
def insert_hashes(self, song_id: int, hashes: List[Tuple[str, int]], batch_size: int = 1000) -> None:
|
||||
"""
|
||||
Insert a multitude of fingerprints.
|
||||
|
||||
:param song_id: Song identifier the fingerprints belong to
|
||||
:param hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
:param batch_size: insert batches.
|
||||
"""
|
||||
values = [(song_id, hsh, int(offset)) for hsh, offset in hashes]
|
||||
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(hashes), batch_size):
|
||||
cur.executemany(self.INSERT_FINGERPRINT, values[index: index + batch_size])
|
||||
|
||||
def return_matches(self, hashes: List[Tuple[str, int]],
|
||||
batch_size: int = 1000) -> Tuple[List[Tuple[int, int]], Dict[int, int]]:
|
||||
"""
|
||||
Searches the database for pairs of (hash, offset) values.
|
||||
|
||||
:param hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
:param batch_size: number of query's batches.
|
||||
:return: a list of (sid, offset_difference) tuples and a
|
||||
dictionary with the amount of hashes matched (not considering
|
||||
duplicated hashes) in each song.
|
||||
- song id: Song identifier
|
||||
- offset_difference: (database_offset - sampled_offset)
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hsh, offset in hashes:
|
||||
if hsh.upper() in mapper.keys():
|
||||
mapper[hsh.upper()].append(offset)
|
||||
else:
|
||||
mapper[hsh.upper()] = [offset]
|
||||
|
||||
values = list(mapper.keys())
|
||||
|
||||
# in order to count each hash only once per db offset we use the dic below
|
||||
dedup_hashes = {}
|
||||
|
||||
results = []
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(values), batch_size):
|
||||
# Create our IN part of the query
|
||||
query = self.SELECT_MULTIPLE % ', '.join([self.IN_MATCH] * len(values[index: index + batch_size]))
|
||||
|
||||
cur.execute(query, values[index: index + batch_size])
|
||||
|
||||
for hsh, sid, offset in cur:
|
||||
if sid not in dedup_hashes.keys():
|
||||
dedup_hashes[sid] = 1
|
||||
else:
|
||||
dedup_hashes[sid] += 1
|
||||
# we now evaluate all offset for each hash matched
|
||||
for song_sampled_offset in mapper[hsh]:
|
||||
results.append((sid, offset - song_sampled_offset))
|
||||
|
||||
return results, dedup_hashes
|
||||
|
||||
def delete_songs_by_id(self, song_ids: List[int], batch_size: int = 1000) -> None:
|
||||
"""
|
||||
Given a list of song ids it deletes all songs specified and their corresponding fingerprints.
|
||||
|
||||
:param song_ids: song ids to be deleted from the database.
|
||||
:param batch_size: number of query's batches.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
for index in range(0, len(song_ids), batch_size):
|
||||
# Create our IN part of the query
|
||||
query = self.DELETE_SONGS % ', '.join(['%s'] * len(song_ids[index: index + batch_size]))
|
||||
|
||||
cur.execute(query, song_ids[index: index + batch_size])
|
0
dejavu/config/__init__.py
Normal file
0
dejavu/config/__init__.py
Normal file
102
dejavu/config/settings.py
Normal file
102
dejavu/config/settings.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
# Dejavu
|
||||
|
||||
# DEJAVU JSON RESPONSE
|
||||
SONG_ID = "song_id"
|
||||
SONG_NAME = 'song_name'
|
||||
RESULTS = 'results'
|
||||
|
||||
HASHES_MATCHED = 'hashes_matched_in_input'
|
||||
|
||||
# Hashes fingerprinted in the db.
|
||||
FINGERPRINTED_HASHES = 'fingerprinted_hashes_in_db'
|
||||
# Percentage regarding hashes matched vs hashes fingerprinted in the db.
|
||||
FINGERPRINTED_CONFIDENCE = 'fingerprinted_confidence'
|
||||
|
||||
# Hashes generated from the input.
|
||||
INPUT_HASHES = 'input_total_hashes'
|
||||
# Percentage regarding hashes matched vs hashes from the input.
|
||||
INPUT_CONFIDENCE = 'input_confidence'
|
||||
|
||||
TOTAL_TIME = 'total_time'
|
||||
FINGERPRINT_TIME = 'fingerprint_time'
|
||||
QUERY_TIME = 'query_time'
|
||||
ALIGN_TIME = 'align_time'
|
||||
OFFSET = 'offset'
|
||||
OFFSET_SECS = 'offset_seconds'
|
||||
|
||||
# DATABASE CLASS INSTANCES:
|
||||
DATABASES = {
|
||||
'mysql': ("dejavu.database_handler.mysql_database", "MySQLDatabase"),
|
||||
'postgres': ("dejavu.database_handler.postgres_database", "PostgreSQLDatabase")
|
||||
}
|
||||
|
||||
# TABLE SONGS
|
||||
SONGS_TABLENAME = "songs"
|
||||
|
||||
# SONGS FIELDS
|
||||
FIELD_SONG_ID = 'song_id'
|
||||
FIELD_SONGNAME = 'song_name'
|
||||
FIELD_FINGERPRINTED = "fingerprinted"
|
||||
FIELD_FILE_SHA1 = 'file_sha1'
|
||||
FIELD_TOTAL_HASHES = 'total_hashes'
|
||||
|
||||
# TABLE FINGERPRINTS
|
||||
FINGERPRINTS_TABLENAME = "fingerprints"
|
||||
|
||||
# FINGERPRINTS FIELDS
|
||||
FIELD_HASH = 'hash'
|
||||
FIELD_OFFSET = 'offset'
|
||||
|
||||
# FINGERPRINTS CONFIG:
|
||||
# This is used as connectivity parameter for scipy.generate_binary_structure function. This parameter
|
||||
# changes the morphology mask when looking for maximum peaks on the spectrogram matrix.
|
||||
# Possible values are: [1, 2]
|
||||
# Where 1 sets a diamond morphology which implies that diagonal elements are not considered as neighbors (this
|
||||
# is the value used in the original dejavu code).
|
||||
# And 2 sets a square mask, i.e. all elements are considered neighbors.
|
||||
CONNECTIVITY_MASK = 2
|
||||
|
||||
# Sampling rate, related to the Nyquist conditions, which affects
|
||||
# the range frequencies we can detect.
|
||||
DEFAULT_FS = 44100
|
||||
|
||||
# Size of the FFT window, affects frequency granularity
|
||||
DEFAULT_WINDOW_SIZE = 4096
|
||||
|
||||
# Ratio by which each sequential window overlaps the last and the
|
||||
# next window. Higher overlap will allow a higher granularity of offset
|
||||
# matching, but potentially more fingerprints.
|
||||
DEFAULT_OVERLAP_RATIO = 0.5
|
||||
|
||||
# Degree to which a fingerprint can be paired with its neighbors. Higher values will
|
||||
# cause more fingerprints, but potentially better accuracy.
|
||||
DEFAULT_FAN_VALUE = 5 # 15 was the original value.
|
||||
|
||||
# Minimum amplitude in spectrogram in order to be considered a peak.
|
||||
# This can be raised to reduce number of fingerprints, but can negatively
|
||||
# affect accuracy.
|
||||
DEFAULT_AMP_MIN = 10
|
||||
|
||||
# Number of cells around an amplitude peak in the spectrogram in order
|
||||
# for Dejavu to consider it a spectral peak. Higher values mean less
|
||||
# fingerprints and faster matching, but can potentially affect accuracy.
|
||||
PEAK_NEIGHBORHOOD_SIZE = 10 # 20 was the original value.
|
||||
|
||||
# Thresholds on how close or far fingerprints can be in time in order
|
||||
# to be paired as a fingerprint. If your max is too low, higher values of
|
||||
# DEFAULT_FAN_VALUE may not perform as expected.
|
||||
MIN_HASH_TIME_DELTA = 0
|
||||
MAX_HASH_TIME_DELTA = 200
|
||||
|
||||
# If True, will sort peaks temporally for fingerprinting;
|
||||
# not sorting will cut down number of fingerprints, but potentially
|
||||
# affect performance.
|
||||
PEAK_SORT = True
|
||||
|
||||
# Number of bits to grab from the front of the SHA1 hash in the
|
||||
# fingerprint calculation. The more you grab, the more memory storage,
|
||||
# with potentially lesser collisions of matches.
|
||||
FINGERPRINT_REDUCTION = 20
|
||||
|
||||
# Number of results being returned for file recognition
|
||||
TOPN = 2
|
|
@ -1,176 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
import abc
|
||||
|
||||
|
||||
class Database(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
FIELD_FILE_SHA1 = 'file_sha1'
|
||||
FIELD_SONG_ID = 'song_id'
|
||||
FIELD_SONGNAME = 'song_name'
|
||||
FIELD_OFFSET = 'offset'
|
||||
FIELD_HASH = 'hash'
|
||||
|
||||
# Name of your Database subclass, this is used in configuration
|
||||
# to refer to your class
|
||||
type = None
|
||||
|
||||
def __init__(self):
|
||||
super(Database, self).__init__()
|
||||
|
||||
def before_fork(self):
|
||||
"""
|
||||
Called before the database instance is given to the new process
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_fork(self):
|
||||
"""
|
||||
Called after the database instance has been given to the new process
|
||||
|
||||
This will be called in the new process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Called on creation or shortly afterwards.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def empty(self):
|
||||
"""
|
||||
Called when the database should be cleared of all data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_unfingerprinted_songs(self):
|
||||
"""
|
||||
Called to remove any song entries that do not have any fingerprints
|
||||
associated with them.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns the amount of songs in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns the number of fingerprints in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Sets a specific song as having all fingerprints in the database.
|
||||
|
||||
sid: Song identifier
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_songs(self):
|
||||
"""
|
||||
Returns all fully fingerprinted songs in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Return a song by its identifier
|
||||
|
||||
sid: Song identifier
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Inserts a single fingerprint into the database.
|
||||
|
||||
hash: Part of a sha1 hash, in hexadecimal format
|
||||
sid: Song identifier this fingerprint is off
|
||||
offset: The offset this hash is from
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_song(self, song_name):
|
||||
"""
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
song_name: The name of the song.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def query(self, hash):
|
||||
"""
|
||||
Returns all matching fingerprint entries associated with
|
||||
the given hash as parameter.
|
||||
|
||||
hash: Part of a sha1 hash, in hexadecimal format
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all fingerprints in the database.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_hashes(self, sid, hashes):
|
||||
"""
|
||||
Insert a multitude of fingerprints.
|
||||
|
||||
sid: Song identifier the fingerprints belong to
|
||||
hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def return_matches(self, hashes):
|
||||
"""
|
||||
Searches the database for pairs of (hash, offset) values.
|
||||
|
||||
hashes: A sequence of tuples in the format (hash, offset)
|
||||
- hash: Part of a sha1 hash, in hexadecimal format
|
||||
- offset: Offset this hash was created from/at.
|
||||
|
||||
Returns a sequence of (sid, offset_difference) tuples.
|
||||
|
||||
sid: Song identifier
|
||||
offset_difference: (offset - database_offset)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_database(database_type=None):
|
||||
# Default to using the mysql database
|
||||
database_type = database_type or "mysql"
|
||||
# Lower all the input.
|
||||
database_type = database_type.lower()
|
||||
|
||||
for db_cls in Database.__subclasses__():
|
||||
if db_cls.type == database_type:
|
||||
return db_cls
|
||||
|
||||
raise TypeError("Unsupported database type supplied.")
|
||||
|
||||
|
||||
# Import our default database handler
|
||||
import dejavu.database_sql
|
0
dejavu/database_handler/__init__.py
Normal file
0
dejavu/database_handler/__init__.py
Normal file
203
dejavu/database_handler/mysql_database.py
Executable file
203
dejavu/database_handler/mysql_database.py
Executable file
|
@ -0,0 +1,203 @@
|
|||
import queue
|
||||
|
||||
import mysql.connector
|
||||
from mysql.connector.errors import DatabaseError
|
||||
|
||||
from dejavu.base_classes.common_database import CommonDatabase
|
||||
from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||
FIELD_SONGNAME, FIELD_TOTAL_HASHES,
|
||||
FINGERPRINTS_TABLENAME, SONGS_TABLENAME)
|
||||
|
||||
|
||||
class MySQLDatabase(CommonDatabase):
|
||||
type = "mysql"
|
||||
|
||||
# CREATES
|
||||
CREATE_SONGS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{SONGS_TABLENAME}` (
|
||||
`{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT
|
||||
, `{FIELD_SONGNAME}` VARCHAR(250) NOT NULL
|
||||
, `{FIELD_FINGERPRINTED}` TINYINT DEFAULT 0
|
||||
, `{FIELD_FILE_SHA1}` BINARY(20) NOT NULL
|
||||
, `{FIELD_TOTAL_HASHES}` INT NOT NULL DEFAULT 0
|
||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
, CONSTRAINT `pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}` PRIMARY KEY (`{FIELD_SONG_ID}`)
|
||||
, CONSTRAINT `uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}` UNIQUE KEY (`{FIELD_SONG_ID}`)
|
||||
) ENGINE=INNODB;
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS `{FINGERPRINTS_TABLENAME}` (
|
||||
`{FIELD_HASH}` BINARY(10) NOT NULL
|
||||
, `{FIELD_SONG_ID}` MEDIUMINT UNSIGNED NOT NULL
|
||||
, `{FIELD_OFFSET}` INT UNSIGNED NOT NULL
|
||||
, `date_created` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
, `date_modified` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
, INDEX `ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}` (`{FIELD_HASH}`)
|
||||
, CONSTRAINT `uq_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}_{FIELD_OFFSET}_{FIELD_HASH}`
|
||||
UNIQUE KEY (`{FIELD_SONG_ID}`, `{FIELD_OFFSET}`, `{FIELD_HASH}`)
|
||||
, CONSTRAINT `fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}` FOREIGN KEY (`{FIELD_SONG_ID}`)
|
||||
REFERENCES `{SONGS_TABLENAME}`(`{FIELD_SONG_ID}`) ON DELETE CASCADE
|
||||
) ENGINE=INNODB;
|
||||
"""
|
||||
|
||||
# INSERTS (IGNORES DUPLICATES)
|
||||
INSERT_FINGERPRINT = f"""
|
||||
INSERT IGNORE INTO `{FINGERPRINTS_TABLENAME}` (
|
||||
`{FIELD_SONG_ID}`
|
||||
, `{FIELD_HASH}`
|
||||
, `{FIELD_OFFSET}`)
|
||||
VALUES (%s, UNHEX(%s), %s);
|
||||
"""
|
||||
|
||||
INSERT_SONG = f"""
|
||||
INSERT INTO `{SONGS_TABLENAME}` (`{FIELD_SONGNAME}`,`{FIELD_FILE_SHA1}`,`{FIELD_TOTAL_HASHES}`)
|
||||
VALUES (%s, UNHEX(%s), %s);
|
||||
"""
|
||||
|
||||
# SELECTS
|
||||
SELECT = f"""
|
||||
SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||
FROM `{FINGERPRINTS_TABLENAME}`
|
||||
WHERE `{FIELD_HASH}` = UNHEX(%s);
|
||||
"""
|
||||
|
||||
SELECT_MULTIPLE = f"""
|
||||
SELECT HEX(`{FIELD_HASH}`), `{FIELD_SONG_ID}`, `{FIELD_OFFSET}`
|
||||
FROM `{FINGERPRINTS_TABLENAME}`
|
||||
WHERE `{FIELD_HASH}` IN (%s);
|
||||
"""
|
||||
|
||||
SELECT_ALL = f"SELECT `{FIELD_SONG_ID}`, `{FIELD_OFFSET}` FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||
|
||||
SELECT_SONG = f"""
|
||||
SELECT `{FIELD_SONGNAME}`, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`, `{FIELD_TOTAL_HASHES}`
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_SONG_ID}` = %s;
|
||||
"""
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = f"SELECT COUNT(*) AS n FROM `{FINGERPRINTS_TABLENAME}`;"
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = f"""
|
||||
SELECT COUNT(`{FIELD_SONG_ID}`) AS n
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||
"""
|
||||
|
||||
SELECT_SONGS = f"""
|
||||
SELECT
|
||||
`{FIELD_SONG_ID}`
|
||||
, `{FIELD_SONGNAME}`
|
||||
, HEX(`{FIELD_FILE_SHA1}`) AS `{FIELD_FILE_SHA1}`
|
||||
, `{FIELD_TOTAL_HASHES}`
|
||||
, `date_created`
|
||||
FROM `{SONGS_TABLENAME}`
|
||||
WHERE `{FIELD_FINGERPRINTED}` = 1;
|
||||
"""
|
||||
|
||||
# DROPS
|
||||
DROP_FINGERPRINTS = f"DROP TABLE IF EXISTS `{FINGERPRINTS_TABLENAME}`;"
|
||||
DROP_SONGS = f"DROP TABLE IF EXISTS `{SONGS_TABLENAME}`;"
|
||||
|
||||
# UPDATE
|
||||
UPDATE_SONG_FINGERPRINTED = f"""
|
||||
UPDATE `{SONGS_TABLENAME}` SET `{FIELD_FINGERPRINTED}` = 1 WHERE `{FIELD_SONG_ID}` = %s;
|
||||
"""
|
||||
|
||||
# DELETES
|
||||
DELETE_UNFINGERPRINTED = f"""
|
||||
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_FINGERPRINTED}` = 0;
|
||||
"""
|
||||
|
||||
DELETE_SONGS = f"""
|
||||
DELETE FROM `{SONGS_TABLENAME}` WHERE `{FIELD_SONG_ID}` IN (%s);
|
||||
"""
|
||||
|
||||
# IN
|
||||
IN_MATCH = f"UNHEX(%s)"
|
||||
|
||||
def __init__(self, **options):
|
||||
super().__init__()
|
||||
self.cursor = cursor_factory(**options)
|
||||
self._options = options
|
||||
|
||||
def after_fork(self) -> None:
|
||||
# Clear the cursor cache, we don't want any stale connections from
|
||||
# the previous process.
|
||||
Cursor.clear_cache()
|
||||
|
||||
def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int:
|
||||
"""
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
:param song_name: The name of the song.
|
||||
:param file_hash: Hash from the fingerprinted file.
|
||||
:param total_hashes: amount of hashes to be inserted on fingerprint table.
|
||||
:return: the inserted id.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.INSERT_SONG, (song_name, file_hash, total_hashes))
|
||||
return cur.lastrowid
|
||||
|
||||
def __getstate__(self):
|
||||
return self._options,
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._options, = state
|
||||
self.cursor = cursor_factory(**self._options)
|
||||
|
||||
|
||||
def cursor_factory(**factory_options):
|
||||
def cursor(**options):
|
||||
options.update(factory_options)
|
||||
return Cursor(**options)
|
||||
return cursor
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
"""
|
||||
Establishes a connection to the database and returns an open cursor.
|
||||
# Use as context manager
|
||||
with Cursor() as cur:
|
||||
cur.execute(query)
|
||||
...
|
||||
"""
|
||||
def __init__(self, dictionary=False, **options):
|
||||
super().__init__()
|
||||
|
||||
self._cache = queue.Queue(maxsize=5)
|
||||
|
||||
try:
|
||||
conn = self._cache.get_nowait()
|
||||
# Ping the connection before using it from the cache.
|
||||
conn.ping(True)
|
||||
except queue.Empty:
|
||||
conn = mysql.connector.connect(**options)
|
||||
|
||||
self.conn = conn
|
||||
self.dictionary = dictionary
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache = queue.Queue(maxsize=5)
|
||||
|
||||
def __enter__(self):
|
||||
self.cursor = self.conn.cursor(dictionary=self.dictionary)
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, extype, exvalue, traceback):
|
||||
# if we had a MySQL related error we try to rollback the cursor.
|
||||
if extype is DatabaseError:
|
||||
self.cursor.rollback()
|
||||
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
|
||||
# Put it back on the queue
|
||||
try:
|
||||
self._cache.put_nowait(self.conn)
|
||||
except queue.Full:
|
||||
self.conn.close()
|
219
dejavu/database_handler/postgres_database.py
Executable file
219
dejavu/database_handler/postgres_database.py
Executable file
|
@ -0,0 +1,219 @@
|
|||
import queue
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
|
||||
from dejavu.base_classes.common_database import CommonDatabase
|
||||
from dejavu.config.settings import (FIELD_FILE_SHA1, FIELD_FINGERPRINTED,
|
||||
FIELD_HASH, FIELD_OFFSET, FIELD_SONG_ID,
|
||||
FIELD_SONGNAME, FIELD_TOTAL_HASHES,
|
||||
FINGERPRINTS_TABLENAME, SONGS_TABLENAME)
|
||||
|
||||
|
||||
class PostgreSQLDatabase(CommonDatabase):
|
||||
type = "postgres"
|
||||
|
||||
# CREATES
|
||||
CREATE_SONGS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{SONGS_TABLENAME}" (
|
||||
"{FIELD_SONG_ID}" SERIAL
|
||||
, "{FIELD_SONGNAME}" VARCHAR(250) NOT NULL
|
||||
, "{FIELD_FINGERPRINTED}" SMALLINT DEFAULT 0
|
||||
, "{FIELD_FILE_SHA1}" BYTEA
|
||||
, "{FIELD_TOTAL_HASHES}" INT NOT NULL DEFAULT 0
|
||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, CONSTRAINT "pk_{SONGS_TABLENAME}_{FIELD_SONG_ID}" PRIMARY KEY ("{FIELD_SONG_ID}")
|
||||
, CONSTRAINT "uq_{SONGS_TABLENAME}_{FIELD_SONG_ID}" UNIQUE ("{FIELD_SONG_ID}")
|
||||
);
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{FINGERPRINTS_TABLENAME}" (
|
||||
"{FIELD_HASH}" BYTEA NOT NULL
|
||||
, "{FIELD_SONG_ID}" INT NOT NULL
|
||||
, "{FIELD_OFFSET}" INT NOT NULL
|
||||
, "date_created" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, "date_modified" TIMESTAMP NOT NULL DEFAULT now()
|
||||
, CONSTRAINT "uq_{FINGERPRINTS_TABLENAME}" UNIQUE ("{FIELD_SONG_ID}", "{FIELD_OFFSET}", "{FIELD_HASH}")
|
||||
, CONSTRAINT "fk_{FINGERPRINTS_TABLENAME}_{FIELD_SONG_ID}" FOREIGN KEY ("{FIELD_SONG_ID}")
|
||||
REFERENCES "{SONGS_TABLENAME}"("{FIELD_SONG_ID}") ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
|
||||
USING hash ("{FIELD_HASH}");
|
||||
"""
|
||||
|
||||
CREATE_FINGERPRINTS_TABLE_INDEX = f"""
|
||||
CREATE INDEX "ix_{FINGERPRINTS_TABLENAME}_{FIELD_HASH}" ON "{FINGERPRINTS_TABLENAME}"
|
||||
USING hash ("{FIELD_HASH}");
|
||||
"""
|
||||
|
||||
# INSERTS (IGNORES DUPLICATES)
|
||||
INSERT_FINGERPRINT = f"""
|
||||
INSERT INTO "{FINGERPRINTS_TABLENAME}" (
|
||||
"{FIELD_SONG_ID}"
|
||||
, "{FIELD_HASH}"
|
||||
, "{FIELD_OFFSET}")
|
||||
VALUES (%s, decode(%s, 'hex'), %s) ON CONFLICT DO NOTHING;
|
||||
"""
|
||||
|
||||
INSERT_SONG = f"""
|
||||
INSERT INTO "{SONGS_TABLENAME}" ("{FIELD_SONGNAME}", "{FIELD_FILE_SHA1}","{FIELD_TOTAL_HASHES}")
|
||||
VALUES (%s, decode(%s, 'hex'), %s)
|
||||
RETURNING "{FIELD_SONG_ID}";
|
||||
"""
|
||||
|
||||
# SELECTS
|
||||
SELECT = f"""
|
||||
SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||
FROM "{FINGERPRINTS_TABLENAME}"
|
||||
WHERE "{FIELD_HASH}" = decode(%s, 'hex');
|
||||
"""
|
||||
|
||||
SELECT_MULTIPLE = f"""
|
||||
SELECT upper(encode("{FIELD_HASH}", 'hex')), "{FIELD_SONG_ID}", "{FIELD_OFFSET}"
|
||||
FROM "{FINGERPRINTS_TABLENAME}"
|
||||
WHERE "{FIELD_HASH}" IN (%s);
|
||||
"""
|
||||
|
||||
SELECT_ALL = f'SELECT "{FIELD_SONG_ID}", "{FIELD_OFFSET}" FROM "{FINGERPRINTS_TABLENAME}";'
|
||||
|
||||
SELECT_SONG = f"""
|
||||
SELECT
|
||||
"{FIELD_SONGNAME}"
|
||||
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||
, "{FIELD_TOTAL_HASHES}"
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_SONG_ID}" = %s;
|
||||
"""
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = f'SELECT COUNT(*) AS n FROM "{FINGERPRINTS_TABLENAME}";'
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = f"""
|
||||
SELECT COUNT("{FIELD_SONG_ID}") AS n
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||
"""
|
||||
|
||||
SELECT_SONGS = f"""
|
||||
SELECT
|
||||
"{FIELD_SONG_ID}"
|
||||
, "{FIELD_SONGNAME}"
|
||||
, upper(encode("{FIELD_FILE_SHA1}", 'hex')) AS "{FIELD_FILE_SHA1}"
|
||||
, "{FIELD_TOTAL_HASHES}"
|
||||
, "date_created"
|
||||
FROM "{SONGS_TABLENAME}"
|
||||
WHERE "{FIELD_FINGERPRINTED}" = 1;
|
||||
"""
|
||||
|
||||
# DROPS
|
||||
DROP_FINGERPRINTS = F'DROP TABLE IF EXISTS "{FINGERPRINTS_TABLENAME}";'
|
||||
DROP_SONGS = F'DROP TABLE IF EXISTS "{SONGS_TABLENAME}";'
|
||||
|
||||
# UPDATE
|
||||
UPDATE_SONG_FINGERPRINTED = f"""
|
||||
UPDATE "{SONGS_TABLENAME}" SET
|
||||
"{FIELD_FINGERPRINTED}" = 1
|
||||
, "date_modified" = now()
|
||||
WHERE "{FIELD_SONG_ID}" = %s;
|
||||
"""
|
||||
|
||||
# DELETES
|
||||
DELETE_UNFINGERPRINTED = f"""
|
||||
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_FINGERPRINTED}" = 0;
|
||||
"""
|
||||
|
||||
DELETE_SONGS = f"""
|
||||
DELETE FROM "{SONGS_TABLENAME}" WHERE "{FIELD_SONG_ID}" IN (%s);
|
||||
"""
|
||||
|
||||
# IN
|
||||
IN_MATCH = f"decode(%s, 'hex')"
|
||||
|
||||
def __init__(self, **options):
|
||||
super().__init__()
|
||||
self.cursor = cursor_factory(**options)
|
||||
self._options = options
|
||||
|
||||
def after_fork(self) -> None:
|
||||
# Clear the cursor cache, we don't want any stale connections from
|
||||
# the previous process.
|
||||
Cursor.clear_cache()
|
||||
|
||||
def insert_song(self, song_name: str, file_hash: str, total_hashes: int) -> int:
|
||||
"""
|
||||
Inserts a song name into the database, returns the new
|
||||
identifier of the song.
|
||||
|
||||
:param song_name: The name of the song.
|
||||
:param file_hash: Hash from the fingerprinted file.
|
||||
:param total_hashes: amount of hashes to be inserted on fingerprint table.
|
||||
:return: the inserted id.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute(self.INSERT_SONG, (song_name, file_hash, total_hashes))
|
||||
return cur.fetchone()[0]
|
||||
|
||||
def __getstate__(self):
|
||||
return self._options,
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._options, = state
|
||||
self.cursor = cursor_factory(**self._options)
|
||||
|
||||
|
||||
def cursor_factory(**factory_options):
|
||||
def cursor(**options):
|
||||
options.update(factory_options)
|
||||
return Cursor(**options)
|
||||
return cursor
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
"""
|
||||
Establishes a connection to the database and returns an open cursor.
|
||||
# Use as context manager
|
||||
with Cursor() as cur:
|
||||
cur.execute(query)
|
||||
...
|
||||
"""
|
||||
def __init__(self, dictionary=False, **options):
|
||||
super().__init__()
|
||||
|
||||
self._cache = queue.Queue(maxsize=5)
|
||||
|
||||
try:
|
||||
conn = self._cache.get_nowait()
|
||||
# Ping the connection before using it from the cache.
|
||||
conn.ping(True)
|
||||
except queue.Empty:
|
||||
conn = psycopg2.connect(**options)
|
||||
|
||||
self.conn = conn
|
||||
self.dictionary = dictionary
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache = queue.Queue(maxsize=5)
|
||||
|
||||
def __enter__(self):
|
||||
if self.dictionary:
|
||||
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
||||
else:
|
||||
self.cursor = self.conn.cursor()
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, extype, exvalue, traceback):
|
||||
# if we had a PostgreSQL related error we try to rollback the cursor.
|
||||
if extype is psycopg2.DatabaseError:
|
||||
self.cursor.rollback()
|
||||
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
|
||||
# Put it back on the queue
|
||||
try:
|
||||
self._cache.put_nowait(self.conn)
|
||||
except queue.Full:
|
||||
self.conn.close()
|
|
@ -1,373 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from itertools import izip_longest
|
||||
import Queue
|
||||
|
||||
import MySQLdb as mysql
|
||||
from MySQLdb.cursors import DictCursor
|
||||
|
||||
from dejavu.database import Database
|
||||
|
||||
|
||||
class SQLDatabase(Database):
|
||||
"""
|
||||
Queries:
|
||||
|
||||
1) Find duplicates (shouldn't be any, though):
|
||||
|
||||
select `hash`, `song_id`, `offset`, count(*) cnt
|
||||
from fingerprints
|
||||
group by `hash`, `song_id`, `offset`
|
||||
having cnt > 1
|
||||
order by cnt asc;
|
||||
|
||||
2) Get number of hashes by song:
|
||||
|
||||
select song_id, song_name, count(song_id) as num
|
||||
from fingerprints
|
||||
natural join songs
|
||||
group by song_id
|
||||
order by count(song_id) desc;
|
||||
|
||||
3) get hashes with highest number of collisions
|
||||
|
||||
select
|
||||
hash,
|
||||
count(distinct song_id) as n
|
||||
from fingerprints
|
||||
group by `hash`
|
||||
order by n DESC;
|
||||
|
||||
=> 26 different songs with same fingerprint (392 times):
|
||||
|
||||
select songs.song_name, fingerprints.offset
|
||||
from fingerprints natural join songs
|
||||
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
|
||||
"""
|
||||
|
||||
type = "mysql"
|
||||
|
||||
# tables
|
||||
FINGERPRINTS_TABLENAME = "fingerprints"
|
||||
SONGS_TABLENAME = "songs"
|
||||
|
||||
# fields
|
||||
FIELD_FINGERPRINTED = "fingerprinted"
|
||||
|
||||
# creates
|
||||
CREATE_FINGERPRINTS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS `%s` (
|
||||
`%s` binary(10) not null,
|
||||
`%s` mediumint unsigned not null,
|
||||
`%s` int unsigned not null,
|
||||
INDEX (%s),
|
||||
UNIQUE KEY `unique_constraint` (%s, %s, %s),
|
||||
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
|
||||
) ENGINE=INNODB;""" % (
|
||||
FINGERPRINTS_TABLENAME, Database.FIELD_HASH,
|
||||
Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH,
|
||||
Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH,
|
||||
Database.FIELD_SONG_ID, SONGS_TABLENAME, Database.FIELD_SONG_ID
|
||||
)
|
||||
|
||||
CREATE_SONGS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS `%s` (
|
||||
`%s` mediumint unsigned not null auto_increment,
|
||||
`%s` varchar(250) not null,
|
||||
`%s` tinyint default 0,
|
||||
`%s` binary(20) not null,
|
||||
PRIMARY KEY (`%s`),
|
||||
UNIQUE KEY `%s` (`%s`)
|
||||
) ENGINE=INNODB;""" % (
|
||||
SONGS_TABLENAME, Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, FIELD_FINGERPRINTED,
|
||||
Database.FIELD_FILE_SHA1,
|
||||
Database.FIELD_SONG_ID, Database.FIELD_SONG_ID, Database.FIELD_SONG_ID,
|
||||
)
|
||||
|
||||
# inserts (ignores duplicates)
|
||||
INSERT_FINGERPRINT = """
|
||||
INSERT IGNORE INTO %s (%s, %s, %s) values
|
||||
(UNHEX(%%s), %%s, %%s);
|
||||
""" % (FINGERPRINTS_TABLENAME, Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET)
|
||||
|
||||
INSERT_SONG = "INSERT INTO %s (%s, %s) values (%%s, UNHEX(%%s));" % (
|
||||
SONGS_TABLENAME, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1)
|
||||
|
||||
# selects
|
||||
SELECT = """
|
||||
SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);
|
||||
""" % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME, Database.FIELD_HASH)
|
||||
|
||||
SELECT_MULTIPLE = """
|
||||
SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s);
|
||||
""" % (Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET,
|
||||
FINGERPRINTS_TABLENAME, Database.FIELD_HASH)
|
||||
|
||||
SELECT_ALL = """
|
||||
SELECT %s, %s FROM %s;
|
||||
""" % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME)
|
||||
|
||||
SELECT_SONG = """
|
||||
SELECT %s, HEX(%s) as %s FROM %s WHERE %s = %%s;
|
||||
""" % (Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1, SONGS_TABLENAME, Database.FIELD_SONG_ID)
|
||||
|
||||
SELECT_NUM_FINGERPRINTS = """
|
||||
SELECT COUNT(*) as n FROM %s
|
||||
""" % (FINGERPRINTS_TABLENAME)
|
||||
|
||||
SELECT_UNIQUE_SONG_IDS = """
|
||||
SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;
|
||||
""" % (Database.FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
SELECT_SONGS = """
|
||||
SELECT %s, %s, HEX(%s) as %s FROM %s WHERE %s = 1;
|
||||
""" % (Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1,
|
||||
SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
# drops
|
||||
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
|
||||
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
|
||||
|
||||
# update
|
||||
UPDATE_SONG_FINGERPRINTED = """
|
||||
UPDATE %s SET %s = 1 WHERE %s = %%s
|
||||
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, Database.FIELD_SONG_ID)
|
||||
|
||||
# delete
|
||||
DELETE_UNFINGERPRINTED = """
|
||||
DELETE FROM %s WHERE %s = 0;
|
||||
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
|
||||
|
||||
def __init__(self, **options):
|
||||
super(SQLDatabase, self).__init__()
|
||||
self.cursor = cursor_factory(**options)
|
||||
self._options = options
|
||||
|
||||
def after_fork(self):
|
||||
# Clear the cursor cache, we don't want any stale connections from
|
||||
# the previous process.
|
||||
Cursor.clear_cache()
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Creates any non-existing tables required for dejavu to function.
|
||||
|
||||
This also removes all songs that have been added but have no
|
||||
fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.CREATE_SONGS_TABLE)
|
||||
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def empty(self):
|
||||
"""
|
||||
Drops tables created by dejavu and then creates them again
|
||||
by calling `SQLDatabase.setup`.
|
||||
|
||||
.. warning:
|
||||
This will result in a loss of data
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.DROP_FINGERPRINTS)
|
||||
cur.execute(self.DROP_SONGS)
|
||||
|
||||
self.setup()
|
||||
|
||||
def delete_unfingerprinted_songs(self):
|
||||
"""
|
||||
Removes all songs that have no fingerprints associated with them.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.DELETE_UNFINGERPRINTED)
|
||||
|
||||
def get_num_songs(self):
|
||||
"""
|
||||
Returns number of songs the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
|
||||
|
||||
for count, in cur:
|
||||
return count
|
||||
return 0
|
||||
|
||||
def get_num_fingerprints(self):
|
||||
"""
|
||||
Returns number of fingerprints the database has fingerprinted.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.SELECT_NUM_FINGERPRINTS)
|
||||
|
||||
for count, in cur:
|
||||
return count
|
||||
return 0
|
||||
|
||||
def set_song_fingerprinted(self, sid):
|
||||
"""
|
||||
Set the fingerprinted flag to TRUE (1) once a song has been completely
|
||||
fingerprinted in the database.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
|
||||
|
||||
def get_songs(self):
|
||||
"""
|
||||
Return songs that have the fingerprinted flag set TRUE (1).
|
||||
"""
|
||||
with self.cursor(cursor_type=DictCursor, charset="utf8") as cur:
|
||||
cur.execute(self.SELECT_SONGS)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
def get_song_by_id(self, sid):
|
||||
"""
|
||||
Returns song by its ID.
|
||||
"""
|
||||
with self.cursor(cursor_type=DictCursor, charset="utf8") as cur:
|
||||
cur.execute(self.SELECT_SONG, (sid,))
|
||||
return cur.fetchone()
|
||||
|
||||
def insert(self, hash, sid, offset):
|
||||
"""
|
||||
Insert a (sha1, song_id, offset) row into database.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
|
||||
|
||||
def insert_song(self, songname, file_hash):
|
||||
"""
|
||||
Inserts song in the database and returns the ID of the inserted record.
|
||||
"""
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(self.INSERT_SONG, (songname, file_hash))
|
||||
return cur.lastrowid
|
||||
|
||||
def query(self, hash):
|
||||
"""
|
||||
Return all tuples associated with hash.
|
||||
|
||||
If hash is None, returns all entries in the
|
||||
database (be careful with that one!).
|
||||
"""
|
||||
# select all if no key
|
||||
query = self.SELECT_ALL if hash is None else self.SELECT
|
||||
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
cur.execute(query)
|
||||
for sid, offset in cur:
|
||||
yield (sid, offset)
|
||||
|
||||
def get_iterable_kv_pairs(self):
|
||||
"""
|
||||
Returns all tuples in database.
|
||||
"""
|
||||
return self.query(None)
|
||||
|
||||
def insert_hashes(self, sid, hashes):
|
||||
"""
|
||||
Insert series of hash => song_id, offset
|
||||
values into the database.
|
||||
"""
|
||||
values = []
|
||||
for hash, offset in hashes:
|
||||
values.append((hash, sid, offset))
|
||||
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
for split_values in grouper(values, 1000):
|
||||
cur.executemany(self.INSERT_FINGERPRINT, split_values)
|
||||
|
||||
def return_matches(self, hashes):
|
||||
"""
|
||||
Return the (song_id, offset_diff) tuples associated with
|
||||
a list of (sha1, sample_offset) values.
|
||||
"""
|
||||
# Create a dictionary of hash => offset pairs for later lookups
|
||||
mapper = {}
|
||||
for hash, offset in hashes:
|
||||
mapper[hash.upper()] = offset
|
||||
|
||||
# Get an iteratable of all the hashes we need
|
||||
values = mapper.keys()
|
||||
|
||||
with self.cursor(charset="utf8") as cur:
|
||||
for split_values in grouper(values, 1000):
|
||||
# Create our IN part of the query
|
||||
query = self.SELECT_MULTIPLE
|
||||
query = query % ', '.join(['UNHEX(%s)'] * len(split_values))
|
||||
|
||||
cur.execute(query, split_values)
|
||||
|
||||
for hash, sid, offset in cur:
|
||||
# (sid, db_offset - song_sampled_offset)
|
||||
yield (sid, offset - mapper[hash])
|
||||
|
||||
def __getstate__(self):
|
||||
return (self._options,)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._options, = state
|
||||
self.cursor = cursor_factory(**self._options)
|
||||
|
||||
|
||||
def grouper(iterable, n, fillvalue=None):
|
||||
args = [iter(iterable)] * n
|
||||
return (filter(None, values) for values
|
||||
in izip_longest(fillvalue=fillvalue, *args))
|
||||
|
||||
|
||||
def cursor_factory(**factory_options):
|
||||
def cursor(**options):
|
||||
options.update(factory_options)
|
||||
return Cursor(**options)
|
||||
return cursor
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
"""
|
||||
Establishes a connection to the database and returns an open cursor.
|
||||
|
||||
|
||||
```python
|
||||
# Use as context manager
|
||||
with Cursor() as cur:
|
||||
cur.execute(query)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, cursor_type=mysql.cursors.Cursor, **options):
|
||||
super(Cursor, self).__init__()
|
||||
|
||||
self._cache = Queue.Queue(maxsize=5)
|
||||
try:
|
||||
conn = self._cache.get_nowait()
|
||||
except Queue.Empty:
|
||||
conn = mysql.connect(**options)
|
||||
else:
|
||||
# Ping the connection before using it from the cache.
|
||||
conn.ping(True)
|
||||
|
||||
self.conn = conn
|
||||
self.conn.autocommit(False)
|
||||
self.cursor_type = cursor_type
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
cls._cache = Queue.Queue(maxsize=5)
|
||||
|
||||
def __enter__(self):
|
||||
self.cursor = self.conn.cursor(self.cursor_type)
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, extype, exvalue, traceback):
|
||||
# if we had a MySQL related error we try to rollback the cursor.
|
||||
if extype is mysql.MySQLError:
|
||||
self.cursor.rollback()
|
||||
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
|
||||
# Put it back on the queue
|
||||
try:
|
||||
self._cache.put_nowait(self.conn)
|
||||
except Queue.Full:
|
||||
self.conn.close()
|
|
@ -1,157 +0,0 @@
|
|||
import numpy as np
|
||||
import matplotlib.mlab as mlab
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage.filters import maximum_filter
|
||||
from scipy.ndimage.morphology import (generate_binary_structure,
|
||||
iterate_structure, binary_erosion)
|
||||
import hashlib
|
||||
from operator import itemgetter
|
||||
|
||||
IDX_FREQ_I = 0
|
||||
IDX_TIME_J = 1
|
||||
|
||||
######################################################################
|
||||
# Sampling rate, related to the Nyquist conditions, which affects
|
||||
# the range frequencies we can detect.
|
||||
DEFAULT_FS = 44100
|
||||
|
||||
######################################################################
|
||||
# Size of the FFT window, affects frequency granularity
|
||||
DEFAULT_WINDOW_SIZE = 4096
|
||||
|
||||
######################################################################
|
||||
# Ratio by which each sequential window overlaps the last and the
|
||||
# next window. Higher overlap will allow a higher granularity of offset
|
||||
# matching, but potentially more fingerprints.
|
||||
DEFAULT_OVERLAP_RATIO = 0.5
|
||||
|
||||
######################################################################
|
||||
# Degree to which a fingerprint can be paired with its neighbors --
|
||||
# higher will cause more fingerprints, but potentially better accuracy.
|
||||
DEFAULT_FAN_VALUE = 15
|
||||
|
||||
######################################################################
|
||||
# Minimum amplitude in spectrogram in order to be considered a peak.
|
||||
# This can be raised to reduce number of fingerprints, but can negatively
|
||||
# affect accuracy.
|
||||
DEFAULT_AMP_MIN = 10
|
||||
|
||||
######################################################################
|
||||
# Number of cells around an amplitude peak in the spectrogram in order
|
||||
# for Dejavu to consider it a spectral peak. Higher values mean less
|
||||
# fingerprints and faster matching, but can potentially affect accuracy.
|
||||
PEAK_NEIGHBORHOOD_SIZE = 20
|
||||
|
||||
######################################################################
|
||||
# Thresholds on how close or far fingerprints can be in time in order
|
||||
# to be paired as a fingerprint. If your max is too low, higher values of
|
||||
# DEFAULT_FAN_VALUE may not perform as expected.
|
||||
MIN_HASH_TIME_DELTA = 0
|
||||
MAX_HASH_TIME_DELTA = 200
|
||||
|
||||
######################################################################
|
||||
# If True, will sort peaks temporally for fingerprinting;
|
||||
# not sorting will cut down number of fingerprints, but potentially
|
||||
# affect performance.
|
||||
PEAK_SORT = True
|
||||
|
||||
######################################################################
|
||||
# Number of bits to grab from the front of the SHA1 hash in the
|
||||
# fingerprint calculation. The more you grab, the more memory storage,
|
||||
# with potentially lesser collisions of matches.
|
||||
FINGERPRINT_REDUCTION = 20
|
||||
|
||||
def fingerprint(channel_samples, Fs=DEFAULT_FS,
|
||||
wsize=DEFAULT_WINDOW_SIZE,
|
||||
wratio=DEFAULT_OVERLAP_RATIO,
|
||||
fan_value=DEFAULT_FAN_VALUE,
|
||||
amp_min=DEFAULT_AMP_MIN):
|
||||
"""
|
||||
FFT the channel, log transform output, find local maxima, then return
|
||||
locally sensitive hashes.
|
||||
"""
|
||||
# FFT the signal and extract frequency components
|
||||
arr2D = mlab.specgram(
|
||||
channel_samples,
|
||||
NFFT=wsize,
|
||||
Fs=Fs,
|
||||
window=mlab.window_hanning,
|
||||
noverlap=int(wsize * wratio))[0]
|
||||
|
||||
# apply log transform since specgram() returns linear array
|
||||
arr2D = 10 * np.log10(arr2D)
|
||||
arr2D[arr2D == -np.inf] = 0 # replace infs with zeros
|
||||
|
||||
# find local maxima
|
||||
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||
|
||||
# return hashes
|
||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||
|
||||
|
||||
def get_2D_peaks(arr2D, plot=False, amp_min=DEFAULT_AMP_MIN):
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html#scipy.ndimage.iterate_structure
|
||||
struct = generate_binary_structure(2, 1)
|
||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||
|
||||
# find local maxima using our filter shape
|
||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||
background = (arr2D == 0)
|
||||
eroded_background = binary_erosion(background, structure=neighborhood,
|
||||
border_value=1)
|
||||
|
||||
# Boolean mask of arr2D with True at peaks (Fixed deprecated boolean operator by changing '-' to '^')
|
||||
detected_peaks = local_max ^ eroded_background
|
||||
|
||||
# extract peaks
|
||||
amps = arr2D[detected_peaks]
|
||||
j, i = np.where(detected_peaks)
|
||||
|
||||
# filter peaks
|
||||
amps = amps.flatten()
|
||||
peaks = zip(i, j, amps)
|
||||
peaks_filtered = filter(lambda x: x[2]>amp_min, peaks) # freq, time, amp
|
||||
# get indices for frequency and time
|
||||
frequency_idx = []
|
||||
time_idx = []
|
||||
for x in peaks_filtered:
|
||||
frequency_idx.append(x[1])
|
||||
time_idx.append(x[0])
|
||||
|
||||
if plot:
|
||||
# scatter of the peaks
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(arr2D)
|
||||
ax.scatter(time_idx, frequency_idx)
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Frequency')
|
||||
ax.set_title("Spectrogram")
|
||||
plt.gca().invert_yaxis()
|
||||
plt.show()
|
||||
|
||||
return zip(frequency_idx, time_idx)
|
||||
|
||||
|
||||
def generate_hashes(peaks, fan_value=DEFAULT_FAN_VALUE):
|
||||
"""
|
||||
Hash list structure:
|
||||
sha1_hash[0:20] time_offset
|
||||
[(e05b341a9b77a51fd26, 32), ... ]
|
||||
"""
|
||||
if PEAK_SORT:
|
||||
peaks.sort(key=itemgetter(1))
|
||||
|
||||
for i in range(len(peaks)):
|
||||
for j in range(1, fan_value):
|
||||
if (i + j) < len(peaks):
|
||||
|
||||
freq1 = peaks[i][IDX_FREQ_I]
|
||||
freq2 = peaks[i + j][IDX_FREQ_I]
|
||||
t1 = peaks[i][IDX_TIME_J]
|
||||
t2 = peaks[i + j][IDX_TIME_J]
|
||||
t_delta = t2 - t1
|
||||
|
||||
if t_delta >= MIN_HASH_TIME_DELTA and t_delta <= MAX_HASH_TIME_DELTA:
|
||||
h = hashlib.sha1(
|
||||
"%s|%s|%s" % (str(freq1), str(freq2), str(t_delta)))
|
||||
yield (h.hexdigest()[0:FINGERPRINT_REDUCTION], t1)
|
0
dejavu/logic/__init__.py
Normal file
0
dejavu/logic/__init__.py
Normal file
|
@ -1,40 +1,57 @@
|
|||
import os
|
||||
import fnmatch
|
||||
import os
|
||||
from hashlib import sha1
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
from pydub.utils import audioop
|
||||
import wavio
|
||||
from hashlib import sha1
|
||||
|
||||
def unique_hash(filepath, blocksize=2**20):
|
||||
from dejavu.third_party import wavio
|
||||
|
||||
|
||||
def unique_hash(file_path: str, block_size: int = 2**20) -> str:
|
||||
""" Small function to generate a hash to uniquely generate
|
||||
a file. Inspired by MD5 version here:
|
||||
http://stackoverflow.com/a/1131255/712997
|
||||
|
||||
Works with large files.
|
||||
|
||||
:param file_path: path to file.
|
||||
:param block_size: read block size.
|
||||
:return: a hash in an hexagesimal string form.
|
||||
"""
|
||||
s = sha1()
|
||||
with open(filepath , "rb") as f:
|
||||
with open(file_path, "rb") as f:
|
||||
while True:
|
||||
buf = f.read(blocksize)
|
||||
buf = f.read(block_size)
|
||||
if not buf:
|
||||
break
|
||||
s.update(buf)
|
||||
return s.hexdigest().upper()
|
||||
|
||||
|
||||
def find_files(path, extensions):
|
||||
def find_files(path: str, extensions: List[str]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Get all files that meet the specified extensions.
|
||||
|
||||
:param path: path to a directory with audio files.
|
||||
:param extensions: file extensions to look for.
|
||||
:return: a list of tuples with file name and its extension.
|
||||
"""
|
||||
# Allow both with ".mp3" and without "mp3" to be used for extensions
|
||||
extensions = [e.replace(".", "") for e in extensions]
|
||||
|
||||
results = []
|
||||
for dirpath, dirnames, files in os.walk(path):
|
||||
for extension in extensions:
|
||||
for f in fnmatch.filter(files, "*.%s" % extension):
|
||||
for f in fnmatch.filter(files, f"*.{extension}"):
|
||||
p = os.path.join(dirpath, f)
|
||||
yield (p, extension)
|
||||
results.append((p, extension))
|
||||
return results
|
||||
|
||||
|
||||
def read(filename, limit=None):
|
||||
def read(file_name: str, limit: int = None) -> Tuple[List[List[int]], int, str]:
|
||||
"""
|
||||
Reads any file supported by pydub (ffmpeg) and returns the data contained
|
||||
within. If file reading fails due to input being a 24-bit wav file,
|
||||
|
@ -44,24 +61,26 @@ def read(filename, limit=None):
|
|||
of the file by specifying the `limit` parameter. This is the amount of
|
||||
seconds from the start of the file.
|
||||
|
||||
returns: (channels, samplerate)
|
||||
:param file_name: file to be read.
|
||||
:param limit: number of seconds to limit.
|
||||
:return: tuple list of (channels, sample_rate, content_file_hash).
|
||||
"""
|
||||
# pydub does not support 24-bit wav files, use wavio when this occurs
|
||||
try:
|
||||
audiofile = AudioSegment.from_file(filename)
|
||||
audiofile = AudioSegment.from_file(file_name)
|
||||
|
||||
if limit:
|
||||
audiofile = audiofile[:limit * 1000]
|
||||
|
||||
data = np.fromstring(audiofile._data, np.int16)
|
||||
data = np.fromstring(audiofile.raw_data, np.int16)
|
||||
|
||||
channels = []
|
||||
for chn in xrange(audiofile.channels):
|
||||
for chn in range(audiofile.channels):
|
||||
channels.append(data[chn::audiofile.channels])
|
||||
|
||||
fs = audiofile.frame_rate
|
||||
audiofile.frame_rate
|
||||
except audioop.error:
|
||||
fs, _, audiofile = wavio.readwav(filename)
|
||||
_, _, audiofile = wavio.readwav(file_name)
|
||||
|
||||
if limit:
|
||||
audiofile = audiofile[:limit * 1000]
|
||||
|
@ -73,12 +92,14 @@ def read(filename, limit=None):
|
|||
for chn in audiofile:
|
||||
channels.append(chn)
|
||||
|
||||
return channels, audiofile.frame_rate, unique_hash(filename)
|
||||
return channels, audiofile.frame_rate, unique_hash(file_name)
|
||||
|
||||
|
||||
def path_to_songname(path):
|
||||
def get_audio_name_from_path(file_path: str) -> str:
|
||||
"""
|
||||
Extracts song name from a filepath. Used to identify which songs
|
||||
have already been fingerprinted on disk.
|
||||
Extracts song name from a file path.
|
||||
|
||||
:param file_path: path to an audio file.
|
||||
:return: file name
|
||||
"""
|
||||
return os.path.splitext(os.path.basename(path))[0]
|
||||
return os.path.splitext(os.path.basename(file_path))[0]
|
156
dejavu/logic/fingerprint.py
Executable file
156
dejavu/logic/fingerprint.py
Executable file
|
@ -0,0 +1,156 @@
|
|||
import hashlib
|
||||
from operator import itemgetter
|
||||
from typing import List, Tuple
|
||||
|
||||
import matplotlib.mlab as mlab
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import maximum_filter
|
||||
from scipy.ndimage.morphology import (binary_erosion,
|
||||
generate_binary_structure,
|
||||
iterate_structure)
|
||||
|
||||
from dejavu.config.settings import (CONNECTIVITY_MASK, DEFAULT_AMP_MIN,
|
||||
DEFAULT_FAN_VALUE, DEFAULT_FS,
|
||||
DEFAULT_OVERLAP_RATIO, DEFAULT_WINDOW_SIZE,
|
||||
FINGERPRINT_REDUCTION, MAX_HASH_TIME_DELTA,
|
||||
MIN_HASH_TIME_DELTA,
|
||||
PEAK_NEIGHBORHOOD_SIZE, PEAK_SORT)
|
||||
|
||||
|
||||
def fingerprint(channel_samples: List[int],
|
||||
Fs: int = DEFAULT_FS,
|
||||
wsize: int = DEFAULT_WINDOW_SIZE,
|
||||
wratio: float = DEFAULT_OVERLAP_RATIO,
|
||||
fan_value: int = DEFAULT_FAN_VALUE,
|
||||
amp_min: int = DEFAULT_AMP_MIN) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
FFT the channel, log transform output, find local maxima, then return locally sensitive hashes.
|
||||
|
||||
:param channel_samples: channel samples to fingerprint.
|
||||
:param Fs: audio sampling rate.
|
||||
:param wsize: FFT windows size.
|
||||
:param wratio: ratio by which each sequential window overlaps the last and the next window.
|
||||
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
|
||||
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
|
||||
:return: a list of hashes with their corresponding offsets.
|
||||
"""
|
||||
# FFT the signal and extract frequency components
|
||||
arr2D = mlab.specgram(
|
||||
channel_samples,
|
||||
NFFT=wsize,
|
||||
Fs=Fs,
|
||||
window=mlab.window_hanning,
|
||||
noverlap=int(wsize * wratio))[0]
|
||||
|
||||
# Apply log transform since specgram function returns linear array. 0s are excluded to avoid np warning.
|
||||
arr2D = 10 * np.log10(arr2D, out=np.zeros_like(arr2D), where=(arr2D != 0))
|
||||
|
||||
local_maxima = get_2D_peaks(arr2D, plot=False, amp_min=amp_min)
|
||||
|
||||
# return hashes
|
||||
return generate_hashes(local_maxima, fan_value=fan_value)
|
||||
|
||||
|
||||
def get_2D_peaks(arr2D: np.array, plot: bool = False, amp_min: int = DEFAULT_AMP_MIN)\
|
||||
-> List[Tuple[List[int], List[int]]]:
|
||||
"""
|
||||
Extract maximum peaks from the spectogram matrix (arr2D).
|
||||
|
||||
:param arr2D: matrix representing the spectogram.
|
||||
:param plot: for plotting the results.
|
||||
:param amp_min: minimum amplitude in spectrogram in order to be considered a peak.
|
||||
:return: a list composed by a list of frequencies and times.
|
||||
"""
|
||||
# Original code from the repo is using a morphology mask that does not consider diagonal elements
|
||||
# as neighbors (basically a diamond figure) and then applies a dilation over it, so what I'm proposing
|
||||
# is to change from the current diamond figure to a just a normal square one:
|
||||
# F T F T T T
|
||||
# T T T ==> T T T
|
||||
# F T F T T T
|
||||
# In my local tests time performance of the square mask was ~3 times faster
|
||||
# respect to the diamond one, without hurting accuracy of the predictions.
|
||||
# I've made now the mask shape configurable in order to allow both ways of find maximum peaks.
|
||||
# That being said, we generate the mask by using the following function
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.generate_binary_structure.html
|
||||
struct = generate_binary_structure(2, CONNECTIVITY_MASK)
|
||||
|
||||
# And then we apply dilation using the following function
|
||||
# http://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.iterate_structure.html
|
||||
# Take into account that if PEAK_NEIGHBORHOOD_SIZE is 2 you can avoid the use of the scipy functions and just
|
||||
# change it by the following code:
|
||||
# neighborhood = np.ones((PEAK_NEIGHBORHOOD_SIZE * 2 + 1, PEAK_NEIGHBORHOOD_SIZE * 2 + 1), dtype=bool)
|
||||
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
|
||||
|
||||
# find local maxima using our filter mask
|
||||
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
|
||||
|
||||
# Applying erosion, the dejavu documentation does not talk about this step.
|
||||
background = (arr2D == 0)
|
||||
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
|
||||
|
||||
# Boolean mask of arr2D with True at peaks (applying XOR on both matrices).
|
||||
detected_peaks = local_max != eroded_background
|
||||
|
||||
# extract peaks
|
||||
amps = arr2D[detected_peaks]
|
||||
freqs, times = np.where(detected_peaks)
|
||||
|
||||
# filter peaks
|
||||
amps = amps.flatten()
|
||||
|
||||
# get indices for frequency and time
|
||||
filter_idxs = np.where(amps > amp_min)
|
||||
|
||||
freqs_filter = freqs[filter_idxs]
|
||||
times_filter = times[filter_idxs]
|
||||
|
||||
if plot:
|
||||
# scatter of the peaks
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(arr2D)
|
||||
ax.scatter(times_filter, freqs_filter)
|
||||
ax.set_xlabel('Time')
|
||||
ax.set_ylabel('Frequency')
|
||||
ax.set_title("Spectrogram")
|
||||
plt.gca().invert_yaxis()
|
||||
plt.show()
|
||||
|
||||
return list(zip(freqs_filter, times_filter))
|
||||
|
||||
|
||||
def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Hash list structure:
|
||||
sha1_hash[0:FINGERPRINT_REDUCTION] time_offset
|
||||
[(e05b341a9b77a51fd26, 32), ... ]
|
||||
|
||||
:param peaks: list of peak frequencies and times.
|
||||
:param fan_value: degree to which a fingerprint can be paired with its neighbors.
|
||||
:return: a list of hashes with their corresponding offsets.
|
||||
"""
|
||||
# frequencies are in the first position of the tuples
|
||||
idx_freq = 0
|
||||
# times are in the second position of the tuples
|
||||
idx_time = 1
|
||||
|
||||
if PEAK_SORT:
|
||||
peaks.sort(key=itemgetter(1))
|
||||
|
||||
hashes = []
|
||||
for i in range(len(peaks)):
|
||||
for j in range(1, fan_value):
|
||||
if (i + j) < len(peaks):
|
||||
|
||||
freq1 = peaks[i][idx_freq]
|
||||
freq2 = peaks[i + j][idx_freq]
|
||||
t1 = peaks[i][idx_time]
|
||||
t2 = peaks[i + j][idx_time]
|
||||
t_delta = t2 - t1
|
||||
|
||||
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
|
||||
h = hashlib.sha1(f"{str(freq1)}|{str(freq2)}|{str(t_delta)}".encode('utf-8'))
|
||||
|
||||
hashes.append((h.hexdigest()[0:FINGERPRINT_REDUCTION], t1))
|
||||
|
||||
return hashes
|
0
dejavu/logic/recognizer/__init__.py
Normal file
0
dejavu/logic/recognizer/__init__.py
Normal file
32
dejavu/logic/recognizer/file_recognizer.py
Normal file
32
dejavu/logic/recognizer/file_recognizer.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from time import time
|
||||
from typing import Dict
|
||||
|
||||
import dejavu.logic.decoder as decoder
|
||||
from dejavu.base_classes.base_recognizer import BaseRecognizer
|
||||
from dejavu.config.settings import (ALIGN_TIME, FINGERPRINT_TIME, QUERY_TIME,
|
||||
RESULTS, TOTAL_TIME)
|
||||
|
||||
|
||||
class FileRecognizer(BaseRecognizer):
|
||||
def __init__(self, dejavu):
|
||||
super().__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filename: str) -> Dict[str, any]:
|
||||
channels, self.Fs, _ = decoder.read(filename, self.dejavu.limit)
|
||||
|
||||
t = time()
|
||||
matches, fingerprint_time, query_time, align_time = self._recognize(*channels)
|
||||
t = time() - t
|
||||
|
||||
results = {
|
||||
TOTAL_TIME: t,
|
||||
FINGERPRINT_TIME: fingerprint_time,
|
||||
QUERY_TIME: query_time,
|
||||
ALIGN_TIME: align_time,
|
||||
RESULTS: matches
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def recognize(self, filename: str) -> Dict[str, any]:
|
||||
return self.recognize_file(filename)
|
|
@ -1,55 +1,17 @@
|
|||
# encoding: utf-8
|
||||
import dejavu.fingerprint as fingerprint
|
||||
import dejavu.decoder as decoder
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import time
|
||||
|
||||
|
||||
class BaseRecognizer(object):
|
||||
|
||||
def __init__(self, dejavu):
|
||||
self.dejavu = dejavu
|
||||
self.Fs = fingerprint.DEFAULT_FS
|
||||
|
||||
def _recognize(self, *data):
|
||||
matches = []
|
||||
for d in data:
|
||||
matches.extend(self.dejavu.find_matches(d, Fs=self.Fs))
|
||||
return self.dejavu.align_matches(matches)
|
||||
|
||||
def recognize(self):
|
||||
pass # base class does nothing
|
||||
|
||||
|
||||
class FileRecognizer(BaseRecognizer):
|
||||
def __init__(self, dejavu):
|
||||
super(FileRecognizer, self).__init__(dejavu)
|
||||
|
||||
def recognize_file(self, filename):
|
||||
frames, self.Fs, file_hash = decoder.read(filename, self.dejavu.limit)
|
||||
|
||||
t = time.time()
|
||||
match = self._recognize(*frames)
|
||||
t = time.time() - t
|
||||
|
||||
if match:
|
||||
match['match_time'] = t
|
||||
|
||||
return match
|
||||
|
||||
def recognize(self, filename):
|
||||
return self.recognize_file(filename)
|
||||
from dejavu.base_classes.base_recognizer import BaseRecognizer
|
||||
|
||||
|
||||
class MicrophoneRecognizer(BaseRecognizer):
|
||||
default_chunksize = 8192
|
||||
default_format = pyaudio.paInt16
|
||||
default_channels = 2
|
||||
default_samplerate = 44100
|
||||
default_chunksize = 8192
|
||||
default_format = pyaudio.paInt16
|
||||
default_channels = 2
|
||||
default_samplerate = 44100
|
||||
|
||||
def __init__(self, dejavu):
|
||||
super(MicrophoneRecognizer, self).__init__(dejavu)
|
||||
super().__init__(dejavu)
|
||||
self.audio = pyaudio.PyAudio()
|
||||
self.stream = None
|
||||
self.data = []
|
0
dejavu/tests/__init__.py
Normal file
0
dejavu/tests/__init__.py
Normal file
|
@ -1,14 +1,183 @@
|
|||
from __future__ import division
|
||||
from pydub import AudioSegment
|
||||
from dejavu.decoder import path_to_songname
|
||||
from dejavu import Dejavu
|
||||
from dejavu.fingerprint import *
|
||||
import traceback
|
||||
import fnmatch
|
||||
import os, re, ast
|
||||
import subprocess
|
||||
import random
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import traceback
|
||||
from os import listdir, makedirs, walk
|
||||
from os.path import basename, exists, isfile, join, splitext
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pydub import AudioSegment
|
||||
|
||||
from dejavu.config.settings import (DEFAULT_FS, DEFAULT_OVERLAP_RATIO,
|
||||
DEFAULT_WINDOW_SIZE, HASHES_MATCHED,
|
||||
OFFSET, RESULTS, SONG_NAME, TOTAL_TIME)
|
||||
from dejavu.logic.decoder import get_audio_name_from_path
|
||||
|
||||
|
||||
class DejavuTest:
|
||||
def __init__(self, folder, seconds):
|
||||
super().__init__()
|
||||
|
||||
self.test_folder = folder
|
||||
self.test_seconds = seconds
|
||||
self.test_songs = []
|
||||
|
||||
print("test_seconds", self.test_seconds)
|
||||
|
||||
self.test_files = [
|
||||
f for f in listdir(self.test_folder)
|
||||
if isfile(join(self.test_folder, f))
|
||||
and any([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds])
|
||||
]
|
||||
|
||||
print("test_files", self.test_files)
|
||||
|
||||
self.n_columns = len(self.test_seconds)
|
||||
self.n_lines = int(len(self.test_files) / self.n_columns)
|
||||
|
||||
print("columns:", self.n_columns)
|
||||
print("length of test files:", len(self.test_files))
|
||||
print("lines:", self.n_lines)
|
||||
|
||||
# variable match results (yes, no, invalid)
|
||||
self.result_match = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||
|
||||
print("result_match matrix:", self.result_match)
|
||||
|
||||
# variable match precision (if matched in the corrected time)
|
||||
self.result_matching_times = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||
|
||||
# variable matching time (query time)
|
||||
self.result_query_duration = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||
|
||||
# variable confidence
|
||||
self.result_match_confidence = [[0 for x in range(self.n_columns)] for x in range(self.n_lines)]
|
||||
|
||||
self.begin()
|
||||
|
||||
def get_column_id(self, secs):
|
||||
for i, sec in enumerate(self.test_seconds):
|
||||
if secs == sec:
|
||||
return i
|
||||
|
||||
def get_line_id(self, song):
|
||||
for i, s in enumerate(self.test_songs):
|
||||
if song == s:
|
||||
return i
|
||||
self.test_songs.append(song)
|
||||
return len(self.test_songs) - 1
|
||||
|
||||
def create_plots(self, name, results, results_folder):
|
||||
for sec in range(0, len(self.test_seconds)):
|
||||
ind = np.arange(self.n_lines)
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 2 * width])
|
||||
|
||||
means_dvj = [x[0] for x in results[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
# add some
|
||||
ax.set_ylabel(name)
|
||||
ax.set_title(f"{self.test_seconds[sec]} {name} Results")
|
||||
ax.set_xticks(ind + width)
|
||||
|
||||
labels = [0 for x in range(0, self.n_lines)]
|
||||
for x in range(0, self.n_lines):
|
||||
labels[x] = f"song {x+1}"
|
||||
ax.set_xticklabels(labels)
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
|
||||
if name == 'Confidence':
|
||||
autolabel(rects1, ax)
|
||||
else:
|
||||
autolabeldoubles(rects1, ax)
|
||||
|
||||
plt.grid()
|
||||
|
||||
fig_name = join(results_folder, f"{name}_{self.test_seconds[sec]}.png")
|
||||
fig.savefig(fig_name)
|
||||
|
||||
def begin(self):
|
||||
for f in self.test_files:
|
||||
log_msg('--------------------------------------------------')
|
||||
log_msg(f'file: {f}')
|
||||
|
||||
# get column
|
||||
col = self.get_column_id([x for x in re.findall("[0-9]sec", f) if x in self.test_seconds][0])
|
||||
|
||||
# format: XXXX_offset_length.mp3, we also take into account underscores within XXXX
|
||||
splits = get_audio_name_from_path(f).split("_")
|
||||
song = "_".join(splits[0:len(get_audio_name_from_path(f).split("_")) - 2])
|
||||
line = self.get_line_id(song)
|
||||
result = subprocess.check_output([
|
||||
"python",
|
||||
"dejavu.py",
|
||||
'-r',
|
||||
'file',
|
||||
join(self.test_folder, f)])
|
||||
|
||||
if result.strip() == "None":
|
||||
log_msg('No match')
|
||||
self.result_match[line][col] = 'no'
|
||||
self.result_matching_times[line][col] = 0
|
||||
self.result_query_duration[line][col] = 0
|
||||
self.result_match_confidence[line][col] = 0
|
||||
|
||||
else:
|
||||
result = result.strip()
|
||||
# we parse the output song back to a json
|
||||
result = json.loads(result.decode('utf-8').replace("'", '"').replace(': b"', ':"'))
|
||||
|
||||
# which song did we predict? We consider only the first match.
|
||||
match = result[RESULTS][0]
|
||||
song_result = match[SONG_NAME]
|
||||
log_msg(f'song: {song}')
|
||||
log_msg(f'song_result: {song_result}')
|
||||
|
||||
if song_result != song:
|
||||
log_msg('invalid match')
|
||||
self.result_match[line][col] = 'invalid'
|
||||
self.result_matching_times[line][col] = 0
|
||||
self.result_query_duration[line][col] = 0
|
||||
self.result_match_confidence[line][col] = 0
|
||||
else:
|
||||
log_msg('correct match')
|
||||
print(self.result_match)
|
||||
self.result_match[line][col] = 'yes'
|
||||
self.result_query_duration[line][col] = round(result[TOTAL_TIME], 3)
|
||||
self.result_match_confidence[line][col] = match[HASHES_MATCHED]
|
||||
|
||||
# using replace in f for getting rid of underscores in name
|
||||
song_start_time = re.findall("_[^_]+", f.replace(song, ""))
|
||||
song_start_time = song_start_time[0].lstrip("_ ")
|
||||
|
||||
result_start_time = round((match[OFFSET] * DEFAULT_WINDOW_SIZE *
|
||||
DEFAULT_OVERLAP_RATIO) / DEFAULT_FS, 0)
|
||||
|
||||
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
||||
if abs(self.result_matching_times[line][col]) == 1:
|
||||
self.result_matching_times[line][col] = 0
|
||||
|
||||
log_msg(f'query duration: {round(result[TOTAL_TIME], 3)}')
|
||||
log_msg(f'confidence: {match[HASHES_MATCHED]}')
|
||||
log_msg(f'song start_time: {song_start_time}')
|
||||
log_msg(f'result start time: {result_start_time}')
|
||||
|
||||
if self.result_matching_times[line][col] == 0:
|
||||
log_msg('accurate match')
|
||||
else:
|
||||
log_msg('inaccurate match')
|
||||
log_msg('--------------------------------------------------\n')
|
||||
|
||||
|
||||
def set_seed(seed=None):
|
||||
"""
|
||||
|
@ -17,17 +186,22 @@ def set_seed(seed=None):
|
|||
Setting your own seed means that you can produce the
|
||||
same experiment over and over.
|
||||
"""
|
||||
if seed != None:
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def get_files_recursive(src, fmt):
|
||||
"""
|
||||
`src` is the source directory.
|
||||
`fmt` is the extension, ie ".mp3" or "mp3", etc.
|
||||
"""
|
||||
for root, dirnames, filenames in os.walk(src):
|
||||
files = []
|
||||
for root, dirnames, filenames in walk(src):
|
||||
for filename in fnmatch.filter(filenames, '*' + fmt):
|
||||
yield os.path.join(root, filename)
|
||||
files.append(join(root, filename))
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def get_length_audio(audiopath, extension):
|
||||
"""
|
||||
|
@ -36,11 +210,12 @@ def get_length_audio(audiopath, extension):
|
|||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_file(audiopath, extension.replace(".", ""))
|
||||
except:
|
||||
print "Error in get_length_audio(): %s" % traceback.format_exc()
|
||||
except Exception:
|
||||
print(f"Error in get_length_audio(): {traceback.format_exc()}")
|
||||
return None
|
||||
return int(len(audio) / 1000.0)
|
||||
|
||||
|
||||
def get_starttime(length, nseconds, padding):
|
||||
"""
|
||||
`length` is total audio length in seconds
|
||||
|
@ -52,6 +227,7 @@ def get_starttime(length, nseconds, padding):
|
|||
return 0
|
||||
return random.randint(padding, maximum)
|
||||
|
||||
|
||||
def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
|
||||
"""
|
||||
Generates a test file for each file recursively in `src` directory
|
||||
|
@ -64,213 +240,47 @@ def generate_test_files(src, dest, nseconds, fmts=[".mp3", ".wav"], padding=10):
|
|||
avoid silence, etc.
|
||||
"""
|
||||
# create directories if necessary
|
||||
for directory in [src, dest]:
|
||||
try:
|
||||
os.stat(directory)
|
||||
except:
|
||||
os.mkdir(directory)
|
||||
if not exists(dest):
|
||||
makedirs(dest)
|
||||
|
||||
# find files recursively of a given file format
|
||||
for fmt in fmts:
|
||||
testsources = get_files_recursive(src, fmt)
|
||||
for audiosource in testsources:
|
||||
|
||||
print "audiosource:", audiosource
|
||||
print("audiosource:", audiosource)
|
||||
|
||||
filename, extension = os.path.splitext(os.path.basename(audiosource))
|
||||
filename, extension = splitext(basename(audiosource))
|
||||
length = get_length_audio(audiosource, extension)
|
||||
starttime = get_starttime(length, nseconds, padding)
|
||||
|
||||
test_file_name = "%s_%s_%ssec.%s" % (
|
||||
os.path.join(dest, filename), starttime,
|
||||
nseconds, extension.replace(".", ""))
|
||||
test_file_name = f"{join(dest, filename)}_{starttime}_{nseconds}sec.{extension.replace('.', '')}"
|
||||
|
||||
subprocess.check_output([
|
||||
"ffmpeg", "-y",
|
||||
"-ss", "%d" % starttime,
|
||||
'-t' , "%d" % nseconds,
|
||||
"-ss", f"{starttime}",
|
||||
'-t', f"{nseconds}",
|
||||
"-i", audiosource,
|
||||
test_file_name])
|
||||
|
||||
|
||||
def log_msg(msg, log=True, silent=False):
|
||||
if log:
|
||||
logging.debug(msg)
|
||||
if not silent:
|
||||
print msg
|
||||
print(msg)
|
||||
|
||||
|
||||
def autolabel(rects, ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height,
|
||||
'%d' % int(height), ha='center', va='bottom')
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{int(height)}', ha='center', va='bottom')
|
||||
|
||||
|
||||
def autolabeldoubles(rects, ax):
|
||||
# attach some text labels
|
||||
for rect in rects:
|
||||
height = rect.get_height()
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height,
|
||||
'%s' % round(float(height), 3), ha='center', va='bottom')
|
||||
|
||||
class DejavuTest(object):
|
||||
def __init__(self, folder, seconds):
|
||||
super(DejavuTest, self).__init__()
|
||||
|
||||
self.test_folder = folder
|
||||
self.test_seconds = seconds
|
||||
self.test_songs = []
|
||||
|
||||
print "test_seconds", self.test_seconds
|
||||
|
||||
self.test_files = [
|
||||
f for f in os.listdir(self.test_folder)
|
||||
if os.path.isfile(os.path.join(self.test_folder, f))
|
||||
and re.findall("[0-9]*sec", f)[0] in self.test_seconds]
|
||||
|
||||
print "test_files", self.test_files
|
||||
|
||||
self.n_columns = len(self.test_seconds)
|
||||
self.n_lines = int(len(self.test_files) / self.n_columns)
|
||||
|
||||
print "columns:", self.n_columns
|
||||
print "length of test files:", len(self.test_files)
|
||||
print "lines:", self.n_lines
|
||||
|
||||
# variable match results (yes, no, invalid)
|
||||
self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
print "result_match matrix:", self.result_match
|
||||
|
||||
# variable match precision (if matched in the corrected time)
|
||||
self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
# variable mahing time (query time)
|
||||
self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
# variable confidence
|
||||
self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
|
||||
|
||||
self.begin()
|
||||
|
||||
def get_column_id (self, secs):
|
||||
for i, sec in enumerate(self.test_seconds):
|
||||
if secs == sec:
|
||||
return i
|
||||
|
||||
def get_line_id (self, song):
|
||||
for i, s in enumerate(self.test_songs):
|
||||
if song == s:
|
||||
return i
|
||||
self.test_songs.append(song)
|
||||
return len(self.test_songs) - 1
|
||||
|
||||
def create_plots(self, name, results, results_folder):
|
||||
for sec in range(0, len(self.test_seconds)):
|
||||
ind = np.arange(self.n_lines) #
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 2 * width])
|
||||
|
||||
means_dvj = [x[0] for x in results[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
# add some
|
||||
ax.set_ylabel(name)
|
||||
ax.set_title("%s %s Results" % (self.test_seconds[sec], name))
|
||||
ax.set_xticks(ind + width)
|
||||
|
||||
labels = [0 for x in range(0, self.n_lines)]
|
||||
for x in range(0, self.n_lines):
|
||||
labels[x] = "song %s" % (x+1)
|
||||
ax.set_xticklabels(labels)
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
|
||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
if name == 'Confidence':
|
||||
autolabel(rects1, ax)
|
||||
else:
|
||||
autolabeldoubles(rects1, ax)
|
||||
|
||||
plt.grid()
|
||||
|
||||
fig_name = os.path.join(results_folder, "%s_%s.png" % (name, self.test_seconds[sec]))
|
||||
fig.savefig(fig_name)
|
||||
|
||||
def begin(self):
|
||||
for f in self.test_files:
|
||||
log_msg('--------------------------------------------------')
|
||||
log_msg('file: %s' % f)
|
||||
|
||||
# get column
|
||||
col = self.get_column_id(re.findall("[0-9]*sec", f)[0])
|
||||
# format: XXXX_offset_length.mp3
|
||||
song = path_to_songname(f).split("_")[0]
|
||||
line = self.get_line_id(song)
|
||||
result = subprocess.check_output([
|
||||
"python",
|
||||
"dejavu.py",
|
||||
'-r',
|
||||
'file',
|
||||
self.test_folder + "/" + f])
|
||||
|
||||
if result.strip() == "None":
|
||||
log_msg('No match')
|
||||
self.result_match[line][col] = 'no'
|
||||
self.result_matching_times[line][col] = 0
|
||||
self.result_query_duration[line][col] = 0
|
||||
self.result_match_confidence[line][col] = 0
|
||||
|
||||
else:
|
||||
result = result.strip()
|
||||
result = result.replace(" \'", ' "')
|
||||
result = result.replace("{\'", '{"')
|
||||
result = result.replace("\':", '":')
|
||||
result = result.replace("\',", '",')
|
||||
|
||||
# which song did we predict?
|
||||
result = ast.literal_eval(result)
|
||||
song_result = result["song_name"]
|
||||
log_msg('song: %s' % song)
|
||||
log_msg('song_result: %s' % song_result)
|
||||
|
||||
if song_result != song:
|
||||
log_msg('invalid match')
|
||||
self.result_match[line][col] = 'invalid'
|
||||
self.result_matching_times[line][col] = 0
|
||||
self.result_query_duration[line][col] = 0
|
||||
self.result_match_confidence[line][col] = 0
|
||||
else:
|
||||
log_msg('correct match')
|
||||
print self.result_match
|
||||
self.result_match[line][col] = 'yes'
|
||||
self.result_query_duration[line][col] = round(result[Dejavu.MATCH_TIME],3)
|
||||
self.result_match_confidence[line][col] = result[Dejavu.CONFIDENCE]
|
||||
|
||||
song_start_time = re.findall("\_[^\_]+",f)
|
||||
song_start_time = song_start_time[0].lstrip("_ ")
|
||||
|
||||
result_start_time = round((result[Dejavu.OFFSET] * DEFAULT_WINDOW_SIZE *
|
||||
DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS), 0)
|
||||
|
||||
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
|
||||
if (abs(self.result_matching_times[line][col]) == 1):
|
||||
self.result_matching_times[line][col] = 0
|
||||
|
||||
log_msg('query duration: %s' % round(result[Dejavu.MATCH_TIME],3))
|
||||
log_msg('confidence: %s' % result[Dejavu.CONFIDENCE])
|
||||
log_msg('song start_time: %s' % song_start_time)
|
||||
log_msg('result start time: %s' % result_start_time)
|
||||
if (self.result_matching_times[line][col] == 0):
|
||||
log_msg('accurate match')
|
||||
else:
|
||||
log_msg('inaccurate match')
|
||||
log_msg('--------------------------------------------------\n')
|
||||
|
||||
|
||||
|
||||
|
||||
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, f'{round(float(height), 3)}',
|
||||
ha='center', va='bottom')
|
0
dejavu/third_party/__init__.py
vendored
Normal file
0
dejavu/third_party/__init__.py
vendored
Normal file
357
dejavu/third_party/wavio.py
vendored
Normal file
357
dejavu/third_party/wavio.py
vendored
Normal file
|
@ -0,0 +1,357 @@
|
|||
# wavio.py
|
||||
# Author: Warren Weckesser
|
||||
# License: BSD 2-Clause (http://opensource.org/licenses/BSD-2-Clause)
|
||||
# Synopsis: A Python module for reading and writing 24 bit WAV files.
|
||||
# Github: github.com/WarrenWeckesser/wavio
|
||||
|
||||
"""
|
||||
The wavio module defines the functions:
|
||||
read(file)
|
||||
Read a WAV file and return a `wavio.Wav` object, with attributes
|
||||
`data`, `rate` and `sampwidth`.
|
||||
write(filename, data, rate, scale=None, sampwidth=None)
|
||||
Write a numpy array to a WAV file.
|
||||
-----
|
||||
Author: Warren Weckesser
|
||||
License: BSD 2-Clause:
|
||||
Copyright (c) 2015, Warren Weckesser
|
||||
All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
1. Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
"""
|
||||
|
||||
|
||||
import wave as _wave
|
||||
|
||||
import numpy as _np
|
||||
|
||||
__version__ = "0.0.5.dev1"
|
||||
|
||||
|
||||
def _wav2array(nchannels, sampwidth, data):
|
||||
"""data must be the string containing the bytes from the wav file."""
|
||||
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
|
||||
if remainder > 0:
|
||||
raise ValueError('The length of data is not a multiple of '
|
||||
'sampwidth * num_channels.')
|
||||
if sampwidth > 4:
|
||||
raise ValueError("sampwidth must not be greater than 4.")
|
||||
|
||||
if sampwidth == 3:
|
||||
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
|
||||
raw_bytes = _np.frombuffer(data, dtype=_np.uint8)
|
||||
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
|
||||
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
|
||||
result = a.view('<i4').reshape(a.shape[:-1])
|
||||
else:
|
||||
# 8 bit samples are stored as unsigned ints; others as signed ints.
|
||||
dt_char = 'u' if sampwidth == 1 else 'i'
|
||||
a = _np.frombuffer(data, dtype=f'<{dt_char}{sampwidth}')
|
||||
result = a.reshape(-1, nchannels)
|
||||
return result
|
||||
|
||||
|
||||
def _array2wav(a, sampwidth):
|
||||
"""
|
||||
Convert the input array `a` to a string of WAV data.
|
||||
a.dtype must be one of uint8, int16 or int32. Allowed sampwidth
|
||||
values are:
|
||||
dtype sampwidth
|
||||
uint8 1
|
||||
int16 2
|
||||
int32 3 or 4
|
||||
When sampwidth is 3, the *low* bytes of `a` are assumed to contain
|
||||
the values to include in the string.
|
||||
"""
|
||||
if sampwidth == 3:
|
||||
# `a` must have dtype int32
|
||||
if a.ndim == 1:
|
||||
# Convert to a 2D array with a single column.
|
||||
a = a.reshape(-1, 1)
|
||||
# By shifting first 0 bits, then 8, then 16, the resulting output
|
||||
# is 24 bit little-endian.
|
||||
a8 = (a.reshape(a.shape + (1,)) >> _np.array([0, 8, 16])) & 255
|
||||
wavdata = a8.astype(_np.uint8).tostring()
|
||||
else:
|
||||
# Make sure the array is little-endian, and then convert using
|
||||
# tostring()
|
||||
a = a.astype('<' + a.dtype.str[1:], copy=False)
|
||||
wavdata = a.tostring()
|
||||
return wavdata
|
||||
|
||||
|
||||
class Wav(object):
|
||||
"""
|
||||
Object returned by `wavio.read`. Attributes are:
|
||||
data : numpy array
|
||||
The array of data read from the WAV file.
|
||||
rate : float
|
||||
The sample rate of the WAV file.
|
||||
sampwidth : int
|
||||
The sample width (i.e. number of bytes per sample) of the WAV file.
|
||||
For example, `sampwidth == 3` is a 24 bit WAV file.
|
||||
"""
|
||||
|
||||
def __init__(self, data, rate, sampwidth):
|
||||
self.data = data
|
||||
self.rate = rate
|
||||
self.sampwidth = sampwidth
|
||||
|
||||
def __repr__(self):
|
||||
s = (f"Wav(data.shape={self.data.shape}, data.dtype={self.data.dtype}, "
|
||||
f"rate={self.rate}, sampwidth={self.sampwidth})")
|
||||
return s
|
||||
|
||||
|
||||
def read(file):
|
||||
"""
|
||||
Read a WAV file.
|
||||
Parameters
|
||||
----------
|
||||
file : string or file object
|
||||
Either the name of a file or an open file pointer.
|
||||
Returns
|
||||
-------
|
||||
wav : wavio.Wav() instance
|
||||
The return value is an instance of the class `wavio.Wav`,
|
||||
with the following attributes:
|
||||
data : numpy array
|
||||
The array containing the data. The shape of the array
|
||||
is (num_samples, num_channels). num_channels is the
|
||||
number of audio channels (1 for mono, 2 for stereo).
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate)
|
||||
sampwidth : float
|
||||
The sample width, in bytes. E.g. for a 24 bit WAV file,
|
||||
sampwidth is 3.
|
||||
Notes
|
||||
-----
|
||||
This function uses the `wave` module of the Python standard libary
|
||||
to read the WAV file, so it has the same limitations as that library.
|
||||
In particular, the function does not read compressed WAV files, and
|
||||
it does not read files with floating point data.
|
||||
The array returned by `wavio.read` is always two-dimensional. If the
|
||||
WAV data is mono, the array will have shape (num_samples, 1).
|
||||
`wavio.read()` does not scale or normalize the data. The data in the
|
||||
array `wav.data` is the data that was in the file. When the file
|
||||
contains 24 bit samples, the resulting numpy array is 32 bit integers,
|
||||
with values that have been sign-extended.
|
||||
"""
|
||||
wav = _wave.open(file)
|
||||
rate = wav.getframerate()
|
||||
nchannels = wav.getnchannels()
|
||||
sampwidth = wav.getsampwidth()
|
||||
nframes = wav.getnframes()
|
||||
data = wav.readframes(nframes)
|
||||
wav.close()
|
||||
array = _wav2array(nchannels, sampwidth, data)
|
||||
w = Wav(data=array, rate=rate, sampwidth=sampwidth)
|
||||
return w
|
||||
|
||||
|
||||
_sampwidth_dtypes = {1: _np.uint8,
|
||||
2: _np.int16,
|
||||
3: _np.int32,
|
||||
4: _np.int32}
|
||||
_sampwidth_ranges = {1: (0, 256),
|
||||
2: (-2**15, 2**15),
|
||||
3: (-2**23, 2**23),
|
||||
4: (-2**31, 2**31)}
|
||||
|
||||
|
||||
def _scale_to_sampwidth(data, sampwidth, vmin, vmax):
|
||||
# Scale and translate the values to fit the range of the data type
|
||||
# associated with the given sampwidth.
|
||||
|
||||
data = data.clip(vmin, vmax)
|
||||
|
||||
dt = _sampwidth_dtypes[sampwidth]
|
||||
if vmax == vmin:
|
||||
data = _np.zeros(data.shape, dtype=dt)
|
||||
else:
|
||||
outmin, outmax = _sampwidth_ranges[sampwidth]
|
||||
if outmin != vmin or outmax != vmax:
|
||||
vmin = float(vmin)
|
||||
vmax = float(vmax)
|
||||
data = (float(outmax - outmin) * (data - vmin) /
|
||||
(vmax - vmin)).astype(_np.int64) + outmin
|
||||
data[data == outmax] = outmax - 1
|
||||
data = data.astype(dt)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def write(file, data, rate, scale=None, sampwidth=None):
|
||||
"""
|
||||
Write the numpy array `data` to a WAV file.
|
||||
The Python standard library "wave" is used to write the data
|
||||
to the file, so this function has the same limitations as that
|
||||
module. In particular, the Python library does not support
|
||||
floating point data. When given a floating point array, this
|
||||
function converts the values to integers. See below for the
|
||||
conversion rules.
|
||||
Parameters
|
||||
----------
|
||||
file : string, or file object open for writing in binary mode
|
||||
Either the name of a file or an open file pointer.
|
||||
data : numpy array, 1- or 2-dimensional, integer or floating point
|
||||
If it is 2-d, the rows are the frames (i.e. samples) and the
|
||||
columns are the channels.
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate) of the data.
|
||||
sampwidth : int, optional
|
||||
The sample width, in bytes, of the output file.
|
||||
If `sampwidth` is not given, it is inferred (if possible) from
|
||||
the data type of `data`, as follows::
|
||||
data.dtype sampwidth
|
||||
---------- ---------
|
||||
uint8, int8 1
|
||||
uint16, int16 2
|
||||
uint32, int32 4
|
||||
For any other data types, or to write a 24 bit file, `sampwidth`
|
||||
must be given.
|
||||
scale : tuple or str, optional
|
||||
By default, the data written to the file is scaled up or down to
|
||||
occupy the full range of the output data type. So, for example,
|
||||
the unsigned 8 bit data [0, 1, 2, 15] would be written to the file
|
||||
as [0, 17, 30, 255]. More generally, the default behavior is
|
||||
(roughly)::
|
||||
vmin = data.min()
|
||||
vmax = data.max()
|
||||
outmin = <minimum integer of the output dtype>
|
||||
outmax = <maximum integer of the output dtype>
|
||||
outdata = (outmax - outmin)*(data - vmin)/(vmax - vmin) + outmin
|
||||
The `scale` argument allows the scaling of the output data to be
|
||||
changed. `scale` can be a tuple of the form `(vmin, vmax)`, in which
|
||||
case the given values override the use of `data.min()` and
|
||||
`data.max()` for `vmin` and `vmax` shown above. (If either value
|
||||
is `None`, the value shown above is used.) Data outside the
|
||||
range (vmin, vmax) is clipped. If `vmin == vmax`, the output is
|
||||
all zeros.
|
||||
If `scale` is the string "none", then `vmin` and `vmax` are set to
|
||||
`outmin` and `outmax`, respectively. This means the data is written
|
||||
to the file with no scaling. (Note: `scale="none" is not the same
|
||||
as `scale=None`. The latter means "use the default behavior",
|
||||
which is to scale by the data minimum and maximum.)
|
||||
If `scale` is the string "dtype-limits", then `vmin` and `vmax`
|
||||
are set to the minimum and maximum integers of `data.dtype`.
|
||||
The string "dtype-limits" is not allowed when the `data` is a
|
||||
floating point array.
|
||||
If using `scale` results in values that exceed the limits of the
|
||||
output sample width, the data is clipped. For example, the
|
||||
following code::
|
||||
>> x = np.array([-100, 0, 100, 200, 300, 325])
|
||||
>> wavio.write('foo.wav', x, 8000, scale='none', sampwidth=1)
|
||||
will write the values [0, 0, 100, 200, 255, 255] to the file.
|
||||
Example
|
||||
-------
|
||||
Create a 3 second 440 Hz sine wave, and save it in a 24-bit WAV file.
|
||||
>> import numpy as np
|
||||
>> import wavio
|
||||
>> rate = 22050 # samples per second
|
||||
>> T = 3 # sample duration (seconds)
|
||||
>> f = 440.0 # sound frequency (Hz)
|
||||
>> t = np.linspace(0, T, T*rate, endpoint=False)
|
||||
>> x = np.sin(2*np.pi * f * t)
|
||||
>> wavio.write("sine24.wav", x, rate, sampwidth=3)
|
||||
Create a file that contains the 16 bit integer values -10000 and 10000
|
||||
repeated 100 times. Don't automatically scale the values. Use a sample
|
||||
rate 8000.
|
||||
>> x = np.empty(200, dtype=np.int16)
|
||||
>> x[::2] = -10000
|
||||
>> x[1::2] = 10000
|
||||
>> wavio.write("foo.wav", x, 8000, scale='none')
|
||||
Check that the file contains what we expect.
|
||||
>> w = wavio.read("foo.wav")
|
||||
>> np.all(w.data[:, 0] == x)
|
||||
True
|
||||
In the following, the values -10000 and 10000 (from within the 16 bit
|
||||
range [-2**15, 2**15-1]) are mapped to the corresponding values 88 and
|
||||
168 (in the range [0, 2**8-1]).
|
||||
>> wavio.write("foo.wav", x, 8000, sampwidth=1, scale='dtype-limits')
|
||||
>> w = wavio.read("foo.wav")
|
||||
>> w.data[:4, 0]
|
||||
array([ 88, 168, 88, 168], dtype=uint8)
|
||||
"""
|
||||
|
||||
if sampwidth is None:
|
||||
if not _np.issubdtype(data.dtype, _np.integer) or data.itemsize > 4:
|
||||
raise ValueError('when data.dtype is not an 8-, 16-, or 32-bit integer type, sampwidth must be specified.')
|
||||
sampwidth = data.itemsize
|
||||
else:
|
||||
if sampwidth not in [1, 2, 3, 4]:
|
||||
raise ValueError('sampwidth must be 1, 2, 3 or 4.')
|
||||
|
||||
outdtype = _sampwidth_dtypes[sampwidth]
|
||||
outmin, outmax = _sampwidth_ranges[sampwidth]
|
||||
|
||||
if scale == "none":
|
||||
data = data.clip(outmin, outmax-1).astype(outdtype)
|
||||
elif scale == "dtype-limits":
|
||||
if not _np.issubdtype(data.dtype, _np.integer):
|
||||
raise ValueError("scale cannot be 'dtype-limits' with non-integer data.")
|
||||
# Easy transforms that just changed the signedness of the data.
|
||||
if sampwidth == 1 and data.dtype == _np.int8:
|
||||
data = (data.astype(_np.int16) + 128).astype(_np.uint8)
|
||||
elif sampwidth == 2 and data.dtype == _np.uint16:
|
||||
data = (data.astype(_np.int32) - 32768).astype(_np.int16)
|
||||
elif sampwidth == 4 and data.dtype == _np.uint32:
|
||||
data = (data.astype(_np.int64) - 2**31).astype(_np.int32)
|
||||
elif data.itemsize != sampwidth:
|
||||
# Integer input, but rescaling is needed to adjust the
|
||||
# input range to the output sample width.
|
||||
ii = _np.iinfo(data.dtype)
|
||||
vmin = ii.min
|
||||
vmax = ii.max
|
||||
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
|
||||
else:
|
||||
if scale is None:
|
||||
vmin = data.min()
|
||||
vmax = data.max()
|
||||
else:
|
||||
# scale must be a tuple of the form (vmin, vmax)
|
||||
vmin, vmax = scale
|
||||
if vmin is None:
|
||||
vmin = data.min()
|
||||
if vmax is None:
|
||||
vmax = data.max()
|
||||
|
||||
data = _scale_to_sampwidth(data, sampwidth, vmin, vmax)
|
||||
|
||||
# At this point, `data` has been converted to have one of the following:
|
||||
# sampwidth dtype
|
||||
# --------- -----
|
||||
# 1 uint8
|
||||
# 2 int16
|
||||
# 3 int32
|
||||
# 4 int32
|
||||
# The values in `data` are in the form in which they will be saved;
|
||||
# no more scaling will take place.
|
||||
|
||||
if data.ndim == 1:
|
||||
data = data.reshape(-1, 1)
|
||||
|
||||
wavdata = _array2wav(data, sampwidth)
|
||||
|
||||
w = _wave.open(file, 'wb')
|
||||
w.setnchannels(data.shape[1])
|
||||
w.setsampwidth(sampwidth)
|
||||
w.setframerate(rate)
|
||||
w.writeframes(wavdata)
|
||||
w.close()
|
121
dejavu/wavio.py
121
dejavu/wavio.py
|
@ -1,121 +0,0 @@
|
|||
# wavio.py
|
||||
# Author: Warren Weckesser
|
||||
# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause)
|
||||
# Synopsis: A Python module for reading and writing 24 bit WAV files.
|
||||
# Github: github.com/WarrenWeckesser/wavio
|
||||
|
||||
import wave as _wave
|
||||
import numpy as _np
|
||||
|
||||
|
||||
def _wav2array(nchannels, sampwidth, data):
|
||||
"""data must be the string containing the bytes from the wav file."""
|
||||
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
|
||||
if remainder > 0:
|
||||
raise ValueError('The length of data is not a multiple of '
|
||||
'sampwidth * num_channels.')
|
||||
if sampwidth > 4:
|
||||
raise ValueError("sampwidth must not be greater than 4.")
|
||||
|
||||
if sampwidth == 3:
|
||||
a = _np.empty((num_samples, nchannels, 4), dtype=_np.uint8)
|
||||
raw_bytes = _np.fromstring(data, dtype=_np.uint8)
|
||||
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
|
||||
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
|
||||
result = a.view('<i4').reshape(a.shape[:-1])
|
||||
else:
|
||||
# 8 bit samples are stored as unsigned ints; others as signed ints.
|
||||
dt_char = 'u' if sampwidth == 1 else 'i'
|
||||
a = _np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
|
||||
result = a.reshape(-1, nchannels)
|
||||
return result
|
||||
|
||||
|
||||
def readwav(file):
|
||||
"""
|
||||
Read a WAV file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : string or file object
|
||||
Either the name of a file or an open file pointer.
|
||||
|
||||
Return Values
|
||||
-------------
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate)
|
||||
sampwidth : float
|
||||
The sample width, in bytes. E.g. for a 24 bit WAV file,
|
||||
sampwidth is 3.
|
||||
data : numpy array
|
||||
The array containing the data. The shape of the array is
|
||||
(num_samples, num_channels). num_channels is the number of
|
||||
audio channels (1 for mono, 2 for stereo).
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function uses the `wave` module of the Python standard libary
|
||||
to read the WAV file, so it has the same limitations as that library.
|
||||
In particular, the function does not read compressed WAV files.
|
||||
|
||||
"""
|
||||
wav = _wave.open(file)
|
||||
rate = wav.getframerate()
|
||||
nchannels = wav.getnchannels()
|
||||
sampwidth = wav.getsampwidth()
|
||||
nframes = wav.getnframes()
|
||||
data = wav.readframes(nframes)
|
||||
wav.close()
|
||||
array = _wav2array(nchannels, sampwidth, data)
|
||||
return rate, sampwidth, array
|
||||
|
||||
|
||||
def writewav24(filename, rate, data):
|
||||
"""
|
||||
Create a 24 bit wav file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : string
|
||||
Name of the file to create.
|
||||
rate : float
|
||||
The sampling frequency (i.e. frame rate) of the data.
|
||||
data : array-like collection of integer or floating point values
|
||||
data must be "array-like", either 1- or 2-dimensional. If it
|
||||
is 2-d, the rows are the frames (i.e. samples) and the columns
|
||||
are the channels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The data is assumed to be signed, and the values are assumed to be
|
||||
within the range of a 24 bit integer. Floating point values are
|
||||
converted to integers. The data is not rescaled or normalized before
|
||||
writing it to the file.
|
||||
|
||||
Example
|
||||
-------
|
||||
Create a 3 second 440 Hz sine wave.
|
||||
|
||||
>>> rate = 22050 # samples per second
|
||||
>>> T = 3 # sample duration (seconds)
|
||||
>>> f = 440.0 # sound frequency (Hz)
|
||||
>>> t = np.linspace(0, T, T*rate, endpoint=False)
|
||||
>>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t)
|
||||
>>> writewav24("sine24.wav", rate, x)
|
||||
|
||||
"""
|
||||
a32 = _np.asarray(data, dtype=_np.int32)
|
||||
if a32.ndim == 1:
|
||||
# Convert to a 2D array with a single column.
|
||||
a32.shape = a32.shape + (1,)
|
||||
# By shifting first 0 bits, then 8, then 16, the resulting output
|
||||
# is 24 bit little-endian.
|
||||
a8 = (a32.reshape(a32.shape + (1,)) >> _np.array([0, 8, 16])) & 255
|
||||
wavdata = a8.astype(_np.uint8).tostring()
|
||||
|
||||
w = _wave.open(filename, 'wb')
|
||||
w.setnchannels(a32.shape[1])
|
||||
w.setsampwidth(3)
|
||||
w.setframerate(rate)
|
||||
w.writeframes(wavdata)
|
||||
w.close()
|
35
example.py
35
example.py
|
@ -1,35 +0,0 @@
|
|||
import warnings
|
||||
import json
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from dejavu import Dejavu
|
||||
from dejavu.recognize import FileRecognizer, MicrophoneRecognizer
|
||||
|
||||
# load config from a JSON file (or anything outputting a python dictionary)
|
||||
with open("dejavu.cnf.SAMPLE") as f:
|
||||
config = json.load(f)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# create a Dejavu instance
|
||||
djv = Dejavu(config)
|
||||
|
||||
# Fingerprint all the mp3's in the directory we give it
|
||||
djv.fingerprint_directory("mp3", [".mp3"])
|
||||
|
||||
# Recognize audio from a file
|
||||
song = djv.recognize(FileRecognizer, "mp3/Sean-Fournier--Falling-For-You.mp3")
|
||||
print "From file we recognized: %s\n" % song
|
||||
|
||||
# Or recognize audio from your microphone for `secs` seconds
|
||||
secs = 5
|
||||
song = djv.recognize(MicrophoneRecognizer, seconds=secs)
|
||||
if song is None:
|
||||
print "Nothing recognized -- did you play the song out loud so your mic could hear it? :)"
|
||||
else:
|
||||
print "From mic with %d seconds we recognized: %s\n" % (secs, song)
|
||||
|
||||
# Or use a recognizer without the shortcut, in anyway you would like
|
||||
recognizer = FileRecognizer(djv)
|
||||
song = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||
print "No shortcut, we recognized: %s\n" % song
|
34
example_script.py
Executable file
34
example_script.py
Executable file
|
@ -0,0 +1,34 @@
|
|||
import json
|
||||
|
||||
from dejavu import Dejavu
|
||||
from dejavu.logic.recognizer.file_recognizer import FileRecognizer
|
||||
from dejavu.logic.recognizer.microphone_recognizer import MicrophoneRecognizer
|
||||
|
||||
# load config from a JSON file (or anything outputting a python dictionary)
|
||||
with open("dejavu.cnf.SAMPLE") as f:
|
||||
config = json.load(f)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# create a Dejavu instance
|
||||
djv = Dejavu(config)
|
||||
|
||||
# Fingerprint all the mp3's in the directory we give it
|
||||
djv.fingerprint_directory("test", [".wav"])
|
||||
|
||||
# Recognize audio from a file
|
||||
results = djv.recognize(FileRecognizer, "mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||
print(f"From file we recognized: {results}\n")
|
||||
|
||||
# Or recognize audio from your microphone for `secs` seconds
|
||||
secs = 5
|
||||
results = djv.recognize(MicrophoneRecognizer, seconds=secs)
|
||||
if results is None:
|
||||
print("Nothing recognized -- did you play the song out loud so your mic could hear it? :)")
|
||||
else:
|
||||
print(f"From mic with {secs} seconds we recognized: {results}\n")
|
||||
|
||||
# Or use a recognizer without the shortcut, in anyway you would like
|
||||
recognizer = FileRecognizer(djv)
|
||||
results = recognizer.recognize_file("mp3/Josh-Woodward--I-Want-To-Destroy-Something-Beautiful.mp3")
|
||||
print(f"No shortcut, we recognized: {results}\n")
|
BIN
mp3/azan_test.wav
Normal file
BIN
mp3/azan_test.wav
Normal file
Binary file not shown.
|
@ -1,9 +1,7 @@
|
|||
# requirements file
|
||||
|
||||
### BEGIN ###
|
||||
pydub>=0.9.4
|
||||
PyAudio>=0.2.7
|
||||
numpy>=1.8.2
|
||||
scipy>=0.12.1
|
||||
matplotlib>=1.3.1
|
||||
### END ###
|
||||
pydub==0.23.1
|
||||
PyAudio==0.2.11
|
||||
numpy==1.17.2
|
||||
scipy==1.3.1
|
||||
matplotlib==3.1.1
|
||||
mysql-connector-python==8.0.17
|
||||
psycopg2==2.8.3
|
||||
|
|
280
run_tests.py
280
run_tests.py
|
@ -1,184 +1,166 @@
|
|||
from dejavu.testing import *
|
||||
from dejavu import Dejavu
|
||||
from optparse import OptionParser
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import shutil
|
||||
from os import makedirs
|
||||
from os.path import exists, join
|
||||
from shutil import rmtree
|
||||
|
||||
usage = "usage: %prog [options] TESTING_AUDIOFOLDER"
|
||||
parser = OptionParser(usage=usage, version="%prog 1.1")
|
||||
parser.add_option("--secs",
|
||||
action="store",
|
||||
dest="secs",
|
||||
default=5,
|
||||
type=int,
|
||||
help='Number of seconds starting from zero to test')
|
||||
parser.add_option("--results",
|
||||
action="store",
|
||||
dest="results_folder",
|
||||
default="./dejavu_test_results",
|
||||
help='Sets the path where the results are saved')
|
||||
parser.add_option("--temp",
|
||||
action="store",
|
||||
dest="temp_folder",
|
||||
default="./dejavu_temp_testing_files",
|
||||
help='Sets the path where the temp files are saved')
|
||||
parser.add_option("--log",
|
||||
action="store_true",
|
||||
dest="log",
|
||||
default=True,
|
||||
help='Enables logging')
|
||||
parser.add_option("--silent",
|
||||
action="store_false",
|
||||
dest="silent",
|
||||
default=False,
|
||||
help='Disables printing')
|
||||
parser.add_option("--log-file",
|
||||
dest="log_file",
|
||||
default="results-compare.log",
|
||||
help='Set the path and filename of the log file')
|
||||
parser.add_option("--padding",
|
||||
action="store",
|
||||
dest="padding",
|
||||
default=10,
|
||||
type=int,
|
||||
help='Number of seconds to pad choice of place to test from')
|
||||
parser.add_option("--seed",
|
||||
action="store",
|
||||
dest="seed",
|
||||
default=None,
|
||||
type=int,
|
||||
help='Random seed')
|
||||
options, args = parser.parse_args()
|
||||
test_folder = args[0]
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# set random seed if set by user
|
||||
set_seed(options.seed)
|
||||
from dejavu.tests.dejavu_test import (DejavuTest, autolabeldoubles,
|
||||
generate_test_files, log_msg, set_seed)
|
||||
|
||||
# ensure results folder exists
|
||||
try:
|
||||
os.stat(options.results_folder)
|
||||
except:
|
||||
os.mkdir(options.results_folder)
|
||||
|
||||
# set logging
|
||||
if options.log:
|
||||
logging.basicConfig(filename=options.log_file, level=logging.DEBUG)
|
||||
def main(seconds: int, results_folder: str, temp_folder: str, log: bool, silent: bool,
|
||||
log_file: str, padding: int, seed: int, src: str):
|
||||
|
||||
# set test seconds
|
||||
test_seconds = ['%dsec' % i for i in range(1, options.secs + 1, 1)]
|
||||
# set random seed if set by user
|
||||
set_seed(seed)
|
||||
|
||||
# generate testing files
|
||||
for i in range(1, options.secs + 1, 1):
|
||||
generate_test_files(test_folder, options.temp_folder,
|
||||
i, padding=options.padding)
|
||||
# ensure results folder exists
|
||||
if not exists(results_folder):
|
||||
makedirs(results_folder)
|
||||
|
||||
# scan files
|
||||
log_msg("Running Dejavu fingerprinter on files in %s..." % test_folder,
|
||||
log=options.log, silent=options.silent)
|
||||
# set logging
|
||||
if log:
|
||||
logging.basicConfig(filename=log_file, level=logging.DEBUG)
|
||||
|
||||
tm = time.time()
|
||||
djv = DejavuTest(options.temp_folder, test_seconds)
|
||||
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm),
|
||||
log=options.log, silent=options.silent)
|
||||
# set test seconds
|
||||
test_seconds = [f'{i}sec' for i in range(1, seconds + 1, 1)]
|
||||
|
||||
tests = 1 # djv
|
||||
n_secs = len(test_seconds)
|
||||
# generate testing files
|
||||
for i in range(1, seconds + 1, 1):
|
||||
generate_test_files(src, temp_folder, i, padding=padding)
|
||||
|
||||
# set result variables -> 4d variables
|
||||
all_match_counter = [[[0 for x in xrange(tests)] for x in xrange(3)] for x in xrange(n_secs)]
|
||||
all_matching_times_counter = [[[0 for x in xrange(tests)] for x in xrange(2)] for x in xrange(n_secs)]
|
||||
all_query_duration = [[[0 for x in xrange(tests)] for x in xrange(djv.n_lines)] for x in xrange(n_secs)]
|
||||
all_match_confidence = [[[0 for x in xrange(tests)] for x in xrange(djv.n_lines)] for x in xrange(n_secs)]
|
||||
# scan files
|
||||
log_msg(f"Running Dejavu fingerprinter on files in {src}...", log=log, silent=silent)
|
||||
|
||||
# group results by seconds
|
||||
for line in range(0, djv.n_lines):
|
||||
for col in range(0, djv.n_columns):
|
||||
# for dejavu
|
||||
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||||
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
|
||||
tm = time.time()
|
||||
djv = DejavuTest(temp_folder, test_seconds)
|
||||
log_msg(f"finished obtaining results from dejavu in {(time.time() - tm)}", log=log, silent=silent)
|
||||
|
||||
djv_match_result = djv.result_match[line][col]
|
||||
tests = 1 # djv
|
||||
n_secs = len(test_seconds)
|
||||
|
||||
if djv_match_result == 'yes':
|
||||
all_match_counter[col][0][0] += 1
|
||||
elif djv_match_result == 'no':
|
||||
all_match_counter[col][1][0] += 1
|
||||
else:
|
||||
all_match_counter[col][2][0] += 1
|
||||
# set result variables -> 4d variables
|
||||
all_match_counter = [[[0 for x in range(tests)] for x in range(3)] for x in range(n_secs)]
|
||||
all_matching_times_counter = [[[0 for x in range(tests)] for x in range(2)] for x in range(n_secs)]
|
||||
all_query_duration = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||
all_match_confidence = [[[0 for x in range(tests)] for x in range(djv.n_lines)] for x in range(n_secs)]
|
||||
|
||||
djv_match_acc = djv.result_matching_times[line][col]
|
||||
# group results by seconds
|
||||
for line in range(0, djv.n_lines):
|
||||
for col in range(0, djv.n_columns):
|
||||
# for dejavu
|
||||
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
|
||||
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
|
||||
|
||||
if djv_match_acc == 0 and djv_match_result == 'yes':
|
||||
all_matching_times_counter[col][0][0] += 1
|
||||
elif djv_match_acc != 0:
|
||||
all_matching_times_counter[col][1][0] += 1
|
||||
djv_match_result = djv.result_match[line][col]
|
||||
|
||||
# create plots
|
||||
djv.create_plots('Confidence', all_match_confidence, options.results_folder)
|
||||
djv.create_plots('Query duration', all_query_duration, options.results_folder)
|
||||
if djv_match_result == 'yes':
|
||||
all_match_counter[col][0][0] += 1
|
||||
elif djv_match_result == 'no':
|
||||
all_match_counter[col][1][0] += 1
|
||||
else:
|
||||
all_match_counter[col][2][0] += 1
|
||||
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(3) #
|
||||
width = 0.25 # the width of the bars
|
||||
djv_match_acc = djv.result_matching_times[line][col]
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 2.75])
|
||||
if djv_match_acc == 0 and djv_match_result == 'yes':
|
||||
all_matching_times_counter[col][0][0] += 1
|
||||
elif djv_match_acc != 0:
|
||||
all_matching_times_counter[col][1][0] += 1
|
||||
|
||||
means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
# create plots
|
||||
djv.create_plots('Confidence', all_match_confidence, results_folder)
|
||||
djv.create_plots('Query duration', all_query_duration, results_folder)
|
||||
|
||||
# add some
|
||||
ax.set_ylabel('Matching Percentage')
|
||||
ax.set_title('%s Matching Percentage' % test_seconds[sec])
|
||||
ax.set_xticks(ind + width)
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(3)
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
labels = ['yes','no','invalid']
|
||||
ax.set_xticklabels( labels )
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 2.75])
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
#ax.legend((rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
autolabeldoubles(rects1,ax)
|
||||
plt.grid()
|
||||
means_dvj = [round(x[0] * 100 / djv.n_lines, 1) for x in all_match_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
fig_name = os.path.join(options.results_folder, "matching_perc_%s.png" % test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
||||
# add some
|
||||
ax.set_ylabel('Matching Percentage')
|
||||
ax.set_title(f'{test_seconds[sec]} Matching Percentage')
|
||||
ax.set_xticks(ind + width)
|
||||
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(2) #
|
||||
width = 0.25 # the width of the bars
|
||||
labels = ['yes', 'no', 'invalid']
|
||||
ax.set_xticklabels(labels)
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1*width, 1.75])
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
autolabeldoubles(rects1, ax)
|
||||
plt.grid()
|
||||
|
||||
div = all_match_counter[sec][0][0]
|
||||
if div == 0 :
|
||||
div = 1000000
|
||||
fig_name = join(results_folder, f"matching_perc_{test_seconds[sec]}.png")
|
||||
fig.savefig(fig_name)
|
||||
|
||||
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
for sec in range(0, n_secs):
|
||||
ind = np.arange(2)
|
||||
width = 0.25 # the width of the bars
|
||||
|
||||
# add some
|
||||
ax.set_ylabel('Matching Accuracy')
|
||||
ax.set_title('%s Matching Times Accuracy' % test_seconds[sec])
|
||||
ax.set_xticks(ind + width)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
ax.set_xlim([-1 * width, 1.75])
|
||||
|
||||
labels = ['yes','no']
|
||||
ax.set_xticklabels( labels )
|
||||
div = all_match_counter[sec][0][0]
|
||||
if div == 0:
|
||||
div = 1000000
|
||||
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
means_dvj = [round(x[0] * 100 / div, 1) for x in all_matching_times_counter[sec]]
|
||||
rects1 = ax.bar(ind, means_dvj, width, color='r')
|
||||
|
||||
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
autolabeldoubles(rects1,ax)
|
||||
# add some
|
||||
ax.set_ylabel('Matching Accuracy')
|
||||
ax.set_title(f'{test_seconds[sec]} Matching Times Accuracy')
|
||||
ax.set_xticks(ind + width)
|
||||
|
||||
plt.grid()
|
||||
labels = ['yes', 'no']
|
||||
ax.set_xticklabels(labels)
|
||||
|
||||
fig_name = os.path.join(options.results_folder, "matching_acc_%s.png" % test_seconds[sec])
|
||||
fig.savefig(fig_name)
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
|
||||
autolabeldoubles(rects1, ax)
|
||||
|
||||
# remove temporary folder
|
||||
shutil.rmtree(options.temp_folder)
|
||||
plt.grid()
|
||||
|
||||
fig_name = join(results_folder, f"matching_acc_{test_seconds[sec]}.png")
|
||||
fig.savefig(fig_name)
|
||||
|
||||
# remove temporary folder
|
||||
rmtree(temp_folder)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description=f'Runs a few tests for dejavu to evaluate '
|
||||
f'its configuration performance. '
|
||||
f'Usage: %(prog).py [options] TESTING_AUDIOFOLDER'
|
||||
)
|
||||
|
||||
parser.add_argument("-sec", "--seconds", action="store", default=5, type=int,
|
||||
help='Number of seconds starting from zero to test.')
|
||||
parser.add_argument("-res", "--results-folder", action="store", default="./dejavu_test_results",
|
||||
help='Sets the path where the results are saved.')
|
||||
parser.add_argument("-temp", "--temp-folder", action="store", default="./dejavu_temp_testing_files",
|
||||
help='Sets the path where the temp files are saved.')
|
||||
parser.add_argument("-l", "--log", action="store_true", default=False, help='Enables logging.')
|
||||
parser.add_argument("-sl", "--silent", action="store_false", default=False, help='Disables printing.')
|
||||
parser.add_argument("-lf", "--log-file", default="results-compare.log",
|
||||
help='Set the path and filename of the log file.')
|
||||
parser.add_argument("-pad", "--padding", action="store", default=10, type=int,
|
||||
help='Number of seconds to pad choice of place to test from.')
|
||||
parser.add_argument("-sd", "--seed", action="store", default=None, type=int, help='Random seed.')
|
||||
parser.add_argument("src", type=str, help='Source folder for audios to use as tests.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.seconds, args.results_folder, args.temp_folder, args.log, args.silent, args.log_file, args.padding,
|
||||
args.seed, args.src)
|
||||
|
|
3
setup.cfg
Normal file
3
setup.cfg
Normal file
|
@ -0,0 +1,3 @@
|
|||
[flake8]
|
||||
max-line-length = 120
|
||||
|
10
setup.py
10
setup.py
|
@ -1,5 +1,4 @@
|
|||
from setuptools import setup, find_packages
|
||||
# import os, sys
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements(requirements):
|
||||
|
@ -7,13 +6,14 @@ def parse_requirements(requirements):
|
|||
with open(requirements) as f:
|
||||
lines = [l for l in f]
|
||||
# remove spaces
|
||||
stripped = map((lambda x: x.strip()), lines)
|
||||
stripped = list(map((lambda x: x.strip()), lines))
|
||||
# remove comments
|
||||
nocomments = filter((lambda x: not x.startswith('#')), stripped)
|
||||
nocomments = list(filter((lambda x: not x.startswith('#')), stripped))
|
||||
# remove empty lines
|
||||
reqs = filter((lambda x: x), nocomments)
|
||||
reqs = list(filter((lambda x: x), nocomments))
|
||||
return reqs
|
||||
|
||||
|
||||
PACKAGE_NAME = "PyDejavu"
|
||||
PACKAGE_VERSION = "0.1.3"
|
||||
SUMMARY = 'Dejavu: Audio Fingerprinting in Python'
|
||||
|
|
BIN
test/sean_secs.wav
Normal file
BIN
test/sean_secs.wav
Normal file
Binary file not shown.
BIN
test/woodward_43s.wav
Normal file
BIN
test/woodward_43s.wav
Normal file
Binary file not shown.
Loading…
Reference in a new issue