summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBlake Rouse <blake.rouse@canonical.com>2017-08-15 19:28:57 (GMT)
committerBlake Rouse <blake.rouse@canonical.com>2017-08-15 19:28:57 (GMT)
commite34ededffc9cb96124ee2232793e0c064fdd735a (patch)
tree906d3f98aba72ce81965ae98aa43a5965032e5d8
parentfdd2a22d7ff9ebb99a5ce3e22124a105f29418dc (diff)
Backport c2aed3017ef73b5af23c72d40c9c0c0fc1cf475f and 6ffe84b98701cff8ead64d082b807642da83f326 from master.
LP: #1707971 - Only expose the source address for each subnet on a region controller.
-rw-r--r--src/maasserver/rpc/regionservice.py19
-rw-r--r--src/maasserver/rpc/tests/test_regionservice.py35
-rw-r--r--src/provisioningserver/utils/network.py68
-rw-r--r--src/provisioningserver/utils/tests/test_network.py76
4 files changed, 182 insertions, 16 deletions
diff --git a/src/maasserver/rpc/regionservice.py b/src/maasserver/rpc/regionservice.py
index 4d962e8..3ddf50c 100644
--- a/src/maasserver/rpc/regionservice.py
+++ b/src/maasserver/rpc/regionservice.py
@@ -81,6 +81,7 @@ from provisioningserver.security import calculate_digest
from provisioningserver.utils.events import EventGroup
from provisioningserver.utils.network import (
get_all_interface_addresses,
+ get_all_interface_source_addresses,
resolves_to_loopback_address,
)
from provisioningserver.utils.ps import is_pid_running
@@ -1171,7 +1172,14 @@ class RegionAdvertisingService(service.Service):
if port is None:
return set() # Not serving yet.
else:
- addresses = set()
+ addresses = get_all_interface_source_addresses()
+ if len(addresses) > 0:
+ return set(
+ (addr, port)
+ for addr in addresses
+ )
+ # There are no non-loopback addresses, so return loopback
+ # address as a fallback.
loopback_addresses = set()
for addr in get_all_interface_addresses():
ipaddr = IPAddress(addr)
@@ -1179,14 +1187,7 @@ class RegionAdvertisingService(service.Service):
continue # Don't advertise link-local addresses.
if ipaddr.is_loopback():
loopback_addresses.add((addr, port))
- else:
- addresses.add((addr, port))
- if len(addresses) > 0:
- return addresses
- else:
- # There are no non-loopback addresses, so return loopback
- # address as a fallback.
- return loopback_addresses
+ return loopback_addresses
class RegionAdvertising:
diff --git a/src/maasserver/rpc/tests/test_regionservice.py b/src/maasserver/rpc/tests/test_regionservice.py
index 972a3d7..0a8016a 100644
--- a/src/maasserver/rpc/tests/test_regionservice.py
+++ b/src/maasserver/rpc/tests/test_regionservice.py
@@ -65,6 +65,7 @@ from maastesting.matchers import (
MockAnyCall,
MockCalledOnceWith,
MockCallsMatch,
+ MockNotCalled,
Provides,
)
from maastesting.runtest import MAASCrochetRunTest
@@ -1269,10 +1270,13 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
getPort = getServiceNamed.return_value.getPort
getPort.return_value = port
- def patch_addresses(self, addresses):
+ def patch_addresses(self, source_addresses, old_addresses):
+ get_all_interface_source_addresses = self.patch(
+ regionservice, "get_all_interface_source_addresses")
+ get_all_interface_source_addresses.return_value = source_addresses
get_all_interface_addresses = self.patch(
regionservice, "get_all_interface_addresses")
- get_all_interface_addresses.return_value = addresses
+ get_all_interface_addresses.return_value = old_addresses
@wait_for_reactor
@inlineCallbacks
@@ -1286,7 +1290,7 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
dump = yield deferToDatabase(RegionAdvertising.dump)
self.assertItemsEqual([], dump)
- def test__getAddresses_excluding_loopback(self):
+ def test__getAddresses_uses_source_addresses(self):
service = RegionAdvertisingService()
example_port = factory.pick_port()
@@ -1311,11 +1315,10 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
str(netaddr.ip.IPV6_LOOPBACK),
}
self.patch_addresses(
- example_ipv4_addrs | example_ipv6_addrs |
+ example_ipv4_addrs | example_ipv6_addrs,
example_link_local_addrs | example_loopback_addrs)
- # IPv6 addresses, link-local addresses and loopback are excluded, and
- # thus not advertised.
+ # Only the source addresses are returned.
self.assertItemsEqual(
[(addr, example_port)
for addr in example_ipv4_addrs.union(example_ipv6_addrs)],
@@ -1325,8 +1328,11 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
eventloop.services.getServiceNamed,
MockCalledOnceWith("rpc"))
self.assertThat(
- regionservice.get_all_interface_addresses,
+ regionservice.get_all_interface_source_addresses,
MockCalledOnceWith())
+ self.assertThat(
+ regionservice.get_all_interface_addresses,
+ MockNotCalled())
def test__getAddresses_including_loopback(self):
service = RegionAdvertisingService()
@@ -1334,6 +1340,16 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
example_port = factory.pick_port()
self.patch_port(example_port)
+ example_ipv4_addrs = set()
+ for _ in range(5):
+ ip = factory.make_ipv4_address()
+ if not netaddr.IPAddress(ip).is_loopback():
+ example_ipv4_addrs.add(ip)
+ example_ipv6_addrs = set()
+ for _ in range(5):
+ ip = factory.make_ipv6_address()
+ if not netaddr.IPAddress(ip).is_loopback():
+ example_ipv6_addrs.add(ip)
example_link_local_addrs = {
factory.pick_ip_in_network(netaddr.ip.IPV4_LINK_LOCAL),
factory.pick_ip_in_network(netaddr.ip.IPV6_LINK_LOCAL),
@@ -1344,6 +1360,8 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
str(netaddr.ip.IPV6_LOOPBACK),
}
self.patch_addresses(
+ set(),
+ example_ipv4_addrs | example_ipv6_addrs |
example_link_local_addrs | example_loopback_addrs)
# Only IPv4 loopback is exposed.
@@ -1355,6 +1373,9 @@ class TestRegionAdvertisingService(MAASTransactionServerTestCase):
eventloop.services.getServiceNamed,
MockCalledOnceWith("rpc"))
self.assertThat(
+ regionservice.get_all_interface_source_addresses,
+ MockCalledOnceWith())
+ self.assertThat(
regionservice.get_all_interface_addresses,
MockCalledOnceWith())
diff --git a/src/provisioningserver/utils/network.py b/src/provisioningserver/utils/network.py
index 264921c..baf7148 100644
--- a/src/provisioningserver/utils/network.py
+++ b/src/provisioningserver/utils/network.py
@@ -94,6 +94,9 @@ OuterRange = TypeVar('OuterRange', IPRange, IPNetwork, bytes, str)
# were passed into the `netaddr.IPAddress` constructor.
MaybeIPAddress = TypeVar('MaybeIPAddress', IPAddress, bytes, str, int)
+IPAddressOrNetwork = TypeVar(
+ 'IPAddressOrNetwork', IPNetwork, IPAddress, bytes, str, int)
+
class IPRANGE_TYPE:
"""Well-known purpose types for IP ranges."""
@@ -1191,6 +1194,39 @@ def get_all_interfaces_definition(annotate_with_monitored: bool=True) -> dict:
return interfaces
+def get_all_interface_subnets():
+ """Returns all subnets that this machine has access to.
+
+ Uses the `get_all_interfaces_definition` to get the available interfaces,
+ and returns a set of subnets for the machine.
+
+ :return: set of IP networks
+ :rtype: set of `IPNetwork`
+ """
+ return set(
+ IPNetwork(link["address"])
+ for interface in get_all_interfaces_definition().values()
+ for link in interface["links"]
+ )
+
+
+def get_all_interface_source_addresses():
+ """Return one source address per subnets defined on this machine.
+
+ Uses the `get_all_interface_subnets` and `get_source_address` to determine
+ the best source addresses for this machine.
+
+ :return: set of IP addresses
+ :rtype: set of `str`
+ """
+ source_addresses = set()
+ for network in get_all_interface_subnets():
+ src = get_source_address(network)
+ if src is not None:
+ source_addresses.add(src)
+ return source_addresses
+
+
def has_ipv4_address(interfaces: dict, interface: str) -> bool:
"""Returns True if the specified interface has an IPv4 address assigned.
@@ -1321,3 +1357,35 @@ def coerce_to_valid_hostname(hostname):
if hostname == '' or len(hostname) > 64:
return None
return hostname
+
+
+def get_source_address(destination_ip: IPAddressOrNetwork):
+ """Returns the local source address for the specified destination IP.
+
+ :param destination_ip: Can be an IP address in string format, an IPNetwork,
+ or an IPAddress object.
+ :return: the string representation of the local IP address that would be
+ used for communication with the specified destination.
+ """
+ if isinstance(destination_ip, IPNetwork):
+ destination_ip = IPAddress(destination_ip.first + 1)
+ else:
+ destination_ip = make_ipaddress(destination_ip)
+ af = AF_INET if destination_ip.version == 4 else AF_INET6
+ with socket.socket(af, socket.SOCK_DGRAM) as sock:
+ peername = str(destination_ip)
+ local_address = "0.0.0.0" if af == socket.AF_INET else "::"
+ try:
+ # Note: this sets up the socket *just enough* to get the source
+ # address. No network traffic will be transmitted.
+ sock.bind((local_address, 0))
+ sock.connect((peername, 7))
+ sockname = sock.getsockname()
+ own_ip = sockname[0]
+ return own_ip
+ except OSError:
+ # Probably "can't assign requested address", which probably means
+ # we tried to connect to an IPv6 address, but IPv6 is not
+ # configured. Could also happen if a network or broadcast address
+ # is passed in, or we otherwise cannot route to the destination.
+ return None
diff --git a/src/provisioningserver/utils/tests/test_network.py b/src/provisioningserver/utils/tests/test_network.py
index 57498af..fbfd678 100644
--- a/src/provisioningserver/utils/tests/test_network.py
+++ b/src/provisioningserver/utils/tests/test_network.py
@@ -51,11 +51,14 @@ from provisioningserver.utils.network import (
format_eui,
get_all_addresses_for_interface,
get_all_interface_addresses,
+ get_all_interface_source_addresses,
+ get_all_interface_subnets,
get_all_interfaces_definition,
get_default_monitored_interfaces,
get_eui_organization,
get_interface_children,
get_mac_organization,
+ get_source_address,
has_ipv4_address,
hex_str_to_bytes,
inet_ntop,
@@ -1518,6 +1521,52 @@ class TestGetAllInterfacesDefinition(MAASTestCase):
self.assertInterfacesResult(ip_addr, iproute_info, {}, expected_result)
+class TestGetAllInterfacesSubnets(MAASTestCase):
+ """Tests for `get_all_interface_subnets()`."""
+
+ def test_includes_unique_subnets(self):
+ interface_definition = {
+ 'eth0': {
+ 'links': [{
+ 'address': '192.168.122.1/24',
+ }, {
+ 'address': '192.168.122.3/24',
+ }],
+ },
+ 'eth1': {
+ 'links': [{
+ 'address': '192.168.123.1/24',
+ }, {
+ 'address': '192.168.123.2/24',
+ }]
+ }
+ }
+ self.patch(
+ network_module,
+ 'get_all_interfaces_definition').return_value = (
+ interface_definition)
+ self.assertEquals(
+ set([
+ IPNetwork('192.168.122.1/24'), IPNetwork('192.168.123.1/24')]),
+ get_all_interface_subnets())
+
+
+class TestGetAllInterfacesSourceAddresses(MAASTestCase):
+ """Tests for `get_all_interface_source_addresses()`."""
+
+ def test_includes_unique_subnets(self):
+ interface_subnets = set([
+ IPNetwork('192.168.122.1/24'), IPNetwork('192.168.123.1/24')])
+ self.patch(
+ network_module,
+ 'get_all_interface_subnets').return_value = interface_subnets
+ self.patch(network_module, 'get_source_address').side_effect = [
+ '192.168.122.1', None]
+ self.assertEquals(
+ set(['192.168.122.1']),
+ get_all_interface_source_addresses())
+
+
class TestHasIPv4Address(MAASTestCase):
"""Tests for `has_ipv4_address()`."""
@@ -2119,3 +2168,30 @@ class TestCoerceHostname(MAASTestCase):
def test_returns_none_if_result_too_large(self):
self.assertIsNone(coerce_to_valid_hostname('a' * 65))
+
+
+class TestGetSourceAddress(MAASTestCase):
+
+ def test__accepts_ipnetwork(self):
+ self.assertThat(
+ get_source_address(IPNetwork("127.0.0.1/8")), Equals("127.0.0.1"))
+
+ def test__accepts_ipaddress(self):
+ self.assertThat(
+ get_source_address(IPAddress("127.0.0.1")), Equals("127.0.0.1"))
+
+ def test__accepts_string(self):
+ self.assertThat(
+ get_source_address("127.0.0.1"), Equals("127.0.0.1"))
+
+ def test__supports_ipv6(self):
+ self.assertThat(
+ get_source_address("::1"), Equals("::1"))
+
+ def test__returns_none_if_no_route_found(self):
+ self.assertThat(
+ get_source_address("127.0.0.0"), Is(None))
+
+ def test__returns_appropriate_address_for_global_ip(self):
+ self.assertThat(
+ get_source_address("8.8.8.8"), Not(Is(None)))