Merge branch 'hacks' of git@github.com:tomstrummer/SleekXMPP into hacks

This commit is contained in:
Tom Nichols 2010-06-08 10:40:15 -04:00
commit 4fccd77685
7 changed files with 364 additions and 70 deletions

2
.gitignore vendored
View file

@ -2,3 +2,5 @@
.project .project
build/ build/
*.swp *.swp
.pydevproject
.settings

View file

@ -69,6 +69,8 @@ class ClientXMPP(basexmpp, XMLStream):
#TODO: Use stream state here #TODO: Use stream state here
self.authenticated = False self.authenticated = False
self.sessionstarted = False self.sessionstarted = False
self.bound = False
self.bindfail = False
self.registerHandler(Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True)) self.registerHandler(Callback('Stream Features', MatchXPath('{http://etherx.jabber.org/streams}features'), self._handleStreamFeatures, thread=True))
self.registerHandler(Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True)) self.registerHandler(Callback('Roster Update', MatchXPath('{%s}iq/{jabber:iq:roster}query' % self.default_ns), self._handleRoster, thread=True))
#self.registerHandler(Callback('Roster Update', MatchXMLMask("<presence xmlns='%s' type='subscribe' />" % self.default_ns), self._handlePresenceSubscribe, thread=True)) #self.registerHandler(Callback('Roster Update', MatchXMLMask("<presence xmlns='%s' type='subscribe' />" % self.default_ns), self._handlePresenceSubscribe, thread=True))
@ -146,13 +148,9 @@ class ClientXMPP(basexmpp, XMLStream):
# overriding reconnect and disconnect so that we can get some events # overriding reconnect and disconnect so that we can get some events
# should events be part of or required by xmlstream? Maybe that would be cleaner # should events be part of or required by xmlstream? Maybe that would be cleaner
def reconnect(self): def reconnect(self):
logging.info("Reconnecting") self.disconnect(reconnect=True)
self.event("disconnected")
self.authenticated = False
self.sessionstarted = False
XMLStream.reconnect(self)
def disconnect(self, init=True, close=False, reconnect=False): def disconnect(self, reconnect=False):
self.event("disconnected") self.event("disconnected")
self.authenticated = False self.authenticated = False
self.sessionstarted = False self.sessionstarted = False
@ -248,19 +246,23 @@ class ClientXMPP(basexmpp, XMLStream):
response = iq.send() response = iq.send()
#response = self.send(iq, self.Iq(sid=iq['id'])) #response = self.send(iq, self.Iq(sid=iq['id']))
self.set_jid(response.xml.find('{urn:ietf:params:xml:ns:xmpp-bind}bind/{urn:ietf:params:xml:ns:xmpp-bind}jid').text) self.set_jid(response.xml.find('{urn:ietf:params:xml:ns:xmpp-bind}bind/{urn:ietf:params:xml:ns:xmpp-bind}jid').text)
self.bound = True
logging.info("Node set to: %s" % self.fulljid) logging.info("Node set to: %s" % self.fulljid)
if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features: if "{urn:ietf:params:xml:ns:xmpp-session}session" not in self.features or self.bindfail:
logging.debug("Established Session") logging.debug("Established Session")
self.sessionstarted = True self.sessionstarted = True
self.event("session_start") self.event("session_start")
def handler_start_session(self, xml): def handler_start_session(self, xml):
if self.authenticated: if self.authenticated and self.bound:
iq = self.makeIqSet(xml) iq = self.makeIqSet(xml)
response = iq.send() response = iq.send()
logging.debug("Established Session") logging.debug("Established Session")
self.sessionstarted = True self.sessionstarted = True
self.event("session_start") self.event("session_start")
else:
#bind probably hasn't happened yet
self.bindfail = True
def _handleRoster(self, iq, request=False): def _handleRoster(self, iq, request=False):
if iq['type'] == 'set' or (iq['type'] == 'result' and request): if iq['type'] == 'set' or (iq['type'] == 'result' and request):

View file

@ -1,9 +1,9 @@
""" """
SleekXMPP: The Sleek XMPP Library SleekXMPP: The Sleek XMPP Library
Copyright (C) 2010 Nathanael C. Fritz Copyright (C) 2010 Nathanael C. Fritz
This file is part of SleekXMPP. This file is part of SleekXMPP.
See the file license.txt for copying permission. See the file license.txt for copying permission.
""" """
from __future__ import with_statement, unicode_literals from __future__ import with_statement, unicode_literals
@ -91,21 +91,27 @@ class basexmpp(object):
if not self.plugin[idx].post_inited: self.plugin[idx].post_init() if not self.plugin[idx].post_inited: self.plugin[idx].post_init()
return super(basexmpp, self).process(*args, **kwargs) return super(basexmpp, self).process(*args, **kwargs)
def registerPlugin(self, plugin, pconfig = {}): def registerPlugin(self, plugin, pconfig = {}, pluginModule = None):
"""Register a plugin not in plugins.__init__.__all__ but in the plugins """Register a plugin not in plugins.__init__.__all__ but in the plugins
directory.""" directory."""
# discover relative "path" to the plugins module from the main app, and import it. # discover relative "path" to the plugins module from the main app, and import it.
# TODO: # TODO:
# gross, this probably isn't necessary anymore, especially for an installed module # gross, this probably isn't necessary anymore, especially for an installed module
__import__("%s.%s" % (globals()['plugins'].__name__, plugin)) try:
# init the plugin class if pluginModule:
self.plugin[plugin] = getattr(getattr(plugins, plugin), plugin)(self, pconfig) # eek module = __import__(pluginModule, globals(), locals(), [plugin])
# all of this for a nice debug? sure. else:
xep = '' module = __import__("%s.%s" % (globals()['plugins'].__name__, plugin), globals(), locals(), [plugin])
if hasattr(self.plugin[plugin], 'xep'): # init the plugin class
xep = "(XEP-%s) " % self.plugin[plugin].xep self.plugin[plugin] = getattr(module, plugin)(self, pconfig) # eek
logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description)) # all of this for a nice debug? sure.
xep = ''
if hasattr(self.plugin[plugin], 'xep'):
xep = "(XEP-%s) " % self.plugin[plugin].xep
logging.debug("Loaded Plugin %s%s" % (xep, self.plugin[plugin].description))
except:
logging.error("Unable to load plugin: %s" %(plugin) )
def register_plugins(self): def register_plugins(self):
"""Initiates all plugins in the plugins/__init__.__all__""" """Initiates all plugins in the plugins/__init__.__all__"""
if self.plugin_whitelist: if self.plugin_whitelist:

View file

@ -76,7 +76,7 @@ class Scheduler(object):
if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next) if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next)
except KeyboardInterrupt: except KeyboardInterrupt:
self.run = False self.run = False
logging.debug("Qutting Scheduler thread") logging.debug("Quitting Scheduler thread")
if self.parentqueue is not None: if self.parentqueue is not None:
self.parentqueue.put(('quit', None, None)) self.parentqueue.put(('quit', None, None))

View file

@ -10,6 +10,7 @@ import threading
import time import time
import logging import logging
class StateMachine(object): class StateMachine(object):
def __init__(self, states=[]): def __init__(self, states=[]):
@ -27,7 +28,7 @@ class StateMachine(object):
self.__states.append( state ) self.__states.append( state )
def transition(self, from_state, to_state, wait=0.0): def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
''' '''
Transition from the given `from_state` to the given `to_state`. Transition from the given `from_state` to the given `to_state`.
This method will return `True` if the state machine is now in `to_state`. It This method will return `True` if the state machine is now in `to_state`. It
@ -47,25 +48,37 @@ class StateMachine(object):
if thread_should_exit: return if thread_should_exit: return
# perform actions here after successful transition # perform actions here after successful transition
This allows the thread to be interrupted by setting `thread_should_exit=True` This allows the thread to be responsive by setting `thread_should_exit=True`.
The optional `func` argument allows the user to pass a callable operation which occurs
within the context of the state transition (e.g. while the state machine is locked.)
If `func` returns a True value, the transition will occur. If `func` returns a non-
True value or if an exception is thrown, the transition will not occur. Any thrown
exception is not caught by the state machine and is the caller's responsibility to handle.
If `func` completes normally, this method will return the value returned by `func.` If
values for `args` and `kwargs` are provided, they are expanded and passed like so:
`func( *args, **kwargs )`.
''' '''
return self.transition_any( (from_state,), to_state, wait=wait ) return self.transition_any( (from_state,), to_state, wait=wait,
func=func, args=args, kwargs=kwargs )
def transition_any(self, from_states, to_state, wait=0.0):
def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ):
''' '''
Transition from any of the given `from_states` to the given `to_state`. Transition from any of the given `from_states` to the given `to_state`.
''' '''
with self.lock: if not (isinstance(from_states,tuple) or isinstance(from_states,list)):
for state in from_states: raise ValueError( "from_states should be a list or tuple" )
if isinstance(state,tuple) or isinstance(state,list):
raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
if not state in self.__states:
raise ValueError( "StateMachine does not contain from_state %s." % state )
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
for state in from_states:
if not state in self.__states:
raise ValueError( "StateMachine does not contain from_state %s." % state )
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
with self.lock:
start = time.time() start = time.time()
while not self.__current_state in from_states: while not self.__current_state in from_states:
# detect timeout: # detect timeout:
@ -73,32 +86,78 @@ class StateMachine(object):
self.lock.wait(wait) self.lock.wait(wait)
if self.__current_state in from_states: # should always be True due to lock if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state self.__current_state = to_state
self.lock.notifyAll() self.lock.notifyAll()
return True return return_val # some 'true' value returned by func or True if func was None
else: else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False return False
def transition_ctx(self, from_state, to_state, wait=0.0):
'''
Use the state machine as a context manager. The transition occurs on /exit/ from
the `with` context, so long as no exception is thrown. For example:
::
with state_machine.transition_ctx('one','two', wait=5) as locked:
if locked:
# the state machine is currently locked in state 'one', and will
# transition to 'two' when the 'with' statement ends, so long as
# no exception is thrown.
print 'Currently locked in state one: %s' % state_machine['one']
else:
# The 'wait' timed out, and no lock has been acquired
print 'Timed out before entering state "one"'
print 'Since no exception was thrown, we are now in state "two": %s' % state_machine['two']
The other main difference between this method and `transition()` is that the
state machine is locked for the duration of the `with` statement (normally,
after a `transition() occurs, the state machine is immediately unlocked and
available to another thread to call `transition()` again.
'''
if not from_state in self.__states:
raise ValueError( "StateMachine does not contain from_state %s." % state )
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
return _StateCtx(self, from_state, to_state, wait)
def ensure(self, state, wait=0.0): def ensure(self, state, wait=0.0):
''' '''
Ensure the state machine is currently in `state`, or wait until it enters `state`. Ensure the state machine is currently in `state`, or wait until it enters `state`.
''' '''
return self.ensure_any( (state,), wait=wait ) return self.ensure_any( (state,), wait=wait )
def ensure_any(self, states, wait=0.0): def ensure_any(self, states, wait=0.0):
''' '''
Ensure we are currently in one of the given `states` Ensure we are currently in one of the given `states`
''' '''
with self.lock: if not (isinstance(states,tuple) or isinstance(states,list)):
for state in states: raise ValueError('states arg should be a tuple or list')
if isinstance(state,tuple) or isinstance(state,list):
raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) )
if not state in self.__states:
raise ValueError( "StateMachine does not contain state %s." % state )
for state in states:
if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state )
with self.lock:
start = time.time() start = time.time()
while not self.__current_state in states: while not self.__current_state in states:
# detect timeout: # detect timeout:
@ -110,7 +169,19 @@ class StateMachine(object):
def reset(self): def reset(self):
# TODO need to lock before calling this? # TODO need to lock before calling this?
self.transition(self.__current_state, self._default_state) self.transition(self.__current_state, self._default_state)
def _set_state(self, state): #unsynchronized, only call internally after lock is acquired
self.__current_state = state
return state
def current_state(self):
'''
Return the current state name.
'''
return self.__current_state
def __getitem__(self, state): def __getitem__(self, state):
''' '''
@ -118,13 +189,46 @@ class StateMachine(object):
Use `StateMachine.ensure(state)` to wait until the machine enters a certain state. Use `StateMachine.ensure(state)` to wait until the machine enters a certain state.
''' '''
return self.__current_state == state return self.__current_state == state
def __str__(self):
return "".join(( "StateMachine(", ','.join(self.__states), "): ", self.__current_state ))
class _StateCtx:
def __init__( self, state_machine, from_state, to_state, wait ):
self.state_machine = state_machine
self.from_state = from_state
self.to_state = to_state
self.wait = wait
self._timeout = False
def __enter__(self): def __enter__(self):
self.lock.acquire() self.state_machine.lock.acquire()
return self start = time.time()
while not self.state_machine[ self.from_state ]:
# detect timeout:
if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
self._timeout = True # to indicate we should not transition
return False
self.state_machine.lock.wait(self.wait)
logging.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() )
return True
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.lock.nofityAll() if exc_val is not None:
self.lock.release() logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
self.state_machine.current_state(), exc_type.__name__, exc_val )
elif not self._timeout:
logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.lock.notifyAll()
self.state_machine.lock.release()
return False # re-raise any exception return False # re-raise any exception

View file

@ -16,6 +16,7 @@ from . stanzabase import StanzaBase
from xml.etree import cElementTree from xml.etree import cElementTree
from xml.parsers import expat from xml.parsers import expat
import logging import logging
import random
import socket import socket
import threading import threading
import time import time
@ -46,6 +47,10 @@ class CloseStream(Exception):
stanza_extensions = {} stanza_extensions = {}
RECONNECT_MAX_DELAY = 3600
RECONNECT_QUIESCE_FACTOR = 1.6180339887498948 # Phi
RECONNECT_QUIESCE_JITTER = 0.11962656472 # molar Planck constant times c, joule meter/mole
class XMLStream(object): class XMLStream(object):
"A connection manager with XML events." "A connection manager with XML events."
@ -95,13 +100,29 @@ class XMLStream(object):
self.filesocket = filesocket self.filesocket = filesocket
def connect(self, host='', port=0, use_ssl=None, use_tls=None): def connect(self, host='', port=0, use_ssl=None, use_tls=None):
"Link to connectTCP" "Establish a socket connection to the given XMPP server."
if self.state.transition('disconnected', 'connecting'):
return self.connectTCP(host, port, use_ssl, use_tls) if not self.state.transition('disconnected','connected',
func=self.connectTCP, args=[host, port, use_ssl, use_tls] ):
if self.state['connected']: logging.debug('Already connected')
else: logging.warning("Connection failed" )
return False
logging.debug('Connection complete.')
return True
# TODO currently a caller can't distinguish between "connection failed" and
# "we're already trying to connect from another thread"
def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True): def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True):
"Connect and create socket" "Connect and create socket"
while reattempt and not self.state['connected']: # the self.state part is redundant.
# Note that this is thread-safe by merit of being called solely from connect() which
# holds the state lock.
delay = 1.0 # reconnection delay
while self.run:
logging.debug('connecting....') logging.debug('connecting....')
try: try:
if host and port: if host and port:
@ -109,27 +130,39 @@ class XMLStream(object):
if use_ssl is not None: if use_ssl is not None:
self.use_ssl = use_ssl self.use_ssl = use_ssl
if use_tls is not None: if use_tls is not None:
# TODO this variable doesn't seem to be used for anything!
self.use_tls = use_tls self.use_tls = use_tls
if sys.version_info < (3, 0): if sys.version_info < (3, 0):
self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM) self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM)
else: else:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(None) #10) self.socket.settimeout(None) #10)
if self.use_ssl and self.ssl_support: if self.use_ssl and self.ssl_support:
logging.debug("Socket Wrapped for SSL") logging.debug("Socket Wrapped for SSL")
self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs) self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs)
except:
logging.exception("Connection error")
try:
self.socket.connect(self.address) self.socket.connect(self.address)
self.filesocket = self.socket.makefile('rb', 0) self.filesocket = self.socket.makefile('rb', 0)
if not self.state.transition('connecting','connected'):
logging.error( "State transition error!!!! Shouldn't have happened" )
logging.debug('connect complete.')
return True return True
except socket.error as serr: except socket.error as serr:
logging.error("Could not connect. Socket Error #%s: %s" % (serr.errno, serr.strerror)) logging.exception("Socket Error #%s: %s", serr.errno, serr.strerror)
time.sleep(1) # TODO proper quiesce if connection attempt fails if not reattempt: return False
except:
logging.exception("Connection error")
if not reattempt: return False
# quiesce if rconnection fails:
# This algorithm based loosely on Twisted internet.protocol
# http://twistedmatrix.com/trac/browser/trunk/twisted/internet/protocol.py#L310
delay = min(delay * RECONNECT_QUIESCE_FACTOR, RECONNECT_MAX_DELAY)
delay = random.normalvariate(delay, delay * RECONNECT_QUIESCE_JITTER)
logging.debug('Waiting %fs until next reconnect attempt...', delay)
time.sleep(delay)
def connectUnix(self, filepath): def connectUnix(self, filepath):
"Connect to Unix file and create socket" "Connect to Unix file and create socket"
@ -244,11 +277,12 @@ class XMLStream(object):
data = None data = None
try: try:
data = self.sendqueue.get(True,10) data = self.sendqueue.get(True,5)
logging.debug("SEND: %s" % data) logging.debug("SEND: %s" % data)
self.socket.sendall(data.encode('utf-8')) self.socket.sendall(data.encode('utf-8'))
except queue.Empty: except queue.Empty:
logging.debug('nothing on send queue') # logging.debug('Nothing on send queue')
pass
except socket.timeout: except socket.timeout:
# this is to prevent a thread blocked indefinitely # this is to prevent a thread blocked indefinitely
logging.debug('timeout sending packet data') logging.debug('timeout sending packet data')
@ -335,6 +369,7 @@ class XMLStream(object):
try: try:
event = self.eventqueue.get(True, timeout=5) event = self.eventqueue.get(True, timeout=5)
except queue.Empty: except queue.Empty:
# logging.debug('Nothing on event queue')
event = None event = None
if event is not None: if event is not None:
etype = event[0] etype = event[0]

View file

@ -1,5 +1,5 @@
import unittest import unittest
import time, threading import time, threading, random, functools
if __name__ == '__main__': if __name__ == '__main__':
import sys, os import sys, os
@ -15,25 +15,23 @@ class testStateMachine(unittest.TestCase):
def testDefaults(self): def testDefaults(self):
"Test ensure transitions occur correctly in a single thread" "Test ensure transitions occur correctly in a single thread"
s = sm.StateMachine(('one','two','three')) s = sm.StateMachine(('one','two','three'))
# self.assertTrue(s.one)
self.assertTrue(s['one']) self.assertTrue(s['one'])
# self.failIf(s.two)
self.failIf(s['two']) self.failIf(s['two'])
try: try:
s.booga s['booga']
self.fail('s.booga is an invalid state and should throw an exception!') self.fail('s.booga is an invalid state and should throw an exception!')
except: pass #expected exception except: pass #expected exception
# just make sure __str__ works, no reason to test its exact value:
print str(s)
def testTransitions(self): def testTransitions(self):
"Test ensure transitions occur correctly in a single thread" "Test ensure transitions occur correctly in a single thread"
s = sm.StateMachine(('one','two','three')) s = sm.StateMachine(('one','two','three'))
# self.assertTrue(s.one)
self.assertTrue( s.transition('one', 'two') ) self.assertTrue( s.transition('one', 'two') )
# self.assertTrue( s.two )
self.assertTrue( s['two'] ) self.assertTrue( s['two'] )
# self.failIf( s.one )
self.failIf( s['one'] ) self.failIf( s['one'] )
self.assertTrue( s.transition('two', 'three') ) self.assertTrue( s.transition('two', 'three') )
@ -83,12 +81,12 @@ class testStateMachine(unittest.TestCase):
thread_state = {'ready': False, 'transitioned': False} thread_state = {'ready': False, 'transitioned': False}
def t1(): def t1():
# this will block until the main thread transitions to 'two'
if s['two']: if s['two']:
print 'thread has already transitioned!' print 'thread has already transitioned!'
self.fail() self.fail()
thread_state['ready'] = True thread_state['ready'] = True
print 'Thread is ready' print 'Thread is ready'
# this will block until the main thread transitions to 'two'
self.assertTrue( s.transition('two','three', wait=20) ) self.assertTrue( s.transition('two','three', wait=20) )
print 'transitioned to three!' print 'transitioned to three!'
thread_state['transitioned'] = True thread_state['transitioned'] = True
@ -109,6 +107,153 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( thread_state['transitioned'] ) self.assertTrue( thread_state['transitioned'] )
def testForRaceCondition(self):
"""Attempt to allow two threads to perform the same transition;
only one should ever make it."""
s = sm.StateMachine(('one','two','three'))
def t1(num):
while True:
if not trigger['go'] or thread_state[num] in (True,False):
time.sleep( random.random()/100 ) # < .01s
if thread_state[num] == 'quit': break
continue
thread_state[num] = s.transition('one','two' )
# print '-',
thread_count = 20
threads = []
thread_state = {}
def reset():
for c in range(thread_count): thread_state[c] = "reset"
trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.
for c in range(thread_count):
thread_state[c] = "reset"
thread = threading.Thread( target= functools.partial(t1,c) )
threads.append( thread )
thread.daemon = True
thread.start()
for x in range(100): # this will take 10s to execute
# print "+",
trigger['go'] = True
time.sleep(.1)
trigger['go'] = False
winners = 0
for (num, state) in thread_state.items():
if state == True: winners = winners +1
elif state != False: raise Exception( "!%d!%s!" % (num,state) )
self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
self.assertTrue( s.ensure('two') )
self.assertTrue( s.transition('two','one') ) # return to the first state.
reset()
# now let the threads quit gracefully:
for c in range(thread_count): thread_state[c] = 'quit'
time.sleep(2)
def testTransitionFunctions(self):
"test that a `func` argument allows or blocks the transition correctly."
s = sm.StateMachine(('one','two','three'))
def alwaysFalse(): return False
def alwaysTrue(): return True
self.failIf( s.transition('one','two', func=alwaysFalse) )
self.assertTrue(s['one'])
self.failIf(s['two'])
self.assertTrue( s.transition('one','two', func=alwaysTrue) )
self.failIf(s['one'])
self.assertTrue(s['two'])
def testTransitionFuncException(self):
"if a transition function throws an exeption, ensure we're in a sane state"
s = sm.StateMachine(('one','two','three'))
def alwaysException(): raise Exception('whups!')
try:
self.failIf( s.transition('one','two', func=alwaysException) )
self.fail("exception should have been thrown")
except: pass #expected exception
self.assertTrue(s['one'])
self.failIf(s['two'])
# ensure a subsequent attempt completes normally:
self.assertTrue( s.transition('one','two') )
self.failIf(s['one'])
self.assertTrue(s['two'])
def testContextManager(self):
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('one','two'):
self.assertTrue( s['one'] )
self.failIf( s['two'] )
#successful transition b/c no exception was thrown
self.assertTrue( s['two'] )
self.failIf( s['one'] )
# failed transition because exception is thrown:
try:
with s.transition_ctx('two','three'):
raise Exception("boom!")
self.fail('exception expected')
except: pass
self.failIf( s.current_state() in ('one','three') )
self.assertTrue( s['two'] )
def testCtxManagerTransitionFailure(self):
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('two','three') as result:
self.failIf( result )
self.assertTrue( s['one'] )
self.failIf( s.current_state in ('two','three') )
self.assertTrue( s['one'] )
def r1():
print 'thread 1 started'
self.assertTrue( s.transition('one','two') )
print 'thread 1 transitioned'
def r2():
print 'thread 2 started'
self.failIf( s['two'] )
with s.transition_ctx('two','three', 10) as result:
self.assertTrue( result )
self.assertTrue( s['two'] )
print 'thread 2 will transition on exit from the context manager...'
self.assertTrue( s['three'] )
print 'transitioned to %s' % s.current_state()
t1 = threading.Thread(target=r1)
t2 = threading.Thread(target=r2)
t2.start() # this should block until r1 goes
time.sleep(1)
t1.start()
t1.join()
t2.join()
self.assertTrue( s['three'] )
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine) suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)