New top-level wizard option (2) lets users inspect and clean up an existing destination before migration. Queries all SMB shares, NFS exports, iSCSI objects, datasets, and zvols; displays a structured inventory report; then offers per-category deletion with escalating warnings — standard confirm for shares/iSCSI, explicit "DELETE" phrase required for zvols and datasets to guard against accidental data loss. Adds to client.py: query_destination_inventory, delete_smb_shares, delete_nfs_exports, delete_zvols, delete_datasets. Adds to cli.py: _fmt_bytes, _print_inventory_report, _run_audit_wizard. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
489 lines
17 KiB
Python
489 lines
17 KiB
Python
"""TrueNAS WebSocket client and dataset utilities."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import contextlib
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import ssl
|
|
import struct
|
|
from typing import Any, Optional
|
|
|
|
from .colors import log
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# Raw WebSocket implementation (stdlib only, RFC 6455)
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
def _ws_mask(data: bytes, mask: bytes) -> bytes:
|
|
"""XOR *data* with a 4-byte repeating mask key."""
|
|
out = bytearray(data)
|
|
for i in range(len(out)):
|
|
out[i] ^= mask[i & 3]
|
|
return bytes(out)
|
|
|
|
|
|
def _ws_encode_frame(payload: bytes, opcode: int = 0x1) -> bytes:
|
|
"""Encode a masked client→server WebSocket frame."""
|
|
mask = os.urandom(4)
|
|
length = len(payload)
|
|
header = bytearray([0x80 | opcode])
|
|
if length < 126:
|
|
header.append(0x80 | length)
|
|
elif length < 65536:
|
|
header.append(0x80 | 126)
|
|
header += struct.pack("!H", length)
|
|
else:
|
|
header.append(0x80 | 127)
|
|
header += struct.pack("!Q", length)
|
|
return bytes(header) + mask + _ws_mask(payload, mask)
|
|
|
|
|
|
async def _ws_recv_message(reader: asyncio.StreamReader) -> str:
|
|
"""
|
|
Read one complete WebSocket message, reassembling continuation frames.
|
|
Skips ping/pong control frames. Raises OSError on close frame.
|
|
"""
|
|
fragments: list[bytes] = []
|
|
while True:
|
|
hdr = await reader.readexactly(2)
|
|
fin = bool(hdr[0] & 0x80)
|
|
opcode = hdr[0] & 0x0F
|
|
masked = bool(hdr[1] & 0x80)
|
|
length = hdr[1] & 0x7F
|
|
|
|
if length == 126:
|
|
length = struct.unpack("!H", await reader.readexactly(2))[0]
|
|
elif length == 127:
|
|
length = struct.unpack("!Q", await reader.readexactly(8))[0]
|
|
|
|
mask_key = await reader.readexactly(4) if masked else None
|
|
payload = await reader.readexactly(length) if length else b""
|
|
if mask_key:
|
|
payload = _ws_mask(payload, mask_key)
|
|
|
|
if opcode == 0x8:
|
|
raise OSError("WebSocket: server sent close frame")
|
|
if opcode in (0x9, 0xA):
|
|
continue
|
|
|
|
fragments.append(payload)
|
|
if fin:
|
|
return b"".join(fragments).decode("utf-8")
|
|
|
|
|
|
class _WebSocket:
|
|
"""asyncio StreamReader/Writer wrapped to a simple send/recv/close API."""
|
|
|
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
self._reader = reader
|
|
self._writer = writer
|
|
|
|
async def send(self, data: str) -> None:
|
|
self._writer.write(_ws_encode_frame(data.encode("utf-8"), opcode=0x1))
|
|
await self._writer.drain()
|
|
|
|
async def recv(self) -> str:
|
|
return await _ws_recv_message(self._reader)
|
|
|
|
async def close(self) -> None:
|
|
with contextlib.suppress(Exception):
|
|
self._writer.write(_ws_encode_frame(b"", opcode=0x8))
|
|
await self._writer.drain()
|
|
self._writer.close()
|
|
with contextlib.suppress(Exception):
|
|
await self._writer.wait_closed()
|
|
|
|
|
|
async def _ws_connect(host: str, port: int, path: str, ssl_ctx: ssl.SSLContext) -> _WebSocket:
|
|
"""Open a TLS connection, perform the HTTP→WebSocket upgrade, return a _WebSocket."""
|
|
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_ctx)
|
|
|
|
key = base64.b64encode(os.urandom(16)).decode()
|
|
writer.write((
|
|
f"GET {path} HTTP/1.1\r\n"
|
|
f"Host: {host}:{port}\r\n"
|
|
f"Upgrade: websocket\r\n"
|
|
f"Connection: Upgrade\r\n"
|
|
f"Sec-WebSocket-Key: {key}\r\n"
|
|
f"Sec-WebSocket-Version: 13\r\n"
|
|
f"\r\n"
|
|
).encode())
|
|
await writer.drain()
|
|
|
|
response_lines: list[bytes] = []
|
|
while True:
|
|
line = await asyncio.wait_for(reader.readline(), timeout=20)
|
|
if not line:
|
|
raise OSError("Connection closed during WebSocket handshake")
|
|
response_lines.append(line)
|
|
if line in (b"\r\n", b"\n"):
|
|
break
|
|
|
|
status = response_lines[0].decode("latin-1").strip()
|
|
if " 101 " not in status:
|
|
raise OSError(f"WebSocket upgrade failed: {status}")
|
|
|
|
expected = base64.b64encode(
|
|
hashlib.sha1(
|
|
(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()
|
|
).digest()
|
|
).decode().lower()
|
|
headers_text = b"".join(response_lines).decode("latin-1").lower()
|
|
if expected not in headers_text:
|
|
raise OSError("WebSocket upgrade: Sec-WebSocket-Accept mismatch")
|
|
|
|
return _WebSocket(reader, writer)
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# TrueNAS JSON-RPC 2.0 client
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
class TrueNASClient:
|
|
"""
|
|
Minimal async JSON-RPC 2.0 client for the TrueNAS WebSocket API.
|
|
|
|
TrueNAS 25.04+ endpoint: wss://<host>:<port>/api/current
|
|
Authentication: auth.login_with_api_key
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str,
|
|
api_key: str,
|
|
port: int = 443,
|
|
verify_ssl: bool = False,
|
|
) -> None:
|
|
self._host = host
|
|
self._port = port
|
|
self._api_key = api_key
|
|
self._verify_ssl = verify_ssl
|
|
self._ws = None
|
|
self._call_id = 0
|
|
|
|
@property
|
|
def _url(self) -> str:
|
|
return f"wss://{self._host}:{self._port}/api/current"
|
|
|
|
async def __aenter__(self) -> "TrueNASClient":
|
|
await self._connect()
|
|
return self
|
|
|
|
async def __aexit__(self, *_: Any) -> None:
|
|
if self._ws:
|
|
await self._ws.close()
|
|
self._ws = None
|
|
|
|
async def _connect(self) -> None:
|
|
ctx = ssl.create_default_context()
|
|
if not self._verify_ssl:
|
|
ctx.check_hostname = False
|
|
ctx.verify_mode = ssl.CERT_NONE
|
|
|
|
log.info("Connecting to %s …", self._url)
|
|
try:
|
|
self._ws = await _ws_connect(
|
|
host=self._host,
|
|
port=self._port,
|
|
path="/api/current",
|
|
ssl_ctx=ctx,
|
|
)
|
|
except (OSError, asyncio.TimeoutError) as exc:
|
|
log.error("Connection failed: %s", exc)
|
|
raise
|
|
|
|
log.info("Authenticating with API key …")
|
|
result = await self.call("auth.login_with_api_key", [self._api_key])
|
|
if result is not True and result != "SUCCESS":
|
|
raise PermissionError(f"Authentication rejected: {result!r}")
|
|
log.info("Connected and authenticated.")
|
|
|
|
async def call(self, method: str, params: Optional[list] = None) -> Any:
|
|
"""Send one JSON-RPC request and return its result.
|
|
Raises RuntimeError if the API returns an error.
|
|
"""
|
|
self._call_id += 1
|
|
req_id = self._call_id
|
|
|
|
await self._ws.send(json.dumps({
|
|
"jsonrpc": "2.0",
|
|
"id": req_id,
|
|
"method": method,
|
|
"params": params or [],
|
|
}))
|
|
|
|
while True:
|
|
raw = await asyncio.wait_for(self._ws.recv(), timeout=60)
|
|
msg = json.loads(raw)
|
|
if "id" not in msg:
|
|
continue
|
|
if msg["id"] != req_id:
|
|
continue
|
|
if "error" in msg:
|
|
err = msg["error"]
|
|
reason = (
|
|
err.get("data", {}).get("reason")
|
|
or err.get("message")
|
|
or repr(err)
|
|
)
|
|
raise RuntimeError(f"API error [{method}]: {reason}")
|
|
return msg.get("result")
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# Dataset utilities
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
async def check_dataset_paths(
|
|
client: TrueNASClient,
|
|
paths: list[str],
|
|
) -> list[str]:
|
|
"""
|
|
Return the subset of *paths* that have no matching ZFS dataset on the
|
|
destination. Returns an empty list when the dataset query itself fails.
|
|
"""
|
|
if not paths:
|
|
return []
|
|
|
|
unique = sorted({p.rstrip("/") for p in paths if p})
|
|
log.info("Checking %d share path(s) against destination datasets …", len(unique))
|
|
try:
|
|
datasets = await client.call("pool.dataset.query") or []
|
|
except RuntimeError as exc:
|
|
log.warning("Could not query datasets (skipping check): %s", exc)
|
|
return []
|
|
|
|
mountpoints = {
|
|
d.get("mountpoint", "").rstrip("/")
|
|
for d in datasets
|
|
if d.get("mountpoint")
|
|
}
|
|
|
|
missing = [p for p in unique if p not in mountpoints]
|
|
if missing:
|
|
for p in missing:
|
|
log.warning(" MISSING dataset for path: %s", p)
|
|
else:
|
|
log.info(" All share paths exist as datasets.")
|
|
return missing
|
|
|
|
|
|
async def create_dataset(client: TrueNASClient, path: str) -> bool:
|
|
"""
|
|
Create a ZFS dataset whose mountpoint will be *path*.
|
|
*path* must be an absolute /mnt/… path.
|
|
Returns True on success, False on failure.
|
|
"""
|
|
if not path.startswith("/mnt/"):
|
|
log.error("Cannot auto-create dataset for non-/mnt/ path: %s", path)
|
|
return False
|
|
|
|
name = path[5:].rstrip("/")
|
|
log.info("Creating dataset %r …", name)
|
|
try:
|
|
await client.call("pool.dataset.create", [{"name": name}])
|
|
log.info(" Created: %s", name)
|
|
return True
|
|
except RuntimeError as exc:
|
|
log.error(" Failed to create dataset %r: %s", name, exc)
|
|
return False
|
|
|
|
|
|
async def create_missing_datasets(
|
|
host: str,
|
|
port: int,
|
|
api_key: str,
|
|
paths: list[str],
|
|
verify_ssl: bool = False,
|
|
) -> None:
|
|
"""Open a fresh connection and create ZFS datasets for *paths*."""
|
|
async with TrueNASClient(
|
|
host=host, port=port, api_key=api_key, verify_ssl=verify_ssl,
|
|
) as client:
|
|
for path in paths:
|
|
await create_dataset(client, path)
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# iSCSI zvol utilities
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
async def check_iscsi_zvols(
|
|
client: TrueNASClient,
|
|
zvol_names: list[str],
|
|
) -> list[str]:
|
|
"""
|
|
Return the subset of *zvol_names* that do not exist on the destination.
|
|
Names are the dataset path without the leading 'zvol/' prefix
|
|
(e.g. 'tank/VMWARE001'). Returns [] when the query itself fails.
|
|
"""
|
|
if not zvol_names:
|
|
return []
|
|
|
|
unique = sorted(set(zvol_names))
|
|
log.info("Checking %d zvol(s) against destination datasets …", len(unique))
|
|
try:
|
|
datasets = await client.call(
|
|
"pool.dataset.query", [[["type", "=", "VOLUME"]]]
|
|
) or []
|
|
except RuntimeError as exc:
|
|
log.warning("Could not query zvols (skipping check): %s", exc)
|
|
return []
|
|
|
|
existing = {d["name"] for d in datasets}
|
|
missing = [n for n in unique if n not in existing]
|
|
if missing:
|
|
for n in missing:
|
|
log.warning(" MISSING zvol: %s", n)
|
|
else:
|
|
log.info(" All iSCSI zvols exist on destination.")
|
|
return missing
|
|
|
|
|
|
async def create_zvol(
|
|
client: TrueNASClient,
|
|
name: str,
|
|
volsize: int,
|
|
) -> bool:
|
|
"""
|
|
Create a ZFS volume (zvol) on the destination.
|
|
*name* is the dataset path (e.g. 'tank/VMWARE001').
|
|
*volsize* is the size in bytes.
|
|
Returns True on success, False on failure.
|
|
"""
|
|
log.info("Creating zvol %r (%d bytes) …", name, volsize)
|
|
try:
|
|
await client.call("pool.dataset.create", [{
|
|
"name": name,
|
|
"type": "VOLUME",
|
|
"volsize": volsize,
|
|
}])
|
|
log.info(" Created: %s", name)
|
|
return True
|
|
except RuntimeError as exc:
|
|
log.error(" Failed to create zvol %r: %s", name, exc)
|
|
return False
|
|
|
|
|
|
async def create_missing_zvols(
|
|
host: str,
|
|
port: int,
|
|
api_key: str,
|
|
zvols: dict[str, int],
|
|
verify_ssl: bool = False,
|
|
) -> None:
|
|
"""Open a fresh connection and create zvols from {name: volsize_bytes}."""
|
|
async with TrueNASClient(
|
|
host=host, port=port, api_key=api_key, verify_ssl=verify_ssl,
|
|
) as client:
|
|
for name, volsize in zvols.items():
|
|
await create_zvol(client, name, volsize)
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
# Destination inventory
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
async def query_destination_inventory(client: TrueNASClient) -> dict[str, list]:
|
|
"""
|
|
Query all current configuration from the destination system.
|
|
Returns a dict with keys: smb_shares, nfs_exports, datasets, zvols,
|
|
iscsi_extents, iscsi_initiators, iscsi_portals, iscsi_targets, iscsi_targetextents.
|
|
Each value is a list (may be empty if the query fails or returns nothing).
|
|
"""
|
|
result: dict[str, list] = {}
|
|
for key, method, params in [
|
|
("smb_shares", "sharing.smb.query", None),
|
|
("nfs_exports", "sharing.nfs.query", None),
|
|
("datasets", "pool.dataset.query", [[["type", "=", "FILESYSTEM"]]]),
|
|
("zvols", "pool.dataset.query", [[["type", "=", "VOLUME"]]]),
|
|
("iscsi_extents", "iscsi.extent.query", None),
|
|
("iscsi_initiators", "iscsi.initiator.query", None),
|
|
("iscsi_portals", "iscsi.portal.query", None),
|
|
("iscsi_targets", "iscsi.target.query", None),
|
|
("iscsi_targetextents", "iscsi.targetextent.query", None),
|
|
]:
|
|
try:
|
|
result[key] = await client.call(method, params) or []
|
|
except RuntimeError as exc:
|
|
log.warning("Could not query %s: %s", key, exc)
|
|
result[key] = []
|
|
return result
|
|
|
|
|
|
async def delete_smb_shares(
|
|
client: TrueNASClient, shares: list[dict]
|
|
) -> tuple[int, int]:
|
|
"""Delete SMB shares by ID. Returns (deleted, failed)."""
|
|
deleted = failed = 0
|
|
for share in shares:
|
|
try:
|
|
await client.call("sharing.smb.delete", [share["id"]])
|
|
log.info(" Deleted SMB share %r", share.get("name"))
|
|
deleted += 1
|
|
except RuntimeError as exc:
|
|
log.error(" Failed to delete SMB share %r: %s", share.get("name"), exc)
|
|
failed += 1
|
|
return deleted, failed
|
|
|
|
|
|
async def delete_nfs_exports(
|
|
client: TrueNASClient, exports: list[dict]
|
|
) -> tuple[int, int]:
|
|
"""Delete NFS exports by ID. Returns (deleted, failed)."""
|
|
deleted = failed = 0
|
|
for export in exports:
|
|
try:
|
|
await client.call("sharing.nfs.delete", [export["id"]])
|
|
log.info(" Deleted NFS export %r", export.get("path"))
|
|
deleted += 1
|
|
except RuntimeError as exc:
|
|
log.error(" Failed to delete NFS export %r: %s", export.get("path"), exc)
|
|
failed += 1
|
|
return deleted, failed
|
|
|
|
|
|
async def delete_zvols(
|
|
client: TrueNASClient, zvols: list[dict]
|
|
) -> tuple[int, int]:
|
|
"""Delete zvols. Returns (deleted, failed)."""
|
|
deleted = failed = 0
|
|
for zvol in zvols:
|
|
try:
|
|
await client.call("pool.dataset.delete", [zvol["id"], {"recursive": True}])
|
|
log.info(" Deleted zvol %r", zvol["id"])
|
|
deleted += 1
|
|
except RuntimeError as exc:
|
|
log.error(" Failed to delete zvol %r: %s", zvol["id"], exc)
|
|
failed += 1
|
|
return deleted, failed
|
|
|
|
|
|
async def delete_datasets(
|
|
client: TrueNASClient, datasets: list[dict]
|
|
) -> tuple[int, int]:
|
|
"""
|
|
Delete datasets deepest-first to avoid parent-before-child errors.
|
|
Skips pool root datasets (no '/' in the dataset name).
|
|
Returns (deleted, failed).
|
|
"""
|
|
sorted_ds = sorted(
|
|
(d for d in datasets if "/" in d["id"]),
|
|
key=lambda d: d["id"].count("/"),
|
|
reverse=True,
|
|
)
|
|
deleted = failed = 0
|
|
for ds in sorted_ds:
|
|
try:
|
|
await client.call("pool.dataset.delete", [ds["id"], {"recursive": True}])
|
|
log.info(" Deleted dataset %r", ds["id"])
|
|
deleted += 1
|
|
except RuntimeError as exc:
|
|
log.error(" Failed to delete dataset %r: %s", ds["id"], exc)
|
|
failed += 1
|
|
return deleted, failed
|