feat(infrahub): add Infrahub client for fabric intent (#42)

- Replace pynetbox with infrahub-sdk>=0.16.0 + pydantic>=2.0 in dependencies
- Add pytest and pytest-asyncio to dev dependencies
- Implement FabricInfrahubClient async client (src/infrahub/client.py)
  - get_device, get_device_vlans, get_device_bgp_config
  - get_device_bgp_peer_groups, get_device_bgp_sessions
  - get_device_vrfs, get_device_vtep, get_device_evpn_instances
  - get_mlag_domain, get_mlag_peer_config
  - TTL-based caching (60s) to avoid redundant SDK queries
  - Async context manager support
- Add Pydantic v2 frozen models for all intent types (src/infrahub/models.py)
- Add custom exception hierarchy (src/infrahub/exceptions.py)
- Add unit tests with fully mocked SDK (tests/test_infrahub_client.py)
  - Tests for correct model return, NotFoundError, branch selection, caching

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Damien
2026-02-26 14:03:09 +01:00
parent 21a44db736
commit f663cc623c
6 changed files with 1610 additions and 2 deletions

View File

@@ -1,13 +1,14 @@
[project]
name = "fabric-orchestrator"
version = "0.1.0"
description = "Declarative Network Fabric Orchestrator - Terraform-like infrastructure management for Arista EVPN-VXLAN using gNMI, YANG, and NetBox as Source of Truth"
description = "Declarative Network Fabric Orchestrator - Terraform-like infrastructure management for Arista EVPN-VXLAN using gNMI, YANG, and Infrahub as Source of Truth"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"click>=8.1.0",
"pygnmi>=0.8.0",
"pynetbox>=7.5.0",
"infrahub-sdk>=0.16.0",
"pydantic>=2.0",
"rich>=13.0.0",
]
@@ -35,5 +36,7 @@ ignore = ["E501"]
[dependency-groups]
dev = [
"basedpyright>=1.37.4",
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"ruff>=0.14.10",
]

24
src/infrahub/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""
Infrahub client package for the Fabric Orchestrator.
Exports the main async client and exception hierarchy for use by the
Reconciler and Prefect flows.
"""
from __future__ import annotations
from .client import FabricInfrahubClient
from .exceptions import (
InfrahubClientError,
InfrahubConnectionError,
InfrahubNotFoundError,
InfrahubQueryError,
)
__all__ = [
"FabricInfrahubClient",
"InfrahubClientError",
"InfrahubConnectionError",
"InfrahubNotFoundError",
"InfrahubQueryError",
]

789
src/infrahub/client.py Normal file
View File

@@ -0,0 +1,789 @@
"""
Infrahub Client for Fabric Intent.
Async client wrapping `infrahub-sdk` to fetch fabric intent data from a remote
Infrahub instance. Replaces the previous NetBox/pynetbox client and is used by
the Reconciler and Prefect flows.
"""
from __future__ import annotations
import time
from typing import Any
from infrahub_sdk import Config, InfrahubClient
from .exceptions import InfrahubNotFoundError, InfrahubQueryError
from .models import (
BgpPeerGroupIntent,
BgpRouterConfigIntent,
BgpSessionIntent,
DeviceIntent,
EvpnInstanceIntent,
MlagDomainIntent,
MlagPeerConfigIntent,
VlanIntent,
VrfIntent,
VtepIntent,
)
# TTL for cache entries in seconds
_CACHE_TTL: int = 60
class FabricInfrahubClient:
"""
Async client for querying fabric intent data from a remote Infrahub instance.
Wraps `infrahub-sdk` with structured Pydantic model responses, error handling,
and a simple TTL-based cache to avoid redundant queries.
Usage as async context manager::
async with FabricInfrahubClient(url="http://infrahub:8080", api_token="xxx") as client:
device = await client.get_device("leaf1")
vlans = await client.get_device_vlans("leaf1")
"""
def __init__(self, url: str, api_token: str, branch: str = "main") -> None:
"""
Initialize the Infrahub client.
Args:
url: Base URL of the Infrahub instance (e.g. "http://infrahub:8080")
api_token: API token for authentication
branch: Default branch to query (default: "main")
"""
self._url = url
self._branch = branch
config = Config(address=url, api_token=api_token, default_branch=branch)
self._client = InfrahubClient(config=config)
self._cache: dict[str, tuple[Any, float]] = {}
async def __aenter__(self) -> FabricInfrahubClient:
"""Enter async context manager."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Exit async context manager."""
pass
# =========================================================================
# Cache helpers
# =========================================================================
def _cache_get(self, key: str) -> Any | None:
"""Return cached value if still valid, else None."""
if key in self._cache:
value, ts = self._cache[key]
if time.monotonic() - ts < _CACHE_TTL:
return value
del self._cache[key]
return None
def _cache_set(self, key: str, value: Any) -> None:
"""Store a value in the cache with the current timestamp."""
self._cache[key] = (value, time.monotonic())
# =========================================================================
# Device
# =========================================================================
async def get_device(self, name: str) -> DeviceIntent:
"""
Fetch a device by name and return a structured DeviceIntent.
Args:
name: Device name as stored in Infrahub
Returns:
DeviceIntent with device attributes and resolved relationship names
Raises:
InfrahubNotFoundError: When no device with the given name exists
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"device:{name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
node = await self._client.get(kind="InfraDevice", name__value=name)
except Exception as e:
msg = str(e).lower()
if "not found" in msg or "no node" in msg:
raise InfrahubNotFoundError(f"Device '{name}' not found in Infrahub") from e
raise InfrahubQueryError(f"Failed to query device '{name}': {e}") from e
if node is None:
raise InfrahubNotFoundError(f"Device '{name}' not found in Infrahub")
# Resolve cardinality-one relationships
platform: str | None = None
site: str | None = None
asn: int | None = None
try:
if node.platform.peer is None:
await node.platform.fetch()
if node.platform.peer is not None:
platform = node.platform.peer.name.value
except Exception:
pass
try:
if node.site.peer is None:
await node.site.fetch()
if node.site.peer is not None:
site = node.site.peer.name.value
except Exception:
pass
try:
if node.asn.peer is None:
await node.asn.fetch()
if node.asn.peer is not None:
asn = node.asn.peer.asn.value
except Exception:
pass
result = DeviceIntent(
name=node.name.value,
role=node.role.value,
status=node.status.value,
platform=platform,
site=site,
asn=asn,
)
self._cache_set(cache_key, result)
return result
# =========================================================================
# VLANs
# =========================================================================
async def get_device_vlans(self, device_name: str) -> list[VlanIntent]:
"""
Fetch all VLANs associated with a device via its VTEP's vlan_vni_mappings.
Falls back to VLANs linked through SVI interfaces if no VTEP is found.
Args:
device_name: Device name as stored in Infrahub
Returns:
List of VlanIntent objects for all VLANs associated with this device
Raises:
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"device_vlans:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
vtep_nodes = await self._client.filters(kind="InfraVTEP", include=["vlan_vni_mappings"])
except Exception as e:
raise InfrahubQueryError(f"Failed to query VTEPs: {e}") from e
# Find VTEP belonging to this device
device_vtep = None
for vtep in vtep_nodes:
try:
if vtep.device.peer is None:
await vtep.device.fetch()
if vtep.device.peer is not None and vtep.device.peer.name.value == device_name:
device_vtep = vtep
break
except Exception:
continue
vlans: list[VlanIntent] = []
if device_vtep is not None:
for mapping_rel in device_vtep.vlan_vni_mappings.peers:
try:
mapping = mapping_rel.peer
await mapping.vlan.fetch()
vlan_node = mapping.vlan.peer
if vlan_node is None:
continue
vni_val: int | None = None
try:
await vlan_node.vni.fetch()
if vlan_node.vni.peer is not None:
vni_val = vlan_node.vni.peer.vni.value
except Exception:
pass
vlans.append(
VlanIntent(
vlan_id=vlan_node.vlan_id.value,
name=vlan_node.name.value,
status=vlan_node.status.value,
vlan_type=vlan_node.vlan_type.value,
vni=vni_val,
stp_enabled=vlan_node.stp_enabled.value,
)
)
except Exception:
continue
else:
# Fallback: discover VLANs via InfraInterfaceVlan → VLAN
try:
svi_nodes = await self._client.filters(kind="InfraInterfaceVlan")
for svi in svi_nodes:
try:
if svi.device.peer is None:
await svi.device.fetch()
if svi.device.peer is None or svi.device.peer.name.value != device_name:
continue
if svi.vlan.peer is None:
await svi.vlan.fetch()
vlan_node = svi.vlan.peer
if vlan_node is None:
continue
vni_val = None
try:
await vlan_node.vni.fetch()
if vlan_node.vni.peer is not None:
vni_val = vlan_node.vni.peer.vni.value
except Exception:
pass
vlans.append(
VlanIntent(
vlan_id=vlan_node.vlan_id.value,
name=vlan_node.name.value,
status=vlan_node.status.value,
vlan_type=vlan_node.vlan_type.value,
vni=vni_val,
stp_enabled=vlan_node.stp_enabled.value,
)
)
except Exception:
continue
except Exception as e:
raise InfrahubQueryError(
f"Failed to query SVI interfaces for device '{device_name}': {e}"
) from e
self._cache_set(cache_key, vlans)
return vlans
# =========================================================================
# BGP
# =========================================================================
async def _get_bgp_router_config_node(self, device_name: str) -> Any:
"""
Fetch the BGPRouterConfig node for a device.
Args:
device_name: Device name as stored in Infrahub
Returns:
The raw InfraBGPRouterConfig SDK node
Raises:
InfrahubNotFoundError: When no BGP config exists for this device
InfrahubQueryError: When the query fails unexpectedly
"""
try:
nodes = await self._client.filters(
kind="InfraBGPRouterConfig",
include=["peer_groups", "sessions"],
)
except Exception as e:
raise InfrahubQueryError(
f"Failed to query BGP router config for '{device_name}': {e}"
) from e
for node in nodes:
try:
if node.device.peer is None:
await node.device.fetch()
if node.device.peer is not None and node.device.peer.name.value == device_name:
return node
except Exception:
continue
raise InfrahubNotFoundError(f"No BGP router config found for device '{device_name}'")
async def get_device_bgp_config(self, device_name: str) -> BgpRouterConfigIntent:
"""
Fetch the BGP router configuration for a device.
Args:
device_name: Device name as stored in Infrahub
Returns:
BgpRouterConfigIntent with BGP router configuration
Raises:
InfrahubNotFoundError: When no BGP config exists for this device
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"bgp_config:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
node = await self._get_bgp_router_config_node(device_name)
local_asn: int = 0
try:
if node.local_asn.peer is None:
await node.local_asn.fetch()
if node.local_asn.peer is not None:
local_asn = node.local_asn.peer.asn.value
except Exception:
pass
result = BgpRouterConfigIntent(
router_id=node.router_id.value,
local_asn=local_asn,
default_ipv4_unicast=node.default_ipv4_unicast.value,
ecmp_max_paths=node.ecmp_max_paths.value,
)
self._cache_set(cache_key, result)
return result
async def get_device_bgp_peer_groups(self, device_name: str) -> list[BgpPeerGroupIntent]:
"""
Fetch all BGP peer groups for a device via its BGPRouterConfig.
Args:
device_name: Device name as stored in Infrahub
Returns:
List of BgpPeerGroupIntent objects
Raises:
InfrahubNotFoundError: When no BGP config exists for this device
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"bgp_peer_groups:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
bgp_node = await self._get_bgp_router_config_node(device_name)
peer_groups: list[BgpPeerGroupIntent] = []
for pg_rel in bgp_node.peer_groups.peers:
try:
pg = pg_rel.peer
remote_asn: int | None = None
try:
if pg.remote_asn.peer is None:
await pg.remote_asn.fetch()
if pg.remote_asn.peer is not None:
remote_asn = pg.remote_asn.peer.asn.value
except Exception:
pass
peer_groups.append(
BgpPeerGroupIntent(
name=pg.name.value,
peer_group_type=pg.peer_group_type.value,
remote_asn=remote_asn,
update_source=pg.update_source.value if pg.update_source.value else None,
send_community=pg.send_community.value,
ebgp_multihop=pg.ebgp_multihop.value if pg.ebgp_multihop.value else None,
next_hop_unchanged=pg.next_hop_unchanged.value,
)
)
except Exception:
continue
self._cache_set(cache_key, peer_groups)
return peer_groups
async def get_device_bgp_sessions(self, device_name: str) -> list[BgpSessionIntent]:
"""
Fetch all BGP sessions for a device via its BGPRouterConfig.
Args:
device_name: Device name as stored in Infrahub
Returns:
List of BgpSessionIntent objects
Raises:
InfrahubNotFoundError: When no BGP config exists for this device
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"bgp_sessions:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
bgp_node = await self._get_bgp_router_config_node(device_name)
sessions: list[BgpSessionIntent] = []
for sess_rel in bgp_node.sessions.peers:
try:
sess = sess_rel.peer
peer_group: str | None = None
try:
if sess.peer_group.peer is None:
await sess.peer_group.fetch()
if sess.peer_group.peer is not None:
peer_group = sess.peer_group.peer.name.value
except Exception:
pass
remote_asn: int | None = None
try:
if sess.remote_asn.peer is None:
await sess.remote_asn.fetch()
if sess.remote_asn.peer is not None:
remote_asn = sess.remote_asn.peer.asn.value
except Exception:
pass
sessions.append(
BgpSessionIntent(
peer_address=sess.peer_address.value,
description=sess.description.value if sess.description.value else None,
enabled=sess.enabled.value,
peer_group=peer_group,
remote_asn=remote_asn,
)
)
except Exception:
continue
self._cache_set(cache_key, sessions)
return sessions
# =========================================================================
# VRF
# =========================================================================
async def get_device_vrfs(self, device_name: str) -> list[VrfIntent]:
"""
Fetch all VRFs assigned to a device via VRFDeviceAssignment.
Args:
device_name: Device name as stored in Infrahub
Returns:
List of VrfIntent objects with resolved route targets
Raises:
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"device_vrfs:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
assignments = await self._client.filters(
kind="InfraVRFDeviceAssignment",
include=["import_targets", "export_targets"],
)
except Exception as e:
raise InfrahubQueryError(
f"Failed to query VRF assignments for device '{device_name}': {e}"
) from e
vrfs: list[VrfIntent] = []
for asgn in assignments:
try:
if asgn.device.peer is None:
await asgn.device.fetch()
if asgn.device.peer is None or asgn.device.peer.name.value != device_name:
continue
if asgn.vrf.peer is None:
await asgn.vrf.fetch()
vrf_node = asgn.vrf.peer
if vrf_node is None:
continue
l3vni: int | None = None
try:
if vrf_node.l3vni.peer is None:
await vrf_node.l3vni.fetch()
if vrf_node.l3vni.peer is not None:
l3vni = vrf_node.l3vni.peer.vni.value
except Exception:
pass
import_targets: list[str] = []
for rt_rel in asgn.import_targets.peers:
try:
import_targets.append(rt_rel.peer.target.value)
except Exception:
pass
export_targets: list[str] = []
for rt_rel in asgn.export_targets.peers:
try:
export_targets.append(rt_rel.peer.target.value)
except Exception:
pass
rd = asgn.route_distinguisher.value if asgn.route_distinguisher.value else None
vrfs.append(
VrfIntent(
name=vrf_node.name.value,
route_distinguisher=rd,
vrf_id=vrf_node.vrf_id.value if vrf_node.vrf_id.value else None,
l3vni=l3vni,
import_targets=import_targets,
export_targets=export_targets,
)
)
except Exception:
continue
self._cache_set(cache_key, vrfs)
return vrfs
# =========================================================================
# VTEP
# =========================================================================
async def get_device_vtep(self, device_name: str) -> VtepIntent | None:
"""
Fetch the VTEP configuration for a device.
Args:
device_name: Device name as stored in Infrahub
Returns:
VtepIntent if a VTEP exists for this device, None otherwise
Raises:
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"device_vtep:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
vtep_nodes = await self._client.filters(kind="InfraVTEP", include=["vlan_vni_mappings"])
except Exception as e:
raise InfrahubQueryError(f"Failed to query VTEPs: {e}") from e
for vtep in vtep_nodes:
try:
if vtep.device.peer is None:
await vtep.device.fetch()
if vtep.device.peer is None or vtep.device.peer.name.value != device_name:
continue
mappings: list[tuple[int, int]] = []
for mapping_rel in vtep.vlan_vni_mappings.peers:
try:
mapping = mapping_rel.peer
await mapping.vlan.fetch()
await mapping.vni.fetch()
vlan_node = mapping.vlan.peer
vni_node = mapping.vni.peer
if vlan_node is not None and vni_node is not None:
mappings.append((vlan_node.vlan_id.value, vni_node.vni.value))
except Exception:
continue
result = VtepIntent(
source_address=vtep.source_address.value,
udp_port=vtep.udp_port.value,
learn_restrict=vtep.learn_restrict.value,
vlan_vni_mappings=mappings,
)
self._cache_set(cache_key, result)
return result
except Exception:
continue
self._cache_set(cache_key, None)
return None
# =========================================================================
# EVPN Instances
# =========================================================================
async def get_device_evpn_instances(self, device_name: str) -> list[EvpnInstanceIntent]:
"""
Fetch all EVPN instances for a device.
Args:
device_name: Device name as stored in Infrahub
Returns:
List of EvpnInstanceIntent objects
Raises:
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"evpn_instances:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
nodes = await self._client.filters(kind="InfraEVPNInstance")
except Exception as e:
raise InfrahubQueryError(
f"Failed to query EVPN instances for device '{device_name}': {e}"
) from e
instances: list[EvpnInstanceIntent] = []
for node in nodes:
try:
if node.device.peer is None:
await node.device.fetch()
if node.device.peer is None or node.device.peer.name.value != device_name:
continue
if node.vlan.peer is None:
await node.vlan.fetch()
vlan_node = node.vlan.peer
vlan_id: int = vlan_node.vlan_id.value if vlan_node is not None else 0
instances.append(
EvpnInstanceIntent(
route_distinguisher=node.route_distinguisher.value,
route_target_import=node.route_target_import.value,
route_target_export=node.route_target_export.value,
redistribute_learned=node.redistribute_learned.value,
vlan_id=vlan_id,
)
)
except Exception:
continue
self._cache_set(cache_key, instances)
return instances
# =========================================================================
# MLAG
# =========================================================================
async def get_mlag_domain(self, device_name: str) -> MlagDomainIntent | None:
"""
Fetch the MLAG domain that contains the given device.
Args:
device_name: Device name as stored in Infrahub
Returns:
MlagDomainIntent if an MLAG domain exists for this device, None otherwise
Raises:
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"mlag_domain:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
nodes = await self._client.filters(kind="InfraMlagDomain", include=["devices"])
except Exception as e:
raise InfrahubQueryError(
f"Failed to query MLAG domains for device '{device_name}': {e}"
) from e
for node in nodes:
try:
peer_device_names: list[str] = []
found = False
for dev_rel in node.devices.peers:
try:
dev_name = dev_rel.peer.name.value
peer_device_names.append(dev_name)
if dev_name == device_name:
found = True
except Exception:
pass
if not found:
continue
result = MlagDomainIntent(
domain_id=node.domain_id.value,
virtual_mac=node.virtual_mac.value,
heartbeat_vrf=node.heartbeat_vrf.value,
dual_primary_detection=node.dual_primary_detection.value,
dual_primary_delay=node.dual_primary_delay.value,
dual_primary_action=node.dual_primary_action.value,
peer_devices=peer_device_names,
)
self._cache_set(cache_key, result)
return result
except Exception:
continue
self._cache_set(cache_key, None)
return None
async def get_mlag_peer_config(self, device_name: str) -> MlagPeerConfigIntent | None:
"""
Fetch the MLAG peer configuration for a device.
Args:
device_name: Device name as stored in Infrahub
Returns:
MlagPeerConfigIntent if MLAG peer config exists for this device, None otherwise
Raises:
InfrahubQueryError: When the query fails unexpectedly
"""
cache_key = f"mlag_peer_config:{device_name}"
cached = self._cache_get(cache_key)
if cached is not None:
return cached
try:
nodes = await self._client.filters(kind="InfraMlagPeerConfig")
except Exception as e:
raise InfrahubQueryError(
f"Failed to query MLAG peer configs for device '{device_name}': {e}"
) from e
for node in nodes:
try:
if node.device.peer is None:
await node.device.fetch()
if node.device.peer is None or node.device.peer.name.value != device_name:
continue
peer_link: str = ""
try:
if node.peer_link.peer is None:
await node.peer_link.fetch()
if node.peer_link.peer is not None:
peer_link = node.peer_link.peer.name.value
except Exception:
pass
result = MlagPeerConfigIntent(
local_interface_ip=node.local_interface_ip.value,
peer_address=node.peer_address.value,
heartbeat_peer_ip=node.heartbeat_peer_ip.value,
peer_link=peer_link,
)
self._cache_set(cache_key, result)
return result
except Exception:
continue
self._cache_set(cache_key, None)
return None

View File

@@ -0,0 +1,31 @@
"""
Infrahub Client Exceptions.
Custom exception hierarchy for the FabricInfrahubClient.
"""
from __future__ import annotations
class InfrahubClientError(Exception):
"""Base exception for all Infrahub client errors."""
pass
class InfrahubConnectionError(InfrahubClientError):
"""Raised when connection to the Infrahub instance fails."""
pass
class InfrahubQueryError(InfrahubClientError):
"""Raised when a query to Infrahub fails or returns unexpected data."""
pass
class InfrahubNotFoundError(InfrahubClientError):
"""Raised when a requested node is not found in Infrahub."""
pass

144
src/infrahub/models.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Pydantic models for Infrahub fabric intent data.
These immutable models represent the structured intent retrieved from Infrahub
and are used throughout the orchestrator for typed access to fabric state.
"""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict
class DeviceIntent(BaseModel):
"""Represents a network device from Infrahub."""
model_config = ConfigDict(frozen=True)
name: str
role: str
status: str
platform: str | None
site: str | None
asn: int | None
class VlanIntent(BaseModel):
"""Represents a VLAN from Infrahub."""
model_config = ConfigDict(frozen=True)
vlan_id: int
name: str
status: str
vlan_type: str
vni: int | None
stp_enabled: bool
class VniIntent(BaseModel):
"""Represents a VNI (VXLAN Network Identifier) from Infrahub."""
model_config = ConfigDict(frozen=True)
vni: int
vni_type: str
description: str | None
class BgpRouterConfigIntent(BaseModel):
"""Represents a BGP router configuration from Infrahub."""
model_config = ConfigDict(frozen=True)
router_id: str
local_asn: int
default_ipv4_unicast: bool
ecmp_max_paths: int
class BgpPeerGroupIntent(BaseModel):
"""Represents a BGP peer group from Infrahub."""
model_config = ConfigDict(frozen=True)
name: str
peer_group_type: str
remote_asn: int | None
update_source: str | None
send_community: str
ebgp_multihop: int | None
next_hop_unchanged: bool
class BgpSessionIntent(BaseModel):
"""Represents a BGP session from Infrahub."""
model_config = ConfigDict(frozen=True)
peer_address: str
description: str | None
enabled: bool
peer_group: str | None
remote_asn: int | None
class VrfIntent(BaseModel):
"""Represents a VRF from Infrahub."""
model_config = ConfigDict(frozen=True)
name: str
route_distinguisher: str | None
vrf_id: int | None
l3vni: int | None
import_targets: list[str]
export_targets: list[str]
class VtepIntent(BaseModel):
"""Represents a VTEP (VXLAN Tunnel Endpoint) from Infrahub."""
model_config = ConfigDict(frozen=True)
source_address: str
udp_port: int
learn_restrict: bool
vlan_vni_mappings: list[tuple[int, int]] # (vlan_id, vni)
class MlagDomainIntent(BaseModel):
"""Represents an MLAG domain from Infrahub."""
model_config = ConfigDict(frozen=True)
domain_id: str
virtual_mac: str
heartbeat_vrf: str
dual_primary_detection: bool
dual_primary_delay: int
dual_primary_action: str
peer_devices: list[str]
class MlagPeerConfigIntent(BaseModel):
"""Represents an MLAG peer configuration from Infrahub."""
model_config = ConfigDict(frozen=True)
local_interface_ip: str
peer_address: str
heartbeat_peer_ip: str
peer_link: str
class EvpnInstanceIntent(BaseModel):
"""Represents an EVPN instance from Infrahub."""
model_config = ConfigDict(frozen=True)
route_distinguisher: str
route_target_import: str
route_target_export: str
redistribute_learned: bool
vlan_id: int

View File

@@ -0,0 +1,617 @@
"""
Unit tests for FabricInfrahubClient.
All SDK interactions are mocked with AsyncMock — no real Infrahub instance required.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest # noqa: I001
from src.infrahub.client import FabricInfrahubClient
from src.infrahub.exceptions import InfrahubNotFoundError, InfrahubQueryError
from src.infrahub.models import (
BgpPeerGroupIntent,
BgpRouterConfigIntent,
BgpSessionIntent,
DeviceIntent,
EvpnInstanceIntent,
MlagDomainIntent,
MlagPeerConfigIntent,
VlanIntent,
VrfIntent,
VtepIntent,
)
# =============================================================================
# Helpers — mock node factories
# =============================================================================
def _attr(value):
"""Return a simple mock whose .value is set."""
m = MagicMock()
m.value = value
return m
def _rel_one(peer_mock):
"""Return a mock relationship (cardinality one) with an already-loaded peer."""
m = MagicMock()
m.peer = peer_mock
m.fetch = AsyncMock()
return m
def _rel_many(peer_mocks):
"""Return a mock relationship (cardinality many) whose .peers yields (peer_rel, peer)."""
peers = []
for p in peer_mocks:
rel = MagicMock()
rel.peer = p
peers.append(rel)
m = MagicMock()
m.peers = peers
return m
def make_device_node(
name="leaf1",
role="leaf",
status="active",
platform_name="EOS",
site_name="dc1",
asn_val=65001,
):
node = MagicMock()
node.name = _attr(name)
node.role = _attr(role)
node.status = _attr(status)
platform = MagicMock()
platform.name = _attr(platform_name)
node.platform = _rel_one(platform)
site = MagicMock()
site.name = _attr(site_name)
node.site = _rel_one(site)
asn_node = MagicMock()
asn_node.asn = _attr(asn_val)
node.asn = _rel_one(asn_node)
return node
def make_vlan_node(vlan_id=10, name="VLAN10", status="active", vlan_type="standard", vni=10010):
node = MagicMock()
node.vlan_id = _attr(vlan_id)
node.name = _attr(name)
node.status = _attr(status)
node.vlan_type = _attr(vlan_type)
node.stp_enabled = _attr(True)
vni_node = MagicMock()
vni_node.vni = _attr(vni)
node.vni = _rel_one(vni_node)
return node
def make_bgp_config_node(router_id="10.0.0.1", asn=65001, device_name="leaf1"):
node = MagicMock()
node.router_id = _attr(router_id)
node.default_ipv4_unicast = _attr(True)
node.ecmp_max_paths = _attr(4)
asn_node = MagicMock()
asn_node.asn = _attr(asn)
node.local_asn = _rel_one(asn_node)
device = MagicMock()
device.name = _attr(device_name)
node.device = _rel_one(device)
node.peer_groups = _rel_many([])
node.sessions = _rel_many([])
return node
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_sdk_client():
"""Patch InfrahubClient and Config so no real connection is made."""
with (
patch("src.infrahub.client.InfrahubClient") as mock_cls,
patch("src.infrahub.client.Config"),
):
sdk = AsyncMock()
mock_cls.return_value = sdk
yield sdk
@pytest.fixture
def client(mock_sdk_client):
"""Return a FabricInfrahubClient backed by a mocked SDK."""
return FabricInfrahubClient(url="http://infrahub:8080", api_token="test-token", branch="main")
# =============================================================================
# Context manager
# =============================================================================
@pytest.mark.asyncio
async def test_context_manager(mock_sdk_client):
async with FabricInfrahubClient(url="http://infrahub:8080", api_token="test-token") as c:
assert isinstance(c, FabricInfrahubClient)
# =============================================================================
# get_device
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_returns_device_intent(client, mock_sdk_client):
mock_sdk_client.get = AsyncMock(return_value=make_device_node())
result = await client.get_device("leaf1")
assert isinstance(result, DeviceIntent)
assert result.name == "leaf1"
assert result.role == "leaf"
assert result.status == "active"
assert result.platform == "EOS"
assert result.site == "dc1"
assert result.asn == 65001
@pytest.mark.asyncio
async def test_get_device_not_found_raises(client, mock_sdk_client):
mock_sdk_client.get = AsyncMock(side_effect=Exception("not found"))
with pytest.raises(InfrahubNotFoundError):
await client.get_device("ghost-device")
@pytest.mark.asyncio
async def test_get_device_none_raises_not_found(client, mock_sdk_client):
mock_sdk_client.get = AsyncMock(return_value=None)
with pytest.raises(InfrahubNotFoundError):
await client.get_device("missing")
@pytest.mark.asyncio
async def test_get_device_query_error(client, mock_sdk_client):
mock_sdk_client.get = AsyncMock(side_effect=Exception("connection refused"))
with pytest.raises(InfrahubQueryError):
await client.get_device("leaf1")
# =============================================================================
# Caching
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_caches_result(client, mock_sdk_client):
mock_sdk_client.get = AsyncMock(return_value=make_device_node())
result1 = await client.get_device("leaf1")
result2 = await client.get_device("leaf1")
assert result1 == result2
# SDK should only be called once despite two client calls
mock_sdk_client.get.assert_called_once()
@pytest.mark.asyncio
async def test_cache_separate_keys_per_device(client, mock_sdk_client):
mock_sdk_client.get = AsyncMock(
side_effect=[
make_device_node(name="leaf1"),
make_device_node(name="leaf2", asn=65002),
]
)
r1 = await client.get_device("leaf1")
r2 = await client.get_device("leaf2")
assert r1.name == "leaf1"
assert r2.name == "leaf2"
assert mock_sdk_client.get.call_count == 2
# =============================================================================
# Branch selection
# =============================================================================
def test_branch_passed_to_config():
with (
patch("src.infrahub.client.InfrahubClient"),
patch("src.infrahub.client.Config") as mock_cfg,
):
FabricInfrahubClient(url="http://infrahub:8080", api_token="tok", branch="proposed-change")
mock_cfg.assert_called_once_with(
address="http://infrahub:8080",
api_token="tok",
default_branch="proposed-change",
)
# =============================================================================
# get_device_bgp_config
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_bgp_config(client, mock_sdk_client):
bgp_node = make_bgp_config_node(router_id="10.0.0.1", asn=65001, device_name="leaf1")
mock_sdk_client.filters = AsyncMock(return_value=[bgp_node])
result = await client.get_device_bgp_config("leaf1")
assert isinstance(result, BgpRouterConfigIntent)
assert result.router_id == "10.0.0.1"
assert result.local_asn == 65001
assert result.default_ipv4_unicast is True
assert result.ecmp_max_paths == 4
@pytest.mark.asyncio
async def test_get_device_bgp_config_not_found(client, mock_sdk_client):
mock_sdk_client.filters = AsyncMock(return_value=[])
with pytest.raises(InfrahubNotFoundError):
await client.get_device_bgp_config("leaf1")
# =============================================================================
# get_device_bgp_peer_groups
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_bgp_peer_groups(client, mock_sdk_client):
pg_node = MagicMock()
pg_node.name = _attr("EVPN-PEERS")
pg_node.peer_group_type = _attr("evpn")
pg_node.update_source = _attr("Loopback0")
pg_node.send_community = _attr("extended")
pg_node.ebgp_multihop = _attr(3)
pg_node.next_hop_unchanged = _attr(True)
remote_asn_node = MagicMock()
remote_asn_node.asn = _attr(65000)
pg_node.remote_asn = _rel_one(remote_asn_node)
bgp_node = make_bgp_config_node(device_name="leaf1")
bgp_node.peer_groups = _rel_many([pg_node])
mock_sdk_client.filters = AsyncMock(return_value=[bgp_node])
result = await client.get_device_bgp_peer_groups("leaf1")
assert len(result) == 1
pg = result[0]
assert isinstance(pg, BgpPeerGroupIntent)
assert pg.name == "EVPN-PEERS"
assert pg.peer_group_type == "evpn"
assert pg.remote_asn == 65000
assert pg.send_community == "extended"
assert pg.next_hop_unchanged is True
# =============================================================================
# get_device_bgp_sessions
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_bgp_sessions(client, mock_sdk_client):
sess_node = MagicMock()
sess_node.peer_address = _attr("10.0.0.2")
sess_node.description = _attr("to-spine1")
sess_node.enabled = _attr(True)
pg = MagicMock()
pg.name = _attr("UNDERLAY")
sess_node.peer_group = _rel_one(pg)
remote_asn = MagicMock()
remote_asn.asn = _attr(65000)
sess_node.remote_asn = _rel_one(remote_asn)
bgp_node = make_bgp_config_node(device_name="leaf1")
bgp_node.sessions = _rel_many([sess_node])
mock_sdk_client.filters = AsyncMock(return_value=[bgp_node])
result = await client.get_device_bgp_sessions("leaf1")
assert len(result) == 1
sess = result[0]
assert isinstance(sess, BgpSessionIntent)
assert sess.peer_address == "10.0.0.2"
assert sess.description == "to-spine1"
assert sess.enabled is True
assert sess.peer_group == "UNDERLAY"
assert sess.remote_asn == 65000
# =============================================================================
# get_device_vrfs
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_vrfs(client, mock_sdk_client):
rt_import = MagicMock()
rt_import.target = _attr("65000:100")
rt_export = MagicMock()
rt_export.target = _attr("65000:100")
vni_node = MagicMock()
vni_node.vni = _attr(10000)
vrf_node = MagicMock()
vrf_node.name = _attr("PROD")
vrf_node.vrf_id = _attr(100)
vrf_node.l3vni = _rel_one(vni_node)
asgn = MagicMock()
asgn.route_distinguisher = _attr("10.0.0.1:100")
asgn.vrf = _rel_one(vrf_node)
device = MagicMock()
device.name = _attr("leaf1")
asgn.device = _rel_one(device)
asgn.import_targets = _rel_many([rt_import])
asgn.export_targets = _rel_many([rt_export])
mock_sdk_client.filters = AsyncMock(return_value=[asgn])
result = await client.get_device_vrfs("leaf1")
assert len(result) == 1
vrf = result[0]
assert isinstance(vrf, VrfIntent)
assert vrf.name == "PROD"
assert vrf.route_distinguisher == "10.0.0.1:100"
assert vrf.vrf_id == 100
assert vrf.l3vni == 10000
assert vrf.import_targets == ["65000:100"]
assert vrf.export_targets == ["65000:100"]
# =============================================================================
# get_device_vtep
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_vtep(client, mock_sdk_client):
vlan_node = MagicMock()
vlan_node.vlan_id = _attr(10)
vni_node = MagicMock()
vni_node.vni = _attr(10010)
mapping = MagicMock()
mapping.vlan = _rel_one(vlan_node)
mapping.vni = _rel_one(vni_node)
mapping.vlan.fetch = AsyncMock()
mapping.vni.fetch = AsyncMock()
device = MagicMock()
device.name = _attr("leaf1")
vtep = MagicMock()
vtep.source_address = _attr("10.0.0.1")
vtep.udp_port = _attr(4789)
vtep.learn_restrict = _attr(False)
vtep.device = _rel_one(device)
vtep.vlan_vni_mappings = _rel_many([mapping])
mock_sdk_client.filters = AsyncMock(return_value=[vtep])
result = await client.get_device_vtep("leaf1")
assert isinstance(result, VtepIntent)
assert result.source_address == "10.0.0.1"
assert result.udp_port == 4789
assert result.learn_restrict is False
assert (10, 10010) in result.vlan_vni_mappings
@pytest.mark.asyncio
async def test_get_device_vtep_none_when_missing(client, mock_sdk_client):
mock_sdk_client.filters = AsyncMock(return_value=[])
result = await client.get_device_vtep("leaf1")
assert result is None
# =============================================================================
# get_device_evpn_instances
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_evpn_instances(client, mock_sdk_client):
vlan_node = MagicMock()
vlan_node.vlan_id = _attr(10)
device = MagicMock()
device.name = _attr("leaf1")
node = MagicMock()
node.route_distinguisher = _attr("10.0.0.1:10")
node.route_target_import = _attr("65000:10")
node.route_target_export = _attr("65000:10")
node.redistribute_learned = _attr(True)
node.device = _rel_one(device)
node.vlan = _rel_one(vlan_node)
mock_sdk_client.filters = AsyncMock(return_value=[node])
result = await client.get_device_evpn_instances("leaf1")
assert len(result) == 1
ev = result[0]
assert isinstance(ev, EvpnInstanceIntent)
assert ev.route_distinguisher == "10.0.0.1:10"
assert ev.route_target_import == "65000:10"
assert ev.route_target_export == "65000:10"
assert ev.redistribute_learned is True
assert ev.vlan_id == 10
# =============================================================================
# get_mlag_domain
# =============================================================================
@pytest.mark.asyncio
async def test_get_mlag_domain(client, mock_sdk_client):
dev1 = MagicMock()
dev1.name = _attr("leaf1")
dev2 = MagicMock()
dev2.name = _attr("leaf2")
domain = MagicMock()
domain.domain_id = _attr("1")
domain.virtual_mac = _attr("00:1c:73:00:00:01")
domain.heartbeat_vrf = _attr("MGMT")
domain.dual_primary_detection = _attr(True)
domain.dual_primary_delay = _attr(10)
domain.dual_primary_action = _attr("errdisable")
domain.devices = _rel_many([dev1, dev2])
mock_sdk_client.filters = AsyncMock(return_value=[domain])
result = await client.get_mlag_domain("leaf1")
assert isinstance(result, MlagDomainIntent)
assert result.domain_id == "1"
assert result.virtual_mac == "00:1c:73:00:00:01"
assert "leaf1" in result.peer_devices
assert "leaf2" in result.peer_devices
@pytest.mark.asyncio
async def test_get_mlag_domain_returns_none_when_not_member(client, mock_sdk_client):
dev_other = MagicMock()
dev_other.name = _attr("spine1")
domain = MagicMock()
domain.domain_id = _attr("99")
domain.devices = _rel_many([dev_other])
mock_sdk_client.filters = AsyncMock(return_value=[domain])
result = await client.get_mlag_domain("leaf1")
assert result is None
# =============================================================================
# get_mlag_peer_config
# =============================================================================
@pytest.mark.asyncio
async def test_get_mlag_peer_config(client, mock_sdk_client):
lag = MagicMock()
lag.name = _attr("Port-Channel1")
device = MagicMock()
device.name = _attr("leaf1")
node = MagicMock()
node.local_interface_ip = _attr("10.255.255.0/31")
node.peer_address = _attr("10.255.255.1")
node.heartbeat_peer_ip = _attr("192.168.0.2")
node.device = _rel_one(device)
node.peer_link = _rel_one(lag)
mock_sdk_client.filters = AsyncMock(return_value=[node])
result = await client.get_mlag_peer_config("leaf1")
assert isinstance(result, MlagPeerConfigIntent)
assert result.local_interface_ip == "10.255.255.0/31"
assert result.peer_address == "10.255.255.1"
assert result.heartbeat_peer_ip == "192.168.0.2"
assert result.peer_link == "Port-Channel1"
@pytest.mark.asyncio
async def test_get_mlag_peer_config_returns_none_when_missing(client, mock_sdk_client):
mock_sdk_client.filters = AsyncMock(return_value=[])
result = await client.get_mlag_peer_config("spine1")
assert result is None
# =============================================================================
# get_device_vlans
# =============================================================================
@pytest.mark.asyncio
async def test_get_device_vlans_via_vtep(client, mock_sdk_client):
vlan_node = make_vlan_node(vlan_id=10, name="VLAN10", vni=10010)
mapping = MagicMock()
mapping.vlan = _rel_one(vlan_node)
mapping.vlan.fetch = AsyncMock()
device = MagicMock()
device.name = _attr("leaf1")
vtep = MagicMock()
vtep.device = _rel_one(device)
vtep.vlan_vni_mappings = _rel_many([mapping])
mock_sdk_client.filters = AsyncMock(return_value=[vtep])
result = await client.get_device_vlans("leaf1")
assert len(result) == 1
assert isinstance(result[0], VlanIntent)
assert result[0].vlan_id == 10
assert result[0].vni == 10010
@pytest.mark.asyncio
async def test_get_device_vlans_caches_result(client, mock_sdk_client):
device = MagicMock()
device.name = _attr("leaf1")
vtep = MagicMock()
vtep.device = _rel_one(device)
vtep.vlan_vni_mappings = _rel_many([])
mock_sdk_client.filters = AsyncMock(return_value=[vtep])
r1 = await client.get_device_vlans("leaf1")
r2 = await client.get_device_vlans("leaf1")
assert r1 == r2
mock_sdk_client.filters.assert_called_once()