diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py
index dfa3736..3223901 100644
--- a/sleekxmpp/xmlstream/stanzabase.py
+++ b/sleekxmpp/xmlstream/stanzabase.py
@@ -293,6 +293,9 @@ class ElementBase(object):
3. Delete top level XML attribute named foo.
4. Remove the foo plugin, if it was loaded.
5. Do nothing.
+
+ Arguments:
+ attrib -- The name of the affected stanza interface.
"""
if attrib in self.interfaces:
del_method = "del%s" % attrib.title()
@@ -308,6 +311,49 @@ class ElementBase(object):
del self.plugins[attrib]
return self
+ def _setAttr(self, name, value):
+ """
+ Set the value of a top level attribute of the underlying XML object.
+
+ If the new value is None or an empty string, then the attribute will
+ be removed.
+
+ Arguments:
+ name -- The name of the attribute.
+ value -- The new value of the attribute, or None or '' to
+ remove it.
+ """
+ if value is None or value == '':
+ self.__delitem__(name)
+ else:
+ self.xml.attrib[name] = value
+
+ def _delAttr(self, name):
+ """
+ Remove a top level attribute of the underlying XML object.
+
+ Arguments:
+ name -- The name of the attribute.
+ """
+ if name in self.xml.attrib:
+ del self.xml.attrib[name]
+
+ def _getAttr(self, name, default=''):
+ """
+ Return the value of a top level attribute of the underlying
+ XML object.
+
+ In case the attribute has not been set, a default value can be
+ returned instead. An empty string is returned if no other default
+ is supplied.
+
+ Arguments:
+ name -- The name of the attribute.
+ default -- Optional value to return if the attribute has not
+ been set. An empty string is returned otherwise.
+ """
+ return self.xml.attrib.get(name, default)
+
@property
def attrib(self): #backwards compatibility
return self
@@ -400,19 +446,6 @@ class ElementBase(object):
return False
return True
- def _setAttr(self, name, value):
- if value is None or value == '':
- self.__delitem__(name)
- else:
- self.xml.attrib[name] = value
-
- def _delAttr(self, name):
- if name in self.xml.attrib:
- del self.xml.attrib[name]
-
- def _getAttr(self, name, default=''):
- return self.xml.attrib.get(name, default)
-
def _getSubText(self, name):
if '}' not in name:
name = "{%s}%s" % (self.namespace, name)
diff --git a/tests/test_elementbase.py b/tests/test_elementbase.py
index bf86e59..78277af 100644
--- a/tests/test_elementbase.py
+++ b/tests/test_elementbase.py
@@ -231,4 +231,41 @@ class TestElementBase(SleekTest):
""")
+ def testModifyingAttributes(self):
+ """Test modifying top level attributes of a stanza's XML object."""
+
+ class TestStanza(ElementBase):
+ name = "foo"
+ namespace = "foo"
+ interfaces = set(('bar', 'baz'))
+
+ stanza = TestStanza()
+
+ self.checkStanza(TestStanza, stanza, """
+
+ """)
+
+ self.failUnless(stanza._getAttr('bar') == '',
+ "Incorrect value returned for an unset XML attribute.")
+
+ stanza._setAttr('bar', 'a')
+ stanza._setAttr('baz', 'b')
+
+ self.checkStanza(TestStanza, stanza, """
+
+ """)
+
+ self.failUnless(stanza._getAttr('bar') == 'a',
+ "Retrieved XML attribute value is incorrect.")
+
+ stanza._setAttr('bar', None)
+ stanza._delAttr('baz')
+
+ self.checkStanza(TestStanza, stanza, """
+
+ """)
+
+ self.failUnless(stanza._getAttr('bar', 'c') == 'c',
+ "Incorrect default value returned for an unset XML attribute.")
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestElementBase)