feat(infrahub): add Infrahub client for fabric intent (#42) #19
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arista-evpn-vxlan-infrahub-client"
|
||||
version = "0.1.0"
|
||||
description = "Infrahub client for Arista EVPN-VXLAN fabric intent"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"infrahub-sdk>=0.16.0",
|
||||
"pydantic>=2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
"pytest-asyncio",
|
||||
"ruff",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/infrahub_client"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "W"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
28
src/infrahub_client/__init__.py
Normal file
28
src/infrahub_client/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Infrahub client for Arista EVPN-VXLAN fabric intent.
|
||||
|
||||
Public API::
|
||||
|
||||
from infrahub_client import FabricInfrahubClient
|
||||
from infrahub_client import (
|
||||
InfrahubClientError,
|
||||
InfrahubConnectionError,
|
||||
InfrahubQueryError,
|
||||
InfrahubNotFoundError,
|
||||
)
|
||||
"""
|
||||
|
||||
from .client import FabricInfrahubClient
|
||||
from .exceptions import (
|
||||
InfrahubClientError,
|
||||
InfrahubConnectionError,
|
||||
InfrahubNotFoundError,
|
||||
InfrahubQueryError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FabricInfrahubClient",
|
||||
"InfrahubClientError",
|
||||
"InfrahubConnectionError",
|
||||
"InfrahubNotFoundError",
|
||||
"InfrahubQueryError",
|
||||
]
|
||||
680
src/infrahub_client/client.py
Normal file
680
src/infrahub_client/client.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""Async Infrahub client for Arista EVPN-VXLAN fabric intent retrieval."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from infrahub_sdk import Config, InfrahubClient
|
||||
|
||||
from .exceptions import InfrahubNotFoundError, InfrahubQueryError
|
||||
from .models import (
|
||||
BgpPeerGroupIntent,
|
||||
BgpRouterConfigIntent,
|
||||
BgpSessionIntent,
|
||||
DeviceIntent,
|
||||
EvpnInstanceIntent,
|
||||
MlagDomainIntent,
|
||||
MlagPeerConfigIntent,
|
||||
VlanIntent,
|
||||
VniIntent, # noqa: F401 — re-exported for convenience
|
||||
VrfIntent,
|
||||
VtepIntent,
|
||||
)
|
||||
|
||||
_CACHE_TTL_SECONDS = 60
|
||||
|
||||
|
||||
class FabricInfrahubClient:
|
||||
"""Async client that wraps the Infrahub SDK to fetch fabric intent data.
|
||||
|
||||
All public methods return immutable Pydantic models. Results are cached
|
||||
in-memory with a TTL of ``_CACHE_TTL_SECONDS`` seconds to avoid redundant
|
||||
queries during a single reconciliation pass.
|
||||
|
||||
Example::
|
||||
|
||||
async with FabricInfrahubClient(url="http://infrahub:8080", api_token="xxx") as client:
|
||||
device = await client.get_device("leaf1")
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, api_token: str, branch: str = "main") -> None:
|
||||
"""Initialise the client.
|
||||
|
||||
Args:
|
||||
url: Base URL of the Infrahub server (e.g. ``http://infrahub:8080``).
|
||||
api_token: API token used for authentication.
|
||||
branch: Infrahub branch to query (default: ``"main"``).
|
||||
"""
|
||||
config = Config(address=url, api_token=api_token, default_branch=branch)
|
||||
self._sdk = InfrahubClient(config=config)
|
||||
self._branch = branch
|
||||
self._cache: dict[str, tuple[Any, float]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async context manager support
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def __aenter__(self) -> "FabricInfrahubClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _cache_get(self, key: str) -> Any | None:
|
||||
"""Return cached value for *key* if still within TTL, else ``None``."""
|
||||
entry = self._cache.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
value, ts = entry
|
||||
if time.monotonic() - ts > _CACHE_TTL_SECONDS:
|
||||
del self._cache[key]
|
||||
return None
|
||||
return value
|
||||
|
||||
def _cache_set(self, key: str, value: Any) -> None:
|
||||
"""Store *value* in the cache under *key* with current timestamp."""
|
||||
self._cache[key] = (value, time.monotonic())
|
||||
|
||||
def _opt_value(self, node: Any, attr: str) -> Any | None:
|
||||
"""Safely read ``node.<attr>.value``, returning ``None`` if missing."""
|
||||
try:
|
||||
attr_obj = getattr(node, attr, None)
|
||||
if attr_obj is None:
|
||||
return None
|
||||
return attr_obj.value
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
async def _fetch_node(self, kind: str, **filters: Any) -> Any:
|
||||
"""Fetch a single node; raise :exc:`InfrahubNotFoundError` if absent.
|
||||
|
||||
Args:
|
||||
kind: Infrahub node kind (e.g. ``"InfraDevice"``).
|
||||
**filters: Keyword arguments forwarded to :meth:`InfrahubClient.get`.
|
||||
|
||||
Returns:
|
||||
The SDK node object.
|
||||
|
||||
Raises:
|
||||
InfrahubNotFoundError: When no node matches the filters.
|
||||
InfrahubQueryError: On SDK/network error.
|
||||
"""
|
||||
try:
|
||||
node = await self._sdk.get(kind=kind, branch=self._branch, **filters)
|
||||
except Exception as exc:
|
||||
msg = str(exc).lower()
|
||||
if "not found" in msg or "does not exist" in msg or "nodnotfound" in msg.replace(" ", ""):
|
||||
raise InfrahubNotFoundError(
|
||||
f"{kind} with filters {filters} not found"
|
||||
) from exc
|
||||
raise InfrahubQueryError(f"Query failed for {kind} with {filters}: {exc}") from exc
|
||||
|
||||
if node is None:
|
||||
raise InfrahubNotFoundError(f"{kind} with filters {filters} not found")
|
||||
return node
|
||||
|
||||
async def _fetch_all(self, kind: str, **filters: Any) -> list[Any]:
|
||||
"""Fetch all nodes of *kind* matching optional *filters*.
|
||||
|
||||
Args:
|
||||
kind: Infrahub node kind.
|
||||
**filters: Optional filter keyword arguments.
|
||||
|
||||
Returns:
|
||||
List of SDK node objects (may be empty).
|
||||
|
||||
Raises:
|
||||
InfrahubQueryError: On SDK/network error.
|
||||
"""
|
||||
try:
|
||||
if filters:
|
||||
return await self._sdk.filters(kind=kind, branch=self._branch, **filters)
|
||||
return await self._sdk.all(kind=kind, branch=self._branch)
|
||||
except Exception as exc:
|
||||
raise InfrahubQueryError(f"Query failed for {kind}: {exc}") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_device(self, name: str) -> DeviceIntent:
|
||||
"""Fetch a device and its core relationships by name.
|
||||
|
||||
Args:
|
||||
name: Device name as stored in Infrahub (e.g. ``"leaf1"``).
|
||||
|
||||
Returns:
|
||||
:class:`DeviceIntent` populated from Infrahub.
|
||||
|
||||
Raises:
|
||||
InfrahubNotFoundError: When no device with *name* exists.
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"device:{name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
node = await self._fetch_node("InfraDevice", name__value=name)
|
||||
|
||||
# Resolve optional one-cardinality relationships
|
||||
platform_name: str | None = None
|
||||
if node.platform and node.platform.initialized:
|
||||
await node.platform.fetch()
|
||||
platform_name = self._opt_value(node.platform.peer, "name")
|
||||
|
||||
site_name: str | None = None
|
||||
if node.site and node.site.initialized:
|
||||
await node.site.fetch()
|
||||
site_name = self._opt_value(node.site.peer, "name")
|
||||
|
||||
asn_value: int | None = None
|
||||
if node.asn and node.asn.initialized:
|
||||
await node.asn.fetch()
|
||||
asn_value = self._opt_value(node.asn.peer, "asn")
|
||||
|
||||
result = DeviceIntent(
|
||||
name=node.name.value,
|
||||
role=node.role.value,
|
||||
status=node.status.value,
|
||||
platform=platform_name,
|
||||
site=site_name,
|
||||
asn=asn_value,
|
||||
)
|
||||
self._cache_set(cache_key, result)
|
||||
return result
|
||||
|
||||
async def get_device_vlans(self, device_name: str) -> list[VlanIntent]:
|
||||
"""Fetch VLANs associated with a device via its VTEP vlan-vni mappings.
|
||||
|
||||
Falls back to VLANs linked through SVI interfaces when no VTEP exists.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
List of :class:`VlanIntent` objects (may be empty).
|
||||
|
||||
Raises:
|
||||
InfrahubNotFoundError: When the device does not exist.
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"device_vlans:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
# Ensure device exists
|
||||
device_node = await self._fetch_node("InfraDevice", name__value=device_name)
|
||||
|
||||
vlan_intents: list[VlanIntent] = []
|
||||
seen_vlan_ids: set[int] = set()
|
||||
|
||||
# Try via VTEP → VlanVniMapping → VLAN
|
||||
try:
|
||||
vtep_nodes = await self._fetch_all(
|
||||
"InfraVTEP", device__name__value=device_name
|
||||
)
|
||||
except InfrahubQueryError:
|
||||
vtep_nodes = []
|
||||
|
||||
for vtep in vtep_nodes:
|
||||
if not (vtep.vlan_vni_mappings and vtep.vlan_vni_mappings.initialized):
|
||||
continue
|
||||
for mapping_rel in vtep.vlan_vni_mappings.peers:
|
||||
mapping = mapping_rel.peer
|
||||
if not (mapping.vlan and mapping.vlan.initialized):
|
||||
continue
|
||||
await mapping.vlan.fetch()
|
||||
vlan_node = mapping.vlan.peer
|
||||
if vlan_node is None:
|
||||
continue
|
||||
vlan_id: int = vlan_node.vlan_id.value
|
||||
if vlan_id in seen_vlan_ids:
|
||||
continue
|
||||
seen_vlan_ids.add(vlan_id)
|
||||
|
||||
vni_val: int | None = None
|
||||
if vlan_node.vni and vlan_node.vni.initialized:
|
||||
await vlan_node.vni.fetch()
|
||||
if vlan_node.vni.peer:
|
||||
vni_val = vlan_node.vni.peer.vni.value
|
||||
|
||||
vlan_intents.append(
|
||||
VlanIntent(
|
||||
vlan_id=vlan_id,
|
||||
name=vlan_node.name.value,
|
||||
status=vlan_node.status.value,
|
||||
vlan_type=vlan_node.vlan_type.value,
|
||||
vni=vni_val,
|
||||
stp_enabled=bool(self._opt_value(vlan_node, "stp_enabled")),
|
||||
)
|
||||
)
|
||||
|
||||
# Fallback: VLANs via SVI interfaces on the device
|
||||
if not vlan_intents:
|
||||
interfaces = await self._fetch_all(
|
||||
"InfraInterfaceVlan", device__name__value=device_name
|
||||
)
|
||||
for intf in interfaces:
|
||||
if not (intf.vlan and intf.vlan.initialized):
|
||||
continue
|
||||
await intf.vlan.fetch()
|
||||
vlan_node = intf.vlan.peer
|
||||
if vlan_node is None:
|
||||
continue
|
||||
vlan_id = vlan_node.vlan_id.value
|
||||
if vlan_id in seen_vlan_ids:
|
||||
continue
|
||||
seen_vlan_ids.add(vlan_id)
|
||||
|
||||
vni_val = None
|
||||
if vlan_node.vni and vlan_node.vni.initialized:
|
||||
await vlan_node.vni.fetch()
|
||||
if vlan_node.vni.peer:
|
||||
vni_val = vlan_node.vni.peer.vni.value
|
||||
|
||||
vlan_intents.append(
|
||||
VlanIntent(
|
||||
vlan_id=vlan_id,
|
||||
name=vlan_node.name.value,
|
||||
status=vlan_node.status.value,
|
||||
vlan_type=vlan_node.vlan_type.value,
|
||||
vni=vni_val,
|
||||
stp_enabled=bool(self._opt_value(vlan_node, "stp_enabled")),
|
||||
)
|
||||
)
|
||||
|
||||
# Suppress unused variable warning — device_node used for existence check
|
||||
_ = device_node
|
||||
|
||||
self._cache_set(cache_key, vlan_intents)
|
||||
return vlan_intents
|
||||
|
||||
async def get_device_bgp_config(self, device_name: str) -> BgpRouterConfigIntent:
|
||||
"""Fetch the BGP router config for a device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
:class:`BgpRouterConfigIntent` for the device.
|
||||
|
||||
Raises:
|
||||
InfrahubNotFoundError: When no BGP config exists for the device.
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"bgp_config:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
node = await self._fetch_node(
|
||||
"InfraBGPRouterConfig", device__name__value=device_name
|
||||
)
|
||||
|
||||
local_asn: int = 0
|
||||
if node.local_asn and node.local_asn.initialized:
|
||||
await node.local_asn.fetch()
|
||||
if node.local_asn.peer:
|
||||
local_asn = node.local_asn.peer.asn.value
|
||||
|
||||
result = BgpRouterConfigIntent(
|
||||
router_id=node.router_id.value,
|
||||
local_asn=local_asn,
|
||||
default_ipv4_unicast=bool(self._opt_value(node, "default_ipv4_unicast")),
|
||||
ecmp_max_paths=int(self._opt_value(node, "ecmp_max_paths") or 1),
|
||||
)
|
||||
self._cache_set(cache_key, result)
|
||||
return result
|
||||
|
||||
async def get_device_bgp_peer_groups(self, device_name: str) -> list[BgpPeerGroupIntent]:
|
||||
"""Fetch all BGP peer groups configured on a device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
List of :class:`BgpPeerGroupIntent` objects.
|
||||
|
||||
Raises:
|
||||
InfrahubNotFoundError: When no BGP config exists for the device.
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"bgp_peer_groups:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
bgp_node = await self._fetch_node(
|
||||
"InfraBGPRouterConfig", device__name__value=device_name
|
||||
)
|
||||
|
||||
results: list[BgpPeerGroupIntent] = []
|
||||
if not (bgp_node.peer_groups and bgp_node.peer_groups.initialized):
|
||||
self._cache_set(cache_key, results)
|
||||
return results
|
||||
|
||||
for pg_rel in bgp_node.peer_groups.peers:
|
||||
pg = pg_rel.peer
|
||||
|
||||
remote_asn: int | None = None
|
||||
if pg.remote_asn and pg.remote_asn.initialized:
|
||||
await pg.remote_asn.fetch()
|
||||
if pg.remote_asn.peer:
|
||||
remote_asn = pg.remote_asn.peer.asn.value
|
||||
|
||||
results.append(
|
||||
BgpPeerGroupIntent(
|
||||
name=pg.name.value,
|
||||
peer_group_type=pg.peer_group_type.value,
|
||||
remote_asn=remote_asn,
|
||||
update_source=self._opt_value(pg, "update_source"),
|
||||
send_community=bool(self._opt_value(pg, "send_community")),
|
||||
ebgp_multihop=self._opt_value(pg, "ebgp_multihop"),
|
||||
next_hop_unchanged=bool(self._opt_value(pg, "next_hop_unchanged")),
|
||||
)
|
||||
)
|
||||
|
||||
self._cache_set(cache_key, results)
|
||||
return results
|
||||
|
||||
async def get_device_bgp_sessions(self, device_name: str) -> list[BgpSessionIntent]:
|
||||
"""Fetch all BGP sessions configured on a device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
List of :class:`BgpSessionIntent` objects.
|
||||
|
||||
Raises:
|
||||
InfrahubNotFoundError: When no BGP config exists for the device.
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"bgp_sessions:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
bgp_node = await self._fetch_node(
|
||||
"InfraBGPRouterConfig", device__name__value=device_name
|
||||
)
|
||||
|
||||
results: list[BgpSessionIntent] = []
|
||||
if not (bgp_node.sessions and bgp_node.sessions.initialized):
|
||||
self._cache_set(cache_key, results)
|
||||
return results
|
||||
|
||||
for sess_rel in bgp_node.sessions.peers:
|
||||
sess = sess_rel.peer
|
||||
|
||||
peer_group_name: str | None = None
|
||||
if sess.peer_group and sess.peer_group.initialized:
|
||||
await sess.peer_group.fetch()
|
||||
if sess.peer_group.peer:
|
||||
peer_group_name = sess.peer_group.peer.name.value
|
||||
|
||||
remote_asn: int | None = None
|
||||
if sess.remote_asn and sess.remote_asn.initialized:
|
||||
await sess.remote_asn.fetch()
|
||||
if sess.remote_asn.peer:
|
||||
remote_asn = sess.remote_asn.peer.asn.value
|
||||
|
||||
results.append(
|
||||
BgpSessionIntent(
|
||||
peer_address=sess.peer_address.value,
|
||||
description=self._opt_value(sess, "description"),
|
||||
enabled=bool(self._opt_value(sess, "enabled")),
|
||||
peer_group=peer_group_name,
|
||||
remote_asn=remote_asn,
|
||||
)
|
||||
)
|
||||
|
||||
self._cache_set(cache_key, results)
|
||||
return results
|
||||
|
||||
async def get_device_vrfs(self, device_name: str) -> list[VrfIntent]:
|
||||
"""Fetch all VRFs assigned to a device via VRFDeviceAssignment.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
List of :class:`VrfIntent` objects (may be empty).
|
||||
|
||||
Raises:
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"device_vrfs:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
assignments = await self._fetch_all(
|
||||
"InfraVRFDeviceAssignment", device__name__value=device_name
|
||||
)
|
||||
|
||||
results: list[VrfIntent] = []
|
||||
for assignment in assignments:
|
||||
if not (assignment.vrf and assignment.vrf.initialized):
|
||||
continue
|
||||
await assignment.vrf.fetch()
|
||||
vrf_node = assignment.vrf.peer
|
||||
if vrf_node is None:
|
||||
continue
|
||||
|
||||
l3vni: int | None = None
|
||||
if vrf_node.l3vni and vrf_node.l3vni.initialized:
|
||||
await vrf_node.l3vni.fetch()
|
||||
if vrf_node.l3vni.peer:
|
||||
l3vni = vrf_node.l3vni.peer.vni.value
|
||||
|
||||
import_targets: list[str] = []
|
||||
if vrf_node.import_targets and vrf_node.import_targets.initialized:
|
||||
for rt_rel in vrf_node.import_targets.peers:
|
||||
import_targets.append(rt_rel.peer.target.value)
|
||||
|
||||
export_targets: list[str] = []
|
||||
if vrf_node.export_targets and vrf_node.export_targets.initialized:
|
||||
for rt_rel in vrf_node.export_targets.peers:
|
||||
export_targets.append(rt_rel.peer.target.value)
|
||||
|
||||
results.append(
|
||||
VrfIntent(
|
||||
name=vrf_node.name.value,
|
||||
route_distinguisher=self._opt_value(vrf_node, "route_distinguisher"),
|
||||
vrf_id=self._opt_value(vrf_node, "vrf_id"),
|
||||
l3vni=l3vni,
|
||||
import_targets=import_targets,
|
||||
export_targets=export_targets,
|
||||
)
|
||||
)
|
||||
|
||||
self._cache_set(cache_key, results)
|
||||
return results
|
||||
|
||||
async def get_device_vtep(self, device_name: str) -> VtepIntent | None:
|
||||
"""Fetch the VTEP configuration for a device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
:class:`VtepIntent` if a VTEP exists for the device, else ``None``.
|
||||
|
||||
Raises:
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"device_vtep:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
vtep_nodes = await self._fetch_all(
|
||||
"InfraVTEP", device__name__value=device_name
|
||||
)
|
||||
if not vtep_nodes:
|
||||
self._cache_set(cache_key, None)
|
||||
return None
|
||||
|
||||
vtep = vtep_nodes[0]
|
||||
mappings: list[tuple[int, int]] = []
|
||||
|
||||
if vtep.vlan_vni_mappings and vtep.vlan_vni_mappings.initialized:
|
||||
for mapping_rel in vtep.vlan_vni_mappings.peers:
|
||||
mapping = mapping_rel.peer
|
||||
vlan_id: int | None = None
|
||||
vni_id: int | None = None
|
||||
|
||||
if mapping.vlan and mapping.vlan.initialized:
|
||||
await mapping.vlan.fetch()
|
||||
if mapping.vlan.peer:
|
||||
vlan_id = mapping.vlan.peer.vlan_id.value
|
||||
|
||||
if mapping.vni and mapping.vni.initialized:
|
||||
await mapping.vni.fetch()
|
||||
if mapping.vni.peer:
|
||||
vni_id = mapping.vni.peer.vni.value
|
||||
|
||||
if vlan_id is not None and vni_id is not None:
|
||||
mappings.append((vlan_id, vni_id))
|
||||
|
||||
result = VtepIntent(
|
||||
source_address=vtep.source_address.value,
|
||||
udp_port=int(self._opt_value(vtep, "udp_port") or 4789),
|
||||
learn_restrict=bool(self._opt_value(vtep, "learn_restrict")),
|
||||
vlan_vni_mappings=mappings,
|
||||
)
|
||||
self._cache_set(cache_key, result)
|
||||
return result
|
||||
|
||||
async def get_device_evpn_instances(self, device_name: str) -> list[EvpnInstanceIntent]:
|
||||
"""Fetch all EVPN instances for a device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
List of :class:`EvpnInstanceIntent` objects.
|
||||
|
||||
Raises:
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"device_evpn:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
nodes = await self._fetch_all(
|
||||
"InfraEVPNInstance", device__name__value=device_name
|
||||
)
|
||||
|
||||
results: list[EvpnInstanceIntent] = []
|
||||
for node in nodes:
|
||||
vlan_id: int = 0
|
||||
if node.vlan and node.vlan.initialized:
|
||||
await node.vlan.fetch()
|
||||
if node.vlan.peer:
|
||||
vlan_id = node.vlan.peer.vlan_id.value
|
||||
|
||||
results.append(
|
||||
EvpnInstanceIntent(
|
||||
route_distinguisher=node.route_distinguisher.value,
|
||||
route_target_import=node.route_target_import.value,
|
||||
route_target_export=node.route_target_export.value,
|
||||
redistribute_learned=bool(
|
||||
self._opt_value(node, "redistribute_learned")
|
||||
),
|
||||
vlan_id=vlan_id,
|
||||
)
|
||||
)
|
||||
|
||||
self._cache_set(cache_key, results)
|
||||
return results
|
||||
|
||||
async def get_mlag_domain(self, device_name: str) -> MlagDomainIntent | None:
|
||||
"""Fetch the MLAG domain that includes the given device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
:class:`MlagDomainIntent` if the device belongs to an MLAG domain,
|
||||
else ``None``.
|
||||
|
||||
Raises:
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"mlag_domain:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
domains = await self._fetch_all("InfraMlagDomain")
|
||||
for domain in domains:
|
||||
if not (domain.devices and domain.devices.initialized):
|
||||
continue
|
||||
device_names: list[str] = []
|
||||
for dev_rel in domain.devices.peers:
|
||||
device_names.append(dev_rel.peer.name.value)
|
||||
|
||||
if device_name not in device_names:
|
||||
continue
|
||||
|
||||
peer_vlan: str | None = None
|
||||
if domain.peer_vlan and domain.peer_vlan.initialized:
|
||||
await domain.peer_vlan.fetch()
|
||||
if domain.peer_vlan.peer:
|
||||
peer_vlan = str(domain.peer_vlan.peer.vlan_id.value)
|
||||
|
||||
result = MlagDomainIntent(
|
||||
domain_id=str(domain.domain_id.value),
|
||||
virtual_mac=domain.virtual_mac.value,
|
||||
heartbeat_vrf=self._opt_value(domain, "heartbeat_vrf") or "",
|
||||
dual_primary_detection=bool(
|
||||
self._opt_value(domain, "dual_primary_detection")
|
||||
),
|
||||
dual_primary_delay=int(self._opt_value(domain, "dual_primary_delay") or 0),
|
||||
dual_primary_action=self._opt_value(domain, "dual_primary_action") or "",
|
||||
peer_devices=device_names,
|
||||
)
|
||||
self._cache_set(cache_key, result)
|
||||
return result
|
||||
|
||||
self._cache_set(cache_key, None)
|
||||
return None
|
||||
|
||||
async def get_mlag_peer_config(self, device_name: str) -> MlagPeerConfigIntent | None:
|
||||
"""Fetch the MLAG peer config for a device.
|
||||
|
||||
Args:
|
||||
device_name: Device name as stored in Infrahub.
|
||||
|
||||
Returns:
|
||||
:class:`MlagPeerConfigIntent` if a config exists, else ``None``.
|
||||
|
||||
Raises:
|
||||
InfrahubQueryError: On query failure.
|
||||
"""
|
||||
cache_key = f"mlag_peer_config:{device_name}"
|
||||
if cached := self._cache_get(cache_key):
|
||||
return cached
|
||||
|
||||
configs = await self._fetch_all(
|
||||
"InfraMlagPeerConfig", device__name__value=device_name
|
||||
)
|
||||
if not configs:
|
||||
self._cache_set(cache_key, None)
|
||||
return None
|
||||
|
||||
cfg = configs[0]
|
||||
|
||||
peer_link_name: str = ""
|
||||
if cfg.peer_link and cfg.peer_link.initialized:
|
||||
await cfg.peer_link.fetch()
|
||||
if cfg.peer_link.peer:
|
||||
peer_link_name = cfg.peer_link.peer.name.value
|
||||
|
||||
result = MlagPeerConfigIntent(
|
||||
local_interface_ip=cfg.local_interface_ip.value,
|
||||
peer_address=cfg.peer_address.value,
|
||||
heartbeat_peer_ip=self._opt_value(cfg, "heartbeat_peer_ip") or "",
|
||||
peer_link=peer_link_name,
|
||||
)
|
||||
self._cache_set(cache_key, result)
|
||||
return result
|
||||
17
src/infrahub_client/exceptions.py
Normal file
17
src/infrahub_client/exceptions.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Exception hierarchy for the Infrahub fabric intent client."""
|
||||
|
||||
|
||||
class InfrahubClientError(Exception):
|
||||
"""Base exception for all Infrahub client errors."""
|
||||
|
||||
|
||||
class InfrahubConnectionError(InfrahubClientError):
|
||||
"""Raised when the client cannot connect to the Infrahub server."""
|
||||
|
||||
|
||||
class InfrahubQueryError(InfrahubClientError):
|
||||
"""Raised when a query to the Infrahub server fails."""
|
||||
|
||||
|
||||
class InfrahubNotFoundError(InfrahubClientError):
|
||||
"""Raised when a requested resource is not found in Infrahub."""
|
||||
137
src/infrahub_client/models.py
Normal file
137
src/infrahub_client/models.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Pydantic v2 models for typed Infrahub fabric intent responses."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class DeviceIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraDevice node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
platform: str | None = None
|
||||
site: str | None = None
|
||||
asn: int | None = None
|
||||
|
||||
|
||||
class VlanIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraVLAN node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
vlan_id: int
|
||||
name: str
|
||||
status: str
|
||||
vlan_type: str
|
||||
vni: int | None = None
|
||||
stp_enabled: bool
|
||||
|
||||
|
||||
class VniIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraVNI node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
vni: int
|
||||
vni_type: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class BgpRouterConfigIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraBGPRouterConfig node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
router_id: str
|
||||
local_asn: int
|
||||
default_ipv4_unicast: bool
|
||||
ecmp_max_paths: int
|
||||
|
||||
|
||||
class BgpPeerGroupIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraBGPPeerGroup node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
name: str
|
||||
peer_group_type: str
|
||||
remote_asn: int | None = None
|
||||
update_source: str | None = None
|
||||
send_community: bool
|
||||
ebgp_multihop: int | None = None
|
||||
next_hop_unchanged: bool
|
||||
|
||||
|
||||
class BgpSessionIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraBGPSession node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
peer_address: str
|
||||
description: str | None = None
|
||||
enabled: bool
|
||||
peer_group: str | None = None
|
||||
remote_asn: int | None = None
|
||||
|
||||
|
||||
class VrfIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraVRF node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
name: str
|
||||
route_distinguisher: str | None = None
|
||||
vrf_id: int | None = None
|
||||
l3vni: int | None = None
|
||||
import_targets: list[str]
|
||||
export_targets: list[str]
|
||||
|
||||
|
||||
class VtepIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraVTEP node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
source_address: str
|
||||
udp_port: int
|
||||
learn_restrict: bool
|
||||
vlan_vni_mappings: list[tuple[int, int]]
|
||||
|
||||
|
||||
class MlagDomainIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraMlagDomain node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
domain_id: str
|
||||
virtual_mac: str
|
||||
heartbeat_vrf: str
|
||||
dual_primary_detection: bool
|
||||
dual_primary_delay: int
|
||||
dual_primary_action: str
|
||||
peer_devices: list[str]
|
||||
|
||||
|
||||
class MlagPeerConfigIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraMlagPeerConfig node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
local_interface_ip: str
|
||||
peer_address: str
|
||||
heartbeat_peer_ip: str
|
||||
peer_link: str
|
||||
|
||||
|
||||
class EvpnInstanceIntent(BaseModel):
|
||||
"""Intent model for an Infrahub InfraEVPNInstance node."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
route_distinguisher: str
|
||||
route_target_import: str
|
||||
route_target_export: str
|
||||
redistribute_learned: bool
|
||||
vlan_id: int
|
||||
598
tests/test_infrahub_client.py
Normal file
598
tests/test_infrahub_client.py
Normal file
@@ -0,0 +1,598 @@
|
||||
"""Unit tests for FabricInfrahubClient — all Infrahub SDK calls are mocked."""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrahub_client import (
|
||||
FabricInfrahubClient,
|
||||
InfrahubNotFoundError,
|
||||
InfrahubQueryError,
|
||||
)
|
||||
from infrahub_client.models import (
|
||||
BgpPeerGroupIntent,
|
||||
BgpRouterConfigIntent,
|
||||
BgpSessionIntent,
|
||||
DeviceIntent,
|
||||
EvpnInstanceIntent,
|
||||
MlagDomainIntent,
|
||||
MlagPeerConfigIntent,
|
||||
VlanIntent,
|
||||
VrfIntent,
|
||||
VtepIntent,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build mock SDK node objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _attr(value):
|
||||
"""Return a simple attribute mock with ``.value`` set."""
|
||||
m = MagicMock()
|
||||
m.value = value
|
||||
return m
|
||||
|
||||
|
||||
def _rel_one(peer_node, initialized=True):
|
||||
"""Return a mock for a cardinality-one relationship."""
|
||||
m = MagicMock()
|
||||
m.initialized = initialized
|
||||
m.peer = peer_node
|
||||
m.fetch = AsyncMock()
|
||||
return m
|
||||
|
||||
|
||||
def _rel_many(peers, initialized=True):
|
||||
"""Return a mock for a cardinality-many relationship (list of peer wrappers)."""
|
||||
m = MagicMock()
|
||||
m.initialized = initialized
|
||||
peer_wrappers = []
|
||||
for p in peers:
|
||||
pw = MagicMock()
|
||||
pw.peer = p
|
||||
peer_wrappers.append(pw)
|
||||
m.peers = peer_wrappers
|
||||
return m
|
||||
|
||||
|
||||
def _make_device(
|
||||
name="leaf1",
|
||||
role="leaf",
|
||||
status="active",
|
||||
platform_name="eos",
|
||||
site_name="dc1",
|
||||
asn_value=65001,
|
||||
):
|
||||
"""Build a mock InfraDevice SDK node."""
|
||||
node = MagicMock()
|
||||
node.name = _attr(name)
|
||||
node.role = _attr(role)
|
||||
node.status = _attr(status)
|
||||
|
||||
platform = MagicMock()
|
||||
platform.name = _attr(platform_name)
|
||||
node.platform = _rel_one(platform)
|
||||
|
||||
site = MagicMock()
|
||||
site.name = _attr(site_name)
|
||||
node.site = _rel_one(site)
|
||||
|
||||
asn = MagicMock()
|
||||
asn.asn = _attr(asn_value)
|
||||
node.asn = _rel_one(asn)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def _make_sdk_client(get_return=None, all_return=None, filters_return=None, get_raises=None):
|
||||
"""Patch InfrahubClient inside client module and configure return values."""
|
||||
mock_sdk = MagicMock()
|
||||
if get_raises:
|
||||
mock_sdk.get = AsyncMock(side_effect=get_raises)
|
||||
else:
|
||||
mock_sdk.get = AsyncMock(return_value=get_return)
|
||||
mock_sdk.all = AsyncMock(return_value=all_return or [])
|
||||
mock_sdk.filters = AsyncMock(return_value=filters_return or [])
|
||||
return mock_sdk
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
"""FabricInfrahubClient with a patched InfrahubClient SDK."""
|
||||
with patch("infrahub_client.client.InfrahubClient") as MockSdk:
|
||||
MockSdk.return_value = MagicMock()
|
||||
c = FabricInfrahubClient(url="http://infrahub:8080", api_token="token", branch="main")
|
||||
yield c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDevice:
|
||||
async def test_returns_device_intent(self, client):
|
||||
device_node = _make_device()
|
||||
client._sdk.get = AsyncMock(return_value=device_node)
|
||||
|
||||
result = await client.get_device("leaf1")
|
||||
|
||||
assert isinstance(result, DeviceIntent)
|
||||
assert result.name == "leaf1"
|
||||
assert result.role == "leaf"
|
||||
assert result.status == "active"
|
||||
assert result.platform == "eos"
|
||||
assert result.site == "dc1"
|
||||
assert result.asn == 65001
|
||||
|
||||
async def test_not_found_raises_error(self, client):
|
||||
client._sdk.get = AsyncMock(side_effect=Exception("not found"))
|
||||
|
||||
with pytest.raises(InfrahubNotFoundError):
|
||||
await client.get_device("nonexistent")
|
||||
|
||||
async def test_branch_is_forwarded(self, client):
|
||||
client._branch = "proposed"
|
||||
device_node = _make_device()
|
||||
client._sdk.get = AsyncMock(return_value=device_node)
|
||||
|
||||
await client.get_device("leaf1")
|
||||
|
||||
client._sdk.get.assert_called_once_with(
|
||||
kind="InfraDevice", branch="proposed", name__value="leaf1"
|
||||
)
|
||||
|
||||
async def test_caching_second_call_skips_sdk(self, client):
|
||||
device_node = _make_device()
|
||||
client._sdk.get = AsyncMock(return_value=device_node)
|
||||
|
||||
first = await client.get_device("leaf1")
|
||||
second = await client.get_device("leaf1")
|
||||
|
||||
assert first == second
|
||||
# SDK should only have been called once
|
||||
assert client._sdk.get.call_count == 1
|
||||
|
||||
async def test_optional_fields_none_when_uninitialized(self, client):
|
||||
node = MagicMock()
|
||||
node.name = _attr("spine1")
|
||||
node.role = _attr("spine")
|
||||
node.status = _attr("active")
|
||||
node.platform = _rel_one(None, initialized=False)
|
||||
node.site = _rel_one(None, initialized=False)
|
||||
node.asn = _rel_one(None, initialized=False)
|
||||
client._sdk.get = AsyncMock(return_value=node)
|
||||
|
||||
result = await client.get_device("spine1")
|
||||
|
||||
assert result.platform is None
|
||||
assert result.site is None
|
||||
assert result.asn is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device_bgp_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDeviceBgpConfig:
|
||||
def _make_bgp_node(self, router_id="10.0.0.1", asn=65001, default_ipv4=True, ecmp=4):
|
||||
node = MagicMock()
|
||||
node.router_id = _attr(router_id)
|
||||
node.default_ipv4_unicast = _attr(default_ipv4)
|
||||
node.ecmp_max_paths = _attr(ecmp)
|
||||
|
||||
asn_peer = MagicMock()
|
||||
asn_peer.asn = _attr(asn)
|
||||
node.local_asn = _rel_one(asn_peer)
|
||||
|
||||
return node
|
||||
|
||||
async def test_returns_bgp_config_intent(self, client):
|
||||
bgp_node = self._make_bgp_node()
|
||||
client._sdk.get = AsyncMock(return_value=bgp_node)
|
||||
|
||||
result = await client.get_device_bgp_config("leaf1")
|
||||
|
||||
assert isinstance(result, BgpRouterConfigIntent)
|
||||
assert result.router_id == "10.0.0.1"
|
||||
assert result.local_asn == 65001
|
||||
assert result.default_ipv4_unicast is True
|
||||
assert result.ecmp_max_paths == 4
|
||||
|
||||
async def test_not_found_raises_error(self, client):
|
||||
client._sdk.get = AsyncMock(side_effect=Exception("not found"))
|
||||
|
||||
with pytest.raises(InfrahubNotFoundError):
|
||||
await client.get_device_bgp_config("leaf1")
|
||||
|
||||
async def test_caching(self, client):
|
||||
bgp_node = self._make_bgp_node()
|
||||
client._sdk.get = AsyncMock(return_value=bgp_node)
|
||||
|
||||
await client.get_device_bgp_config("leaf1")
|
||||
await client.get_device_bgp_config("leaf1")
|
||||
|
||||
assert client._sdk.get.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device_bgp_peer_groups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDeviceBgpPeerGroups:
|
||||
def _make_pg(self, name="UNDERLAY", pg_type="underlay", remote_asn=65000):
|
||||
pg = MagicMock()
|
||||
pg.name = _attr(name)
|
||||
pg.peer_group_type = _attr(pg_type)
|
||||
pg.update_source = _attr("Loopback0")
|
||||
pg.send_community = _attr(True)
|
||||
pg.ebgp_multihop = _attr(None)
|
||||
pg.next_hop_unchanged = _attr(False)
|
||||
|
||||
asn_peer = MagicMock()
|
||||
asn_peer.asn = _attr(remote_asn)
|
||||
pg.remote_asn = _rel_one(asn_peer)
|
||||
|
||||
return pg
|
||||
|
||||
async def test_returns_peer_group_intents(self, client):
|
||||
pg = self._make_pg()
|
||||
bgp_node = MagicMock()
|
||||
bgp_node.peer_groups = _rel_many([pg])
|
||||
client._sdk.get = AsyncMock(return_value=bgp_node)
|
||||
|
||||
results = await client.get_device_bgp_peer_groups("leaf1")
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], BgpPeerGroupIntent)
|
||||
assert results[0].name == "UNDERLAY"
|
||||
assert results[0].peer_group_type == "underlay"
|
||||
assert results[0].remote_asn == 65000
|
||||
assert results[0].send_community is True
|
||||
|
||||
async def test_empty_when_no_peer_groups(self, client):
|
||||
bgp_node = MagicMock()
|
||||
bgp_node.peer_groups = _rel_many([], initialized=False)
|
||||
client._sdk.get = AsyncMock(return_value=bgp_node)
|
||||
|
||||
results = await client.get_device_bgp_peer_groups("leaf1")
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device_bgp_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDeviceBgpSessions:
|
||||
def _make_session(self, peer_addr="192.168.1.1", enabled=True):
|
||||
sess = MagicMock()
|
||||
sess.peer_address = _attr(peer_addr)
|
||||
sess.description = _attr("peer session")
|
||||
sess.enabled = _attr(enabled)
|
||||
|
||||
pg_peer = MagicMock()
|
||||
pg_peer.name = _attr("UNDERLAY")
|
||||
sess.peer_group = _rel_one(pg_peer)
|
||||
|
||||
asn_peer = MagicMock()
|
||||
asn_peer.asn = _attr(65000)
|
||||
sess.remote_asn = _rel_one(asn_peer)
|
||||
|
||||
return sess
|
||||
|
||||
async def test_returns_session_intents(self, client):
|
||||
sess = self._make_session()
|
||||
bgp_node = MagicMock()
|
||||
bgp_node.sessions = _rel_many([sess])
|
||||
client._sdk.get = AsyncMock(return_value=bgp_node)
|
||||
|
||||
results = await client.get_device_bgp_sessions("leaf1")
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], BgpSessionIntent)
|
||||
assert results[0].peer_address == "192.168.1.1"
|
||||
assert results[0].enabled is True
|
||||
assert results[0].peer_group == "UNDERLAY"
|
||||
assert results[0].remote_asn == 65000
|
||||
|
||||
async def test_empty_when_no_sessions(self, client):
|
||||
bgp_node = MagicMock()
|
||||
bgp_node.sessions = _rel_many([], initialized=False)
|
||||
client._sdk.get = AsyncMock(return_value=bgp_node)
|
||||
|
||||
results = await client.get_device_bgp_sessions("leaf1")
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device_vrfs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDeviceVrfs:
|
||||
def _make_assignment(self, vrf_name="PROD", rd="65001:100", vni=100100):
|
||||
vni_peer = MagicMock()
|
||||
vni_peer.vni = _attr(vni)
|
||||
|
||||
rt_import = MagicMock()
|
||||
rt_import.target = _attr("65001:100")
|
||||
rt_export = MagicMock()
|
||||
rt_export.target = _attr("65001:100")
|
||||
|
||||
vrf_node = MagicMock()
|
||||
vrf_node.name = _attr(vrf_name)
|
||||
vrf_node.route_distinguisher = _attr(rd)
|
||||
vrf_node.vrf_id = _attr(None)
|
||||
vrf_node.l3vni = _rel_one(vni_peer)
|
||||
vrf_node.import_targets = _rel_many([rt_import])
|
||||
vrf_node.export_targets = _rel_many([rt_export])
|
||||
|
||||
assignment = MagicMock()
|
||||
assignment.vrf = _rel_one(vrf_node)
|
||||
|
||||
return assignment
|
||||
|
||||
async def test_returns_vrf_intents(self, client):
|
||||
assignment = self._make_assignment()
|
||||
client._sdk.filters = AsyncMock(return_value=[assignment])
|
||||
|
||||
results = await client.get_device_vrfs("leaf1")
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VrfIntent)
|
||||
assert results[0].name == "PROD"
|
||||
assert results[0].route_distinguisher == "65001:100"
|
||||
assert results[0].l3vni == 100100
|
||||
assert "65001:100" in results[0].import_targets
|
||||
|
||||
async def test_empty_when_no_assignments(self, client):
|
||||
client._sdk.filters = AsyncMock(return_value=[])
|
||||
|
||||
results = await client.get_device_vrfs("leaf1")
|
||||
|
||||
assert results == []
|
||||
|
||||
async def test_caching(self, client):
|
||||
assignment = self._make_assignment()
|
||||
client._sdk.filters = AsyncMock(return_value=[assignment])
|
||||
|
||||
await client.get_device_vrfs("leaf1")
|
||||
await client.get_device_vrfs("leaf1")
|
||||
|
||||
assert client._sdk.filters.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device_vtep
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDeviceVtep:
|
||||
def _make_vtep(self, src_addr="10.255.0.1", udp_port=4789):
|
||||
vlan_node = MagicMock()
|
||||
vlan_node.vlan_id = _attr(100)
|
||||
vni_node = MagicMock()
|
||||
vni_node.vni = _attr(10100)
|
||||
|
||||
mapping = MagicMock()
|
||||
mapping.vlan = _rel_one(vlan_node)
|
||||
mapping.vni = _rel_one(vni_node)
|
||||
|
||||
vtep = MagicMock()
|
||||
vtep.source_address = _attr(src_addr)
|
||||
vtep.udp_port = _attr(udp_port)
|
||||
vtep.learn_restrict = _attr(True)
|
||||
vtep.vlan_vni_mappings = _rel_many([mapping])
|
||||
|
||||
return vtep
|
||||
|
||||
async def test_returns_vtep_intent(self, client):
|
||||
vtep = self._make_vtep()
|
||||
client._sdk.filters = AsyncMock(return_value=[vtep])
|
||||
|
||||
result = await client.get_device_vtep("leaf1")
|
||||
|
||||
assert isinstance(result, VtepIntent)
|
||||
assert result.source_address == "10.255.0.1"
|
||||
assert result.udp_port == 4789
|
||||
assert result.learn_restrict is True
|
||||
assert (100, 10100) in result.vlan_vni_mappings
|
||||
|
||||
async def test_returns_none_when_no_vtep(self, client):
|
||||
client._sdk.filters = AsyncMock(return_value=[])
|
||||
|
||||
result = await client.get_device_vtep("leaf1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_device_evpn_instances
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDeviceEvpnInstances:
|
||||
def _make_evpn_instance(self, rd="65001:100", rt_import="65001:100", rt_export="65001:100"):
|
||||
vlan_node = MagicMock()
|
||||
vlan_node.vlan_id = _attr(100)
|
||||
|
||||
node = MagicMock()
|
||||
node.route_distinguisher = _attr(rd)
|
||||
node.route_target_import = _attr(rt_import)
|
||||
node.route_target_export = _attr(rt_export)
|
||||
node.redistribute_learned = _attr(True)
|
||||
node.vlan = _rel_one(vlan_node)
|
||||
|
||||
return node
|
||||
|
||||
async def test_returns_evpn_instance_intents(self, client):
|
||||
evpn_node = self._make_evpn_instance()
|
||||
client._sdk.filters = AsyncMock(return_value=[evpn_node])
|
||||
|
||||
results = await client.get_device_evpn_instances("leaf1")
|
||||
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], EvpnInstanceIntent)
|
||||
assert results[0].route_distinguisher == "65001:100"
|
||||
assert results[0].redistribute_learned is True
|
||||
assert results[0].vlan_id == 100
|
||||
|
||||
async def test_empty_when_no_instances(self, client):
|
||||
client._sdk.filters = AsyncMock(return_value=[])
|
||||
|
||||
results = await client.get_device_evpn_instances("leaf1")
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_mlag_domain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetMlagDomain:
|
||||
def _make_domain(self, device_names=("leaf1", "leaf2")):
|
||||
devices = []
|
||||
for dname in device_names:
|
||||
d = MagicMock()
|
||||
d.name = _attr(dname)
|
||||
devices.append(d)
|
||||
|
||||
peer_vlan_node = MagicMock()
|
||||
peer_vlan_node.vlan_id = _attr(4094)
|
||||
|
||||
domain = MagicMock()
|
||||
domain.domain_id = _attr("1")
|
||||
domain.virtual_mac = _attr("00:1c:73:00:00:01")
|
||||
domain.heartbeat_vrf = _attr("MGMT")
|
||||
domain.dual_primary_detection = _attr(True)
|
||||
domain.dual_primary_delay = _attr(10)
|
||||
domain.dual_primary_action = _attr("errDisable")
|
||||
domain.devices = _rel_many(devices)
|
||||
domain.peer_vlan = _rel_one(peer_vlan_node)
|
||||
|
||||
return domain
|
||||
|
||||
async def test_returns_mlag_domain_intent(self, client):
|
||||
domain = self._make_domain()
|
||||
client._sdk.all = AsyncMock(return_value=[domain])
|
||||
|
||||
result = await client.get_mlag_domain("leaf1")
|
||||
|
||||
assert isinstance(result, MlagDomainIntent)
|
||||
assert result.domain_id == "1"
|
||||
assert result.virtual_mac == "00:1c:73:00:00:01"
|
||||
assert "leaf1" in result.peer_devices
|
||||
assert "leaf2" in result.peer_devices
|
||||
|
||||
async def test_returns_none_when_device_not_in_any_domain(self, client):
|
||||
domain = self._make_domain(device_names=("leaf3", "leaf4"))
|
||||
client._sdk.all = AsyncMock(return_value=[domain])
|
||||
|
||||
result = await client.get_mlag_domain("leaf1")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_caching(self, client):
|
||||
domain = self._make_domain()
|
||||
client._sdk.all = AsyncMock(return_value=[domain])
|
||||
|
||||
await client.get_mlag_domain("leaf1")
|
||||
await client.get_mlag_domain("leaf1")
|
||||
|
||||
assert client._sdk.all.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_mlag_peer_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetMlagPeerConfig:
|
||||
def _make_peer_config(self):
|
||||
peer_link_node = MagicMock()
|
||||
peer_link_node.name = _attr("Port-Channel1")
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.local_interface_ip = _attr("192.168.255.0/31")
|
||||
cfg.peer_address = _attr("192.168.255.1")
|
||||
cfg.heartbeat_peer_ip = _attr("192.168.1.2")
|
||||
cfg.peer_link = _rel_one(peer_link_node)
|
||||
|
||||
return cfg
|
||||
|
||||
async def test_returns_mlag_peer_config_intent(self, client):
|
||||
cfg = self._make_peer_config()
|
||||
client._sdk.filters = AsyncMock(return_value=[cfg])
|
||||
|
||||
result = await client.get_mlag_peer_config("leaf1")
|
||||
|
||||
assert isinstance(result, MlagPeerConfigIntent)
|
||||
assert result.local_interface_ip == "192.168.255.0/31"
|
||||
assert result.peer_address == "192.168.255.1"
|
||||
assert result.peer_link == "Port-Channel1"
|
||||
|
||||
async def test_returns_none_when_no_config(self, client):
|
||||
client._sdk.filters = AsyncMock(return_value=[])
|
||||
|
||||
result = await client.get_mlag_peer_config("leaf1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
async def test_context_manager(self):
|
||||
with patch("infrahub_client.client.InfrahubClient"):
|
||||
async with FabricInfrahubClient(
|
||||
url="http://infrahub:8080", api_token="token"
|
||||
) as client:
|
||||
assert isinstance(client, FabricInfrahubClient)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache TTL expiry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheTtl:
|
||||
async def test_expired_cache_triggers_new_sdk_call(self, client):
|
||||
device_node = _make_device()
|
||||
client._sdk.get = AsyncMock(return_value=device_node)
|
||||
|
||||
await client.get_device("leaf1")
|
||||
# Manually expire the cache entry
|
||||
client._cache["device:leaf1"] = (client._cache["device:leaf1"][0], 0.0)
|
||||
await client.get_device("leaf1")
|
||||
|
||||
assert client._sdk.get.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InfrahubQueryError propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueryError:
|
||||
async def test_sdk_unexpected_error_raises_query_error(self, client):
|
||||
client._sdk.get = AsyncMock(side_effect=Exception("connection refused"))
|
||||
|
||||
with pytest.raises(InfrahubQueryError):
|
||||
await client.get_device_bgp_config("leaf1")
|
||||
Reference in New Issue
Block a user