From 8967f04f4f2f925478b493dc484cf326d7a3398d Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:44:07 +0000 Subject: [PATCH] feat(infrahub): add Infrahub client for fabric intent (#42) --- tests/test_infrahub_client.py | 598 ++++++++++++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 tests/test_infrahub_client.py diff --git a/tests/test_infrahub_client.py b/tests/test_infrahub_client.py new file mode 100644 index 0000000..3e3f4f5 --- /dev/null +++ b/tests/test_infrahub_client.py @@ -0,0 +1,598 @@ +"""Unit tests for FabricInfrahubClient — all Infrahub SDK calls are mocked.""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from infrahub_client import ( + FabricInfrahubClient, + InfrahubNotFoundError, + InfrahubQueryError, +) +from infrahub_client.models import ( + BgpPeerGroupIntent, + BgpRouterConfigIntent, + BgpSessionIntent, + DeviceIntent, + EvpnInstanceIntent, + MlagDomainIntent, + MlagPeerConfigIntent, + VlanIntent, + VrfIntent, + VtepIntent, +) + + +# --------------------------------------------------------------------------- +# Helpers to build mock SDK node objects +# --------------------------------------------------------------------------- + + +def _attr(value): + """Return a simple attribute mock with ``.value`` set.""" + m = MagicMock() + m.value = value + return m + + +def _rel_one(peer_node, initialized=True): + """Return a mock for a cardinality-one relationship.""" + m = MagicMock() + m.initialized = initialized + m.peer = peer_node + m.fetch = AsyncMock() + return m + + +def _rel_many(peers, initialized=True): + """Return a mock for a cardinality-many relationship (list of peer wrappers).""" + m = MagicMock() + m.initialized = initialized + peer_wrappers = [] + for p in peers: + pw = MagicMock() + pw.peer = p + peer_wrappers.append(pw) + m.peers = peer_wrappers + return m + + +def _make_device( + name="leaf1", + role="leaf", + status="active", + platform_name="eos", + site_name="dc1", + asn_value=65001, +): + """Build a mock InfraDevice SDK node.""" + node = MagicMock() + node.name = _attr(name) + node.role = _attr(role) + node.status = _attr(status) + + platform = MagicMock() + platform.name = _attr(platform_name) + node.platform = _rel_one(platform) + + site = MagicMock() + site.name = _attr(site_name) + node.site = _rel_one(site) + + asn = MagicMock() + asn.asn = _attr(asn_value) + node.asn = _rel_one(asn) + + return node + + +def _make_sdk_client(get_return=None, all_return=None, filters_return=None, get_raises=None): + """Patch InfrahubClient inside client module and configure return values.""" + mock_sdk = MagicMock() + if get_raises: + mock_sdk.get = AsyncMock(side_effect=get_raises) + else: + mock_sdk.get = AsyncMock(return_value=get_return) + mock_sdk.all = AsyncMock(return_value=all_return or []) + mock_sdk.filters = AsyncMock(return_value=filters_return or []) + return mock_sdk + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def client(): + """FabricInfrahubClient with a patched InfrahubClient SDK.""" + with patch("infrahub_client.client.InfrahubClient") as MockSdk: + MockSdk.return_value = MagicMock() + c = FabricInfrahubClient(url="http://infrahub:8080", api_token="token", branch="main") + yield c + + +# --------------------------------------------------------------------------- +# get_device +# --------------------------------------------------------------------------- + + +class TestGetDevice: + async def test_returns_device_intent(self, client): + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + result = await client.get_device("leaf1") + + assert isinstance(result, DeviceIntent) + assert result.name == "leaf1" + assert result.role == "leaf" + assert result.status == "active" + assert result.platform == "eos" + assert result.site == "dc1" + assert result.asn == 65001 + + async def test_not_found_raises_error(self, client): + client._sdk.get = AsyncMock(side_effect=Exception("not found")) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device("nonexistent") + + async def test_branch_is_forwarded(self, client): + client._branch = "proposed" + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + await client.get_device("leaf1") + + client._sdk.get.assert_called_once_with( + kind="InfraDevice", branch="proposed", name__value="leaf1" + ) + + async def test_caching_second_call_skips_sdk(self, client): + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + first = await client.get_device("leaf1") + second = await client.get_device("leaf1") + + assert first == second + # SDK should only have been called once + assert client._sdk.get.call_count == 1 + + async def test_optional_fields_none_when_uninitialized(self, client): + node = MagicMock() + node.name = _attr("spine1") + node.role = _attr("spine") + node.status = _attr("active") + node.platform = _rel_one(None, initialized=False) + node.site = _rel_one(None, initialized=False) + node.asn = _rel_one(None, initialized=False) + client._sdk.get = AsyncMock(return_value=node) + + result = await client.get_device("spine1") + + assert result.platform is None + assert result.site is None + assert result.asn is None + + +# --------------------------------------------------------------------------- +# get_device_bgp_config +# --------------------------------------------------------------------------- + + +class TestGetDeviceBgpConfig: + def _make_bgp_node(self, router_id="10.0.0.1", asn=65001, default_ipv4=True, ecmp=4): + node = MagicMock() + node.router_id = _attr(router_id) + node.default_ipv4_unicast = _attr(default_ipv4) + node.ecmp_max_paths = _attr(ecmp) + + asn_peer = MagicMock() + asn_peer.asn = _attr(asn) + node.local_asn = _rel_one(asn_peer) + + return node + + async def test_returns_bgp_config_intent(self, client): + bgp_node = self._make_bgp_node() + client._sdk.get = AsyncMock(return_value=bgp_node) + + result = await client.get_device_bgp_config("leaf1") + + assert isinstance(result, BgpRouterConfigIntent) + assert result.router_id == "10.0.0.1" + assert result.local_asn == 65001 + assert result.default_ipv4_unicast is True + assert result.ecmp_max_paths == 4 + + async def test_not_found_raises_error(self, client): + client._sdk.get = AsyncMock(side_effect=Exception("not found")) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device_bgp_config("leaf1") + + async def test_caching(self, client): + bgp_node = self._make_bgp_node() + client._sdk.get = AsyncMock(return_value=bgp_node) + + await client.get_device_bgp_config("leaf1") + await client.get_device_bgp_config("leaf1") + + assert client._sdk.get.call_count == 1 + + +# --------------------------------------------------------------------------- +# get_device_bgp_peer_groups +# --------------------------------------------------------------------------- + + +class TestGetDeviceBgpPeerGroups: + def _make_pg(self, name="UNDERLAY", pg_type="underlay", remote_asn=65000): + pg = MagicMock() + pg.name = _attr(name) + pg.peer_group_type = _attr(pg_type) + pg.update_source = _attr("Loopback0") + pg.send_community = _attr(True) + pg.ebgp_multihop = _attr(None) + pg.next_hop_unchanged = _attr(False) + + asn_peer = MagicMock() + asn_peer.asn = _attr(remote_asn) + pg.remote_asn = _rel_one(asn_peer) + + return pg + + async def test_returns_peer_group_intents(self, client): + pg = self._make_pg() + bgp_node = MagicMock() + bgp_node.peer_groups = _rel_many([pg]) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_peer_groups("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], BgpPeerGroupIntent) + assert results[0].name == "UNDERLAY" + assert results[0].peer_group_type == "underlay" + assert results[0].remote_asn == 65000 + assert results[0].send_community is True + + async def test_empty_when_no_peer_groups(self, client): + bgp_node = MagicMock() + bgp_node.peer_groups = _rel_many([], initialized=False) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_peer_groups("leaf1") + + assert results == [] + + +# --------------------------------------------------------------------------- +# get_device_bgp_sessions +# --------------------------------------------------------------------------- + + +class TestGetDeviceBgpSessions: + def _make_session(self, peer_addr="192.168.1.1", enabled=True): + sess = MagicMock() + sess.peer_address = _attr(peer_addr) + sess.description = _attr("peer session") + sess.enabled = _attr(enabled) + + pg_peer = MagicMock() + pg_peer.name = _attr("UNDERLAY") + sess.peer_group = _rel_one(pg_peer) + + asn_peer = MagicMock() + asn_peer.asn = _attr(65000) + sess.remote_asn = _rel_one(asn_peer) + + return sess + + async def test_returns_session_intents(self, client): + sess = self._make_session() + bgp_node = MagicMock() + bgp_node.sessions = _rel_many([sess]) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_sessions("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], BgpSessionIntent) + assert results[0].peer_address == "192.168.1.1" + assert results[0].enabled is True + assert results[0].peer_group == "UNDERLAY" + assert results[0].remote_asn == 65000 + + async def test_empty_when_no_sessions(self, client): + bgp_node = MagicMock() + bgp_node.sessions = _rel_many([], initialized=False) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_sessions("leaf1") + + assert results == [] + + +# --------------------------------------------------------------------------- +# get_device_vrfs +# --------------------------------------------------------------------------- + + +class TestGetDeviceVrfs: + def _make_assignment(self, vrf_name="PROD", rd="65001:100", vni=100100): + vni_peer = MagicMock() + vni_peer.vni = _attr(vni) + + rt_import = MagicMock() + rt_import.target = _attr("65001:100") + rt_export = MagicMock() + rt_export.target = _attr("65001:100") + + vrf_node = MagicMock() + vrf_node.name = _attr(vrf_name) + vrf_node.route_distinguisher = _attr(rd) + vrf_node.vrf_id = _attr(None) + vrf_node.l3vni = _rel_one(vni_peer) + vrf_node.import_targets = _rel_many([rt_import]) + vrf_node.export_targets = _rel_many([rt_export]) + + assignment = MagicMock() + assignment.vrf = _rel_one(vrf_node) + + return assignment + + async def test_returns_vrf_intents(self, client): + assignment = self._make_assignment() + client._sdk.filters = AsyncMock(return_value=[assignment]) + + results = await client.get_device_vrfs("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], VrfIntent) + assert results[0].name == "PROD" + assert results[0].route_distinguisher == "65001:100" + assert results[0].l3vni == 100100 + assert "65001:100" in results[0].import_targets + + async def test_empty_when_no_assignments(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + results = await client.get_device_vrfs("leaf1") + + assert results == [] + + async def test_caching(self, client): + assignment = self._make_assignment() + client._sdk.filters = AsyncMock(return_value=[assignment]) + + await client.get_device_vrfs("leaf1") + await client.get_device_vrfs("leaf1") + + assert client._sdk.filters.call_count == 1 + + +# --------------------------------------------------------------------------- +# get_device_vtep +# --------------------------------------------------------------------------- + + +class TestGetDeviceVtep: + def _make_vtep(self, src_addr="10.255.0.1", udp_port=4789): + vlan_node = MagicMock() + vlan_node.vlan_id = _attr(100) + vni_node = MagicMock() + vni_node.vni = _attr(10100) + + mapping = MagicMock() + mapping.vlan = _rel_one(vlan_node) + mapping.vni = _rel_one(vni_node) + + vtep = MagicMock() + vtep.source_address = _attr(src_addr) + vtep.udp_port = _attr(udp_port) + vtep.learn_restrict = _attr(True) + vtep.vlan_vni_mappings = _rel_many([mapping]) + + return vtep + + async def test_returns_vtep_intent(self, client): + vtep = self._make_vtep() + client._sdk.filters = AsyncMock(return_value=[vtep]) + + result = await client.get_device_vtep("leaf1") + + assert isinstance(result, VtepIntent) + assert result.source_address == "10.255.0.1" + assert result.udp_port == 4789 + assert result.learn_restrict is True + assert (100, 10100) in result.vlan_vni_mappings + + async def test_returns_none_when_no_vtep(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + result = await client.get_device_vtep("leaf1") + + assert result is None + + +# --------------------------------------------------------------------------- +# get_device_evpn_instances +# --------------------------------------------------------------------------- + + +class TestGetDeviceEvpnInstances: + def _make_evpn_instance(self, rd="65001:100", rt_import="65001:100", rt_export="65001:100"): + vlan_node = MagicMock() + vlan_node.vlan_id = _attr(100) + + node = MagicMock() + node.route_distinguisher = _attr(rd) + node.route_target_import = _attr(rt_import) + node.route_target_export = _attr(rt_export) + node.redistribute_learned = _attr(True) + node.vlan = _rel_one(vlan_node) + + return node + + async def test_returns_evpn_instance_intents(self, client): + evpn_node = self._make_evpn_instance() + client._sdk.filters = AsyncMock(return_value=[evpn_node]) + + results = await client.get_device_evpn_instances("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], EvpnInstanceIntent) + assert results[0].route_distinguisher == "65001:100" + assert results[0].redistribute_learned is True + assert results[0].vlan_id == 100 + + async def test_empty_when_no_instances(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + results = await client.get_device_evpn_instances("leaf1") + + assert results == [] + + +# --------------------------------------------------------------------------- +# get_mlag_domain +# --------------------------------------------------------------------------- + + +class TestGetMlagDomain: + def _make_domain(self, device_names=("leaf1", "leaf2")): + devices = [] + for dname in device_names: + d = MagicMock() + d.name = _attr(dname) + devices.append(d) + + peer_vlan_node = MagicMock() + peer_vlan_node.vlan_id = _attr(4094) + + domain = MagicMock() + domain.domain_id = _attr("1") + domain.virtual_mac = _attr("00:1c:73:00:00:01") + domain.heartbeat_vrf = _attr("MGMT") + domain.dual_primary_detection = _attr(True) + domain.dual_primary_delay = _attr(10) + domain.dual_primary_action = _attr("errDisable") + domain.devices = _rel_many(devices) + domain.peer_vlan = _rel_one(peer_vlan_node) + + return domain + + async def test_returns_mlag_domain_intent(self, client): + domain = self._make_domain() + client._sdk.all = AsyncMock(return_value=[domain]) + + result = await client.get_mlag_domain("leaf1") + + assert isinstance(result, MlagDomainIntent) + assert result.domain_id == "1" + assert result.virtual_mac == "00:1c:73:00:00:01" + assert "leaf1" in result.peer_devices + assert "leaf2" in result.peer_devices + + async def test_returns_none_when_device_not_in_any_domain(self, client): + domain = self._make_domain(device_names=("leaf3", "leaf4")) + client._sdk.all = AsyncMock(return_value=[domain]) + + result = await client.get_mlag_domain("leaf1") + + assert result is None + + async def test_caching(self, client): + domain = self._make_domain() + client._sdk.all = AsyncMock(return_value=[domain]) + + await client.get_mlag_domain("leaf1") + await client.get_mlag_domain("leaf1") + + assert client._sdk.all.call_count == 1 + + +# --------------------------------------------------------------------------- +# get_mlag_peer_config +# --------------------------------------------------------------------------- + + +class TestGetMlagPeerConfig: + def _make_peer_config(self): + peer_link_node = MagicMock() + peer_link_node.name = _attr("Port-Channel1") + + cfg = MagicMock() + cfg.local_interface_ip = _attr("192.168.255.0/31") + cfg.peer_address = _attr("192.168.255.1") + cfg.heartbeat_peer_ip = _attr("192.168.1.2") + cfg.peer_link = _rel_one(peer_link_node) + + return cfg + + async def test_returns_mlag_peer_config_intent(self, client): + cfg = self._make_peer_config() + client._sdk.filters = AsyncMock(return_value=[cfg]) + + result = await client.get_mlag_peer_config("leaf1") + + assert isinstance(result, MlagPeerConfigIntent) + assert result.local_interface_ip == "192.168.255.0/31" + assert result.peer_address == "192.168.255.1" + assert result.peer_link == "Port-Channel1" + + async def test_returns_none_when_no_config(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + result = await client.get_mlag_peer_config("leaf1") + + assert result is None + + +# --------------------------------------------------------------------------- +# Async context manager +# --------------------------------------------------------------------------- + + +class TestContextManager: + async def test_context_manager(self): + with patch("infrahub_client.client.InfrahubClient"): + async with FabricInfrahubClient( + url="http://infrahub:8080", api_token="token" + ) as client: + assert isinstance(client, FabricInfrahubClient) + + +# --------------------------------------------------------------------------- +# Cache TTL expiry +# --------------------------------------------------------------------------- + + +class TestCacheTtl: + async def test_expired_cache_triggers_new_sdk_call(self, client): + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + await client.get_device("leaf1") + # Manually expire the cache entry + client._cache["device:leaf1"] = (client._cache["device:leaf1"][0], 0.0) + await client.get_device("leaf1") + + assert client._sdk.get.call_count == 2 + + +# --------------------------------------------------------------------------- +# InfrahubQueryError propagation +# --------------------------------------------------------------------------- + + +class TestQueryError: + async def test_sdk_unexpected_error_raises_query_error(self, client): + client._sdk.get = AsyncMock(side_effect=Exception("connection refused")) + + with pytest.raises(InfrahubQueryError): + await client.get_device_bgp_config("leaf1")