diff --git a/src/gnmi_eos/client.py b/src/gnmi_eos/client.py new file mode 100644 index 0000000..53bf0e1 --- /dev/null +++ b/src/gnmi_eos/client.py @@ -0,0 +1,348 @@ +""" +gNMI Client for Arista EOS devices. + +Provides a high-level gNMI client using pygnmi with: +- Connection management via context manager +- Typed exceptions for connection and path errors +- get / set / subscribe / capabilities operations +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Literal + +from pygnmi.client import gNMIclient + +# ============================================================================== +# Exceptions +# ============================================================================== + + +class GNMIError(Exception): + """Base exception for gNMI operations.""" + + pass + + +class GNMIConnectionError(GNMIError): + """Raised when connection to device fails.""" + + pass + + +class GNMIPathError(GNMIError): + """Raised when path is invalid or not found.""" + + pass + + +# ============================================================================== +# Data Types +# ============================================================================== + +DataType = Literal["config", "state", "all"] +SetOperation = Literal["update", "replace", "delete"] +SubscribeMode = Literal["once", "stream", "poll"] +StreamMode = Literal["on-change", "sample", "target-defined"] + + +@dataclass +class Capability: + """Represents a gNMI capability/model.""" + + name: str + organization: str + version: str + + def __str__(self) -> str: + return f"{self.name} ({self.organization}) v{self.version}" + + +@dataclass +class SubscriptionUpdate: + """Represents a subscription update.""" + + path: str + value: Any + timestamp: int + + +# ============================================================================== +# gNMI Client +# ============================================================================== + + +class GNMIClient: + """ + gNMI Client wrapper using pygnmi. + + Provides high-level methods for gNMI operations with proper + error handling and context manager support. + + Example: + with GNMIClient(host="leaf1", port=6030) as client: + caps = client.capabilities() + counters = client.get("/interfaces/interface[name=Ethernet1]/state/counters") + """ + + def __init__( + self, + host: str, + port: int = 6030, + username: str = "admin", + password: str = "admin", + insecure: bool = True, + skip_verify: bool = True, + timeout: int = 10, + ): + """ + Initialize gNMI client. + + Args: + host: Target hostname or IP + port: gNMI port (default: 6030) + username: Username for authentication + password: Password for authentication + insecure: Skip TLS verification (default: True for lab) + skip_verify: Skip certificate verification + timeout: Connection timeout in seconds + """ + self.host = host + self.port = port + self.username = username + self.password = password + self.insecure = insecure + self.skip_verify = skip_verify + self.timeout = timeout + self._client: gNMIclient | None = None + + @property + def target(self) -> str: + """Get target string (host:port).""" + return f"{self.host}:{self.port}" + + def _ensure_connected(self) -> gNMIclient: + """Ensure client is connected.""" + if self._client is None: + raise GNMIConnectionError("Client not connected. Use context manager.") + return self._client + + def __enter__(self) -> GNMIClient: + """Enter context manager.""" + try: + self._client = gNMIclient( + target=(self.host, self.port), + username=self.username, + password=self.password, + insecure=self.insecure, + skip_verify=self.skip_verify, + ) + self._client.__enter__() + return self + except Exception as e: + raise GNMIConnectionError( + f"Failed to connect to {self.target}: {e}" + ) from e + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager.""" + if self._client: + try: + self._client.__exit__(exc_type, exc_val, exc_tb) + except Exception: + pass # Ignore cleanup errors + self._client = None + + # ========================================================================== + # gNMI Operations + # ========================================================================== + + def capabilities(self) -> dict[str, Any]: + """ + Get device capabilities. + + Returns: + Dictionary containing: + - gnmi_version: gNMI protocol version + - supported_models: List of supported YANG models + - supported_encodings: List of supported encodings + """ + client = self._ensure_connected() + try: + return client.capabilities() + except Exception as e: + raise GNMIError(f"Failed to get capabilities: {e}") from e + + def get_models(self) -> list[Capability]: + """ + Get list of supported YANG models. + + Returns: + List of Capability objects + """ + caps = self.capabilities() + return [ + Capability( + name=m.get("name", ""), + organization=m.get("organization", ""), + version=m.get("version", ""), + ) + for m in caps.get("supported_models", []) + ] + + def get( + self, + path: str | list[str], + data_type: DataType = "all", + encoding: str = "json_ietf", + ) -> dict[str, Any]: + """ + Get data at path. + + Args: + path: YANG path or list of paths + data_type: "config", "state", or "all" (default) + encoding: Data encoding (default: json_ietf) + + Returns: + Dictionary with path data + """ + client = self._ensure_connected() + paths = [path] if isinstance(path, str) else path + datatype_map = {"config": "config", "state": "state", "all": "all"} + datatype = datatype_map.get(data_type, "all") + + try: + return client.get(path=paths, datatype=datatype, encoding=encoding) + except Exception as e: + err = str(e).lower() + if "not found" in err or "invalid" in err: + raise GNMIPathError(f"Path not found or invalid: {path}") from e + raise GNMIError(f"Failed to get data at {path}: {e}") from e + + def set( + self, + path: str, + value: Any, + operation: SetOperation = "update", + encoding: str = "json_ietf", + dry_run: bool = True, + ) -> dict[str, Any]: + """ + Set configuration at path. + + Args: + path: YANG path + value: Value to set (dict or JSON string) + operation: "update" (merge), "replace", or "delete" + encoding: Data encoding (default: json_ietf) + dry_run: If True, only validate without applying (default: True) + + Returns: + Result of set operation, or dry-run preview dict + """ + if dry_run: + return { + "dry_run": True, + "operation": operation, + "path": path, + "value": value, + "message": "Dry-run mode - no changes applied", + } + + client = self._ensure_connected() + + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + pass # Keep as string + + try: + if operation == "delete": + return client.set(delete=[path]) + elif operation == "replace": + return client.set(replace=[(path, value)]) + else: # update + return client.set(update=[(path, value)]) + except Exception as e: + raise GNMIError(f"Failed to set {path}: {e}") from e + + def subscribe( + self, + paths: str | list[str], + mode: SubscribeMode = "stream", + stream_mode: StreamMode = "on-change", + sample_interval: int = 10, + encoding: str = "json_ietf", + ) -> Any: + """ + Subscribe to path updates. + + Args: + paths: YANG path(s) to subscribe to + mode: "once", "stream" (default), or "poll" + stream_mode: "on-change" (default), "sample", or "target-defined" + sample_interval: Interval in seconds for sample mode + encoding: Data encoding + + Yields: + Subscription updates + """ + client = self._ensure_connected() + path_list = [paths] if isinstance(paths, str) else paths + + subscribe_list = [] + for p in path_list: + sub: dict[str, Any] = { + "path": p, + "mode": stream_mode.replace("-", "_"), # on-change -> on_change + } + if stream_mode == "sample": + sub["sample_interval"] = sample_interval * 1_000_000_000 # nanoseconds + subscribe_list.append(sub) + + subscribe_request = { + "subscription": subscribe_list, + "mode": mode, + "encoding": encoding, + } + + try: + return client.subscribe2(subscribe=subscribe_request) + except Exception as e: + raise GNMIError(f"Failed to subscribe to {paths}: {e}") from e + + # ========================================================================== + # Utility Methods + # ========================================================================== + + def explore(self, path: str = "/") -> dict[str, Any]: + """ + Get all data under a given path. + + Args: + path: Base path to explore (default: root) + + Returns: + Dictionary with discovered data + """ + return self.get(path, data_type="all") + + def validate_path(self, path: str) -> bool: + """ + Check if path exists on device. + + Args: + path: YANG path to validate + + Returns: + True if path exists, False otherwise + """ + try: + self.get(path) + return True + except (GNMIPathError, GNMIError): + return False