diff --git a/extensions/pi-remote-control/auth.ts b/extensions/pi-remote-control/auth.ts new file mode 100644 index 0000000..f42a097 --- /dev/null +++ b/extensions/pi-remote-control/auth.ts @@ -0,0 +1,40 @@ +/** + * Authentication helpers for remote-control. + * + * Provides one-time token generation/validation and session cookie management. + */ + +import { randomBytes, timingSafeEqual } from "node:crypto"; + +export function generateToken(): string { + return randomBytes(24).toString("base64url"); // 32 chars, URL-safe +} + +export function validateToken(provided: string, expected: string): boolean { + const a = Buffer.from(provided); + const b = Buffer.from(expected); + if (a.length !== b.length) return false; + return timingSafeEqual(a, b); +} + +/** Name of the cookie that grants access after initial token validation */ +export const SESSION_COOKIE = "pi_rc_session"; + +export function generateSessionId(): string { + return randomBytes(24).toString("base64url"); +} + +export function parseCookies(header: string | undefined): Record { + const cookies: Record = {}; + if (!header) return cookies; + for (const pair of header.split(";")) { + const idx = pair.indexOf("="); + if (idx < 0) continue; + const name = pair.slice(0, idx).trim(); + const raw = pair.slice(idx + 1).trim(); + let value = raw; + try { value = decodeURIComponent(raw); } catch { /* keep raw */ } + cookies[name] = value; + } + return cookies; +} diff --git a/extensions/pi-remote-control/config.ts b/extensions/pi-remote-control/config.ts new file mode 100644 index 0000000..fe39270 --- /dev/null +++ b/extensions/pi-remote-control/config.ts @@ -0,0 +1,106 @@ +/** + * Configuration management for remote-control. + * + * Reads/writes the `remote-control.json` config file from the agent directory, + * and provides the UI flow for configuring the public base URL. + */ + +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import type { ExtensionContext } from "@mariozechner/pi-coding-agent"; + +const REMOTE_CONTROL_CONFIG_FILE = "remote-control.json"; + +export interface RemoteControlConfig { + publicBaseUrl?: string; +} + +function getAgentDir(): string { + const envCandidates = ["PI_CODING_AGENT_DIR", "TAU_CODING_AGENT_DIR"]; + let envDir: string | undefined; + for (const key of envCandidates) { + if (process.env[key]) { + envDir = process.env[key]; + break; + } + } + if (!envDir) { + for (const [key, value] of Object.entries(process.env)) { + if (key.endsWith("_CODING_AGENT_DIR") && value) { + envDir = value; + break; + } + } + } + + if (envDir === "~") return os.homedir(); + if (envDir?.startsWith("~/")) return path.join(os.homedir(), envDir.slice(2)); + return envDir ?? path.join(os.homedir(), ".pi", "agent"); +} + +function getRemoteControlConfigPath(): string { + return path.join(getAgentDir(), REMOTE_CONTROL_CONFIG_FILE); +} + +export async function readRemoteControlConfig(): Promise { + try { + const raw = await fs.readFile(getRemoteControlConfigPath(), "utf8"); + const parsed = JSON.parse(raw) as RemoteControlConfig; + if (!parsed || typeof parsed !== "object") return {}; + return parsed; + } catch { + return {}; + } +} + +async function writeRemoteControlConfig(config: RemoteControlConfig): Promise { + const configPath = getRemoteControlConfigPath(); + await fs.mkdir(path.dirname(configPath), { recursive: true }); + await fs.writeFile(configPath, JSON.stringify(config, null, 2) + "\n", "utf8"); +} + +export function normalizePublicBaseUrl(value: string): string { + const parsed = new URL(value.trim()); + parsed.username = ""; + parsed.password = ""; + parsed.pathname = ""; + parsed.search = ""; + parsed.hash = ""; + return parsed.toString().replace(/\/+$/, ""); +} + +export function buildRemoteControlUrl(publicBaseUrl: string, port: number, token: string): string { + const parsed = new URL(normalizePublicBaseUrl(publicBaseUrl)); + if (parsed.protocol === "http:") { + parsed.port = String(port); + } + parsed.searchParams.set("token", token); + return parsed.toString(); +} + +export async function configureRemoteControlUI(ctx: ExtensionContext): Promise { + if (!ctx.hasUI) return; + + const current = (await readRemoteControlConfig()).publicBaseUrl ?? ""; + const title = current + ? `Public base URL (current: ${current})` + : "Public base URL"; + const raw = await ctx.ui.input(title, "e.g. http://pi.myhost"); + if (raw === undefined) return; + + let value: string; + try { + value = normalizePublicBaseUrl(raw); + } catch { + ctx.ui.notify("Public base URL must be a valid http:// or https:// URL", "warning"); + return; + } + if (!["http:", "https:"].includes(new URL(value).protocol)) { + ctx.ui.notify("Public base URL must start with http:// or https://", "warning"); + return; + } + + await writeRemoteControlConfig({ publicBaseUrl: value }); + ctx.ui.notify(`Saved remote-control URL to ${getRemoteControlConfigPath()}`, "info"); +} diff --git a/extensions/pi-remote-control/html.ts b/extensions/pi-remote-control/html.ts new file mode 100644 index 0000000..c52a363 --- /dev/null +++ b/extensions/pi-remote-control/html.ts @@ -0,0 +1,656 @@ +/** + * Inline web UI for remote-control. + * + * Generates the single-page HTML/CSS/JS served to the browser client. + * Everything is self-contained — no external dependencies. + */ + +export function buildHTML(nonce: string): string { +return /* html */ ` + + + + + Pi Remote + + + +
+
+
+ Connecting\u2026 + +
+
+
+
+ + +
+
+ + +`; +} diff --git a/extensions/pi-remote-control/index.ts b/extensions/pi-remote-control/index.ts index ee2a879..07f3a6d 100644 --- a/extensions/pi-remote-control/index.ts +++ b/extensions/pi-remote-control/index.ts @@ -10,1102 +10,17 @@ * The server stops automatically when the session closes. */ -import { createServer } from "node:http"; -import { createRequire } from "node:module"; import { execFileSync } from "node:child_process"; -import { randomBytes, timingSafeEqual } from "node:crypto"; -import fs from "node:fs/promises"; -import os from "node:os"; -import path from "node:path"; import type { ExtensionAPI, ExtensionContext } from "@mariozechner/pi-coding-agent"; import { DynamicBorder } from "@mariozechner/pi-coding-agent"; import { Container, Key, Text, matchesKey } from "@mariozechner/pi-tui"; - - -// Load ws (bundled with pi) without needing @types/ws installed locally -const _require = createRequire(import.meta.url); -const wsModule = _require("ws") as { - WebSocketServer: new (opts: { noServer: boolean }) => any; - OPEN: number; -}; -const { WebSocketServer, OPEN } = wsModule; - -const REMOTE_CONTROL_CONFIG_FILE = "remote-control.json"; - -interface RemoteControlConfig { - publicBaseUrl?: string; -} - -function getAgentDir(): string { - const envCandidates = ["PI_CODING_AGENT_DIR", "TAU_CODING_AGENT_DIR"]; - let envDir: string | undefined; - for (const key of envCandidates) { - if (process.env[key]) { - envDir = process.env[key]; - break; - } - } - if (!envDir) { - for (const [key, value] of Object.entries(process.env)) { - if (key.endsWith("_CODING_AGENT_DIR") && value) { - envDir = value; - break; - } - } - } - - if (envDir === "~") return os.homedir(); - if (envDir?.startsWith("~/")) return path.join(os.homedir(), envDir.slice(2)); - return envDir ?? path.join(os.homedir(), ".pi", "agent"); -} - -function getRemoteControlConfigPath(): string { - return path.join(getAgentDir(), REMOTE_CONTROL_CONFIG_FILE); -} - -async function readRemoteControlConfig(): Promise { - try { - const raw = await fs.readFile(getRemoteControlConfigPath(), "utf8"); - const parsed = JSON.parse(raw) as RemoteControlConfig; - if (!parsed || typeof parsed !== "object") return {}; - return parsed; - } catch { - return {}; - } -} - -async function writeRemoteControlConfig(config: RemoteControlConfig): Promise { - const configPath = getRemoteControlConfigPath(); - await fs.mkdir(path.dirname(configPath), { recursive: true }); - await fs.writeFile(configPath, JSON.stringify(config, null, 2) + "\n", "utf8"); -} - -function normalizePublicBaseUrl(value: string): string { - const parsed = new URL(value.trim()); - parsed.username = ""; - parsed.password = ""; - parsed.pathname = ""; - parsed.search = ""; - parsed.hash = ""; - return parsed.toString().replace(/\/+$/, ""); -} - -function buildRemoteControlUrl(publicBaseUrl: string, port: number, token: string): string { - const parsed = new URL(normalizePublicBaseUrl(publicBaseUrl)); - if (parsed.protocol === "http:") { - parsed.port = String(port); - } - parsed.searchParams.set("token", token); - return parsed.toString(); -} - -async function configureRemoteControlUI(ctx: ExtensionContext): Promise { - if (!ctx.hasUI) return; - - const current = (await readRemoteControlConfig()).publicBaseUrl ?? ""; - const title = current - ? `Public base URL (current: ${current})` - : "Public base URL"; - const raw = await ctx.ui.input(title, "e.g. http://pi.myhost"); - if (raw === undefined) return; - - let value: string; - try { - value = normalizePublicBaseUrl(raw); - } catch { - ctx.ui.notify("Public base URL must be a valid http:// or https:// URL", "warning"); - return; - } - if (!["http:", "https:"].includes(new URL(value).protocol)) { - ctx.ui.notify("Public base URL must start with http:// or https://", "warning"); - return; - } - - await writeRemoteControlConfig({ publicBaseUrl: value }); - ctx.ui.notify(`Saved remote-control URL to ${getRemoteControlConfigPath()}`, "info"); -} - -// ── Auth helpers ───────────────────────────────────────────────────────────── - -function generateToken(): string { - return randomBytes(24).toString("base64url"); // 32 chars, URL-safe -} - -function validateToken(provided: string, expected: string): boolean { - const a = Buffer.from(provided); - const b = Buffer.from(expected); - if (a.length !== b.length) return false; - return timingSafeEqual(a, b); -} - -/** Name of the cookie that grants access after initial token validation */ -const SESSION_COOKIE = "pi_rc_session"; - -function generateSessionId(): string { - return randomBytes(24).toString("base64url"); -} - -function parsecookies(header: string | undefined): Record { - const cookies: Record = {}; - if (!header) return cookies; - for (const pair of header.split(";")) { - const idx = pair.indexOf("="); - if (idx < 0) continue; - const name = pair.slice(0, idx).trim(); - const raw = pair.slice(idx + 1).trim(); - let value = raw; - try { value = decodeURIComponent(raw); } catch { /* keep raw */ } - cookies[name] = value; - } - return cookies; -} - -// ── Wire protocol types ────────────────────────────────────────────────────── - -interface RenderMsg { - id: string; // SessionEntry id, or "pending" while streaming - role: "user" | "assistant" | "tool_result"; - text: string; - toolCalls?: Array<{ id: string; name: string; args: string }>; - toolName?: string; - toolCallId?: string; - isError?: boolean; - model?: string; -} - -// ── Message serialization ──────────────────────────────────────────────────── - -function serializeMessage(id: string, msg: any): RenderMsg | null { - if (msg.role === "user") { - const text = - typeof msg.content === "string" - ? msg.content - : (msg.content as any[]) - .filter((c) => c.type === "text") - .map((c) => c.text) - .join(""); - return { id, role: "user", text }; - } - - if (msg.role === "assistant") { - const text = (msg.content as any[]) - .filter((c) => c.type === "text") - .map((c) => c.text) - .join(""); - const toolCalls = (msg.content as any[]) - .filter((c) => c.type === "toolCall") - .map((c) => ({ - id: c.id, - name: c.name, - args: JSON.stringify(c.arguments, null, 2), - })); - return { - id, - role: "assistant", - text, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - model: msg.model, - }; - } - - if (msg.role === "toolResult") { - const text = (msg.content as any[]) - .filter((c) => c.type === "text") - .map((c) => c.text) - .join(""); - return { - id, - role: "tool_result", - text, - toolName: msg.toolName, - toolCallId: msg.toolCallId, - isError: msg.isError, - }; - } - - return null; -} - -function getBranchMessages(ctx: ExtensionContext): RenderMsg[] { - const branch = ctx.sessionManager.getBranch(); - const out: RenderMsg[] = []; - for (const entry of branch) { - if (entry.type !== "message") continue; - const m = serializeMessage(entry.id, (entry as any).message); - if (m) out.push(m); - } - return out; -} - -// ── Inlined web UI ─────────────────────────────────────────────────────────── - -function buildHTML(nonce: string): string { -return /* html */ ` - - - - - Pi Remote - - - -
-
-
- Connecting\u2026 - -
-
-
-
- - -
-
- - -`; -} - -// ── HTTP + WebSocket server ────────────────────────────────────────────────── - -interface RemoteServer { - broadcast: (msg: object) => void; - stop: () => Promise; - clientCount: () => number; - onClientChange: (cb: () => void) => void; - port: number; - token: string; -} - -function startServer(pi: ExtensionAPI, ctx: ExtensionContext): Promise { - const clientChangeListeners: Array<() => void> = []; - const clients = new Set(); - const token = generateToken(); - // Map of valid session IDs → expiry timestamp (ms since epoch) - const SESSION_TTL_MS = 86_400_000; // 24 h — matches cookie Max-Age - const validSessions = new Map(); - const pruneExpiredSessions = (): void => { - const now = Date.now(); - for (const [id, expiresAt] of validSessions) { - if (expiresAt <= now) validSessions.delete(id); - } - }; - - /** Check if a request is authenticated (valid token query param OR valid session cookie) */ - function isAuthenticated(req: any): boolean { - // Check session cookie first - const cookies = parsecookies(req.headers.cookie); - const sessionId = cookies[SESSION_COOKIE]; - const sessionExpiry = sessionId ? validSessions.get(sessionId) : undefined; - if (sessionExpiry !== undefined && sessionExpiry > Date.now()) return true; - - // Check token query param - const url = new URL(req.url ?? "/", "http://localhost"); - const providedToken = url.searchParams.get("token"); - if (providedToken && validateToken(providedToken, token)) return true; - - return false; - } - - function broadcast(msg: object): void { - const data = JSON.stringify(msg); - for (const client of clients) { - if (client.readyState === OPEN) { - try { - client.send(data); - } catch { - /* ignore */ - } - } - } - } - - const httpServer = createServer((req, res) => { - const url = new URL(req.url ?? "/", "http://localhost"); - const pathname = url.pathname; - - if (pathname === "/" || pathname === "/index.html") { - // Check authentication - const cookies = parsecookies(req.headers.cookie); - const sc = cookies[SESSION_COOKIE]; - const hasValidSession = sc !== undefined && (validSessions.get(sc) ?? 0) > Date.now(); - const providedToken = url.searchParams.get("token"); - const hasValidToken = providedToken && validateToken(providedToken, token); - - if (!hasValidSession && !hasValidToken) { - res.writeHead(403, { "Content-Type": "text/plain" }); - res.end("Forbidden — valid token required. Use the URL shown in the pi terminal."); - return; - } - - // If authenticated via token (first visit), issue a session cookie and redirect to clean URL - if (!hasValidSession && hasValidToken) { - pruneExpiredSessions(); - const sessionId = generateSessionId(); - validSessions.set(sessionId, Date.now() + SESSION_TTL_MS); - res.writeHead(302, { - "Set-Cookie": `${SESSION_COOKIE}=${sessionId}; Path=/; HttpOnly; SameSite=Strict; Max-Age=86400`, - Location: "/", - }); - res.end(); - return; - } - - // Valid session cookie — serve the page - const nonce = randomBytes(16).toString("base64"); - res.writeHead(200, { - "Content-Type": "text/html; charset=utf-8", - "X-Frame-Options": "DENY", - "X-Content-Type-Options": "nosniff", - "Referrer-Policy": "no-referrer", - "Content-Security-Policy": - `default-src 'none'; script-src 'nonce-${nonce}'; style-src 'nonce-${nonce}'; connect-src 'self'; base-uri 'none'`, - }); - res.end(buildHTML(nonce)); - } else { - res.writeHead(404, { "Content-Type": "text/plain" }); - res.end("Not found"); - } - }); - - const wss = new WebSocketServer({ noServer: true }); - - httpServer.on("error", (err: Error) => { - console.error("[remote-control] httpServer error:", err.message); - }); - - wss.on("error", (err: Error) => { - console.error("[remote-control] wss error:", err.message); - }); - - httpServer.on("upgrade", (request: any, socket: any, head: any) => { - const url = new URL(request.url, "http://localhost"); - if (url.pathname === "/ws") { - // Validate auth: session cookie or token query param - if (!isAuthenticated(request)) { - socket.write("HTTP/1.1 403 Forbidden\r\n\r\n"); - socket.destroy(); - return; - } - wss.handleUpgrade(request, socket, head, (ws: any) => { - wss.emit("connection", ws, request); - }); - } else { - socket.destroy(); - } - }); - - wss.on("connection", (ws: any) => { - clients.add(ws); - for (const cb of clientChangeListeners) cb(); - - // Send full state snapshot to the new client - try { - ws.send( - JSON.stringify({ - type: "sync", - messages: getBranchMessages(ctx), - state: { - isStreaming: !ctx.isIdle(), - model: ctx.model?.id, - cwd: ctx.cwd, - sessionName: ctx.sessionManager.getSessionName(), - }, - }), - ); - } catch { - /* client disconnected before first send */ - } - - // Per-connection rate limiting: max 30 prompts per 60 seconds - const RATE_WINDOW_MS = 60_000; - const RATE_MAX = 30; - const MAX_MSG_BYTES = 64 * 1024; - const recentPrompts: number[] = []; - - ws.on("message", (data: any) => { - if (data.length > MAX_MSG_BYTES) return; - let msg: any; - try { - msg = JSON.parse(data.toString()); - } catch { - return; - } - if (msg.type === "prompt" && typeof msg.text === "string" && msg.text.trim()) { - const text = msg.text.trim(); - // Sliding-window rate limit - const now = Date.now(); - const cutoff = now - RATE_WINDOW_MS; - while (recentPrompts.length > 0 && recentPrompts[0] < cutoff) recentPrompts.shift(); - if (recentPrompts.length >= RATE_MAX) return; - recentPrompts.push(now); - if (ctx.isIdle()) { - pi.sendUserMessage(text); - } else { - pi.sendUserMessage(text, { deliverAs: "followUp" }); - } - } - }); - - const onClose = () => { - clients.delete(ws); - broadcast({ type: "status", clientCount: clients.size }); - for (const cb of clientChangeListeners) cb(); - }; - ws.on("close", onClose); - ws.on("error", onClose); - }); - - return new Promise((resolve) => { - httpServer.listen(0, "127.0.0.1", () => { - resolve({ - broadcast, - stop: () => - new Promise((res) => { - for (const client of clients) { - try { - client.close(); - } catch { - /* ignore */ - } - } - wss.close(() => httpServer.close(() => res())); - }), - clientCount: () => clients.size, - onClientChange: (cb: () => void) => { clientChangeListeners.push(cb); }, - get port() { - return (httpServer.address() as any)?.port ?? 0; - }, - get token() { - return token; - }, - }); - }); - }); -} +import { + readRemoteControlConfig, + buildRemoteControlUrl, + configureRemoteControlUI, +} from "./config.js"; +import { serializeMessage } from "./messages.js"; +import { type RemoteServer, startServer } from "./server.js"; // ── Extension entry point ──────────────────────────────────────────────────── diff --git a/extensions/pi-remote-control/messages.ts b/extensions/pi-remote-control/messages.ts new file mode 100644 index 0000000..e65a8dd --- /dev/null +++ b/extensions/pi-remote-control/messages.ts @@ -0,0 +1,81 @@ +/** + * Wire protocol types and message serialization for remote-control. + * + * Converts pi session entries into the simplified RenderMsg format + * consumed by the browser client. + */ + +import type { ExtensionContext } from "@mariozechner/pi-coding-agent"; + +export interface RenderMsg { + id: string; // SessionEntry id, or "pending" while streaming + role: "user" | "assistant" | "tool_result"; + text: string; + toolCalls?: Array<{ id: string; name: string; args: string }>; + toolName?: string; + toolCallId?: string; + isError?: boolean; + model?: string; +} + +export function serializeMessage(id: string, msg: any): RenderMsg | null { + if (msg.role === "user") { + const text = + typeof msg.content === "string" + ? msg.content + : (msg.content as any[]) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join(""); + return { id, role: "user", text }; + } + + if (msg.role === "assistant") { + const text = (msg.content as any[]) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join(""); + const toolCalls = (msg.content as any[]) + .filter((c) => c.type === "toolCall") + .map((c) => ({ + id: c.id, + name: c.name, + args: JSON.stringify(c.arguments, null, 2), + })); + return { + id, + role: "assistant", + text, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + model: msg.model, + }; + } + + if (msg.role === "toolResult") { + const text = (msg.content as any[]) + .filter((c) => c.type === "text") + .map((c) => c.text) + .join(""); + return { + id, + role: "tool_result", + text, + toolName: msg.toolName, + toolCallId: msg.toolCallId, + isError: msg.isError, + }; + } + + return null; +} + +export function getBranchMessages(ctx: ExtensionContext): RenderMsg[] { + const branch = ctx.sessionManager.getBranch(); + const out: RenderMsg[] = []; + for (const entry of branch) { + if (entry.type !== "message") continue; + const m = serializeMessage(entry.id, (entry as any).message); + if (m) out.push(m); + } + return out; +} diff --git a/extensions/pi-remote-control/server.ts b/extensions/pi-remote-control/server.ts new file mode 100644 index 0000000..7908dec --- /dev/null +++ b/extensions/pi-remote-control/server.ts @@ -0,0 +1,244 @@ +/** + * HTTP + WebSocket server for remote-control. + * + * Handles authentication, serves the web UI, and manages WebSocket connections + * for real-time message streaming between the pi session and browser clients. + */ + +import { createServer } from "node:http"; +import { createRequire } from "node:module"; +import { randomBytes } from "node:crypto"; +import type { ExtensionAPI, ExtensionContext } from "@mariozechner/pi-coding-agent"; +import { + generateToken, + validateToken, + SESSION_COOKIE, + generateSessionId, + parseCookies, +} from "./auth.js"; +import { getBranchMessages } from "./messages.js"; +import { buildHTML } from "./html.js"; + +// Load ws (bundled with pi) without needing @types/ws installed locally +const _require = createRequire(import.meta.url); +const wsModule = _require("ws") as { + WebSocketServer: new (opts: { noServer: boolean }) => any; + OPEN: number; +}; +const { WebSocketServer, OPEN } = wsModule; + +export interface RemoteServer { + broadcast: (msg: object) => void; + stop: () => Promise; + clientCount: () => number; + onClientChange: (cb: () => void) => void; + port: number; + token: string; +} + +export function startServer(pi: ExtensionAPI, ctx: ExtensionContext): Promise { + const clientChangeListeners: Array<() => void> = []; + const clients = new Set(); + const token = generateToken(); + // Map of valid session IDs → expiry timestamp (ms since epoch) + const SESSION_TTL_MS = 86_400_000; // 24 h — matches cookie Max-Age + const validSessions = new Map(); + const pruneExpiredSessions = (): void => { + const now = Date.now(); + for (const [id, expiresAt] of validSessions) { + if (expiresAt <= now) validSessions.delete(id); + } + }; + + /** Check if a request is authenticated (valid token query param OR valid session cookie) */ + function isAuthenticated(req: any): boolean { + // Check session cookie first + const cookies = parseCookies(req.headers.cookie); + const sessionId = cookies[SESSION_COOKIE]; + const sessionExpiry = sessionId ? validSessions.get(sessionId) : undefined; + if (sessionExpiry !== undefined && sessionExpiry > Date.now()) return true; + + // Check token query param + const url = new URL(req.url ?? "/", "http://localhost"); + const providedToken = url.searchParams.get("token"); + if (providedToken && validateToken(providedToken, token)) return true; + + return false; + } + + function broadcast(msg: object): void { + const data = JSON.stringify(msg); + for (const client of clients) { + if (client.readyState === OPEN) { + try { + client.send(data); + } catch { + /* ignore */ + } + } + } + } + + const httpServer = createServer((req, res) => { + const url = new URL(req.url ?? "/", "http://localhost"); + const pathname = url.pathname; + + if (pathname === "/" || pathname === "/index.html") { + // Check authentication + const cookies = parseCookies(req.headers.cookie); + const sc = cookies[SESSION_COOKIE]; + const hasValidSession = sc !== undefined && (validSessions.get(sc) ?? 0) > Date.now(); + const providedToken = url.searchParams.get("token"); + const hasValidToken = providedToken && validateToken(providedToken, token); + + if (!hasValidSession && !hasValidToken) { + res.writeHead(403, { "Content-Type": "text/plain" }); + res.end("Forbidden — valid token required. Use the URL shown in the pi terminal."); + return; + } + + // If authenticated via token (first visit), issue a session cookie and redirect to clean URL + if (!hasValidSession && hasValidToken) { + pruneExpiredSessions(); + const sessionId = generateSessionId(); + validSessions.set(sessionId, Date.now() + SESSION_TTL_MS); + res.writeHead(302, { + "Set-Cookie": `${SESSION_COOKIE}=${sessionId}; Path=/; HttpOnly; SameSite=Strict; Max-Age=86400`, + Location: "/", + }); + res.end(); + return; + } + + // Valid session cookie — serve the page + const nonce = randomBytes(16).toString("base64"); + res.writeHead(200, { + "Content-Type": "text/html; charset=utf-8", + "X-Frame-Options": "DENY", + "X-Content-Type-Options": "nosniff", + "Referrer-Policy": "no-referrer", + "Content-Security-Policy": + `default-src 'none'; script-src 'nonce-${nonce}'; style-src 'nonce-${nonce}'; connect-src 'self'; base-uri 'none'`, + }); + res.end(buildHTML(nonce)); + } else { + res.writeHead(404, { "Content-Type": "text/plain" }); + res.end("Not found"); + } + }); + + const wss = new WebSocketServer({ noServer: true }); + + httpServer.on("error", (err: Error) => { + console.error("[remote-control] httpServer error:", err.message); + }); + + wss.on("error", (err: Error) => { + console.error("[remote-control] wss error:", err.message); + }); + + httpServer.on("upgrade", (request: any, socket: any, head: any) => { + const url = new URL(request.url, "http://localhost"); + if (url.pathname === "/ws") { + // Validate auth: session cookie or token query param + if (!isAuthenticated(request)) { + socket.write("HTTP/1.1 403 Forbidden\r\n\r\n"); + socket.destroy(); + return; + } + wss.handleUpgrade(request, socket, head, (ws: any) => { + wss.emit("connection", ws, request); + }); + } else { + socket.destroy(); + } + }); + + wss.on("connection", (ws: any) => { + clients.add(ws); + for (const cb of clientChangeListeners) cb(); + + // Send full state snapshot to the new client + try { + ws.send( + JSON.stringify({ + type: "sync", + messages: getBranchMessages(ctx), + state: { + isStreaming: !ctx.isIdle(), + model: ctx.model?.id, + cwd: ctx.cwd, + sessionName: ctx.sessionManager.getSessionName(), + }, + }), + ); + } catch { + /* client disconnected before first send */ + } + + // Per-connection rate limiting: max 30 prompts per 60 seconds + const RATE_WINDOW_MS = 60_000; + const RATE_MAX = 30; + const MAX_MSG_BYTES = 64 * 1024; + const recentPrompts: number[] = []; + + ws.on("message", (data: any) => { + if (data.length > MAX_MSG_BYTES) return; + let msg: any; + try { + msg = JSON.parse(data.toString()); + } catch { + return; + } + if (msg.type === "prompt" && typeof msg.text === "string" && msg.text.trim()) { + const text = msg.text.trim(); + // Sliding-window rate limit + const now = Date.now(); + const cutoff = now - RATE_WINDOW_MS; + while (recentPrompts.length > 0 && recentPrompts[0] < cutoff) recentPrompts.shift(); + if (recentPrompts.length >= RATE_MAX) return; + recentPrompts.push(now); + if (ctx.isIdle()) { + pi.sendUserMessage(text); + } else { + pi.sendUserMessage(text, { deliverAs: "followUp" }); + } + } + }); + + const onClose = () => { + clients.delete(ws); + broadcast({ type: "status", clientCount: clients.size }); + for (const cb of clientChangeListeners) cb(); + }; + ws.on("close", onClose); + ws.on("error", onClose); + }); + + return new Promise((resolve) => { + httpServer.listen(0, "127.0.0.1", () => { + resolve({ + broadcast, + stop: () => + new Promise((res) => { + for (const client of clients) { + try { + client.close(); + } catch { + /* ignore */ + } + } + wss.close(() => httpServer.close(() => res())); + }), + clientCount: () => clients.size, + onClientChange: (cb: () => void) => { clientChangeListeners.push(cb); }, + get port() { + return (httpServer.address() as any)?.port ?? 0; + }, + get token() { + return token; + }, + }); + }); + }); +}