Files
scott c28ce9e3b8 Add destination audit wizard with selective deletion
New top-level wizard option (2) lets users inspect and clean up an
existing destination before migration. Queries all SMB shares, NFS
exports, iSCSI objects, datasets, and zvols; displays a structured
inventory report; then offers per-category deletion with escalating
warnings — standard confirm for shares/iSCSI, explicit "DELETE" phrase
required for zvols and datasets to guard against accidental data loss.

Adds to client.py: query_destination_inventory, delete_smb_shares,
delete_nfs_exports, delete_zvols, delete_datasets.
Adds to cli.py: _fmt_bytes, _print_inventory_report, _run_audit_wizard.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-05 15:56:10 -05:00

489 lines
17 KiB
Python

"""TrueNAS WebSocket client and dataset utilities."""
from __future__ import annotations
import asyncio
import base64
import contextlib
import hashlib
import json
import os
import ssl
import struct
from typing import Any, Optional
from .colors import log
# ─────────────────────────────────────────────────────────────────────────────
# Raw WebSocket implementation (stdlib only, RFC 6455)
# ─────────────────────────────────────────────────────────────────────────────
def _ws_mask(data: bytes, mask: bytes) -> bytes:
"""XOR *data* with a 4-byte repeating mask key."""
out = bytearray(data)
for i in range(len(out)):
out[i] ^= mask[i & 3]
return bytes(out)
def _ws_encode_frame(payload: bytes, opcode: int = 0x1) -> bytes:
"""Encode a masked client→server WebSocket frame."""
mask = os.urandom(4)
length = len(payload)
header = bytearray([0x80 | opcode])
if length < 126:
header.append(0x80 | length)
elif length < 65536:
header.append(0x80 | 126)
header += struct.pack("!H", length)
else:
header.append(0x80 | 127)
header += struct.pack("!Q", length)
return bytes(header) + mask + _ws_mask(payload, mask)
async def _ws_recv_message(reader: asyncio.StreamReader) -> str:
"""
Read one complete WebSocket message, reassembling continuation frames.
Skips ping/pong control frames. Raises OSError on close frame.
"""
fragments: list[bytes] = []
while True:
hdr = await reader.readexactly(2)
fin = bool(hdr[0] & 0x80)
opcode = hdr[0] & 0x0F
masked = bool(hdr[1] & 0x80)
length = hdr[1] & 0x7F
if length == 126:
length = struct.unpack("!H", await reader.readexactly(2))[0]
elif length == 127:
length = struct.unpack("!Q", await reader.readexactly(8))[0]
mask_key = await reader.readexactly(4) if masked else None
payload = await reader.readexactly(length) if length else b""
if mask_key:
payload = _ws_mask(payload, mask_key)
if opcode == 0x8:
raise OSError("WebSocket: server sent close frame")
if opcode in (0x9, 0xA):
continue
fragments.append(payload)
if fin:
return b"".join(fragments).decode("utf-8")
class _WebSocket:
"""asyncio StreamReader/Writer wrapped to a simple send/recv/close API."""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self._reader = reader
self._writer = writer
async def send(self, data: str) -> None:
self._writer.write(_ws_encode_frame(data.encode("utf-8"), opcode=0x1))
await self._writer.drain()
async def recv(self) -> str:
return await _ws_recv_message(self._reader)
async def close(self) -> None:
with contextlib.suppress(Exception):
self._writer.write(_ws_encode_frame(b"", opcode=0x8))
await self._writer.drain()
self._writer.close()
with contextlib.suppress(Exception):
await self._writer.wait_closed()
async def _ws_connect(host: str, port: int, path: str, ssl_ctx: ssl.SSLContext) -> _WebSocket:
"""Open a TLS connection, perform the HTTP→WebSocket upgrade, return a _WebSocket."""
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_ctx)
key = base64.b64encode(os.urandom(16)).decode()
writer.write((
f"GET {path} HTTP/1.1\r\n"
f"Host: {host}:{port}\r\n"
f"Upgrade: websocket\r\n"
f"Connection: Upgrade\r\n"
f"Sec-WebSocket-Key: {key}\r\n"
f"Sec-WebSocket-Version: 13\r\n"
f"\r\n"
).encode())
await writer.drain()
response_lines: list[bytes] = []
while True:
line = await asyncio.wait_for(reader.readline(), timeout=20)
if not line:
raise OSError("Connection closed during WebSocket handshake")
response_lines.append(line)
if line in (b"\r\n", b"\n"):
break
status = response_lines[0].decode("latin-1").strip()
if " 101 " not in status:
raise OSError(f"WebSocket upgrade failed: {status}")
expected = base64.b64encode(
hashlib.sha1(
(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()
).digest()
).decode().lower()
headers_text = b"".join(response_lines).decode("latin-1").lower()
if expected not in headers_text:
raise OSError("WebSocket upgrade: Sec-WebSocket-Accept mismatch")
return _WebSocket(reader, writer)
# ─────────────────────────────────────────────────────────────────────────────
# TrueNAS JSON-RPC 2.0 client
# ─────────────────────────────────────────────────────────────────────────────
class TrueNASClient:
"""
Minimal async JSON-RPC 2.0 client for the TrueNAS WebSocket API.
TrueNAS 25.04+ endpoint: wss://<host>:<port>/api/current
Authentication: auth.login_with_api_key
"""
def __init__(
self,
host: str,
api_key: str,
port: int = 443,
verify_ssl: bool = False,
) -> None:
self._host = host
self._port = port
self._api_key = api_key
self._verify_ssl = verify_ssl
self._ws = None
self._call_id = 0
@property
def _url(self) -> str:
return f"wss://{self._host}:{self._port}/api/current"
async def __aenter__(self) -> "TrueNASClient":
await self._connect()
return self
async def __aexit__(self, *_: Any) -> None:
if self._ws:
await self._ws.close()
self._ws = None
async def _connect(self) -> None:
ctx = ssl.create_default_context()
if not self._verify_ssl:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
log.info("Connecting to %s", self._url)
try:
self._ws = await _ws_connect(
host=self._host,
port=self._port,
path="/api/current",
ssl_ctx=ctx,
)
except (OSError, asyncio.TimeoutError) as exc:
log.error("Connection failed: %s", exc)
raise
log.info("Authenticating with API key …")
result = await self.call("auth.login_with_api_key", [self._api_key])
if result is not True and result != "SUCCESS":
raise PermissionError(f"Authentication rejected: {result!r}")
log.info("Connected and authenticated.")
async def call(self, method: str, params: Optional[list] = None) -> Any:
"""Send one JSON-RPC request and return its result.
Raises RuntimeError if the API returns an error.
"""
self._call_id += 1
req_id = self._call_id
await self._ws.send(json.dumps({
"jsonrpc": "2.0",
"id": req_id,
"method": method,
"params": params or [],
}))
while True:
raw = await asyncio.wait_for(self._ws.recv(), timeout=60)
msg = json.loads(raw)
if "id" not in msg:
continue
if msg["id"] != req_id:
continue
if "error" in msg:
err = msg["error"]
reason = (
err.get("data", {}).get("reason")
or err.get("message")
or repr(err)
)
raise RuntimeError(f"API error [{method}]: {reason}")
return msg.get("result")
# ─────────────────────────────────────────────────────────────────────────────
# Dataset utilities
# ─────────────────────────────────────────────────────────────────────────────
async def check_dataset_paths(
client: TrueNASClient,
paths: list[str],
) -> list[str]:
"""
Return the subset of *paths* that have no matching ZFS dataset on the
destination. Returns an empty list when the dataset query itself fails.
"""
if not paths:
return []
unique = sorted({p.rstrip("/") for p in paths if p})
log.info("Checking %d share path(s) against destination datasets …", len(unique))
try:
datasets = await client.call("pool.dataset.query") or []
except RuntimeError as exc:
log.warning("Could not query datasets (skipping check): %s", exc)
return []
mountpoints = {
d.get("mountpoint", "").rstrip("/")
for d in datasets
if d.get("mountpoint")
}
missing = [p for p in unique if p not in mountpoints]
if missing:
for p in missing:
log.warning(" MISSING dataset for path: %s", p)
else:
log.info(" All share paths exist as datasets.")
return missing
async def create_dataset(client: TrueNASClient, path: str) -> bool:
"""
Create a ZFS dataset whose mountpoint will be *path*.
*path* must be an absolute /mnt/… path.
Returns True on success, False on failure.
"""
if not path.startswith("/mnt/"):
log.error("Cannot auto-create dataset for non-/mnt/ path: %s", path)
return False
name = path[5:].rstrip("/")
log.info("Creating dataset %r", name)
try:
await client.call("pool.dataset.create", [{"name": name}])
log.info(" Created: %s", name)
return True
except RuntimeError as exc:
log.error(" Failed to create dataset %r: %s", name, exc)
return False
async def create_missing_datasets(
host: str,
port: int,
api_key: str,
paths: list[str],
verify_ssl: bool = False,
) -> None:
"""Open a fresh connection and create ZFS datasets for *paths*."""
async with TrueNASClient(
host=host, port=port, api_key=api_key, verify_ssl=verify_ssl,
) as client:
for path in paths:
await create_dataset(client, path)
# ─────────────────────────────────────────────────────────────────────────────
# iSCSI zvol utilities
# ─────────────────────────────────────────────────────────────────────────────
async def check_iscsi_zvols(
client: TrueNASClient,
zvol_names: list[str],
) -> list[str]:
"""
Return the subset of *zvol_names* that do not exist on the destination.
Names are the dataset path without the leading 'zvol/' prefix
(e.g. 'tank/VMWARE001'). Returns [] when the query itself fails.
"""
if not zvol_names:
return []
unique = sorted(set(zvol_names))
log.info("Checking %d zvol(s) against destination datasets …", len(unique))
try:
datasets = await client.call(
"pool.dataset.query", [[["type", "=", "VOLUME"]]]
) or []
except RuntimeError as exc:
log.warning("Could not query zvols (skipping check): %s", exc)
return []
existing = {d["name"] for d in datasets}
missing = [n for n in unique if n not in existing]
if missing:
for n in missing:
log.warning(" MISSING zvol: %s", n)
else:
log.info(" All iSCSI zvols exist on destination.")
return missing
async def create_zvol(
client: TrueNASClient,
name: str,
volsize: int,
) -> bool:
"""
Create a ZFS volume (zvol) on the destination.
*name* is the dataset path (e.g. 'tank/VMWARE001').
*volsize* is the size in bytes.
Returns True on success, False on failure.
"""
log.info("Creating zvol %r (%d bytes) …", name, volsize)
try:
await client.call("pool.dataset.create", [{
"name": name,
"type": "VOLUME",
"volsize": volsize,
}])
log.info(" Created: %s", name)
return True
except RuntimeError as exc:
log.error(" Failed to create zvol %r: %s", name, exc)
return False
async def create_missing_zvols(
host: str,
port: int,
api_key: str,
zvols: dict[str, int],
verify_ssl: bool = False,
) -> None:
"""Open a fresh connection and create zvols from {name: volsize_bytes}."""
async with TrueNASClient(
host=host, port=port, api_key=api_key, verify_ssl=verify_ssl,
) as client:
for name, volsize in zvols.items():
await create_zvol(client, name, volsize)
# ─────────────────────────────────────────────────────────────────────────────
# Destination inventory
# ─────────────────────────────────────────────────────────────────────────────
async def query_destination_inventory(client: TrueNASClient) -> dict[str, list]:
"""
Query all current configuration from the destination system.
Returns a dict with keys: smb_shares, nfs_exports, datasets, zvols,
iscsi_extents, iscsi_initiators, iscsi_portals, iscsi_targets, iscsi_targetextents.
Each value is a list (may be empty if the query fails or returns nothing).
"""
result: dict[str, list] = {}
for key, method, params in [
("smb_shares", "sharing.smb.query", None),
("nfs_exports", "sharing.nfs.query", None),
("datasets", "pool.dataset.query", [[["type", "=", "FILESYSTEM"]]]),
("zvols", "pool.dataset.query", [[["type", "=", "VOLUME"]]]),
("iscsi_extents", "iscsi.extent.query", None),
("iscsi_initiators", "iscsi.initiator.query", None),
("iscsi_portals", "iscsi.portal.query", None),
("iscsi_targets", "iscsi.target.query", None),
("iscsi_targetextents", "iscsi.targetextent.query", None),
]:
try:
result[key] = await client.call(method, params) or []
except RuntimeError as exc:
log.warning("Could not query %s: %s", key, exc)
result[key] = []
return result
async def delete_smb_shares(
client: TrueNASClient, shares: list[dict]
) -> tuple[int, int]:
"""Delete SMB shares by ID. Returns (deleted, failed)."""
deleted = failed = 0
for share in shares:
try:
await client.call("sharing.smb.delete", [share["id"]])
log.info(" Deleted SMB share %r", share.get("name"))
deleted += 1
except RuntimeError as exc:
log.error(" Failed to delete SMB share %r: %s", share.get("name"), exc)
failed += 1
return deleted, failed
async def delete_nfs_exports(
client: TrueNASClient, exports: list[dict]
) -> tuple[int, int]:
"""Delete NFS exports by ID. Returns (deleted, failed)."""
deleted = failed = 0
for export in exports:
try:
await client.call("sharing.nfs.delete", [export["id"]])
log.info(" Deleted NFS export %r", export.get("path"))
deleted += 1
except RuntimeError as exc:
log.error(" Failed to delete NFS export %r: %s", export.get("path"), exc)
failed += 1
return deleted, failed
async def delete_zvols(
client: TrueNASClient, zvols: list[dict]
) -> tuple[int, int]:
"""Delete zvols. Returns (deleted, failed)."""
deleted = failed = 0
for zvol in zvols:
try:
await client.call("pool.dataset.delete", [zvol["id"], {"recursive": True}])
log.info(" Deleted zvol %r", zvol["id"])
deleted += 1
except RuntimeError as exc:
log.error(" Failed to delete zvol %r: %s", zvol["id"], exc)
failed += 1
return deleted, failed
async def delete_datasets(
client: TrueNASClient, datasets: list[dict]
) -> tuple[int, int]:
"""
Delete datasets deepest-first to avoid parent-before-child errors.
Skips pool root datasets (no '/' in the dataset name).
Returns (deleted, failed).
"""
sorted_ds = sorted(
(d for d in datasets if "/" in d["id"]),
key=lambda d: d["id"].count("/"),
reverse=True,
)
deleted = failed = 0
for ds in sorted_ds:
try:
await client.call("pool.dataset.delete", [ds["id"], {"recursive": True}])
log.info(" Deleted dataset %r", ds["id"])
deleted += 1
except RuntimeError as exc:
log.error(" Failed to delete dataset %r: %s", ds["id"], exc)
failed += 1
return deleted, failed