diff --git a/README b/README index a698909..abc2d09 100644 --- a/README +++ b/README @@ -1,5 +1,8 @@ -SleekXMPP is an XMPP library written for Python 3.x (with 2.6 compatibility). +SleekXMPP is an XMPP library written for Python 3.1+ (with 2.6 compatibility). +Hosted at http://wiki.github.com/fritzy/SleekXMPP/ + Featured in examples in XMPP: The Definitive Guide by Kevin Smith, Remko Tronçon, and Peter Saint-Andre +If you're coming here from The Definitive Guide, please read http://wiki.github.com/fritzy/SleekXMPP/xmpp-the-definitive-guide SleekXMPP has several design goals/philosophies: - Low number of dependencies. @@ -31,7 +34,9 @@ Since 0.2, here's the Changelog: Credits ---------------- Main Author: Nathan Fritz fritz@netflint.net -XEP-0045 original implementation: Kevin Smith +Contributors: Kevin Smith & Lance Stout Patches: Remko Tronçon Feel free to add fritzy@netflint.net to your roster for direct support and comments. +Join sleekxmpp-discussion@googlegroups.com / http://groups.google.com/group/sleekxmpp-discussion for email discussion. +Join sleek@conference.jabber.org for groupchat discussion. diff --git a/conn_tests/test_pubsubjobs.py b/conn_tests/test_pubsubjobs.py new file mode 100644 index 0000000..edf22cc --- /dev/null +++ b/conn_tests/test_pubsubjobs.py @@ -0,0 +1,171 @@ +import logging +import sleekxmpp +from optparse import OptionParser +from xml.etree import cElementTree as ET +import os +import time +import sys +import unittest +import sleekxmpp.plugins.xep_0004 +from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath +from sleekxmpp.xmlstream.handler.waiter import Waiter +try: + import configparser +except ImportError: + import ConfigParser as configparser +try: + import queue +except ImportError: + import Queue as queue + +class TestClient(sleekxmpp.ClientXMPP): + def __init__(self, jid, password): + sleekxmpp.ClientXMPP.__init__(self, jid, password) + self.add_event_handler("session_start", self.start) + #self.add_event_handler("message", self.message) + self.waitforstart = queue.Queue() + + def start(self, event): + self.getRoster() + self.sendPresence() + self.waitforstart.put(True) + + +class TestPubsubServer(unittest.TestCase): + statev = {} + + def __init__(self, *args, **kwargs): + unittest.TestCase.__init__(self, *args, **kwargs) + + def setUp(self): + pass + + def test001getdefaultconfig(self): + """Get the default node config""" + self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode2') + self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode3') + self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode4') + self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode5') + result = self.xmpp1['xep_0060'].getNodeConfig(self.pshost) + self.statev['defaultconfig'] = result + self.failUnless(isinstance(result, sleekxmpp.plugins.xep_0004.Form)) + + def test002createdefaultnode(self): + """Create a node without config""" + self.failUnless(self.xmpp1['xep_0060'].create_node(self.pshost, 'testnode1')) + + def test003deletenode(self): + """Delete recently created node""" + self.failUnless(self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode1')) + + def test004createnode(self): + """Create a node with a config""" + self.statev['defaultconfig'].field['pubsub#access_model'].setValue('open') + self.statev['defaultconfig'].field['pubsub#notify_retract'].setValue(True) + self.statev['defaultconfig'].field['pubsub#persist_items'].setValue(True) + self.statev['defaultconfig'].field['pubsub#presence_based_delivery'].setValue(True) + p = self.xmpp2.Presence() + p['to'] = self.pshost + p.send() + self.failUnless(self.xmpp1['xep_0060'].create_node(self.pshost, 'testnode2', self.statev['defaultconfig'], ntype='job')) + + def test005reconfigure(self): + """Retrieving node config and reconfiguring""" + nconfig = self.xmpp1['xep_0060'].getNodeConfig(self.pshost, 'testnode2') + self.failUnless(nconfig, "No configuration returned") + #print("\n%s ==\n %s" % (nconfig.getValues(), self.statev['defaultconfig'].getValues())) + self.failUnless(nconfig.getValues() == self.statev['defaultconfig'].getValues(), "Configuration does not match") + self.failUnless(self.xmpp1['xep_0060'].setNodeConfig(self.pshost, 'testnode2', nconfig)) + + def test006subscribetonode(self): + """Subscribe to node from account 2""" + self.failUnless(self.xmpp2['xep_0060'].subscribe(self.pshost, "testnode2")) + + def test007publishitem(self): + """Publishing item""" + item = ET.Element('{http://netflint.net/protocol/test}test') + w = Waiter('wait publish', StanzaPath('message/pubsub_event/items')) + self.xmpp2.registerHandler(w) + #result = self.xmpp1['xep_0060'].setItem(self.pshost, "testnode2", (('test1', item),)) + result = self.xmpp1['jobs'].createJob(self.pshost, "testnode2", 'test1', item) + msg = w.wait(5) # got to get a result in 5 seconds + self.failUnless(msg != False, "Account #2 did not get message event") + #result = self.xmpp1['xep_0060'].setItem(self.pshost, "testnode2", (('test2', item),)) + result = self.xmpp1['jobs'].createJob(self.pshost, "testnode2", 'test2', item) + w = Waiter('wait publish2', StanzaPath('message/pubsub_event/items')) + self.xmpp2.registerHandler(w) + self.xmpp2['jobs'].claimJob(self.pshost, 'testnode2', 'test1') + msg = w.wait(5) # got to get a result in 5 seconds + self.xmpp2['jobs'].claimJob(self.pshost, 'testnode2', 'test2') + self.xmpp2['jobs'].finishJob(self.pshost, 'testnode2', 'test1') + self.xmpp2['jobs'].finishJob(self.pshost, 'testnode2', 'test2') + print result + #need to add check for update + + def test900cleanup(self): + "Cleaning up" + #self.failUnless(self.xmpp1['xep_0060'].deleteNode(self.pshost, 'testnode2'), "Could not delete test node.") + time.sleep(10) + + +if __name__ == '__main__': + #parse command line arguements + optp = OptionParser() + optp.add_option('-q','--quiet', help='set logging to ERROR', action='store_const', dest='loglevel', const=logging.ERROR, default=logging.INFO) + optp.add_option('-d','--debug', help='set logging to DEBUG', action='store_const', dest='loglevel', const=logging.DEBUG, default=logging.INFO) + optp.add_option('-v','--verbose', help='set logging to COMM', action='store_const', dest='loglevel', const=5, default=logging.INFO) + optp.add_option("-c","--config", dest="configfile", default="config.xml", help="set config file to use") + optp.add_option("-n","--nodenum", dest="nodenum", default="1", help="set node number to use") + optp.add_option("-p","--pubsub", dest="pubsub", default="1", help="set pubsub host to use") + opts,args = optp.parse_args() + + logging.basicConfig(level=opts.loglevel, format='%(levelname)-8s %(message)s') + + #load xml config + logging.info("Loading config file: %s" % opts.configfile) + config = configparser.RawConfigParser() + config.read(opts.configfile) + + #init + logging.info("Account 1 is %s" % config.get('account1', 'jid')) + xmpp1 = TestClient(config.get('account1','jid'), config.get('account1','pass')) + logging.info("Account 2 is %s" % config.get('account2', 'jid')) + xmpp2 = TestClient(config.get('account2','jid'), config.get('account2','pass')) + + xmpp1.registerPlugin('xep_0004') + xmpp1.registerPlugin('xep_0030') + xmpp1.registerPlugin('xep_0060') + xmpp1.registerPlugin('xep_0199') + xmpp1.registerPlugin('jobs') + xmpp2.registerPlugin('xep_0004') + xmpp2.registerPlugin('xep_0030') + xmpp2.registerPlugin('xep_0060') + xmpp2.registerPlugin('xep_0199') + xmpp2.registerPlugin('jobs') + + if not config.get('account1', 'server'): + # we don't know the server, but the lib can probably figure it out + xmpp1.connect() + else: + xmpp1.connect((config.get('account1', 'server'), 5222)) + xmpp1.process(threaded=True) + + #init + if not config.get('account2', 'server'): + # we don't know the server, but the lib can probably figure it out + xmpp2.connect() + else: + xmpp2.connect((config.get('account2', 'server'), 5222)) + xmpp2.process(threaded=True) + + TestPubsubServer.xmpp1 = xmpp1 + TestPubsubServer.xmpp2 = xmpp2 + TestPubsubServer.pshost = config.get('settings', 'pubsub') + xmpp1.waitforstart.get(True) + xmpp2.waitforstart.get(True) + testsuite = unittest.TestLoader().loadTestsFromTestCase(TestPubsubServer) + + alltests_suite = unittest.TestSuite([testsuite]) + result = unittest.TextTestRunner(verbosity=2).run(alltests_suite) + xmpp1.disconnect() + xmpp2.disconnect() diff --git a/conn_tests/test_pubsubserver.py b/conn_tests/test_pubsubserver.py index d1e2208..15635b4 100644 --- a/conn_tests/test_pubsubserver.py +++ b/conn_tests/test_pubsubserver.py @@ -5,7 +5,6 @@ from xml.etree import cElementTree as ET import os import time import sys -import thread import unittest import sleekxmpp.plugins.xep_0004 from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath diff --git a/example.py b/example.py index 1ffe724..c9b6559 100644 --- a/example.py +++ b/example.py @@ -37,8 +37,8 @@ if __name__ == '__main__': logging.basicConfig(level=opts.loglevel, format='%(levelname)-8s %(message)s') xmpp = Example('user@gmail.com/sleekxmpp', 'password') + xmpp.registerPlugin('xep_0030') xmpp.registerPlugin('xep_0004') - xmpp.registerPlugin('xep_0030') xmpp.registerPlugin('xep_0060') xmpp.registerPlugin('xep_0199') if xmpp.connect(('talk.google.com', 5222)): diff --git a/sleekxmpp/__init__.py b/sleekxmpp/__init__.py index 954ca99..6995f7c 100644 --- a/sleekxmpp/__init__.py +++ b/sleekxmpp/__init__.py @@ -94,6 +94,8 @@ class ClientXMPP(basexmpp, XMLStream): """Connect to the Jabber Server. Attempts SRV lookup, and if it fails, uses the JID server.""" + if self.state['connected']: return True + if host: self.server = host if port is None: port = self.port @@ -174,6 +176,7 @@ class ClientXMPP(basexmpp, XMLStream): self._handleRoster(iq, request=True) def _handleStreamFeatures(self, features): + logging.debug('handling stream features') self.features = [] for sub in features.xml: self.features.append(sub.tag) @@ -181,12 +184,16 @@ class ClientXMPP(basexmpp, XMLStream): for feature in self.registered_features: if feature[0].match(subelement): #if self.maskcmp(subelement, feature[0], True): + # This calls the feature handler & optionally breaks if feature[1](subelement) and feature[2]: #if breaker, don't continue return True def handler_starttls(self, xml): + logging.debug( 'TLS start handler; SSL support: %s', self.ssl_support ) if not self.authenticated and self.ssl_support: - self.add_handler("", self.handler_tls_start, instream=True) + _stanza = "" + if not self.event_handlers.get(_stanza,None): # don't add handler > once + self.add_handler( _stanza, self.handler_tls_start, instream=True ) self.sendXML(xml) return True else: @@ -221,12 +228,13 @@ class ClientXMPP(basexmpp, XMLStream): return True def handler_auth_success(self, xml): + logging.debug("Authentication successful.") self.authenticated = True self.features = [] raise RestartStream() def handler_auth_fail(self, xml): - logging.info("Authentication failed.") + logging.warning("Authentication failed.") self.disconnect() self.event("failed_auth") diff --git a/sleekxmpp/basexmpp.py b/sleekxmpp/basexmpp.py index b011f7b..a916fe8 100644 --- a/sleekxmpp/basexmpp.py +++ b/sleekxmpp/basexmpp.py @@ -85,6 +85,11 @@ class basexmpp(object): self.jid = self.getjidbare(jid) self.username = jid.split('@', 1)[0] self.domain = jid.split('@',1)[-1].split('/', 1)[0] + + def process(self, *args, **kwargs): + for idx in self.plugin: + if not self.plugin[idx].post_inited: self.plugin[idx].post_init() + return super(basexmpp, self).process(*args, **kwargs) def registerPlugin(self, plugin, pconfig = {}): """Register a plugin not in plugins.__init__.__all__ but in the plugins @@ -109,7 +114,7 @@ class basexmpp(object): plugin_list = plugins.__all__ for plugin in plugin_list: if plugin in plugins.__all__: - self.registerPlugin(plugin, self.plugin_config.get(plugin, {})) + self.registerPlugin(plugin, self.plugin_config.get(plugin, {}), False) else: raise NameError("No plugin by the name of %s listed in plugins.__all__." % plugin) # run post_init() for cross-plugin interaction @@ -185,6 +190,19 @@ class basexmpp(object): self.event_handlers[name] = [] self.event_handlers[name].append((pointer, threaded, disposable)) + def del_event_handler(self, name, pointer): + """Remove a handler for an event.""" + if not name in self.event_handlers: + return + + # Need to keep handlers that do not use + # the given function pointer + def filter_pointers(handler): + return handler[0] != pointer + + self.event_handlers[name] = filter(filter_pointers, + self.event_handlers[name]) + def event(self, name, eventdata = {}): # called on an event for handler in self.event_handlers.get(name, []): if handler[1]: #if threaded diff --git a/sleekxmpp/componentxmpp.py b/sleekxmpp/componentxmpp.py index 9c7a612..de12581 100755 --- a/sleekxmpp/componentxmpp.py +++ b/sleekxmpp/componentxmpp.py @@ -1,4 +1,4 @@ -#!/usr/bin/python2.5 +#!/usr/bin/python2.6 """ SleekXMPP: The Sleek XMPP Library @@ -54,6 +54,16 @@ class ComponentXMPP(basexmpp, XMLStream): self.secret = secret self.registerHandler(Callback('Handshake', MatchXPath('{jabber:component:accept}handshake'), self._handleHandshake)) + def __getitem__(self, key): + if key in self.plugin: + return self.plugin[key] + else: + logging.warning("""Plugin "%s" is not loaded.""" % key) + return False + + def get(self, key, default): + return self.plugin.get(key, default) + def incoming_filter(self, xmlobj): if xmlobj.tag.startswith('{jabber:client}'): xmlobj.tag = xmlobj.tag.replace('jabber:client', self.default_ns) diff --git a/sleekxmpp/plugins/base.py b/sleekxmpp/plugins/base.py index 685833f..4223646 100644 --- a/sleekxmpp/plugins/base.py +++ b/sleekxmpp/plugins/base.py @@ -24,6 +24,7 @@ class base_plugin(object): self.description = 'Base Plugin' self.xmpp = xmpp self.config = config + self.post_inited = False self.enable = config.get('enable', True) if self.enable: self.plugin_init() @@ -32,4 +33,4 @@ class base_plugin(object): pass def post_init(self): - pass + self.post_inited = True diff --git a/sleekxmpp/plugins/jobs.py b/sleekxmpp/plugins/jobs.py new file mode 100644 index 0000000..bb2e255 --- /dev/null +++ b/sleekxmpp/plugins/jobs.py @@ -0,0 +1,44 @@ +from . import base +import logging +from xml.etree import cElementTree as ET + +class jobs(base.base_plugin): + def plugin_init(self): + self.xep = 'pubsubjob' + self.description = "Job distribution over Pubsub" + + def post_init(self): + pass + #TODO add event + + def createJobNode(self, host, jid, node, config=None): + pass + + def createJob(self, host, node, jobid=None, payload=None): + return self.xmpp.plugin['xep_0060'].setItem(host, node, ((jobid, payload),)) + + def claimJob(self, host, node, jobid, ifrom=None): + return self._setState(host, node, jobid, ET.Element('{http://andyet.net/protocol/pubsubjob}claimed')) + + def unclaimJob(self, jobid): + return self._setState(host, node, jobid, ET.Element('{http://andyet.net/protocol/pubsubjob}unclaimed')) + + def finishJob(self, host, node, jobid, payload=None): + finished = ET.Element('{http://andyet.net/protocol/pubsubjob}finished') + if payload is not None: + finished.append(payload) + return self._setState(host, node, jobid, finished) + + def _setState(self, host, node, jobid, state, ifrom=None): + iq = self.xmpp.Iq() + iq['to'] = host + if ifrom: iq['from'] = ifrom + iq['type'] = 'set' + iq['psstate']['node'] = node + iq['psstate']['item'] = jobid + iq['psstate']['payload'] = state + result = iq.send() + if result is None or result['type'] != 'result': + return False + return True + diff --git a/sleekxmpp/plugins/stanza_pubsub.py b/sleekxmpp/plugins/stanza_pubsub.py index 4187d49..1a1526f 100644 --- a/sleekxmpp/plugins/stanza_pubsub.py +++ b/sleekxmpp/plugins/stanza_pubsub.py @@ -10,6 +10,39 @@ def stanzaPlugin(stanza, plugin): stanza.plugin_attrib_map[plugin.plugin_attrib] = plugin stanza.plugin_tag_map["{%s}%s" % (plugin.namespace, plugin.name)] = plugin +class PubsubState(ElementBase): + namespace = 'http://jabber.org/protocol/psstate' + name = 'state' + plugin_attrib = 'psstate' + interfaces = set(('node', 'item', 'payload')) + plugin_attrib_map = {} + plugin_tag_map = {} + + def setPayload(self, value): + self.xml.append(value) + + def getPayload(self): + childs = self.xml.getchildren() + if len(childs) > 0: + return childs[0] + + def delPayload(self): + for child in self.xml.getchildren(): + self.xml.remove(child) + +stanzaPlugin(Iq, PubsubState) + +class PubsubStateEvent(ElementBase): + namespace = 'http://jabber.org/protocol/psstate#event' + name = 'event' + plugin_attrib = 'psstate_event' + intefaces = set(tuple()) + plugin_attrib_map = {} + plugin_tag_map = {} + +stanzaPlugin(Message, PubsubStateEvent) +stanzaPlugin(PubsubStateEvent, PubsubState) + class Pubsub(ElementBase): namespace = 'http://jabber.org/protocol/pubsub' name = 'pubsub' @@ -281,7 +314,7 @@ class DefaultConfig(ElementBase): def getType(self): t = self._getAttr('type') - if not t: t == 'leaf' + if not t: t = 'leaf' return t stanzaPlugin(PubsubOwner, DefaultConfig) @@ -321,18 +354,6 @@ class Options(ElementBase): stanzaPlugin(Pubsub, Options) stanzaPlugin(Subscribe, Options) -#iq = Iq() -#iq['pubsub']['defaultconfig'] -#print(iq) - -#from xml.etree import cElementTree as ET -#iq = Iq() -#item = Item() -#item['payload'] = ET.Element("{http://netflint.net/p/crap}stupidshit") -#item['id'] = 'aa11bbcc' -#iq['pubsub']['items'].append(item) -#print(iq) - class OwnerAffiliations(Affiliations): namespace = 'http://jabber.org/protocol/pubsub#owner' interfaces = set(('node')) diff --git a/sleekxmpp/plugins/xep_0004.py b/sleekxmpp/plugins/xep_0004.py index ec85925..015bd8b 100644 --- a/sleekxmpp/plugins/xep_0004.py +++ b/sleekxmpp/plugins/xep_0004.py @@ -31,7 +31,8 @@ class xep_0004(base.base_plugin): self.xmpp.add_handler("", self.handler_message_xform) def post_init(self): - self.xmpp['xep_0030'].add_feature('jabber:x:data') + base.base_plugin.post_init(self) + self.xmpp.plugin['xep_0030'].add_feature('jabber:x:data') def handler_message_xform(self, xml): object = self.handle_form(xml) @@ -187,7 +188,6 @@ class Form(FieldContainer): #def getXML(self, tostring = False): def getXML(self, ftype=None): - logging.debug("creating form as %s" % ftype) if ftype: self.type = ftype form = ET.Element('{jabber:x:data}x') diff --git a/sleekxmpp/plugins/xep_0009.py b/sleekxmpp/plugins/xep_0009.py index e0da829..49ffac4 100644 --- a/sleekxmpp/plugins/xep_0009.py +++ b/sleekxmpp/plugins/xep_0009.py @@ -185,8 +185,9 @@ class xep_0009(base.base_plugin): self.activeCalls = [] def post_init(self): - self.xmpp['xep_0030'].add_feature('jabber:iq:rpc') - self.xmpp['xep_0030'].add_identity('automatition','rpc') + base.base_plugin.post_init(self) + self.xmpp.plugin['xep_0030'].add_feature('jabber:iq:rpc') + self.xmpp.plugin['xep_0030'].add_identity('automatition','rpc') def register_call(self, method, name=None): #@returns an string that can be used in acl commands. diff --git a/sleekxmpp/plugins/xep_0030.py b/sleekxmpp/plugins/xep_0030.py index 5432dd5..6a31d24 100644 --- a/sleekxmpp/plugins/xep_0030.py +++ b/sleekxmpp/plugins/xep_0030.py @@ -1,25 +1,184 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2007 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz, Lance J.T. Stout + This file is part of SleekXMPP. - SleekXMPP is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - SleekXMPP is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with SleekXMPP; if not, write to the Free Software - Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA + See the file license.txt for copying permissio """ -from . import base + import logging -from xml.etree import cElementTree as ET +from . import base +from .. xmlstream.handler.callback import Callback +from .. xmlstream.matcher.xpath import MatchXPath +from .. xmlstream.stanzabase import ElementBase, ET, JID +from .. stanza.iq import Iq + +class DiscoInfo(ElementBase): + namespace = 'http://jabber.org/protocol/disco#info' + name = 'query' + plugin_attrib = 'disco_info' + interfaces = set(('node', 'features', 'identities')) + + def getFeatures(self): + features = [] + featuresXML = self.xml.findall('{%s}feature' % self.namespace) + for feature in featuresXML: + features.append(feature.attrib['var']) + return features + + def setFeatures(self, features): + self.delFeatures() + for name in features: + self.addFeature(name) + + def delFeatures(self): + featuresXML = self.xml.findall('{%s}feature' % self.namespace) + for feature in featuresXML: + self.xml.remove(feature) + + def addFeature(self, feature): + featureXML = ET.Element('{%s}feature' % self.namespace, + {'var': feature}) + self.xml.append(featureXML) + + def delFeature(self, feature): + featuresXML = self.xml.findall('{%s}feature' % self.namespace) + for featureXML in featuresXML: + if featureXML.attrib['var'] == feature: + self.xml.remove(featureXML) + + def getIdentities(self): + ids = [] + idsXML = self.xml.findall('{%s}identity' % self.namespace) + for idXML in idsXML: + idData = (idXML.attrib['category'], + idXML.attrib['type'], + idXML.attrib.get('name', '')) + ids.append(idData) + return ids + + def setIdentities(self, ids): + self.delIdentities() + for idData in ids: + self.addIdentity(*idData) + + def delIdentities(self): + idsXML = self.xml.findall('{%s}identity' % self.namespace) + for idXML in idsXML: + self.xml.remove(idXML) + + def addIdentity(self, category, id_type, name=''): + idXML = ET.Element('{%s}identity' % self.namespace, + {'category': category, + 'type': id_type, + 'name': name}) + self.xml.append(idXML) + + def delIdentity(self, category, id_type, name=''): + idsXML = self.xml.findall('{%s}identity' % self.namespace) + for idXML in idsXML: + idData = (idXML.attrib['category'], + idXML.attrib['type']) + delId = (category, id_type) + if idData == delId: + self.xml.remove(idXML) + + +class DiscoItems(ElementBase): + namespace = 'http://jabber.org/protocol/disco#items' + name = 'query' + plugin_attrib = 'disco_items' + interfaces = set(('node', 'items')) + + def getItems(self): + items = [] + itemsXML = self.xml.findall('{%s}item' % self.namespace) + for item in itemsXML: + itemData = (item.attrib['jid'], + item.attrib.get('node'), + item.attrib.get('name')) + items.append(itemData) + return items + + def setItems(self, items): + self.delItems() + for item in items: + self.addItem(*item) + + def delItems(self): + itemsXML = self.xml.findall('{%s}item' % self.namespace) + for item in itemsXML: + self.xml.remove(item) + + def addItem(self, jid, node='', name=''): + itemXML = ET.Element('{%s}item' % self.namespace, {'jid': jid}) + if name: + itemXML.attrib['name'] = name + if node: + itemXML.attrib['node'] = node + self.xml.append(itemXML) + + def delItem(self, jid, node=''): + itemsXML = self.xml.findall('{%s}item' % self.namespace) + for itemXML in itemsXML: + itemData = (itemXML.attrib['jid'], + itemXML.attrib.get('node', '')) + itemDel = (jid, node) + if itemData == itemDel: + self.xml.remove(itemXML) + + +class DiscoNode(object): + """ + Collection object for grouping info and item information + into nodes. + """ + def __init__(self, name): + self.name = name + self.info = DiscoInfo() + self.items = DiscoItems() + + # This is a bit like poor man's inheritance, but + # to simplify adding information to the node we + # map node functions to either the info or items + # stanza objects. + # + # We don't want to make DiscoNode inherit from + # DiscoInfo and DiscoItems because DiscoNode is + # not an actual stanza, and doing so would create + # confusion and potential bugs. + + self._map(self.items, 'items', ['get', 'set', 'del']) + self._map(self.items, 'item', ['add', 'del']) + self._map(self.info, 'identities', ['get', 'set', 'del']) + self._map(self.info, 'identity', ['add', 'del']) + self._map(self.info, 'features', ['get', 'set', 'del']) + self._map(self.info, 'feature', ['add', 'del']) + + def isEmpty(self): + """ + Test if the node contains any information. Useful for + determining if a node can be deleted. + """ + ids = self.getIdentities() + features = self.getFeatures() + items = self.getItems() + + if not ids and not features and not items: + return True + return False + + def _map(self, obj, interface, access): + """ + Map functions of the form obj.accessInterface + to self.accessInterface for each given access type. + """ + interface = interface.title() + for access_type in access: + method = access_type + interface + if hasattr(obj, method): + setattr(self, method, getattr(obj, method)) + class xep_0030(base.base_plugin): """ @@ -29,85 +188,137 @@ class xep_0030(base.base_plugin): def plugin_init(self): self.xep = '0030' self.description = 'Service Discovery' - self.features = {'main': ['http://jabber.org/protocol/disco#info', 'http://jabber.org/protocol/disco#items']} - self.identities = {'main': [{'category': 'client', 'type': 'pc', 'name': 'SleekXMPP'}]} - self.items = {'main': []} - self.xmpp.add_handler("" % self.xmpp.default_ns, self.info_handler) - self.xmpp.add_handler("" % self.xmpp.default_ns, self.item_handler) + + self.xmpp.registerHandler( + Callback('Disco Items', + MatchXPath('{%s}iq/{%s}query' % (self.xmpp.default_ns, + DiscoItems.namespace)), + self.handle_item_query)) + + self.xmpp.registerHandler( + Callback('Disco Info', + MatchXPath('{%s}iq/{%s}query' % (self.xmpp.default_ns, + DiscoInfo.namespace)), + self.handle_info_query)) + + self.xmpp.stanzaPlugin(Iq, DiscoInfo) + self.xmpp.stanzaPlugin(Iq, DiscoItems) + + self.xmpp.add_event_handler('disco_items_request', self.handle_disco_items) + self.xmpp.add_event_handler('disco_info_request', self.handle_disco_info) + + self.nodes = {'main': DiscoNode('main')} + + def add_node(self, node): + if node not in self.nodes: + self.nodes[node] = DiscoNode(node) + + def del_node(self, node): + if node in self.nodes: + del self.nodes[node] + + def handle_item_query(self, iq): + if iq['type'] == 'get': + logging.debug("Items requested by %s" % iq['from']) + self.xmpp.event('disco_items_request', iq) + elif iq['type'] == 'result': + logging.debug("Items result from %s" % iq['from']) + self.xmpp.event('disco_items', iq) + + def handle_info_query(self, iq): + if iq['type'] == 'get': + logging.debug("Info requested by %s" % iq['from']) + self.xmpp.event('disco_info_request', iq) + elif iq['type'] == 'result': + logging.debug("Info result from %s" % iq['from']) + self.xmpp.event('disco_info', iq) + + def handle_disco_info(self, iq, forwarded=False): + """ + A default handler for disco#info requests. If another + handler is registered, this one will defer and not run. + """ + handlers = self.xmpp.event_handlers['disco_info_request'] + if not forwarded and len(handlers) > 1: + return + + node_name = iq['disco_info']['node'] + if not node_name: + node_name = 'main' + + logging.debug("Using default handler for disco#info on node '%s'." % node_name) + + if node_name in self.nodes: + node = self.nodes[node_name] + iq.reply().setPayload(node.info.xml).send() + else: + logging.debug("Node %s requested, but does not exist." % node_name) + iq.reply().error().setPayload(iq['disco_info'].xml) + iq['error']['code'] = '404' + iq['error']['type'] = 'cancel' + iq['error']['condition'] = 'item-not-found' + iq.send() + + def handle_disco_items(self, iq, forwarded=False): + """ + A default handler for disco#items requests. If another + handler is registered, this one will defer and not run. + + If this handler is called by your own custom handler with + forwarded set to True, then it will run as normal. + """ + handlers = self.xmpp.event_handlers['disco_items_request'] + if not forwarded and len(handlers) > 1: + return + + node_name = iq['disco_items']['node'] + if not node_name: + node_name = 'main' + + logging.debug("Using default handler for disco#items on node '%s'." % node_name) + + if node_name in self.nodes: + node = self.nodes[node_name] + iq.reply().setPayload(node.items.xml).send() + else: + logging.debug("Node %s requested, but does not exist." % node_name) + iq.reply().error().setPayload(iq['disco_items'].xml) + iq['error']['code'] = '404' + iq['error']['type'] = 'cancel' + iq['error']['condition'] = 'item-not-found' + iq.send() + + # Older interface methods for backwards compatibility + + def getInfo(self, jid, node=''): + iq = self.xmpp.Iq() + iq['type'] = 'get' + iq['to'] = jid + iq['from'] = self.xmpp.fulljid + iq['disco_info']['node'] = node + iq.send() + + def getItems(self, jid, node=''): + iq = self.xmpp.Iq() + iq['type'] = 'get' + iq['to'] = jid + iq['from'] = self.xmpp.fulljid + iq['disco_items']['node'] = node + iq.send() def add_feature(self, feature, node='main'): - if not node in self.features: - self.features[node] = [] - self.features[node].append(feature) + self.add_node(node) + self.nodes[node].addFeature(feature) - def add_identity(self, category=None, itype=None, name=None, node='main'): - if not node in self.identities: - self.identities[node] = [] - self.identities[node].append({'category': category, 'type': itype, 'name': name}) + def add_identity(self, category='', itype='', name='', node='main'): + self.add_node(node) + self.nodes[node].addIdentity(category=category, + id_type=itype, + name=name) - def add_item(self, jid=None, name=None, node='main', subnode=''): - if not node in self.items: - self.items[node] = [] - self.items[node].append({'jid': jid, 'name': name, 'node': subnode}) - - def info_handler(self, xml): - logging.debug("Info request from %s" % xml.get('from', '')) - iq = self.xmpp.makeIqResult(xml.get('id', self.xmpp.getNewId())) - iq.attrib['from'] = xml.get('to') - iq.attrib['to'] = xml.get('from', self.xmpp.server) - query = xml.find('{http://jabber.org/protocol/disco#info}query') - node = query.get('node', 'main') - for identity in self.identities.get(node, []): - idxml = ET.Element('identity') - for attrib in identity: - if identity[attrib]: - idxml.attrib[attrib] = identity[attrib] - query.append(idxml) - for feature in self.features.get(node, []): - featxml = ET.Element('feature') - featxml.attrib['var'] = feature - query.append(featxml) - iq.append(query) - #print ET.tostring(iq) - self.xmpp.send(iq) - - def item_handler(self, xml): - logging.debug("Item request from %s" % xml.get('from', '')) - iq = self.xmpp.makeIqResult(xml.get('id', self.xmpp.getNewId())) - iq.attrib['from'] = xml.get('to') - iq.attrib['to'] = xml.get('from', self.xmpp.server) - query = self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#items').find('{http://jabber.org/protocol/disco#items}query') - node = xml.find('{http://jabber.org/protocol/disco#items}query').get('node', 'main') - for item in self.items.get(node, []): - itemxml = ET.Element('item') - itemxml.attrib = item - if itemxml.attrib['jid'] is None: - itemxml.attrib['jid'] = xml.get('to') - query.append(itemxml) - self.xmpp.send(iq) - - def getItems(self, jid, node=None): - iq = self.xmpp.makeIqGet() - iq.attrib['from'] = self.xmpp.fulljid - iq.attrib['to'] = jid - self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#items') - if node: - iq.find('{http://jabber.org/protocol/disco#items}query').attrib['node'] = node - return iq.send() - - def getInfo(self, jid, node=None): - iq = self.xmpp.makeIqGet() - iq.attrib['from'] = self.xmpp.fulljid - iq.attrib['to'] = jid - self.xmpp.makeIqQuery(iq, 'http://jabber.org/protocol/disco#info') - if node: - iq.find('{http://jabber.org/protocol/disco#info}query').attrib['node'] = node - return iq.send() - - def parseInfo(self, xml): - result = {'identity': {}, 'feature': []} - for identity in xml.findall('{http://jabber.org/protocol/disco#info}query/{{http://jabber.org/protocol/disco#info}identity'): - result['identity'][identity['name']] = identity.attrib - for feature in xml.findall('{http://jabber.org/protocol/disco#info}query/{{http://jabber.org/protocol/disco#info}feature'): - result['feature'].append(feature.get('var', '__unknown__')) - return result + def add_item(self, jid=None, name='', node='main', subnode=''): + self.add_node(node) + self.add_node(subnode) + if jid is None: + jid = self.xmpp.fulljid + self.nodes[node].addItem(jid=jid, name=name, node=subnode) diff --git a/sleekxmpp/plugins/xep_0050.py b/sleekxmpp/plugins/xep_0050.py index 0ca66dd..2f356e1 100644 --- a/sleekxmpp/plugins/xep_0050.py +++ b/sleekxmpp/plugins/xep_0050.py @@ -42,6 +42,7 @@ class xep_0050(base.base_plugin): self.sd = self.xmpp.plugin['xep_0030'] def post_init(self): + base.base_plugin.post_init(self) self.sd.add_feature('http://jabber.org/protocol/commands') def addCommand(self, node, name, form, pointer=None, multi=False): diff --git a/sleekxmpp/plugins/xep_0060.py b/sleekxmpp/plugins/xep_0060.py index 44a70e9..bff158a 100644 --- a/sleekxmpp/plugins/xep_0060.py +++ b/sleekxmpp/plugins/xep_0060.py @@ -14,12 +14,14 @@ class xep_0060(base.base_plugin): self.xep = '0060' self.description = 'Publish-Subscribe' - def create_node(self, jid, node, config=None, collection=False): + def create_node(self, jid, node, config=None, collection=False, ntype=None): pubsub = ET.Element('{http://jabber.org/protocol/pubsub}pubsub') create = ET.Element('create') create.set('node', node) pubsub.append(create) configure = ET.Element('configure') + if collection: + ntype = 'collection' #if config is None: # submitform = self.xmpp.plugin['xep_0004'].makeForm('submit') #else: @@ -29,11 +31,11 @@ class xep_0060(base.base_plugin): submitform.field['FORM_TYPE'].setValue('http://jabber.org/protocol/pubsub#node_config') else: submitform.addField('FORM_TYPE', 'hidden', value='http://jabber.org/protocol/pubsub#node_config') - if collection: + if ntype: if 'pubsub#node_type' in submitform.field: - submitform.field['pubsub#node_type'].setValue('collection') + submitform.field['pubsub#node_type'].setValue(ntype) else: - submitform.addField('pubsub#node_type', value='collection') + submitform.addField('pubsub#node_type', value=ntype) else: if 'pubsub#node_type' in submitform.field: submitform.field['pubsub#node_type'].setValue('leaf') diff --git a/sleekxmpp/plugins/xep_0092.py b/sleekxmpp/plugins/xep_0092.py index 3d02638..aeebbe0 100644 --- a/sleekxmpp/plugins/xep_0092.py +++ b/sleekxmpp/plugins/xep_0092.py @@ -33,7 +33,8 @@ class xep_0092(base.base_plugin): self.xmpp.add_handler("" % self.xmpp.default_ns, self.report_version) def post_init(self): - self.xmpp['xep_0030'].add_feature('jabber:iq:version') + base.base_plugin.post_init(self) + self.xmpp.plugin['xep_0030'].add_feature('jabber:iq:version') def report_version(self, xml): iq = self.xmpp.makeIqResult(xml.get('id', 'unknown')) diff --git a/sleekxmpp/plugins/xep_0199.py b/sleekxmpp/plugins/xep_0199.py index 989e645..ccaf0b3 100644 --- a/sleekxmpp/plugins/xep_0199.py +++ b/sleekxmpp/plugins/xep_0199.py @@ -35,7 +35,8 @@ class xep_0199(base.base_plugin): #self.xmpp.add_event_handler('session_start', self.handler_pingserver, threaded=True) def post_init(self): - self.xmpp['xep_0030'].add_feature('http://www.xmpp.org/extensions/xep-0199.html#ns') + base.base_plugin.post_init(self) + self.xmpp.plugin['xep_0030'].add_feature('http://www.xmpp.org/extensions/xep-0199.html#ns') def handler_pingserver(self, xml): if not self.running: diff --git a/sleekxmpp/stanza/error.py b/sleekxmpp/stanza/error.py index 15af662..ee46722 100644 --- a/sleekxmpp/stanza/error.py +++ b/sleekxmpp/stanza/error.py @@ -11,8 +11,8 @@ class Error(ElementBase): namespace = 'jabber:client' name = 'error' plugin_attrib = 'error' - conditions = set(('bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'item-not-found', 'jid-malformed', 'not-acceptable', 'not-allowed', 'not-authorized', 'payment-required', 'recipient-unavailable', 'redirect', 'registration-required', 'remote-server-not-found', 'remote-server-timeout', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request')) - interfaces = set(('condition', 'text', 'type')) + conditions = set(('bad-request', 'conflict', 'feature-not-implemented', 'forbidden', 'gone', 'internal-server-error', 'item-not-found', 'jid-malformed', 'not-acceptable', 'not-allowed', 'not-authorized', 'payment-required', 'recipient-unavailable', 'redirect', 'registration-required', 'remote-server-not-found', 'remote-server-timeout', 'resource-constraint', 'service-unavailable', 'subscription-required', 'undefined-condition', 'unexpected-request')) + interfaces = set(('code', 'condition', 'text', 'type')) types = set(('cancel', 'continue', 'modify', 'auth', 'wait')) sub_interfaces = set(('text',)) condition_ns = 'urn:ietf:params:xml:ns:xmpp-stanzas' diff --git a/sleekxmpp/stanza/iq.py b/sleekxmpp/stanza/iq.py index 4969b70..ded7515 100644 --- a/sleekxmpp/stanza/iq.py +++ b/sleekxmpp/stanza/iq.py @@ -37,6 +37,7 @@ class Iq(RootStanza): def setPayload(self, value): self.clear() StanzaBase.setPayload(self, value) + return self def setQuery(self, value): query = self.xml.find("{%s}query" % value) diff --git a/sleekxmpp/xmlstream/handler/base.py b/sleekxmpp/xmlstream/handler/base.py index 5d55f4e..a44edf0 100644 --- a/sleekxmpp/xmlstream/handler/base.py +++ b/sleekxmpp/xmlstream/handler/base.py @@ -18,7 +18,7 @@ class BaseHandler(object): def match(self, xml): return self._matcher.match(xml) - def prerun(self, payload): + def prerun(self, payload): # what's the point of this if the payload is called again in run?? self._payload = payload def run(self, payload): diff --git a/sleekxmpp/xmlstream/handler/callback.py b/sleekxmpp/xmlstream/handler/callback.py index 49cfa14..ea5acb5 100644 --- a/sleekxmpp/xmlstream/handler/callback.py +++ b/sleekxmpp/xmlstream/handler/callback.py @@ -17,13 +17,15 @@ class Callback(base.BaseHandler): self._once = once self._instream = instream - def prerun(self, payload): + def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN! base.BaseHandler.prerun(self, payload) if self._instream: + logging.debug('callback "%s" prerun', self.name) self.run(payload, True) def run(self, payload, instream=False): if not self._instream or instream: + logging.debug('callback "%s" run', self.name) base.BaseHandler.run(self, payload) #if self._thread: # x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,)) diff --git a/sleekxmpp/xmlstream/scheduler.py b/sleekxmpp/xmlstream/scheduler.py new file mode 100644 index 0000000..945d9fa --- /dev/null +++ b/sleekxmpp/xmlstream/scheduler.py @@ -0,0 +1,87 @@ +try: + import queue +except ImportError: + import Queue as queue +import time +import threading +import logging + +class Task(object): + """Task object for the Scheduler class""" + def __init__(self, name, seconds, callback, args=None, kwargs=None, repeat=False, qpointer=None): + self.name = name + self.seconds = seconds + self.callback = callback + self.args = args or tuple() + self.kwargs = kwargs or {} + self.repeat = repeat + self.next = time.time() + self.seconds + self.qpointer = qpointer + + def run(self): + if self.qpointer is not None: + self.qpointer.put(('schedule', self.callback, self.args)) + else: + self.callback(*self.args, **self.kwargs) + self.reset() + return self.repeat + + def reset(self): + self.next = time.time() + self.seconds + +class Scheduler(object): + """Threaded scheduler that allows for updates mid-execution unlike http://docs.python.org/library/sched.html#module-sched""" + def __init__(self, parentqueue=None): + self.addq = queue.Queue() + self.schedule = [] + self.thread = None + self.run = False + self.parentqueue = parentqueue + + def process(self, threaded=True): + if threaded: + self.thread = threading.Thread(name='shedulerprocess', target=self._process) + self.thread.start() + else: + self._process() + + def _process(self): + self.run = True + while self.run: + try: + wait = 1 + updated = False + if self.schedule: + wait = self.schedule[0].next - time.time() + try: + if wait <= 0.0: + newtask = self.addq.get(False) + else: + newtask = self.addq.get(True, wait) + except queue.Empty: + cleanup = [] + for task in self.schedule: + if time.time() >= task.next: + updated = True + if not task.run(): + cleanup.append(task) + else: + break + for task in cleanup: + x = self.schedule.pop(self.schedule.index(task)) + else: + updated = True + self.schedule.append(newtask) + finally: + if updated: self.schedule = sorted(self.schedule, key=lambda task: task.next) + except KeyboardInterrupt: + self.run = False + logging.debug("Qutting Scheduler thread") + if self.parentqueue is not None: + self.parentqueue.put(('quit', None, None)) + + def add(self, name, seconds, callback, args=None, kwargs=None, repeat=False, qpointer=None): + self.addq.put(Task(name, seconds, callback, args, kwargs, repeat, qpointer)) + + def quit(self): + self.run = False diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py index 018e81c..64020c8 100644 --- a/sleekxmpp/xmlstream/stanzabase.py +++ b/sleekxmpp/xmlstream/stanzabase.py @@ -78,6 +78,9 @@ class ElementBase(tostring.ToString): def __iter__(self): self.idx = 0 return self + + def __bool__(self): + return True def __next__(self): self.idx += 1 @@ -319,6 +322,8 @@ class StanzaBase(ElementBase): def __init__(self, stream=None, xml=None, stype=None, sto=None, sfrom=None, sid=None): self.stream = stream + if stream is not None: + self.namespace = stream.default_ns ElementBase.__init__(self, xml) if stype is not None: self['type'] = stype @@ -326,13 +331,11 @@ class StanzaBase(ElementBase): self['to'] = sto if sfrom is not None: self['from'] = sfrom - if stream is not None: - self.namespace = stream.default_ns self.tag = "{%s}%s" % (self.namespace, self.name) def setType(self, value): if value in self.types: - self.xml.attrib['type'] = value + self.xml.attrib['type'] = value return self def getPayload(self): @@ -340,15 +343,18 @@ class StanzaBase(ElementBase): def setPayload(self, value): self.xml.append(value) + return self def delPayload(self): self.clear() + return self def clear(self): for child in self.xml.getchildren(): self.xml.remove(child) for plugin in list(self.plugins.keys()): del self.plugins[plugin] + return self def reply(self): self['from'], self['to'] = self['to'], self['from'] @@ -357,6 +363,7 @@ class StanzaBase(ElementBase): def error(self): self['type'] = 'error' + return self def getTo(self): return JID(self._getAttr('to')) diff --git a/sleekxmpp/xmlstream/statemachine.py b/sleekxmpp/xmlstream/statemachine.py index fb7d150..c5f5176 100644 --- a/sleekxmpp/xmlstream/statemachine.py +++ b/sleekxmpp/xmlstream/statemachine.py @@ -7,53 +7,124 @@ """ from __future__ import with_statement import threading +import time +import logging class StateMachine(object): - def __init__(self, states=[], groups=[]): - self.lock = threading.Lock() - self.__state = {} - self.__default_state = {} - self.__group = {} + def __init__(self, states=[]): + self.lock = threading.Condition(threading.RLock()) + self.__states= [] self.addStates(states) - self.addGroups(groups) + self.__default_state = self.__states[0] + self.__current_state = self.__default_state def addStates(self, states): with self.lock: for state in states: - if state in self.__state or state in self.__group: - raise IndexError("The state or group '%s' is already in the StateMachine." % state) - self.__state[state] = states[state] - self.__default_state[state] = states[state] + if state in self.__states: + raise IndexError("The state '%s' is already in the StateMachine." % state) + self.__states.append( state ) - def addGroups(self, groups): - with self.lock: - for gstate in groups: - if gstate in self.__state or gstate in self.__group: - raise IndexError("The key or group '%s' is already in the StateMachine." % gstate) - for state in groups[gstate]: - if state in self.__state: - raise IndexError("The group %s contains a key %s which is not set in the StateMachine." % (gstate, state)) - self.__group[gstate] = groups[gstate] - def set(self, state, status): + def transition(self, from_state, to_state, wait=0.0): + ''' + Transition from the given `from_state` to the given `to_state`. + This method will return `True` if the state machine is now in `to_state`. It + will return `False` if a timeout occurred the transition did not occur. + If `wait` is 0 (the default,) this method returns immediately if the state machine + is not in `from_state`. + + If you want the thread to block and transition once the state machine to enters + `from_state`, set `wait` to a non-negative value. Note there is no 'block + indefinitely' flag since this leads to deadlock. If you want to wait indefinitely, + choose a reasonable value for `wait` (e.g. 20 seconds) and do so in a while loop like so: + + :: + + while not thread_should_exit and not state_machine.transition('disconnected', 'connecting', wait=20 ): + pass # timeout will occur every 20s unless transition occurs + if thread_should_exit: return + # perform actions here after successful transition + + This allows the thread to be interrupted by setting `thread_should_exit=True` + ''' + + return self.transition_any( (from_state,), to_state, wait=wait ) + + def transition_any(self, from_states, to_state, wait=0.0): + ''' + Transition from any of the given `from_states` to the given `to_state`. + ''' + with self.lock: - if state in self.__state: - self.__state[state] = bool(status) + for state in from_states: + if isinstance(state,tuple) or isinstance(state,list): + raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) ) + if not state in self.__states: + raise ValueError( "StateMachine does not contain from_state %s." % state ) + if not to_state in self.__states: + raise ValueError( "StateMachine does not contain to_state %s." % to_state ) + + start = time.time() + while not self.__current_state in from_states: + # detect timeout: + if time.time() >= start + wait: return False + self.lock.wait(wait) + + if self.__current_state in from_states: # should always be True due to lock + logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) + self.__current_state = to_state + self.lock.notifyAll() + return True else: - raise KeyError("StateMachine does not contain state %s." % state) - - def __getitem__(self, key): - if key in self.__group: - for state in self.__group[key]: - if not self.__state[state]: - return False - return True - return self.__state[key] - - def __getattr__(self, attr): - return self.__getitem__(attr) + logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) + return False + + + def ensure(self, state, wait=0.0): + ''' + Ensure the state machine is currently in `state`, or wait until it enters `state`. + ''' + return self.ensure_any( (state,), wait=wait ) + + def ensure_any(self, states, wait=0.0): + ''' + Ensure we are currently in one of the given `states` + ''' + with self.lock: + for state in states: + if isinstance(state,tuple) or isinstance(state,list): + raise ValueError( "State %s should be a string. Did you mean to call 'StateMachine.transition_any()?" % str(state) ) + if not state in self.__states: + raise ValueError( "StateMachine does not contain state %s." % state ) + + start = time.time() + while not self.__current_state in states: + # detect timeout: + if time.time() >= start + wait: return False + self.lock.wait(wait) + return self.__current_state in states # should always be True due to lock + def reset(self): - self.__state = self.__default_state + # TODO need to lock before calling this? + self.transition(self.__current_state, self._default_state) + + + def __getitem__(self, state): + ''' + Non-blocking, non-synchronized test to determine if we are in the given state. + Use `StateMachine.ensure(state)` to wait until the machine enters a certain state. + ''' + return self.__current_state == state + + def __enter__(self): + self.lock.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.lock.nofityAll() + self.lock.release() + return False # re-raise any exception diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index cdce1fd..3bcb341 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -1,9 +1,9 @@ """ - SleekXMPP: The Sleek XMPP Library - Copyright (C) 2010 Nathanael C. Fritz - This file is part of SleekXMPP. + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. - See the file license.txt for copying permission. + See the file license.txt for copying permission. """ from __future__ import with_statement, unicode_literals @@ -22,6 +22,7 @@ import time import traceback import types import xml.sax.saxutils +from . import scheduler HANDLER_THREADS = 1 @@ -52,8 +53,9 @@ class XMLStream(object): global ssl_support self.ssl_support = ssl_support self.escape_quotes = escape_quotes - self.state = statemachine.StateMachine() - self.state.addStates({'connected':False, 'is client':False, 'ssl':False, 'tls':False, 'reconnect':True, 'processing':False, 'disconnecting':False}) #set initial states + self.state = statemachine.StateMachine(('disconnected','connecting', + 'connected')) + self.should_reconnect = True self.setSocket(socket) self.address = (host, int(port)) @@ -76,6 +78,7 @@ class XMLStream(object): self.eventqueue = queue.Queue() self.sendqueue = queue.Queue() + self.scheduler = scheduler.Scheduler(self.eventqueue) self.namespace_map = {} @@ -84,45 +87,49 @@ class XMLStream(object): def setSocket(self, socket): "Set the socket" self.socket = socket - if socket is not None: + if socket is not None and self.state.transition('disconnected','connecting'): self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary - self.state.set('connected', True) - + self.state.transition('connecting','connected') def setFileSocket(self, filesocket): self.filesocket = filesocket - def connect(self, host='', port=0, use_ssl=False, use_tls=True): + def connect(self, host='', port=0, use_ssl=None, use_tls=None): "Link to connectTCP" - return self.connectTCP(host, port, use_ssl, use_tls) + if self.state.transition('disconnected', 'connecting'): + return self.connectTCP(host, port, use_ssl, use_tls) def connectTCP(self, host='', port=0, use_ssl=None, use_tls=None, reattempt=True): "Connect and create socket" - while reattempt and not self.state['connected']: - if host and port: - self.address = (host, int(port)) - if use_ssl is not None: - self.use_ssl = use_ssl - if use_tls is not None: - self.use_tls = use_tls - self.state.set('is client', True) - if sys.version_info < (3, 0): - self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM) - else: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.settimeout(None) - if self.use_ssl and self.ssl_support: - logging.debug("Socket Wrapped for SSL") - self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs) + while reattempt and not self.state['connected']: # the self.state part is redundant. + logging.debug('connecting....') + try: + if host and port: + self.address = (host, int(port)) + if use_ssl is not None: + self.use_ssl = use_ssl + if use_tls is not None: + self.use_tls = use_tls + if sys.version_info < (3, 0): + self.socket = filesocket.Socket26(socket.AF_INET, socket.SOCK_STREAM) + else: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.settimeout(None) #10) + if self.use_ssl and self.ssl_support: + logging.debug("Socket Wrapped for SSL") + self.socket = ssl.wrap_socket(self.socket,ca_certs=self.ca_certs) + except: + logging.exception("Connection error") try: self.socket.connect(self.address) - #self.filesocket = self.socket.makefile('rb', 0) self.filesocket = self.socket.makefile('rb', 0) - self.state.set('connected', True) + if not self.state.transition('connecting','connected'): + logging.error( "State transition error!!!! Shouldn't have happened" ) + logging.debug('connect complete.') return True except socket.error as serr: logging.error("Could not connect. Socket Error #%s: %s" % (serr.errno, serr.strerror)) - time.sleep(1) + time.sleep(1) # TODO proper quiesce if connection attempt fails def connectUnix(self, filepath): "Connect to Unix file and create socket" @@ -131,19 +138,19 @@ class XMLStream(object): "Handshakes for TLS" if self.ssl_support: logging.info("Negotiating TLS") - self.realsocket = self.socket +# self.realsocket = self.socket # NOT USED self.socket = ssl.wrap_socket(self.socket, - ssl_version=ssl.PROTOCOL_TLSv1, - do_handshake_on_connect=False, - ca_certs=self.ca_certs) - print "doing handshake..." + ssl_version=ssl.PROTOCOL_TLSv1, + do_handshake_on_connect=False, + ca_certs=self.ca_certs) self.socket.do_handshake() - print "got handshake..." if sys.version_info < (3,0): from . filesocket import filesocket self.filesocket = filesocket(self.socket) else: self.filesocket = self.socket.makefile('rb', 0) + + logging.debug("TLS negotitation successful") return True else: logging.warning("Tried to enable TLS, but ssl module not found.") @@ -151,6 +158,8 @@ class XMLStream(object): raise RestartStream() def process(self, threaded=True): + self.scheduler.process(threaded=True) + self.run = True for t in range(0, HANDLER_THREADS): th = threading.Thread(name='eventhandle%s' % t, target=self._eventRunner) th.setDaemon(True) @@ -168,58 +177,37 @@ class XMLStream(object): else: self._process() - def schedule(self, seconds, handler, args=None): - threading.Timer(seconds, handler, args).start() + def schedule(self, name, seconds, callback, args=None, kwargs=None, repeat=False): + self.scheduler.add(name, seconds, callback, args, kwargs, repeat, qpointer=self.eventqueue) def _process(self): "Start processing the socket." - firstrun = True - while self.run and (firstrun or self.state['reconnect']): - self.state.set('processing', True) - firstrun = False + logging.debug('Process thread starting...') + while self.run: + if not self.state.ensure('connected',wait=2): continue try: - if self.state['is client']: - self.sendRaw(self.stream_header) - while self.run and self.__readXML(): - if self.state['is client']: - self.sendRaw(self.stream_header) - except KeyboardInterrupt: - logging.debug("Keyboard Escape Detected") - self.state.set('processing', False) - self.state.set('reconnect', False) - self.disconnect() - self.run = False - self.eventqueue.put(('quit', None, None)) - return + self.sendRaw(self.stream_header) + while self.run and self.__readXML(): pass + except socket.timeout: + logging.debug('socket rcv timeout') + pass except CloseStream: - return - except SystemExit: + # TODO warn that the listener thread is exiting!!! + pass + except RestartStream: + logging.debug("Restarting stream...") + continue # DON'T re-initialize the stream -- this exception is sent + # specifically when we've initialized TLS and need to re-send the header. + except (KeyboardInterrupt, SystemExit): + logging.debug("System interrupt detected") + self.shutdown() self.eventqueue.put(('quit', None, None)) - return - except socket.error: - if not self.state.reconnect: - return - else: - self.state.set('processing', False) - traceback.print_exc() - self.disconnect(reconnect=True) except: - if not self.state.reconnect: - return - else: - self.state.set('processing', False) - traceback.print_exc() + logging.exception('Unexpected error in RCV thread') + if self.should_reconnect: self.disconnect(reconnect=True) - if self.state['reconnect']: - self.state.set('connected', False) - self.state.set('processing', False) - self.reconnect() - else: - self.eventqueue.put(('quit', None, None)) - #self.__thread['readXML'] = threading.Thread(name='readXML', target=self.__readXML) - #self.__thread['readXML'].start() - #self.__thread['spawnEvents'] = threading.Thread(name='spawnEvents', target=self.__spawnEvents) - #self.__thread['spawnEvents'].start() + + logging.debug('Quitting Process thread') def __readXML(self): "Parses the incoming stream, adding to xmlin queue as it goes" @@ -232,39 +220,47 @@ class XMLStream(object): if edepth == 0: # and xmlobj.tag.split('}', 1)[-1] == self.basetag: if event == b'start': root = xmlobj + logging.debug('handling start stream') self.start_stream_handler(root) if event == b'end': edepth += -1 if edepth == 0 and event == b'end': - self.disconnect(reconnect=self.state['reconnect']) + # what is this case exactly? Premature EOF? + logging.debug("Ending readXML loop") return False elif edepth == 1: #self.xmlin.put(xmlobj) - try: - self.__spawnEvent(xmlobj) - except RestartStream: - return True - except CloseStream: - return False - if root: - root.clear() + self.__spawnEvent(xmlobj) + if root: root.clear() if event == b'start': edepth += 1 + logging.debug("Exiting readXML loop") + return False def _sendThread(self): + logging.debug('send thread starting...') while self.run: - data = self.sendqueue.get(True) - logging.debug("SEND: %s" % data) + if not self.state.ensure('connected',wait=2): continue + + data = None try: - self.socket.send(data.encode('utf-8')) - #self.socket.send(bytes(data, "utf-8")) - #except socket.error,(errno, strerror): + data = self.sendqueue.get(True,10) + logging.debug("SEND: %s" % data) + self.socket.sendall(data.encode('utf-8')) + except queue.Empty: + logging.debug('nothing on send queue') + except socket.timeout: + # this is to prevent a thread blocked indefinitely + logging.debug('timeout sending packet data') except: logging.warning("Failed to send %s" % data) - self.state.set('connected', False) - if self.state.reconnect: - logging.error("Disconnected. Socket Error.") - traceback.print_exc() + logging.exception("Socket error in SEND thread") + # TODO it's somewhat unsafe for the sender thread to assume it can just + # re-intitialize the connection, since the receiver thread could be doing + # the same thing concurrently. Oops! The safer option would be to throw + # some sort of event that could be handled by a common thread or the reader + # thread to perform reconnect and then re-initialize the handler threads as well. + if self.should_reconnect: self.disconnect(reconnect=True) def sendRaw(self, data): @@ -272,42 +268,41 @@ class XMLStream(object): return True def disconnect(self, reconnect=False): - self.state.set('reconnect', reconnect) - if self.state['disconnecting']: + if not self.state.transition('connected','disconnected'): + logging.warning("Already disconnected.") return - if not self.state['reconnect']: - logging.debug("Disconnecting...") - self.state.set('disconnecting', True) - self.run = False - if self.state['connected']: - self.sendRaw(self.stream_footer) - time.sleep(1) - #send end of stream - #wait for end of stream back + logging.debug("Disconnecting...") + self.sendRaw(self.stream_footer) + time.sleep(5) + #send end of stream + #wait for end of stream back try: +# self.socket.shutdown(socket.SHUT_RDWR) self.socket.close() + except socket.error as (errno,strerror): + logging.exception("Error while disconnecting. Socket Error #%s: %s" % (errno, strerror)) + try: self.filesocket.close() - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error as serr: - #logging.warning("Error while disconnecting. Socket Error #%s: %s" % (errno, strerror)) - #thread.exit_thread() - pass - if self.state['processing']: - #raise CloseStream - pass - - def reconnect(self): - self.state.set('tls',False) - self.state.set('ssl',False) - time.sleep(1) - self.connect(self.server,self.port) + except socket.error as (errno,strerror): + logging.exception("Error closing filesocket.") + + if reconnect: self.connect() + def shutdown(self): + ''' + Disconnects and shuts down all event threads. + ''' + self.disconnect() + self.run = False + self.scheduler.run = False + def incoming_filter(self, xmlobj): return xmlobj - + def __spawnEvent(self, xmlobj): "watching xmlOut and processes handlers" #convert XML into Stanza + # TODO surround this log statement with an if, it's expensive logging.debug("RECV: %s" % cElementTree.tostring(xmlobj)) xmlobj = self.incoming_filter(xmlobj) stanza = None @@ -319,17 +314,21 @@ class XMLStream(object): if stanza is None: stanza = StanzaBase(self, xmlobj) unhandled = True + # TODO inefficient linear search; performance might be improved by hashtable lookup for handler in self.__handlers: if handler.match(stanza): + logging.debug('matched stanza to handler %s', handler.name) handler.prerun(stanza) self.eventqueue.put(('stanza', handler, stanza)) - if handler.checkDelete(): self.__handlers.pop(self.__handlers.index(handler)) + if handler.checkDelete(): + logging.debug('deleting callback %s', handler.name) + self.__handlers.pop(self.__handlers.index(handler)) unhandled = False if unhandled: stanza.unhandled() #loop through handlers and test match #spawn threads as necessary, call handlers, sending Stanza - + def _eventRunner(self): logging.debug("Loading event runner") while self.run: @@ -341,26 +340,27 @@ class XMLStream(object): etype = event[0] handler = event[1] args = event[2:] - #etype, handler, *args = event #python 3.x way + #etype, handler, *args = event #python 3.x way if etype == 'stanza': try: handler.run(args[0]) except Exception as e: - traceback.print_exc() + logging.exception("Exception in event handler") args[0].exception(e) elif etype == 'sched': try: + #handler(*args[0]) handler.run(*args) except: logging.error(traceback.format_exc()) elif etype == 'quit': logging.debug("Quitting eventRunner thread") return False - + def registerHandler(self, handler, before=None, after=None): "Add handler with matcher class and parameters." self.__handlers.append(handler) - + def removeHandler(self, name): "Removes the handler." idx = 0 @@ -446,4 +446,4 @@ class XMLStream(object): def start_stream_handler(self, xml): """Meant to be overridden""" - pass + logging.warn("No start stream handler has been implemented.") diff --git a/tests/test_disco.py b/tests/test_disco.py new file mode 100644 index 0000000..bbe285a --- /dev/null +++ b/tests/test_disco.py @@ -0,0 +1,155 @@ +import unittest +from xml.etree import cElementTree as ET +from sleekxmpp.xmlstream.matcher.stanzapath import StanzaPath +from . import xmlcompare + +import sleekxmpp.plugins.xep_0030 as sd + +def stanzaPlugin(stanza, plugin): + stanza.plugin_attrib_map[plugin.plugin_attrib] = plugin + stanza.plugin_tag_map["{%s}%s" % (plugin.namespace, plugin.name)] = plugin + +class testdisco(unittest.TestCase): + + def setUp(self): + self.sd = sd + stanzaPlugin(self.sd.Iq, self.sd.DiscoInfo) + stanzaPlugin(self.sd.Iq, self.sd.DiscoItems) + + def try3Methods(self, xmlstring, iq): + iq2 = self.sd.Iq(None, self.sd.ET.fromstring(xmlstring)) + values = iq2.getValues() + iq3 = self.sd.Iq() + iq3.setValues(values) + self.failUnless(xmlstring == str(iq) == str(iq2) == str(iq3), str(iq)+"3 methods for creating stanza don't match") + + def testCreateInfoQueryNoNode(self): + """Testing disco#info query with no node.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_info']['node'] = '' + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testCreateInfoQueryWithNode(self): + """Testing disco#info query with a node.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_info']['node'] = 'foo' + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testCreateInfoQueryNoNode(self): + """Testing disco#items query with no node.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_items']['node'] = '' + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testCreateItemsQueryWithNode(self): + """Testing disco#items query with a node.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_items']['node'] = 'foo' + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testInfoIdentities(self): + """Testing adding identities to disco#info.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_info']['node'] = 'foo' + iq['disco_info'].addIdentity('conference', 'text', 'Chatroom') + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testInfoFeatures(self): + """Testing adding features to disco#info.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_info']['node'] = 'foo' + iq['disco_info'].addFeature('foo') + iq['disco_info'].addFeature('bar') + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testItems(self): + """Testing adding features to disco#info.""" + iq = self.sd.Iq() + iq['id'] = "0" + iq['disco_items']['node'] = 'foo' + iq['disco_items'].addItem('user@localhost') + iq['disco_items'].addItem('user@localhost', 'foo') + iq['disco_items'].addItem('user@localhost', 'bar', 'Testing') + xmlstring = """""" + self.try3Methods(xmlstring, iq) + + def testAddRemoveIdentities(self): + """Test adding and removing identities to disco#info stanza""" + ids = [('automation', 'commands', 'AdHoc'), + ('conference', 'text', 'ChatRoom')] + + info = self.sd.DiscoInfo() + info.addIdentity(*ids[0]) + self.failUnless(info.getIdentities() == [ids[0]]) + + info.delIdentity('automation', 'commands') + self.failUnless(info.getIdentities() == []) + + info.setIdentities(ids) + self.failUnless(info.getIdentities() == ids) + + info.delIdentity('automation', 'commands') + self.failUnless(info.getIdentities() == [ids[1]]) + + info.delIdentities() + self.failUnless(info.getIdentities() == []) + + def testAddRemoveFeatures(self): + """Test adding and removing features to disco#info stanza""" + features = ['foo', 'bar', 'baz'] + + info = self.sd.DiscoInfo() + info.addFeature(features[0]) + self.failUnless(info.getFeatures() == [features[0]]) + + info.delFeature('foo') + self.failUnless(info.getFeatures() == []) + + info.setFeatures(features) + self.failUnless(info.getFeatures() == features) + + info.delFeature('bar') + self.failUnless(info.getFeatures() == ['foo', 'baz']) + + info.delFeatures() + self.failUnless(info.getFeatures() == []) + + def testAddRemoveItems(self): + """Test adding and removing items to disco#items stanza""" + items = [('user@localhost', None, None), + ('user@localhost', 'foo', None), + ('user@localhost', 'bar', 'Test')] + + info = self.sd.DiscoItems() + self.failUnless(True, ""+str(items[0])) + + info.addItem(*(items[0])) + self.failUnless(info.getItems() == [items[0]], info.getItems()) + + info.delItem('user@localhost') + self.failUnless(info.getItems() == []) + + info.setItems(items) + self.failUnless(info.getItems() == items) + + info.delItem('user@localhost', 'foo') + self.failUnless(info.getItems() == [items[0], items[2]]) + + info.delItems() + self.failUnless(info.getItems() == []) + + + +suite = unittest.TestLoader().loadTestsFromTestCase(testdisco) diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 0000000..11821db --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,35 @@ +import unittest + +class testevents(unittest.TestCase): + + def setUp(self): + import sleekxmpp.stanza.presence as p + self.p = p + + def testEventHappening(self): + "Test handler working" + import sleekxmpp + c = sleekxmpp.ClientXMPP('crap@wherever', 'password') + happened = [] + def handletestevent(event): + happened.append(True) + c.add_event_handler("test_event", handletestevent) + c.event("test_event", {}) + c.event("test_event", {}) + self.failUnless(happened == [True, True], "event did not get triggered twice") + + def testDelEvent(self): + "Test handler working, then deleted and not triggered" + import sleekxmpp + c = sleekxmpp.ClientXMPP('crap@wherever', 'password') + happened = [] + def handletestevent(event): + happened.append(True) + c.add_event_handler("test_event", handletestevent) + c.event("test_event", {}) + c.del_event_handler("test_event", handletestevent) + c.event("test_event", {}) # should not trigger because it was deleted + self.failUnless(happened == [True], "event did not get triggered the correct number of times") + + +suite = unittest.TestLoader().loadTestsFromTestCase(testevents) diff --git a/tests/test_pubsubstanzas.py b/tests/test_pubsubstanzas.py index 5353f90..dc41fc3 100644 --- a/tests/test_pubsubstanzas.py +++ b/tests/test_pubsubstanzas.py @@ -97,6 +97,21 @@ class testpubsubstanzas(unittest.TestCase): iq3.setValues(values) self.failUnless(xmlstring == str(iq) == str(iq2) == str(iq3)) + def testState(self): + "Testing iq/psstate stanzas" + from sleekxmpp.plugins import xep_0004 + iq = self.ps.Iq() + iq['psstate']['node']= 'mynode' + iq['psstate']['item']= 'myitem' + pl = ET.Element('{http://andyet.net/protocol/pubsubqueue}claimed') + iq['psstate']['payload'] = pl + xmlstring = """""" + iq2 = self.ps.Iq(None, self.ps.ET.fromstring(xmlstring)) + iq3 = self.ps.Iq() + values = iq2.getValues() + iq3.setValues(values) + self.failUnless(xmlstring == str(iq) == str(iq2) == str(iq3)) + def testDefault(self): "Testing iq/pubsub_owner/default stanzas" from sleekxmpp.plugins import xep_0004 diff --git a/tests/test_statemachine.py b/tests/test_statemachine.py new file mode 100644 index 0000000..6749c8d --- /dev/null +++ b/tests/test_statemachine.py @@ -0,0 +1,116 @@ +import unittest +import time, threading + +if __name__ == '__main__': + import sys, os + sys.path.insert(0, os.getcwd()) + import sleekxmpp.xmlstream.statemachine as sm + + +class testStateMachine(unittest.TestCase): + + def setUp(self): pass + + + def testDefaults(self): + "Test ensure transitions occur correctly in a single thread" + s = sm.StateMachine(('one','two','three')) +# self.assertTrue(s.one) + self.assertTrue(s['one']) +# self.failIf(s.two) + self.failIf(s['two']) + try: + s.booga + self.fail('s.booga is an invalid state and should throw an exception!') + except: pass #expected exception + + + def testTransitions(self): + "Test ensure transitions occur correctly in a single thread" + s = sm.StateMachine(('one','two','three')) +# self.assertTrue(s.one) + + self.assertTrue( s.transition('one', 'two') ) +# self.assertTrue( s.two ) + self.assertTrue( s['two'] ) +# self.failIf( s.one ) + self.failIf( s['one'] ) + + self.assertTrue( s.transition('two', 'three') ) + self.assertTrue( s['three'] ) + self.failIf( s['two'] ) + + self.assertTrue( s.transition('three', 'one') ) + self.assertTrue( s['one'] ) + self.failIf( s['three'] ) + + # should return False immediately w/ no wait: + self.failIf( s.transition('three', 'one') ) + self.assertTrue( s['one'] ) + self.failIf( s['three'] ) + + # test fail condition w/ a short delay: + self.failIf( s.transition('two', 'three') ) + + # Ensure bad states are weeded out: + try: + s.transition('blah', 'three') + s.fail('Exception expected') + except: pass + + try: + s.transition('one', 'blahblah') + s.fail('Exception expected') + except: pass + + + def testTransitionsBlocking(self): + "Test that transitions block from more than one thread" + + s = sm.StateMachine(('one','two','three')) + self.assertTrue(s['one']) + + now = time.time() + self.failIf( s.transition('two', 'one', wait=5.0) ) + self.assertTrue( time.time() > now + 4 ) + self.assertTrue( time.time() < now + 7 ) + + def testThreadedTransitions(self): + "Test that transitions are atomic in > one thread" + + s = sm.StateMachine(('one','two','three')) + self.assertTrue(s['one']) + + thread_state = {'ready': False, 'transitioned': False} + def t1(): + # this will block until the main thread transitions to 'two' + if s['two']: + print 'thread has already transitioned!' + self.fail() + thread_state['ready'] = True + print 'Thread is ready' + self.assertTrue( s.transition('two','three', wait=20) ) + print 'transitioned to three!' + thread_state['transitioned'] = True + + thread = threading.Thread(target=t1) + thread.daemon = True + thread.start() + start = time.time() + while not thread_state['ready']: + print 'not ready' + if time.time() > start+10: self.fail('Timeout waiting for thread to init!') + time.sleep(0.1) + time.sleep(0.2) # the thread should be blocking on the 'transition' call at this point. + self.failIf( thread_state['transitioned'] ) # ensure it didn't 'go' yet. + print 'transitioning to two!' + self.assertTrue( s.transition('one','two') ) + time.sleep(0.2) # second thread should have transitioned now: + self.assertTrue( thread_state['transitioned'] ) + + + + +suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine) + +if __name__ == '__main__': unittest.main()