diff options
| -rw-r--r-- | pyptlib/server_config.py | 40 | ||||
| -rw-r--r-- | pyptlib/test/test_server.py | 2 | ||||
| -rw-r--r-- | pyptlib/test/test_util.py | 34 | ||||
| -rw-r--r-- | pyptlib/util.py | 85 |
4 files changed, 133 insertions, 28 deletions
diff --git a/pyptlib/server_config.py b/pyptlib/server_config.py index 4ceb3f4..ba0fa63 100644 --- a/pyptlib/server_config.py +++ b/pyptlib/server_config.py @@ -6,6 +6,7 @@ Low-level parts of pyptlib that are only useful to servers. """ import pyptlib.config as config +import pyptlib.util as util class ServerConfig(config.Config): """ @@ -56,7 +57,13 @@ class ServerConfig(config.Config): bindaddrs = self.get('TOR_PT_SERVER_BINDADDR').split(',') for bindaddr in bindaddrs: (transport_name, addrport) = bindaddr.split('-') - (addr, port) = self.get_addrport_from_string(addrport) + + try: + (addr, port) = util.parse_addr_spec(addrport) + except ValueError, err: + self.writeEnvError(err) + raise config.EnvError(err) + self.serverBindAddr[transport_name] = (addr, port) # Get transports. @@ -148,30 +155,9 @@ class ServerConfig(config.Config): """ string = self.get(key) - return self.get_addrport_from_string(string) - - def get_addrport_from_string(self, string): - """ - Parse a string holding an address:port value. - - :param str string: A string. - - :returns: tuple -- (address,port) - - :raises: :class:`pyptlib.config.EnvError` if string was not in address:port format. - """ - - addrport = string.split(':') - - if (len(addrport) != 2) or (not addrport[1].isdigit()): - message = 'Parsing error (%s).' % (string) - self.writeEnvError(message) - raise config.EnvError(message) # XXX maybe return ValueError - - if (not 0 <= int(addrport[1]) < 65536): - message = 'Port out of range (%s).' % (string) - self.writeEnvError(message) - raise config.EnvError(message) - - return addrport + try: + return util.parse_addr_spec(string) + except ValueError, err: + self.writeEnvError(err) + raise config.EnvError(err) diff --git a/pyptlib/test/test_server.py b/pyptlib/test/test_server.py index e5896e8..a73fb8c 100644 --- a/pyptlib/test/test_server.py +++ b/pyptlib/test/test_server.py @@ -176,7 +176,7 @@ class testServer(unittest.TestCase): os.environ = TEST_ENVIRON retval = pyptlib.server.init(["what"]) self.assertEquals(retval['auth_cookie_file'], '/lulzie') - self.assertEquals(retval['ext_orport'], ['127.0.0.1', '5555']) + self.assertEquals(retval['ext_orport'], ('127.0.0.1', 5555)) def test_ext_or_but_no_auth_cookie(self): """TOR_PT_EXTENDED_SERVER_PORT without TOR_PT_AUTH_COOKIE_FILE.""" diff --git a/pyptlib/test/test_util.py b/pyptlib/test/test_util.py new file mode 100644 index 0000000..9b4e809 --- /dev/null +++ b/pyptlib/test/test_util.py @@ -0,0 +1,34 @@ +import unittest + +import pyptlib.util + +# Tests borrowed from flashproxy. +class ParseAddrSpecTest(unittest.TestCase): + def test_ipv4(self): + self.assertEqual(pyptlib.util.parse_addr_spec("192.168.0.1:9999"), ("192.168.0.1", 9999)) + + def test_ipv6(self): + self.assertEqual(pyptlib.util.parse_addr_spec("[12::34]:9999"), ("12::34", 9999)) + + def test_defhost_defport_ipv4(self): + self.assertEqual(pyptlib.util.parse_addr_spec("192.168.0.2:8888", defhost="192.168.0.1", defport=9999), ("192.168.0.2", 8888)) + self.assertEqual(pyptlib.util.parse_addr_spec("192.168.0.2:", defhost="192.168.0.1", defport=9999), ("192.168.0.2", 9999)) + self.assertEqual(pyptlib.util.parse_addr_spec("192.168.0.2", defhost="192.168.0.1", defport=9999), ("192.168.0.2", 9999)) + self.assertEqual(pyptlib.util.parse_addr_spec(":8888", defhost="192.168.0.1", defport=9999), ("192.168.0.1", 8888)) + self.assertEqual(pyptlib.util.parse_addr_spec(":", defhost="192.168.0.1", defport=9999), ("192.168.0.1", 9999)) + self.assertEqual(pyptlib.util.parse_addr_spec("", defhost="192.168.0.1", defport=9999), ("192.168.0.1", 9999)) + + def test_defhost_defport_ipv6(self): + self.assertEqual(pyptlib.util.parse_addr_spec("[1234::2]:8888", defhost="1234::1", defport=9999), ("1234::2", 8888)) + self.assertEqual(pyptlib.util.parse_addr_spec("[1234::2]:", defhost="1234::1", defport=9999), ("1234::2", 9999)) + self.assertEqual(pyptlib.util.parse_addr_spec("[1234::2]", defhost="1234::1", defport=9999), ("1234::2", 9999)) + self.assertEqual(pyptlib.util.parse_addr_spec(":8888", defhost="1234::1", defport=9999), ("1234::1", 8888)) + self.assertEqual(pyptlib.util.parse_addr_spec(":", defhost="1234::1", defport=9999), ("1234::1", 9999)) + self.assertEqual(pyptlib.util.parse_addr_spec("", defhost="1234::1", defport=9999), ("1234::1", 9999)) + + def test_noresolve(self): + """Test that parse_addr_spec does not do DNS resolution by default.""" + self.assertRaises(ValueError, pyptlib.util.parse_addr_spec, "example.com") + +if __name__ == "__main__": + unittest.main() diff --git a/pyptlib/util.py b/pyptlib/util.py index 9b75786..9966625 100644 --- a/pyptlib/util.py +++ b/pyptlib/util.py @@ -5,6 +5,9 @@ Utility functions. """ +import re +import socket + from pyptlib.config import Config, EnvError def checkClientMode(): # XXX WTF!???! This also exists in config.py. @@ -19,4 +22,86 @@ def checkClientMode(): # XXX WTF!???! This also exists in config.py. except EnvError: return False +# This code is borrowed from flashproxy. Thanks David! +def parse_addr_spec(spec, defhost = None, defport = None, resolve = False): + """ + Parse a host:port specification and return a 2-tuple ("host", port) as + understood by the Python socket functions. + + If resolve is true, then the host in the specification or the defhost may be + a domain name, which will be resolved. If resolve is false, then the host + must be a numeric IPv4 or IPv6 address. + + IPv6 addresses must be enclosed in square brackets. + + :returns: tuple -- (address, port) + + :raises: ValueError if spec is not well formed. + + >>> parse_addr_spec("192.168.0.1:9999") + ('192.168.0.1', 9999) + + If defhost or defport are given, those parts of the specification may be + omitted; if so, they will be filled in with defaults. + >>> parse_addr_spec("192.168.0.2:8888", defhost="192.168.0.1", defport=9999) + ('192.168.0.2', 8888) + >>> parse_addr_spec(":8888", defhost="192.168.0.1", defport=9999) + ('192.168.0.1', 8888) + >>> parse_addr_spec("192.168.0.2", defhost="192.168.0.1", defport=9999) + ('192.168.0.2', 9999) + >>> parse_addr_spec("192.168.0.2:", defhost="192.168.0.1", defport=9999) + ('192.168.0.2', 9999) + >>> parse_addr_spec(":", defhost="192.168.0.1", defport=9999) + ('192.168.0.1', 9999) + >>> parse_addr_spec("", defhost="192.168.0.1", defport=9999) + ('192.168.0.1', 9999) + """ + host = None + port = None + af = 0 + m = None + # IPv6 syntax. + if not m: + m = re.match(ur'^\[(.+)\]:(\d*)$', spec) + if m: + host, port = m.groups() + af = socket.AF_INET6 + if not m: + m = re.match(ur'^\[(.+)\]$', spec) + if m: + host, = m.groups() + af = socket.AF_INET6 + # IPv4/hostname/port-only syntax. + if not m: + try: + host, port = spec.split(":", 1) + except ValueError: + host = spec + if re.match(ur'^[\d.]+$', host): + af = socket.AF_INET + else: + af = 0 + host = host or defhost + port = port or defport + if host is None or port is None: + raise ValueError("Bad address specification \"%s\"" % spec) + + # Now we have split around the colon and have a guess at the address family. + # Forward-resolve the name into an addrinfo struct. Real DNS resolution is + # done only if resolve is true; otherwise the address must be numeric. + if resolve: + flags = 0 + else: + flags = socket.AI_NUMERICHOST + try: + addrs = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, socket.IPPROTO_TCP, flags) + except socket.gaierror, e: + raise ValueError("Bad host or port: \"%s\" \"%s\": %s" % (host, port, str(e))) + if not addrs: + raise ValueError("Bad host or port: \"%s\" \"%s\"" % (host, port)) + # Convert the result of socket.getaddrinfo (which is a 2-tuple for IPv4 and + # a 4-tuple for IPv6) into a (host, port) 2-tuple. + host, port = socket.getnameinfo(addrs[0][4], socket.NI_NUMERICHOST | socket.NI_NUMERICSERV) + port = int(port) + return host, port |
