Merge remote branch 'tom/hacks'

This commit is contained in:
Brian Beggs 2010-06-04 12:52:52 -04:00
commit 1aa34cb0fc
3 changed files with 128 additions and 17 deletions

View file

@ -10,6 +10,7 @@ import threading
import time
import logging
class StateMachine(object):
def __init__(self, states=[]):
@ -27,7 +28,7 @@ class StateMachine(object):
self.__states.append( state )
def transition(self, from_state, to_state, wait=0.0):
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
'''
Transition from the given `from_state` to the given `to_state`.
This method will return `True` if the state machine is now in `to_state`. It
@ -47,12 +48,22 @@ class StateMachine(object):
if thread_should_exit: return
# perform actions here after successful transition
This allows the thread to be interrupted by setting `thread_should_exit=True`
This allows the thread to be responsive by setting `thread_should_exit=True`.
The optional `func` argument allows the user to pass a callable operation which occurs
within the context of the state transition (e.g. while the state machine is locked.)
If `func` returns a True value, the transition will occur. If `func` returns a non-
True value or if an exception is thrown, the transition will not occur. Any thrown
exception is not caught by the state machine and is the caller's responsibility to handle.
If `func` completes normally, this method will return the value returned by `func.` If
values for `args` and `kwargs` are provided, they are expanded and passed like so:
`func( *args, **kwargs )`.
'''
return self.transition_any( (from_state,), to_state, wait=wait )
return self.transition_any( (from_state,), to_state, wait=wait,
func=func, args=args, kwargs=kwargs )
def transition_any(self, from_states, to_state, wait=0.0):
def transition_any(self, from_states, to_state, wait=0.0, func=None, args=[], kwargs={} ):
'''
Transition from any of the given `from_states` to the given `to_state`.
'''
@ -73,10 +84,19 @@ class StateMachine(object):
self.lock.wait(wait)
if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state
self.lock.notifyAll()
return True
return return_val # some 'true' value returned by func or True if func was None
else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False

View file

@ -100,18 +100,20 @@ class XMLStream(object):
self.filesocket = filesocket
def connect(self, host='', port=0, use_ssl=None, use_tls=None):
"Link to connectTCP"
"Establish a socket connection to the given XMPP server."
if not self.state.transition('disconnected','connecting'):
logging.warning("Can't connect now; Already in state %s", self.state.current_state())
return False
if not self.connectTCP(host, port, use_ssl, use_tls):
# return to the 'disconnected' state if connect failed:
# otherwise the connect method is not reentrant
if not self.state.transition('connecting','disconnected'):
logging.error("Couldn't transition to the 'disconnected' state!")
return False
return True
try:
return self.connectTCP(host, port, use_ssl, use_tls)
finally:
# attempt to ensure once a connection attempt starts, we leave either in the
# 'connected' or 'disconnected' state. Otherwise the connect method is not reentrant
if self.state['connecting']:
if not self.state.transition('connecting','disconnected'):
logging.error("Couldn't return to the 'disconnected' state after connection failure!")
# TODO currently a caller can't distinguish between "connection failed" and
# "we're already trying to connect from another thread"
@ -281,11 +283,12 @@ class XMLStream(object):
data = None
try:
data = self.sendqueue.get(True,10)
data = self.sendqueue.get(True,5)
logging.debug("SEND: %s" % data)
self.socket.sendall(data.encode('utf-8'))
except queue.Empty:
logging.debug('nothing on send queue')
# logging.debug('Nothing on send queue')
pass
except socket.timeout:
# this is to prevent a thread blocked indefinitely
logging.debug('timeout sending packet data')
@ -372,6 +375,7 @@ class XMLStream(object):
try:
event = self.eventqueue.get(True, timeout=5)
except queue.Empty:
# logging.debug('Nothing on event queue')
event = None
if event is not None:
etype = event[0]

View file

@ -1,5 +1,5 @@
import unittest
import time, threading
import time, threading, random, functools
if __name__ == '__main__':
import sys, os
@ -83,12 +83,12 @@ class testStateMachine(unittest.TestCase):
thread_state = {'ready': False, 'transitioned': False}
def t1():
# this will block until the main thread transitions to 'two'
if s['two']:
print 'thread has already transitioned!'
self.fail()
thread_state['ready'] = True
print 'Thread is ready'
# this will block until the main thread transitions to 'two'
self.assertTrue( s.transition('two','three', wait=20) )
print 'transitioned to three!'
thread_state['transitioned'] = True
@ -109,6 +109,93 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( thread_state['transitioned'] )
def testForRaceCondition(self):
"""Attempt to allow two threads to perform the same transition;
only one should ever make it."""
s = sm.StateMachine(('one','two','three'))
def t1(num):
while True:
if not trigger['go'] or thread_state[num] in (True,False):
time.sleep( random.random()/100 ) # < .01s
if thread_state[num] == 'quit': break
continue
thread_state[num] = s.transition('one','two' )
# print '-',
thread_count = 20
threads = []
thread_state = {}
def reset():
for c in range(thread_count): thread_state[c] = "reset"
trigger = {'go':False} # use of a plain boolean seems to be non-volatile between threads.
for c in range(thread_count):
thread_state[c] = "reset"
thread = threading.Thread( target= functools.partial(t1,c) )
threads.append( thread )
thread.daemon = True
thread.start()
for x in range(100): # this will take 10s to execute
# print "+",
trigger['go'] = True
time.sleep(.1)
trigger['go'] = False
winners = 0
for (num, state) in thread_state.items():
if state == True: winners = winners +1
elif state != False: raise Exception( "!%d!%s!" % (num,state) )
self.assertEqual( 1, winners, "Expected one winner! %d" % winners )
self.assertTrue( s.ensure('two') )
self.assertTrue( s.transition('two','one') ) # return to the first state.
reset()
# now let the threads quit gracefully:
for c in range(thread_count): thread_state[c] = 'quit'
time.sleep(2)
def testTransitionFunctions(self):
"test that a `func` argument allows or blocks the transition correctly."
s = sm.StateMachine(('one','two','three'))
def alwaysFalse(): return False
def alwaysTrue(): return True
self.failIf( s.transition('one','two', func=alwaysFalse) )
self.assertTrue(s['one'])
self.failIf(s['two'])
self.assertTrue( s.transition('one','two', func=alwaysTrue) )
self.failIf(s['one'])
self.assertTrue(s['two'])
def testTransitionFuncException(self):
"if a transition function throws an exeption, ensure we're in a sane state"
s = sm.StateMachine(('one','two','three'))
def alwaysException(): raise Exception('whups!')
try:
self.failIf( s.transition('one','two', func=alwaysException) )
self.fail("exception should have been thrown")
except: pass #expected exception
self.assertTrue(s['one'])
self.failIf(s['two'])
# ensure a subsequent attempt completes normally:
self.assertTrue( s.transition('one','two') )
self.failIf(s['one'])
self.assertTrue(s['two'])
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)