feat(infrahub): add Infrahub client for fabric intent (#42)

This commit is contained in:
2026-02-26 12:44:07 +00:00
parent ae29117827
commit 8967f04f4f

View File

@@ -0,0 +1,598 @@
"""Unit tests for FabricInfrahubClient — all Infrahub SDK calls are mocked."""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from infrahub_client import (
FabricInfrahubClient,
InfrahubNotFoundError,
InfrahubQueryError,
)
from infrahub_client.models import (
BgpPeerGroupIntent,
BgpRouterConfigIntent,
BgpSessionIntent,
DeviceIntent,
EvpnInstanceIntent,
MlagDomainIntent,
MlagPeerConfigIntent,
VlanIntent,
VrfIntent,
VtepIntent,
)
# ---------------------------------------------------------------------------
# Helpers to build mock SDK node objects
# ---------------------------------------------------------------------------
def _attr(value):
"""Return a simple attribute mock with ``.value`` set."""
m = MagicMock()
m.value = value
return m
def _rel_one(peer_node, initialized=True):
"""Return a mock for a cardinality-one relationship."""
m = MagicMock()
m.initialized = initialized
m.peer = peer_node
m.fetch = AsyncMock()
return m
def _rel_many(peers, initialized=True):
"""Return a mock for a cardinality-many relationship (list of peer wrappers)."""
m = MagicMock()
m.initialized = initialized
peer_wrappers = []
for p in peers:
pw = MagicMock()
pw.peer = p
peer_wrappers.append(pw)
m.peers = peer_wrappers
return m
def _make_device(
name="leaf1",
role="leaf",
status="active",
platform_name="eos",
site_name="dc1",
asn_value=65001,
):
"""Build a mock InfraDevice SDK node."""
node = MagicMock()
node.name = _attr(name)
node.role = _attr(role)
node.status = _attr(status)
platform = MagicMock()
platform.name = _attr(platform_name)
node.platform = _rel_one(platform)
site = MagicMock()
site.name = _attr(site_name)
node.site = _rel_one(site)
asn = MagicMock()
asn.asn = _attr(asn_value)
node.asn = _rel_one(asn)
return node
def _make_sdk_client(get_return=None, all_return=None, filters_return=None, get_raises=None):
"""Patch InfrahubClient inside client module and configure return values."""
mock_sdk = MagicMock()
if get_raises:
mock_sdk.get = AsyncMock(side_effect=get_raises)
else:
mock_sdk.get = AsyncMock(return_value=get_return)
mock_sdk.all = AsyncMock(return_value=all_return or [])
mock_sdk.filters = AsyncMock(return_value=filters_return or [])
return mock_sdk
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def client():
"""FabricInfrahubClient with a patched InfrahubClient SDK."""
with patch("infrahub_client.client.InfrahubClient") as MockSdk:
MockSdk.return_value = MagicMock()
c = FabricInfrahubClient(url="http://infrahub:8080", api_token="token", branch="main")
yield c
# ---------------------------------------------------------------------------
# get_device
# ---------------------------------------------------------------------------
class TestGetDevice:
async def test_returns_device_intent(self, client):
device_node = _make_device()
client._sdk.get = AsyncMock(return_value=device_node)
result = await client.get_device("leaf1")
assert isinstance(result, DeviceIntent)
assert result.name == "leaf1"
assert result.role == "leaf"
assert result.status == "active"
assert result.platform == "eos"
assert result.site == "dc1"
assert result.asn == 65001
async def test_not_found_raises_error(self, client):
client._sdk.get = AsyncMock(side_effect=Exception("not found"))
with pytest.raises(InfrahubNotFoundError):
await client.get_device("nonexistent")
async def test_branch_is_forwarded(self, client):
client._branch = "proposed"
device_node = _make_device()
client._sdk.get = AsyncMock(return_value=device_node)
await client.get_device("leaf1")
client._sdk.get.assert_called_once_with(
kind="InfraDevice", branch="proposed", name__value="leaf1"
)
async def test_caching_second_call_skips_sdk(self, client):
device_node = _make_device()
client._sdk.get = AsyncMock(return_value=device_node)
first = await client.get_device("leaf1")
second = await client.get_device("leaf1")
assert first == second
# SDK should only have been called once
assert client._sdk.get.call_count == 1
async def test_optional_fields_none_when_uninitialized(self, client):
node = MagicMock()
node.name = _attr("spine1")
node.role = _attr("spine")
node.status = _attr("active")
node.platform = _rel_one(None, initialized=False)
node.site = _rel_one(None, initialized=False)
node.asn = _rel_one(None, initialized=False)
client._sdk.get = AsyncMock(return_value=node)
result = await client.get_device("spine1")
assert result.platform is None
assert result.site is None
assert result.asn is None
# ---------------------------------------------------------------------------
# get_device_bgp_config
# ---------------------------------------------------------------------------
class TestGetDeviceBgpConfig:
def _make_bgp_node(self, router_id="10.0.0.1", asn=65001, default_ipv4=True, ecmp=4):
node = MagicMock()
node.router_id = _attr(router_id)
node.default_ipv4_unicast = _attr(default_ipv4)
node.ecmp_max_paths = _attr(ecmp)
asn_peer = MagicMock()
asn_peer.asn = _attr(asn)
node.local_asn = _rel_one(asn_peer)
return node
async def test_returns_bgp_config_intent(self, client):
bgp_node = self._make_bgp_node()
client._sdk.get = AsyncMock(return_value=bgp_node)
result = await client.get_device_bgp_config("leaf1")
assert isinstance(result, BgpRouterConfigIntent)
assert result.router_id == "10.0.0.1"
assert result.local_asn == 65001
assert result.default_ipv4_unicast is True
assert result.ecmp_max_paths == 4
async def test_not_found_raises_error(self, client):
client._sdk.get = AsyncMock(side_effect=Exception("not found"))
with pytest.raises(InfrahubNotFoundError):
await client.get_device_bgp_config("leaf1")
async def test_caching(self, client):
bgp_node = self._make_bgp_node()
client._sdk.get = AsyncMock(return_value=bgp_node)
await client.get_device_bgp_config("leaf1")
await client.get_device_bgp_config("leaf1")
assert client._sdk.get.call_count == 1
# ---------------------------------------------------------------------------
# get_device_bgp_peer_groups
# ---------------------------------------------------------------------------
class TestGetDeviceBgpPeerGroups:
def _make_pg(self, name="UNDERLAY", pg_type="underlay", remote_asn=65000):
pg = MagicMock()
pg.name = _attr(name)
pg.peer_group_type = _attr(pg_type)
pg.update_source = _attr("Loopback0")
pg.send_community = _attr(True)
pg.ebgp_multihop = _attr(None)
pg.next_hop_unchanged = _attr(False)
asn_peer = MagicMock()
asn_peer.asn = _attr(remote_asn)
pg.remote_asn = _rel_one(asn_peer)
return pg
async def test_returns_peer_group_intents(self, client):
pg = self._make_pg()
bgp_node = MagicMock()
bgp_node.peer_groups = _rel_many([pg])
client._sdk.get = AsyncMock(return_value=bgp_node)
results = await client.get_device_bgp_peer_groups("leaf1")
assert len(results) == 1
assert isinstance(results[0], BgpPeerGroupIntent)
assert results[0].name == "UNDERLAY"
assert results[0].peer_group_type == "underlay"
assert results[0].remote_asn == 65000
assert results[0].send_community is True
async def test_empty_when_no_peer_groups(self, client):
bgp_node = MagicMock()
bgp_node.peer_groups = _rel_many([], initialized=False)
client._sdk.get = AsyncMock(return_value=bgp_node)
results = await client.get_device_bgp_peer_groups("leaf1")
assert results == []
# ---------------------------------------------------------------------------
# get_device_bgp_sessions
# ---------------------------------------------------------------------------
class TestGetDeviceBgpSessions:
def _make_session(self, peer_addr="192.168.1.1", enabled=True):
sess = MagicMock()
sess.peer_address = _attr(peer_addr)
sess.description = _attr("peer session")
sess.enabled = _attr(enabled)
pg_peer = MagicMock()
pg_peer.name = _attr("UNDERLAY")
sess.peer_group = _rel_one(pg_peer)
asn_peer = MagicMock()
asn_peer.asn = _attr(65000)
sess.remote_asn = _rel_one(asn_peer)
return sess
async def test_returns_session_intents(self, client):
sess = self._make_session()
bgp_node = MagicMock()
bgp_node.sessions = _rel_many([sess])
client._sdk.get = AsyncMock(return_value=bgp_node)
results = await client.get_device_bgp_sessions("leaf1")
assert len(results) == 1
assert isinstance(results[0], BgpSessionIntent)
assert results[0].peer_address == "192.168.1.1"
assert results[0].enabled is True
assert results[0].peer_group == "UNDERLAY"
assert results[0].remote_asn == 65000
async def test_empty_when_no_sessions(self, client):
bgp_node = MagicMock()
bgp_node.sessions = _rel_many([], initialized=False)
client._sdk.get = AsyncMock(return_value=bgp_node)
results = await client.get_device_bgp_sessions("leaf1")
assert results == []
# ---------------------------------------------------------------------------
# get_device_vrfs
# ---------------------------------------------------------------------------
class TestGetDeviceVrfs:
def _make_assignment(self, vrf_name="PROD", rd="65001:100", vni=100100):
vni_peer = MagicMock()
vni_peer.vni = _attr(vni)
rt_import = MagicMock()
rt_import.target = _attr("65001:100")
rt_export = MagicMock()
rt_export.target = _attr("65001:100")
vrf_node = MagicMock()
vrf_node.name = _attr(vrf_name)
vrf_node.route_distinguisher = _attr(rd)
vrf_node.vrf_id = _attr(None)
vrf_node.l3vni = _rel_one(vni_peer)
vrf_node.import_targets = _rel_many([rt_import])
vrf_node.export_targets = _rel_many([rt_export])
assignment = MagicMock()
assignment.vrf = _rel_one(vrf_node)
return assignment
async def test_returns_vrf_intents(self, client):
assignment = self._make_assignment()
client._sdk.filters = AsyncMock(return_value=[assignment])
results = await client.get_device_vrfs("leaf1")
assert len(results) == 1
assert isinstance(results[0], VrfIntent)
assert results[0].name == "PROD"
assert results[0].route_distinguisher == "65001:100"
assert results[0].l3vni == 100100
assert "65001:100" in results[0].import_targets
async def test_empty_when_no_assignments(self, client):
client._sdk.filters = AsyncMock(return_value=[])
results = await client.get_device_vrfs("leaf1")
assert results == []
async def test_caching(self, client):
assignment = self._make_assignment()
client._sdk.filters = AsyncMock(return_value=[assignment])
await client.get_device_vrfs("leaf1")
await client.get_device_vrfs("leaf1")
assert client._sdk.filters.call_count == 1
# ---------------------------------------------------------------------------
# get_device_vtep
# ---------------------------------------------------------------------------
class TestGetDeviceVtep:
def _make_vtep(self, src_addr="10.255.0.1", udp_port=4789):
vlan_node = MagicMock()
vlan_node.vlan_id = _attr(100)
vni_node = MagicMock()
vni_node.vni = _attr(10100)
mapping = MagicMock()
mapping.vlan = _rel_one(vlan_node)
mapping.vni = _rel_one(vni_node)
vtep = MagicMock()
vtep.source_address = _attr(src_addr)
vtep.udp_port = _attr(udp_port)
vtep.learn_restrict = _attr(True)
vtep.vlan_vni_mappings = _rel_many([mapping])
return vtep
async def test_returns_vtep_intent(self, client):
vtep = self._make_vtep()
client._sdk.filters = AsyncMock(return_value=[vtep])
result = await client.get_device_vtep("leaf1")
assert isinstance(result, VtepIntent)
assert result.source_address == "10.255.0.1"
assert result.udp_port == 4789
assert result.learn_restrict is True
assert (100, 10100) in result.vlan_vni_mappings
async def test_returns_none_when_no_vtep(self, client):
client._sdk.filters = AsyncMock(return_value=[])
result = await client.get_device_vtep("leaf1")
assert result is None
# ---------------------------------------------------------------------------
# get_device_evpn_instances
# ---------------------------------------------------------------------------
class TestGetDeviceEvpnInstances:
def _make_evpn_instance(self, rd="65001:100", rt_import="65001:100", rt_export="65001:100"):
vlan_node = MagicMock()
vlan_node.vlan_id = _attr(100)
node = MagicMock()
node.route_distinguisher = _attr(rd)
node.route_target_import = _attr(rt_import)
node.route_target_export = _attr(rt_export)
node.redistribute_learned = _attr(True)
node.vlan = _rel_one(vlan_node)
return node
async def test_returns_evpn_instance_intents(self, client):
evpn_node = self._make_evpn_instance()
client._sdk.filters = AsyncMock(return_value=[evpn_node])
results = await client.get_device_evpn_instances("leaf1")
assert len(results) == 1
assert isinstance(results[0], EvpnInstanceIntent)
assert results[0].route_distinguisher == "65001:100"
assert results[0].redistribute_learned is True
assert results[0].vlan_id == 100
async def test_empty_when_no_instances(self, client):
client._sdk.filters = AsyncMock(return_value=[])
results = await client.get_device_evpn_instances("leaf1")
assert results == []
# ---------------------------------------------------------------------------
# get_mlag_domain
# ---------------------------------------------------------------------------
class TestGetMlagDomain:
def _make_domain(self, device_names=("leaf1", "leaf2")):
devices = []
for dname in device_names:
d = MagicMock()
d.name = _attr(dname)
devices.append(d)
peer_vlan_node = MagicMock()
peer_vlan_node.vlan_id = _attr(4094)
domain = MagicMock()
domain.domain_id = _attr("1")
domain.virtual_mac = _attr("00:1c:73:00:00:01")
domain.heartbeat_vrf = _attr("MGMT")
domain.dual_primary_detection = _attr(True)
domain.dual_primary_delay = _attr(10)
domain.dual_primary_action = _attr("errDisable")
domain.devices = _rel_many(devices)
domain.peer_vlan = _rel_one(peer_vlan_node)
return domain
async def test_returns_mlag_domain_intent(self, client):
domain = self._make_domain()
client._sdk.all = AsyncMock(return_value=[domain])
result = await client.get_mlag_domain("leaf1")
assert isinstance(result, MlagDomainIntent)
assert result.domain_id == "1"
assert result.virtual_mac == "00:1c:73:00:00:01"
assert "leaf1" in result.peer_devices
assert "leaf2" in result.peer_devices
async def test_returns_none_when_device_not_in_any_domain(self, client):
domain = self._make_domain(device_names=("leaf3", "leaf4"))
client._sdk.all = AsyncMock(return_value=[domain])
result = await client.get_mlag_domain("leaf1")
assert result is None
async def test_caching(self, client):
domain = self._make_domain()
client._sdk.all = AsyncMock(return_value=[domain])
await client.get_mlag_domain("leaf1")
await client.get_mlag_domain("leaf1")
assert client._sdk.all.call_count == 1
# ---------------------------------------------------------------------------
# get_mlag_peer_config
# ---------------------------------------------------------------------------
class TestGetMlagPeerConfig:
def _make_peer_config(self):
peer_link_node = MagicMock()
peer_link_node.name = _attr("Port-Channel1")
cfg = MagicMock()
cfg.local_interface_ip = _attr("192.168.255.0/31")
cfg.peer_address = _attr("192.168.255.1")
cfg.heartbeat_peer_ip = _attr("192.168.1.2")
cfg.peer_link = _rel_one(peer_link_node)
return cfg
async def test_returns_mlag_peer_config_intent(self, client):
cfg = self._make_peer_config()
client._sdk.filters = AsyncMock(return_value=[cfg])
result = await client.get_mlag_peer_config("leaf1")
assert isinstance(result, MlagPeerConfigIntent)
assert result.local_interface_ip == "192.168.255.0/31"
assert result.peer_address == "192.168.255.1"
assert result.peer_link == "Port-Channel1"
async def test_returns_none_when_no_config(self, client):
client._sdk.filters = AsyncMock(return_value=[])
result = await client.get_mlag_peer_config("leaf1")
assert result is None
# ---------------------------------------------------------------------------
# Async context manager
# ---------------------------------------------------------------------------
class TestContextManager:
async def test_context_manager(self):
with patch("infrahub_client.client.InfrahubClient"):
async with FabricInfrahubClient(
url="http://infrahub:8080", api_token="token"
) as client:
assert isinstance(client, FabricInfrahubClient)
# ---------------------------------------------------------------------------
# Cache TTL expiry
# ---------------------------------------------------------------------------
class TestCacheTtl:
async def test_expired_cache_triggers_new_sdk_call(self, client):
device_node = _make_device()
client._sdk.get = AsyncMock(return_value=device_node)
await client.get_device("leaf1")
# Manually expire the cache entry
client._cache["device:leaf1"] = (client._cache["device:leaf1"][0], 0.0)
await client.get_device("leaf1")
assert client._sdk.get.call_count == 2
# ---------------------------------------------------------------------------
# InfrahubQueryError propagation
# ---------------------------------------------------------------------------
class TestQueryError:
async def test_sdk_unexpected_error_raises_query_error(self, client):
client._sdk.get = AsyncMock(side_effect=Exception("connection refused"))
with pytest.raises(InfrahubQueryError):
await client.get_device_bgp_config("leaf1")