Restructure into package: truenas_migrate/
Split single-file script into focused modules: colors.py – ANSI helpers and shared logger summary.py – Summary dataclass and report renderer archive.py – Debug archive parser (SCALE + CORE layouts) client.py – WebSocket engine, TrueNASClient, dataset utilities migrate.py – Payload builders, migrate_smb_shares, migrate_nfs_shares cli.py – Interactive wizard, argparse, run(), main() __main__.py – python -m truenas_migrate entry point truenas_migrate.py retained as a one-line compatibility shim. Both 'python truenas_migrate.py' and 'python -m truenas_migrate' work. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
308
truenas_migrate/client.py
Normal file
308
truenas_migrate/client.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""TrueNAS WebSocket client and dataset utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import struct
|
||||
from typing import Any, Optional
|
||||
|
||||
from .colors import log
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Raw WebSocket implementation (stdlib only, RFC 6455)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _ws_mask(data: bytes, mask: bytes) -> bytes:
|
||||
"""XOR *data* with a 4-byte repeating mask key."""
|
||||
out = bytearray(data)
|
||||
for i in range(len(out)):
|
||||
out[i] ^= mask[i & 3]
|
||||
return bytes(out)
|
||||
|
||||
|
||||
def _ws_encode_frame(payload: bytes, opcode: int = 0x1) -> bytes:
|
||||
"""Encode a masked client→server WebSocket frame."""
|
||||
mask = os.urandom(4)
|
||||
length = len(payload)
|
||||
header = bytearray([0x80 | opcode])
|
||||
if length < 126:
|
||||
header.append(0x80 | length)
|
||||
elif length < 65536:
|
||||
header.append(0x80 | 126)
|
||||
header += struct.pack("!H", length)
|
||||
else:
|
||||
header.append(0x80 | 127)
|
||||
header += struct.pack("!Q", length)
|
||||
return bytes(header) + mask + _ws_mask(payload, mask)
|
||||
|
||||
|
||||
async def _ws_recv_message(reader: asyncio.StreamReader) -> str:
|
||||
"""
|
||||
Read one complete WebSocket message, reassembling continuation frames.
|
||||
Skips ping/pong control frames. Raises OSError on close frame.
|
||||
"""
|
||||
fragments: list[bytes] = []
|
||||
while True:
|
||||
hdr = await reader.readexactly(2)
|
||||
fin = bool(hdr[0] & 0x80)
|
||||
opcode = hdr[0] & 0x0F
|
||||
masked = bool(hdr[1] & 0x80)
|
||||
length = hdr[1] & 0x7F
|
||||
|
||||
if length == 126:
|
||||
length = struct.unpack("!H", await reader.readexactly(2))[0]
|
||||
elif length == 127:
|
||||
length = struct.unpack("!Q", await reader.readexactly(8))[0]
|
||||
|
||||
mask_key = await reader.readexactly(4) if masked else None
|
||||
payload = await reader.readexactly(length) if length else b""
|
||||
if mask_key:
|
||||
payload = _ws_mask(payload, mask_key)
|
||||
|
||||
if opcode == 0x8:
|
||||
raise OSError("WebSocket: server sent close frame")
|
||||
if opcode in (0x9, 0xA):
|
||||
continue
|
||||
|
||||
fragments.append(payload)
|
||||
if fin:
|
||||
return b"".join(fragments).decode("utf-8")
|
||||
|
||||
|
||||
class _WebSocket:
|
||||
"""asyncio StreamReader/Writer wrapped to a simple send/recv/close API."""
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
|
||||
async def send(self, data: str) -> None:
|
||||
self._writer.write(_ws_encode_frame(data.encode("utf-8"), opcode=0x1))
|
||||
await self._writer.drain()
|
||||
|
||||
async def recv(self) -> str:
|
||||
return await _ws_recv_message(self._reader)
|
||||
|
||||
async def close(self) -> None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._writer.write(_ws_encode_frame(b"", opcode=0x8))
|
||||
await self._writer.drain()
|
||||
self._writer.close()
|
||||
with contextlib.suppress(Exception):
|
||||
await self._writer.wait_closed()
|
||||
|
||||
|
||||
async def _ws_connect(host: str, port: int, path: str, ssl_ctx: ssl.SSLContext) -> _WebSocket:
|
||||
"""Open a TLS connection, perform the HTTP→WebSocket upgrade, return a _WebSocket."""
|
||||
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_ctx)
|
||||
|
||||
key = base64.b64encode(os.urandom(16)).decode()
|
||||
writer.write((
|
||||
f"GET {path} HTTP/1.1\r\n"
|
||||
f"Host: {host}:{port}\r\n"
|
||||
f"Upgrade: websocket\r\n"
|
||||
f"Connection: Upgrade\r\n"
|
||||
f"Sec-WebSocket-Key: {key}\r\n"
|
||||
f"Sec-WebSocket-Version: 13\r\n"
|
||||
f"\r\n"
|
||||
).encode())
|
||||
await writer.drain()
|
||||
|
||||
response_lines: list[bytes] = []
|
||||
while True:
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=20)
|
||||
if not line:
|
||||
raise OSError("Connection closed during WebSocket handshake")
|
||||
response_lines.append(line)
|
||||
if line in (b"\r\n", b"\n"):
|
||||
break
|
||||
|
||||
status = response_lines[0].decode("latin-1").strip()
|
||||
if " 101 " not in status:
|
||||
raise OSError(f"WebSocket upgrade failed: {status}")
|
||||
|
||||
expected = base64.b64encode(
|
||||
hashlib.sha1(
|
||||
(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()
|
||||
).digest()
|
||||
).decode().lower()
|
||||
headers_text = b"".join(response_lines).decode("latin-1").lower()
|
||||
if expected not in headers_text:
|
||||
raise OSError("WebSocket upgrade: Sec-WebSocket-Accept mismatch")
|
||||
|
||||
return _WebSocket(reader, writer)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# TrueNAS JSON-RPC 2.0 client
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TrueNASClient:
|
||||
"""
|
||||
Minimal async JSON-RPC 2.0 client for the TrueNAS WebSocket API.
|
||||
|
||||
TrueNAS 25.04+ endpoint: wss://<host>:<port>/api/current
|
||||
Authentication: auth.login_with_api_key
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
api_key: str,
|
||||
port: int = 443,
|
||||
verify_ssl: bool = False,
|
||||
) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._api_key = api_key
|
||||
self._verify_ssl = verify_ssl
|
||||
self._ws = None
|
||||
self._call_id = 0
|
||||
|
||||
@property
|
||||
def _url(self) -> str:
|
||||
return f"wss://{self._host}:{self._port}/api/current"
|
||||
|
||||
async def __aenter__(self) -> "TrueNASClient":
|
||||
await self._connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_: Any) -> None:
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def _connect(self) -> None:
|
||||
ctx = ssl.create_default_context()
|
||||
if not self._verify_ssl:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
log.info("Connecting to %s …", self._url)
|
||||
try:
|
||||
self._ws = await _ws_connect(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
path="/api/current",
|
||||
ssl_ctx=ctx,
|
||||
)
|
||||
except (OSError, asyncio.TimeoutError) as exc:
|
||||
log.error("Connection failed: %s", exc)
|
||||
raise
|
||||
|
||||
log.info("Authenticating with API key …")
|
||||
result = await self.call("auth.login_with_api_key", [self._api_key])
|
||||
if result is not True and result != "SUCCESS":
|
||||
raise PermissionError(f"Authentication rejected: {result!r}")
|
||||
log.info("Connected and authenticated.")
|
||||
|
||||
async def call(self, method: str, params: Optional[list] = None) -> Any:
|
||||
"""Send one JSON-RPC request and return its result.
|
||||
Raises RuntimeError if the API returns an error.
|
||||
"""
|
||||
self._call_id += 1
|
||||
req_id = self._call_id
|
||||
|
||||
await self._ws.send(json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"method": method,
|
||||
"params": params or [],
|
||||
}))
|
||||
|
||||
while True:
|
||||
raw = await asyncio.wait_for(self._ws.recv(), timeout=60)
|
||||
msg = json.loads(raw)
|
||||
if "id" not in msg:
|
||||
continue
|
||||
if msg["id"] != req_id:
|
||||
continue
|
||||
if "error" in msg:
|
||||
err = msg["error"]
|
||||
reason = (
|
||||
err.get("data", {}).get("reason")
|
||||
or err.get("message")
|
||||
or repr(err)
|
||||
)
|
||||
raise RuntimeError(f"API error [{method}]: {reason}")
|
||||
return msg.get("result")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Dataset utilities
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def check_dataset_paths(
|
||||
client: TrueNASClient,
|
||||
paths: list[str],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Return the subset of *paths* that have no matching ZFS dataset on the
|
||||
destination. Returns an empty list when the dataset query itself fails.
|
||||
"""
|
||||
if not paths:
|
||||
return []
|
||||
|
||||
unique = sorted({p.rstrip("/") for p in paths if p})
|
||||
log.info("Checking %d share path(s) against destination datasets …", len(unique))
|
||||
try:
|
||||
datasets = await client.call("pool.dataset.query") or []
|
||||
except RuntimeError as exc:
|
||||
log.warning("Could not query datasets (skipping check): %s", exc)
|
||||
return []
|
||||
|
||||
mountpoints = {
|
||||
d.get("mountpoint", "").rstrip("/")
|
||||
for d in datasets
|
||||
if d.get("mountpoint")
|
||||
}
|
||||
|
||||
missing = [p for p in unique if p not in mountpoints]
|
||||
if missing:
|
||||
for p in missing:
|
||||
log.warning(" MISSING dataset for path: %s", p)
|
||||
else:
|
||||
log.info(" All share paths exist as datasets.")
|
||||
return missing
|
||||
|
||||
|
||||
async def create_dataset(client: TrueNASClient, path: str) -> bool:
|
||||
"""
|
||||
Create a ZFS dataset whose mountpoint will be *path*.
|
||||
*path* must be an absolute /mnt/… path.
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
if not path.startswith("/mnt/"):
|
||||
log.error("Cannot auto-create dataset for non-/mnt/ path: %s", path)
|
||||
return False
|
||||
|
||||
name = path[5:].rstrip("/")
|
||||
log.info("Creating dataset %r …", name)
|
||||
try:
|
||||
await client.call("pool.dataset.create", [{"name": name}])
|
||||
log.info(" Created: %s", name)
|
||||
return True
|
||||
except RuntimeError as exc:
|
||||
log.error(" Failed to create dataset %r: %s", name, exc)
|
||||
return False
|
||||
|
||||
|
||||
async def create_missing_datasets(
|
||||
host: str,
|
||||
port: int,
|
||||
api_key: str,
|
||||
paths: list[str],
|
||||
verify_ssl: bool = False,
|
||||
) -> None:
|
||||
"""Open a fresh connection and create ZFS datasets for *paths*."""
|
||||
async with TrueNASClient(
|
||||
host=host, port=port, api_key=api_key, verify_ssl=verify_ssl,
|
||||
) as client:
|
||||
for path in paths:
|
||||
await create_dataset(client, path)
|
||||
Reference in New Issue
Block a user