Cleaned up the database driver.

- The SQLDatabase class now uses a context manager for mysql access.
- Most of the error handling is done by the context manager now
- Optimized several methods that returned a list into returning a generator
- Optimized return_matches to use an IN query instead.
- Other small fixes.
This commit is contained in:
Wessie 2013-12-17 01:39:03 +01:00
parent f918b6b7e0
commit 0bd7219b87
2 changed files with 243 additions and 206 deletions

52
dejavu/cursor.py Normal file
View file

@ -0,0 +1,52 @@
from __future__ import unicode_literals
from __future__ import absolute_import
import Queue
import pymysql
import pymysql.cursors
def cursor_factory(**factory_options):
def cursor(**options):
options.update(factory_options)
return Cursor(**options)
return cursor
class Cursor(object):
"""
Establishes a connection to the database and returns an open cursor.
```python
# Use as context manager
with Cursor() as cur:
cur.execute(query)
```
"""
_cache = Queue.Queue(maxsize=5)
def __init__(self, cursor_type=pymysql.cursors.DictCursor, **options):
super(Cursor, self).__init__()
try:
conn = self._cache.get_nowait()
except Queue.Empty:
conn = pymysql.connect(**options)
self.conn = conn
self.cursor_type = cursor_type
def __enter__(self):
self.cursor = self.conn.cursor(self.cursor_type)
return self.cursor
def __exit__(self, type, value, traceback):
self.cursor.close()
self.conn.commit()
# Put it back on the queue
try:
self._cache.put_nowait(self.conn)
except Queue.Full:
self.conn.close()

View file

@ -1,8 +1,12 @@
import MySQLdb as mysql from __future__ import absolute_import
import MySQLdb.cursors as cursors from binascii import unhexlify
import os
class SQLDatabase(): class Database(object):
def __init__(self):
super(Database, self).__init__()
class SQLDatabase(Database):
""" """
Queries: Queries:
@ -64,9 +68,11 @@ class SQLDatabase():
`%s` int unsigned not null, `%s` int unsigned not null,
INDEX(%s), INDEX(%s),
UNIQUE(%s, %s, %s) UNIQUE(%s, %s, %s)
);""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, );""" % (
FINGERPRINTS_TABLENAME, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH) FIELD_SONG_ID, FIELD_OFFSET, FIELD_HASH,
)
CREATE_SONGS_TABLE = """ CREATE_SONGS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` ( CREATE TABLE IF NOT EXISTS `%s` (
@ -75,194 +81,165 @@ class SQLDatabase():
`%s` tinyint default 0, `%s` tinyint default 0,
PRIMARY KEY (`%s`), PRIMARY KEY (`%s`),
UNIQUE KEY `%s` (`%s`) UNIQUE KEY `%s` (`%s`)
);""" % (SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED, );""" % (
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID) SONGS_TABLENAME, FIELD_SONG_ID, FIELD_SONGNAME, FIELD_FINGERPRINTED,
FIELD_SONG_ID, FIELD_SONG_ID, FIELD_SONG_ID,
)
# inserts (ignores duplicates)
INSERT_FINGERPRINT = """
INSERT IGNORE INTO %s (%s, %s, %s) VALUES
(UNHEX(%%s), %%s, %%s);
""" % (FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET)
# inserts
INSERT_FINGERPRINT = "INSERT IGNORE INTO %s (%s, %s, %s) VALUES (UNHEX(%%s), %%s, %%s)" % (
FINGERPRINTS_TABLENAME, FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET) # ignore duplicates and don't insert them
INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % ( INSERT_SONG = "INSERT INTO %s (%s) VALUES (%%s);" % (
SONGS_TABLENAME, FIELD_SONGNAME) SONGS_TABLENAME, FIELD_SONGNAME)
# selects # selects
SELECT = "SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH) SELECT = """
SELECT_ALL = "SELECT %s, %s FROM %s;" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME) SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);
SELECT_SONG = "SELECT %s FROM %s WHERE %s = %%s" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID) """ % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_NUM_FINGERPRINTS = "SELECT COUNT(*) as n FROM %s" % (FINGERPRINTS_TABLENAME)
SELECT_UNIQUE_SONG_IDS = "SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED) SELECT_MULTIPLE = """
SELECT_SONGS = "SELECT %s, %s FROM %s WHERE %s = 1;" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED) SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s);
""" % (FIELD_HASH, FIELD_SONG_ID, FIELD_OFFSET,
FINGERPRINTS_TABLENAME, FIELD_HASH)
SELECT_ALL = """
SELECT %s, %s FROM %s;
""" % (FIELD_SONG_ID, FIELD_OFFSET, FINGERPRINTS_TABLENAME)
SELECT_SONG = """
SELECT %s FROM %s WHERE %s = %%s
""" % (FIELD_SONGNAME, SONGS_TABLENAME, FIELD_SONG_ID)
SELECT_NUM_FINGERPRINTS = """
SELECT COUNT(*) as n FROM %s
""" % (FINGERPRINTS_TABLENAME)
SELECT_UNIQUE_SONG_IDS = """
SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;
""" % (FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_SONGS = """
SELECT %s, %s FROM %s WHERE %s = 1;
""" % (FIELD_SONG_ID, FIELD_SONGNAME, SONGS_TABLENAME, FIELD_FINGERPRINTED)
# drops # drops
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
# update # update
UPDATE_SONG_FINGERPRINTED = "UPDATE %s SET %s = 1 WHERE %s = %%s" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID) UPDATE_SONG_FINGERPRINTED = """
UPDATE %s SET %s = 1 WHERE %s = %%s
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, FIELD_SONG_ID)
# delete # delete
DELETE_UNFINGERPRINTED = "DELETE FROM %s WHERE %s = 0;" % (SONGS_TABLENAME, FIELD_FINGERPRINTED) DELETE_UNFINGERPRINTED = """
DELETE FROM %s WHERE %s = 0;
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
DELETE_ORPHANS = """ DELETE_ORPHANS = """
delete from fingerprints delete from fingerprints
where not exists ( where not exists (
select * from songs where fingerprints.song_id = songs.song_id select * from songs where fingerprints.song_id = songs.song_id
)""" );
"""
def __init__(self, hostname, username, password, database): def __init__(self, cursor):
# connect super(SQLDatabase, self).__init__()
self.database = database self.cursor = cursor
try:
# http://www.halfcooked.com/mt/archives/000969.html
self.connection = mysql.connect(
hostname, username, password,
database, cursorclass=cursors.DictCursor)
self.connection.autocommit(False) # for fast bulk inserts
self.cursor = self.connection.cursor()
except mysql.Error, e:
print "Connection error %d: %s" % (e.args[0], e.args[1])
def setup(self): def setup(self):
try: with self.cursor() as cur:
# create fingerprints table cur.execute(self.CREATE_FINGERPRINTS_TABLE)
self.cursor.execute("USE %s;" % self.database) cur.execute(self.CREATE_SONGS_TABLE)
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE) cur.execute(self.DELETE_UNFINGERPRINTED)
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
self.delete_unfingerprinted_songs()
self.connection.commit()
except mysql.Error, e:
print "Connection error %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def empty(self): def empty(self):
""" """
Drops all tables and re-adds them. Be carfeul with this! Drops all tables and re-adds them. Be carfeul with this!
""" """
try: with self.cursor() as cur:
self.cursor.execute("USE %s;" % self.database) cur.execute(self.DROP_FINGERPRINTS)
cur.execute(self.DROP_SONGS)
# drop tables self.setup()
self.cursor.execute(SQLDatabase.DROP_FINGERPRINTS)
self.cursor.execute(SQLDatabase.DROP_SONGS)
# recreate
self.cursor.execute(SQLDatabase.CREATE_FINGERPRINTS_TABLE)
self.cursor.execute(SQLDatabase.CREATE_SONGS_TABLE)
self.connection.commit()
except mysql.Error, e:
print "Error in empty(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def delete_orphans(self): def delete_orphans(self):
try: # TODO: SQLDatabase.DELETE_ORPHANS is not
self.cursor = self.connection.cursor() # performant enough, need better query to
### TODO: SQLDatabase.DELETE_ORPHANS is not performant enough, need better query # delete fingerprints for which no song is tied to.
### to delete fingerprints for which no song is tied to.
#self.cursor.execute(SQLDatabase.DELETE_ORPHANS) # with self.cursor() as cur:
#self.connection.commit() # cur.execute(self.DELETE_ORPHANS)
except mysql.Error, e: pass
print "Error in delete_orphans(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def delete_unfingerprinted_songs(self): def delete_unfingerprinted_songs(self):
try: with self.cursor() as cur:
self.cursor = self.connection.cursor() cur.execute(self.DELETE_UNFINGERPRINTED)
self.cursor.execute(SQLDatabase.DELETE_UNFINGERPRINTED)
self.connection.commit()
except mysql.Error, e:
print "Error in delete_unfingerprinted_songs(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def get_num_songs(self): def get_num_songs(self):
""" """
Returns number of songs the database has fingerprinted. Returns number of songs the database has fingerprinted.
""" """
try: with self.cursor() as cur:
self.cursor = self.connection.cursor() cur.execute(self.SELECT_UNIQUE_SONG_IDS)
self.cursor.execute(SQLDatabase.SELECT_UNIQUE_SONG_IDS)
record = self.cursor.fetchone() for row in cur:
return int(record['n']) return row['n']
except mysql.Error, e:
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
def get_num_fingerprints(self): def get_num_fingerprints(self):
""" """
Returns number of fingerprints the database has fingerprinted. Returns number of fingerprints the database has fingerprinted.
""" """
try: with self.cursor() as cur:
self.cursor = self.connection.cursor() cur.execute(self.SELECT_NUM_FINGERPRINTS)
self.cursor.execute(SQLDatabase.SELECT_NUM_FINGERPRINTS)
record = self.cursor.fetchone()
return int(record['n'])
except mysql.Error, e:
print "Error in get_num_songs(), %d: %s" % (e.args[0], e.args[1])
for row in cur:
return row['n']
def set_song_fingerprinted(self, song_id): def set_song_fingerprinted(self, sid):
""" """
Set the fingerprinted flag to TRUE (1) once a song has been completely Set the fingerprinted flag to TRUE (1) once a song has been completely
fingerprinted in the database. fingerprinted in the database.
""" """
try: with self.cursor() as cur:
self.cursor = self.connection.cursor() cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
self.cursor.execute(SQLDatabase.UPDATE_SONG_FINGERPRINTED, song_id)
self.connection.commit()
except mysql.Error, e:
print "Error in set_song_fingerprinted(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def get_songs(self): def get_songs(self):
""" """
Return songs that have the fingerprinted flag set TRUE (1). Return songs that have the fingerprinted flag set TRUE (1).
""" """
try: with self.cursor() as cur:
self.cursor.execute(SQLDatabase.SELECT_SONGS) cur.execute(self.SELECT_SONGS)
return self.cursor.fetchall() for row in cur:
except mysql.Error, e: yield row
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
return None
def get_song_by_id(self, sid): def get_song_by_id(self, sid):
""" """
Returns song by its ID. Returns song by its ID.
""" """
try: with self.cursor() as cur:
self.cursor.execute(SQLDatabase.SELECT_SONG, (sid,)) cur.execute(self.SELECT_SONG, (sid,))
return self.cursor.fetchone() return cur.fetchone()
except mysql.Error, e:
print "Error in get_songs(), %d: %s" % (e.args[0], e.args[1])
return None
def insert(self, hash, sid, offset):
def insert(self, key, value):
""" """
Insert a (sha1, song_id, offset) row into database. Insert a (sha1, song_id, offset) row into database.
key is a sha1 hash, value = (song_id, offset)
""" """
try: with self.cursor() as cur:
args = (key, value[0], value[1]) cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
self.cursor.execute(SQLDatabase.INSERT_FINGERPRINT, args)
except mysql.Error, e:
print "Error in insert(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
def insert_song(self, songname): def insert_song(self, songname):
""" """
Inserts song in the database and returns the ID of the inserted record. Inserts song in the database and returns the ID of the inserted record.
""" """
try: with self.cursor() as cur:
self.cursor.execute(SQLDatabase.INSERT_SONG, (songname,)) cur.execute(self.INSERT_SONG, (songname,))
self.connection.commit() return cur.lastrowid
return int(self.cursor.lastrowid)
except mysql.Error, e:
print "Error in insert_song(), %d: %s" % (e.args[0], e.args[1])
self.connection.rollback()
return None
def query(self, key): def query(self, hash):
""" """
Return all tuples associated with hash. Return all tuples associated with hash.
@ -270,24 +247,12 @@ class SQLDatabase():
database (be careful with that one!). database (be careful with that one!).
""" """
# select all if no key # select all if no key
if key is not None: query = self.SELECT_ALL if hash is None else self.SELECT
sql = SQLDatabase.SELECT
else:
sql = SQLDatabase.SELECT_ALL
matches = [] with self.cursor() as cur:
try: cur.execute(query)
self.cursor.execute(sql, (key,)) for row in cur:
yield (row[self.FIELD_SONG_ID], row[self.FIELD_OFFSET])
# collect all matches
records = self.cursor.fetchall()
for record in records:
matches.append((record[SQLDatabase.FIELD_SONG_ID], record[SQLDatabase.FIELD_OFFSET]))
except mysql.Error, e:
print "Error in query(), %d: %s" % (e.args[0], e.args[1])
return matches
def get_iterable_kv_pairs(self): def get_iterable_kv_pairs(self):
""" """
@ -300,10 +265,13 @@ class SQLDatabase():
Insert series of hash => song_id, offset Insert series of hash => song_id, offset
values into the database. values into the database.
""" """
for h in hashes: # TODO: Fix this when hashes will be a new format.
sha1, val = h values = []
self.insert(sha1, val) for hash, (sid, offset) in hashes:
self.connection.commit() values.append((hash, sid, offset))
with self.cursor() as cur:
cur.executemany(self.INSERT_FINGERPRINT, values)
def return_matches(self, hashes): def return_matches(self, hashes):
""" """
@ -314,12 +282,29 @@ class SQLDatabase():
values. values.
""" """
matches = [] from pymysql.cursors import Cursor
for h in hashes: # Create a dictionary of hash => offset pairs for later lookups
sha1, val = h mapper = {}
list_of_tups = self.query(sha1) for hash, (_, offset) in hashes:
if list_of_tups: mapper[hash.upper()] = offset
for t in list_of_tups:
# (song_id, db_offset, song_sampled_offset) # Get an iteratable of all the hashes we need
matches.append((t[0], t[1] - val[1])) values = mapper.keys()
return matches
with self.cursor(cursor_type=Cursor) as cur:
for split_values in grouper(values, 1000):
# Create our IN part of the query
query = self.SELECT_MULTIPLE
query = query % ', '.join(['UNHEX(%s)'] * len(split_values))
cur.execute(query, split_values)
for hash, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hash])
from itertools import izip_longest
def grouper(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return izip_longest(fillvalue=fillvalue, *args)