mirror of
https://github.com/correl/SleekXMPP.git
synced 2025-01-03 03:00:20 +00:00
Merge branch 'hacks' of git@github.com:tomstrummer/SleekXMPP into hacks
This commit is contained in:
commit
4fccd77685
7 changed files with 364 additions and 70 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -2,3 +2,5 @@
|
|||
.project
|
||||
build/
|
||||
*.swp
|
||||
.pydevproject
|
||||
.settings
|
||||
|
|
|
@ -69,6 +69,8 @@ class ClientXMPP(basexmpp, XMLStream):
|
|||
#TODO: Use stream state here
|
||||
self.authenticated = False
|
||||
self.sessionstarted = False
|
||||
self.bound = False
|
||||
self.bindfail = False
|
||||
self.registerHandler(Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True))
|
||||
self.registerHandler(Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True))
|
||||
#self.registerHandler(Callback('Roster Update', MatchXMLMask("<presence xmlns='%s' type='subscribe' />" % self.default_ns), self._handlePresenceSubscribe, thread=True))
|
||||
|
@ -146,13 +148,9 @@ class ClientXMPP(basexmpp, XMLStream):
|
|||
# overriding reconnect and disconnect so that we can get some events
|
||||
# should events be part of or required by xmlstream? Maybe that would be cleaner
|
||||
def reconnect(self):
|
||||
logging.info("Reconnecting")
|
||||
self.event("disconnected")
|
||||
self.authenticated = False
|
||||
self.sessionstarted = False
|
||||
XMLStream.reconnect(self)
|
||||
self.disconnect(reconnect=True)
|
||||
|
||||
def disconnect(self, init=True, close=False, reconnect=False):
|
||||
def disconnect(self, reconnect=False):
|
||||
self.event("disconnected")
|
||||
self.authenticated = False
|
||||
self.sessionstarted = False
|
||||
|
@ -248,19 +246,23 @@ class ClientXMPP(basexmpp, XMLStream):
|
|||
response = iq.send()
|
||||
#response = self.send(iq, self.Iq(sid=iq['id']))
|
||||
self.set_jid(response.xml.find('{urn:ietf:params:xml:ns:xmpp-bind}bind/{urn:ietf:params:xml:ns:xmpp-bind}jid').text)
|
||||
self.bound = True
|
||||
logging.info("Node set to: %s" % self.fulljid)
|
||||
if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features:
|
||||
if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features or self.bindfail:
|
||||
logging.debug("Established Session")
|
||||
self.sessionstarted = True
|
||||
self.event("session_start")
|
||||
|
||||
def handler_start_session(self, xml):
|
||||
if self.authenticated:
|
||||
if self.authenticated and self.bound:
|
||||
iq = self.makeIqSet(xml)
|
||||
response = iq.send()
|
||||
logging.debug("Established Session")
|
||||
self.sessionstarted = True
|
||||
self.event("session_start")
|
||||
else:
|
||||
#bind probably hasn't happened yet
|
||||
self.bindfail = True
|
||||
|
||||
def _handleRoster(self, iq, request=False):
|
||||
if iq['type'] == 'set' or (iq['type'] == 'result' and request):
|
||||
|
|
|
@ -91,20 +91,26 @@ class basexmpp(object):
|
|||
if not self.plugin[idx].post_inited: self.plugin[idx].post_init()
|
||||
return super(basexmpp, self).process(*args, **kwargs)
|
||||
|
||||
def registerPlugin(self, plugin, pconfig = {}):
|
||||
def registerPlugin(self, plugin, pconfig = {}, pluginModule = None):
|
||||
"""Register a plugin not in plugins.__init__.__all__ but in the plugins
|
||||
directory."""
|
||||
# discover relative "path" to the plugins module from the main app, and import it.
|
||||
# TODO:
|
||||
# gross, this probably isn't necessary anymore, especially for an installed module
|
||||
__import__("%s.%s" % (globals()['plugins'].__name__, plugin))
|
||||
try:
|
||||
if pluginModule:
|
||||
module = __import__(pluginModule, globals(), locals(), [plugin])
|
||||
else:
|
||||
module = __import__("%s.%s" % (globals()['plugins'].__name__, plugin), globals(), locals(), [plugin])
|
||||
# init the plugin class
|
||||
self.plugin[plugin] = getattr(getattr(plugins, plugin), plugin)(self, pconfig) # eek
|
||||
self.plugin[plugin] = getattr(module, plugin)(self, pconfig) # eek
|
||||
# all of this for a nice debug? sure.
|
||||
xep = ''
|
||||
if hasattr(self.plugin[plugin], 'xep'):
|
||||
xep = "(XEP-%s) " % self.plugin[plugin].xep
|
||||
logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description))
|
||||
except:
|
||||
logging.error("Unable to load plugin: %s" %(plugin) )
|
||||
|
||||
def register_plugins(self):
|
||||
"""Initiates all plugins in the plugins/__init__.__all__"""
|
||||
|
|
|
@ -76,7 +76,7 @@ class Scheduler(object):
|
|||
if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next)
|
||||
except KeyboardInterrupt:
|
||||
self.run = False
|
||||
logging.debug("Qutting Scheduler thread")
|
||||
logging.debug("Quitting Scheduler thread")
|
||||
if self.parentqueue is not None:
|
||||
self.parentqueue.put(('quit', None, None))
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import threading
|
|||
import time
|
||||
import logging
|
||||
|
||||
|
||||
class StateMachine(object):
|
||||
|
||||
def __init__(self, states=[]):
|
||||
|
@ -27,7 +28,7 @@ class StateMachine(object):
|
|||
self.__states.append( state )
|
||||
|
||||
|
||||
def transition(self, from_state, to_state, wait=0.0):
|
||||
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
|
||||
'''
|
||||
Transition from the given `from_state` to the given `to_state`.
|
||||
This method will return `True` if the state machine is now in `to_state`. It
|
||||
|
@ -47,25 +48,37 @@ class StateMachine(object):
|
|||
if thread_should_exit: return
|
||||
# perform actions here after successful transition
|
||||
|
||||
This allows the thread to be interrupted by setting `thread_should_exit=True`
|
||||
This allows the thread to be responsive by setting `thread_should_exit=True`.
|
||||
|
||||
The optional `func` argument allows the user to pass a callable operation which occurs
|
||||
within the context of the state transition (e.g. while the state machine is locked.)
|
||||
If `func` returns a True value, the transition will occur. If `func` returns a non-
|
||||
True value or if an exception is thrown, the transition will not occur. Any thrown
|
||||
exception is not caught by the state machine and is the caller's responsibility to handle.
|
||||
If `func` completes normally, this method will return the value returned by `func.` If
|
||||
values for `args` and `kwargs` are provided, they are expanded and passed like so:
|
||||
`func( *args, **kwargs )`.
|
||||
'''
|
||||
|
||||
return self.transition_any( (from_state,), to_state, wait=wait )
|
||||
return self.transition_any( (from_state,), to_state, wait=wait,
|
||||
func=func, args=args, kwargs=kwargs )
|
||||
|
||||
def transition_any(self, from_states, to_state, wait=0.0):
|
||||
|
||||
def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ):
|
||||
'''
|
||||
Transition from any of the given `from_states` to the given `to_state`.
|
||||
'''
|
||||
|
||||
with self.lock:
|
||||
if not (isinstance(from_states,tuple) or isinstance(from_states,list)):
|
||||
raise ValueError( "from_states should be a list or tuple" )
|
||||
|
||||
for state in from_states:
|
||||
if isinstance(state,tuple) or isinstance(state,list):
|
||||
raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
|
||||
if not state in self.__states:
|
||||
raise ValueError( "StateMachine does not contain from_state %s." % state )
|
||||
if not to_state in self.__states:
|
||||
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
|
||||
|
||||
with self.lock:
|
||||
start = time.time()
|
||||
while not self.__current_state in from_states:
|
||||
# detect timeout:
|
||||
|
@ -73,32 +86,78 @@ class StateMachine(object):
|
|||
self.lock.wait(wait)
|
||||
|
||||
if self.__current_state in from_states: # should always be True due to lock
|
||||
|
||||
return_val = True
|
||||
# Note that func might throw an exception, but that's OK, it aborts the transition
|
||||
if func is not None: return_val = func(*args,**kwargs)
|
||||
|
||||
# some 'false' value returned from func,
|
||||
# indicating that transition should not occur:
|
||||
if not return_val: return return_val
|
||||
|
||||
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
|
||||
self.__current_state = to_state
|
||||
self.lock.notifyAll()
|
||||
return True
|
||||
return return_val # some 'true' value returned by func or True if func was None
|
||||
else:
|
||||
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
|
||||
return False
|
||||
|
||||
|
||||
def transition_ctx(self, from_state, to_state, wait=0.0):
|
||||
'''
|
||||
Use the state machine as a context manager. The transition occurs on /exit/ from
|
||||
the `with` context, so long as no exception is thrown. For example:
|
||||
|
||||
::
|
||||
|
||||
with state_machine.transition_ctx('one','two', wait=5) as locked:
|
||||
if locked:
|
||||
# the state machine is currently locked in state 'one', and will
|
||||
# transition to 'two' when the 'with' statement ends, so long as
|
||||
# no exception is thrown.
|
||||
print 'Currently locked in state one: %s' % state_machine['one']
|
||||
|
||||
else:
|
||||
# The 'wait' timed out, and no lock has been acquired
|
||||
print 'Timed out before entering state "one"'
|
||||
|
||||
print 'Since no exception was thrown, we are now in state "two": %s' % state_machine['two']
|
||||
|
||||
|
||||
The other main difference between this method and `transition()` is that the
|
||||
state machine is locked for the duration of the `with` statement (normally,
|
||||
after a `transition() occurs, the state machine is immediately unlocked and
|
||||
available to another thread to call `transition()` again.
|
||||
'''
|
||||
|
||||
if not from_state in self.__states:
|
||||
raise ValueError( "StateMachine does not contain from_state %s." % state )
|
||||
if not to_state in self.__states:
|
||||
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
|
||||
|
||||
return _StateCtx(self, from_state, to_state, wait)
|
||||
|
||||
|
||||
def ensure(self, state, wait=0.0):
|
||||
'''
|
||||
Ensure the state machine is currently in `state`, or wait until it enters `state`.
|
||||
'''
|
||||
return self.ensure_any( (state,), wait=wait )
|
||||
|
||||
|
||||
def ensure_any(self, states, wait=0.0):
|
||||
'''
|
||||
Ensure we are currently in one of the given `states`
|
||||
'''
|
||||
with self.lock:
|
||||
for state in states:
|
||||
if isinstance(state,tuple) or isinstance(state,list):
|
||||
raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
|
||||
if not state in self.__states:
|
||||
raise ValueError( "StateMachine does not contain state %s." % state )
|
||||
if not (isinstance(states,tuple) or isinstance(states,list)):
|
||||
raise ValueError('states arg should be a tuple or list')
|
||||
|
||||
for state in states:
|
||||
if not state in self.__states:
|
||||
raise ValueError( "StateMachine does not contain state '%s'" % state )
|
||||
|
||||
with self.lock:
|
||||
start = time.time()
|
||||
while not self.__current_state in states:
|
||||
# detect timeout:
|
||||
|
@ -112,6 +171,18 @@ class StateMachine(object):
|
|||
self.transition(self.__current_state, self._default_state)
|
||||
|
||||
|
||||
def _set_state(self, state): #unsynchronized, only call internally after lock is acquired
|
||||
self.__current_state = state
|
||||
return state
|
||||
|
||||
|
||||
def current_state(self):
|
||||
'''
|
||||
Return the current state name.
|
||||
'''
|
||||
return self.__current_state
|
||||
|
||||
|
||||
def __getitem__(self, state):
|
||||
'''
|
||||
Non-blocking, non-synchronized test to determine if we are in the given state.
|
||||
|
@ -119,12 +190,45 @@ class StateMachine(object):
|
|||
'''
|
||||
return self.__current_state == state
|
||||
|
||||
def __str__(self):
|
||||
return "".join(( "StateMachine(", ','.join(self.__states), "): ", self.__current_state ))
|
||||
|
||||
|
||||
|
||||
class _StateCtx:
|
||||
|
||||
def __init__( self, state_machine, from_state, to_state, wait ):
|
||||
self.state_machine = state_machine
|
||||
self.from_state = from_state
|
||||
self.to_state = to_state
|
||||
self.wait = wait
|
||||
self._timeout = False
|
||||
|
||||
def __enter__(self):
|
||||
self.lock.acquire()
|
||||
return self
|
||||
self.state_machine.lock.acquire()
|
||||
start = time.time()
|
||||
while not self.state_machine[ self.from_state ]:
|
||||
# detect timeout:
|
||||
if time.time() >= start + self.wait:
|
||||
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
|
||||
self._timeout = True # to indicate we should not transition
|
||||
return False
|
||||
self.state_machine.lock.wait(self.wait)
|
||||
|
||||
logging.debug('StateMachine entered context in state: %s',
|
||||
self.state_machine.current_state() )
|
||||
return True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.lock.nofityAll()
|
||||
self.lock.release()
|
||||
if exc_val is not None:
|
||||
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
|
||||
self.state_machine.current_state(), exc_type.__name__, exc_val )
|
||||
elif not self._timeout:
|
||||
logging.debug(' ==== TRANSITION %s -> %s',
|
||||
self.state_machine.current_state(), self.to_state)
|
||||
self.state_machine._set_state( self.to_state )
|
||||
|
||||
self.state_machine.lock.notifyAll()
|
||||
self.state_machine.lock.release()
|
||||
return False # re-raise any exception
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from . stanzabase import StanzaBase
|
|||
from xml.etree import cElementTree
|
||||
from xml.parsers import expat
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
@ -46,6 +47,10 @@ class CloseStream(Exception):
|
|||
|
||||
stanza_extensions = {}
|
||||
|
||||
RECONNECT_MAX_DELAY = 3600
|
||||
RECONNECT_QUIESCE_FACTOR = 1.6180339887498948 # Phi
|
||||
RECONNECT_QUIESCE_JITTER = 0.11962656472 # molar Planck constant times c, joule meter/mole
|
||||
|
||||
class XMLStream(object):
|
||||
"A connection manager with XML events."
|
||||
|
||||
|
@ -95,13 +100,29 @@ class XMLStream(object):
|
|||
self.filesocket = filesocket
|
||||
|
||||
def connect(self, host='', port=0, use_ssl=None, use_tls=None):
|
||||
"Link to connectTCP"
|
||||
if self.state.transition('disconnected', 'connecting'):
|
||||
return self.connectTCP(host, port, use_ssl, use_tls)
|
||||
"Establish a socket connection to the given XMPP server."
|
||||
|
||||
if not self.state.transition('disconnected','connected',
|
||||
func=self.connectTCP, args=[host, port, use_ssl, use_tls] ):
|
||||
|
||||
if self.state['connected']: logging.debug('Already connected')
|
||||
else: logging.warning("Connection failed" )
|
||||
return False
|
||||
|
||||
logging.debug('Connection complete.')
|
||||
return True
|
||||
|
||||
# TODO currently a caller can't distinguish between "connection failed" and
|
||||
# "we're already trying to connect from another thread"
|
||||
|
||||
def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True):
|
||||
"Connect and create socket"
|
||||
while reattempt and not self.state['connected']: # the self.state part is redundant.
|
||||
|
||||
# Note that this is thread-safe by merit of being called solely from connect() which
|
||||
# holds the state lock.
|
||||
|
||||
delay = 1.0 # reconnection delay
|
||||
while self.run:
|
||||
logging.debug('connecting....')
|
||||
try:
|
||||
if host and port:
|
||||
|
@ -109,27 +130,39 @@ class XMLStream(object):
|
|||
if use_ssl is not None:
|
||||
self.use_ssl = use_ssl
|
||||
if use_tls is not None:
|
||||
# TODO this variable doesn't seem to be used for anything!
|
||||
self.use_tls = use_tls
|
||||
if sys.version_info < (3, 0):
|
||||
self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM)
|
||||
else:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.settimeout(None) #10)
|
||||
|
||||
if self.use_ssl and self.ssl_support:
|
||||
logging.debug("Socket Wrapped for SSL")
|
||||
self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs)
|
||||
except:
|
||||
logging.exception("Connection error")
|
||||
try:
|
||||
|
||||
self.socket.connect(self.address)
|
||||
self.filesocket = self.socket.makefile('rb', 0)
|
||||
if not self.state.transition('connecting','connected'):
|
||||
logging.error( "State transition error!!!! Shouldn't have happened" )
|
||||
logging.debug('connect complete.')
|
||||
|
||||
return True
|
||||
|
||||
except socket.error as serr:
|
||||
logging.error("Could not connect. Socket Error #%s: %s" % (serr.errno, serr.strerror))
|
||||
time.sleep(1) # TODO proper quiesce if connection attempt fails
|
||||
logging.exception("Socket Error #%s: %s", serr.errno, serr.strerror)
|
||||
if not reattempt: return False
|
||||
except:
|
||||
logging.exception("Connection error")
|
||||
if not reattempt: return False
|
||||
|
||||
# quiesce if rconnection fails:
|
||||
# This algorithm based loosely on Twisted internet.protocol
|
||||
# http://twistedmatrix.com/trac/browser/trunk/twisted/internet/protocol.py#L310
|
||||
delay = min(delay * RECONNECT_QUIESCE_FACTOR, RECONNECT_MAX_DELAY)
|
||||
delay = random.normalvariate(delay, delay * RECONNECT_QUIESCE_JITTER)
|
||||
logging.debug('Waiting %fs until next reconnect attempt...', delay)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
|
||||
def connectUnix(self, filepath):
|
||||
"Connect to Unix file and create socket"
|
||||
|
@ -244,11 +277,12 @@ class XMLStream(object):
|
|||
|
||||
data = None
|
||||
try:
|
||||
data = self.sendqueue.get(True,10)
|
||||
data = self.sendqueue.get(True,5)
|
||||
logging.debug("SEND: %s" % data)
|
||||
self.socket.sendall(data.encode('utf-8'))
|
||||
except queue.Empty:
|
||||
logging.debug('nothing on send queue')
|
||||
# logging.debug('Nothing on send queue')
|
||||
pass
|
||||
except socket.timeout:
|
||||
# this is to prevent a thread blocked indefinitely
|
||||
logging.debug('timeout sending packet data')
|
||||
|
@ -335,6 +369,7 @@ class XMLStream(object):
|
|||
try:
|
||||
event = self.eventqueue.get(True, timeout=5)
|
||||
except queue.Empty:
|
||||
# logging.debug('Nothing on event queue')
|
||||
event = None
|
||||
if event is not None:
|
||||
etype = event[0]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
import time, threading
|
||||
import time, threading, random, functools
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys, os
|
||||
|
@ -15,25 +15,23 @@ class testStateMachine(unittest.TestCase):
|
|||
def testDefaults(self):
|
||||
"Test ensure transitions occur correctly in a single thread"
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
# self.assertTrue(s.one)
|
||||
self.assertTrue(s['one'])
|
||||
# self.failIf(s.two)
|
||||
self.failIf(s['two'])
|
||||
try:
|
||||
s.booga
|
||||
s['booga']
|
||||
self.fail('s.booga is an invalid state and should throw an exception!')
|
||||
except: pass #expected exception
|
||||
|
||||
# just make sure __str__ works, no reason to test its exact value:
|
||||
print str(s)
|
||||
|
||||
|
||||
def testTransitions(self):
|
||||
"Test ensure transitions occur correctly in a single thread"
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
# self.assertTrue(s.one)
|
||||
|
||||
self.assertTrue( s.transition('one', 'two') )
|
||||
# self.assertTrue( s.two )
|
||||
self.assertTrue( s['two'] )
|
||||
# self.failIf( s.one )
|
||||
self.failIf( s['one'] )
|
||||
|
||||
self.assertTrue( s.transition('two', 'three') )
|
||||
|
@ -83,12 +81,12 @@ class testStateMachine(unittest.TestCase):
|
|||
|
||||
thread_state = {'ready': False, 'transitioned': False}
|
||||
def t1():
|
||||
# this will block until the main thread transitions to 'two'
|
||||
if s['two']:
|
||||
print 'thread has already transitioned!'
|
||||
self.fail()
|
||||
thread_state['ready'] = True
|
||||
print 'Thread is ready'
|
||||
# this will block until the main thread transitions to 'two'
|
||||
self.assertTrue( s.transition('two','three', wait=20) )
|
||||
print 'transitioned to three!'
|
||||
thread_state['transitioned'] = True
|
||||
|
@ -109,6 +107,153 @@ class testStateMachine(unittest.TestCase):
|
|||
self.assertTrue( thread_state['transitioned'] )
|
||||
|
||||
|
||||
def testForRaceCondition(self):
|
||||
"""Attempt to allow two threads to perform the same transition;
|
||||
only one should ever make it."""
|
||||
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
|
||||
def t1(num):
|
||||
while True:
|
||||
if not trigger['go'] or thread_state[num] in (True,False):
|
||||
time.sleep( random.random()/100 ) # < .01s
|
||||
if thread_state[num] == 'quit': break
|
||||
continue
|
||||
|
||||
thread_state[num] = s.transition('one','two' )
|
||||
# print '-',
|
||||
|
||||
thread_count = 20
|
||||
threads = []
|
||||
thread_state = {}
|
||||
def reset():
|
||||
for c in range(thread_count): thread_state[c] = "reset"
|
||||
trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.
|
||||
|
||||
for c in range(thread_count):
|
||||
thread_state[c] = "reset"
|
||||
thread = threading.Thread( target= functools.partial(t1,c) )
|
||||
threads.append( thread )
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
for x in range(100): # this will take 10s to execute
|
||||
# print "+",
|
||||
trigger['go'] = True
|
||||
time.sleep(.1)
|
||||
trigger['go'] = False
|
||||
winners = 0
|
||||
for (num, state) in thread_state.items():
|
||||
if state == True: winners = winners +1
|
||||
elif state != False: raise Exception( "!%d!%s!" % (num,state) )
|
||||
|
||||
self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
|
||||
self.assertTrue( s.ensure('two') )
|
||||
self.assertTrue( s.transition('two','one') ) # return to the first state.
|
||||
reset()
|
||||
|
||||
# now let the threads quit gracefully:
|
||||
for c in range(thread_count): thread_state[c] = 'quit'
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def testTransitionFunctions(self):
|
||||
"test that a `func` argument allows or blocks the transition correctly."
|
||||
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
|
||||
def alwaysFalse(): return False
|
||||
def alwaysTrue(): return True
|
||||
|
||||
self.failIf( s.transition('one','two', func=alwaysFalse) )
|
||||
self.assertTrue(s['one'])
|
||||
self.failIf(s['two'])
|
||||
|
||||
self.assertTrue( s.transition('one','two', func=alwaysTrue) )
|
||||
self.failIf(s['one'])
|
||||
self.assertTrue(s['two'])
|
||||
|
||||
|
||||
def testTransitionFuncException(self):
|
||||
"if a transition function throws an exeption, ensure we're in a sane state"
|
||||
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
|
||||
def alwaysException(): raise Exception('whups!')
|
||||
|
||||
try:
|
||||
self.failIf( s.transition('one','two', func=alwaysException) )
|
||||
self.fail("exception should have been thrown")
|
||||
except: pass #expected exception
|
||||
|
||||
self.assertTrue(s['one'])
|
||||
self.failIf(s['two'])
|
||||
|
||||
# ensure a subsequent attempt completes normally:
|
||||
self.assertTrue( s.transition('one','two') )
|
||||
self.failIf(s['one'])
|
||||
self.assertTrue(s['two'])
|
||||
|
||||
|
||||
def testContextManager(self):
|
||||
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
|
||||
with s.transition_ctx('one','two'):
|
||||
self.assertTrue( s['one'] )
|
||||
self.failIf( s['two'] )
|
||||
|
||||
#successful transition b/c no exception was thrown
|
||||
self.assertTrue( s['two'] )
|
||||
self.failIf( s['one'] )
|
||||
|
||||
# failed transition because exception is thrown:
|
||||
try:
|
||||
with s.transition_ctx('two','three'):
|
||||
raise Exception("boom!")
|
||||
self.fail('exception expected')
|
||||
except: pass
|
||||
|
||||
self.failIf( s.current_state() in ('one','three') )
|
||||
self.assertTrue( s['two'] )
|
||||
|
||||
def testCtxManagerTransitionFailure(self):
|
||||
|
||||
s = sm.StateMachine(('one','two','three'))
|
||||
|
||||
with s.transition_ctx('two','three') as result:
|
||||
self.failIf( result )
|
||||
self.assertTrue( s['one'] )
|
||||
self.failIf( s.current_state in ('two','three') )
|
||||
|
||||
self.assertTrue( s['one'] )
|
||||
|
||||
def r1():
|
||||
print 'thread 1 started'
|
||||
self.assertTrue( s.transition('one','two') )
|
||||
print 'thread 1 transitioned'
|
||||
|
||||
def r2():
|
||||
print 'thread 2 started'
|
||||
self.failIf( s['two'] )
|
||||
with s.transition_ctx('two','three', 10) as result:
|
||||
self.assertTrue( result )
|
||||
self.assertTrue( s['two'] )
|
||||
print 'thread 2 will transition on exit from the context manager...'
|
||||
self.assertTrue( s['three'] )
|
||||
print 'transitioned to %s' % s.current_state()
|
||||
|
||||
t1 = threading.Thread(target=r1)
|
||||
t2 = threading.Thread(target=r2)
|
||||
|
||||
t2.start() # this should block until r1 goes
|
||||
time.sleep(1)
|
||||
t1.start()
|
||||
|
||||
t1.join()
|
||||
t2.join()
|
||||
|
||||
self.assertTrue( s['three'] )
|
||||
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
|
||||
|
|
Loading…
Reference in a new issue