"""TrueNAS WebSocket client and dataset utilities.""" from __future__ import annotations import asyncio import base64 import contextlib import hashlib import json import os import ssl import struct from typing import Any, Optional from .colors import log # ───────────────────────────────────────────────────────────────────────────── # Raw WebSocket implementation (stdlib only, RFC 6455) # ───────────────────────────────────────────────────────────────────────────── def _ws_mask(data: bytes, mask: bytes) -> bytes: """XOR *data* with a 4-byte repeating mask key.""" out = bytearray(data) for i in range(len(out)): out[i] ^= mask[i & 3] return bytes(out) def _ws_encode_frame(payload: bytes, opcode: int = 0x1) -> bytes: """Encode a masked client→server WebSocket frame.""" mask = os.urandom(4) length = len(payload) header = bytearray([0x80 | opcode]) if length < 126: header.append(0x80 | length) elif length < 65536: header.append(0x80 | 126) header += struct.pack("!H", length) else: header.append(0x80 | 127) header += struct.pack("!Q", length) return bytes(header) + mask + _ws_mask(payload, mask) async def _ws_recv_message(reader: asyncio.StreamReader) -> str: """ Read one complete WebSocket message, reassembling continuation frames. Skips ping/pong control frames. Raises OSError on close frame. """ fragments: list[bytes] = [] while True: hdr = await reader.readexactly(2) fin = bool(hdr[0] & 0x80) opcode = hdr[0] & 0x0F masked = bool(hdr[1] & 0x80) length = hdr[1] & 0x7F if length == 126: length = struct.unpack("!H", await reader.readexactly(2))[0] elif length == 127: length = struct.unpack("!Q", await reader.readexactly(8))[0] mask_key = await reader.readexactly(4) if masked else None payload = await reader.readexactly(length) if length else b"" if mask_key: payload = _ws_mask(payload, mask_key) if opcode == 0x8: raise OSError("WebSocket: server sent close frame") if opcode in (0x9, 0xA): continue fragments.append(payload) if fin: return b"".join(fragments).decode("utf-8") class _WebSocket: """asyncio StreamReader/Writer wrapped to a simple send/recv/close API.""" def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: self._reader = reader self._writer = writer async def send(self, data: str) -> None: self._writer.write(_ws_encode_frame(data.encode("utf-8"), opcode=0x1)) await self._writer.drain() async def recv(self) -> str: return await _ws_recv_message(self._reader) async def close(self) -> None: with contextlib.suppress(Exception): self._writer.write(_ws_encode_frame(b"", opcode=0x8)) await self._writer.drain() self._writer.close() with contextlib.suppress(Exception): await self._writer.wait_closed() async def _ws_connect(host: str, port: int, path: str, ssl_ctx: ssl.SSLContext) -> _WebSocket: """Open a TLS connection, perform the HTTP→WebSocket upgrade, return a _WebSocket.""" reader, writer = await asyncio.open_connection(host, port, ssl=ssl_ctx) key = base64.b64encode(os.urandom(16)).decode() writer.write(( f"GET {path} HTTP/1.1\r\n" f"Host: {host}:{port}\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {key}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"\r\n" ).encode()) await writer.drain() response_lines: list[bytes] = [] while True: line = await asyncio.wait_for(reader.readline(), timeout=20) if not line: raise OSError("Connection closed during WebSocket handshake") response_lines.append(line) if line in (b"\r\n", b"\n"): break status = response_lines[0].decode("latin-1").strip() if " 101 " not in status: raise OSError(f"WebSocket upgrade failed: {status}") expected = base64.b64encode( hashlib.sha1( (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode() ).digest() ).decode().lower() headers_text = b"".join(response_lines).decode("latin-1").lower() if expected not in headers_text: raise OSError("WebSocket upgrade: Sec-WebSocket-Accept mismatch") return _WebSocket(reader, writer) # ───────────────────────────────────────────────────────────────────────────── # TrueNAS JSON-RPC 2.0 client # ───────────────────────────────────────────────────────────────────────────── class TrueNASClient: """ Minimal async JSON-RPC 2.0 client for the TrueNAS WebSocket API. TrueNAS 25.04+ endpoint: wss://:/api/current Authentication: auth.login_with_api_key """ def __init__( self, host: str, api_key: str, port: int = 443, verify_ssl: bool = False, ) -> None: self._host = host self._port = port self._api_key = api_key self._verify_ssl = verify_ssl self._ws = None self._call_id = 0 @property def _url(self) -> str: return f"wss://{self._host}:{self._port}/api/current" async def __aenter__(self) -> "TrueNASClient": await self._connect() return self async def __aexit__(self, *_: Any) -> None: if self._ws: await self._ws.close() self._ws = None async def _connect(self) -> None: ctx = ssl.create_default_context() if not self._verify_ssl: ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE log.info("Connecting to %s …", self._url) try: self._ws = await _ws_connect( host=self._host, port=self._port, path="/api/current", ssl_ctx=ctx, ) except (OSError, asyncio.TimeoutError) as exc: log.error("Connection failed: %s", exc) raise log.info("Authenticating with API key …") result = await self.call("auth.login_with_api_key", [self._api_key]) if result is not True and result != "SUCCESS": raise PermissionError(f"Authentication rejected: {result!r}") log.info("Connected and authenticated.") async def call(self, method: str, params: Optional[list] = None) -> Any: """Send one JSON-RPC request and return its result. Raises RuntimeError if the API returns an error. """ self._call_id += 1 req_id = self._call_id await self._ws.send(json.dumps({ "jsonrpc": "2.0", "id": req_id, "method": method, "params": params or [], })) while True: raw = await asyncio.wait_for(self._ws.recv(), timeout=60) msg = json.loads(raw) if "id" not in msg: continue if msg["id"] != req_id: continue if "error" in msg: err = msg["error"] reason = ( err.get("data", {}).get("reason") or err.get("message") or repr(err) ) raise RuntimeError(f"API error [{method}]: {reason}") return msg.get("result") # ───────────────────────────────────────────────────────────────────────────── # Dataset utilities # ───────────────────────────────────────────────────────────────────────────── async def check_dataset_paths( client: TrueNASClient, paths: list[str], ) -> list[str]: """ Return the subset of *paths* that have no matching ZFS dataset on the destination. Returns an empty list when the dataset query itself fails. """ if not paths: return [] unique = sorted({p.rstrip("/") for p in paths if p}) log.info("Checking %d share path(s) against destination datasets …", len(unique)) try: datasets = await client.call("pool.dataset.query") or [] except RuntimeError as exc: log.warning("Could not query datasets (skipping check): %s", exc) return [] mountpoints = { d.get("mountpoint", "").rstrip("/") for d in datasets if d.get("mountpoint") } missing = [p for p in unique if p not in mountpoints] if missing: for p in missing: log.warning(" MISSING dataset for path: %s", p) else: log.info(" All share paths exist as datasets.") return missing async def create_dataset(client: TrueNASClient, path: str) -> bool: """ Create a ZFS dataset whose mountpoint will be *path*. *path* must be an absolute /mnt/… path. Returns True on success, False on failure. """ if not path.startswith("/mnt/"): log.error("Cannot auto-create dataset for non-/mnt/ path: %s", path) return False name = path[5:].rstrip("/") log.info("Creating dataset %r …", name) try: await client.call("pool.dataset.create", [{"name": name}]) log.info(" Created: %s", name) return True except RuntimeError as exc: log.error(" Failed to create dataset %r: %s", name, exc) return False async def create_missing_datasets( host: str, port: int, api_key: str, paths: list[str], verify_ssl: bool = False, ) -> None: """Open a fresh connection and create ZFS datasets for *paths*.""" async with TrueNASClient( host=host, port=port, api_key=api_key, verify_ssl=verify_ssl, ) as client: for path in paths: await create_dataset(client, path)