diff options
| author | Blake Rouse <blake.rouse@canonical.com> | 2017-08-15 19:28:57 (GMT) |
|---|---|---|
| committer | Blake Rouse <blake.rouse@canonical.com> | 2017-08-15 19:28:57 (GMT) |
| commit | e34ededffc9cb96124ee2232793e0c064fdd735a (patch) | |
| tree | 906d3f98aba72ce81965ae98aa43a5965032e5d8 | |
| parent | fdd2a22d7ff9ebb99a5ce3e22124a105f29418dc (diff) | |
Backport c2aed3017ef73b5af23c72d40c9c0c0fc1cf475f and 6ffe84b98701cff8ead64d082b807642da83f326 from master.
LP: #1707971 - Only expose the source address for each subnet on a region controller.
| -rw-r--r-- | src/maasserver/rpc/regionservice.py | 19 | ||||
| -rw-r--r-- | src/maasserver/rpc/tests/test_regionservice.py | 35 | ||||
| -rw-r--r-- | src/provisioningserver/utils/network.py | 68 | ||||
| -rw-r--r-- | src/provisioningserver/utils/tests/test_network.py | 76 |
4 files changed, 182 insertions, 16 deletions
diff --git a/src/maasserver/rpc/regionservice.py b/src/maasserver/rpc/regionservice.py index 4d962e8..3ddf50c 100644 --- a/src/maasserver/rpc/regionservice.py +++ b/src/maasserver/rpc/regionservice.py @@ -81,6 +81,7 @@ from provisioningserver.security import calculate_digest from provisioningserver.utils.events import EventGroup from provisioningserver.utils.network import ( get_all_interface_addresses, + get_all_interface_source_addresses, resolves_to_loopback_address, ) from provisioningserver.utils.ps import is_pid_running @@ -1171,7 +1172,14 @@ class RegionAdvertisingService(service.Service): if port is None: return set() # Not serving yet. else: - addresses = set() + addresses = get_all_interface_source_addresses() + if len(addresses) > 0: + return set( + (addr, port) + for addr in addresses + ) + # There are no non-loopback addresses, so return loopback + # address as a fallback. loopback_addresses = set() for addr in get_all_interface_addresses(): ipaddr = IPAddress(addr) @@ -1179,14 +1187,7 @@ class RegionAdvertisingService(service.Service): continue # Don't advertise link-local addresses. if ipaddr.is_loopback(): loopback_addresses.add((addr, port)) - else: - addresses.add((addr, port)) - if len(addresses) > 0: - return addresses - else: - # There are no non-loopback addresses, so return loopback - # address as a fallback. - return loopback_addresses + return loopback_addresses class RegionAdvertising: diff --git a/src/maasserver/rpc/tests/test_regionservice.py b/src/maasserver/rpc/tests/test_regionservice.py index 972a3d7..0a8016a 100644 --- a/src/maasserver/rpc/tests/test_regionservice.py +++ b/src/maasserver/rpc/tests/test_regionservice.py @@ -65,6 +65,7 @@ from maastesting.matchers import ( MockAnyCall, MockCalledOnceWith, MockCallsMatch, + MockNotCalled, Provides, ) from maastesting.runtest import MAASCrochetRunTest @@ -1269,10 +1270,13 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): getPort = getServiceNamed.return_value.getPort getPort.return_value = port - def patch_addresses(self, addresses): + def patch_addresses(self, source_addresses, old_addresses): + get_all_interface_source_addresses = self.patch( + regionservice, "get_all_interface_source_addresses") + get_all_interface_source_addresses.return_value = source_addresses get_all_interface_addresses = self.patch( regionservice, "get_all_interface_addresses") - get_all_interface_addresses.return_value = addresses + get_all_interface_addresses.return_value = old_addresses @wait_for_reactor @inlineCallbacks @@ -1286,7 +1290,7 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): dump = yield deferToDatabase(RegionAdvertising.dump) self.assertItemsEqual([], dump) - def test__getAddresses_excluding_loopback(self): + def test__getAddresses_uses_source_addresses(self): service = RegionAdvertisingService() example_port = factory.pick_port() @@ -1311,11 +1315,10 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): str(netaddr.ip.IPV6_LOOPBACK), } self.patch_addresses( - example_ipv4_addrs | example_ipv6_addrs | + example_ipv4_addrs | example_ipv6_addrs, example_link_local_addrs | example_loopback_addrs) - # IPv6 addresses, link-local addresses and loopback are excluded, and - # thus not advertised. + # Only the source addresses are returned. self.assertItemsEqual( [(addr, example_port) for addr in example_ipv4_addrs.union(example_ipv6_addrs)], @@ -1325,8 +1328,11 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): eventloop.services.getServiceNamed, MockCalledOnceWith("rpc")) self.assertThat( - regionservice.get_all_interface_addresses, + regionservice.get_all_interface_source_addresses, MockCalledOnceWith()) + self.assertThat( + regionservice.get_all_interface_addresses, + MockNotCalled()) def test__getAddresses_including_loopback(self): service = RegionAdvertisingService() @@ -1334,6 +1340,16 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): example_port = factory.pick_port() self.patch_port(example_port) + example_ipv4_addrs = set() + for _ in range(5): + ip = factory.make_ipv4_address() + if not netaddr.IPAddress(ip).is_loopback(): + example_ipv4_addrs.add(ip) + example_ipv6_addrs = set() + for _ in range(5): + ip = factory.make_ipv6_address() + if not netaddr.IPAddress(ip).is_loopback(): + example_ipv6_addrs.add(ip) example_link_local_addrs = { factory.pick_ip_in_network(netaddr.ip.IPV4_LINK_LOCAL), factory.pick_ip_in_network(netaddr.ip.IPV6_LINK_LOCAL), @@ -1344,6 +1360,8 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): str(netaddr.ip.IPV6_LOOPBACK), } self.patch_addresses( + set(), + example_ipv4_addrs | example_ipv6_addrs | example_link_local_addrs | example_loopback_addrs) # Only IPv4 loopback is exposed. @@ -1355,6 +1373,9 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase): eventloop.services.getServiceNamed, MockCalledOnceWith("rpc")) self.assertThat( + regionservice.get_all_interface_source_addresses, + MockCalledOnceWith()) + self.assertThat( regionservice.get_all_interface_addresses, MockCalledOnceWith()) diff --git a/src/provisioningserver/utils/network.py b/src/provisioningserver/utils/network.py index 264921c..baf7148 100644 --- a/src/provisioningserver/utils/network.py +++ b/src/provisioningserver/utils/network.py @@ -94,6 +94,9 @@ OuterRange = TypeVar('OuterRange', IPRange, IPNetwork, bytes, str) # were passed into the `netaddr.IPAddress` constructor. MaybeIPAddress = TypeVar('MaybeIPAddress', IPAddress, bytes, str, int) +IPAddressOrNetwork = TypeVar( + 'IPAddressOrNetwork', IPNetwork, IPAddress, bytes, str, int) + class IPRANGE_TYPE: """Well-known purpose types for IP ranges.""" @@ -1191,6 +1194,39 @@ def get_all_interfaces_definition(annotate_with_monitored: bool=True) -> dict: return interfaces +def get_all_interface_subnets(): + """Returns all subnets that this machine has access to. + + Uses the `get_all_interfaces_definition` to get the available interfaces, + and returns a set of subnets for the machine. + + :return: set of IP networks + :rtype: set of `IPNetwork` + """ + return set( + IPNetwork(link["address"]) + for interface in get_all_interfaces_definition().values() + for link in interface["links"] + ) + + +def get_all_interface_source_addresses(): + """Return one source address per subnets defined on this machine. + + Uses the `get_all_interface_subnets` and `get_source_address` to determine + the best source addresses for this machine. + + :return: set of IP addresses + :rtype: set of `str` + """ + source_addresses = set() + for network in get_all_interface_subnets(): + src = get_source_address(network) + if src is not None: + source_addresses.add(src) + return source_addresses + + def has_ipv4_address(interfaces: dict, interface: str) -> bool: """Returns True if the specified interface has an IPv4 address assigned. @@ -1321,3 +1357,35 @@ def coerce_to_valid_hostname(hostname): if hostname == '' or len(hostname) > 64: return None return hostname + + +def get_source_address(destination_ip: IPAddressOrNetwork): + """Returns the local source address for the specified destination IP. + + :param destination_ip: Can be an IP address in string format, an IPNetwork, + or an IPAddress object. + :return: the string representation of the local IP address that would be + used for communication with the specified destination. + """ + if isinstance(destination_ip, IPNetwork): + destination_ip = IPAddress(destination_ip.first + 1) + else: + destination_ip = make_ipaddress(destination_ip) + af = AF_INET if destination_ip.version == 4 else AF_INET6 + with socket.socket(af, socket.SOCK_DGRAM) as sock: + peername = str(destination_ip) + local_address = "0.0.0.0" if af == socket.AF_INET else "::" + try: + # Note: this sets up the socket *just enough* to get the source + # address. No network traffic will be transmitted. + sock.bind((local_address, 0)) + sock.connect((peername, 7)) + sockname = sock.getsockname() + own_ip = sockname[0] + return own_ip + except OSError: + # Probably "can't assign requested address", which probably means + # we tried to connect to an IPv6 address, but IPv6 is not + # configured. Could also happen if a network or broadcast address + # is passed in, or we otherwise cannot route to the destination. + return None diff --git a/src/provisioningserver/utils/tests/test_network.py b/src/provisioningserver/utils/tests/test_network.py index 57498af..fbfd678 100644 --- a/src/provisioningserver/utils/tests/test_network.py +++ b/src/provisioningserver/utils/tests/test_network.py @@ -51,11 +51,14 @@ from provisioningserver.utils.network import ( format_eui, get_all_addresses_for_interface, get_all_interface_addresses, + get_all_interface_source_addresses, + get_all_interface_subnets, get_all_interfaces_definition, get_default_monitored_interfaces, get_eui_organization, get_interface_children, get_mac_organization, + get_source_address, has_ipv4_address, hex_str_to_bytes, inet_ntop, @@ -1518,6 +1521,52 @@ class TestGetAllInterfacesDefinition(MAASTestCase): self.assertInterfacesResult(ip_addr, iproute_info, {}, expected_result) +class TestGetAllInterfacesSubnets(MAASTestCase): + """Tests for `get_all_interface_subnets()`.""" + + def test_includes_unique_subnets(self): + interface_definition = { + 'eth0': { + 'links': [{ + 'address': '192.168.122.1/24', + }, { + 'address': '192.168.122.3/24', + }], + }, + 'eth1': { + 'links': [{ + 'address': '192.168.123.1/24', + }, { + 'address': '192.168.123.2/24', + }] + } + } + self.patch( + network_module, + 'get_all_interfaces_definition').return_value = ( + interface_definition) + self.assertEquals( + set([ + IPNetwork('192.168.122.1/24'), IPNetwork('192.168.123.1/24')]), + get_all_interface_subnets()) + + +class TestGetAllInterfacesSourceAddresses(MAASTestCase): + """Tests for `get_all_interface_source_addresses()`.""" + + def test_includes_unique_subnets(self): + interface_subnets = set([ + IPNetwork('192.168.122.1/24'), IPNetwork('192.168.123.1/24')]) + self.patch( + network_module, + 'get_all_interface_subnets').return_value = interface_subnets + self.patch(network_module, 'get_source_address').side_effect = [ + '192.168.122.1', None] + self.assertEquals( + set(['192.168.122.1']), + get_all_interface_source_addresses()) + + class TestHasIPv4Address(MAASTestCase): """Tests for `has_ipv4_address()`.""" @@ -2119,3 +2168,30 @@ class TestCoerceHostname(MAASTestCase): def test_returns_none_if_result_too_large(self): self.assertIsNone(coerce_to_valid_hostname('a' * 65)) + + +class TestGetSourceAddress(MAASTestCase): + + def test__accepts_ipnetwork(self): + self.assertThat( + get_source_address(IPNetwork("127.0.0.1/8")), Equals("127.0.0.1")) + + def test__accepts_ipaddress(self): + self.assertThat( + get_source_address(IPAddress("127.0.0.1")), Equals("127.0.0.1")) + + def test__accepts_string(self): + self.assertThat( + get_source_address("127.0.0.1"), Equals("127.0.0.1")) + + def test__supports_ipv6(self): + self.assertThat( + get_source_address("::1"), Equals("::1")) + + def test__returns_none_if_no_route_found(self): + self.assertThat( + get_source_address("127.0.0.0"), Is(None)) + + def test__returns_appropriate_address_for_global_ip(self): + self.assertThat( + get_source_address("8.8.8.8"), Not(Is(None))) |
