Merge branch 'hacks' of git@github.com:tomstrummer/SleekXMPP into hacks

This commit is contained in:
Tom Nichols 2010-06-08 10:40:15 -04:00
commit 4fccd77685
7 changed files with 364 additions and 70 deletions

2
.gitignore vendored
View file

@ -2,3 +2,5 @@
.project
build/
*.swp
.pydevproject
.settings

View file

@ -69,6 +69,8 @@ class ClientXMPP(basexmpp, XMLStream):
#TODO: Use stream state here
self.authenticated = False
self.sessionstarted = False
self.bound = False
self.bindfail = False
self.registerHandler(Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True))
self.registerHandler(Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True))
#self.registerHandler(Callback('Roster Update', MatchXMLMask("<presence xmlns='%s' type='subscribe' />" % self.default_ns), self._handlePresenceSubscribe, thread=True))
@ -146,13 +148,9 @@ class ClientXMPP(basexmpp, XMLStream):
# overriding reconnect and disconnect so that we can get some events
# should events be part of or required by xmlstream? Maybe that would be cleaner
def reconnect(self):
logging.info("Reconnecting")
self.event("disconnected")
self.authenticated = False
self.sessionstarted = False
XMLStream.reconnect(self)
self.disconnect(reconnect=True)
def disconnect(self, init=True, close=False, reconnect=False):
def disconnect(self, reconnect=False):
self.event("disconnected")
self.authenticated = False
self.sessionstarted = False
@ -248,19 +246,23 @@ class ClientXMPP(basexmpp, XMLStream):
response = iq.send()
#response = self.send(iq, self.Iq(sid=iq['id']))
self.set_jid(response.xml.find('{urn:ietf:params:xml:ns:xmpp-bind}bind/{urn:ietf:params:xml:ns:xmpp-bind}jid').text)
self.bound = True
logging.info("Node set to: %s" % self.fulljid)
if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features:
if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features or self.bindfail:
logging.debug("Established Session")
self.sessionstarted = True
self.event("session_start")
def handler_start_session(self, xml):
if self.authenticated:
if self.authenticated and self.bound:
iq = self.makeIqSet(xml)
response = iq.send()
logging.debug("Established Session")
self.sessionstarted = True
self.event("session_start")
else:
#bind probably hasn't happened yet
self.bindfail = True
def _handleRoster(self, iq, request=False):
if iq['type'] == 'set' or (iq['type'] == 'result' and request):

View file

@ -1,9 +1,9 @@
"""
SleekXMPP: The Sleek XMPP Library
Copyright (C) 2010 Nathanael C. Fritz
This file is part of SleekXMPP.
SleekXMPP: The Sleek XMPP Library
Copyright (C) 2010 Nathanael C. Fritz
This file is part of SleekXMPP.
See the file license.txt for copying permission.
See the file license.txt for copying permission.
"""
from __future__ import with_statement, unicode_literals
@ -91,20 +91,26 @@ class basexmpp(object):
if not self.plugin[idx].post_inited: self.plugin[idx].post_init()
return super(basexmpp, self).process(*args, **kwargs)
def registerPlugin(self, plugin, pconfig = {}):
def registerPlugin(self, plugin, pconfig = {}, pluginModule = None):
"""Register a plugin not in plugins.__init__.__all__ but in the plugins
directory."""
# discover relative "path" to the plugins module from the main app, and import it.
# TODO:
# gross, this probably isn't necessary anymore, especially for an installed module
__import__("%s.%s" % (globals()['plugins'].__name__, plugin))
# init the plugin class
self.plugin[plugin] = getattr(getattr(plugins, plugin), plugin)(self, pconfig) # eek
# all of this for a nice debug? sure.
xep = ''
if hasattr(self.plugin[plugin], 'xep'):
xep = "(XEP-%s) " % self.plugin[plugin].xep
logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description))
try:
if pluginModule:
module = __import__(pluginModule, globals(), locals(), [plugin])
else:
module = __import__("%s.%s" % (globals()['plugins'].__name__, plugin), globals(), locals(), [plugin])
# init the plugin class
self.plugin[plugin] = getattr(module, plugin)(self, pconfig) # eek
# all of this for a nice debug? sure.
xep = ''
if hasattr(self.plugin[plugin], 'xep'):
xep = "(XEP-%s) " % self.plugin[plugin].xep
logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description))
except:
logging.error("Unable to load plugin: %s" %(plugin) )
def register_plugins(self):
"""Initiates all plugins in the plugins/__init__.__all__"""

View file

@ -76,7 +76,7 @@ class Scheduler(object):
if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next)
except KeyboardInterrupt:
self.run = False
logging.debug("Qutting Scheduler thread")
logging.debug("Quitting Scheduler thread")
if self.parentqueue is not None:
self.parentqueue.put(('quit', None, None))

View file

@ -10,6 +10,7 @@ import threading
import time
import logging
class StateMachine(object):
def __init__(self, states=[]):
@ -27,7 +28,7 @@ class StateMachine(object):
self.__states.append( state )
def transition(self, from_state, to_state, wait=0.0):
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
'''
Transition from the given `from_state` to the given `to_state`.
This method will return `True` if the state machine is now in `to_state`. It
@ -47,25 +48,37 @@ class StateMachine(object):
if thread_should_exit: return
# perform actions here after successful transition
This allows the thread to be interrupted by setting `thread_should_exit=True`
This allows the thread to be responsive by setting `thread_should_exit=True`.
The optional `func` argument allows the user to pass a callable operation which occurs
within the context of the state transition (e.g. while the state machine is locked.)
If `func` returns a True value, the transition will occur. If `func` returns a non-
True value or if an exception is thrown, the transition will not occur. Any thrown
exception is not caught by the state machine and is the caller's responsibility to handle.
If `func` completes normally, this method will return the value returned by `func.` If
values for `args` and `kwargs` are provided, they are expanded and passed like so:
`func( *args, **kwargs )`.
'''
return self.transition_any( (from_state,), to_state, wait=wait )
return self.transition_any( (from_state,), to_state, wait=wait,
func=func, args=args, kwargs=kwargs )
def transition_any(self, from_states, to_state, wait=0.0):
def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ):
'''
Transition from any of the given `from_states` to the given `to_state`.
'''
with self.lock:
for state in from_states:
if isinstance(state,tuple) or isinstance(state,list):
raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
if not state in self.__states:
raise ValueError( "StateMachine does not contain from_state %s." % state )
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
if not (isinstance(from_states,tuple) or isinstance(from_states,list)):
raise ValueError( "from_states should be a list or tuple" )
for state in from_states:
if not state in self.__states:
raise ValueError( "StateMachine does not contain from_state %s." % state )
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
with self.lock:
start = time.time()
while not self.__current_state in from_states:
# detect timeout:
@ -73,32 +86,78 @@ class StateMachine(object):
self.lock.wait(wait)
if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state
self.lock.notifyAll()
return True
return return_val # some 'true' value returned by func or True if func was None
else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False
def transition_ctx(self, from_state, to_state, wait=0.0):
'''
Use the state machine as a context manager. The transition occurs on /exit/ from
the `with` context, so long as no exception is thrown. For example:
::
with state_machine.transition_ctx('one','two', wait=5) as locked:
if locked:
# the state machine is currently locked in state 'one', and will
# transition to 'two' when the 'with' statement ends, so long as
# no exception is thrown.
print 'Currently locked in state one: %s' % state_machine['one']
else:
# The 'wait' timed out, and no lock has been acquired
print 'Timed out before entering state "one"'
print 'Since no exception was thrown, we are now in state "two": %s' % state_machine['two']
The other main difference between this method and `transition()` is that the
state machine is locked for the duration of the `with` statement (normally,
after a `transition() occurs, the state machine is immediately unlocked and
available to another thread to call `transition()` again.
'''
if not from_state in self.__states:
raise ValueError( "StateMachine does not contain from_state %s." % state )
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
return _StateCtx(self, from_state, to_state, wait)
def ensure(self, state, wait=0.0):
'''
Ensure the state machine is currently in `state`, or wait until it enters `state`.
'''
return self.ensure_any( (state,), wait=wait )
def ensure_any(self, states, wait=0.0):
'''
Ensure we are currently in one of the given `states`
'''
with self.lock:
for state in states:
if isinstance(state,tuple) or isinstance(state,list):
raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
if not state in self.__states:
raise ValueError( "StateMachine does not contain state %s." % state )
if not (isinstance(states,tuple) or isinstance(states,list)):
raise ValueError('states arg should be a tuple or list')
for state in states:
if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state )
with self.lock:
start = time.time()
while not self.__current_state in states:
# detect timeout:
@ -112,6 +171,18 @@ class StateMachine(object):
self.transition(self.__current_state, self._default_state)
def _set_state(self, state): #unsynchronized, only call internally after lock is acquired
self.__current_state = state
return state
def current_state(self):
'''
Return the current state name.
'''
return self.__current_state
def __getitem__(self, state):
'''
Non-blocking, non-synchronized test to determine if we are in the given state.
@ -119,12 +190,45 @@ class StateMachine(object):
'''
return self.__current_state == state
def __str__(self):
return "".join(( "StateMachine(", ','.join(self.__states), "): ", self.__current_state ))
class _StateCtx:
def __init__( self, state_machine, from_state, to_state, wait ):
self.state_machine = state_machine
self.from_state = from_state
self.to_state = to_state
self.wait = wait
self._timeout = False
def __enter__(self):
self.lock.acquire()
return self
self.state_machine.lock.acquire()
start = time.time()
while not self.state_machine[ self.from_state ]:
# detect timeout:
if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
self._timeout = True # to indicate we should not transition
return False
self.state_machine.lock.wait(self.wait)
logging.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() )
return True
def __exit__(self, exc_type, exc_val, exc_tb):
self.lock.nofityAll()
self.lock.release()
if exc_val is not None:
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
self.state_machine.current_state(), exc_type.__name__, exc_val )
elif not self._timeout:
logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.lock.notifyAll()
self.state_machine.lock.release()
return False # re-raise any exception

View file

@ -16,6 +16,7 @@ from . stanzabase import StanzaBase
from xml.etree import cElementTree
from xml.parsers import expat
import logging
import random
import socket
import threading
import time
@ -46,6 +47,10 @@ class CloseStream(Exception):
stanza_extensions = {}
RECONNECT_MAX_DELAY = 3600
RECONNECT_QUIESCE_FACTOR = 1.6180339887498948 # Phi
RECONNECT_QUIESCE_JITTER = 0.11962656472 # molar Planck constant times c, joule meter/mole
class XMLStream(object):
"A connection manager with XML events."
@ -95,13 +100,29 @@ class XMLStream(object):
self.filesocket = filesocket
def connect(self, host='', port=0, use_ssl=None, use_tls=None):
"Link to connectTCP"
if self.state.transition('disconnected', 'connecting'):
return self.connectTCP(host, port, use_ssl, use_tls)
"Establish a socket connection to the given XMPP server."
if not self.state.transition('disconnected','connected',
func=self.connectTCP, args=[host, port, use_ssl, use_tls] ):
if self.state['connected']: logging.debug('Already connected')
else: logging.warning("Connection failed" )
return False
logging.debug('Connection complete.')
return True
# TODO currently a caller can't distinguish between "connection failed" and
# "we're already trying to connect from another thread"
def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True):
"Connect and create socket"
while reattempt and not self.state['connected']: # the self.state part is redundant.
# Note that this is thread-safe by merit of being called solely from connect() which
# holds the state lock.
delay = 1.0 # reconnection delay
while self.run:
logging.debug('connecting....')
try:
if host and port:
@ -109,27 +130,39 @@ class XMLStream(object):
if use_ssl is not None:
self.use_ssl = use_ssl
if use_tls is not None:
# TODO this variable doesn't seem to be used for anything!
self.use_tls = use_tls
if sys.version_info < (3, 0):
self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM)
else:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(None) #10)
if self.use_ssl and self.ssl_support:
logging.debug("Socket Wrapped for SSL")
self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs)
except:
logging.exception("Connection error")
try:
self.socket.connect(self.address)
self.filesocket = self.socket.makefile('rb', 0)
if not self.state.transition('connecting','connected'):
logging.error( "State transition error!!!! Shouldn't have happened" )
logging.debug('connect complete.')
return True
except socket.error as serr:
logging.error("Could not connect. Socket Error #%s: %s" % (serr.errno, serr.strerror))
time.sleep(1) # TODO proper quiesce if connection attempt fails
logging.exception("Socket Error #%s: %s", serr.errno, serr.strerror)
if not reattempt: return False
except:
logging.exception("Connection error")
if not reattempt: return False
# quiesce if rconnection fails:
# This algorithm based loosely on Twisted internet.protocol
# http://twistedmatrix.com/trac/browser/trunk/twisted/internet/protocol.py#L310
delay = min(delay * RECONNECT_QUIESCE_FACTOR, RECONNECT_MAX_DELAY)
delay = random.normalvariate(delay, delay * RECONNECT_QUIESCE_JITTER)
logging.debug('Waiting %fs until next reconnect attempt...', delay)
time.sleep(delay)
def connectUnix(self, filepath):
"Connect to Unix file and create socket"
@ -244,11 +277,12 @@ class XMLStream(object):
data = None
try:
data = self.sendqueue.get(True,10)
data = self.sendqueue.get(True,5)
logging.debug("SEND: %s" % data)
self.socket.sendall(data.encode('utf-8'))
except queue.Empty:
logging.debug('nothing on send queue')
# logging.debug('Nothing on send queue')
pass
except socket.timeout:
# this is to prevent a thread blocked indefinitely
logging.debug('timeout sending packet data')
@ -335,6 +369,7 @@ class XMLStream(object):
try:
event = self.eventqueue.get(True, timeout=5)
except queue.Empty:
# logging.debug('Nothing on event queue')
event = None
if event is not None:
etype = event[0]

View file

@ -1,5 +1,5 @@
import unittest
import time, threading
import time, threading, random, functools
if __name__ == '__main__':
import sys, os
@ -15,25 +15,23 @@ class testStateMachine(unittest.TestCase):
def testDefaults(self):
"Test ensure transitions occur correctly in a single thread"
s = sm.StateMachine(('one','two','three'))
# self.assertTrue(s.one)
self.assertTrue(s['one'])
# self.failIf(s.two)
self.failIf(s['two'])
try:
s.booga
s['booga']
self.fail('s.booga is an invalid state and should throw an exception!')
except: pass #expected exception
# just make sure __str__ works, no reason to test its exact value:
print str(s)
def testTransitions(self):
"Test ensure transitions occur correctly in a single thread"
s = sm.StateMachine(('one','two','three'))
# self.assertTrue(s.one)
self.assertTrue( s.transition('one', 'two') )
# self.assertTrue( s.two )
self.assertTrue( s['two'] )
# self.failIf( s.one )
self.failIf( s['one'] )
self.assertTrue( s.transition('two', 'three') )
@ -83,12 +81,12 @@ class testStateMachine(unittest.TestCase):
thread_state = {'ready': False, 'transitioned': False}
def t1():
# this will block until the main thread transitions to 'two'
if s['two']:
print 'thread has already transitioned!'
self.fail()
thread_state['ready'] = True
print 'Thread is ready'
# this will block until the main thread transitions to 'two'
self.assertTrue( s.transition('two','three', wait=20) )
print 'transitioned to three!'
thread_state['transitioned'] = True
@ -109,6 +107,153 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( thread_state['transitioned'] )
def testForRaceCondition(self):
"""Attempt to allow two threads to perform the same transition;
only one should ever make it."""
s = sm.StateMachine(('one','two','three'))
def t1(num):
while True:
if not trigger['go'] or thread_state[num] in (True,False):
time.sleep( random.random()/100 ) # < .01s
if thread_state[num] == 'quit': break
continue
thread_state[num] = s.transition('one','two' )
# print '-',
thread_count = 20
threads = []
thread_state = {}
def reset():
for c in range(thread_count): thread_state[c] = "reset"
trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.
for c in range(thread_count):
thread_state[c] = "reset"
thread = threading.Thread( target= functools.partial(t1,c) )
threads.append( thread )
thread.daemon = True
thread.start()
for x in range(100): # this will take 10s to execute
# print "+",
trigger['go'] = True
time.sleep(.1)
trigger['go'] = False
winners = 0
for (num, state) in thread_state.items():
if state == True: winners = winners +1
elif state != False: raise Exception( "!%d!%s!" % (num,state) )
self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
self.assertTrue( s.ensure('two') )
self.assertTrue( s.transition('two','one') ) # return to the first state.
reset()
# now let the threads quit gracefully:
for c in range(thread_count): thread_state[c] = 'quit'
time.sleep(2)
def testTransitionFunctions(self):
"test that a `func` argument allows or blocks the transition correctly."
s = sm.StateMachine(('one','two','three'))
def alwaysFalse(): return False
def alwaysTrue(): return True
self.failIf( s.transition('one','two', func=alwaysFalse) )
self.assertTrue(s['one'])
self.failIf(s['two'])
self.assertTrue( s.transition('one','two', func=alwaysTrue) )
self.failIf(s['one'])
self.assertTrue(s['two'])
def testTransitionFuncException(self):
"if a transition function throws an exeption, ensure we're in a sane state"
s = sm.StateMachine(('one','two','three'))
def alwaysException(): raise Exception('whups!')
try:
self.failIf( s.transition('one','two', func=alwaysException) )
self.fail("exception should have been thrown")
except: pass #expected exception
self.assertTrue(s['one'])
self.failIf(s['two'])
# ensure a subsequent attempt completes normally:
self.assertTrue( s.transition('one','two') )
self.failIf(s['one'])
self.assertTrue(s['two'])
def testContextManager(self):
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('one','two'):
self.assertTrue( s['one'] )
self.failIf( s['two'] )
#successful transition b/c no exception was thrown
self.assertTrue( s['two'] )
self.failIf( s['one'] )
# failed transition because exception is thrown:
try:
with s.transition_ctx('two','three'):
raise Exception("boom!")
self.fail('exception expected')
except: pass
self.failIf( s.current_state() in ('one','three') )
self.assertTrue( s['two'] )
def testCtxManagerTransitionFailure(self):
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('two','three') as result:
self.failIf( result )
self.assertTrue( s['one'] )
self.failIf( s.current_state in ('two','three') )
self.assertTrue( s['one'] )
def r1():
print 'thread 1 started'
self.assertTrue( s.transition('one','two') )
print 'thread 1 transitioned'
def r2():
print 'thread 2 started'
self.failIf( s['two'] )
with s.transition_ctx('two','three', 10) as result:
self.assertTrue( result )
self.assertTrue( s['two'] )
print 'thread 2 will transition on exit from the context manager...'
self.assertTrue( s['three'] )
print 'transitioned to %s' % s.current_state()
t1 = threading.Thread(target=r1)
t2 = threading.Thread(target=r2)
t2.start() # this should block until r1 goes
time.sleep(1)
t1.start()
t1.join()
t2.join()
self.assertTrue( s['three'] )
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)