From cd768b8a6e9a7d7a748f295b8dd9d2219974a23b Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:37:42 +0000 Subject: [PATCH 1/6] feat(infrahub): add Infrahub client for fabric intent (#42) --- pyproject.toml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e0df97e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "arista-evpn-vxlan-infrahub-client" +version = "0.1.0" +description = "Infrahub client for Arista EVPN-VXLAN fabric intent" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "infrahub-sdk>=0.16.0", + "pydantic>=2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-asyncio", + "ruff", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/infrahub_client"] + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] -- 2.53.0 From 78f8ceb1db5638e32acb649efafeefcf00da95bd Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:40:04 +0000 Subject: [PATCH 2/6] feat(infrahub): add Infrahub client for fabric intent (#42) --- src/infrahub_client/exceptions.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 src/infrahub_client/exceptions.py diff --git a/src/infrahub_client/exceptions.py b/src/infrahub_client/exceptions.py new file mode 100644 index 0000000..ad1c713 --- /dev/null +++ b/src/infrahub_client/exceptions.py @@ -0,0 +1,17 @@ +"""Exception hierarchy for the Infrahub fabric intent client.""" + + +class InfrahubClientError(Exception): + """Base exception for all Infrahub client errors.""" + + +class InfrahubConnectionError(InfrahubClientError): + """Raised when the client cannot connect to the Infrahub server.""" + + +class InfrahubQueryError(InfrahubClientError): + """Raised when a query to the Infrahub server fails.""" + + +class InfrahubNotFoundError(InfrahubClientError): + """Raised when a requested resource is not found in Infrahub.""" -- 2.53.0 From 26c9e04f582fd1362b7879ab29303d87979c3842 Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:41:28 +0000 Subject: [PATCH 3/6] feat(infrahub): add Infrahub client for fabric intent (#42) --- src/infrahub_client/models.py | 137 ++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 src/infrahub_client/models.py diff --git a/src/infrahub_client/models.py b/src/infrahub_client/models.py new file mode 100644 index 0000000..3cc80db --- /dev/null +++ b/src/infrahub_client/models.py @@ -0,0 +1,137 @@ +"""Pydantic v2 models for typed Infrahub fabric intent responses.""" + +from pydantic import BaseModel, ConfigDict + + +class DeviceIntent(BaseModel): + """Intent model for an Infrahub InfraDevice node.""" + + model_config = ConfigDict(frozen=True) + + name: str + role: str + status: str + platform: str | None = None + site: str | None = None + asn: int | None = None + + +class VlanIntent(BaseModel): + """Intent model for an Infrahub InfraVLAN node.""" + + model_config = ConfigDict(frozen=True) + + vlan_id: int + name: str + status: str + vlan_type: str + vni: int | None = None + stp_enabled: bool + + +class VniIntent(BaseModel): + """Intent model for an Infrahub InfraVNI node.""" + + model_config = ConfigDict(frozen=True) + + vni: int + vni_type: str + description: str | None = None + + +class BgpRouterConfigIntent(BaseModel): + """Intent model for an Infrahub InfraBGPRouterConfig node.""" + + model_config = ConfigDict(frozen=True) + + router_id: str + local_asn: int + default_ipv4_unicast: bool + ecmp_max_paths: int + + +class BgpPeerGroupIntent(BaseModel): + """Intent model for an Infrahub InfraBGPPeerGroup node.""" + + model_config = ConfigDict(frozen=True) + + name: str + peer_group_type: str + remote_asn: int | None = None + update_source: str | None = None + send_community: bool + ebgp_multihop: int | None = None + next_hop_unchanged: bool + + +class BgpSessionIntent(BaseModel): + """Intent model for an Infrahub InfraBGPSession node.""" + + model_config = ConfigDict(frozen=True) + + peer_address: str + description: str | None = None + enabled: bool + peer_group: str | None = None + remote_asn: int | None = None + + +class VrfIntent(BaseModel): + """Intent model for an Infrahub InfraVRF node.""" + + model_config = ConfigDict(frozen=True) + + name: str + route_distinguisher: str | None = None + vrf_id: int | None = None + l3vni: int | None = None + import_targets: list[str] + export_targets: list[str] + + +class VtepIntent(BaseModel): + """Intent model for an Infrahub InfraVTEP node.""" + + model_config = ConfigDict(frozen=True) + + source_address: str + udp_port: int + learn_restrict: bool + vlan_vni_mappings: list[tuple[int, int]] + + +class MlagDomainIntent(BaseModel): + """Intent model for an Infrahub InfraMlagDomain node.""" + + model_config = ConfigDict(frozen=True) + + domain_id: str + virtual_mac: str + heartbeat_vrf: str + dual_primary_detection: bool + dual_primary_delay: int + dual_primary_action: str + peer_devices: list[str] + + +class MlagPeerConfigIntent(BaseModel): + """Intent model for an Infrahub InfraMlagPeerConfig node.""" + + model_config = ConfigDict(frozen=True) + + local_interface_ip: str + peer_address: str + heartbeat_peer_ip: str + peer_link: str + + +class EvpnInstanceIntent(BaseModel): + """Intent model for an Infrahub InfraEVPNInstance node.""" + + model_config = ConfigDict(frozen=True) + + route_distinguisher: str + route_target_import: str + route_target_export: str + redistribute_learned: bool + vlan_id: int -- 2.53.0 From b2ed4199cdf950f603854fdc4cc886b7eb389cbb Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:42:45 +0000 Subject: [PATCH 4/6] feat(infrahub): add Infrahub client for fabric intent (#42) --- src/infrahub_client/client.py | 680 ++++++++++++++++++++++++++++++++++ 1 file changed, 680 insertions(+) create mode 100644 src/infrahub_client/client.py diff --git a/src/infrahub_client/client.py b/src/infrahub_client/client.py new file mode 100644 index 0000000..b97718f --- /dev/null +++ b/src/infrahub_client/client.py @@ -0,0 +1,680 @@ +"""Async Infrahub client for Arista EVPN-VXLAN fabric intent retrieval.""" + +import time +from typing import Any + +from infrahub_sdk import Config, InfrahubClient + +from .exceptions import InfrahubNotFoundError, InfrahubQueryError +from .models import ( + BgpPeerGroupIntent, + BgpRouterConfigIntent, + BgpSessionIntent, + DeviceIntent, + EvpnInstanceIntent, + MlagDomainIntent, + MlagPeerConfigIntent, + VlanIntent, + VniIntent, # noqa: F401 — re-exported for convenience + VrfIntent, + VtepIntent, +) + +_CACHE_TTL_SECONDS = 60 + + +class FabricInfrahubClient: + """Async client that wraps the Infrahub SDK to fetch fabric intent data. + + All public methods return immutable Pydantic models. Results are cached + in-memory with a TTL of ``_CACHE_TTL_SECONDS`` seconds to avoid redundant + queries during a single reconciliation pass. + + Example:: + + async with FabricInfrahubClient(url="http://infrahub:8080", api_token="xxx") as client: + device = await client.get_device("leaf1") + """ + + def __init__(self, url: str, api_token: str, branch: str = "main") -> None: + """Initialise the client. + + Args: + url: Base URL of the Infrahub server (e.g. ``http://infrahub:8080``). + api_token: API token used for authentication. + branch: Infrahub branch to query (default: ``"main"``). + """ + config = Config(address=url, api_token=api_token, default_branch=branch) + self._sdk = InfrahubClient(config=config) + self._branch = branch + self._cache: dict[str, tuple[Any, float]] = {} + + # ------------------------------------------------------------------ + # Async context manager support + # ------------------------------------------------------------------ + + async def __aenter__(self) -> "FabricInfrahubClient": + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _cache_get(self, key: str) -> Any | None: + """Return cached value for *key* if still within TTL, else ``None``.""" + entry = self._cache.get(key) + if entry is None: + return None + value, ts = entry + if time.monotonic() - ts > _CACHE_TTL_SECONDS: + del self._cache[key] + return None + return value + + def _cache_set(self, key: str, value: Any) -> None: + """Store *value* in the cache under *key* with current timestamp.""" + self._cache[key] = (value, time.monotonic()) + + def _opt_value(self, node: Any, attr: str) -> Any | None: + """Safely read ``node..value``, returning ``None`` if missing.""" + try: + attr_obj = getattr(node, attr, None) + if attr_obj is None: + return None + return attr_obj.value + except AttributeError: + return None + + async def _fetch_node(self, kind: str, **filters: Any) -> Any: + """Fetch a single node; raise :exc:`InfrahubNotFoundError` if absent. + + Args: + kind: Infrahub node kind (e.g. ``"InfraDevice"``). + **filters: Keyword arguments forwarded to :meth:`InfrahubClient.get`. + + Returns: + The SDK node object. + + Raises: + InfrahubNotFoundError: When no node matches the filters. + InfrahubQueryError: On SDK/network error. + """ + try: + node = await self._sdk.get(kind=kind, branch=self._branch, **filters) + except Exception as exc: + msg = str(exc).lower() + if "not found" in msg or "does not exist" in msg or "nodnotfound" in msg.replace(" ", ""): + raise InfrahubNotFoundError( + f"{kind} with filters {filters} not found" + ) from exc + raise InfrahubQueryError(f"Query failed for {kind} with {filters}: {exc}") from exc + + if node is None: + raise InfrahubNotFoundError(f"{kind} with filters {filters} not found") + return node + + async def _fetch_all(self, kind: str, **filters: Any) -> list[Any]: + """Fetch all nodes of *kind* matching optional *filters*. + + Args: + kind: Infrahub node kind. + **filters: Optional filter keyword arguments. + + Returns: + List of SDK node objects (may be empty). + + Raises: + InfrahubQueryError: On SDK/network error. + """ + try: + if filters: + return await self._sdk.filters(kind=kind, branch=self._branch, **filters) + return await self._sdk.all(kind=kind, branch=self._branch) + except Exception as exc: + raise InfrahubQueryError(f"Query failed for {kind}: {exc}") from exc + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def get_device(self, name: str) -> DeviceIntent: + """Fetch a device and its core relationships by name. + + Args: + name: Device name as stored in Infrahub (e.g. ``"leaf1"``). + + Returns: + :class:`DeviceIntent` populated from Infrahub. + + Raises: + InfrahubNotFoundError: When no device with *name* exists. + InfrahubQueryError: On query failure. + """ + cache_key = f"device:{name}" + if cached := self._cache_get(cache_key): + return cached + + node = await self._fetch_node("InfraDevice", name__value=name) + + # Resolve optional one-cardinality relationships + platform_name: str | None = None + if node.platform and node.platform.initialized: + await node.platform.fetch() + platform_name = self._opt_value(node.platform.peer, "name") + + site_name: str | None = None + if node.site and node.site.initialized: + await node.site.fetch() + site_name = self._opt_value(node.site.peer, "name") + + asn_value: int | None = None + if node.asn and node.asn.initialized: + await node.asn.fetch() + asn_value = self._opt_value(node.asn.peer, "asn") + + result = DeviceIntent( + name=node.name.value, + role=node.role.value, + status=node.status.value, + platform=platform_name, + site=site_name, + asn=asn_value, + ) + self._cache_set(cache_key, result) + return result + + async def get_device_vlans(self, device_name: str) -> list[VlanIntent]: + """Fetch VLANs associated with a device via its VTEP vlan-vni mappings. + + Falls back to VLANs linked through SVI interfaces when no VTEP exists. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + List of :class:`VlanIntent` objects (may be empty). + + Raises: + InfrahubNotFoundError: When the device does not exist. + InfrahubQueryError: On query failure. + """ + cache_key = f"device_vlans:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + # Ensure device exists + device_node = await self._fetch_node("InfraDevice", name__value=device_name) + + vlan_intents: list[VlanIntent] = [] + seen_vlan_ids: set[int] = set() + + # Try via VTEP → VlanVniMapping → VLAN + try: + vtep_nodes = await self._fetch_all( + "InfraVTEP", device__name__value=device_name + ) + except InfrahubQueryError: + vtep_nodes = [] + + for vtep in vtep_nodes: + if not (vtep.vlan_vni_mappings and vtep.vlan_vni_mappings.initialized): + continue + for mapping_rel in vtep.vlan_vni_mappings.peers: + mapping = mapping_rel.peer + if not (mapping.vlan and mapping.vlan.initialized): + continue + await mapping.vlan.fetch() + vlan_node = mapping.vlan.peer + if vlan_node is None: + continue + vlan_id: int = vlan_node.vlan_id.value + if vlan_id in seen_vlan_ids: + continue + seen_vlan_ids.add(vlan_id) + + vni_val: int | None = None + if vlan_node.vni and vlan_node.vni.initialized: + await vlan_node.vni.fetch() + if vlan_node.vni.peer: + vni_val = vlan_node.vni.peer.vni.value + + vlan_intents.append( + VlanIntent( + vlan_id=vlan_id, + name=vlan_node.name.value, + status=vlan_node.status.value, + vlan_type=vlan_node.vlan_type.value, + vni=vni_val, + stp_enabled=bool(self._opt_value(vlan_node, "stp_enabled")), + ) + ) + + # Fallback: VLANs via SVI interfaces on the device + if not vlan_intents: + interfaces = await self._fetch_all( + "InfraInterfaceVlan", device__name__value=device_name + ) + for intf in interfaces: + if not (intf.vlan and intf.vlan.initialized): + continue + await intf.vlan.fetch() + vlan_node = intf.vlan.peer + if vlan_node is None: + continue + vlan_id = vlan_node.vlan_id.value + if vlan_id in seen_vlan_ids: + continue + seen_vlan_ids.add(vlan_id) + + vni_val = None + if vlan_node.vni and vlan_node.vni.initialized: + await vlan_node.vni.fetch() + if vlan_node.vni.peer: + vni_val = vlan_node.vni.peer.vni.value + + vlan_intents.append( + VlanIntent( + vlan_id=vlan_id, + name=vlan_node.name.value, + status=vlan_node.status.value, + vlan_type=vlan_node.vlan_type.value, + vni=vni_val, + stp_enabled=bool(self._opt_value(vlan_node, "stp_enabled")), + ) + ) + + # Suppress unused variable warning — device_node used for existence check + _ = device_node + + self._cache_set(cache_key, vlan_intents) + return vlan_intents + + async def get_device_bgp_config(self, device_name: str) -> BgpRouterConfigIntent: + """Fetch the BGP router config for a device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + :class:`BgpRouterConfigIntent` for the device. + + Raises: + InfrahubNotFoundError: When no BGP config exists for the device. + InfrahubQueryError: On query failure. + """ + cache_key = f"bgp_config:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + node = await self._fetch_node( + "InfraBGPRouterConfig", device__name__value=device_name + ) + + local_asn: int = 0 + if node.local_asn and node.local_asn.initialized: + await node.local_asn.fetch() + if node.local_asn.peer: + local_asn = node.local_asn.peer.asn.value + + result = BgpRouterConfigIntent( + router_id=node.router_id.value, + local_asn=local_asn, + default_ipv4_unicast=bool(self._opt_value(node, "default_ipv4_unicast")), + ecmp_max_paths=int(self._opt_value(node, "ecmp_max_paths") or 1), + ) + self._cache_set(cache_key, result) + return result + + async def get_device_bgp_peer_groups(self, device_name: str) -> list[BgpPeerGroupIntent]: + """Fetch all BGP peer groups configured on a device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + List of :class:`BgpPeerGroupIntent` objects. + + Raises: + InfrahubNotFoundError: When no BGP config exists for the device. + InfrahubQueryError: On query failure. + """ + cache_key = f"bgp_peer_groups:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + bgp_node = await self._fetch_node( + "InfraBGPRouterConfig", device__name__value=device_name + ) + + results: list[BgpPeerGroupIntent] = [] + if not (bgp_node.peer_groups and bgp_node.peer_groups.initialized): + self._cache_set(cache_key, results) + return results + + for pg_rel in bgp_node.peer_groups.peers: + pg = pg_rel.peer + + remote_asn: int | None = None + if pg.remote_asn and pg.remote_asn.initialized: + await pg.remote_asn.fetch() + if pg.remote_asn.peer: + remote_asn = pg.remote_asn.peer.asn.value + + results.append( + BgpPeerGroupIntent( + name=pg.name.value, + peer_group_type=pg.peer_group_type.value, + remote_asn=remote_asn, + update_source=self._opt_value(pg, "update_source"), + send_community=bool(self._opt_value(pg, "send_community")), + ebgp_multihop=self._opt_value(pg, "ebgp_multihop"), + next_hop_unchanged=bool(self._opt_value(pg, "next_hop_unchanged")), + ) + ) + + self._cache_set(cache_key, results) + return results + + async def get_device_bgp_sessions(self, device_name: str) -> list[BgpSessionIntent]: + """Fetch all BGP sessions configured on a device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + List of :class:`BgpSessionIntent` objects. + + Raises: + InfrahubNotFoundError: When no BGP config exists for the device. + InfrahubQueryError: On query failure. + """ + cache_key = f"bgp_sessions:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + bgp_node = await self._fetch_node( + "InfraBGPRouterConfig", device__name__value=device_name + ) + + results: list[BgpSessionIntent] = [] + if not (bgp_node.sessions and bgp_node.sessions.initialized): + self._cache_set(cache_key, results) + return results + + for sess_rel in bgp_node.sessions.peers: + sess = sess_rel.peer + + peer_group_name: str | None = None + if sess.peer_group and sess.peer_group.initialized: + await sess.peer_group.fetch() + if sess.peer_group.peer: + peer_group_name = sess.peer_group.peer.name.value + + remote_asn: int | None = None + if sess.remote_asn and sess.remote_asn.initialized: + await sess.remote_asn.fetch() + if sess.remote_asn.peer: + remote_asn = sess.remote_asn.peer.asn.value + + results.append( + BgpSessionIntent( + peer_address=sess.peer_address.value, + description=self._opt_value(sess, "description"), + enabled=bool(self._opt_value(sess, "enabled")), + peer_group=peer_group_name, + remote_asn=remote_asn, + ) + ) + + self._cache_set(cache_key, results) + return results + + async def get_device_vrfs(self, device_name: str) -> list[VrfIntent]: + """Fetch all VRFs assigned to a device via VRFDeviceAssignment. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + List of :class:`VrfIntent` objects (may be empty). + + Raises: + InfrahubQueryError: On query failure. + """ + cache_key = f"device_vrfs:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + assignments = await self._fetch_all( + "InfraVRFDeviceAssignment", device__name__value=device_name + ) + + results: list[VrfIntent] = [] + for assignment in assignments: + if not (assignment.vrf and assignment.vrf.initialized): + continue + await assignment.vrf.fetch() + vrf_node = assignment.vrf.peer + if vrf_node is None: + continue + + l3vni: int | None = None + if vrf_node.l3vni and vrf_node.l3vni.initialized: + await vrf_node.l3vni.fetch() + if vrf_node.l3vni.peer: + l3vni = vrf_node.l3vni.peer.vni.value + + import_targets: list[str] = [] + if vrf_node.import_targets and vrf_node.import_targets.initialized: + for rt_rel in vrf_node.import_targets.peers: + import_targets.append(rt_rel.peer.target.value) + + export_targets: list[str] = [] + if vrf_node.export_targets and vrf_node.export_targets.initialized: + for rt_rel in vrf_node.export_targets.peers: + export_targets.append(rt_rel.peer.target.value) + + results.append( + VrfIntent( + name=vrf_node.name.value, + route_distinguisher=self._opt_value(vrf_node, "route_distinguisher"), + vrf_id=self._opt_value(vrf_node, "vrf_id"), + l3vni=l3vni, + import_targets=import_targets, + export_targets=export_targets, + ) + ) + + self._cache_set(cache_key, results) + return results + + async def get_device_vtep(self, device_name: str) -> VtepIntent | None: + """Fetch the VTEP configuration for a device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + :class:`VtepIntent` if a VTEP exists for the device, else ``None``. + + Raises: + InfrahubQueryError: On query failure. + """ + cache_key = f"device_vtep:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + vtep_nodes = await self._fetch_all( + "InfraVTEP", device__name__value=device_name + ) + if not vtep_nodes: + self._cache_set(cache_key, None) + return None + + vtep = vtep_nodes[0] + mappings: list[tuple[int, int]] = [] + + if vtep.vlan_vni_mappings and vtep.vlan_vni_mappings.initialized: + for mapping_rel in vtep.vlan_vni_mappings.peers: + mapping = mapping_rel.peer + vlan_id: int | None = None + vni_id: int | None = None + + if mapping.vlan and mapping.vlan.initialized: + await mapping.vlan.fetch() + if mapping.vlan.peer: + vlan_id = mapping.vlan.peer.vlan_id.value + + if mapping.vni and mapping.vni.initialized: + await mapping.vni.fetch() + if mapping.vni.peer: + vni_id = mapping.vni.peer.vni.value + + if vlan_id is not None and vni_id is not None: + mappings.append((vlan_id, vni_id)) + + result = VtepIntent( + source_address=vtep.source_address.value, + udp_port=int(self._opt_value(vtep, "udp_port") or 4789), + learn_restrict=bool(self._opt_value(vtep, "learn_restrict")), + vlan_vni_mappings=mappings, + ) + self._cache_set(cache_key, result) + return result + + async def get_device_evpn_instances(self, device_name: str) -> list[EvpnInstanceIntent]: + """Fetch all EVPN instances for a device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + List of :class:`EvpnInstanceIntent` objects. + + Raises: + InfrahubQueryError: On query failure. + """ + cache_key = f"device_evpn:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + nodes = await self._fetch_all( + "InfraEVPNInstance", device__name__value=device_name + ) + + results: list[EvpnInstanceIntent] = [] + for node in nodes: + vlan_id: int = 0 + if node.vlan and node.vlan.initialized: + await node.vlan.fetch() + if node.vlan.peer: + vlan_id = node.vlan.peer.vlan_id.value + + results.append( + EvpnInstanceIntent( + route_distinguisher=node.route_distinguisher.value, + route_target_import=node.route_target_import.value, + route_target_export=node.route_target_export.value, + redistribute_learned=bool( + self._opt_value(node, "redistribute_learned") + ), + vlan_id=vlan_id, + ) + ) + + self._cache_set(cache_key, results) + return results + + async def get_mlag_domain(self, device_name: str) -> MlagDomainIntent | None: + """Fetch the MLAG domain that includes the given device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + :class:`MlagDomainIntent` if the device belongs to an MLAG domain, + else ``None``. + + Raises: + InfrahubQueryError: On query failure. + """ + cache_key = f"mlag_domain:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + domains = await self._fetch_all("InfraMlagDomain") + for domain in domains: + if not (domain.devices and domain.devices.initialized): + continue + device_names: list[str] = [] + for dev_rel in domain.devices.peers: + device_names.append(dev_rel.peer.name.value) + + if device_name not in device_names: + continue + + peer_vlan: str | None = None + if domain.peer_vlan and domain.peer_vlan.initialized: + await domain.peer_vlan.fetch() + if domain.peer_vlan.peer: + peer_vlan = str(domain.peer_vlan.peer.vlan_id.value) + + result = MlagDomainIntent( + domain_id=str(domain.domain_id.value), + virtual_mac=domain.virtual_mac.value, + heartbeat_vrf=self._opt_value(domain, "heartbeat_vrf") or "", + dual_primary_detection=bool( + self._opt_value(domain, "dual_primary_detection") + ), + dual_primary_delay=int(self._opt_value(domain, "dual_primary_delay") or 0), + dual_primary_action=self._opt_value(domain, "dual_primary_action") or "", + peer_devices=device_names, + ) + self._cache_set(cache_key, result) + return result + + self._cache_set(cache_key, None) + return None + + async def get_mlag_peer_config(self, device_name: str) -> MlagPeerConfigIntent | None: + """Fetch the MLAG peer config for a device. + + Args: + device_name: Device name as stored in Infrahub. + + Returns: + :class:`MlagPeerConfigIntent` if a config exists, else ``None``. + + Raises: + InfrahubQueryError: On query failure. + """ + cache_key = f"mlag_peer_config:{device_name}" + if cached := self._cache_get(cache_key): + return cached + + configs = await self._fetch_all( + "InfraMlagPeerConfig", device__name__value=device_name + ) + if not configs: + self._cache_set(cache_key, None) + return None + + cfg = configs[0] + + peer_link_name: str = "" + if cfg.peer_link and cfg.peer_link.initialized: + await cfg.peer_link.fetch() + if cfg.peer_link.peer: + peer_link_name = cfg.peer_link.peer.name.value + + result = MlagPeerConfigIntent( + local_interface_ip=cfg.local_interface_ip.value, + peer_address=cfg.peer_address.value, + heartbeat_peer_ip=self._opt_value(cfg, "heartbeat_peer_ip") or "", + peer_link=peer_link_name, + ) + self._cache_set(cache_key, result) + return result -- 2.53.0 From ae29117827439f775d3911cdb40fcf7a778a12a2 Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:42:52 +0000 Subject: [PATCH 5/6] feat(infrahub): add Infrahub client for fabric intent (#42) --- src/infrahub_client/__init__.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/infrahub_client/__init__.py diff --git a/src/infrahub_client/__init__.py b/src/infrahub_client/__init__.py new file mode 100644 index 0000000..6c7f338 --- /dev/null +++ b/src/infrahub_client/__init__.py @@ -0,0 +1,28 @@ +"""Infrahub client for Arista EVPN-VXLAN fabric intent. + +Public API:: + + from infrahub_client import FabricInfrahubClient + from infrahub_client import ( + InfrahubClientError, + InfrahubConnectionError, + InfrahubQueryError, + InfrahubNotFoundError, + ) +""" + +from .client import FabricInfrahubClient +from .exceptions import ( + InfrahubClientError, + InfrahubConnectionError, + InfrahubNotFoundError, + InfrahubQueryError, +) + +__all__ = [ + "FabricInfrahubClient", + "InfrahubClientError", + "InfrahubConnectionError", + "InfrahubNotFoundError", + "InfrahubQueryError", +] -- 2.53.0 From 8967f04f4f2f925478b493dc484cf326d7a3398d Mon Sep 17 00:00:00 2001 From: Damien Arnodo Date: Thu, 26 Feb 2026 12:44:07 +0000 Subject: [PATCH 6/6] feat(infrahub): add Infrahub client for fabric intent (#42) --- tests/test_infrahub_client.py | 598 ++++++++++++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 tests/test_infrahub_client.py diff --git a/tests/test_infrahub_client.py b/tests/test_infrahub_client.py new file mode 100644 index 0000000..3e3f4f5 --- /dev/null +++ b/tests/test_infrahub_client.py @@ -0,0 +1,598 @@ +"""Unit tests for FabricInfrahubClient — all Infrahub SDK calls are mocked.""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from infrahub_client import ( + FabricInfrahubClient, + InfrahubNotFoundError, + InfrahubQueryError, +) +from infrahub_client.models import ( + BgpPeerGroupIntent, + BgpRouterConfigIntent, + BgpSessionIntent, + DeviceIntent, + EvpnInstanceIntent, + MlagDomainIntent, + MlagPeerConfigIntent, + VlanIntent, + VrfIntent, + VtepIntent, +) + + +# --------------------------------------------------------------------------- +# Helpers to build mock SDK node objects +# --------------------------------------------------------------------------- + + +def _attr(value): + """Return a simple attribute mock with ``.value`` set.""" + m = MagicMock() + m.value = value + return m + + +def _rel_one(peer_node, initialized=True): + """Return a mock for a cardinality-one relationship.""" + m = MagicMock() + m.initialized = initialized + m.peer = peer_node + m.fetch = AsyncMock() + return m + + +def _rel_many(peers, initialized=True): + """Return a mock for a cardinality-many relationship (list of peer wrappers).""" + m = MagicMock() + m.initialized = initialized + peer_wrappers = [] + for p in peers: + pw = MagicMock() + pw.peer = p + peer_wrappers.append(pw) + m.peers = peer_wrappers + return m + + +def _make_device( + name="leaf1", + role="leaf", + status="active", + platform_name="eos", + site_name="dc1", + asn_value=65001, +): + """Build a mock InfraDevice SDK node.""" + node = MagicMock() + node.name = _attr(name) + node.role = _attr(role) + node.status = _attr(status) + + platform = MagicMock() + platform.name = _attr(platform_name) + node.platform = _rel_one(platform) + + site = MagicMock() + site.name = _attr(site_name) + node.site = _rel_one(site) + + asn = MagicMock() + asn.asn = _attr(asn_value) + node.asn = _rel_one(asn) + + return node + + +def _make_sdk_client(get_return=None, all_return=None, filters_return=None, get_raises=None): + """Patch InfrahubClient inside client module and configure return values.""" + mock_sdk = MagicMock() + if get_raises: + mock_sdk.get = AsyncMock(side_effect=get_raises) + else: + mock_sdk.get = AsyncMock(return_value=get_return) + mock_sdk.all = AsyncMock(return_value=all_return or []) + mock_sdk.filters = AsyncMock(return_value=filters_return or []) + return mock_sdk + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def client(): + """FabricInfrahubClient with a patched InfrahubClient SDK.""" + with patch("infrahub_client.client.InfrahubClient") as MockSdk: + MockSdk.return_value = MagicMock() + c = FabricInfrahubClient(url="http://infrahub:8080", api_token="token", branch="main") + yield c + + +# --------------------------------------------------------------------------- +# get_device +# --------------------------------------------------------------------------- + + +class TestGetDevice: + async def test_returns_device_intent(self, client): + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + result = await client.get_device("leaf1") + + assert isinstance(result, DeviceIntent) + assert result.name == "leaf1" + assert result.role == "leaf" + assert result.status == "active" + assert result.platform == "eos" + assert result.site == "dc1" + assert result.asn == 65001 + + async def test_not_found_raises_error(self, client): + client._sdk.get = AsyncMock(side_effect=Exception("not found")) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device("nonexistent") + + async def test_branch_is_forwarded(self, client): + client._branch = "proposed" + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + await client.get_device("leaf1") + + client._sdk.get.assert_called_once_with( + kind="InfraDevice", branch="proposed", name__value="leaf1" + ) + + async def test_caching_second_call_skips_sdk(self, client): + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + first = await client.get_device("leaf1") + second = await client.get_device("leaf1") + + assert first == second + # SDK should only have been called once + assert client._sdk.get.call_count == 1 + + async def test_optional_fields_none_when_uninitialized(self, client): + node = MagicMock() + node.name = _attr("spine1") + node.role = _attr("spine") + node.status = _attr("active") + node.platform = _rel_one(None, initialized=False) + node.site = _rel_one(None, initialized=False) + node.asn = _rel_one(None, initialized=False) + client._sdk.get = AsyncMock(return_value=node) + + result = await client.get_device("spine1") + + assert result.platform is None + assert result.site is None + assert result.asn is None + + +# --------------------------------------------------------------------------- +# get_device_bgp_config +# --------------------------------------------------------------------------- + + +class TestGetDeviceBgpConfig: + def _make_bgp_node(self, router_id="10.0.0.1", asn=65001, default_ipv4=True, ecmp=4): + node = MagicMock() + node.router_id = _attr(router_id) + node.default_ipv4_unicast = _attr(default_ipv4) + node.ecmp_max_paths = _attr(ecmp) + + asn_peer = MagicMock() + asn_peer.asn = _attr(asn) + node.local_asn = _rel_one(asn_peer) + + return node + + async def test_returns_bgp_config_intent(self, client): + bgp_node = self._make_bgp_node() + client._sdk.get = AsyncMock(return_value=bgp_node) + + result = await client.get_device_bgp_config("leaf1") + + assert isinstance(result, BgpRouterConfigIntent) + assert result.router_id == "10.0.0.1" + assert result.local_asn == 65001 + assert result.default_ipv4_unicast is True + assert result.ecmp_max_paths == 4 + + async def test_not_found_raises_error(self, client): + client._sdk.get = AsyncMock(side_effect=Exception("not found")) + + with pytest.raises(InfrahubNotFoundError): + await client.get_device_bgp_config("leaf1") + + async def test_caching(self, client): + bgp_node = self._make_bgp_node() + client._sdk.get = AsyncMock(return_value=bgp_node) + + await client.get_device_bgp_config("leaf1") + await client.get_device_bgp_config("leaf1") + + assert client._sdk.get.call_count == 1 + + +# --------------------------------------------------------------------------- +# get_device_bgp_peer_groups +# --------------------------------------------------------------------------- + + +class TestGetDeviceBgpPeerGroups: + def _make_pg(self, name="UNDERLAY", pg_type="underlay", remote_asn=65000): + pg = MagicMock() + pg.name = _attr(name) + pg.peer_group_type = _attr(pg_type) + pg.update_source = _attr("Loopback0") + pg.send_community = _attr(True) + pg.ebgp_multihop = _attr(None) + pg.next_hop_unchanged = _attr(False) + + asn_peer = MagicMock() + asn_peer.asn = _attr(remote_asn) + pg.remote_asn = _rel_one(asn_peer) + + return pg + + async def test_returns_peer_group_intents(self, client): + pg = self._make_pg() + bgp_node = MagicMock() + bgp_node.peer_groups = _rel_many([pg]) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_peer_groups("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], BgpPeerGroupIntent) + assert results[0].name == "UNDERLAY" + assert results[0].peer_group_type == "underlay" + assert results[0].remote_asn == 65000 + assert results[0].send_community is True + + async def test_empty_when_no_peer_groups(self, client): + bgp_node = MagicMock() + bgp_node.peer_groups = _rel_many([], initialized=False) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_peer_groups("leaf1") + + assert results == [] + + +# --------------------------------------------------------------------------- +# get_device_bgp_sessions +# --------------------------------------------------------------------------- + + +class TestGetDeviceBgpSessions: + def _make_session(self, peer_addr="192.168.1.1", enabled=True): + sess = MagicMock() + sess.peer_address = _attr(peer_addr) + sess.description = _attr("peer session") + sess.enabled = _attr(enabled) + + pg_peer = MagicMock() + pg_peer.name = _attr("UNDERLAY") + sess.peer_group = _rel_one(pg_peer) + + asn_peer = MagicMock() + asn_peer.asn = _attr(65000) + sess.remote_asn = _rel_one(asn_peer) + + return sess + + async def test_returns_session_intents(self, client): + sess = self._make_session() + bgp_node = MagicMock() + bgp_node.sessions = _rel_many([sess]) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_sessions("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], BgpSessionIntent) + assert results[0].peer_address == "192.168.1.1" + assert results[0].enabled is True + assert results[0].peer_group == "UNDERLAY" + assert results[0].remote_asn == 65000 + + async def test_empty_when_no_sessions(self, client): + bgp_node = MagicMock() + bgp_node.sessions = _rel_many([], initialized=False) + client._sdk.get = AsyncMock(return_value=bgp_node) + + results = await client.get_device_bgp_sessions("leaf1") + + assert results == [] + + +# --------------------------------------------------------------------------- +# get_device_vrfs +# --------------------------------------------------------------------------- + + +class TestGetDeviceVrfs: + def _make_assignment(self, vrf_name="PROD", rd="65001:100", vni=100100): + vni_peer = MagicMock() + vni_peer.vni = _attr(vni) + + rt_import = MagicMock() + rt_import.target = _attr("65001:100") + rt_export = MagicMock() + rt_export.target = _attr("65001:100") + + vrf_node = MagicMock() + vrf_node.name = _attr(vrf_name) + vrf_node.route_distinguisher = _attr(rd) + vrf_node.vrf_id = _attr(None) + vrf_node.l3vni = _rel_one(vni_peer) + vrf_node.import_targets = _rel_many([rt_import]) + vrf_node.export_targets = _rel_many([rt_export]) + + assignment = MagicMock() + assignment.vrf = _rel_one(vrf_node) + + return assignment + + async def test_returns_vrf_intents(self, client): + assignment = self._make_assignment() + client._sdk.filters = AsyncMock(return_value=[assignment]) + + results = await client.get_device_vrfs("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], VrfIntent) + assert results[0].name == "PROD" + assert results[0].route_distinguisher == "65001:100" + assert results[0].l3vni == 100100 + assert "65001:100" in results[0].import_targets + + async def test_empty_when_no_assignments(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + results = await client.get_device_vrfs("leaf1") + + assert results == [] + + async def test_caching(self, client): + assignment = self._make_assignment() + client._sdk.filters = AsyncMock(return_value=[assignment]) + + await client.get_device_vrfs("leaf1") + await client.get_device_vrfs("leaf1") + + assert client._sdk.filters.call_count == 1 + + +# --------------------------------------------------------------------------- +# get_device_vtep +# --------------------------------------------------------------------------- + + +class TestGetDeviceVtep: + def _make_vtep(self, src_addr="10.255.0.1", udp_port=4789): + vlan_node = MagicMock() + vlan_node.vlan_id = _attr(100) + vni_node = MagicMock() + vni_node.vni = _attr(10100) + + mapping = MagicMock() + mapping.vlan = _rel_one(vlan_node) + mapping.vni = _rel_one(vni_node) + + vtep = MagicMock() + vtep.source_address = _attr(src_addr) + vtep.udp_port = _attr(udp_port) + vtep.learn_restrict = _attr(True) + vtep.vlan_vni_mappings = _rel_many([mapping]) + + return vtep + + async def test_returns_vtep_intent(self, client): + vtep = self._make_vtep() + client._sdk.filters = AsyncMock(return_value=[vtep]) + + result = await client.get_device_vtep("leaf1") + + assert isinstance(result, VtepIntent) + assert result.source_address == "10.255.0.1" + assert result.udp_port == 4789 + assert result.learn_restrict is True + assert (100, 10100) in result.vlan_vni_mappings + + async def test_returns_none_when_no_vtep(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + result = await client.get_device_vtep("leaf1") + + assert result is None + + +# --------------------------------------------------------------------------- +# get_device_evpn_instances +# --------------------------------------------------------------------------- + + +class TestGetDeviceEvpnInstances: + def _make_evpn_instance(self, rd="65001:100", rt_import="65001:100", rt_export="65001:100"): + vlan_node = MagicMock() + vlan_node.vlan_id = _attr(100) + + node = MagicMock() + node.route_distinguisher = _attr(rd) + node.route_target_import = _attr(rt_import) + node.route_target_export = _attr(rt_export) + node.redistribute_learned = _attr(True) + node.vlan = _rel_one(vlan_node) + + return node + + async def test_returns_evpn_instance_intents(self, client): + evpn_node = self._make_evpn_instance() + client._sdk.filters = AsyncMock(return_value=[evpn_node]) + + results = await client.get_device_evpn_instances("leaf1") + + assert len(results) == 1 + assert isinstance(results[0], EvpnInstanceIntent) + assert results[0].route_distinguisher == "65001:100" + assert results[0].redistribute_learned is True + assert results[0].vlan_id == 100 + + async def test_empty_when_no_instances(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + results = await client.get_device_evpn_instances("leaf1") + + assert results == [] + + +# --------------------------------------------------------------------------- +# get_mlag_domain +# --------------------------------------------------------------------------- + + +class TestGetMlagDomain: + def _make_domain(self, device_names=("leaf1", "leaf2")): + devices = [] + for dname in device_names: + d = MagicMock() + d.name = _attr(dname) + devices.append(d) + + peer_vlan_node = MagicMock() + peer_vlan_node.vlan_id = _attr(4094) + + domain = MagicMock() + domain.domain_id = _attr("1") + domain.virtual_mac = _attr("00:1c:73:00:00:01") + domain.heartbeat_vrf = _attr("MGMT") + domain.dual_primary_detection = _attr(True) + domain.dual_primary_delay = _attr(10) + domain.dual_primary_action = _attr("errDisable") + domain.devices = _rel_many(devices) + domain.peer_vlan = _rel_one(peer_vlan_node) + + return domain + + async def test_returns_mlag_domain_intent(self, client): + domain = self._make_domain() + client._sdk.all = AsyncMock(return_value=[domain]) + + result = await client.get_mlag_domain("leaf1") + + assert isinstance(result, MlagDomainIntent) + assert result.domain_id == "1" + assert result.virtual_mac == "00:1c:73:00:00:01" + assert "leaf1" in result.peer_devices + assert "leaf2" in result.peer_devices + + async def test_returns_none_when_device_not_in_any_domain(self, client): + domain = self._make_domain(device_names=("leaf3", "leaf4")) + client._sdk.all = AsyncMock(return_value=[domain]) + + result = await client.get_mlag_domain("leaf1") + + assert result is None + + async def test_caching(self, client): + domain = self._make_domain() + client._sdk.all = AsyncMock(return_value=[domain]) + + await client.get_mlag_domain("leaf1") + await client.get_mlag_domain("leaf1") + + assert client._sdk.all.call_count == 1 + + +# --------------------------------------------------------------------------- +# get_mlag_peer_config +# --------------------------------------------------------------------------- + + +class TestGetMlagPeerConfig: + def _make_peer_config(self): + peer_link_node = MagicMock() + peer_link_node.name = _attr("Port-Channel1") + + cfg = MagicMock() + cfg.local_interface_ip = _attr("192.168.255.0/31") + cfg.peer_address = _attr("192.168.255.1") + cfg.heartbeat_peer_ip = _attr("192.168.1.2") + cfg.peer_link = _rel_one(peer_link_node) + + return cfg + + async def test_returns_mlag_peer_config_intent(self, client): + cfg = self._make_peer_config() + client._sdk.filters = AsyncMock(return_value=[cfg]) + + result = await client.get_mlag_peer_config("leaf1") + + assert isinstance(result, MlagPeerConfigIntent) + assert result.local_interface_ip == "192.168.255.0/31" + assert result.peer_address == "192.168.255.1" + assert result.peer_link == "Port-Channel1" + + async def test_returns_none_when_no_config(self, client): + client._sdk.filters = AsyncMock(return_value=[]) + + result = await client.get_mlag_peer_config("leaf1") + + assert result is None + + +# --------------------------------------------------------------------------- +# Async context manager +# --------------------------------------------------------------------------- + + +class TestContextManager: + async def test_context_manager(self): + with patch("infrahub_client.client.InfrahubClient"): + async with FabricInfrahubClient( + url="http://infrahub:8080", api_token="token" + ) as client: + assert isinstance(client, FabricInfrahubClient) + + +# --------------------------------------------------------------------------- +# Cache TTL expiry +# --------------------------------------------------------------------------- + + +class TestCacheTtl: + async def test_expired_cache_triggers_new_sdk_call(self, client): + device_node = _make_device() + client._sdk.get = AsyncMock(return_value=device_node) + + await client.get_device("leaf1") + # Manually expire the cache entry + client._cache["device:leaf1"] = (client._cache["device:leaf1"][0], 0.0) + await client.get_device("leaf1") + + assert client._sdk.get.call_count == 2 + + +# --------------------------------------------------------------------------- +# InfrahubQueryError propagation +# --------------------------------------------------------------------------- + + +class TestQueryError: + async def test_sdk_unexpected_error_raises_query_error(self, client): + client._sdk.get = AsyncMock(side_effect=Exception("connection refused")) + + with pytest.raises(InfrahubQueryError): + await client.get_device_bgp_config("leaf1") -- 2.53.0