dejavu/test.py

346 lines
11 KiB
Python
Executable file

# result generator for dejavu
# TODO: Don't work very well with musics with special chars.
# use test file on the format below, with no special chars and only one "-" to separate artist from song
import os, subprocess, json, re, sys
import logging, time
from os import listdir
from os.path import isfile, join
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from optparse import OptionParser
from dejavu.decoder import path_to_songname
import ast
#####
### Test files are in specific format:
### 'artist_name'-'song_name'_'start_time'_'duration'sec.wav
#####
DEFAULT_FS = 44100
DEFAULT_WINDOW_SIZE = 4096
DEFAULT_OVERLAP_RATIO = 0.5
FIELD_SONG_NAME = 'song_name'
FIELD_CONFIDENCE = 'confidence'
FIELD_QUERY_TIME = 'match_time'
FIELD_OFFSET = 'offset'
# Parse options
usage = "usage: %prog [options] DEJAVU_PATH TEST_FOLDER"
parser = OptionParser(usage=usage, version="%prog 1.1")
parser.add_option("--no-log",
action="store_false",
dest="log",
default=True,
help='Disables logging')
parser.add_option("--log-file",
dest="log_file",
default="results-compare.log",
metavar="LOG_FILE",
help='Set the path and filename of the log file')
parser.add_option("--test-seconds",
action="append",
dest="test_seconds",
default=[],
metavar="Xsec",
help='Appends seconds to test suit')
parser.add_option("--results-folder",
action="store",
dest="results_folder",
metavar="FOLDER",
help='Sets the path where the results are saved')
(options, args) = parser.parse_args()
if len(args) != 2:
parser.error("wrong number of arguments")
if len(options.test_seconds) == 0:
options.test_seconds = ['1sec','2sec','3sec','4sec','5sec','6sec','7sec','8sec','9sec','10sec']
if options.log == True:
logging.basicConfig(filename=options.log_file, level=logging.DEBUG)
if options.results_folder != "" and options.results_folder[len(options.results_folder) - 1] != '/':
options.results_folder += "/"
# ensure results folder exists
try:
os.stat(options.results_folder)
except:
os.mkdir(options.results_folder)
def log_msg(msg):
if options.log == True:
logging.debug(msg)
class DejavuTest (object):
def __init__(self, folder, seconds):
super(DejavuTest, self).__init__()
self.test_folder = folder
self.test_seconds = seconds
self.test_songs = []
self.test_files = [ f for f in listdir(self.test_folder) if isfile(join(self.test_folder,f))
and re.findall("[0-9]*sec",f)[0] in self.test_seconds ]
self.n_columns = len(self.test_seconds)
self.n_lines = len(self.test_files) / self.n_columns
# variable match results (yes, no, invalid)
self.result_match = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
print "columns:", self.n_columns
print "length of test files:", len(self.test_files)
print "lines:", self.n_lines
print "result_match matrix:", self.result_match
# variable match precision (if matched in the corrected time)
self.result_matching_times = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
# variable mahing time (query time)
self.result_query_duration = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
# variable confidence
self.result_match_confidence = [[0 for x in xrange(self.n_columns)] for x in xrange(self.n_lines)]
self.begin()
def get_column_id (self, secs):
for i, sec in enumerate(self.test_seconds):
if secs == sec:
return i
def get_line_id (self, song):
for i, s in enumerate(self.test_songs):
if song == s:
return i
self.test_songs.append(song)
return len(self.test_songs) - 1
def begin(self):
for f in self.test_files:
log_msg('--------------------------------------------------')
log_msg('file: %s' % f)
# get column
col = self.get_column_id(re.findall("[0-9]*sec",f)[0])
song = path_to_songname(f).split("_")[0] # format: XXXX_offset_length.mp3
line = self.get_line_id (song)
result = subprocess.check_output(["python", args[0] + "/dejavu.py", 'recognize', 'file', self.test_folder+"/"+f])
log_msg('RESULT: %s' % result.strip() )
if result.strip() == "None":
log_msg('No match')
self.result_match[line][col] = 'no'
self.result_matching_times[line][col] = 0
self.result_query_duration[line][col] = 0
self.result_match_confidence[line][col] = 0
else:
result = result.strip()
result = result.replace(" \'", ' "')
result = result.replace("{\'", '{"')
result = result.replace("\':", '":')
result = result.replace("\',", '",')
# which song did we predict?
result = ast.literal_eval(result)
print "result", result
song_result = result["song_name"]
log_msg('song: %s' % song)
log_msg('song_result: %s' % song_result)
if song_result != song:
log_msg('invalid match')
self.result_match[line][col] = 'invalid'
self.result_matching_times[line][col] = 0
self.result_query_duration[line][col] = 0
self.result_match_confidence[line][col] = 0
else:
log_msg('correct match')
print self.result_match
self.result_match[line][col] = 'yes'
self.result_query_duration[line][col] = round(result[FIELD_QUERY_TIME],3)
self.result_match_confidence[line][col] = result[FIELD_CONFIDENCE]
song_start_time = re.findall("\_[^\_]+",f)
song_start_time = song_start_time[0].lstrip("_ ")
#result_start_time = round((result[FIELD_SONG_DURATION] * result[FIELD_OFFSET]) / float(result[FIELD_SONG_SPEC_DURATION]), 0)
result_start_time = round((result[FIELD_OFFSET] * DEFAULT_WINDOW_SIZE * DEFAULT_OVERLAP_RATIO) / (DEFAULT_FS),0)
self.result_matching_times[line][col] = int(result_start_time) - int(song_start_time)
if (abs(self.result_matching_times[line][col]) == 1):
self.result_matching_times[line][col] = 0
log_msg('query duration: %s' % round(result[FIELD_QUERY_TIME],3))
log_msg('confidence: %s' % result[FIELD_CONFIDENCE])
log_msg('song start_time: %s' % song_start_time)
log_msg('result start time: %s' % result_start_time)
if (self.result_matching_times[line][col] == 0):
log_msg('accurate match')
else:
log_msg('inaccurate match')
log_msg('--------------------------------------------------\n')
print "obtaining results from dejavu"
log_msg('obtaining results from dejavu')
tm = time.time()
djv = DejavuTest(args[1], options.test_seconds)
print "finished obtaining results from dejavu in %s" % (time.time() - tm)
log_msg("finished obtaining results from dejavu in %s" % (time.time() - tm))
tests_n_lines = djv.n_lines
tests_n_columns = djv.n_columns # len(options.test_seconds)
tests = 1 # djv
n_secs = len(options.test_seconds) # = tests.n_columns
# set result variables -> 4d variables
all_match_counter = [[[0 for x in xrange(tests)] for x in xrange(3)] for x in xrange(n_secs)]
all_matching_times_counter = [[[0 for x in xrange(tests)] for x in xrange(2)] for x in xrange(n_secs)]
all_query_duration = [[[0 for x in xrange(tests)] for x in xrange(tests_n_lines)] for x in xrange(n_secs)]
all_match_confidence = [[[0 for x in xrange(tests)] for x in xrange(tests_n_lines)] for x in xrange(n_secs)]
# agroup results by seconds
for line in range(0, tests_n_lines):
for col in range(0, tests_n_columns):
# for dejavu
all_query_duration[col][line][0] = djv.result_query_duration[line][col]
all_match_confidence[col][line][0] = djv.result_match_confidence[line][col]
djv_match_result = djv.result_match[line][col]
if djv_match_result == 'yes':
all_match_counter[col][0][0] += 1
elif djv_match_result == 'no':
all_match_counter[col][1][0] += 1
else:
all_match_counter[col][2][0] += 1
djv_match_acc = djv.result_matching_times[line][col]
if djv_match_acc == 0 and djv_match_result == 'yes':
all_matching_times_counter[col][0][0] += 1
elif djv_match_acc != 0:
all_matching_times_counter[col][1][0] += 1
def autolabel(rects,ax):
# attach some text labels
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x()+rect.get_width()/2., 1.05*height, '%d'%int(height),
ha='center', va='bottom')
def autolabeldoubles(rects,ax):
# attach some text labels
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x()+rect.get_width()/2., 1.05*height, '%s'%round(float(height),3),
ha='center', va='bottom')
def create_plots(name,results):
for sec in range(0,n_secs):
ind = np.arange(tests_n_lines) #
width = 0.25 # the width of the bars
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_xlim([-1*width, 2*width])
means_dvj = [x[0] for x in results[sec]]
rects1 = ax.bar(ind, means_dvj, width, color='r')
# add some
ax.set_ylabel(name)
ax.set_title("%s %s Results" % (options.test_seconds[sec],name))
ax.set_xticks(ind+width)
labels = [0 for x in range(0,tests_n_lines)]
for x in range(0,tests_n_lines):
labels[x] = "song %s" % (x+1)
ax.set_xticklabels( labels )
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
if name == 'Confidence':
autolabel(rects1,ax)
else:
autolabeldoubles(rects1,ax)
plt.grid()
fig_name = "%s%s_%s.png" % (options.results_folder,name,options.test_seconds[sec])
fig.savefig(fig_name)
create_plots('Confidence',all_match_confidence)
create_plots('Query duration',all_query_duration)
for sec in range(0,n_secs):
ind = np.arange(3) #
width = 0.25 # the width of the bars
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_xlim([-1*width, 2.75])
means_dvj = [round(x[0]*100/tests_n_lines,1) for x in all_match_counter[sec]]
rects1 = ax.bar(ind, means_dvj, width, color='r')
# add some
ax.set_ylabel('Matching Percentage')
ax.set_title('%s Matching Percentage' % options.test_seconds[sec])
ax.set_xticks(ind+width)
labels = ['yes','no','invalid']
ax.set_xticklabels( labels )
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
autolabeldoubles(rects1,ax)
plt.grid()
fig_name = "%smatching_perc_%s.png" % (options.results_folder,options.test_seconds[sec])
fig.savefig(fig_name)
for sec in range(0,n_secs):
ind = np.arange(2) #
width = 0.25 # the width of the bars
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_xlim([-1*width, 1.75])
div = all_match_counter[sec][0][0]
if div == 0 :
div = 1000000
means_dvj = [round(x[0]*100/div,1) for x in all_matching_times_counter[sec]]
rects1 = ax.bar(ind, means_dvj, width, color='r')
# add some
ax.set_ylabel('Matching Accuracy')
ax.set_title('%s Matching Times Accuracy' % options.test_seconds[sec])
ax.set_xticks(ind+width)
labels = ['yes','no']
ax.set_xticklabels( labels )
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.75, box.height])
#ax.legend( (rects1[0]), ('Dejavu'), loc='center left', bbox_to_anchor=(1, 0.5))
autolabeldoubles(rects1,ax)
plt.grid()
fig_name = "%smatching_acc_%s.png" % (options.results_folder,options.test_seconds[sec])
fig.savefig(fig_name)