Merge remote branch 'tom/hacks'

This commit is contained in:
Brian Beggs 2010-06-04 12:52:52 -04:00
commit 1aa34cb0fc
3 changed files with 128 additions and 17 deletions

View file

@ -10,6 +10,7 @@ import threading
import time import time
import logging import logging
class StateMachine(object): class StateMachine(object):
def __init__(self, states=[]): def __init__(self, states=[]):
@ -27,7 +28,7 @@ class StateMachine(object):
self.__states.append( state ) self.__states.append( state )
def transition(self, from_state, to_state, wait=0.0): def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
''' '''
Transition from the given `from_state` to the given `to_state`. Transition from the given `from_state` to the given `to_state`.
This method will return `True` if the state machine is now in `to_state`. It This method will return `True` if the state machine is now in `to_state`. It
@ -47,12 +48,22 @@ class StateMachine(object):
if thread_should_exit: return if thread_should_exit: return
# perform actions here after successful transition # perform actions here after successful transition
This allows the thread to be interrupted by setting `thread_should_exit=True` This allows the thread to be responsive by setting `thread_should_exit=True`.
The optional `func` argument allows the user to pass a callable operation which occurs
within the context of the state transition (e.g. while the state machine is locked.)
If `func` returns a True value, the transition will occur. If `func` returns a non-
True value or if an exception is thrown, the transition will not occur. Any thrown
exception is not caught by the state machine and is the caller's responsibility to handle.
If `func` completes normally, this method will return the value returned by `func.` If
values for `args` and `kwargs` are provided, they are expanded and passed like so:
`func( *args, **kwargs )`.
''' '''
return self.transition_any( (from_state,), to_state, wait=wait ) return self.transition_any( (from_state,), to_state, wait=wait,
func=func, args=args, kwargs=kwargs )
def transition_any(self, from_states, to_state, wait=0.0): def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ):
''' '''
Transition from any of the given `from_states` to the given `to_state`. Transition from any of the given `from_states` to the given `to_state`.
''' '''
@ -73,10 +84,19 @@ class StateMachine(object):
self.lock.wait(wait) self.lock.wait(wait)
if self.__current_state in from_states: # should always be True due to lock if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state self.__current_state = to_state
self.lock.notifyAll() self.lock.notifyAll()
return True return return_val # some 'true' value returned by func or True if func was None
else: else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False return False

View file

@ -100,18 +100,20 @@ class XMLStream(object):
self.filesocket = filesocket self.filesocket = filesocket
def connect(self, host='', port=0, use_ssl=None, use_tls=None): def connect(self, host='', port=0, use_ssl=None, use_tls=None):
"Link to connectTCP" "Establish a socket connection to the given XMPP server."
if not self.state.transition('disconnected','connecting'): if not self.state.transition('disconnected','connecting'):
logging.warning("Can't connect now; Already in state %s", self.state.current_state()) logging.warning("Can't connect now; Already in state %s", self.state.current_state())
return False return False
if not self.connectTCP(host, port, use_ssl, use_tls): try:
# return to the 'disconnected' state if connect failed: return self.connectTCP(host, port, use_ssl, use_tls)
# otherwise the connect method is not reentrant finally:
# attempt to ensure once a connection attempt starts, we leave either in the
# 'connected' or 'disconnected' state. Otherwise the connect method is not reentrant
if self.state['connecting']:
if not self.state.transition('connecting','disconnected'): if not self.state.transition('connecting','disconnected'):
logging.error("Couldn't transition to the 'disconnected' state!") logging.error("Couldn't return to the 'disconnected' state after connection failure!")
return False
return True
# TODO currently a caller can't distinguish between "connection failed" and # TODO currently a caller can't distinguish between "connection failed" and
# "we're already trying to connect from another thread" # "we're already trying to connect from another thread"
@ -281,11 +283,12 @@ class XMLStream(object):
data = None data = None
try: try:
data = self.sendqueue.get(True,10) data = self.sendqueue.get(True,5)
logging.debug("SEND: %s" % data) logging.debug("SEND: %s" % data)
self.socket.sendall(data.encode('utf-8')) self.socket.sendall(data.encode('utf-8'))
except queue.Empty: except queue.Empty:
logging.debug('nothing on send queue') # logging.debug('Nothing on send queue')
pass
except socket.timeout: except socket.timeout:
# this is to prevent a thread blocked indefinitely # this is to prevent a thread blocked indefinitely
logging.debug('timeout sending packet data') logging.debug('timeout sending packet data')
@ -372,6 +375,7 @@ class XMLStream(object):
try: try:
event = self.eventqueue.get(True, timeout=5) event = self.eventqueue.get(True, timeout=5)
except queue.Empty: except queue.Empty:
# logging.debug('Nothing on event queue')
event = None event = None
if event is not None: if event is not None:
etype = event[0] etype = event[0]

View file

@ -1,5 +1,5 @@
import unittest import unittest
import time, threading import time, threading, random, functools
if __name__ == '__main__': if __name__ == '__main__':
import sys, os import sys, os
@ -83,12 +83,12 @@ class testStateMachine(unittest.TestCase):
thread_state = {'ready': False, 'transitioned': False} thread_state = {'ready': False, 'transitioned': False}
def t1(): def t1():
# this will block until the main thread transitions to 'two'
if s['two']: if s['two']:
print 'thread has already transitioned!' print 'thread has already transitioned!'
self.fail() self.fail()
thread_state['ready'] = True thread_state['ready'] = True
print 'Thread is ready' print 'Thread is ready'
# this will block until the main thread transitions to 'two'
self.assertTrue( s.transition('two','three', wait=20) ) self.assertTrue( s.transition('two','three', wait=20) )
print 'transitioned to three!' print 'transitioned to three!'
thread_state['transitioned'] = True thread_state['transitioned'] = True
@ -109,6 +109,93 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( thread_state['transitioned'] ) self.assertTrue( thread_state['transitioned'] )
def testForRaceCondition(self):
"""Attempt to allow two threads to perform the same transition;
only one should ever make it."""
s = sm.StateMachine(('one','two','three'))
def t1(num):
while True:
if not trigger['go'] or thread_state[num] in (True,False):
time.sleep( random.random()/100 ) # < .01s
if thread_state[num] == 'quit': break
continue
thread_state[num] = s.transition('one','two' )
# print '-',
thread_count = 20
threads = []
thread_state = {}
def reset():
for c in range(thread_count): thread_state[c] = "reset"
trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.
for c in range(thread_count):
thread_state[c] = "reset"
thread = threading.Thread( target= functools.partial(t1,c) )
threads.append( thread )
thread.daemon = True
thread.start()
for x in range(100): # this will take 10s to execute
# print "+",
trigger['go'] = True
time.sleep(.1)
trigger['go'] = False
winners = 0
for (num, state) in thread_state.items():
if state == True: winners = winners +1
elif state != False: raise Exception( "!%d!%s!" % (num,state) )
self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
self.assertTrue( s.ensure('two') )
self.assertTrue( s.transition('two','one') ) # return to the first state.
reset()
# now let the threads quit gracefully:
for c in range(thread_count): thread_state[c] = 'quit'
time.sleep(2)
def testTransitionFunctions(self):
"test that a `func` argument allows or blocks the transition correctly."
s = sm.StateMachine(('one','two','three'))
def alwaysFalse(): return False
def alwaysTrue(): return True
self.failIf( s.transition('one','two', func=alwaysFalse) )
self.assertTrue(s['one'])
self.failIf(s['two'])
self.assertTrue( s.transition('one','two', func=alwaysTrue) )
self.failIf(s['one'])
self.assertTrue(s['two'])
def testTransitionFuncException(self):
"if a transition function throws an exeption, ensure we're in a sane state"
s = sm.StateMachine(('one','two','three'))
def alwaysException(): raise Exception('whups!')
try:
self.failIf( s.transition('one','two', func=alwaysException) )
self.fail("exception should have been thrown")
except: pass #expected exception
self.assertTrue(s['one'])
self.failIf(s['two'])
# ensure a subsequent attempt completes normally:
self.assertTrue( s.transition('one','two') )
self.failIf(s['one'])
self.assertTrue(s['two'])
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine) suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)