diff --git a/sleekxmpp/clientxmpp.py b/sleekxmpp/clientxmpp.py index d722748..f7bf8f8 100644 --- a/sleekxmpp/clientxmpp.py +++ b/sleekxmpp/clientxmpp.py @@ -27,11 +27,12 @@ from sleekxmpp.xmlstream.matcher import * from sleekxmpp.xmlstream.handler import * # Flag indicating if DNS SRV records are available for use. -SRV_SUPPORT = True try: import dns.resolver -except: - SRV_SUPPORT = False +except ImportError: + DNSPYTHON = False +else: + DNSPYTHON = True log = logging.getLogger(__name__) @@ -78,7 +79,7 @@ class ClientXMPP(BaseXMPP): self.escape_quotes = escape_quotes self.plugin_config = plugin_config self.plugin_whitelist = plugin_whitelist - self.srv_support = SRV_SUPPORT + self.default_port = 5222 self.stream_header = "" % ( self.boundjid.host, @@ -133,55 +134,28 @@ class ClientXMPP(BaseXMPP): connection. Defaults to True. """ self.session_started_event.clear() - if not address or len(address) < 2: - if not self.srv_support: - log.debug("Did not supply (address, port) to connect" + \ - " to and no SRV support is installed" + \ - " (http://www.dnspython.org)." + \ - " Continuing to attempt connection, using" + \ - " server hostname from JID.") - else: - log.debug("Since no address is supplied," + \ - "attempting SRV lookup.") - try: - xmpp_srv = "_xmpp-client._tcp.%s" % self.boundjid.host - answers = dns.resolver.query(xmpp_srv, dns.rdatatype.SRV) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): - log.debug("No appropriate SRV record found." + \ - " Using JID server name.") - except (dns.exception.Timeout,): - log.debug("DNS resolution timed out.") - else: - # Pick a random server, weighted by priority. - - addresses = {} - intmax = 0 - topprio = 65535 - for answer in answers: - topprio = min(topprio, answer.priority) - for answer in answers: - if answer.priority == topprio: - intmax += answer.weight - addresses[intmax] = (answer.target.to_text()[:-1], - answer.port) - - #python3 returns a generator for dictionary keys - items = [x for x in addresses.keys()] - items.sort() - - picked = random.randint(0, intmax) - for item in items: - if picked <= item: - address = addresses[item] - break - - if not address: - # If all else fails, use the server from the JID. - address = (self.boundjid.host, 5222) + address = (self.boundjid.host, 5222) return XMLStream.connect(self, address[0], address[1], use_tls=use_tls, reattempt=reattempt) + def get_dns_records(self, domain, port=None): + if port is None: + port = self.default_port + if DNSPYTHON: + try: + answers = [((answer.target.to_text()[:-1], answer.port), answer.priority, answer.weight) for answer in dns.resolver.query("_xmpp-client._tcp.%s" % domain, dns.rdatatype.SRV)] + except dns.resolver.NXDOMAIN, dns.resolver.NoAnswer: + log.warning("No SRV records for %s" % domain) + answers = super(ClientXMPP, self).get_dns_records(domain, port) + except dns.exception.Timeout: + log.warning("DNS resolution timed out for SRV record of %s" % domain) + answers = super(ClientXMPP, self).get_dns_records(domain, port) + return answers + else: + log.warning("dnspython is not installed -- relying on OS A record resolution") + return [((domain, port), 0, 0)] + def register_feature(self, name, handler, restart=False, order=5000): """ Register a stream feature. diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 8126d98..b090bea 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -36,6 +36,13 @@ from sleekxmpp.xmlstream.matcher import MatchXMLMask if sys.version_info < (3, 0): from sleekxmpp.xmlstream.filesocket import FileSocket, Socket26 +try: + import dns.resolver +except ImportError: + DNSPYTHON = False +else: + DNSPYTHON = True + # The time in seconds to wait before timing out waiting for response stanzas. RESPONSE_TIMEOUT = 10 @@ -51,7 +58,6 @@ SSL_SUPPORT = True # Maximum time to delay between connection attempts is one hour. RECONNECT_MAX_DELAY = 600 - log = logging.getLogger(__name__) @@ -92,6 +98,7 @@ class XMLStream(object): events to be processed. filesocket -- A filesocket created from the main connection socket. Required for ElementTree.iterparse. + default_port -- Default port to connect to. namespace_map -- Optional mapping of namespaces to namespace prefixes. scheduler -- A scheduler object for triggering events after a given period of time. @@ -121,6 +128,7 @@ class XMLStream(object): reconnect_max_delay -- Maximum time to delay between connection attempts. Defaults to RECONNECT_MAX_DELAY, which is one hour. + dns_answers -- List of dns answers not yet used to connect. Methods: add_event_handler -- Add a handler for a custom event. @@ -177,6 +185,8 @@ class XMLStream(object): self.state = StateMachine(('disconnected', 'connected')) self.state._set_state('disconnected') + self.default_port = int(port) + self.default_domain = '' self.address = (host, int(port)) self.filesocket = None self.set_socket(socket) @@ -219,6 +229,7 @@ class XMLStream(object): self.auto_reconnect = True self.is_client = False + self.dns_answers = [] def use_signals(self, signals=None): """ @@ -303,6 +314,10 @@ class XMLStream(object): """ if host and port: self.address = (host, int(port)) + try: + Socket.inet_aton(self.address[0]) + except Socket.error: + self.default_domain = self.address[0] self.is_client = True # Respect previous SSL and TLS usage directives. @@ -322,6 +337,8 @@ class XMLStream(object): def _connect(self): self.stop.clear() + if self.default_domain: + self.address = self.pick_dns_answer(self.default_domain, self.address[1]) self.socket = self.socket_class(Socket.AF_INET, Socket.SOCK_STREAM) self.socket.settimeout(None) @@ -639,6 +656,51 @@ class XMLStream(object): idx += 1 return False + def get_dns_records(self, domain, port=None): + if port is None: + port = self.default_port + if DNSPYTHON: + try: + answers = dns.resolver.query(domain, dns.rdatatype.A) + except dns.resolver.NXDOMAIN, dns.resolver.NoAnswer: + log.warning("No A records for %s" % domain) + except dns.exception.Timeout: + log.warning("DNS resolution timed out for A record of %s" % domain) + answers = [((answer.address, port), 0, 0) for answer in answers] + return answers + else: + log.warning("dnspython is not installed -- relying on OS A record resolution") + return [((domain, port), 0, 0)] + + def pick_dns_answer(self, domain, port=None): + if not self.dns_answers: + self.dns_answers = self.get_dns_records(domain, port) + addresses = {} + intmax = 0 + topprio = 65535 + for answer in self.dns_answers: + topprio = min(topprio, answer[1]) + for answer in self.dns_answers: + if answer[1] == topprio: + intmax += answer[2] + addresses[intmax] = answer[0] + + #python3 returns a generator for dictionary keys + items = [x for x in addresses.keys()] + items.sort() + + picked = random.randint(0, intmax) + for item in items: + if picked <= item: + address = addresses[item] + break + for idx, answer in enumerate(self.dns_answers): + if self.dns_answers[0] == address: + break + self.dns_answers.pop(idx) + log.debug("Trying to connect to %s:%s" % address) + return address + def add_event_handler(self, name, pointer, threaded=False, disposable=False): """