From f663cc623c48562993394bb012eed65c37e3e873 Mon Sep 17 00:00:00 2001 From: Damien Date: Thu, 26 Feb 2026 14:03:09 +0100 Subject: [PATCH] feat(infrahub): add Infrahub client for fabric intent (#42) - Replace pynetbox with infrahub-sdk>=0.16.0 + pydantic>=2.0 in dependencies - Add pytest and pytest-asyncio to dev dependencies - Implement FabricInfrahubClient async client (src/infrahub/client.py) - get_device, get_device_vlans, get_device_bgp_config - get_device_bgp_peer_groups, get_device_bgp_sessions - get_device_vrfs, get_device_vtep, get_device_evpn_instances - get_mlag_domain, get_mlag_peer_config - TTL-based caching (60s) to avoid redundant SDK queries - Async context manager support - Add Pydantic v2 frozen models for all intent types (src/infrahub/models.py) - Add custom exception hierarchy (src/infrahub/exceptions.py) - Add unit tests with fully mocked SDK (tests/test_infrahub_client.py) - Tests for correct model return, NotFoundError, branch selection, caching Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 7 +- src/infrahub/__init__.py | 24 ++ src/infrahub/client.py | 789 ++++++++++++++++++++++++++++++++++ src/infrahub/exceptions.py | 31 ++ src/infrahub/models.py | 144 +++++++ tests/test_infrahub_client.py | 617 ++++++++++++++++++++++++++ 6 files changed, 1610 insertions(+), 2 deletions(-) create mode 100644 src/infrahub/__init__.py create mode 100644 src/infrahub/client.py create mode 100644 src/infrahub/exceptions.py create mode 100644 src/infrahub/models.py create mode 100644 tests/test_infrahub_client.py diff --git a/pyproject.toml b/pyproject.toml index af449fc..983ecd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,14 @@ [project] name = "fabric-orchestrator" version = "0.1.0" -description = "Declarative Network Fabric Orchestrator - Terraform-like infrastructure management for Arista EVPN-VXLAN using gNMI, YANG, and NetBox as Source of Truth" +description = "Declarative Network Fabric Orchestrator - Terraform-like infrastructure management for Arista EVPN-VXLAN using gNMI, YANG, and Infrahub as Source of Truth" readme = "README.md" requires-python = ">=3.12" dependencies = [ "click>=8.1.0", "pygnmi>=0.8.0", - "pynetbox>=7.5.0", + "infrahub-sdk>=0.16.0", + "pydantic>=2.0", "rich>=13.0.0", ] @@ -35,5 +36,7 @@ ignore = ["E501"] [dependency-groups] dev = [ "basedpyright>=1.37.4", + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", "ruff>=0.14.10", ] diff --git a/src/infrahub/__init__.py b/src/infrahub/__init__.py new file mode 100644 index 0000000..707fa11 --- /dev/null +++ b/src/infrahub/__init__.py @@ -0,0 +1,24 @@ +""" +Infrahub client package for the Fabric Orchestrator. + +Exports the main async client and exception hierarchy for use by the +Reconciler and Prefect flows. +""" + +from __future__ import annotations + +from .client import FabricInfrahubClient +from .exceptions import ( + InfrahubClientError, + InfrahubConnectionError, + InfrahubNotFoundError, + InfrahubQueryError, +) + +__all__ = [ + "FabricInfrahubClient", + "InfrahubClientError", + "InfrahubConnectionError", + "InfrahubNotFoundError", + "InfrahubQueryError", +] diff --git a/src/infrahub/client.py b/src/infrahub/client.py new file mode 100644 index 0000000..a4e92a9 --- /dev/null +++ b/src/infrahub/client.py @@ -0,0 +1,789 @@ +""" +Infrahub Client for Fabric Intent. + +Async client wrapping `infrahub-sdk` to fetch fabric intent data from a remote +Infrahub instance. Replaces the previous NetBox/pynetbox client and is used by +the Reconciler and Prefect flows. +""" + +from __future__ import annotations + +import time +from typing import Any + +from infrahub_sdk import Config, InfrahubClient + +from .exceptions import InfrahubNotFoundError, InfrahubQueryError +from .models import ( + BgpPeerGroupIntent, + BgpRouterConfigIntent, + BgpSessionIntent, + DeviceIntent, + EvpnInstanceIntent, + MlagDomainIntent, + MlagPeerConfigIntent, + VlanIntent, + VrfIntent, + VtepIntent, +) + +# TTL for cache entries in seconds +_CACHE_TTL: int = 60 + + +class FabricInfrahubClient: + """ + Async client for querying fabric intent data from a remote Infrahub instance. + + Wraps `infrahub-sdk` with structured Pydantic model responses, error handling, + and a simple TTL-based cache to avoid redundant queries. + + Usage as async context manager:: + + async with FabricInfrahubClient(url="http://infrahub:8080", api_token="xxx") as client: + device = await client.get_device("leaf1") + vlans = await client.get_device_vlans("leaf1") + """ + + def __init__(self, url: str, api_token: str, branch: str = "main") -> None: + """ + Initialize the Infrahub client. + + Args: + url: Base URL of the Infrahub instance (e.g. "http://infrahub:8080") + api_token: API token for authentication + branch: Default branch to query (default: "main") + """ + self._url = url + self._branch = branch + config = Config(address=url, api_token=api_token, default_branch=branch) + self._client = InfrahubClient(config=config) + self._cache: dict[str, tuple[Any, float]] = {} + + async def __aenter__(self) -> FabricInfrahubClient: + """Enter async context manager.""" + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit async context manager.""" + pass + + # ========================================================================= + # Cache helpers + # ========================================================================= + + def _cache_get(self, key: str) -> Any | None: + """Return cached value if still valid, else None.""" + if key in self._cache: + value, ts = self._cache[key] + if time.monotonic() - ts < _CACHE_TTL: + return value + del self._cache[key] + return None + + def _cache_set(self, key: str, value: Any) -> None: + """Store a value in the cache with the current timestamp.""" + self._cache[key] = (value, time.monotonic()) + + # ========================================================================= + # Device + # ========================================================================= + + async def get_device(self, name: str) -> DeviceIntent: + """ + Fetch a device by name and return a structured DeviceIntent. + + Args: + name: Device name as stored in Infrahub + + Returns: + DeviceIntent with device attributes and resolved relationship names + + Raises: + InfrahubNotFoundError: When no device with the given name exists + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"device:{name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + node = await self._client.get(kind="InfraDevice", name__value=name) + except Exception as e: + msg = str(e).lower() + if "not found" in msg or "no node" in msg: + raise InfrahubNotFoundError(f"Device '{name}' not found in Infrahub") from e + raise InfrahubQueryError(f"Failed to query device '{name}': {e}") from e + + if node is None: + raise InfrahubNotFoundError(f"Device '{name}' not found in Infrahub") + + # Resolve cardinality-one relationships + platform: str | None = None + site: str | None = None + asn: int | None = None + + try: + if node.platform.peer is None: + await node.platform.fetch() + if node.platform.peer is not None: + platform = node.platform.peer.name.value + except Exception: + pass + + try: + if node.site.peer is None: + await node.site.fetch() + if node.site.peer is not None: + site = node.site.peer.name.value + except Exception: + pass + + try: + if node.asn.peer is None: + await node.asn.fetch() + if node.asn.peer is not None: + asn = node.asn.peer.asn.value + except Exception: + pass + + result = DeviceIntent( + name=node.name.value, + role=node.role.value, + status=node.status.value, + platform=platform, + site=site, + asn=asn, + ) + self._cache_set(cache_key, result) + return result + + # ========================================================================= + # VLANs + # ========================================================================= + + async def get_device_vlans(self, device_name: str) -> list[VlanIntent]: + """ + Fetch all VLANs associated with a device via its VTEP's vlan_vni_mappings. + + Falls back to VLANs linked through SVI interfaces if no VTEP is found. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + List of VlanIntent objects for all VLANs associated with this device + + Raises: + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"device_vlans:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + vtep_nodes = await self._client.filters(kind="InfraVTEP", include=["vlan_vni_mappings"]) + except Exception as e: + raise InfrahubQueryError(f"Failed to query VTEPs: {e}") from e + + # Find VTEP belonging to this device + device_vtep = None + for vtep in vtep_nodes: + try: + if vtep.device.peer is None: + await vtep.device.fetch() + if vtep.device.peer is not None and vtep.device.peer.name.value == device_name: + device_vtep = vtep + break + except Exception: + continue + + vlans: list[VlanIntent] = [] + + if device_vtep is not None: + for mapping_rel in device_vtep.vlan_vni_mappings.peers: + try: + mapping = mapping_rel.peer + await mapping.vlan.fetch() + vlan_node = mapping.vlan.peer + if vlan_node is None: + continue + + vni_val: int | None = None + try: + await vlan_node.vni.fetch() + if vlan_node.vni.peer is not None: + vni_val = vlan_node.vni.peer.vni.value + except Exception: + pass + + vlans.append( + VlanIntent( + vlan_id=vlan_node.vlan_id.value, + name=vlan_node.name.value, + status=vlan_node.status.value, + vlan_type=vlan_node.vlan_type.value, + vni=vni_val, + stp_enabled=vlan_node.stp_enabled.value, + ) + ) + except Exception: + continue + else: + # Fallback: discover VLANs via InfraInterfaceVlan → VLAN + try: + svi_nodes = await self._client.filters(kind="InfraInterfaceVlan") + for svi in svi_nodes: + try: + if svi.device.peer is None: + await svi.device.fetch() + if svi.device.peer is None or svi.device.peer.name.value != device_name: + continue + + if svi.vlan.peer is None: + await svi.vlan.fetch() + vlan_node = svi.vlan.peer + if vlan_node is None: + continue + + vni_val = None + try: + await vlan_node.vni.fetch() + if vlan_node.vni.peer is not None: + vni_val = vlan_node.vni.peer.vni.value + except Exception: + pass + + vlans.append( + VlanIntent( + vlan_id=vlan_node.vlan_id.value, + name=vlan_node.name.value, + status=vlan_node.status.value, + vlan_type=vlan_node.vlan_type.value, + vni=vni_val, + stp_enabled=vlan_node.stp_enabled.value, + ) + ) + except Exception: + continue + except Exception as e: + raise InfrahubQueryError( + f"Failed to query SVI interfaces for device '{device_name}': {e}" + ) from e + + self._cache_set(cache_key, vlans) + return vlans + + # ========================================================================= + # BGP + # ========================================================================= + + async def _get_bgp_router_config_node(self, device_name: str) -> Any: + """ + Fetch the BGPRouterConfig node for a device. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + The raw InfraBGPRouterConfig SDK node + + Raises: + InfrahubNotFoundError: When no BGP config exists for this device + InfrahubQueryError: When the query fails unexpectedly + """ + try: + nodes = await self._client.filters( + kind="InfraBGPRouterConfig", + include=["peer_groups", "sessions"], + ) + except Exception as e: + raise InfrahubQueryError( + f"Failed to query BGP router config for '{device_name}': {e}" + ) from e + + for node in nodes: + try: + if node.device.peer is None: + await node.device.fetch() + if node.device.peer is not None and node.device.peer.name.value == device_name: + return node + except Exception: + continue + + raise InfrahubNotFoundError(f"No BGP router config found for device '{device_name}'") + + async def get_device_bgp_config(self, device_name: str) -> BgpRouterConfigIntent: + """ + Fetch the BGP router configuration for a device. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + BgpRouterConfigIntent with BGP router configuration + + Raises: + InfrahubNotFoundError: When no BGP config exists for this device + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"bgp_config:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + node = await self._get_bgp_router_config_node(device_name) + + local_asn: int = 0 + try: + if node.local_asn.peer is None: + await node.local_asn.fetch() + if node.local_asn.peer is not None: + local_asn = node.local_asn.peer.asn.value + except Exception: + pass + + result = BgpRouterConfigIntent( + router_id=node.router_id.value, + local_asn=local_asn, + default_ipv4_unicast=node.default_ipv4_unicast.value, + ecmp_max_paths=node.ecmp_max_paths.value, + ) + self._cache_set(cache_key, result) + return result + + async def get_device_bgp_peer_groups(self, device_name: str) -> list[BgpPeerGroupIntent]: + """ + Fetch all BGP peer groups for a device via its BGPRouterConfig. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + List of BgpPeerGroupIntent objects + + Raises: + InfrahubNotFoundError: When no BGP config exists for this device + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"bgp_peer_groups:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + bgp_node = await self._get_bgp_router_config_node(device_name) + + peer_groups: list[BgpPeerGroupIntent] = [] + for pg_rel in bgp_node.peer_groups.peers: + try: + pg = pg_rel.peer + + remote_asn: int | None = None + try: + if pg.remote_asn.peer is None: + await pg.remote_asn.fetch() + if pg.remote_asn.peer is not None: + remote_asn = pg.remote_asn.peer.asn.value + except Exception: + pass + + peer_groups.append( + BgpPeerGroupIntent( + name=pg.name.value, + peer_group_type=pg.peer_group_type.value, + remote_asn=remote_asn, + update_source=pg.update_source.value if pg.update_source.value else None, + send_community=pg.send_community.value, + ebgp_multihop=pg.ebgp_multihop.value if pg.ebgp_multihop.value else None, + next_hop_unchanged=pg.next_hop_unchanged.value, + ) + ) + except Exception: + continue + + self._cache_set(cache_key, peer_groups) + return peer_groups + + async def get_device_bgp_sessions(self, device_name: str) -> list[BgpSessionIntent]: + """ + Fetch all BGP sessions for a device via its BGPRouterConfig. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + List of BgpSessionIntent objects + + Raises: + InfrahubNotFoundError: When no BGP config exists for this device + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"bgp_sessions:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + bgp_node = await self._get_bgp_router_config_node(device_name) + + sessions: list[BgpSessionIntent] = [] + for sess_rel in bgp_node.sessions.peers: + try: + sess = sess_rel.peer + + peer_group: str | None = None + try: + if sess.peer_group.peer is None: + await sess.peer_group.fetch() + if sess.peer_group.peer is not None: + peer_group = sess.peer_group.peer.name.value + except Exception: + pass + + remote_asn: int | None = None + try: + if sess.remote_asn.peer is None: + await sess.remote_asn.fetch() + if sess.remote_asn.peer is not None: + remote_asn = sess.remote_asn.peer.asn.value + except Exception: + pass + + sessions.append( + BgpSessionIntent( + peer_address=sess.peer_address.value, + description=sess.description.value if sess.description.value else None, + enabled=sess.enabled.value, + peer_group=peer_group, + remote_asn=remote_asn, + ) + ) + except Exception: + continue + + self._cache_set(cache_key, sessions) + return sessions + + # ========================================================================= + # VRF + # ========================================================================= + + async def get_device_vrfs(self, device_name: str) -> list[VrfIntent]: + """ + Fetch all VRFs assigned to a device via VRFDeviceAssignment. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + List of VrfIntent objects with resolved route targets + + Raises: + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"device_vrfs:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + assignments = await self._client.filters( + kind="InfraVRFDeviceAssignment", + include=["import_targets", "export_targets"], + ) + except Exception as e: + raise InfrahubQueryError( + f"Failed to query VRF assignments for device '{device_name}': {e}" + ) from e + + vrfs: list[VrfIntent] = [] + for asgn in assignments: + try: + if asgn.device.peer is None: + await asgn.device.fetch() + if asgn.device.peer is None or asgn.device.peer.name.value != device_name: + continue + + if asgn.vrf.peer is None: + await asgn.vrf.fetch() + vrf_node = asgn.vrf.peer + if vrf_node is None: + continue + + l3vni: int | None = None + try: + if vrf_node.l3vni.peer is None: + await vrf_node.l3vni.fetch() + if vrf_node.l3vni.peer is not None: + l3vni = vrf_node.l3vni.peer.vni.value + except Exception: + pass + + import_targets: list[str] = [] + for rt_rel in asgn.import_targets.peers: + try: + import_targets.append(rt_rel.peer.target.value) + except Exception: + pass + + export_targets: list[str] = [] + for rt_rel in asgn.export_targets.peers: + try: + export_targets.append(rt_rel.peer.target.value) + except Exception: + pass + + rd = asgn.route_distinguisher.value if asgn.route_distinguisher.value else None + + vrfs.append( + VrfIntent( + name=vrf_node.name.value, + route_distinguisher=rd, + vrf_id=vrf_node.vrf_id.value if vrf_node.vrf_id.value else None, + l3vni=l3vni, + import_targets=import_targets, + export_targets=export_targets, + ) + ) + except Exception: + continue + + self._cache_set(cache_key, vrfs) + return vrfs + + # ========================================================================= + # VTEP + # ========================================================================= + + async def get_device_vtep(self, device_name: str) -> VtepIntent | None: + """ + Fetch the VTEP configuration for a device. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + VtepIntent if a VTEP exists for this device, None otherwise + + Raises: + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"device_vtep:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + vtep_nodes = await self._client.filters(kind="InfraVTEP", include=["vlan_vni_mappings"]) + except Exception as e: + raise InfrahubQueryError(f"Failed to query VTEPs: {e}") from e + + for vtep in vtep_nodes: + try: + if vtep.device.peer is None: + await vtep.device.fetch() + if vtep.device.peer is None or vtep.device.peer.name.value != device_name: + continue + + mappings: list[tuple[int, int]] = [] + for mapping_rel in vtep.vlan_vni_mappings.peers: + try: + mapping = mapping_rel.peer + await mapping.vlan.fetch() + await mapping.vni.fetch() + vlan_node = mapping.vlan.peer + vni_node = mapping.vni.peer + if vlan_node is not None and vni_node is not None: + mappings.append((vlan_node.vlan_id.value, vni_node.vni.value)) + except Exception: + continue + + result = VtepIntent( + source_address=vtep.source_address.value, + udp_port=vtep.udp_port.value, + learn_restrict=vtep.learn_restrict.value, + vlan_vni_mappings=mappings, + ) + self._cache_set(cache_key, result) + return result + except Exception: + continue + + self._cache_set(cache_key, None) + return None + + # ========================================================================= + # EVPN Instances + # ========================================================================= + + async def get_device_evpn_instances(self, device_name: str) -> list[EvpnInstanceIntent]: + """ + Fetch all EVPN instances for a device. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + List of EvpnInstanceIntent objects + + Raises: + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"evpn_instances:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + nodes = await self._client.filters(kind="InfraEVPNInstance") + except Exception as e: + raise InfrahubQueryError( + f"Failed to query EVPN instances for device '{device_name}': {e}" + ) from e + + instances: list[EvpnInstanceIntent] = [] + for node in nodes: + try: + if node.device.peer is None: + await node.device.fetch() + if node.device.peer is None or node.device.peer.name.value != device_name: + continue + + if node.vlan.peer is None: + await node.vlan.fetch() + vlan_node = node.vlan.peer + vlan_id: int = vlan_node.vlan_id.value if vlan_node is not None else 0 + + instances.append( + EvpnInstanceIntent( + route_distinguisher=node.route_distinguisher.value, + route_target_import=node.route_target_import.value, + route_target_export=node.route_target_export.value, + redistribute_learned=node.redistribute_learned.value, + vlan_id=vlan_id, + ) + ) + except Exception: + continue + + self._cache_set(cache_key, instances) + return instances + + # ========================================================================= + # MLAG + # ========================================================================= + + async def get_mlag_domain(self, device_name: str) -> MlagDomainIntent | None: + """ + Fetch the MLAG domain that contains the given device. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + MlagDomainIntent if an MLAG domain exists for this device, None otherwise + + Raises: + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"mlag_domain:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + nodes = await self._client.filters(kind="InfraMlagDomain", include=["devices"]) + except Exception as e: + raise InfrahubQueryError( + f"Failed to query MLAG domains for device '{device_name}': {e}" + ) from e + + for node in nodes: + try: + peer_device_names: list[str] = [] + found = False + + for dev_rel in node.devices.peers: + try: + dev_name = dev_rel.peer.name.value + peer_device_names.append(dev_name) + if dev_name == device_name: + found = True + except Exception: + pass + + if not found: + continue + + result = MlagDomainIntent( + domain_id=node.domain_id.value, + virtual_mac=node.virtual_mac.value, + heartbeat_vrf=node.heartbeat_vrf.value, + dual_primary_detection=node.dual_primary_detection.value, + dual_primary_delay=node.dual_primary_delay.value, + dual_primary_action=node.dual_primary_action.value, + peer_devices=peer_device_names, + ) + self._cache_set(cache_key, result) + return result + except Exception: + continue + + self._cache_set(cache_key, None) + return None + + async def get_mlag_peer_config(self, device_name: str) -> MlagPeerConfigIntent | None: + """ + Fetch the MLAG peer configuration for a device. + + Args: + device_name: Device name as stored in Infrahub + + Returns: + MlagPeerConfigIntent if MLAG peer config exists for this device, None otherwise + + Raises: + InfrahubQueryError: When the query fails unexpectedly + """ + cache_key = f"mlag_peer_config:{device_name}" + cached = self._cache_get(cache_key) + if cached is not None: + return cached + + try: + nodes = await self._client.filters(kind="InfraMlagPeerConfig") + except Exception as e: + raise InfrahubQueryError( + f"Failed to query MLAG peer configs for device '{device_name}': {e}" + ) from e + + for node in nodes: + try: + if node.device.peer is None: + await node.device.fetch() + if node.device.peer is None or node.device.peer.name.value != device_name: + continue + + peer_link: str = "" + try: + if node.peer_link.peer is None: + await node.peer_link.fetch() + if node.peer_link.peer is not None: + peer_link = node.peer_link.peer.name.value + except Exception: + pass + + result = MlagPeerConfigIntent( + local_interface_ip=node.local_interface_ip.value, + peer_address=node.peer_address.value, + heartbeat_peer_ip=node.heartbeat_peer_ip.value, + peer_link=peer_link, + ) + self._cache_set(cache_key, result) + return result + except Exception: + continue + + self._cache_set(cache_key, None) + return None diff --git a/src/infrahub/exceptions.py b/src/infrahub/exceptions.py new file mode 100644 index 0000000..dc8bb6b --- /dev/null +++ b/src/infrahub/exceptions.py @@ -0,0 +1,31 @@ +""" +Infrahub Client Exceptions. + +Custom exception hierarchy for the FabricInfrahubClient. +""" + +from __future__ import annotations + + +class InfrahubClientError(Exception): + """Base exception for all Infrahub client errors.""" + + pass + + +class InfrahubConnectionError(InfrahubClientError): + """Raised when connection to the Infrahub instance fails.""" + + pass + + +class InfrahubQueryError(InfrahubClientError): + """Raised when a query to Infrahub fails or returns unexpected data.""" + + pass + + +class InfrahubNotFoundError(InfrahubClientError): + """Raised when a requested node is not found in Infrahub.""" + + pass diff --git a/src/infrahub/models.py b/src/infrahub/models.py new file mode 100644 index 0000000..b8e87a4 --- /dev/null +++ b/src/infrahub/models.py @@ -0,0 +1,144 @@ +""" +Pydantic models for Infrahub fabric intent data. + +These immutable models represent the structured intent retrieved from Infrahub +and are used throughout the orchestrator for typed access to fabric state. +""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class DeviceIntent(BaseModel): + """Represents a network device from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + name: str + role: str + status: str + platform: str | None + site: str | None + asn: int | None + + +class VlanIntent(BaseModel): + """Represents a VLAN from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + vlan_id: int + name: str + status: str + vlan_type: str + vni: int | None + stp_enabled: bool + + +class VniIntent(BaseModel): + """Represents a VNI (VXLAN Network Identifier) from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + vni: int + vni_type: str + description: str | None + + +class BgpRouterConfigIntent(BaseModel): + """Represents a BGP router configuration from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + router_id: str + local_asn: int + default_ipv4_unicast: bool + ecmp_max_paths: int + + +class BgpPeerGroupIntent(BaseModel): + """Represents a BGP peer group from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + name: str + peer_group_type: str + remote_asn: int | None + update_source: str | None + send_community: str + ebgp_multihop: int | None + next_hop_unchanged: bool + + +class BgpSessionIntent(BaseModel): + """Represents a BGP session from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + peer_address: str + description: str | None + enabled: bool + peer_group: str | None + remote_asn: int | None + + +class VrfIntent(BaseModel): + """Represents a VRF from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + name: str + route_distinguisher: str | None + vrf_id: int | None + l3vni: int | None + import_targets: list[str] + export_targets: list[str] + + +class VtepIntent(BaseModel): + """Represents a VTEP (VXLAN Tunnel Endpoint) from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + source_address: str + udp_port: int + learn_restrict: bool + vlan_vni_mappings: list[tuple[int, int]] # (vlan_id, vni) + + +class MlagDomainIntent(BaseModel): + """Represents an MLAG domain from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + domain_id: str + virtual_mac: str + heartbeat_vrf: str + dual_primary_detection: bool + dual_primary_delay: int + dual_primary_action: str + peer_devices: list[str] + + +class MlagPeerConfigIntent(BaseModel): + """Represents an MLAG peer configuration from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + local_interface_ip: str + peer_address: str + heartbeat_peer_ip: str + peer_link: str + + +class EvpnInstanceIntent(BaseModel): + """Represents an EVPN instance from Infrahub.""" + + model_config = ConfigDict(frozen=True) + + route_distinguisher: str + route_target_import: str + route_target_export: str + redistribute_learned: bool + vlan_id: int diff --git a/tests/test_infrahub_client.py b/tests/test_infrahub_client.py new file mode 100644 index 0000000..a2c60d9 --- /dev/null +++ b/tests/test_infrahub_client.py @@ -0,0 +1,617 @@ +""" +Unit tests for FabricInfrahubClient. + +All SDK interactions are mocked with AsyncMock — no real Infrahub instance required. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest # noqa: I001 + +from src.infrahub.client import FabricInfrahubClient +from src.infrahub.exceptions import InfrahubNotFoundError, InfrahubQueryError +from src.infrahub.models import ( + BgpPeerGroupIntent, + BgpRouterConfigIntent, + BgpSessionIntent, + DeviceIntent, + EvpnInstanceIntent, + MlagDomainIntent, + MlagPeerConfigIntent, + VlanIntent, + VrfIntent, + VtepIntent, +) + +# ============================================================================= +# Helpers — mock node factories +# ============================================================================= + + +def _attr(value): + """Return a simple mock whose .value is set.""" + m = MagicMock() + m.value = value + return m + + +def _rel_one(peer_mock): + """Return a mock relationship (cardinality one) with an already-loaded peer.""" + m = MagicMock() + m.peer = peer_mock + m.fetch = AsyncMock() + return m + + +def _rel_many(peer_mocks): + """Return a mock relationship (cardinality many) whose .peers yields (peer_rel, peer).""" + peers = [] + for p in peer_mocks: + rel = MagicMock() + rel.peer = p + peers.append(rel) + m = MagicMock() + m.peers = peers + return m + + +def make_device_node( + name="leaf1", + role="leaf", + status="active", + platform_name="EOS", + site_name="dc1", + asn_val=65001, +): + node = MagicMock() + node.name = _attr(name) + node.role = _attr(role) + node.status = _attr(status) + + platform = MagicMock() + platform.name = _attr(platform_name) + node.platform = _rel_one(platform) + + site = MagicMock() + site.name = _attr(site_name) + node.site = _rel_one(site) + + asn_node = MagicMock() + asn_node.asn = _attr(asn_val) + node.asn = _rel_one(asn_node) + + return node + + +def make_vlan_node(vlan_id=10, name="VLAN10", status="active", vlan_type="standard", vni=10010): + node = MagicMock() + node.vlan_id = _attr(vlan_id) + node.name = _attr(name) + node.status = _attr(status) + node.vlan_type = _attr(vlan_type) + node.stp_enabled = _attr(True) + + vni_node = MagicMock() + vni_node.vni = _attr(vni) + node.vni = _rel_one(vni_node) + + return node + + +def make_bgp_config_node(router_id="10.0.0.1", asn=65001, device_name="leaf1"): + node = MagicMock() + node.router_id = _attr(router_id) + node.default_ipv4_unicast = _attr(True) + node.ecmp_max_paths = _attr(4) + + asn_node = MagicMock() + asn_node.asn = _attr(asn) + node.local_asn = _rel_one(asn_node) + + device = MagicMock() + device.name = _attr(device_name) + node.device = _rel_one(device) + + node.peer_groups = _rel_many([]) + node.sessions = _rel_many([]) + + return node + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_sdk_client(): + """Patch InfrahubClient and Config so no real connection is made.""" + with ( + patch("src.infrahub.client.InfrahubClient") as mock_cls, + patch("src.infrahub.client.Config"), + ): + sdk = AsyncMock() + mock_cls.return_value = sdk + yield sdk + + +@pytest.fixture +def client(mock_sdk_client): + """Return a FabricInfrahubClient backed by a mocked SDK.""" + return FabricInfrahubClient(url="http://infrahub:8080", api_token="test-token", branch="main") + + +# ============================================================================= +# Context manager +# ============================================================================= + + +@pytest.mark.asyncio +async def test_context_manager(mock_sdk_client): + async with FabricInfrahubClient(url="http://infrahub:8080", api_token="test-token") as c: + assert isinstance(c, FabricInfrahubClient) + + +# ============================================================================= +# get_device +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_returns_device_intent(client, mock_sdk_client): + mock_sdk_client.get = AsyncMock(return_value=make_device_node()) + + result = await client.get_device("leaf1") + + assert isinstance(result, DeviceIntent) + assert result.name == "leaf1" + assert result.role == "leaf" + assert result.status == "active" + assert result.platform == "EOS" + assert result.site == "dc1" + assert result.asn == 65001 + + +@pytest.mark.asyncio +async def test_get_device_not_found_raises(client, mock_sdk_client): + mock_sdk_client.get = AsyncMock(side_effect=Exception("not found")) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device("ghost-device") + + +@pytest.mark.asyncio +async def test_get_device_none_raises_not_found(client, mock_sdk_client): + mock_sdk_client.get = AsyncMock(return_value=None) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device("missing") + + +@pytest.mark.asyncio +async def test_get_device_query_error(client, mock_sdk_client): + mock_sdk_client.get = AsyncMock(side_effect=Exception("connection refused")) + + with pytest.raises(InfrahubQueryError): + await client.get_device("leaf1") + + +# ============================================================================= +# Caching +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_caches_result(client, mock_sdk_client): + mock_sdk_client.get = AsyncMock(return_value=make_device_node()) + + result1 = await client.get_device("leaf1") + result2 = await client.get_device("leaf1") + + assert result1 == result2 + # SDK should only be called once despite two client calls + mock_sdk_client.get.assert_called_once() + + +@pytest.mark.asyncio +async def test_cache_separate_keys_per_device(client, mock_sdk_client): + mock_sdk_client.get = AsyncMock( + side_effect=[ + make_device_node(name="leaf1"), + make_device_node(name="leaf2", asn=65002), + ] + ) + + r1 = await client.get_device("leaf1") + r2 = await client.get_device("leaf2") + + assert r1.name == "leaf1" + assert r2.name == "leaf2" + assert mock_sdk_client.get.call_count == 2 + + +# ============================================================================= +# Branch selection +# ============================================================================= + + +def test_branch_passed_to_config(): + with ( + patch("src.infrahub.client.InfrahubClient"), + patch("src.infrahub.client.Config") as mock_cfg, + ): + FabricInfrahubClient(url="http://infrahub:8080", api_token="tok", branch="proposed-change") + mock_cfg.assert_called_once_with( + address="http://infrahub:8080", + api_token="tok", + default_branch="proposed-change", + ) + + +# ============================================================================= +# get_device_bgp_config +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_bgp_config(client, mock_sdk_client): + bgp_node = make_bgp_config_node(router_id="10.0.0.1", asn=65001, device_name="leaf1") + mock_sdk_client.filters = AsyncMock(return_value=[bgp_node]) + + result = await client.get_device_bgp_config("leaf1") + + assert isinstance(result, BgpRouterConfigIntent) + assert result.router_id == "10.0.0.1" + assert result.local_asn == 65001 + assert result.default_ipv4_unicast is True + assert result.ecmp_max_paths == 4 + + +@pytest.mark.asyncio +async def test_get_device_bgp_config_not_found(client, mock_sdk_client): + mock_sdk_client.filters = AsyncMock(return_value=[]) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device_bgp_config("leaf1") + + +# ============================================================================= +# get_device_bgp_peer_groups +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_bgp_peer_groups(client, mock_sdk_client): + pg_node = MagicMock() + pg_node.name = _attr("EVPN-PEERS") + pg_node.peer_group_type = _attr("evpn") + pg_node.update_source = _attr("Loopback0") + pg_node.send_community = _attr("extended") + pg_node.ebgp_multihop = _attr(3) + pg_node.next_hop_unchanged = _attr(True) + + remote_asn_node = MagicMock() + remote_asn_node.asn = _attr(65000) + pg_node.remote_asn = _rel_one(remote_asn_node) + + bgp_node = make_bgp_config_node(device_name="leaf1") + bgp_node.peer_groups = _rel_many([pg_node]) + + mock_sdk_client.filters = AsyncMock(return_value=[bgp_node]) + + result = await client.get_device_bgp_peer_groups("leaf1") + + assert len(result) == 1 + pg = result[0] + assert isinstance(pg, BgpPeerGroupIntent) + assert pg.name == "EVPN-PEERS" + assert pg.peer_group_type == "evpn" + assert pg.remote_asn == 65000 + assert pg.send_community == "extended" + assert pg.next_hop_unchanged is True + + +# ============================================================================= +# get_device_bgp_sessions +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_bgp_sessions(client, mock_sdk_client): + sess_node = MagicMock() + sess_node.peer_address = _attr("10.0.0.2") + sess_node.description = _attr("to-spine1") + sess_node.enabled = _attr(True) + + pg = MagicMock() + pg.name = _attr("UNDERLAY") + sess_node.peer_group = _rel_one(pg) + + remote_asn = MagicMock() + remote_asn.asn = _attr(65000) + sess_node.remote_asn = _rel_one(remote_asn) + + bgp_node = make_bgp_config_node(device_name="leaf1") + bgp_node.sessions = _rel_many([sess_node]) + + mock_sdk_client.filters = AsyncMock(return_value=[bgp_node]) + + result = await client.get_device_bgp_sessions("leaf1") + + assert len(result) == 1 + sess = result[0] + assert isinstance(sess, BgpSessionIntent) + assert sess.peer_address == "10.0.0.2" + assert sess.description == "to-spine1" + assert sess.enabled is True + assert sess.peer_group == "UNDERLAY" + assert sess.remote_asn == 65000 + + +# ============================================================================= +# get_device_vrfs +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_vrfs(client, mock_sdk_client): + rt_import = MagicMock() + rt_import.target = _attr("65000:100") + + rt_export = MagicMock() + rt_export.target = _attr("65000:100") + + vni_node = MagicMock() + vni_node.vni = _attr(10000) + + vrf_node = MagicMock() + vrf_node.name = _attr("PROD") + vrf_node.vrf_id = _attr(100) + vrf_node.l3vni = _rel_one(vni_node) + + asgn = MagicMock() + asgn.route_distinguisher = _attr("10.0.0.1:100") + asgn.vrf = _rel_one(vrf_node) + + device = MagicMock() + device.name = _attr("leaf1") + asgn.device = _rel_one(device) + + asgn.import_targets = _rel_many([rt_import]) + asgn.export_targets = _rel_many([rt_export]) + + mock_sdk_client.filters = AsyncMock(return_value=[asgn]) + + result = await client.get_device_vrfs("leaf1") + + assert len(result) == 1 + vrf = result[0] + assert isinstance(vrf, VrfIntent) + assert vrf.name == "PROD" + assert vrf.route_distinguisher == "10.0.0.1:100" + assert vrf.vrf_id == 100 + assert vrf.l3vni == 10000 + assert vrf.import_targets == ["65000:100"] + assert vrf.export_targets == ["65000:100"] + + +# ============================================================================= +# get_device_vtep +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_vtep(client, mock_sdk_client): + vlan_node = MagicMock() + vlan_node.vlan_id = _attr(10) + + vni_node = MagicMock() + vni_node.vni = _attr(10010) + + mapping = MagicMock() + mapping.vlan = _rel_one(vlan_node) + mapping.vni = _rel_one(vni_node) + mapping.vlan.fetch = AsyncMock() + mapping.vni.fetch = AsyncMock() + + device = MagicMock() + device.name = _attr("leaf1") + + vtep = MagicMock() + vtep.source_address = _attr("10.0.0.1") + vtep.udp_port = _attr(4789) + vtep.learn_restrict = _attr(False) + vtep.device = _rel_one(device) + vtep.vlan_vni_mappings = _rel_many([mapping]) + + mock_sdk_client.filters = AsyncMock(return_value=[vtep]) + + result = await client.get_device_vtep("leaf1") + + assert isinstance(result, VtepIntent) + assert result.source_address == "10.0.0.1" + assert result.udp_port == 4789 + assert result.learn_restrict is False + assert (10, 10010) in result.vlan_vni_mappings + + +@pytest.mark.asyncio +async def test_get_device_vtep_none_when_missing(client, mock_sdk_client): + mock_sdk_client.filters = AsyncMock(return_value=[]) + + result = await client.get_device_vtep("leaf1") + + assert result is None + + +# ============================================================================= +# get_device_evpn_instances +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_evpn_instances(client, mock_sdk_client): + vlan_node = MagicMock() + vlan_node.vlan_id = _attr(10) + + device = MagicMock() + device.name = _attr("leaf1") + + node = MagicMock() + node.route_distinguisher = _attr("10.0.0.1:10") + node.route_target_import = _attr("65000:10") + node.route_target_export = _attr("65000:10") + node.redistribute_learned = _attr(True) + node.device = _rel_one(device) + node.vlan = _rel_one(vlan_node) + + mock_sdk_client.filters = AsyncMock(return_value=[node]) + + result = await client.get_device_evpn_instances("leaf1") + + assert len(result) == 1 + ev = result[0] + assert isinstance(ev, EvpnInstanceIntent) + assert ev.route_distinguisher == "10.0.0.1:10" + assert ev.route_target_import == "65000:10" + assert ev.route_target_export == "65000:10" + assert ev.redistribute_learned is True + assert ev.vlan_id == 10 + + +# ============================================================================= +# get_mlag_domain +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_mlag_domain(client, mock_sdk_client): + dev1 = MagicMock() + dev1.name = _attr("leaf1") + dev2 = MagicMock() + dev2.name = _attr("leaf2") + + domain = MagicMock() + domain.domain_id = _attr("1") + domain.virtual_mac = _attr("00:1c:73:00:00:01") + domain.heartbeat_vrf = _attr("MGMT") + domain.dual_primary_detection = _attr(True) + domain.dual_primary_delay = _attr(10) + domain.dual_primary_action = _attr("errdisable") + domain.devices = _rel_many([dev1, dev2]) + + mock_sdk_client.filters = AsyncMock(return_value=[domain]) + + result = await client.get_mlag_domain("leaf1") + + assert isinstance(result, MlagDomainIntent) + assert result.domain_id == "1" + assert result.virtual_mac == "00:1c:73:00:00:01" + assert "leaf1" in result.peer_devices + assert "leaf2" in result.peer_devices + + +@pytest.mark.asyncio +async def test_get_mlag_domain_returns_none_when_not_member(client, mock_sdk_client): + dev_other = MagicMock() + dev_other.name = _attr("spine1") + + domain = MagicMock() + domain.domain_id = _attr("99") + domain.devices = _rel_many([dev_other]) + + mock_sdk_client.filters = AsyncMock(return_value=[domain]) + + result = await client.get_mlag_domain("leaf1") + + assert result is None + + +# ============================================================================= +# get_mlag_peer_config +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_mlag_peer_config(client, mock_sdk_client): + lag = MagicMock() + lag.name = _attr("Port-Channel1") + + device = MagicMock() + device.name = _attr("leaf1") + + node = MagicMock() + node.local_interface_ip = _attr("10.255.255.0/31") + node.peer_address = _attr("10.255.255.1") + node.heartbeat_peer_ip = _attr("192.168.0.2") + node.device = _rel_one(device) + node.peer_link = _rel_one(lag) + + mock_sdk_client.filters = AsyncMock(return_value=[node]) + + result = await client.get_mlag_peer_config("leaf1") + + assert isinstance(result, MlagPeerConfigIntent) + assert result.local_interface_ip == "10.255.255.0/31" + assert result.peer_address == "10.255.255.1" + assert result.heartbeat_peer_ip == "192.168.0.2" + assert result.peer_link == "Port-Channel1" + + +@pytest.mark.asyncio +async def test_get_mlag_peer_config_returns_none_when_missing(client, mock_sdk_client): + mock_sdk_client.filters = AsyncMock(return_value=[]) + + result = await client.get_mlag_peer_config("spine1") + + assert result is None + + +# ============================================================================= +# get_device_vlans +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_device_vlans_via_vtep(client, mock_sdk_client): + vlan_node = make_vlan_node(vlan_id=10, name="VLAN10", vni=10010) + + mapping = MagicMock() + mapping.vlan = _rel_one(vlan_node) + mapping.vlan.fetch = AsyncMock() + + device = MagicMock() + device.name = _attr("leaf1") + + vtep = MagicMock() + vtep.device = _rel_one(device) + vtep.vlan_vni_mappings = _rel_many([mapping]) + + mock_sdk_client.filters = AsyncMock(return_value=[vtep]) + + result = await client.get_device_vlans("leaf1") + + assert len(result) == 1 + assert isinstance(result[0], VlanIntent) + assert result[0].vlan_id == 10 + assert result[0].vni == 10010 + + +@pytest.mark.asyncio +async def test_get_device_vlans_caches_result(client, mock_sdk_client): + device = MagicMock() + device.name = _attr("leaf1") + + vtep = MagicMock() + vtep.device = _rel_one(device) + vtep.vlan_vni_mappings = _rel_many([]) + + mock_sdk_client.filters = AsyncMock(return_value=[vtep]) + + r1 = await client.get_device_vlans("leaf1") + r2 = await client.get_device_vlans("leaf1") + + assert r1 == r2 + mock_sdk_client.filters.assert_called_once()