Merge remote-tracking branch 'wessie/master'

This commit is contained in:
Vin 2013-12-17 00:46:17 +00:00
commit 9faf2cd591
3 changed files with 243 additions and 206 deletions

BIN
README.md

Binary file not shown.

52
dejavu/cursor.py Normal file
View file

@ -0,0 +1,52 @@
from __future__ import unicode_literals
from __future__ import absolute_import
import Queue
import pymysql
import pymysql.cursors
def cursor_factory(**factory_options):
def cursor(**options):
options.update(factory_options)
return Cursor(**options)
return cursor
class Cursor(object):
"""
Establishes a connection to the database and returns an open cursor.
```python
# Use as context manager
with Cursor() as cur:
cur.execute(query)
```
"""
_cache = Queue.Queue(maxsize=5)
def __init__(self, cursor_type=pymysql.cursors.DictCursor, **options):
super(Cursor, self).__init__()
try:
conn = self._cache.get_nowait()
except Queue.Empty:
conn = pymysql.connect(**options)
self.conn = conn
self.cursor_type = cursor_type
def __enter__(self):
self.cursor = self.conn.cursor(self.cursor_type)
return self.cursor
def __exit__(self, type, value, traceback):
self.cursor.close()
self.conn.commit()
# Put it back on the queue
try:
self._cache.put_nowait(self.conn)
except Queue.Full:
self.conn.close()

View file

@ -1,8 +1,12 @@
import MySQLdb as mysql
import MySQLdb.cursors as cursors
import os
from __future__ import absolute_import
from binascii import unhexlify
class SQLDatabase():
class Database(object):
def __init__(self):
super(Database, self).__init__()
class SQLDatabase(Database):
"""
Queries:
@ -64,9 +68,11 @@ class SQLDatabase():
`%s` int unsigned not null,
INDEX(%s),
UNIQUE(%s, %s, %s)
);""" % (FINGERPRINTS_TABLENAME, FIELD_HASH,
);""" % (
FINGERPRINTS_TABLENAME, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH)
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
)
CREATE_SONGS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
@ -75,194 +81,165 @@ class SQLDatabase():
`%s` tinyint default 0,
PRIMARY KEY (`%s`),
UNIQUE KEY `%s` (`%s`)
);""" % (SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID)
);""" % (
SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID,
)
# inserts (ignores duplicates)
INSERT_FINGERPRINT = """
INSERT IGNORE INTO %s (%s, %s, %s) VALUES
(UNHEX(%%s), %%s, %%s);
""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET)
# inserts
INSERT_FINGERPRINT = "INSERT IGNORE INTO %s (%s, %s, %s) VALUES (UNHEX(%%s), %%s, %%s)" % (
FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) # ignore duplicates and don't insert them
INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % (
SONGS_TABLENAME, FIELD_SONGNAME)
# selects
SELECT = "SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_ALL = "SELECT %s, %s FROM %s;" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
SELECT_SONG = "SELECT %s FROM %s WHERE %s = %%s" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
SELECT_NUM_FINGERPRINTS = "SELECT COUNT(*) as n FROM %s" % (FINGERPRINTS_TABLENAME)
SELECT = """
SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_UNIQUE_SONG_IDS = "SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_SONGS = "SELECT %s, %s FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_MULTIPLE = """
SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s);
""" % (FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET,
FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_ALL = """
SELECT %s, %s FROM %s;
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
SELECT_SONG = """
SELECT %s FROM %s WHERE %s = %%s
""" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
SELECT_NUM_FINGERPRINTS = """
SELECT COUNT(*) as n FROM %s
""" % (FINGERPRINTS_TABLENAME)
SELECT_UNIQUE_SONG_IDS = """
SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;
""" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_SONGS = """
SELECT %s, %s FROM %s WHERE %s = 1;
""" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
# drops
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
# update
UPDATE_SONG_FINGERPRINTED = "UPDATE %s SET %s = 1 WHERE %s = %%s" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
UPDATE_SONG_FINGERPRINTED = """
UPDATE %s SET %s = 1 WHERE %s = %%s
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
# delete
DELETE_UNFINGERPRINTED = "DELETE FROM %s WHERE %s = 0;" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
DELETE_UNFINGERPRINTED = """
DELETE FROM %s WHERE %s = 0;
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
DELETE_ORPHANS = """
delete from fingerprints
where not exists (
select * from songs where fingerprints.song_id = songs.song_id
)"""
);
"""
def __init__(self, hostname, username, password, database):
# connect
self.database = database
try:
# http://www.halfcooked.com/mt/archives/000969.html
self.connection = mysql.connect(
hostname, username, password,
database, cursorclass=cursors.DictCursor)
self.connection.autocommit(False) # for fast bulk inserts
self.cursor = self.connection.cursor()
except mysql.Error, e:
print "Connection error %d: %s" % (e.args[0], e.args[1])
def __init__(self, cursor):
super(SQLDatabase, self).__init__()
self.cursor = cursor
def setup(self):
try:
# create fingerprints table
self.cursor.execute("USE %s;" % self.database)
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE)
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
self.delete_unfingerprinted_songs()
self.connection.commit()
except mysql.Error, e:
print "Connection error %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
with self.cursor() as cur:
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.CREATE_SONGS_TABLE)
cur.execute(self.DELETE_UNFINGERPRINTED)
def empty(self):
"""
Drops all tables and re-adds them. Be carfeul with this!
"""
try:
self.cursor.execute("USE %s;" % self.database)
with self.cursor() as cur:
cur.execute(self.DROP_FINGERPRINTS)
cur.execute(self.DROP_SONGS)
# drop tables
self.cursor.execute(SQLDatabase.DROP_FINGERPRINTS)
self.cursor.execute(SQLDatabase.DROP_SONGS)
# recreate
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE)
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
self.connection.commit()
except mysql.Error, e:
print "Error in empty(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
self.setup()
def delete_orphans(self):
try:
self.cursor = self.connection.cursor()
### TODO: SQLDatabase.DELETE_ORPHANS is not performant enough, need better query
### to delete fingerprints for which no song is tied to.
#self.cursor.execute(SQLDatabase.DELETE_ORPHANS)
#self.connection.commit()
except mysql.Error, e:
print "Error in delete_orphans(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
# TODO: SQLDatabase.DELETE_ORPHANS is not
# performant enough, need better query to
# delete fingerprints for which no song is tied to.
# with self.cursor() as cur:
# cur.execute(self.DELETE_ORPHANS)
pass
def delete_unfingerprinted_songs(self):
try:
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.DELETE_UNFINGERPRINTED)
self.connection.commit()
except mysql.Error, e:
print "Error in delete_unfingerprinted_songs(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
with self.cursor() as cur:
cur.execute(self.DELETE_UNFINGERPRINTED)
def get_num_songs(self):
"""
Returns number of songs the database has fingerprinted.
"""
try:
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.SELECT_UNIQUE_SONG_IDS)
record = self.cursor.fetchone()
return int(record['n'])
except mysql.Error, e:
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
with self.cursor() as cur:
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
for row in cur:
return row['n']
def get_num_fingerprints(self):
"""
Returns number of fingerprints the database has fingerprinted.
"""
try:
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.SELECT_NUM_FINGERPRINTS)
record = self.cursor.fetchone()
return int(record['n'])
except mysql.Error, e:
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
with self.cursor() as cur:
cur.execute(self.SELECT_NUM_FINGERPRINTS)
for row in cur:
return row['n']
def set_song_fingerprinted(self, song_id):
def set_song_fingerprinted(self, sid):
"""
Set the fingerprinted flag to TRUE (1) once a song has been completely
fingerprinted in the database.
"""
try:
self.cursor = self.connection.cursor()
self.cursor.execute(SQLDatabase.UPDATE_SONG_FINGERPRINTED, song_id)
self.connection.commit()
except mysql.Error, e:
print "Error in set_song_fingerprinted(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
with self.cursor() as cur:
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
def get_songs(self):
"""
Return songs that have the fingerprinted flag set TRUE (1).
"""
try:
self.cursor.execute(SQLDatabase.SELECT_SONGS)
return self.cursor.fetchall()
except mysql.Error, e:
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
return None
with self.cursor() as cur:
cur.execute(self.SELECT_SONGS)
for row in cur:
yield row
def get_song_by_id(self, sid):
"""
Returns song by its ID.
"""
try:
self.cursor.execute(SQLDatabase.SELECT_SONG, (sid,))
return self.cursor.fetchone()
except mysql.Error, e:
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
return None
with self.cursor() as cur:
cur.execute(self.SELECT_SONG, (sid,))
return cur.fetchone()
def insert(self, key, value):
def insert(self, hash, sid, offset):
"""
Insert a (sha1, song_id, offset) row into database.
key is a sha1 hash, value = (song_id, offset)
"""
try:
args = (key, value[0], value[1])
self.cursor.execute(SQLDatabase.INSERT_FINGERPRINT, args)
except mysql.Error, e:
print "Error in insert(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
with self.cursor() as cur:
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
def insert_song(self, songname):
"""
Inserts song in the database and returns the ID of the inserted record.
"""
try:
self.cursor.execute(SQLDatabase.INSERT_SONG, (songname,))
self.connection.commit()
return int(self.cursor.lastrowid)
except mysql.Error, e:
print "Error in insert_song(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
return None
with self.cursor() as cur:
cur.execute(self.INSERT_SONG, (songname,))
return cur.lastrowid
def query(self, key):
def query(self, hash):
"""
Return all tuples associated with hash.
@ -270,24 +247,12 @@ class SQLDatabase():
database (be careful with that one!).
"""
# select all if no key
if key is not None:
sql = SQLDatabase.SELECT
else:
sql = SQLDatabase.SELECT_ALL
query = self.SELECT_ALL if hash is None else self.SELECT
matches = []
try:
self.cursor.execute(sql, (key,))
# collect all matches
records = self.cursor.fetchall()
for record in records:
matches.append((record[SQLDatabase.FIELD_SONG_ID], record[SQLDatabase.FIELD_OFFSET]))
except mysql.Error, e:
print "Error in query(), %d: %s" % (e.args[0], e.args[1])
return matches
with self.cursor() as cur:
cur.execute(query)
for row in cur:
yield (row[self.FIELD_SONG_ID], row[self.FIELD_OFFSET])
def get_iterable_kv_pairs(self):
"""
@ -300,10 +265,13 @@ class SQLDatabase():
Insert series of hash => song_id, offset
values into the database.
"""
for h in hashes:
sha1, val = h
self.insert(sha1, val)
self.connection.commit()
# TODO: Fix this when hashes will be a new format.
values = []
for hash, (sid, offset) in hashes:
values.append((hash, sid, offset))
with self.cursor() as cur:
cur.executemany(self.INSERT_FINGERPRINT, values)
def return_matches(self, hashes):
"""
@ -314,12 +282,29 @@ class SQLDatabase():
values.
"""
matches = []
for h in hashes:
sha1, val = h
list_of_tups = self.query(sha1)
if list_of_tups:
for t in list_of_tups:
# (song_id, db_offset, song_sampled_offset)
matches.append((t[0], t[1] - val[1]))
return matches
from pymysql.cursors import Cursor
# Create a dictionary of hash => offset pairs for later lookups
mapper = {}
for hash, (_, offset) in hashes:
mapper[hash.upper()] = offset
# Get an iteratable of all the hashes we need
values = mapper.keys()
with self.cursor(cursor_type=Cursor) as cur:
for split_values in grouper(values, 1000):
# Create our IN part of the query
query = self.SELECT_MULTIPLE
query = query % ', '.join(['UNHEX(%s)'] * len(split_values))
cur.execute(query, split_values)
for hash, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hash])
from itertools import izip_longest
def grouper(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return izip_longest(fillvalue=fillvalue, *args)