""" Unit tests for FabricInfrahubClient. All SDK interactions are mocked with AsyncMock — no real Infrahub instance required. """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest # noqa: I001 from src.infrahub.client import FabricInfrahubClient from src.infrahub.exceptions import InfrahubNotFoundError, InfrahubQueryError from src.infrahub.models import ( BgpPeerGroupIntent, BgpRouterConfigIntent, BgpSessionIntent, DeviceIntent, EvpnInstanceIntent, MlagDomainIntent, MlagPeerConfigIntent, VlanIntent, VrfIntent, VtepIntent, ) # ============================================================================= # Helpers — mock node factories # ============================================================================= def _attr(value): """Return a simple mock whose .value is set.""" m = MagicMock() m.value = value return m def _rel_one(peer_mock): """Return a mock relationship (cardinality one) with an already-loaded peer.""" m = MagicMock() m.peer = peer_mock m.fetch = AsyncMock() return m def _rel_many(peer_mocks): """Return a mock relationship (cardinality many) whose .peers yields (peer_rel, peer).""" peers = [] for p in peer_mocks: rel = MagicMock() rel.peer = p peers.append(rel) m = MagicMock() m.peers = peers return m def make_device_node( name="leaf1", role="leaf", status="active", platform_name="EOS", site_name="dc1", asn_val=65001, ): node = MagicMock() node.name = _attr(name) node.role = _attr(role) node.status = _attr(status) platform = MagicMock() platform.name = _attr(platform_name) node.platform = _rel_one(platform) site = MagicMock() site.name = _attr(site_name) node.site = _rel_one(site) asn_node = MagicMock() asn_node.asn = _attr(asn_val) node.asn = _rel_one(asn_node) return node def make_vlan_node(vlan_id=10, name="VLAN10", status="active", vlan_type="standard", vni=10010): node = MagicMock() node.vlan_id = _attr(vlan_id) node.name = _attr(name) node.status = _attr(status) node.vlan_type = _attr(vlan_type) node.stp_enabled = _attr(True) vni_node = MagicMock() vni_node.vni = _attr(vni) node.vni = _rel_one(vni_node) return node def make_bgp_config_node(router_id="10.0.0.1", asn=65001, device_name="leaf1"): node = MagicMock() node.router_id = _attr(router_id) node.default_ipv4_unicast = _attr(True) node.ecmp_max_paths = _attr(4) asn_node = MagicMock() asn_node.asn = _attr(asn) node.local_asn = _rel_one(asn_node) device = MagicMock() device.name = _attr(device_name) node.device = _rel_one(device) node.peer_groups = _rel_many([]) node.sessions = _rel_many([]) return node # ============================================================================= # Fixtures # ============================================================================= @pytest.fixture def mock_sdk_client(): """Patch InfrahubClient and Config so no real connection is made.""" with ( patch("src.infrahub.client.InfrahubClient") as mock_cls, patch("src.infrahub.client.Config"), ): sdk = AsyncMock() mock_cls.return_value = sdk yield sdk @pytest.fixture def client(mock_sdk_client): """Return a FabricInfrahubClient backed by a mocked SDK.""" return FabricInfrahubClient(url="http://infrahub:8080", api_token="test-token", branch="main") # ============================================================================= # Context manager # ============================================================================= @pytest.mark.asyncio async def test_context_manager(mock_sdk_client): async with FabricInfrahubClient(url="http://infrahub:8080", api_token="test-token") as c: assert isinstance(c, FabricInfrahubClient) # ============================================================================= # get_device # ============================================================================= @pytest.mark.asyncio async def test_get_device_returns_device_intent(client, mock_sdk_client): mock_sdk_client.get = AsyncMock(return_value=make_device_node()) result = await client.get_device("leaf1") assert isinstance(result, DeviceIntent) assert result.name == "leaf1" assert result.role == "leaf" assert result.status == "active" assert result.platform == "EOS" assert result.site == "dc1" assert result.asn == 65001 @pytest.mark.asyncio async def test_get_device_not_found_raises(client, mock_sdk_client): mock_sdk_client.get = AsyncMock(side_effect=Exception("not found")) with pytest.raises(InfrahubNotFoundError): await client.get_device("ghost-device") @pytest.mark.asyncio async def test_get_device_none_raises_not_found(client, mock_sdk_client): mock_sdk_client.get = AsyncMock(return_value=None) with pytest.raises(InfrahubNotFoundError): await client.get_device("missing") @pytest.mark.asyncio async def test_get_device_query_error(client, mock_sdk_client): mock_sdk_client.get = AsyncMock(side_effect=Exception("connection refused")) with pytest.raises(InfrahubQueryError): await client.get_device("leaf1") # ============================================================================= # Caching # ============================================================================= @pytest.mark.asyncio async def test_get_device_caches_result(client, mock_sdk_client): mock_sdk_client.get = AsyncMock(return_value=make_device_node()) result1 = await client.get_device("leaf1") result2 = await client.get_device("leaf1") assert result1 == result2 # SDK should only be called once despite two client calls mock_sdk_client.get.assert_called_once() @pytest.mark.asyncio async def test_cache_separate_keys_per_device(client, mock_sdk_client): mock_sdk_client.get = AsyncMock( side_effect=[ make_device_node(name="leaf1"), make_device_node(name="leaf2", asn=65002), ] ) r1 = await client.get_device("leaf1") r2 = await client.get_device("leaf2") assert r1.name == "leaf1" assert r2.name == "leaf2" assert mock_sdk_client.get.call_count == 2 # ============================================================================= # Branch selection # ============================================================================= def test_branch_passed_to_config(): with ( patch("src.infrahub.client.InfrahubClient"), patch("src.infrahub.client.Config") as mock_cfg, ): FabricInfrahubClient(url="http://infrahub:8080", api_token="tok", branch="proposed-change") mock_cfg.assert_called_once_with( address="http://infrahub:8080", api_token="tok", default_branch="proposed-change", ) # ============================================================================= # get_device_bgp_config # ============================================================================= @pytest.mark.asyncio async def test_get_device_bgp_config(client, mock_sdk_client): bgp_node = make_bgp_config_node(router_id="10.0.0.1", asn=65001, device_name="leaf1") mock_sdk_client.filters = AsyncMock(return_value=[bgp_node]) result = await client.get_device_bgp_config("leaf1") assert isinstance(result, BgpRouterConfigIntent) assert result.router_id == "10.0.0.1" assert result.local_asn == 65001 assert result.default_ipv4_unicast is True assert result.ecmp_max_paths == 4 @pytest.mark.asyncio async def test_get_device_bgp_config_not_found(client, mock_sdk_client): mock_sdk_client.filters = AsyncMock(return_value=[]) with pytest.raises(InfrahubNotFoundError): await client.get_device_bgp_config("leaf1") # ============================================================================= # get_device_bgp_peer_groups # ============================================================================= @pytest.mark.asyncio async def test_get_device_bgp_peer_groups(client, mock_sdk_client): pg_node = MagicMock() pg_node.name = _attr("EVPN-PEERS") pg_node.peer_group_type = _attr("evpn") pg_node.update_source = _attr("Loopback0") pg_node.send_community = _attr("extended") pg_node.ebgp_multihop = _attr(3) pg_node.next_hop_unchanged = _attr(True) remote_asn_node = MagicMock() remote_asn_node.asn = _attr(65000) pg_node.remote_asn = _rel_one(remote_asn_node) bgp_node = make_bgp_config_node(device_name="leaf1") bgp_node.peer_groups = _rel_many([pg_node]) mock_sdk_client.filters = AsyncMock(return_value=[bgp_node]) result = await client.get_device_bgp_peer_groups("leaf1") assert len(result) == 1 pg = result[0] assert isinstance(pg, BgpPeerGroupIntent) assert pg.name == "EVPN-PEERS" assert pg.peer_group_type == "evpn" assert pg.remote_asn == 65000 assert pg.send_community == "extended" assert pg.next_hop_unchanged is True # ============================================================================= # get_device_bgp_sessions # ============================================================================= @pytest.mark.asyncio async def test_get_device_bgp_sessions(client, mock_sdk_client): sess_node = MagicMock() sess_node.peer_address = _attr("10.0.0.2") sess_node.description = _attr("to-spine1") sess_node.enabled = _attr(True) pg = MagicMock() pg.name = _attr("UNDERLAY") sess_node.peer_group = _rel_one(pg) remote_asn = MagicMock() remote_asn.asn = _attr(65000) sess_node.remote_asn = _rel_one(remote_asn) bgp_node = make_bgp_config_node(device_name="leaf1") bgp_node.sessions = _rel_many([sess_node]) mock_sdk_client.filters = AsyncMock(return_value=[bgp_node]) result = await client.get_device_bgp_sessions("leaf1") assert len(result) == 1 sess = result[0] assert isinstance(sess, BgpSessionIntent) assert sess.peer_address == "10.0.0.2" assert sess.description == "to-spine1" assert sess.enabled is True assert sess.peer_group == "UNDERLAY" assert sess.remote_asn == 65000 # ============================================================================= # get_device_vrfs # ============================================================================= @pytest.mark.asyncio async def test_get_device_vrfs(client, mock_sdk_client): rt_import = MagicMock() rt_import.target = _attr("65000:100") rt_export = MagicMock() rt_export.target = _attr("65000:100") vni_node = MagicMock() vni_node.vni = _attr(10000) vrf_node = MagicMock() vrf_node.name = _attr("PROD") vrf_node.vrf_id = _attr(100) vrf_node.l3vni = _rel_one(vni_node) asgn = MagicMock() asgn.route_distinguisher = _attr("10.0.0.1:100") asgn.vrf = _rel_one(vrf_node) device = MagicMock() device.name = _attr("leaf1") asgn.device = _rel_one(device) asgn.import_targets = _rel_many([rt_import]) asgn.export_targets = _rel_many([rt_export]) mock_sdk_client.filters = AsyncMock(return_value=[asgn]) result = await client.get_device_vrfs("leaf1") assert len(result) == 1 vrf = result[0] assert isinstance(vrf, VrfIntent) assert vrf.name == "PROD" assert vrf.route_distinguisher == "10.0.0.1:100" assert vrf.vrf_id == 100 assert vrf.l3vni == 10000 assert vrf.import_targets == ["65000:100"] assert vrf.export_targets == ["65000:100"] # ============================================================================= # get_device_vtep # ============================================================================= @pytest.mark.asyncio async def test_get_device_vtep(client, mock_sdk_client): vlan_node = MagicMock() vlan_node.vlan_id = _attr(10) vni_node = MagicMock() vni_node.vni = _attr(10010) mapping = MagicMock() mapping.vlan = _rel_one(vlan_node) mapping.vni = _rel_one(vni_node) mapping.vlan.fetch = AsyncMock() mapping.vni.fetch = AsyncMock() device = MagicMock() device.name = _attr("leaf1") vtep = MagicMock() vtep.source_address = _attr("10.0.0.1") vtep.udp_port = _attr(4789) vtep.learn_restrict = _attr(False) vtep.device = _rel_one(device) vtep.vlan_vni_mappings = _rel_many([mapping]) mock_sdk_client.filters = AsyncMock(return_value=[vtep]) result = await client.get_device_vtep("leaf1") assert isinstance(result, VtepIntent) assert result.source_address == "10.0.0.1" assert result.udp_port == 4789 assert result.learn_restrict is False assert (10, 10010) in result.vlan_vni_mappings @pytest.mark.asyncio async def test_get_device_vtep_none_when_missing(client, mock_sdk_client): mock_sdk_client.filters = AsyncMock(return_value=[]) result = await client.get_device_vtep("leaf1") assert result is None # ============================================================================= # get_device_evpn_instances # ============================================================================= @pytest.mark.asyncio async def test_get_device_evpn_instances(client, mock_sdk_client): vlan_node = MagicMock() vlan_node.vlan_id = _attr(10) device = MagicMock() device.name = _attr("leaf1") node = MagicMock() node.route_distinguisher = _attr("10.0.0.1:10") node.route_target_import = _attr("65000:10") node.route_target_export = _attr("65000:10") node.redistribute_learned = _attr(True) node.device = _rel_one(device) node.vlan = _rel_one(vlan_node) mock_sdk_client.filters = AsyncMock(return_value=[node]) result = await client.get_device_evpn_instances("leaf1") assert len(result) == 1 ev = result[0] assert isinstance(ev, EvpnInstanceIntent) assert ev.route_distinguisher == "10.0.0.1:10" assert ev.route_target_import == "65000:10" assert ev.route_target_export == "65000:10" assert ev.redistribute_learned is True assert ev.vlan_id == 10 # ============================================================================= # get_mlag_domain # ============================================================================= @pytest.mark.asyncio async def test_get_mlag_domain(client, mock_sdk_client): dev1 = MagicMock() dev1.name = _attr("leaf1") dev2 = MagicMock() dev2.name = _attr("leaf2") domain = MagicMock() domain.domain_id = _attr("1") domain.virtual_mac = _attr("00:1c:73:00:00:01") domain.heartbeat_vrf = _attr("MGMT") domain.dual_primary_detection = _attr(True) domain.dual_primary_delay = _attr(10) domain.dual_primary_action = _attr("errdisable") domain.devices = _rel_many([dev1, dev2]) mock_sdk_client.filters = AsyncMock(return_value=[domain]) result = await client.get_mlag_domain("leaf1") assert isinstance(result, MlagDomainIntent) assert result.domain_id == "1" assert result.virtual_mac == "00:1c:73:00:00:01" assert "leaf1" in result.peer_devices assert "leaf2" in result.peer_devices @pytest.mark.asyncio async def test_get_mlag_domain_returns_none_when_not_member(client, mock_sdk_client): dev_other = MagicMock() dev_other.name = _attr("spine1") domain = MagicMock() domain.domain_id = _attr("99") domain.devices = _rel_many([dev_other]) mock_sdk_client.filters = AsyncMock(return_value=[domain]) result = await client.get_mlag_domain("leaf1") assert result is None # ============================================================================= # get_mlag_peer_config # ============================================================================= @pytest.mark.asyncio async def test_get_mlag_peer_config(client, mock_sdk_client): lag = MagicMock() lag.name = _attr("Port-Channel1") device = MagicMock() device.name = _attr("leaf1") node = MagicMock() node.local_interface_ip = _attr("10.255.255.0/31") node.peer_address = _attr("10.255.255.1") node.heartbeat_peer_ip = _attr("192.168.0.2") node.device = _rel_one(device) node.peer_link = _rel_one(lag) mock_sdk_client.filters = AsyncMock(return_value=[node]) result = await client.get_mlag_peer_config("leaf1") assert isinstance(result, MlagPeerConfigIntent) assert result.local_interface_ip == "10.255.255.0/31" assert result.peer_address == "10.255.255.1" assert result.heartbeat_peer_ip == "192.168.0.2" assert result.peer_link == "Port-Channel1" @pytest.mark.asyncio async def test_get_mlag_peer_config_returns_none_when_missing(client, mock_sdk_client): mock_sdk_client.filters = AsyncMock(return_value=[]) result = await client.get_mlag_peer_config("spine1") assert result is None # ============================================================================= # get_device_vlans # ============================================================================= @pytest.mark.asyncio async def test_get_device_vlans_via_vtep(client, mock_sdk_client): vlan_node = make_vlan_node(vlan_id=10, name="VLAN10", vni=10010) mapping = MagicMock() mapping.vlan = _rel_one(vlan_node) mapping.vlan.fetch = AsyncMock() device = MagicMock() device.name = _attr("leaf1") vtep = MagicMock() vtep.device = _rel_one(device) vtep.vlan_vni_mappings = _rel_many([mapping]) mock_sdk_client.filters = AsyncMock(return_value=[vtep]) result = await client.get_device_vlans("leaf1") assert len(result) == 1 assert isinstance(result[0], VlanIntent) assert result[0].vlan_id == 10 assert result[0].vni == 10010 @pytest.mark.asyncio async def test_get_device_vlans_caches_result(client, mock_sdk_client): device = MagicMock() device.name = _attr("leaf1") vtep = MagicMock() vtep.device = _rel_one(device) vtep.vlan_vni_mappings = _rel_many([]) mock_sdk_client.filters = AsyncMock(return_value=[vtep]) r1 = await client.get_device_vlans("leaf1") r2 = await client.get_device_vlans("leaf1") assert r1 == r2 mock_sdk_client.filters.assert_called_once()