fix for statemachine where operations would unintentionally block if the lock was acquired in a long-running transition

This commit is contained in:
Tom Nichols 2010-07-01 15:10:22 -04:00
parent 8bdfa77024
commit 0a23f84ec3
2 changed files with 122 additions and 35 deletions

View file

@ -5,7 +5,6 @@
See the file license.txt for copying permission.
"""
from __future__ import with_statement
import threading
import time
import logging
@ -14,18 +13,21 @@ import logging
class StateMachine(object):
def __init__(self, states=[]):
self.lock = threading.Condition(threading.RLock())
self.lock = threading.Lock()
self.notifier = threading.Event()
self.__states= []
self.addStates(states)
self.__default_state = self.__states[0]
self.__current_state = self.__default_state
def addStates(self, states):
with self.lock:
self.lock.acquire()
try:
for state in states:
if state in self.__states:
raise IndexError("The state '%s' is already in the StateMachine." % state)
self.__states.append( state )
finally: self.lock.release()
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
@ -78,30 +80,33 @@ class StateMachine(object):
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
with self.lock:
start = time.time()
while not self.__current_state in from_states:
# detect timeout:
if time.time() >= start + wait: return False
self.lock.wait(wait)
start = time.time()
while not self.__current_state in from_states or not self.lock.acquire(False):
# detect timeout:
if time.time() >= start + wait: return False
self.notifier.wait(wait)
try: # lock is acquired; all other threads will return false or wait until notify/timeout
self.notifier.clear()
if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
return_val = func(*args,**kwargs) if func is not None else True
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state
self.lock.notify_all()
self._set_state( to_state )
return return_val # some 'true' value returned by func or True if func was None
else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False
finally:
self.notifier.set()
self.lock.release()
def transition_ctx(self, from_state, to_state, wait=0.0):
@ -148,7 +153,15 @@ class StateMachine(object):
def ensure_any(self, states, wait=0.0):
'''
Ensure we are currently in one of the given `states`
Ensure we are currently in one of the given `states` or wait until
we enter one of those states.
Note that due to the nature of the function, you cannot guarantee that
the entirety of some operation completes while you remain in a given
state. That would require acquiring and holding a lock, which
would mean no other threads could do the same. (You'd essentially
be serializing all of the threads that are 'ensuring' their tasks
occurred in some state.
'''
if not (isinstance(states,tuple) or isinstance(states,list)):
raise ValueError('states arg should be a tuple or list')
@ -157,13 +170,17 @@ class StateMachine(object):
if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state )
with self.lock:
start = time.time()
while not self.__current_state in states:
# detect timeout:
if time.time() >= start + wait: return False
self.lock.wait(wait)
return self.__current_state in states # should always be True due to lock
# Locking never really gained us anything here, since the lock was released
# before the function returned anyways. The only thing it _did_ do was
# increase the probability that this function would block for longer than
# intended if a `transition` function or context was running while holding
# the lock.
start = time.time()
while not self.__current_state in states:
# detect timeout:
if time.time() >= start + wait: return False
self.notifier.wait(wait)
return True
def reset(self):
@ -202,19 +219,19 @@ class _StateCtx:
self.from_state = from_state
self.to_state = to_state
self.wait = wait
self._timeout = False
self._locked = False
def __enter__(self):
self.state_machine.lock.acquire()
start = time.time()
while not self.state_machine[ self.from_state ]:
while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False):
# detect timeout:
if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
self._timeout = True # to indicate we should not transition
return False
self.state_machine.lock.wait(self.wait)
self.state_machine.notifier.wait(self.wait)
self._locked = True # lock has been acquired at this point
self.state_machine.notifier.clear()
logging.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() )
return True
@ -222,13 +239,16 @@ class _StateCtx:
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is not None:
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
self.state_machine.current_state(), exc_type.__name__, exc_val )
elif not self._timeout:
logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.current_state(), exc_type.__name__, exc_val )
if self._locked:
if exc_val is None:
logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.notifier.set()
self.state_machine.lock.release()
self.state_machine.lock.notify_all()
self.state_machine.lock.release()
return False # re-raise any exception

View file

@ -256,6 +256,73 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( s['three'] )
def testTransitionsDontUnintentionallyBlock(self):
'''
There was a bug where a long-running transition (e.g. one with a 'func'
arg or a `transition_ctx` call would cause any `transition` or `ensure`
call to block since the lock is acquired before checking the current
state. Attempts to acquire the mutex need to be non-blocking so when a
timeout is _not_ given, the caller can return immediately. At the same
time, threads that _do_ want to wait need the ability to be notified
(to avoid waiting beyond when the lock is released) so we've moved to a
combination of a plain-ol `threading.Lock` to act as mutex, and a
`threading.Event` to perform notification for threads who choose to wait.
'''
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('two','three') as result:
self.failIf( result )
self.assertTrue( s['one'] )
self.failIf( s.current_state in ('two','three') )
self.assertTrue( s['one'] )
statuses = {'t1':"not started",
't2':'not started'}
def t1():
print 'thread 1 started'
# no wait, so this should 'return False' immediately.
self.failIf( s.transition('two','three') )
statuses['t1'] = 'complete'
print 'thread 1 transitioned'
def t2():
print 'thread 2 started'
self.failIf( s['two'] )
self.failIf( s['three'] )
# we want this thread to acquire the lock, but for
# the second thread not to wait on the first.
with s.transition_ctx('one','two', 10) as locked:
statuses['t2'] = 'started'
print 'thread 2 has entered context'
self.assertTrue( locked )
# give thread1 a chance to complete while this
# thread still owns the lock
time.sleep(5)
self.assertTrue( s['two'] )
statuses['t2'] = 'complete'
t1 = threading.Thread(target=t1)
t2 = threading.Thread(target=t2)
t2.start() # this should acquire the lock
time.sleep(.2)
self.assertEqual( 'started', statuses['t2'] )
t1.start() # but it shouldn't prevent thread 1 from completing
time.sleep(1)
self.assertEqual( 'complete', statuses['t1'] )
t1.join()
t2.join()
self.assertEqual( 'complete', statuses['t2'] )
self.assertTrue( s['two'] )
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
if __name__ == '__main__': unittest.main()