diff --git a/src/cli/run/session-resolver.test.ts b/src/cli/run/session-resolver.test.ts index ca7d9f65..a9775bb4 100644 --- a/src/cli/run/session-resolver.test.ts +++ b/src/cli/run/session-resolver.test.ts @@ -1,6 +1,8 @@ -import { describe, it, expect, beforeEach, mock, spyOn } from "bun:test" -import { resolveSession } from "./session-resolver" -import type { OpencodeClient } from "./types" +/// + +import { beforeEach, describe, expect, it, mock, spyOn } from "bun:test"; +import { resolveSession } from "./session-resolver"; +import type { OpencodeClient } from "./types"; const createMockClient = (overrides: { getResult?: { error?: unknown; data?: { id: string } } @@ -58,7 +60,9 @@ describe("resolveSession", () => { const result = resolveSession({ client: mockClient, sessionId }) // then - await expect(result).rejects.toThrow(`Session not found: ${sessionId}`) + await Promise.resolve( + expect(result).rejects.toThrow(`Session not found: ${sessionId}`) + ) expect(mockClient.session.get).toHaveBeenCalledWith({ path: { id: sessionId }, }) @@ -77,7 +81,12 @@ describe("resolveSession", () => { // then expect(result).toBe("new-session-id") expect(mockClient.session.create).toHaveBeenCalledWith({ - body: { title: "oh-my-opencode run" }, + body: { + title: "oh-my-opencode run", + permission: [ + { permission: "question", action: "deny", pattern: "*" }, + ], + }, }) expect(mockClient.session.get).not.toHaveBeenCalled() }) @@ -98,7 +107,12 @@ describe("resolveSession", () => { expect(result).toBe("retried-session-id") expect(mockClient.session.create).toHaveBeenCalledTimes(2) expect(mockClient.session.create).toHaveBeenCalledWith({ - body: { title: "oh-my-opencode run" }, + body: { + title: "oh-my-opencode run", + permission: [ + { permission: "question", action: "deny", pattern: "*" }, + ], + }, }) }) @@ -116,7 +130,9 @@ describe("resolveSession", () => { const result = resolveSession({ client: mockClient }) // then - await expect(result).rejects.toThrow("Failed to create session after all retries") + await Promise.resolve( + expect(result).rejects.toThrow("Failed to create session after all retries") + ) expect(mockClient.session.create).toHaveBeenCalledTimes(3) }) @@ -134,7 +150,9 @@ describe("resolveSession", () => { const result = resolveSession({ client: mockClient }) // then - await expect(result).rejects.toThrow("Failed to create session after all retries") + await Promise.resolve( + expect(result).rejects.toThrow("Failed to create session after all retries") + ) expect(mockClient.session.create).toHaveBeenCalledTimes(3) }) }) diff --git a/src/cli/run/session-resolver.ts b/src/cli/run/session-resolver.ts index 31bd5a2c..1ec07199 100644 --- a/src/cli/run/session-resolver.ts +++ b/src/cli/run/session-resolver.ts @@ -19,14 +19,18 @@ export async function resolveSession(options: { return sessionId } - let lastError: unknown for (let attempt = 1; attempt <= SESSION_CREATE_MAX_RETRIES; attempt++) { const res = await client.session.create({ - body: { title: "oh-my-opencode run" }, + body: { + title: "oh-my-opencode run", + // In CLI run mode there's no TUI to answer questions. + permission: [ + { permission: "question", action: "deny" as const, pattern: "*" }, + ], + } as any, }) if (res.error) { - lastError = res.error console.error( pc.yellow(`Session create attempt ${attempt}/${SESSION_CREATE_MAX_RETRIES} failed:`) ) @@ -44,9 +48,6 @@ export async function resolveSession(options: { return res.data.id } - lastError = new Error( - `Unexpected response: ${JSON.stringify(res, null, 2)}` - ) console.error( pc.yellow( `Session create attempt ${attempt}/${SESSION_CREATE_MAX_RETRIES}: No session ID returned` diff --git a/src/features/background-agent/manager.test.ts b/src/features/background-agent/manager.test.ts index c4db9056..d67ae2ad 100644 --- a/src/features/background-agent/manager.test.ts +++ b/src/features/background-agent/manager.test.ts @@ -1412,14 +1412,14 @@ describe("BackgroundManager - Non-blocking Queue Integration", () => { let manager: BackgroundManager let mockClient: ReturnType - function createMockClient() { - return { - session: { - create: async () => ({ data: { id: `ses_${crypto.randomUUID()}` } }), - get: async () => ({ data: { directory: "/test/dir" } }), - prompt: async () => ({}), - promptAsync: async () => ({}), - messages: async () => ({ data: [] }), + function createMockClient() { + return { + session: { + create: async (_args?: any) => ({ data: { id: `ses_${crypto.randomUUID()}` } }), + get: async () => ({ data: { directory: "/test/dir" } }), + prompt: async () => ({}), + promptAsync: async () => ({}), + messages: async () => ({ data: [] }), todo: async () => ({ data: [] }), status: async () => ({ data: {} }), abort: async () => ({}), @@ -1520,6 +1520,55 @@ describe("BackgroundManager - Non-blocking Queue Integration", () => { }) describe("task transitions pending→running when slot available", () => { + test("should inherit parent session permission rules (and force deny question)", async () => { + // given + const createCalls: any[] = [] + const parentPermission = [ + { permission: "question", action: "allow" as const, pattern: "*" }, + { permission: "plan_enter", action: "deny" as const, pattern: "*" }, + ] + + const customClient = { + session: { + create: async (args?: any) => { + createCalls.push(args) + return { data: { id: `ses_${crypto.randomUUID()}` } } + }, + get: async () => ({ data: { directory: "/test/dir", permission: parentPermission } }), + prompt: async () => ({}), + promptAsync: async () => ({}), + messages: async () => ({ data: [] }), + todo: async () => ({ data: [] }), + status: async () => ({ data: {} }), + abort: async () => ({}), + }, + } + manager.shutdown() + manager = new BackgroundManager({ client: customClient, directory: tmpdir() } as unknown as PluginInput, { + defaultConcurrency: 5, + }) + + const input = { + description: "Test task", + prompt: "Do something", + agent: "test-agent", + parentSessionID: "parent-session", + parentMessageID: "parent-message", + } + + // when + await manager.launch(input) + await new Promise(resolve => setTimeout(resolve, 50)) + + // then + expect(createCalls).toHaveLength(1) + const permission = createCalls[0]?.body?.permission + expect(permission).toEqual([ + { permission: "plan_enter", action: "deny", pattern: "*" }, + { permission: "question", action: "deny", pattern: "*" }, + ]) + }) + test("should transition first task to running immediately", async () => { // given const config = { defaultConcurrency: 5 } diff --git a/src/features/background-agent/manager.ts b/src/features/background-agent/manager.ts index 320a93f9..cfe8808a 100644 --- a/src/features/background-agent/manager.ts +++ b/src/features/background-agent/manager.ts @@ -1,32 +1,70 @@ + import type { PluginInput } from "@opencode-ai/plugin" -import type { BackgroundTask, LaunchInput, ResumeInput } from "./types" -import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema" - -import { log } from "../../shared" +import type { + BackgroundTask, + LaunchInput, + ResumeInput, +} from "./types" +import { log, getAgentToolRestrictions, promptWithModelSuggestionRetry } from "../../shared" import { ConcurrencyManager } from "./concurrency" -import { POLLING_INTERVAL_MS } from "./constants" +import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema" +import { isInsideTmux } from "../../shared/tmux" +import { + DEFAULT_STALE_TIMEOUT_MS, + MIN_IDLE_TIME_MS, + MIN_RUNTIME_BEFORE_STALE_MS, + MIN_STABILITY_TIME_MS, + POLLING_INTERVAL_MS, + TASK_CLEANUP_DELAY_MS, + TASK_TTL_MS, +} from "./constants" -import { handleBackgroundEvent } from "./background-event-handler" -import { shutdownBackgroundManager } from "./background-manager-shutdown" -import { clearNotifications, clearNotificationsForTask, cleanupPendingByParent, getPendingNotifications, markForNotification } from "./notification-tracker" -import { notifyParentSession as notifyParentSessionInternal } from "./notify-parent-session" -import { pollRunningTasks } from "./poll-running-tasks" -import { registerProcessSignal, type ProcessCleanupEvent } from "./process-signal" -import { validateSessionHasOutput, checkSessionTodos } from "./session-validator" -import { pruneStaleState } from "./stale-task-pruner" -import { getAllDescendantTasks, getCompletedTasks, getRunningTasks, getTasksByParentSession, hasRunningTasks, findTaskBySession } from "./task-queries" -import { checkAndInterruptStaleTasks } from "./task-poller" -import { cancelBackgroundTask } from "./task-canceller" -import { tryCompleteBackgroundTask } from "./task-completer" -import { launchBackgroundTask } from "./task-launch" -import { processConcurrencyKeyQueue } from "./task-queue-processor" -import { resumeBackgroundTask } from "./task-resumer" -import { startQueuedTask } from "./task-starter" -import { trackExternalTask } from "./task-tracker" +import { subagentSessions } from "../claude-code-session-state" +import { getTaskToastManager } from "../task-toast-manager" +import { findNearestMessageWithFields, MESSAGE_STORAGE } from "../hook-message-injector" +import { existsSync, readdirSync } from "node:fs" +import { join } from "node:path" -type QueueItem = { task: BackgroundTask; input: LaunchInput } +type ProcessCleanupEvent = NodeJS.Signals | "beforeExit" | "exit" + +type OpencodeClient = PluginInput["client"] + + +interface MessagePartInfo { + sessionID?: string + type?: string + tool?: string +} + +interface EventProperties { + sessionID?: string + info?: { id?: string } + [key: string]: unknown +} + +interface Event { + type: string + properties?: EventProperties +} + +interface Todo { + content: string + status: string + priority: string + id: string +} + +interface QueueItem { + task: BackgroundTask + input: LaunchInput +} + +export interface SubagentSessionCreatedEvent { + sessionID: string + parentID: string + title: string +} -export interface SubagentSessionCreatedEvent { sessionID: string; parentID: string; title: string } export type OnSubagentSessionCreated = (event: SubagentSessionCreatedEvent) => Promise export class BackgroundManager { @@ -34,26 +72,37 @@ export class BackgroundManager { private static cleanupRegistered = false private static cleanupHandlers = new Map void>() - private tasks = new Map() - private notifications = new Map() - private pendingByParent = new Map>() - private queuesByKey = new Map() - private processingKeys = new Set() - private completionTimers = new Map>() - private idleDeferralTimers = new Map>() - private notificationQueueByParent = new Map>() - - private client: PluginInput["client"] + private tasks: Map + private notifications: Map + private pendingByParent: Map> // Track pending tasks per parent for batching + private client: OpencodeClient private directory: string private pollingInterval?: ReturnType private concurrencyManager: ConcurrencyManager - private shutdownTriggered = { value: false } + private shutdownTriggered = false private config?: BackgroundTaskConfig private tmuxEnabled: boolean private onSubagentSessionCreated?: OnSubagentSessionCreated private onShutdown?: () => void - constructor(ctx: PluginInput, config?: BackgroundTaskConfig, options?: { tmuxConfig?: TmuxConfig; onSubagentSessionCreated?: OnSubagentSessionCreated; onShutdown?: () => void }) { + private queuesByKey: Map = new Map() + private processingKeys: Set = new Set() + private completionTimers: Map> = new Map() + private idleDeferralTimers: Map> = new Map() + private notificationQueueByParent: Map> = new Map() + + constructor( + ctx: PluginInput, + config?: BackgroundTaskConfig, + options?: { + tmuxConfig?: TmuxConfig + onSubagentSessionCreated?: OnSubagentSessionCreated + onShutdown?: () => void + } + ) { + this.tasks = new Map() + this.notifications = new Map() + this.pendingByParent = new Map() this.client = ctx.client this.directory = ctx.directory this.concurrencyManager = new ConcurrencyManager(config) @@ -65,82 +114,1487 @@ export class BackgroundManager { } async launch(input: LaunchInput): Promise { - return launchBackgroundTask({ input, tasks: this.tasks, pendingByParent: this.pendingByParent, queuesByKey: this.queuesByKey, getConcurrencyKeyFromInput: (i) => this.getConcurrencyKeyFromInput(i), processKey: (key) => void this.processKey(key) }) + log("[background-agent] launch() called with:", { + agent: input.agent, + model: input.model, + description: input.description, + parentSessionID: input.parentSessionID, + }) + + if (!input.agent || input.agent.trim() === "") { + throw new Error("Agent parameter is required") + } + + // Create task immediately with status="pending" + const task: BackgroundTask = { + id: `bg_${crypto.randomUUID().slice(0, 8)}`, + status: "pending", + queuedAt: new Date(), + // Do NOT set startedAt - will be set when running + // Do NOT set sessionID - will be set when running + description: input.description, + prompt: input.prompt, + agent: input.agent, + parentSessionID: input.parentSessionID, + parentMessageID: input.parentMessageID, + parentModel: input.parentModel, + parentAgent: input.parentAgent, + model: input.model, + category: input.category, + } + + this.tasks.set(task.id, task) + + // Track for batched notifications immediately (pending state) + if (input.parentSessionID) { + const pending = this.pendingByParent.get(input.parentSessionID) ?? new Set() + pending.add(task.id) + this.pendingByParent.set(input.parentSessionID, pending) + } + + // Add to queue + const key = this.getConcurrencyKeyFromInput(input) + const queue = this.queuesByKey.get(key) ?? [] + queue.push({ task, input }) + this.queuesByKey.set(key, queue) + + log("[background-agent] Task queued:", { taskId: task.id, key, queueLength: queue.length }) + + const toastManager = getTaskToastManager() + if (toastManager) { + toastManager.addTask({ + id: task.id, + description: input.description, + agent: input.agent, + isBackground: true, + status: "queued", + skills: input.skills, + }) + } + + // Trigger processing (fire-and-forget) + this.processKey(key) + + return task } - async trackTask(input: { taskId: string; sessionID: string; parentSessionID: string; description: string; agent?: string; parentAgent?: string; concurrencyKey?: string }): Promise { - return trackExternalTask({ input, tasks: this.tasks, pendingByParent: this.pendingByParent, concurrencyManager: this.concurrencyManager, startPolling: () => this.startPolling(), cleanupPendingByParent: (task) => this.cleanupPendingByParent(task) }) + private async processKey(key: string): Promise { + if (this.processingKeys.has(key)) { + return + } + + this.processingKeys.add(key) + + try { + const queue = this.queuesByKey.get(key) + while (queue && queue.length > 0) { + const item = queue[0] + + await this.concurrencyManager.acquire(key) + + if (item.task.status === "cancelled") { + this.concurrencyManager.release(key) + queue.shift() + continue + } + + try { + await this.startTask(item) + } catch (error) { + log("[background-agent] Error starting task:", error) + // Release concurrency slot if startTask failed and didn't release it itself + // This prevents slot leaks when errors occur after acquire but before task.concurrencyKey is set + if (!item.task.concurrencyKey) { + this.concurrencyManager.release(key) + } + } + + queue.shift() + } + } finally { + this.processingKeys.delete(key) + } + } + + private async startTask(item: QueueItem): Promise { + const { task, input } = item + + log("[background-agent] Starting task:", { + taskId: task.id, + agent: input.agent, + model: input.model, + }) + + const concurrencyKey = this.getConcurrencyKeyFromInput(input) + + const parentSession = await this.client.session.get({ + path: { id: input.parentSessionID }, + }).catch((err) => { + log(`[background-agent] Failed to get parent session: ${err}`) + return null + }) + const parentDirectory = parentSession?.data?.directory ?? this.directory + log(`[background-agent] Parent dir: ${parentSession?.data?.directory}, using: ${parentDirectory}`) + + const inheritedPermission = (parentSession as any)?.data?.permission + const permissionRules = Array.isArray(inheritedPermission) + ? inheritedPermission.filter((r: any) => r?.permission !== "question") + : [] + permissionRules.push({ permission: "question", action: "deny" as const, pattern: "*" }) + + const createResult = await this.client.session.create({ + body: { + parentID: input.parentSessionID, + title: `${input.description} (@${input.agent} subagent)`, + permission: permissionRules, + } as any, + query: { + directory: parentDirectory, + }, + }) + + if (createResult.error) { + throw new Error(`Failed to create background session: ${createResult.error}`) + } + + if (!createResult.data?.id) { + throw new Error("Failed to create background session: API returned no session ID") + } + + const sessionID = createResult.data.id + subagentSessions.add(sessionID) + + log("[background-agent] tmux callback check", { + hasCallback: !!this.onSubagentSessionCreated, + tmuxEnabled: this.tmuxEnabled, + isInsideTmux: isInsideTmux(), + sessionID, + parentID: input.parentSessionID, + }) + + if (this.onSubagentSessionCreated && this.tmuxEnabled && isInsideTmux()) { + log("[background-agent] Invoking tmux callback NOW", { sessionID }) + await this.onSubagentSessionCreated({ + sessionID, + parentID: input.parentSessionID, + title: input.description, + }).catch((err) => { + log("[background-agent] Failed to spawn tmux pane:", err) + }) + log("[background-agent] tmux callback completed, waiting 200ms") + await new Promise(r => setTimeout(r, 200)) + } else { + log("[background-agent] SKIP tmux callback - conditions not met") + } + + // Update task to running state + task.status = "running" + task.startedAt = new Date() + task.sessionID = sessionID + task.progress = { + toolCalls: 0, + lastUpdate: new Date(), + } + task.concurrencyKey = concurrencyKey + task.concurrencyGroup = concurrencyKey + + this.startPolling() + + log("[background-agent] Launching task:", { taskId: task.id, sessionID, agent: input.agent }) + + const toastManager = getTaskToastManager() + if (toastManager) { + toastManager.updateTask(task.id, "running") + } + + log("[background-agent] Calling prompt (fire-and-forget) for launch with:", { + sessionID, + agent: input.agent, + model: input.model, + hasSkillContent: !!input.skillContent, + promptLength: input.prompt.length, + }) + + // Fire-and-forget prompt via promptAsync (no response body needed) + // Include model if caller provided one (e.g., from Sisyphus category configs) + // IMPORTANT: variant must be a top-level field in the body, NOT nested inside model + // OpenCode's PromptInput schema expects: { model: { providerID, modelID }, variant: "max" } + const launchModel = input.model + ? { providerID: input.model.providerID, modelID: input.model.modelID } + : undefined + const launchVariant = input.model?.variant + + promptWithModelSuggestionRetry(this.client, { + path: { id: sessionID }, + body: { + agent: input.agent, + ...(launchModel ? { model: launchModel } : {}), + ...(launchVariant ? { variant: launchVariant } : {}), + system: input.skillContent, + tools: { + ...getAgentToolRestrictions(input.agent), + task: false, + call_omo_agent: true, + question: false, + }, + parts: [{ type: "text", text: input.prompt }], + }, + }).catch((error) => { + log("[background-agent] promptAsync error:", error) + const existingTask = this.findBySession(sessionID) + if (existingTask) { + existingTask.status = "error" + const errorMessage = error instanceof Error ? error.message : String(error) + if (errorMessage.includes("agent.name") || errorMessage.includes("undefined")) { + existingTask.error = `Agent "${input.agent}" not found. Make sure the agent is registered in your opencode.json or provided by a plugin.` + } else { + existingTask.error = errorMessage + } + existingTask.completedAt = new Date() + if (existingTask.concurrencyKey) { + this.concurrencyManager.release(existingTask.concurrencyKey) + existingTask.concurrencyKey = undefined + } + + // Abort the session to prevent infinite polling hang + this.client.session.abort({ + path: { id: sessionID }, + }).catch(() => {}) + + this.markForNotification(existingTask) + this.cleanupPendingByParent(existingTask) + this.enqueueNotificationForParent(existingTask.parentSessionID, () => this.notifyParentSession(existingTask)).catch(err => { + log("[background-agent] Failed to notify on error:", err) + }) + } + }) + } + + getTask(id: string): BackgroundTask | undefined { + return this.tasks.get(id) + } + + getTasksByParentSession(sessionID: string): BackgroundTask[] { + const result: BackgroundTask[] = [] + for (const task of this.tasks.values()) { + if (task.parentSessionID === sessionID) { + result.push(task) + } + } + return result + } + + getAllDescendantTasks(sessionID: string): BackgroundTask[] { + const result: BackgroundTask[] = [] + const directChildren = this.getTasksByParentSession(sessionID) + + for (const child of directChildren) { + result.push(child) + if (child.sessionID) { + const descendants = this.getAllDescendantTasks(child.sessionID) + result.push(...descendants) + } + } + + return result + } + + findBySession(sessionID: string): BackgroundTask | undefined { + for (const task of this.tasks.values()) { + if (task.sessionID === sessionID) { + return task + } + } + return undefined + } + + private getConcurrencyKeyFromInput(input: LaunchInput): string { + if (input.model) { + return `${input.model.providerID}/${input.model.modelID}` + } + return input.agent + } + + /** + * Track a task created elsewhere (e.g., from task) for notification tracking. + * This allows tasks created by other tools to receive the same toast/prompt notifications. + */ + async trackTask(input: { + taskId: string + sessionID: string + parentSessionID: string + description: string + agent?: string + parentAgent?: string + concurrencyKey?: string + }): Promise { + const existingTask = this.tasks.get(input.taskId) + if (existingTask) { + // P2 fix: Clean up old parent's pending set BEFORE changing parent + // Otherwise cleanupPendingByParent would use the new parent ID + const parentChanged = input.parentSessionID !== existingTask.parentSessionID + if (parentChanged) { + this.cleanupPendingByParent(existingTask) // Clean from OLD parent + existingTask.parentSessionID = input.parentSessionID + } + if (input.parentAgent !== undefined) { + existingTask.parentAgent = input.parentAgent + } + if (!existingTask.concurrencyGroup) { + existingTask.concurrencyGroup = input.concurrencyKey ?? existingTask.agent + } + + if (existingTask.sessionID) { + subagentSessions.add(existingTask.sessionID) + } + this.startPolling() + + // Track for batched notifications if task is pending or running + if (existingTask.status === "pending" || existingTask.status === "running") { + const pending = this.pendingByParent.get(input.parentSessionID) ?? new Set() + pending.add(existingTask.id) + this.pendingByParent.set(input.parentSessionID, pending) + } else if (!parentChanged) { + // Only clean up if parent didn't change (already cleaned above if it did) + this.cleanupPendingByParent(existingTask) + } + + log("[background-agent] External task already registered:", { taskId: existingTask.id, sessionID: existingTask.sessionID, status: existingTask.status }) + + return existingTask + } + + const concurrencyGroup = input.concurrencyKey ?? input.agent ?? "task" + + // Acquire concurrency slot if a key is provided + if (input.concurrencyKey) { + await this.concurrencyManager.acquire(input.concurrencyKey) + } + + const task: BackgroundTask = { + id: input.taskId, + sessionID: input.sessionID, + parentSessionID: input.parentSessionID, + parentMessageID: "", + description: input.description, + prompt: "", + agent: input.agent || "task", + status: "running", + startedAt: new Date(), + progress: { + toolCalls: 0, + lastUpdate: new Date(), + }, + parentAgent: input.parentAgent, + concurrencyKey: input.concurrencyKey, + concurrencyGroup, + } + + this.tasks.set(task.id, task) + subagentSessions.add(input.sessionID) + this.startPolling() + + if (input.parentSessionID) { + const pending = this.pendingByParent.get(input.parentSessionID) ?? new Set() + pending.add(task.id) + this.pendingByParent.set(input.parentSessionID, pending) + } + + log("[background-agent] Registered external task:", { taskId: task.id, sessionID: input.sessionID }) + + return task } async resume(input: ResumeInput): Promise { - return resumeBackgroundTask({ input, findBySession: (id) => this.findBySession(id), client: this.client, concurrencyManager: this.concurrencyManager, pendingByParent: this.pendingByParent, startPolling: () => this.startPolling(), markForNotification: (task) => this.markForNotification(task), cleanupPendingByParent: (task) => this.cleanupPendingByParent(task), notifyParentSession: (task) => this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) }) + const existingTask = this.findBySession(input.sessionId) + if (!existingTask) { + throw new Error(`Task not found for session: ${input.sessionId}`) + } + + if (!existingTask.sessionID) { + throw new Error(`Task has no sessionID: ${existingTask.id}`) + } + + if (existingTask.status === "running") { + log("[background-agent] Resume skipped - task already running:", { + taskId: existingTask.id, + sessionID: existingTask.sessionID, + }) + return existingTask + } + + // Re-acquire concurrency using the persisted concurrency group + const concurrencyKey = existingTask.concurrencyGroup ?? existingTask.agent + await this.concurrencyManager.acquire(concurrencyKey) + existingTask.concurrencyKey = concurrencyKey + existingTask.concurrencyGroup = concurrencyKey + + + existingTask.status = "running" + existingTask.completedAt = undefined + existingTask.error = undefined + existingTask.parentSessionID = input.parentSessionID + existingTask.parentMessageID = input.parentMessageID + existingTask.parentModel = input.parentModel + existingTask.parentAgent = input.parentAgent + // Reset startedAt on resume to prevent immediate completion + // The MIN_IDLE_TIME_MS check uses startedAt, so resumed tasks need fresh timing + existingTask.startedAt = new Date() + + existingTask.progress = { + toolCalls: existingTask.progress?.toolCalls ?? 0, + lastUpdate: new Date(), + } + + this.startPolling() + if (existingTask.sessionID) { + subagentSessions.add(existingTask.sessionID) + } + + if (input.parentSessionID) { + const pending = this.pendingByParent.get(input.parentSessionID) ?? new Set() + pending.add(existingTask.id) + this.pendingByParent.set(input.parentSessionID, pending) + } + + const toastManager = getTaskToastManager() + if (toastManager) { + toastManager.addTask({ + id: existingTask.id, + description: existingTask.description, + agent: existingTask.agent, + isBackground: true, + }) + } + + log("[background-agent] Resuming task:", { taskId: existingTask.id, sessionID: existingTask.sessionID }) + + log("[background-agent] Resuming task - calling prompt (fire-and-forget) with:", { + sessionID: existingTask.sessionID, + agent: existingTask.agent, + model: existingTask.model, + promptLength: input.prompt.length, + }) + + // Fire-and-forget prompt via promptAsync (no response body needed) + // Include model if task has one (preserved from original launch with category config) + // variant must be top-level in body, not nested inside model (OpenCode PromptInput schema) + const resumeModel = existingTask.model + ? { providerID: existingTask.model.providerID, modelID: existingTask.model.modelID } + : undefined + const resumeVariant = existingTask.model?.variant + + this.client.session.promptAsync({ + path: { id: existingTask.sessionID }, + body: { + agent: existingTask.agent, + ...(resumeModel ? { model: resumeModel } : {}), + ...(resumeVariant ? { variant: resumeVariant } : {}), + tools: { + ...getAgentToolRestrictions(existingTask.agent), + task: false, + call_omo_agent: true, + question: false, + }, + parts: [{ type: "text", text: input.prompt }], + }, + }).catch((error) => { + log("[background-agent] resume prompt error:", error) + existingTask.status = "error" + const errorMessage = error instanceof Error ? error.message : String(error) + existingTask.error = errorMessage + existingTask.completedAt = new Date() + + // Release concurrency on error to prevent slot leaks + if (existingTask.concurrencyKey) { + this.concurrencyManager.release(existingTask.concurrencyKey) + existingTask.concurrencyKey = undefined + } + + // Abort the session to prevent infinite polling hang + if (existingTask.sessionID) { + this.client.session.abort({ + path: { id: existingTask.sessionID }, + }).catch(() => {}) + } + + this.markForNotification(existingTask) + this.cleanupPendingByParent(existingTask) + this.enqueueNotificationForParent(existingTask.parentSessionID, () => this.notifyParentSession(existingTask)).catch(err => { + log("[background-agent] Failed to notify on resume error:", err) + }) + }) + + return existingTask } - getTask(id: string): BackgroundTask | undefined { return this.tasks.get(id) } - getTasksByParentSession(sessionID: string): BackgroundTask[] { return getTasksByParentSession(this.tasks.values(), sessionID) } - getAllDescendantTasks(sessionID: string): BackgroundTask[] { return getAllDescendantTasks((id) => this.getTasksByParentSession(id), sessionID) } - findBySession(sessionID: string): BackgroundTask | undefined { return findTaskBySession(this.tasks.values(), sessionID) } - getRunningTasks(): BackgroundTask[] { return getRunningTasks(this.tasks.values()) } - getCompletedTasks(): BackgroundTask[] { return getCompletedTasks(this.tasks.values()) } + private async checkSessionTodos(sessionID: string): Promise { + try { + const response = await this.client.session.todo({ + path: { id: sessionID }, + }) + const todos = (response.data ?? response) as Todo[] + if (!todos || todos.length === 0) return false - markForNotification(task: BackgroundTask): void { markForNotification(this.notifications, task) } - getPendingNotifications(sessionID: string): BackgroundTask[] { return getPendingNotifications(this.notifications, sessionID) } - clearNotifications(sessionID: string): void { clearNotifications(this.notifications, sessionID) } + const incomplete = todos.filter( + (t) => t.status !== "completed" && t.status !== "cancelled" + ) + return incomplete.length > 0 + } catch { + return false + } + } + handleEvent(event: Event): void { + const props = event.properties + + if (event.type === "message.part.updated") { + if (!props || typeof props !== "object" || !("sessionID" in props)) return + const partInfo = props as unknown as MessagePartInfo + const sessionID = partInfo?.sessionID + if (!sessionID) return + + const task = this.findBySession(sessionID) + if (!task) return + + // Clear any pending idle deferral timer since the task is still active + const existingTimer = this.idleDeferralTimers.get(task.id) + if (existingTimer) { + clearTimeout(existingTimer) + this.idleDeferralTimers.delete(task.id) + } + + if (partInfo?.type === "tool" || partInfo?.tool) { + if (!task.progress) { + task.progress = { + toolCalls: 0, + lastUpdate: new Date(), + } + } + task.progress.toolCalls += 1 + task.progress.lastTool = partInfo.tool + task.progress.lastUpdate = new Date() + } + } + + if (event.type === "session.idle") { + const sessionID = props?.sessionID as string | undefined + if (!sessionID) return + + const task = this.findBySession(sessionID) + if (!task || task.status !== "running") return + + const startedAt = task.startedAt + if (!startedAt) return + + // Edge guard: Require minimum elapsed time (5 seconds) before accepting idle + const elapsedMs = Date.now() - startedAt.getTime() + if (elapsedMs < MIN_IDLE_TIME_MS) { + const remainingMs = MIN_IDLE_TIME_MS - elapsedMs + if (!this.idleDeferralTimers.has(task.id)) { + log("[background-agent] Deferring early session.idle:", { elapsedMs, remainingMs, taskId: task.id }) + const timer = setTimeout(() => { + this.idleDeferralTimers.delete(task.id) + this.handleEvent({ type: "session.idle", properties: { sessionID } }) + }, remainingMs) + this.idleDeferralTimers.set(task.id, timer) + } else { + log("[background-agent] session.idle already deferred:", { elapsedMs, taskId: task.id }) + } + return + } + + // Edge guard: Verify session has actual assistant output before completing + this.validateSessionHasOutput(sessionID).then(async (hasValidOutput) => { + // Re-check status after async operation (could have been completed by polling) + if (task.status !== "running") { + log("[background-agent] Task status changed during validation, skipping:", { taskId: task.id, status: task.status }) + return + } + + if (!hasValidOutput) { + log("[background-agent] Session.idle but no valid output yet, waiting:", task.id) + return + } + + const hasIncompleteTodos = await this.checkSessionTodos(sessionID) + + // Re-check status after async operation again + if (task.status !== "running") { + log("[background-agent] Task status changed during todo check, skipping:", { taskId: task.id, status: task.status }) + return + } + + if (hasIncompleteTodos) { + log("[background-agent] Task has incomplete todos, waiting for todo-continuation:", task.id) + return + } + + await this.tryCompleteTask(task, "session.idle event") + }).catch(err => { + log("[background-agent] Error in session.idle handler:", err) + }) + } + + if (event.type === "session.deleted") { + const info = props?.info + if (!info || typeof info.id !== "string") return + const sessionID = info.id + + const tasksToCancel = new Map() + const directTask = this.findBySession(sessionID) + if (directTask) { + tasksToCancel.set(directTask.id, directTask) + } + for (const descendant of this.getAllDescendantTasks(sessionID)) { + tasksToCancel.set(descendant.id, descendant) + } + + if (tasksToCancel.size === 0) return + + for (const task of tasksToCancel.values()) { + if (task.status === "running" || task.status === "pending") { + void this.cancelTask(task.id, { + source: "session.deleted", + reason: "Session deleted", + skipNotification: true, + }).catch(err => { + log("[background-agent] Failed to cancel task on session.deleted:", { taskId: task.id, error: err }) + }) + } + + const existingTimer = this.completionTimers.get(task.id) + if (existingTimer) { + clearTimeout(existingTimer) + this.completionTimers.delete(task.id) + } + + const idleTimer = this.idleDeferralTimers.get(task.id) + if (idleTimer) { + clearTimeout(idleTimer) + this.idleDeferralTimers.delete(task.id) + } + + this.cleanupPendingByParent(task) + this.tasks.delete(task.id) + this.clearNotificationsForTask(task.id) + if (task.sessionID) { + subagentSessions.delete(task.sessionID) + } + } + } + } + + markForNotification(task: BackgroundTask): void { + const queue = this.notifications.get(task.parentSessionID) ?? [] + queue.push(task) + this.notifications.set(task.parentSessionID, queue) + } + + getPendingNotifications(sessionID: string): BackgroundTask[] { + return this.notifications.get(sessionID) ?? [] + } + + clearNotifications(sessionID: string): void { + this.notifications.delete(sessionID) + } + + /** + * Validates that a session has actual assistant/tool output before marking complete. + * Prevents premature completion when session.idle fires before agent responds. + */ + private async validateSessionHasOutput(sessionID: string): Promise { + try { + const response = await this.client.session.messages({ + path: { id: sessionID }, + }) + + const messages = response.data ?? [] + + // Check for at least one assistant or tool message + const hasAssistantOrToolMessage = messages.some( + (m: { info?: { role?: string } }) => + m.info?.role === "assistant" || m.info?.role === "tool" + ) + + if (!hasAssistantOrToolMessage) { + log("[background-agent] No assistant/tool messages found in session:", sessionID) + return false + } + + // Additionally check that at least one message has content (not just empty) + // OpenCode API uses different part types than Anthropic's API: + // - "reasoning" with .text property (thinking/reasoning content) + // - "tool" with .state.output property (tool call results) + // - "text" with .text property (final text output) + // - "step-start"/"step-finish" (metadata, no content) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const hasContent = messages.some((m: any) => { + if (m.info?.role !== "assistant" && m.info?.role !== "tool") return false + const parts = m.parts ?? [] + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return parts.some((p: any) => + // Text content (final output) + (p.type === "text" && p.text && p.text.trim().length > 0) || + // Reasoning content (thinking blocks) + (p.type === "reasoning" && p.text && p.text.trim().length > 0) || + // Tool calls (indicates work was done) + p.type === "tool" || + // Tool results (output from executed tools) - important for tool-only tasks + (p.type === "tool_result" && p.content && + (typeof p.content === "string" ? p.content.trim().length > 0 : p.content.length > 0)) + ) + }) + + if (!hasContent) { + log("[background-agent] Messages exist but no content found in session:", sessionID) + return false + } + + return true + } catch (error) { + log("[background-agent] Error validating session output:", error) + // On error, allow completion to proceed (don't block indefinitely) + return true + } + } + + private clearNotificationsForTask(taskId: string): void { + for (const [sessionID, tasks] of this.notifications.entries()) { + const filtered = tasks.filter((t) => t.id !== taskId) + if (filtered.length === 0) { + this.notifications.delete(sessionID) + } else { + this.notifications.set(sessionID, filtered) + } + } + } + + /** + * Remove task from pending tracking for its parent session. + * Cleans up the parent entry if no pending tasks remain. + */ + private cleanupPendingByParent(task: BackgroundTask): void { + if (!task.parentSessionID) return + const pending = this.pendingByParent.get(task.parentSessionID) + if (pending) { + pending.delete(task.id) + if (pending.size === 0) { + this.pendingByParent.delete(task.parentSessionID) + } + } + } + + async cancelTask( + taskId: string, + options?: { source?: string; reason?: string; abortSession?: boolean; skipNotification?: boolean } + ): Promise { + const task = this.tasks.get(taskId) + if (!task || (task.status !== "running" && task.status !== "pending")) { + return false + } + + const source = options?.source ?? "cancel" + const abortSession = options?.abortSession !== false + const reason = options?.reason + + if (task.status === "pending") { + const key = task.model + ? `${task.model.providerID}/${task.model.modelID}` + : task.agent + const queue = this.queuesByKey.get(key) + if (queue) { + const index = queue.findIndex(item => item.task.id === taskId) + if (index !== -1) { + queue.splice(index, 1) + if (queue.length === 0) { + this.queuesByKey.delete(key) + } + } + } + log("[background-agent] Cancelled pending task:", { taskId, key }) + } + + task.status = "cancelled" + task.completedAt = new Date() + if (reason) { + task.error = reason + } + + if (task.concurrencyKey) { + this.concurrencyManager.release(task.concurrencyKey) + task.concurrencyKey = undefined + } + + const existingTimer = this.completionTimers.get(task.id) + if (existingTimer) { + clearTimeout(existingTimer) + this.completionTimers.delete(task.id) + } + + const idleTimer = this.idleDeferralTimers.get(task.id) + if (idleTimer) { + clearTimeout(idleTimer) + this.idleDeferralTimers.delete(task.id) + } + + this.cleanupPendingByParent(task) + + if (abortSession && task.sessionID) { + this.client.session.abort({ + path: { id: task.sessionID }, + }).catch(() => {}) + } + + if (options?.skipNotification) { + log(`[background-agent] Task cancelled via ${source} (notification skipped):`, task.id) + return true + } + + this.markForNotification(task) + + try { + await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) + log(`[background-agent] Task cancelled via ${source}:`, task.id) + } catch (err) { + log("[background-agent] Error in notifyParentSession for cancelled task:", { taskId: task.id, error: err }) + } + + return true + } + + /** + * Cancels a pending task by removing it from queue and marking as cancelled. + * Does NOT abort session (no session exists yet) or release concurrency slot (wasn't acquired). + */ cancelPendingTask(taskId: string): boolean { const task = this.tasks.get(taskId) - if (!task || task.status !== "pending") return false + if (!task || task.status !== "pending") { + return false + } + void this.cancelTask(taskId, { source: "cancelPendingTask", abortSession: false }) return true } - async cancelTask(taskId: string, options?: { source?: string; reason?: string; abortSession?: boolean; skipNotification?: boolean }): Promise { - return cancelBackgroundTask({ taskId, options, tasks: this.tasks, queuesByKey: this.queuesByKey, completionTimers: this.completionTimers, idleDeferralTimers: this.idleDeferralTimers, concurrencyManager: this.concurrencyManager, client: this.client, cleanupPendingByParent: (task) => this.cleanupPendingByParent(task), markForNotification: (task) => this.markForNotification(task), notifyParentSession: (task) => this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) }) - } - - handleEvent(event: { type: string; properties?: Record }): void { - handleBackgroundEvent({ event, findBySession: (id) => this.findBySession(id), getAllDescendantTasks: (id) => this.getAllDescendantTasks(id), cancelTask: (id, opts) => this.cancelTask(id, opts), tryCompleteTask: (task, source) => this.tryCompleteTask(task, source), validateSessionHasOutput: (id) => this.validateSessionHasOutput(id), checkSessionTodos: (id) => this.checkSessionTodos(id), idleDeferralTimers: this.idleDeferralTimers, completionTimers: this.completionTimers, tasks: this.tasks, cleanupPendingByParent: (task) => this.cleanupPendingByParent(task), clearNotificationsForTask: (id) => this.clearNotificationsForTask(id), emitIdleEvent: (sessionID) => this.handleEvent({ type: "session.idle", properties: { sessionID } }) }) - } - - shutdown(): void { - this.notificationQueueByParent.clear() - shutdownBackgroundManager({ shutdownTriggered: this.shutdownTriggered, stopPolling: () => this.stopPolling(), tasks: this.tasks, client: this.client, onShutdown: this.onShutdown, concurrencyManager: this.concurrencyManager, completionTimers: this.completionTimers, idleDeferralTimers: this.idleDeferralTimers, notifications: this.notifications, pendingByParent: this.pendingByParent, queuesByKey: this.queuesByKey, processingKeys: this.processingKeys, unregisterProcessCleanup: () => this.unregisterProcessCleanup() }) - } - - private getConcurrencyKeyFromInput(input: LaunchInput): string { return input.model ? `${input.model.providerID}/${input.model.modelID}` : input.agent } - private async processKey(key: string): Promise { await processConcurrencyKeyQueue({ key, queuesByKey: this.queuesByKey, processingKeys: this.processingKeys, concurrencyManager: this.concurrencyManager, startTask: (item) => this.startTask(item) }) } - private async startTask(item: QueueItem): Promise { - await startQueuedTask({ item, client: this.client, defaultDirectory: this.directory, tmuxEnabled: this.tmuxEnabled, onSubagentSessionCreated: this.onSubagentSessionCreated, startPolling: () => this.startPolling(), getConcurrencyKeyFromInput: (i) => this.getConcurrencyKeyFromInput(i), concurrencyManager: this.concurrencyManager, findBySession: (id) => this.findBySession(id), markForNotification: (task) => this.markForNotification(task), cleanupPendingByParent: (task) => this.cleanupPendingByParent(task), notifyParentSession: (task) => this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) }) - } - private startPolling(): void { if (this.pollingInterval) return - this.pollingInterval = setInterval(() => void this.pollRunningTasks(), POLLING_INTERVAL_MS) + + this.pollingInterval = setInterval(() => { + this.pollRunningTasks() + }, POLLING_INTERVAL_MS) this.pollingInterval.unref() } - private stopPolling(): void { if (this.pollingInterval) { clearInterval(this.pollingInterval); this.pollingInterval = undefined } } - private async pollRunningTasks(): Promise { - await pollRunningTasks({ tasks: this.tasks.values(), client: this.client, pruneStaleTasksAndNotifications: () => this.pruneStaleTasksAndNotifications(), checkAndInterruptStaleTasks: () => this.checkAndInterruptStaleTasks(), validateSessionHasOutput: (id) => this.validateSessionHasOutput(id), checkSessionTodos: (id) => this.checkSessionTodos(id), tryCompleteTask: (task, source) => this.tryCompleteTask(task, source), hasRunningTasks: () => this.hasRunningTasks(), stopPolling: () => this.stopPolling() }) + private stopPolling(): void { + if (this.pollingInterval) { + clearInterval(this.pollingInterval) + this.pollingInterval = undefined + } + } + + private registerProcessCleanup(): void { + BackgroundManager.cleanupManagers.add(this) + + if (BackgroundManager.cleanupRegistered) return + BackgroundManager.cleanupRegistered = true + + const cleanupAll = () => { + for (const manager of BackgroundManager.cleanupManagers) { + try { + manager.shutdown() + } catch (error) { + log("[background-agent] Error during shutdown cleanup:", error) + } + } + } + + const registerSignal = (signal: ProcessCleanupEvent, exitAfter: boolean): void => { + const listener = registerProcessSignal(signal, cleanupAll, exitAfter) + BackgroundManager.cleanupHandlers.set(signal, listener) + } + + registerSignal("SIGINT", true) + registerSignal("SIGTERM", true) + if (process.platform === "win32") { + registerSignal("SIGBREAK", true) + } + registerSignal("beforeExit", false) + registerSignal("exit", false) + } + + private unregisterProcessCleanup(): void { + BackgroundManager.cleanupManagers.delete(this) + + if (BackgroundManager.cleanupManagers.size > 0) return + + for (const [signal, listener] of BackgroundManager.cleanupHandlers.entries()) { + process.off(signal, listener) + } + BackgroundManager.cleanupHandlers.clear() + BackgroundManager.cleanupRegistered = false + } + + + /** + * Get all running tasks (for compaction hook) + */ + getRunningTasks(): BackgroundTask[] { + return Array.from(this.tasks.values()).filter(t => t.status === "running") + } + + /** + * Get all completed tasks still in memory (for compaction hook) + */ + getCompletedTasks(): BackgroundTask[] { + return Array.from(this.tasks.values()).filter(t => t.status !== "running") + } + + /** + * Safely complete a task with race condition protection. + * Returns true if task was successfully completed, false if already completed by another path. + */ + private async tryCompleteTask(task: BackgroundTask, source: string): Promise { + // Guard: Check if task is still running (could have been completed by another path) + if (task.status !== "running") { + log("[background-agent] Task already completed, skipping:", { taskId: task.id, status: task.status, source }) + return false + } + + // Atomically mark as completed to prevent race conditions + task.status = "completed" + task.completedAt = new Date() + + // Release concurrency BEFORE any async operations to prevent slot leaks + if (task.concurrencyKey) { + this.concurrencyManager.release(task.concurrencyKey) + task.concurrencyKey = undefined + } + + this.markForNotification(task) + + // Ensure pending tracking is cleaned up even if notification fails + this.cleanupPendingByParent(task) + + const idleTimer = this.idleDeferralTimers.get(task.id) + if (idleTimer) { + clearTimeout(idleTimer) + this.idleDeferralTimers.delete(task.id) + } + + if (task.sessionID) { + this.client.session.abort({ + path: { id: task.sessionID }, + }).catch(() => {}) + } + + try { + await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) + log(`[background-agent] Task completed via ${source}:`, task.id) + } catch (err) { + log("[background-agent] Error in notifyParentSession:", { taskId: task.id, error: err }) + // Concurrency already released, notification failed but task is complete + } + + return true + } + + private async notifyParentSession(task: BackgroundTask): Promise { + // Note: Callers must release concurrency before calling this method + // to ensure slots are freed even if notification fails + + const duration = this.formatDuration(task.startedAt ?? new Date(), task.completedAt) + + log("[background-agent] notifyParentSession called for task:", task.id) + + // Show toast notification + const toastManager = getTaskToastManager() + if (toastManager) { + toastManager.showCompletionToast({ + id: task.id, + description: task.description, + duration, + }) + } + + // Update pending tracking and check if all tasks complete + const pendingSet = this.pendingByParent.get(task.parentSessionID) + let allComplete = false + let remainingCount = 0 + if (pendingSet) { + pendingSet.delete(task.id) + remainingCount = pendingSet.size + allComplete = remainingCount === 0 + if (allComplete) { + this.pendingByParent.delete(task.parentSessionID) + } + } else { + allComplete = true + } + + const statusText = task.status === "completed" ? "COMPLETED" : "CANCELLED" + const errorInfo = task.error ? `\n**Error:** ${task.error}` : "" + + let notification: string + let completedTasks: BackgroundTask[] = [] + if (allComplete) { + completedTasks = Array.from(this.tasks.values()) + .filter(t => t.parentSessionID === task.parentSessionID && t.status !== "running" && t.status !== "pending") + const completedTasksText = completedTasks + .map(t => `- \`${t.id}\`: ${t.description}`) + .join("\n") + + notification = ` +[ALL BACKGROUND TASKS COMPLETE] + +**Completed:** +${completedTasksText || `- \`${task.id}\`: ${task.description}`} + +Use \`background_output(task_id="")\` to retrieve each result. +` + } else { + // Individual completion - silent notification + notification = ` +[BACKGROUND TASK ${statusText}] +**ID:** \`${task.id}\` +**Description:** ${task.description} +**Duration:** ${duration}${errorInfo} + +**${remainingCount} task${remainingCount === 1 ? "" : "s"} still in progress.** You WILL be notified when ALL complete. +Do NOT poll - continue productive work. + +Use \`background_output(task_id="${task.id}")\` to retrieve this result when ready. +` + } + + let agent: string | undefined = task.parentAgent + let model: { providerID: string; modelID: string } | undefined + + try { + const messagesResp = await this.client.session.messages({ path: { id: task.parentSessionID } }) + const messages = (messagesResp.data ?? []) as Array<{ + info?: { agent?: string; model?: { providerID: string; modelID: string }; modelID?: string; providerID?: string } + }> + for (let i = messages.length - 1; i >= 0; i--) { + const info = messages[i].info + if (info?.agent || info?.model || (info?.modelID && info?.providerID)) { + agent = info.agent ?? task.parentAgent + model = info.model ?? (info.providerID && info.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined) + break + } + } + } catch (error) { + if (this.isAbortedSessionError(error)) { + log("[background-agent] Parent session aborted, skipping notification:", { + taskId: task.id, + parentSessionID: task.parentSessionID, + }) + return + } + const messageDir = getMessageDir(task.parentSessionID) + const currentMessage = messageDir ? findNearestMessageWithFields(messageDir) : null + agent = currentMessage?.agent ?? task.parentAgent + model = currentMessage?.model?.providerID && currentMessage?.model?.modelID + ? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID } + : undefined + } + + log("[background-agent] notifyParentSession context:", { + taskId: task.id, + resolvedAgent: agent, + resolvedModel: model, + }) + + try { + await this.client.session.promptAsync({ + path: { id: task.parentSessionID }, + body: { + noReply: !allComplete, + ...(agent !== undefined ? { agent } : {}), + ...(model !== undefined ? { model } : {}), + parts: [{ type: "text", text: notification }], + }, + }) + log("[background-agent] Sent notification to parent session:", { + taskId: task.id, + allComplete, + noReply: !allComplete, + }) + } catch (error) { + if (this.isAbortedSessionError(error)) { + log("[background-agent] Parent session aborted, skipping notification:", { + taskId: task.id, + parentSessionID: task.parentSessionID, + }) + return + } + log("[background-agent] Failed to send notification:", error) + } + + if (allComplete) { + for (const completedTask of completedTasks) { + const taskId = completedTask.id + const existingTimer = this.completionTimers.get(taskId) + if (existingTimer) { + clearTimeout(existingTimer) + this.completionTimers.delete(taskId) + } + const timer = setTimeout(() => { + this.completionTimers.delete(taskId) + if (this.tasks.has(taskId)) { + this.clearNotificationsForTask(taskId) + this.tasks.delete(taskId) + log("[background-agent] Removed completed task from memory:", taskId) + } + }, TASK_CLEANUP_DELAY_MS) + this.completionTimers.set(taskId, timer) + } + } + } + + private formatDuration(start: Date, end?: Date): string { + const duration = (end ?? new Date()).getTime() - start.getTime() + const seconds = Math.floor(duration / 1000) + const minutes = Math.floor(seconds / 60) + const hours = Math.floor(minutes / 60) + + if (hours > 0) { + return `${hours}h ${minutes % 60}m ${seconds % 60}s` + } else if (minutes > 0) { + return `${minutes}m ${seconds % 60}s` + } + return `${seconds}s` + } + + private isAbortedSessionError(error: unknown): boolean { + const message = this.getErrorText(error) + return message.toLowerCase().includes("aborted") + } + + private getErrorText(error: unknown): string { + if (!error) return "" + if (typeof error === "string") return error + if (error instanceof Error) { + return `${error.name}: ${error.message}` + } + if (typeof error === "object" && error !== null) { + if ("message" in error && typeof error.message === "string") { + return error.message + } + if ("name" in error && typeof error.name === "string") { + return error.name + } + } + return "" + } + + private hasRunningTasks(): boolean { + for (const task of this.tasks.values()) { + if (task.status === "running") return true + } + return false } private pruneStaleTasksAndNotifications(): void { - pruneStaleState({ tasks: this.tasks, notifications: this.notifications, concurrencyManager: this.concurrencyManager, cleanupPendingByParent: (task) => this.cleanupPendingByParent(task), clearNotificationsForTask: (id) => this.clearNotificationsForTask(id) }) + const now = Date.now() + + for (const [taskId, task] of this.tasks.entries()) { + const timestamp = task.status === "pending" + ? task.queuedAt?.getTime() + : task.startedAt?.getTime() + + if (!timestamp) { + continue + } + + const age = now - timestamp + if (age > TASK_TTL_MS) { + const errorMessage = task.status === "pending" + ? "Task timed out while queued (30 minutes)" + : "Task timed out after 30 minutes" + + log("[background-agent] Pruning stale task:", { taskId, status: task.status, age: Math.round(age / 1000) + "s" }) + task.status = "error" + task.error = errorMessage + task.completedAt = new Date() + if (task.concurrencyKey) { + this.concurrencyManager.release(task.concurrencyKey) + task.concurrencyKey = undefined + } + // Clean up pendingByParent to prevent stale entries + this.cleanupPendingByParent(task) + this.clearNotificationsForTask(taskId) + this.tasks.delete(taskId) + if (task.sessionID) { + subagentSessions.delete(task.sessionID) + } + } + } + + for (const [sessionID, notifications] of this.notifications.entries()) { + if (notifications.length === 0) { + this.notifications.delete(sessionID) + continue + } + const validNotifications = notifications.filter((task) => { + if (!task.startedAt) return false + const age = now - task.startedAt.getTime() + return age <= TASK_TTL_MS + }) + if (validNotifications.length === 0) { + this.notifications.delete(sessionID) + } else if (validNotifications.length !== notifications.length) { + this.notifications.set(sessionID, validNotifications) + } + } } + private async checkAndInterruptStaleTasks(): Promise { - await checkAndInterruptStaleTasks({ tasks: this.tasks.values(), client: this.client, config: this.config, concurrencyManager: this.concurrencyManager, notifyParentSession: (task) => this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) }) + const staleTimeoutMs = this.config?.staleTimeoutMs ?? DEFAULT_STALE_TIMEOUT_MS + const now = Date.now() + + for (const task of this.tasks.values()) { + if (task.status !== "running") continue + if (!task.progress?.lastUpdate) continue + + const startedAt = task.startedAt + const sessionID = task.sessionID + if (!startedAt || !sessionID) continue + + const runtime = now - startedAt.getTime() + if (runtime < MIN_RUNTIME_BEFORE_STALE_MS) continue + + const timeSinceLastUpdate = now - task.progress.lastUpdate.getTime() + if (timeSinceLastUpdate <= staleTimeoutMs) continue + + if (task.status !== "running") continue + + const staleMinutes = Math.round(timeSinceLastUpdate / 60000) + task.status = "cancelled" + task.error = `Stale timeout (no activity for ${staleMinutes}min)` + task.completedAt = new Date() + + if (task.concurrencyKey) { + this.concurrencyManager.release(task.concurrencyKey) + task.concurrencyKey = undefined + } + + this.client.session.abort({ + path: { id: sessionID }, + }).catch(() => {}) + + log(`[background-agent] Task ${task.id} interrupted: stale timeout`) + + try { + await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) + } catch (err) { + log("[background-agent] Error in notifyParentSession for stale task:", { taskId: task.id, error: err }) + } + } } - private hasRunningTasks(): boolean { return hasRunningTasks(this.tasks.values()) } - private async tryCompleteTask(task: BackgroundTask, source: string): Promise { - return tryCompleteBackgroundTask({ task, source, concurrencyManager: this.concurrencyManager, idleDeferralTimers: this.idleDeferralTimers, client: this.client, markForNotification: (t) => this.markForNotification(t), cleanupPendingByParent: (t) => this.cleanupPendingByParent(t), notifyParentSession: (t) => this.enqueueNotificationForParent(t.parentSessionID, () => this.notifyParentSession(t)) }) - } - private async notifyParentSession(task: BackgroundTask): Promise { - await notifyParentSessionInternal({ task, tasks: this.tasks, pendingByParent: this.pendingByParent, completionTimers: this.completionTimers, clearNotificationsForTask: (id) => this.clearNotificationsForTask(id), client: this.client }) + private async pollRunningTasks(): Promise { + this.pruneStaleTasksAndNotifications() + await this.checkAndInterruptStaleTasks() + + const statusResult = await this.client.session.status() + const allStatuses = (statusResult.data ?? {}) as Record + + for (const task of this.tasks.values()) { + if (task.status !== "running") continue + + const sessionID = task.sessionID + if (!sessionID) continue + + try { + const sessionStatus = allStatuses[sessionID] + + // Don't skip if session not in status - fall through to message-based detection + if (sessionStatus?.type === "idle") { + // Edge guard: Validate session has actual output before completing + const hasValidOutput = await this.validateSessionHasOutput(sessionID) + if (!hasValidOutput) { + log("[background-agent] Polling idle but no valid output yet, waiting:", task.id) + continue + } + + // Re-check status after async operation + if (task.status !== "running") continue + + const hasIncompleteTodos = await this.checkSessionTodos(sessionID) + if (hasIncompleteTodos) { + log("[background-agent] Task has incomplete todos via polling, waiting:", task.id) + continue + } + + await this.tryCompleteTask(task, "polling (idle status)") + continue + } + + const messagesResult = await this.client.session.messages({ + path: { id: sessionID }, + }) + + if (!messagesResult.error && messagesResult.data) { + const messages = messagesResult.data as Array<{ + info?: { role?: string } + parts?: Array<{ type?: string; tool?: string; name?: string; text?: string }> + }> + const assistantMsgs = messages.filter( + (m) => m.info?.role === "assistant" + ) + + let toolCalls = 0 + let lastTool: string | undefined + let lastMessage: string | undefined + + for (const msg of assistantMsgs) { + const parts = msg.parts ?? [] + for (const part of parts) { + if (part.type === "tool_use" || part.tool) { + toolCalls++ + lastTool = part.tool || part.name || "unknown" + } + if (part.type === "text" && part.text) { + lastMessage = part.text + } + } + } + + if (!task.progress) { + task.progress = { toolCalls: 0, lastUpdate: new Date() } + } + task.progress.toolCalls = toolCalls + task.progress.lastTool = lastTool + task.progress.lastUpdate = new Date() + if (lastMessage) { + task.progress.lastMessage = lastMessage + task.progress.lastMessageAt = new Date() + } + + // Stability detection: complete when message count unchanged for 3 polls + const currentMsgCount = messages.length + const startedAt = task.startedAt + if (!startedAt) continue + + const elapsedMs = Date.now() - startedAt.getTime() + + if (elapsedMs >= MIN_STABILITY_TIME_MS) { + if (task.lastMsgCount === currentMsgCount) { + task.stablePolls = (task.stablePolls ?? 0) + 1 + if (task.stablePolls >= 3) { + // Re-fetch session status to confirm agent is truly idle + const recheckStatus = await this.client.session.status() + const recheckData = (recheckStatus.data ?? {}) as Record + const currentStatus = recheckData[sessionID] + + if (currentStatus?.type !== "idle") { + log("[background-agent] Stability reached but session not idle, resetting:", { + taskId: task.id, + sessionStatus: currentStatus?.type ?? "not_in_status" + }) + task.stablePolls = 0 + continue + } + + // Edge guard: Validate session has actual output before completing + const hasValidOutput = await this.validateSessionHasOutput(sessionID) + if (!hasValidOutput) { + log("[background-agent] Stability reached but no valid output, waiting:", task.id) + continue + } + + // Re-check status after async operation + if (task.status !== "running") continue + + const hasIncompleteTodos = await this.checkSessionTodos(sessionID) + if (!hasIncompleteTodos) { + await this.tryCompleteTask(task, "stability detection") + continue + } + } + } else { + task.stablePolls = 0 + } + } + task.lastMsgCount = currentMsgCount + } + } catch (error) { + log("[background-agent] Poll error for task:", { taskId: task.id, error }) + } + } + + if (!this.hasRunningTasks()) { + this.stopPolling() + } } - private enqueueNotificationForParent(parentSessionID: string | undefined, operation: () => Promise): Promise { - if (!parentSessionID) return operation() + /** + * Shutdown the manager gracefully. + * Cancels all pending concurrency waiters and clears timers. + * Should be called when the plugin is unloaded. + */ + shutdown(): void { + if (this.shutdownTriggered) return + this.shutdownTriggered = true + log("[background-agent] Shutting down BackgroundManager") + this.stopPolling() + + // Abort all running sessions to prevent zombie processes (#1240) + for (const task of this.tasks.values()) { + if (task.status === "running" && task.sessionID) { + this.client.session.abort({ + path: { id: task.sessionID }, + }).catch(() => {}) + } + } + + // Notify shutdown listeners (e.g., tmux cleanup) + if (this.onShutdown) { + try { + this.onShutdown() + } catch (error) { + log("[background-agent] Error in onShutdown callback:", error) + } + } + + // Release concurrency for all running tasks + for (const task of this.tasks.values()) { + if (task.concurrencyKey) { + this.concurrencyManager.release(task.concurrencyKey) + task.concurrencyKey = undefined + } + } + + for (const timer of this.completionTimers.values()) { + clearTimeout(timer) + } + this.completionTimers.clear() + + for (const timer of this.idleDeferralTimers.values()) { + clearTimeout(timer) + } + this.idleDeferralTimers.clear() + + this.concurrencyManager.clear() + this.tasks.clear() + this.notifications.clear() + this.pendingByParent.clear() + this.notificationQueueByParent.clear() + this.queuesByKey.clear() + this.processingKeys.clear() + this.unregisterProcessCleanup() + log("[background-agent] Shutdown complete") + + } + + private enqueueNotificationForParent( + parentSessionID: string | undefined, + operation: () => Promise + ): Promise { + if (!parentSessionID) { + return operation() + } const previous = this.notificationQueueByParent.get(parentSessionID) ?? Promise.resolve() const current = previous @@ -149,36 +1603,44 @@ export class BackgroundManager { this.notificationQueueByParent.set(parentSessionID, current) - void current - .finally(() => { - if (this.notificationQueueByParent.get(parentSessionID) === current) { - this.notificationQueueByParent.delete(parentSessionID) - } - }) - .catch(() => {}) + void current.finally(() => { + if (this.notificationQueueByParent.get(parentSessionID) === current) { + this.notificationQueueByParent.delete(parentSessionID) + } + }).catch(() => {}) return current } - - private async validateSessionHasOutput(sessionID: string): Promise { return validateSessionHasOutput(this.client, sessionID) } - private async checkSessionTodos(sessionID: string): Promise { return checkSessionTodos(this.client, sessionID) } - private clearNotificationsForTask(taskId: string): void { clearNotificationsForTask(this.notifications, taskId) } - private cleanupPendingByParent(task: BackgroundTask): void { cleanupPendingByParent(this.pendingByParent, task) } - - private registerProcessCleanup(): void { - BackgroundManager.cleanupManagers.add(this) - if (BackgroundManager.cleanupRegistered) return - BackgroundManager.cleanupRegistered = true - const cleanupAll = () => { for (const manager of BackgroundManager.cleanupManagers) { try { manager.shutdown() } catch (error) { log("[background-agent] Error during shutdown cleanup:", error) } } } - const registerSignal = (signal: ProcessCleanupEvent, exitAfter: boolean) => { const listener = registerProcessSignal(signal, cleanupAll, exitAfter); BackgroundManager.cleanupHandlers.set(signal, listener) } - registerSignal("SIGINT", true); registerSignal("SIGTERM", true); if (process.platform === "win32") registerSignal("SIGBREAK", true) - registerSignal("beforeExit", false); registerSignal("exit", false) - } - - private unregisterProcessCleanup(): void { - BackgroundManager.cleanupManagers.delete(this) - if (BackgroundManager.cleanupManagers.size > 0) return - for (const [signal, listener] of BackgroundManager.cleanupHandlers.entries()) process.off(signal, listener) - BackgroundManager.cleanupHandlers.clear(); BackgroundManager.cleanupRegistered = false - } +} + +function registerProcessSignal( + signal: ProcessCleanupEvent, + handler: () => void, + exitAfter: boolean +): () => void { + const listener = () => { + handler() + if (exitAfter) { + // Set exitCode and schedule exit after delay to allow other handlers to complete async cleanup + // Use 6s delay to accommodate LSP cleanup (5s timeout + 1s SIGKILL wait) + process.exitCode = 0 + setTimeout(() => process.exit(), 6000) + } + } + process.on(signal, listener) + return listener +} + + +function getMessageDir(sessionID: string): string | null { + if (!existsSync(MESSAGE_STORAGE)) return null + + const directPath = join(MESSAGE_STORAGE, sessionID) + if (existsSync(directPath)) return directPath + + for (const dir of readdirSync(MESSAGE_STORAGE)) { + const sessionPath = join(MESSAGE_STORAGE, dir, sessionID) + if (existsSync(sessionPath)) return sessionPath + } + return null } diff --git a/src/features/background-agent/spawner.test.ts b/src/features/background-agent/spawner.test.ts new file mode 100644 index 00000000..334f3762 --- /dev/null +++ b/src/features/background-agent/spawner.test.ts @@ -0,0 +1,65 @@ +import { describe, test, expect } from "bun:test" + +import { createTask, startTask } from "./spawner" + +describe("background-agent spawner.startTask", () => { + test("should inherit parent session permission rules (and force deny question)", async () => { + //#given + const createCalls: any[] = [] + const parentPermission = [ + { permission: "question", action: "allow" as const, pattern: "*" }, + { permission: "plan_enter", action: "deny" as const, pattern: "*" }, + ] + + const client = { + session: { + get: async () => ({ data: { directory: "/parent/dir", permission: parentPermission } }), + create: async (args?: any) => { + createCalls.push(args) + return { data: { id: "ses_child" } } + }, + promptAsync: async () => ({}), + }, + } + + const task = createTask({ + description: "Test task", + prompt: "Do work", + agent: "explore", + parentSessionID: "ses_parent", + parentMessageID: "msg_parent", + }) + + const item = { + task, + input: { + description: task.description, + prompt: task.prompt, + agent: task.agent, + parentSessionID: task.parentSessionID, + parentMessageID: task.parentMessageID, + parentModel: task.parentModel, + parentAgent: task.parentAgent, + model: task.model, + }, + } + + const ctx = { + client, + directory: "/fallback", + concurrencyManager: { release: () => {} }, + tmuxEnabled: false, + onTaskError: () => {}, + } + + //#when + await startTask(item as any, ctx as any) + + //#then + expect(createCalls).toHaveLength(1) + expect(createCalls[0]?.body?.permission).toEqual([ + { permission: "plan_enter", action: "deny", pattern: "*" }, + { permission: "question", action: "deny", pattern: "*" }, + ]) + }) +}) diff --git a/src/features/background-agent/spawner.ts b/src/features/background-agent/spawner.ts index ddf0e153..1b6773fb 100644 --- a/src/features/background-agent/spawner.ts +++ b/src/features/background-agent/spawner.ts @@ -1,4 +1,246 @@ -export type { SpawnerContext } from "./spawner/spawner-context" -export { createTask } from "./spawner/task-factory" -export { startTask } from "./spawner/task-starter" -export { resumeTask } from "./spawner/task-resumer" +import type { BackgroundTask, LaunchInput, ResumeInput } from "./types" +import type { OpencodeClient, OnSubagentSessionCreated, QueueItem } from "./constants" +import { TMUX_CALLBACK_DELAY_MS } from "./constants" +import { log, getAgentToolRestrictions, promptWithModelSuggestionRetry } from "../../shared" +import { subagentSessions } from "../claude-code-session-state" +import { getTaskToastManager } from "../task-toast-manager" +import { isInsideTmux } from "../../shared/tmux" +import type { ConcurrencyManager } from "./concurrency" + +export interface SpawnerContext { + client: OpencodeClient + directory: string + concurrencyManager: ConcurrencyManager + tmuxEnabled: boolean + onSubagentSessionCreated?: OnSubagentSessionCreated + onTaskError: (task: BackgroundTask, error: Error) => void +} + +export function createTask(input: LaunchInput): BackgroundTask { + return { + id: `bg_${crypto.randomUUID().slice(0, 8)}`, + status: "pending", + queuedAt: new Date(), + description: input.description, + prompt: input.prompt, + agent: input.agent, + parentSessionID: input.parentSessionID, + parentMessageID: input.parentMessageID, + parentModel: input.parentModel, + parentAgent: input.parentAgent, + model: input.model, + } +} + +export async function startTask( + item: QueueItem, + ctx: SpawnerContext +): Promise { + const { task, input } = item + const { client, directory, concurrencyManager, tmuxEnabled, onSubagentSessionCreated, onTaskError } = ctx + + log("[background-agent] Starting task:", { + taskId: task.id, + agent: input.agent, + model: input.model, + }) + + const concurrencyKey = input.model + ? `${input.model.providerID}/${input.model.modelID}` + : input.agent + + const parentSession = await client.session.get({ + path: { id: input.parentSessionID }, + }).catch((err) => { + log(`[background-agent] Failed to get parent session: ${err}`) + return null + }) + const parentDirectory = parentSession?.data?.directory ?? directory + log(`[background-agent] Parent dir: ${parentSession?.data?.directory}, using: ${parentDirectory}`) + + const inheritedPermission = (parentSession as any)?.data?.permission + const permissionRules = Array.isArray(inheritedPermission) + ? inheritedPermission.filter((r: any) => r?.permission !== "question") + : [] + permissionRules.push({ permission: "question", action: "deny" as const, pattern: "*" }) + + const createResult = await client.session.create({ + body: { + parentID: input.parentSessionID, + title: `Background: ${input.description}`, + permission: permissionRules, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any, + query: { + directory: parentDirectory, + }, + }).catch((error) => { + concurrencyManager.release(concurrencyKey) + throw error + }) + + if (createResult.error) { + concurrencyManager.release(concurrencyKey) + throw new Error(`Failed to create background session: ${createResult.error}`) + } + + const sessionID = createResult.data.id + subagentSessions.add(sessionID) + + log("[background-agent] tmux callback check", { + hasCallback: !!onSubagentSessionCreated, + tmuxEnabled, + isInsideTmux: isInsideTmux(), + sessionID, + parentID: input.parentSessionID, + }) + + if (onSubagentSessionCreated && tmuxEnabled && isInsideTmux()) { + log("[background-agent] Invoking tmux callback NOW", { sessionID }) + await onSubagentSessionCreated({ + sessionID, + parentID: input.parentSessionID, + title: input.description, + }).catch((err) => { + log("[background-agent] Failed to spawn tmux pane:", err) + }) + log("[background-agent] tmux callback completed, waiting") + await new Promise(r => setTimeout(r, TMUX_CALLBACK_DELAY_MS)) + } else { + log("[background-agent] SKIP tmux callback - conditions not met") + } + + task.status = "running" + task.startedAt = new Date() + task.sessionID = sessionID + task.progress = { + toolCalls: 0, + lastUpdate: new Date(), + } + task.concurrencyKey = concurrencyKey + task.concurrencyGroup = concurrencyKey + + log("[background-agent] Launching task:", { taskId: task.id, sessionID, agent: input.agent }) + + const toastManager = getTaskToastManager() + if (toastManager) { + toastManager.updateTask(task.id, "running") + } + + log("[background-agent] Calling prompt (fire-and-forget) for launch with:", { + sessionID, + agent: input.agent, + model: input.model, + hasSkillContent: !!input.skillContent, + promptLength: input.prompt.length, + }) + + const launchModel = input.model + ? { providerID: input.model.providerID, modelID: input.model.modelID } + : undefined + const launchVariant = input.model?.variant + + promptWithModelSuggestionRetry(client, { + path: { id: sessionID }, + body: { + agent: input.agent, + ...(launchModel ? { model: launchModel } : {}), + ...(launchVariant ? { variant: launchVariant } : {}), + system: input.skillContent, + tools: { + ...getAgentToolRestrictions(input.agent), + task: false, + call_omo_agent: true, + question: false, + }, + parts: [{ type: "text", text: input.prompt }], + }, + }).catch((error) => { + log("[background-agent] promptAsync error:", error) + onTaskError(task, error instanceof Error ? error : new Error(String(error))) + }) +} + +export async function resumeTask( + task: BackgroundTask, + input: ResumeInput, + ctx: Pick +): Promise { + const { client, concurrencyManager, onTaskError } = ctx + + if (!task.sessionID) { + throw new Error(`Task has no sessionID: ${task.id}`) + } + + if (task.status === "running") { + log("[background-agent] Resume skipped - task already running:", { + taskId: task.id, + sessionID: task.sessionID, + }) + return + } + + const concurrencyKey = task.concurrencyGroup ?? task.agent + await concurrencyManager.acquire(concurrencyKey) + task.concurrencyKey = concurrencyKey + task.concurrencyGroup = concurrencyKey + + task.status = "running" + task.completedAt = undefined + task.error = undefined + task.parentSessionID = input.parentSessionID + task.parentMessageID = input.parentMessageID + task.parentModel = input.parentModel + task.parentAgent = input.parentAgent + task.startedAt = new Date() + + task.progress = { + toolCalls: task.progress?.toolCalls ?? 0, + lastUpdate: new Date(), + } + + subagentSessions.add(task.sessionID) + + const toastManager = getTaskToastManager() + if (toastManager) { + toastManager.addTask({ + id: task.id, + description: task.description, + agent: task.agent, + isBackground: true, + }) + } + + log("[background-agent] Resuming task:", { taskId: task.id, sessionID: task.sessionID }) + + log("[background-agent] Resuming task - calling prompt (fire-and-forget) with:", { + sessionID: task.sessionID, + agent: task.agent, + model: task.model, + promptLength: input.prompt.length, + }) + + const resumeModel = task.model + ? { providerID: task.model.providerID, modelID: task.model.modelID } + : undefined + const resumeVariant = task.model?.variant + + client.session.promptAsync({ + path: { id: task.sessionID }, + body: { + agent: task.agent, + ...(resumeModel ? { model: resumeModel } : {}), + ...(resumeVariant ? { variant: resumeVariant } : {}), + tools: { + ...getAgentToolRestrictions(task.agent), + task: false, + call_omo_agent: true, + question: false, + }, + parts: [{ type: "text", text: input.prompt }], + }, + }).catch((error) => { + log("[background-agent] resume prompt error:", error) + onTaskError(task, error instanceof Error ? error : new Error(String(error))) + }) +} diff --git a/src/features/tmux-subagent/manager.ts b/src/features/tmux-subagent/manager.ts index bc973eec..5bd8d6e8 100644 --- a/src/features/tmux-subagent/manager.ts +++ b/src/features/tmux-subagent/manager.ts @@ -4,20 +4,23 @@ import type { TrackedSession, CapacityConfig } from "./types" import { isInsideTmux as defaultIsInsideTmux, getCurrentPaneId as defaultGetCurrentPaneId, + POLL_INTERVAL_BACKGROUND_MS, + SESSION_MISSING_GRACE_MS, + SESSION_READY_POLL_INTERVAL_MS, + SESSION_READY_TIMEOUT_MS, } from "../../shared/tmux" import { log } from "../../shared" -import type { SessionMapping } from "./decision-engine" -import { - coerceSessionCreatedEvent, - handleSessionCreated, - handleSessionDeleted, - type SessionCreatedEvent, -} from "./event-handlers" -import { createSessionPollingController, type SessionPollingController } from "./polling" -import { cleanupTmuxSessions } from "./cleanup" - +import { queryWindowState } from "./pane-state-querier" +import { decideSpawnActions, decideCloseAction, type SessionMapping } from "./decision-engine" +import { executeActions, executeAction } from "./action-executor" +import { TmuxPollingManager } from "./polling-manager" type OpencodeClient = PluginInput["client"] +interface SessionCreatedEvent { + type: string + properties?: { info?: { id?: string; parentID?: string; title?: string } } +} + export interface TmuxUtilDeps { isInsideTmux: () => boolean getCurrentPaneId: () => string | undefined @@ -28,6 +31,13 @@ const defaultTmuxDeps: TmuxUtilDeps = { getCurrentPaneId: defaultGetCurrentPaneId, } +const SESSION_TIMEOUT_MS = 10 * 60 * 1000 + +// Stability detection constants (prevents premature closure - see issue #1330) +// Mirrors the proven pattern from background-agent/manager.ts +const MIN_STABILITY_TIME_MS = 10 * 1000 // Must run at least 10s before stability detection kicks in +const STABLE_POLLS_REQUIRED = 3 // 3 consecutive idle polls (~6s with 2s poll interval) + /** * State-first Tmux Session Manager * @@ -48,8 +58,7 @@ export class TmuxSessionManager { private sessions = new Map() private pendingSessions = new Set() private deps: TmuxUtilDeps - private polling: SessionPollingController - + private pollingManager: TmuxPollingManager constructor(ctx: PluginInput, tmuxConfig: TmuxConfig, deps: TmuxUtilDeps = defaultTmuxDeps) { this.client = ctx.client this.tmuxConfig = tmuxConfig @@ -57,15 +66,11 @@ export class TmuxSessionManager { const defaultPort = process.env.OPENCODE_PORT ?? "4096" this.serverUrl = ctx.serverUrl?.toString() ?? `http://localhost:${defaultPort}` this.sourcePaneId = deps.getCurrentPaneId() - - this.polling = createSessionPollingController({ - client: this.client, - tmuxConfig: this.tmuxConfig, - serverUrl: this.serverUrl, - sourcePaneId: this.sourcePaneId, - sessions: this.sessions, - }) - + this.pollingManager = new TmuxPollingManager( + this.client, + this.sessions, + this.closeSessionById.bind(this) + ) log("[tmux-session-manager] initialized", { configEnabled: this.tmuxConfig.enabled, tmuxConfig: this.tmuxConfig, @@ -73,7 +78,6 @@ export class TmuxSessionManager { sourcePaneId: this.sourcePaneId, }) } - private isEnabled(): boolean { return this.tmuxConfig.enabled && this.deps.isInsideTmux() } @@ -93,58 +97,254 @@ export class TmuxSessionManager { })) } + private async waitForSessionReady(sessionId: string): Promise { + const startTime = Date.now() + + while (Date.now() - startTime < SESSION_READY_TIMEOUT_MS) { + try { + const statusResult = await this.client.session.status({ path: undefined }) + const allStatuses = (statusResult.data ?? {}) as Record + + if (allStatuses[sessionId]) { + log("[tmux-session-manager] session ready", { + sessionId, + status: allStatuses[sessionId].type, + waitedMs: Date.now() - startTime, + }) + return true + } + } catch (err) { + log("[tmux-session-manager] session status check error", { error: String(err) }) + } + + await new Promise((resolve) => setTimeout(resolve, SESSION_READY_POLL_INTERVAL_MS)) + } + + log("[tmux-session-manager] session ready timeout", { + sessionId, + timeoutMs: SESSION_READY_TIMEOUT_MS, + }) + return false + } + + // NOTE: Exposed (via `as any`) for test stability checks. + // Actual polling is owned by TmuxPollingManager. + private async pollSessions(): Promise { + await (this.pollingManager as any).pollSessions() + } + async onSessionCreated(event: SessionCreatedEvent): Promise { - await handleSessionCreated( - { - client: this.client, - tmuxConfig: this.tmuxConfig, - serverUrl: this.serverUrl, - sourcePaneId: this.sourcePaneId, - sessions: this.sessions, - pendingSessions: this.pendingSessions, - isInsideTmux: this.deps.isInsideTmux, - isEnabled: () => this.isEnabled(), - getCapacityConfig: () => this.getCapacityConfig(), - getSessionMappings: () => this.getSessionMappings(), - waitForSessionReady: (sessionId) => this.polling.waitForSessionReady(sessionId), - startPolling: () => this.polling.startPolling(), - }, - event, - ) + const enabled = this.isEnabled() + log("[tmux-session-manager] onSessionCreated called", { + enabled, + tmuxConfigEnabled: this.tmuxConfig.enabled, + isInsideTmux: this.deps.isInsideTmux(), + eventType: event.type, + infoId: event.properties?.info?.id, + infoParentID: event.properties?.info?.parentID, + }) + + if (!enabled) return + if (event.type !== "session.created") return + + const info = event.properties?.info + if (!info?.id || !info?.parentID) return + + const sessionId = info.id + const title = info.title ?? "Subagent" + + if (this.sessions.has(sessionId) || this.pendingSessions.has(sessionId)) { + log("[tmux-session-manager] session already tracked or pending", { sessionId }) + return + } + + if (!this.sourcePaneId) { + log("[tmux-session-manager] no source pane id") + return + } + + this.pendingSessions.add(sessionId) + + try { + const state = await queryWindowState(this.sourcePaneId) + if (!state) { + log("[tmux-session-manager] failed to query window state") + return + } + + log("[tmux-session-manager] window state queried", { + windowWidth: state.windowWidth, + mainPane: state.mainPane?.paneId, + agentPaneCount: state.agentPanes.length, + agentPanes: state.agentPanes.map((p) => p.paneId), + }) + + const decision = decideSpawnActions( + state, + sessionId, + title, + this.getCapacityConfig(), + this.getSessionMappings() + ) + + log("[tmux-session-manager] spawn decision", { + canSpawn: decision.canSpawn, + reason: decision.reason, + actionCount: decision.actions.length, + actions: decision.actions.map((a) => { + if (a.type === "close") return { type: "close", paneId: a.paneId } + if (a.type === "replace") return { type: "replace", paneId: a.paneId, newSessionId: a.newSessionId } + return { type: "spawn", sessionId: a.sessionId } + }), + }) + + if (!decision.canSpawn) { + log("[tmux-session-manager] cannot spawn", { reason: decision.reason }) + return + } + + const result = await executeActions( + decision.actions, + { config: this.tmuxConfig, serverUrl: this.serverUrl, windowState: state } + ) + + for (const { action, result: actionResult } of result.results) { + if (action.type === "close" && actionResult.success) { + this.sessions.delete(action.sessionId) + log("[tmux-session-manager] removed closed session from cache", { + sessionId: action.sessionId, + }) + } + if (action.type === "replace" && actionResult.success) { + this.sessions.delete(action.oldSessionId) + log("[tmux-session-manager] removed replaced session from cache", { + oldSessionId: action.oldSessionId, + newSessionId: action.newSessionId, + }) + } + } + + if (result.success && result.spawnedPaneId) { + const sessionReady = await this.waitForSessionReady(sessionId) + + if (!sessionReady) { + log("[tmux-session-manager] session not ready after timeout, tracking anyway", { + sessionId, + paneId: result.spawnedPaneId, + }) + } + + const now = Date.now() + this.sessions.set(sessionId, { + sessionId, + paneId: result.spawnedPaneId, + description: title, + createdAt: new Date(now), + lastSeenAt: new Date(now), + }) + log("[tmux-session-manager] pane spawned and tracked", { + sessionId, + paneId: result.spawnedPaneId, + sessionReady, + }) + this.pollingManager.startPolling() + } else { + log("[tmux-session-manager] spawn failed", { + success: result.success, + results: result.results.map((r) => ({ + type: r.action.type, + success: r.result.success, + error: r.result.error, + })), + }) + } + } finally { + this.pendingSessions.delete(sessionId) + } } async onSessionDeleted(event: { sessionID: string }): Promise { - await handleSessionDeleted( - { - tmuxConfig: this.tmuxConfig, - serverUrl: this.serverUrl, - sourcePaneId: this.sourcePaneId, - sessions: this.sessions, - isEnabled: () => this.isEnabled(), - getSessionMappings: () => this.getSessionMappings(), - stopPolling: () => this.polling.stopPolling(), - }, - event, - ) + if (!this.isEnabled()) return + if (!this.sourcePaneId) return + + const tracked = this.sessions.get(event.sessionID) + if (!tracked) return + + log("[tmux-session-manager] onSessionDeleted", { sessionId: event.sessionID }) + + const state = await queryWindowState(this.sourcePaneId) + if (!state) { + this.sessions.delete(event.sessionID) + return + } + + const closeAction = decideCloseAction(state, event.sessionID, this.getSessionMappings()) + if (closeAction) { + await executeAction(closeAction, { config: this.tmuxConfig, serverUrl: this.serverUrl, windowState: state }) + } + + this.sessions.delete(event.sessionID) + + if (this.sessions.size === 0) { + this.pollingManager.stopPolling() + } + } + + + private async closeSessionById(sessionId: string): Promise { + const tracked = this.sessions.get(sessionId) + if (!tracked) return + + log("[tmux-session-manager] closing session pane", { + sessionId, + paneId: tracked.paneId, + }) + + const state = this.sourcePaneId ? await queryWindowState(this.sourcePaneId) : null + if (state) { + await executeAction( + { type: "close", paneId: tracked.paneId, sessionId }, + { config: this.tmuxConfig, serverUrl: this.serverUrl, windowState: state } + ) + } + + this.sessions.delete(sessionId) + + if (this.sessions.size === 0) { + this.pollingManager.stopPolling() + } } createEventHandler(): (input: { event: { type: string; properties?: unknown } }) => Promise { return async (input) => { - await this.onSessionCreated(coerceSessionCreatedEvent(input.event)) + await this.onSessionCreated(input.event as SessionCreatedEvent) } } - async pollSessions(): Promise { - return this.polling.pollSessions() - } - async cleanup(): Promise { - await cleanupTmuxSessions({ - tmuxConfig: this.tmuxConfig, - serverUrl: this.serverUrl, - sourcePaneId: this.sourcePaneId, - sessions: this.sessions, - stopPolling: () => this.polling.stopPolling(), - }) + this.pollingManager.stopPolling() + + if (this.sessions.size > 0) { + log("[tmux-session-manager] closing all panes", { count: this.sessions.size }) + const state = this.sourcePaneId ? await queryWindowState(this.sourcePaneId) : null + + if (state) { + const closePromises = Array.from(this.sessions.values()).map((s) => + executeAction( + { type: "close", paneId: s.paneId, sessionId: s.sessionId }, + { config: this.tmuxConfig, serverUrl: this.serverUrl, windowState: state } + ).catch((err) => + log("[tmux-session-manager] cleanup error for pane", { + paneId: s.paneId, + error: String(err), + }), + ), + ) + await Promise.all(closePromises) + } + this.sessions.clear() + } + + log("[tmux-session-manager] cleanup complete") } } diff --git a/src/shared/git-worktree/collect-git-diff-stats.test.ts b/src/shared/git-worktree/collect-git-diff-stats.test.ts new file mode 100644 index 00000000..678d2f67 --- /dev/null +++ b/src/shared/git-worktree/collect-git-diff-stats.test.ts @@ -0,0 +1,66 @@ +/// + +import { describe, expect, mock, test } from "bun:test" + +const execSyncMock = mock(() => { + throw new Error("execSync should not be called") +}) + +const execFileSyncMock = mock((file: string, args: string[], _opts: { cwd?: string }) => { + if (file !== "git") throw new Error(`unexpected file: ${file}`) + const subcommand = args[0] + + if (subcommand === "diff") { + return "1\t2\tfile.ts\n" + } + + if (subcommand === "status") { + return " M file.ts\n" + } + + throw new Error(`unexpected args: ${args.join(" ")}`) +}) + +mock.module("node:child_process", () => ({ + execSync: execSyncMock, + execFileSync: execFileSyncMock, +})) + +const { collectGitDiffStats } = await import("./collect-git-diff-stats") + +describe("collectGitDiffStats", () => { + test("uses execFileSync with arg arrays (no shell injection)", () => { + //#given + const directory = "/tmp/safe-repo;touch /tmp/pwn" + + //#when + const result = collectGitDiffStats(directory) + + //#then + expect(execSyncMock).not.toHaveBeenCalled() + expect(execFileSyncMock).toHaveBeenCalledTimes(2) + + const [firstCallFile, firstCallArgs, firstCallOpts] = execFileSyncMock.mock + .calls[0]! as unknown as [string, string[], { cwd?: string }] + expect(firstCallFile).toBe("git") + expect(firstCallArgs).toEqual(["diff", "--numstat", "HEAD"]) + expect(firstCallOpts.cwd).toBe(directory) + expect(firstCallArgs.join(" ")).not.toContain(directory) + + const [secondCallFile, secondCallArgs, secondCallOpts] = execFileSyncMock.mock + .calls[1]! as unknown as [string, string[], { cwd?: string }] + expect(secondCallFile).toBe("git") + expect(secondCallArgs).toEqual(["status", "--porcelain"]) + expect(secondCallOpts.cwd).toBe(directory) + expect(secondCallArgs.join(" ")).not.toContain(directory) + + expect(result).toEqual([ + { + path: "file.ts", + added: 1, + removed: 2, + status: "modified", + }, + ]) + }) +}) diff --git a/src/shared/migration/config-migration.ts b/src/shared/migration/config-migration.ts index 60cf708a..4a0bdefc 100644 --- a/src/shared/migration/config-migration.ts +++ b/src/shared/migration/config-migration.ts @@ -8,30 +8,32 @@ export function migrateConfigFile( configPath: string, rawConfig: Record ): boolean { + // Work on a deep copy — only apply changes to rawConfig if file write succeeds + const copy = structuredClone(rawConfig) let needsWrite = false // Load previously applied migrations - const existingMigrations = Array.isArray(rawConfig._migrations) - ? new Set(rawConfig._migrations as string[]) + const existingMigrations = Array.isArray(copy._migrations) + ? new Set(copy._migrations as string[]) : new Set() const allNewMigrations: string[] = [] - if (rawConfig.agents && typeof rawConfig.agents === "object") { - const { migrated, changed } = migrateAgentNames(rawConfig.agents as Record) + if (copy.agents && typeof copy.agents === "object") { + const { migrated, changed } = migrateAgentNames(copy.agents as Record) if (changed) { - rawConfig.agents = migrated + copy.agents = migrated needsWrite = true } } // Migrate model versions in agents (skip already-applied migrations) - if (rawConfig.agents && typeof rawConfig.agents === "object") { + if (copy.agents && typeof copy.agents === "object") { const { migrated, changed, newMigrations } = migrateModelVersions( - rawConfig.agents as Record, + copy.agents as Record, existingMigrations ) if (changed) { - rawConfig.agents = migrated + copy.agents = migrated needsWrite = true log("Migrated model versions in agents config") } @@ -39,13 +41,13 @@ export function migrateConfigFile( } // Migrate model versions in categories (skip already-applied migrations) - if (rawConfig.categories && typeof rawConfig.categories === "object") { + if (copy.categories && typeof copy.categories === "object") { const { migrated, changed, newMigrations } = migrateModelVersions( - rawConfig.categories as Record, + copy.categories as Record, existingMigrations ) if (changed) { - rawConfig.categories = migrated + copy.categories = migrated needsWrite = true log("Migrated model versions in categories config") } @@ -56,20 +58,20 @@ export function migrateConfigFile( if (allNewMigrations.length > 0) { const updatedMigrations = Array.from(existingMigrations) updatedMigrations.push(...allNewMigrations) - rawConfig._migrations = updatedMigrations + copy._migrations = updatedMigrations needsWrite = true } - if (rawConfig.omo_agent) { - rawConfig.sisyphus_agent = rawConfig.omo_agent - delete rawConfig.omo_agent + if (copy.omo_agent) { + copy.sisyphus_agent = copy.omo_agent + delete copy.omo_agent needsWrite = true } - if (rawConfig.disabled_agents && Array.isArray(rawConfig.disabled_agents)) { + if (copy.disabled_agents && Array.isArray(copy.disabled_agents)) { const migrated: string[] = [] let changed = false - for (const agent of rawConfig.disabled_agents as string[]) { + for (const agent of copy.disabled_agents as string[]) { const newAgent = AGENT_NAME_MAP[agent.toLowerCase()] ?? AGENT_NAME_MAP[agent] ?? agent if (newAgent !== agent) { changed = true @@ -77,15 +79,15 @@ export function migrateConfigFile( migrated.push(newAgent) } if (changed) { - rawConfig.disabled_agents = migrated + copy.disabled_agents = migrated needsWrite = true } } - if (rawConfig.disabled_hooks && Array.isArray(rawConfig.disabled_hooks)) { - const { migrated, changed, removed } = migrateHookNames(rawConfig.disabled_hooks as string[]) + if (copy.disabled_hooks && Array.isArray(copy.disabled_hooks)) { + const { migrated, changed, removed } = migrateHookNames(copy.disabled_hooks as string[]) if (changed) { - rawConfig.disabled_hooks = migrated + copy.disabled_hooks = migrated needsWrite = true } if (removed.length > 0) { @@ -99,13 +101,25 @@ export function migrateConfigFile( try { const timestamp = new Date().toISOString().replace(/[:.]/g, "-") const backupPath = `${configPath}.bak.${timestamp}` - fs.copyFileSync(configPath, backupPath) + try { + fs.copyFileSync(configPath, backupPath) + } catch { + // Original file may not exist yet — skip backup + } - fs.writeFileSync(configPath, JSON.stringify(rawConfig, null, 2) + "\n", "utf-8") + fs.writeFileSync(configPath, JSON.stringify(copy, null, 2) + "\n", "utf-8") log(`Migrated config file: ${configPath} (backup: ${backupPath})`) } catch (err) { log(`Failed to write migrated config to ${configPath}:`, err) + // File write failed — rawConfig is untouched, preserving user's original values + return false } + + // File write succeeded — apply changes to the original rawConfig + for (const key of Object.keys(rawConfig)) { + delete rawConfig[key] + } + Object.assign(rawConfig, copy) } return needsWrite diff --git a/src/shared/model-availability.test.ts b/src/shared/model-availability.test.ts index df60e8e3..cbaed0f5 100644 --- a/src/shared/model-availability.test.ts +++ b/src/shared/model-availability.test.ts @@ -5,198 +5,174 @@ import { tmpdir } from "os" import { join } from "path" let __resetModelCache: () => void -let fetchAvailableModels: ( - client?: unknown, - options?: { connectedProviders?: string[] | null }, -) => Promise> +let fetchAvailableModels: (client?: unknown, options?: { connectedProviders?: string[] | null }) => Promise> let fuzzyMatchModel: (target: string, available: Set, providers?: string[]) => string | null let isModelAvailable: (targetModel: string, availableModels: Set) => boolean let getConnectedProviders: (client: unknown) => Promise beforeAll(async () => { - ;({ - __resetModelCache, - fetchAvailableModels, - fuzzyMatchModel, - isModelAvailable, - getConnectedProviders, - } = await import("./model-availability")) + ;({ + __resetModelCache, + fetchAvailableModels, + fuzzyMatchModel, + isModelAvailable, + getConnectedProviders, + } = await import("./model-availability")) }) describe("fetchAvailableModels", () => { - let tempDir: string + let tempDir: string let originalXdgCache: string | undefined - beforeEach(() => { - __resetModelCache() - tempDir = mkdtempSync(join(tmpdir(), "opencode-test-")) + + beforeEach(() => { + __resetModelCache() + tempDir = mkdtempSync(join(tmpdir(), "opencode-test-")) originalXdgCache = process.env.XDG_CACHE_HOME process.env.XDG_CACHE_HOME = tempDir - }) + }) - afterEach(() => { - if (originalXdgCache !== undefined) { + afterEach(() => { + if (originalXdgCache !== undefined) { process.env.XDG_CACHE_HOME = originalXdgCache } else { delete process.env.XDG_CACHE_HOME } - rmSync(tempDir, { recursive: true, force: true }) - }) + rmSync(tempDir, { recursive: true, force: true }) + }) - function writeModelsCache(data: Record) { - const cacheDir = join(tempDir, "opencode") - require("fs").mkdirSync(cacheDir, { recursive: true }) - writeFileSync(join(cacheDir, "models.json"), JSON.stringify(data)) - } + function writeModelsCache(data: Record) { + const cacheDir = join(tempDir, "opencode") + require("fs").mkdirSync(cacheDir, { recursive: true }) + writeFileSync(join(cacheDir, "models.json"), JSON.stringify(data)) + } - it("#given cache file with models #when fetchAvailableModels called with connectedProviders #then returns Set of model IDs", async () => { - writeModelsCache({ - openai: { id: "openai", models: { "gpt-5.2": { id: "gpt-5.2" } } }, - anthropic: { - id: "anthropic", - models: { "claude-opus-4-6": { id: "claude-opus-4-6" } }, - }, - google: { id: "google", models: { "gemini-3-pro": { id: "gemini-3-pro" } } }, - }) + it("#given cache file with models #when fetchAvailableModels called with connectedProviders #then returns Set of model IDs", async () => { + writeModelsCache({ + openai: { id: "openai", models: { "gpt-5.2": { id: "gpt-5.2" } } }, + anthropic: { id: "anthropic", models: { "claude-opus-4-6": { id: "claude-opus-4-6" } } }, + google: { id: "google", models: { "gemini-3-pro": { id: "gemini-3-pro" } } }, + }) - const result = await fetchAvailableModels(undefined, { - connectedProviders: ["openai", "anthropic", "google"], - }) + const result = await fetchAvailableModels(undefined, { + connectedProviders: ["openai", "anthropic", "google"] + }) - expect(result).toBeInstanceOf(Set) - expect(result.size).toBe(3) - expect(result.has("openai/gpt-5.2")).toBe(true) - expect(result.has("anthropic/claude-opus-4-6")).toBe(true) - expect(result.has("google/gemini-3-pro")).toBe(true) - }) + expect(result).toBeInstanceOf(Set) + expect(result.size).toBe(3) + expect(result.has("openai/gpt-5.2")).toBe(true) + expect(result.has("anthropic/claude-opus-4-6")).toBe(true) + expect(result.has("google/gemini-3-pro")).toBe(true) + }) - it("#given connectedProviders unknown #when fetchAvailableModels called without options #then returns empty Set", async () => { - writeModelsCache({ - openai: { id: "openai", models: { "gpt-5.2": { id: "gpt-5.2" } } }, - }) + it("#given connectedProviders unknown #when fetchAvailableModels called without options #then returns empty Set", async () => { + writeModelsCache({ + openai: { id: "openai", models: { "gpt-5.2": { id: "gpt-5.2" } } }, + }) - const result = await fetchAvailableModels() + const result = await fetchAvailableModels() - expect(result).toBeInstanceOf(Set) - expect(result.size).toBe(0) - }) + expect(result).toBeInstanceOf(Set) + expect(result.size).toBe(0) + }) - it("#given connectedProviders unknown but client can list #when fetchAvailableModels called with client #then returns models from API filtered by connected providers", async () => { - const client = { - provider: { - list: async () => ({ data: { connected: ["openai"] } }), - }, - model: { - list: async () => ({ - data: [ - { id: "gpt-5.3-codex", provider: "openai" }, - { id: "gemini-3-pro", provider: "google" }, - ], - }), - }, - } + it("#given connectedProviders unknown but client can list #when fetchAvailableModels called with client #then returns models from API filtered by connected providers", async () => { + const client = { + provider: { + list: async () => ({ data: { connected: ["openai"] } }), + }, + model: { + list: async () => ({ + data: [ + { id: "gpt-5.3-codex", provider: "openai" }, + { id: "gemini-3-pro", provider: "google" }, + ], + }), + }, + } - const result = await fetchAvailableModels(client) + const result = await fetchAvailableModels(client) - expect(result).toBeInstanceOf(Set) - expect(result.has("openai/gpt-5.3-codex")).toBe(true) - expect(result.has("google/gemini-3-pro")).toBe(false) - }) + expect(result).toBeInstanceOf(Set) + expect(result.has("openai/gpt-5.3-codex")).toBe(true) + expect(result.has("google/gemini-3-pro")).toBe(false) + }) - it("#given cache file not found #when fetchAvailableModels called with connectedProviders #then returns empty Set", async () => { - const result = await fetchAvailableModels(undefined, { - connectedProviders: ["openai"], - }) + it("#given cache file not found #when fetchAvailableModels called with connectedProviders #then returns empty Set", async () => { + const result = await fetchAvailableModels(undefined, { connectedProviders: ["openai"] }) - expect(result).toBeInstanceOf(Set) - expect(result.size).toBe(0) - }) + expect(result).toBeInstanceOf(Set) + expect(result.size).toBe(0) + }) - it("#given cache missing but client can list #when fetchAvailableModels called with connectedProviders #then returns models from API", async () => { - const client = { - provider: { - list: async () => ({ data: { connected: ["openai", "google"] } }), - }, - model: { - list: async () => ({ - data: [ - { id: "gpt-5.3-codex", provider: "openai" }, - { id: "gemini-3-pro", provider: "google" }, - ], - }), - }, - } + it("#given cache missing but client can list #when fetchAvailableModels called with connectedProviders #then returns models from API", async () => { + const client = { + provider: { + list: async () => ({ data: { connected: ["openai", "google"] } }), + }, + model: { + list: async () => ({ + data: [ + { id: "gpt-5.3-codex", provider: "openai" }, + { id: "gemini-3-pro", provider: "google" }, + ], + }), + }, + } - const result = await fetchAvailableModels(client, { - connectedProviders: ["openai", "google"], - }) + const result = await fetchAvailableModels(client, { connectedProviders: ["openai", "google"] }) - expect(result).toBeInstanceOf(Set) - expect(result.has("openai/gpt-5.3-codex")).toBe(true) - expect(result.has("google/gemini-3-pro")).toBe(true) - }) + expect(result).toBeInstanceOf(Set) + expect(result.has("openai/gpt-5.3-codex")).toBe(true) + expect(result.has("google/gemini-3-pro")).toBe(true) + }) - it("#given cache read twice #when second call made with same providers #then reads fresh each time", async () => { - writeModelsCache({ - openai: { id: "openai", models: { "gpt-5.2": { id: "gpt-5.2" } } }, - anthropic: { - id: "anthropic", - models: { "claude-opus-4-6": { id: "claude-opus-4-6" } }, - }, - }) + it("#given cache read twice #when second call made with same providers #then reads fresh each time", async () => { + writeModelsCache({ + openai: { id: "openai", models: { "gpt-5.2": { id: "gpt-5.2" } } }, + anthropic: { id: "anthropic", models: { "claude-opus-4-6": { id: "claude-opus-4-6" } } }, + }) - const result1 = await fetchAvailableModels(undefined, { - connectedProviders: ["openai"], - }) - const result2 = await fetchAvailableModels(undefined, { - connectedProviders: ["openai"], - }) + const result1 = await fetchAvailableModels(undefined, { connectedProviders: ["openai"] }) + const result2 = await fetchAvailableModels(undefined, { connectedProviders: ["openai"] }) - expect(result1.size).toBe(result2.size) - expect(result1.has("openai/gpt-5.2")).toBe(true) - }) + expect(result1.size).toBe(result2.size) + expect(result1.has("openai/gpt-5.2")).toBe(true) + }) - it("#given empty providers in cache #when fetchAvailableModels called with connectedProviders #then returns empty Set", async () => { - writeModelsCache({}) + it("#given empty providers in cache #when fetchAvailableModels called with connectedProviders #then returns empty Set", async () => { + writeModelsCache({}) - const result = await fetchAvailableModels(undefined, { - connectedProviders: ["openai"], - }) + const result = await fetchAvailableModels(undefined, { connectedProviders: ["openai"] }) - expect(result).toBeInstanceOf(Set) - expect(result.size).toBe(0) - }) + expect(result).toBeInstanceOf(Set) + expect(result.size).toBe(0) + }) - it("#given cache file with various providers #when fetchAvailableModels called with all providers #then extracts all IDs correctly", async () => { - writeModelsCache({ - openai: { - id: "openai", - models: { "gpt-5.3-codex": { id: "gpt-5.3-codex" } }, - }, - anthropic: { - id: "anthropic", - models: { "claude-sonnet-4-5": { id: "claude-sonnet-4-5" } }, - }, - google: { - id: "google", - models: { "gemini-3-flash": { id: "gemini-3-flash" } }, - }, - opencode: { id: "opencode", models: { "gpt-5-nano": { id: "gpt-5-nano" } } }, - }) + it("#given cache file with various providers #when fetchAvailableModels called with all providers #then extracts all IDs correctly", async () => { + writeModelsCache({ + openai: { id: "openai", models: { "gpt-5.3-codex": { id: "gpt-5.3-codex" } } }, + anthropic: { id: "anthropic", models: { "claude-sonnet-4-5": { id: "claude-sonnet-4-5" } } }, + google: { id: "google", models: { "gemini-3-flash": { id: "gemini-3-flash" } } }, + opencode: { id: "opencode", models: { "gpt-5-nano": { id: "gpt-5-nano" } } }, + }) - const result = await fetchAvailableModels(undefined, { - connectedProviders: ["openai", "anthropic", "google", "opencode"], - }) + const result = await fetchAvailableModels(undefined, { + connectedProviders: ["openai", "anthropic", "google", "opencode"] + }) - expect(result.size).toBe(4) - expect(result.has("openai/gpt-5.3-codex")).toBe(true) - expect(result.has("anthropic/claude-sonnet-4-5")).toBe(true) - expect(result.has("google/gemini-3-flash")).toBe(true) - expect(result.has("opencode/gpt-5-nano")).toBe(true) - }) + expect(result.size).toBe(4) + expect(result.has("openai/gpt-5.3-codex")).toBe(true) + expect(result.has("anthropic/claude-sonnet-4-5")).toBe(true) + expect(result.has("google/gemini-3-flash")).toBe(true) + expect(result.has("opencode/gpt-5-nano")).toBe(true) + }) }) describe("fuzzyMatchModel", () => { + // given available models from multiple providers + // when searching for a substring match + // then return the matching model it("should match substring in model name", () => { const available = new Set([ "openai/gpt-5.2", @@ -207,6 +183,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("openai/gpt-5.2") }) + // given available model with preview suffix + // when searching with provider-prefixed base model + // then return preview model it("should match preview suffix for gemini-3-flash", () => { const available = new Set(["google/gemini-3-flash-preview"]) const result = fuzzyMatchModel( @@ -217,6 +196,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("google/gemini-3-flash-preview") }) + // given available models with partial matches + // when searching for a substring + // then return exact match if it exists it("should prefer exact match over substring match", () => { const available = new Set([ "openai/gpt-5.2", @@ -227,6 +209,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("openai/gpt-5.2") }) + // given available models with multiple substring matches + // when searching for a substring + // then return the shorter model name (more specific) it("should prefer shorter model name when multiple matches exist", () => { const available = new Set([ "openai/gpt-5.2-ultra", @@ -236,6 +221,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("openai/gpt-5.2-ultra") }) + // given available models with claude variants + // when searching for claude-opus + // then return matching claude-opus model it("should match claude-opus to claude-opus-4-6", () => { const available = new Set([ "anthropic/claude-opus-4-6", @@ -245,6 +233,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("anthropic/claude-opus-4-6") }) + // given available models from multiple providers + // when providers filter is specified + // then only search models from specified providers it("should filter by provider when providers array is given", () => { const available = new Set([ "openai/gpt-5.2", @@ -255,6 +246,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("openai/gpt-5.2") }) + // given available models from multiple providers + // when providers filter excludes matching models + // then return null it("should return null when provider filter excludes all matches", () => { const available = new Set([ "openai/gpt-5.2", @@ -264,6 +258,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBeNull() }) + // given available models + // when no substring match exists + // then return null it("should return null when no match found", () => { const available = new Set([ "openai/gpt-5.2", @@ -273,6 +270,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBeNull() }) + // given available models with different cases + // when searching with different case + // then match case-insensitively it("should match case-insensitively", () => { const available = new Set([ "openai/gpt-5.2", @@ -282,6 +282,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("openai/gpt-5.2") }) + // given available models with exact match and longer variants + // when searching for exact match + // then return exact match first it("should prioritize exact match over longer variants", () => { const available = new Set([ "anthropic/claude-opus-4-6", @@ -291,6 +294,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("anthropic/claude-opus-4-6") }) + // given available models with similar model IDs (e.g., glm-4.7 and glm-4.7-free) + // when searching for the longer variant (glm-4.7-free) + // then return exact model ID match, not the shorter one it("should prefer exact model ID match over shorter substring match", () => { const available = new Set([ "zai-coding-plan/glm-4.7", @@ -300,6 +306,9 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("zai-coding-plan/glm-4.7-free") }) + // given available models with similar model IDs + // when searching for the shorter variant + // then return the shorter match (existing behavior preserved) it("should still prefer shorter match when searching for shorter variant", () => { const available = new Set([ "zai-coding-plan/glm-4.7", @@ -309,12 +318,21 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("zai-coding-plan/glm-4.7") }) + // given same model ID from multiple providers + // when searching for exact model ID + // then return shortest full string (preserves tie-break behavior) it("should use shortest tie-break when multiple providers have same model ID", () => { - const available = new Set(["opencode/gpt-5.2", "openai/gpt-5.2"]) + const available = new Set([ + "opencode/gpt-5.2", + "openai/gpt-5.2", + ]) const result = fuzzyMatchModel("gpt-5.2", available) expect(result).toBe("openai/gpt-5.2") }) + // given available models with multiple providers + // when multiple providers are specified + // then search all specified providers it("should search all specified providers", () => { const available = new Set([ "openai/gpt-5.2", @@ -325,12 +343,21 @@ describe("fuzzyMatchModel", () => { expect(result).toBe("openai/gpt-5.2") }) + // given available models with provider prefix + // when searching with provider filter + // then only match models with correct provider prefix it("should only match models with correct provider prefix", () => { - const available = new Set(["openai/gpt-5.2", "anthropic/gpt-something"]) + const available = new Set([ + "openai/gpt-5.2", + "anthropic/gpt-something", + ]) const result = fuzzyMatchModel("gpt", available, ["openai"]) expect(result).toBe("openai/gpt-5.2") }) + // given empty available set + // when searching + // then return null it("should return null for empty available set", () => { const available = new Set() const result = fuzzyMatchModel("gpt", available) @@ -339,13 +366,16 @@ describe("fuzzyMatchModel", () => { }) describe("getConnectedProviders", () => { + // given SDK client with connected providers + // when provider.list returns data + // then returns connected array it("should return connected providers from SDK", async () => { const mockClient = { provider: { list: async () => ({ - data: { connected: ["anthropic", "opencode", "google"] }, - }), - }, + data: { connected: ["anthropic", "opencode", "google"] } + }) + } } const result = await getConnectedProviders(mockClient) @@ -353,13 +383,14 @@ describe("getConnectedProviders", () => { expect(result).toEqual(["anthropic", "opencode", "google"]) }) + // given SDK client + // when provider.list throws error + // then returns empty array it("should return empty array on SDK error", async () => { const mockClient = { provider: { - list: async () => { - throw new Error("Network error") - }, - }, + list: async () => { throw new Error("Network error") } + } } const result = await getConnectedProviders(mockClient) @@ -367,11 +398,14 @@ describe("getConnectedProviders", () => { expect(result).toEqual([]) }) + // given SDK client with empty connected array + // when provider.list returns empty + // then returns empty array it("should return empty array when no providers connected", async () => { const mockClient = { provider: { - list: async () => ({ data: { connected: [] } }), - }, + list: async () => ({ data: { connected: [] } }) + } } const result = await getConnectedProviders(mockClient) @@ -379,6 +413,9 @@ describe("getConnectedProviders", () => { expect(result).toEqual([]) }) + // given SDK client without provider.list method + // when getConnectedProviders called + // then returns empty array it("should return empty array when client.provider.list not available", async () => { const mockClient = {} @@ -387,17 +424,23 @@ describe("getConnectedProviders", () => { expect(result).toEqual([]) }) + // given null client + // when getConnectedProviders called + // then returns empty array it("should return empty array for null client", async () => { const result = await getConnectedProviders(null) expect(result).toEqual([]) }) + // given SDK client with missing data.connected + // when provider.list returns without connected field + // then returns empty array it("should return empty array when data.connected is undefined", async () => { const mockClient = { provider: { - list: async () => ({ data: {} }), - }, + list: async () => ({ data: {} }) + } } const result = await getConnectedProviders(mockClient) @@ -432,6 +475,9 @@ describe("fetchAvailableModels with connected providers filtering", () => { writeFileSync(join(cacheDir, "models.json"), JSON.stringify(data)) } + // given cache with multiple providers + // when connectedProviders specifies one provider + // then only returns models from that provider it("should filter models by connected providers", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, @@ -440,7 +486,7 @@ describe("fetchAvailableModels with connected providers filtering", () => { }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["anthropic"], + connectedProviders: ["anthropic"] }) expect(result.size).toBe(1) @@ -449,6 +495,9 @@ describe("fetchAvailableModels with connected providers filtering", () => { expect(result.has("google/gemini-3-pro")).toBe(false) }) + // given cache with multiple providers + // when connectedProviders specifies multiple providers + // then returns models from all specified providers it("should filter models by multiple connected providers", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, @@ -457,7 +506,7 @@ describe("fetchAvailableModels with connected providers filtering", () => { }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["anthropic", "google"], + connectedProviders: ["anthropic", "google"] }) expect(result.size).toBe(2) @@ -466,6 +515,9 @@ describe("fetchAvailableModels with connected providers filtering", () => { expect(result.has("openai/gpt-5.2")).toBe(false) }) + // given cache with models + // when connectedProviders is empty array + // then returns empty set it("should return empty set when connectedProviders is empty", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, @@ -473,12 +525,15 @@ describe("fetchAvailableModels with connected providers filtering", () => { }) const result = await fetchAvailableModels(undefined, { - connectedProviders: [], + connectedProviders: [] }) expect(result.size).toBe(0) }) + // given cache with models + // when connectedProviders is undefined (no options) + // then returns empty set (triggers fallback in resolver) it("should return empty set when connectedProviders not specified", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, @@ -490,18 +545,24 @@ describe("fetchAvailableModels with connected providers filtering", () => { expect(result.size).toBe(0) }) + // given cache with models + // when connectedProviders contains provider not in cache + // then returns empty set for that provider it("should handle provider not in cache gracefully", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["azure"], + connectedProviders: ["azure"] }) expect(result.size).toBe(0) }) + // given cache with models and mixed connected providers + // when some providers exist in cache and some don't + // then returns models only from matching providers it("should return models from providers that exist in both cache and connected list", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, @@ -509,31 +570,39 @@ describe("fetchAvailableModels with connected providers filtering", () => { }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["anthropic", "azure", "unknown"], + connectedProviders: ["anthropic", "azure", "unknown"] }) expect(result.size).toBe(1) expect(result.has("anthropic/claude-opus-4-6")).toBe(true) }) + // given filtered fetch + // when called twice with different filters + // then does NOT use cache (dynamic per-session) it("should not cache filtered results", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, anthropic: { models: { "claude-opus-4-6": { id: "claude-opus-4-6" } } }, }) + // First call with anthropic const result1 = await fetchAvailableModels(undefined, { - connectedProviders: ["anthropic"], + connectedProviders: ["anthropic"] }) expect(result1.size).toBe(1) + // Second call with openai - should work, not cached const result2 = await fetchAvailableModels(undefined, { - connectedProviders: ["openai"], + connectedProviders: ["openai"] }) expect(result2.size).toBe(1) expect(result2.has("openai/gpt-5.2")).toBe(true) }) + // given connectedProviders unknown + // when called twice without connectedProviders + // then always returns empty set (triggers fallback) it("should return empty set when connectedProviders unknown", async () => { writeModelsCache({ openai: { models: { "gpt-5.2": { id: "gpt-5.2" } } }, @@ -567,19 +636,13 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", rmSync(tempDir, { recursive: true, force: true }) }) - function writeProviderModelsCache(data: { - models: Record - connected: string[] - }) { + function writeProviderModelsCache(data: { models: Record; connected: string[] }) { const cacheDir = join(tempDir, "oh-my-opencode") require("fs").mkdirSync(cacheDir, { recursive: true }) - writeFileSync( - join(cacheDir, "provider-models.json"), - JSON.stringify({ - ...data, - updatedAt: new Date().toISOString(), - }), - ) + writeFileSync(join(cacheDir, "provider-models.json"), JSON.stringify({ + ...data, + updatedAt: new Date().toISOString() + })) } function writeModelsCache(data: Record) { @@ -588,21 +651,24 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", writeFileSync(join(cacheDir, "models.json"), JSON.stringify(data)) } + // given provider-models cache exists (whitelist-filtered) + // when fetchAvailableModels called + // then uses provider-models cache instead of models.json it("should prefer provider-models cache over models.json", async () => { writeProviderModelsCache({ models: { opencode: ["glm-4.7-free", "gpt-5-nano"], - anthropic: ["claude-opus-4-6"], + anthropic: ["claude-opus-4-6"] }, - connected: ["opencode", "anthropic"], + connected: ["opencode", "anthropic"] }) writeModelsCache({ opencode: { models: { "glm-4.7-free": {}, "gpt-5-nano": {}, "gpt-5.2": {} } }, - anthropic: { models: { "claude-opus-4-6": {}, "claude-sonnet-4-5": {} } }, + anthropic: { models: { "claude-opus-4-6": {}, "claude-sonnet-4-5": {} } } }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["opencode", "anthropic"], + connectedProviders: ["opencode", "anthropic"] }) expect(result.size).toBe(3) @@ -613,9 +679,13 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", expect(result.has("anthropic/claude-sonnet-4-5")).toBe(false) }) + // given provider-models cache exists but has no models (API failure) + // when fetchAvailableModels called + // then falls back to models.json so fuzzy matching can still work it("should fall back to models.json when provider-models cache is empty", async () => { writeProviderModelsCache({ - models: {}, + models: { + }, connected: ["google"], }) writeModelsCache({ @@ -625,22 +695,21 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", const availableModels = await fetchAvailableModels(undefined, { connectedProviders: ["google"], }) - const match = fuzzyMatchModel( - "google/gemini-3-flash", - availableModels, - ["google"], - ) + const match = fuzzyMatchModel("google/gemini-3-flash", availableModels, ["google"]) expect(match).toBe("google/gemini-3-flash-preview") }) + // given only models.json exists (no provider-models cache) + // when fetchAvailableModels called + // then falls back to models.json (no whitelist filtering) it("should fallback to models.json when provider-models cache not found", async () => { writeModelsCache({ opencode: { models: { "glm-4.7-free": {}, "gpt-5-nano": {}, "gpt-5.2": {} } }, }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["opencode"], + connectedProviders: ["opencode"] }) expect(result.size).toBe(3) @@ -649,18 +718,21 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", expect(result.has("opencode/gpt-5.2")).toBe(true) }) + // given provider-models cache with whitelist + // when connectedProviders filters to subset + // then only returns models from connected providers it("should filter by connectedProviders even with provider-models cache", async () => { writeProviderModelsCache({ models: { opencode: ["glm-4.7-free"], anthropic: ["claude-opus-4-6"], - google: ["gemini-3-pro"], + google: ["gemini-3-pro"] }, - connected: ["opencode", "anthropic", "google"], + connected: ["opencode", "anthropic", "google"] }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["opencode"], + connectedProviders: ["opencode"] }) expect(result.size).toBe(1) @@ -673,25 +745,15 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", writeProviderModelsCache({ models: { ollama: [ - { - id: "ministral-3:14b-32k-agent", - provider: "ollama", - context: 32768, - output: 8192, - }, - { - id: "qwen3-coder:32k-agent", - provider: "ollama", - context: 32768, - output: 8192, - }, - ], + { id: "ministral-3:14b-32k-agent", provider: "ollama", context: 32768, output: 8192 }, + { id: "qwen3-coder:32k-agent", provider: "ollama", context: 32768, output: 8192 } + ] }, - connected: ["ollama"], + connected: ["ollama"] }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["ollama"], + connectedProviders: ["ollama"] }) expect(result.size).toBe(2) @@ -705,14 +767,14 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", anthropic: ["claude-opus-4-6", "claude-sonnet-4-5"], ollama: [ { id: "ministral-3:14b-32k-agent", provider: "ollama" }, - { id: "qwen3-coder:32k-agent", provider: "ollama" }, - ], + { id: "qwen3-coder:32k-agent", provider: "ollama" } + ] }, - connected: ["anthropic", "ollama"], + connected: ["anthropic", "ollama"] }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["anthropic", "ollama"], + connectedProviders: ["anthropic", "ollama"] }) expect(result.size).toBe(4) @@ -730,14 +792,14 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", { provider: "ollama" }, { id: "", provider: "ollama" }, null, - "string-model", - ], + "string-model" + ] }, - connected: ["ollama"], + connected: ["ollama"] }) const result = await fetchAvailableModels(undefined, { - connectedProviders: ["ollama"], + connectedProviders: ["ollama"] }) expect(result.size).toBe(2) @@ -749,10 +811,7 @@ describe("fetchAvailableModels with provider-models cache (whitelist-filtered)", describe("isModelAvailable", () => { it("returns true when model exists via fuzzy match", () => { // given - const available = new Set([ - "openai/gpt-5.3-codex", - "anthropic/claude-opus-4-6", - ]) + const available = new Set(["openai/gpt-5.3-codex", "anthropic/claude-opus-4-6"]) // when const result = isModelAvailable("gpt-5.3-codex", available) diff --git a/src/shared/model-availability.ts b/src/shared/model-availability.ts index ad5df3b4..1ff696ee 100644 --- a/src/shared/model-availability.ts +++ b/src/shared/model-availability.ts @@ -1,4 +1,358 @@ -export { fetchAvailableModels, getConnectedProviders } from "./available-models-fetcher" -export { isAnyFallbackModelAvailable, isAnyProviderConnected } from "./fallback-model-availability" -export { __resetModelCache, isModelCacheAvailable } from "./model-cache-availability" -export { fuzzyMatchModel, isModelAvailable } from "./model-name-matcher" +import { existsSync, readFileSync } from "fs" +import { join } from "path" +import { log } from "./logger" +import { getOpenCodeCacheDir } from "./data-path" +import * as connectedProvidersCache from "./connected-providers-cache" + +/** + * Fuzzy match a target model name against available models + * + * @param target - The model name or substring to search for (e.g., "gpt-5.2", "claude-opus") + * @param available - Set of available model names in format "provider/model-name" + * @param providers - Optional array of provider names to filter by (e.g., ["openai", "anthropic"]) + * @returns The matched model name or null if no match found + * + * Matching priority: + * 1. Exact match (if exists) + * 2. Shorter model name (more specific) + * + * Matching is case-insensitive substring match. + * If providers array is given, only models starting with "provider/" are considered. + * + * @example + * const available = new Set(["openai/gpt-5.2", "openai/gpt-5.3-codex", "anthropic/claude-opus-4-6"]) + * fuzzyMatchModel("gpt-5.2", available) // → "openai/gpt-5.2" + * fuzzyMatchModel("claude", available, ["openai"]) // → null (provider filter excludes anthropic) + */ +function normalizeModelName(name: string): string { + return name + .toLowerCase() + .replace(/claude-(opus|sonnet|haiku)-4-5/g, "claude-$1-4.5") + .replace(/claude-(opus|sonnet|haiku)-4\.5/g, "claude-$1-4.5") +} + +export function fuzzyMatchModel( + target: string, + available: Set, + providers?: string[], +): string | null { + log("[fuzzyMatchModel] called", { target, availableCount: available.size, providers }) + + if (available.size === 0) { + log("[fuzzyMatchModel] empty available set") + return null + } + + const targetNormalized = normalizeModelName(target) + + // Filter by providers if specified + let candidates = Array.from(available) + if (providers && providers.length > 0) { + const providerSet = new Set(providers) + candidates = candidates.filter((model) => { + const [provider] = model.split("/") + return providerSet.has(provider) + }) + log("[fuzzyMatchModel] filtered by providers", { candidateCount: candidates.length, candidates: candidates.slice(0, 10) }) + } + + if (candidates.length === 0) { + log("[fuzzyMatchModel] no candidates after filter") + return null + } + + // Find all matches (case-insensitive substring match with normalization) + const matches = candidates.filter((model) => + normalizeModelName(model).includes(targetNormalized), + ) + + log("[fuzzyMatchModel] substring matches", { targetNormalized, matchCount: matches.length, matches }) + + if (matches.length === 0) { + return null + } + + // Priority 1: Exact match (normalized full model string) + const exactMatch = matches.find((model) => normalizeModelName(model) === targetNormalized) + if (exactMatch) { + log("[fuzzyMatchModel] exact match found", { exactMatch }) + return exactMatch + } + + // Priority 2: Exact model ID match (part after provider/) + // This ensures "glm-4.7-free" matches "zai-coding-plan/glm-4.7-free" over "zai-coding-plan/glm-4.7" + // Use filter + shortest to handle multi-provider cases (e.g., openai/gpt-5.2 + opencode/gpt-5.2) + const exactModelIdMatches = matches.filter((model) => { + const modelId = model.split("/").slice(1).join("/") + return normalizeModelName(modelId) === targetNormalized + }) + if (exactModelIdMatches.length > 0) { + const result = exactModelIdMatches.reduce((shortest, current) => + current.length < shortest.length ? current : shortest, + ) + log("[fuzzyMatchModel] exact model ID match found", { result, candidateCount: exactModelIdMatches.length }) + return result + } + + // Priority 3: Shorter model name (more specific, fallback for partial matches) + const result = matches.reduce((shortest, current) => + current.length < shortest.length ? current : shortest, + ) + log("[fuzzyMatchModel] shortest match", { result }) + return result +} + +/** + * Check if a target model is available (fuzzy match by model name, no provider filtering) + * + * @param targetModel - Model name to check (e.g., "gpt-5.3-codex") + * @param availableModels - Set of available models in "provider/model" format + * @returns true if model is available, false otherwise + */ +export function isModelAvailable( + targetModel: string, + availableModels: Set, +): boolean { + return fuzzyMatchModel(targetModel, availableModels) !== null +} + +export async function getConnectedProviders(client: any): Promise { + if (!client?.provider?.list) { + log("[getConnectedProviders] client.provider.list not available") + return [] + } + + try { + const result = await client.provider.list() + const connected = result.data?.connected ?? [] + log("[getConnectedProviders] connected providers", { count: connected.length, providers: connected }) + return connected + } catch (err) { + log("[getConnectedProviders] SDK error", { error: String(err) }) + return [] + } +} + +export async function fetchAvailableModels( + client?: any, + options?: { connectedProviders?: string[] | null } +): Promise> { + let connectedProviders = options?.connectedProviders ?? null + let connectedProvidersUnknown = connectedProviders === null + + log("[fetchAvailableModels] CALLED", { + connectedProvidersUnknown, + connectedProviders: options?.connectedProviders + }) + + if (connectedProvidersUnknown && client) { + const liveConnected = await getConnectedProviders(client) + if (liveConnected.length > 0) { + connectedProviders = liveConnected + connectedProvidersUnknown = false + log("[fetchAvailableModels] connected providers fetched from client", { count: liveConnected.length }) + } + } + + if (connectedProvidersUnknown) { + if (client?.model?.list) { + const modelSet = new Set() + try { + const modelsResult = await client.model.list() + const models = modelsResult.data ?? [] + for (const model of models) { + if (model?.provider && model?.id) { + modelSet.add(`${model.provider}/${model.id}`) + } + } + log("[fetchAvailableModels] fetched models from client without provider filter", { + count: modelSet.size, + }) + return modelSet + } catch (err) { + log("[fetchAvailableModels] client.model.list error", { error: String(err) }) + } + } + log("[fetchAvailableModels] connected providers unknown, returning empty set for fallback resolution") + return new Set() + } + + const connectedProvidersList = connectedProviders ?? [] + const connectedSet = new Set(connectedProvidersList) + const modelSet = new Set() + + const providerModelsCache = connectedProvidersCache.readProviderModelsCache() + if (providerModelsCache) { + const providerCount = Object.keys(providerModelsCache.models).length + if (providerCount === 0) { + log("[fetchAvailableModels] provider-models cache empty, falling back to models.json") + } else { + log("[fetchAvailableModels] using provider-models cache (whitelist-filtered)") + + const modelsByProvider = providerModelsCache.models as Record> + for (const [providerId, modelIds] of Object.entries(modelsByProvider)) { + if (!connectedSet.has(providerId)) { + continue + } + for (const modelItem of modelIds) { + // Handle both string[] (legacy) and object[] (with metadata) formats + const modelId = typeof modelItem === 'string' + ? modelItem + : (modelItem as any)?.id + + if (modelId) { + modelSet.add(`${providerId}/${modelId}`) + } + } + } + + log("[fetchAvailableModels] parsed from provider-models cache", { + count: modelSet.size, + connectedProviders: connectedProvidersList.slice(0, 5) + }) + + if (modelSet.size > 0) { + return modelSet + } + log("[fetchAvailableModels] provider-models cache produced no models for connected providers, falling back to models.json") + } + } + + log("[fetchAvailableModels] provider-models cache not found, falling back to models.json") + const cacheFile = join(getOpenCodeCacheDir(), "models.json") + + if (!existsSync(cacheFile)) { + log("[fetchAvailableModels] models.json cache file not found, falling back to client") + } else { + try { + const content = readFileSync(cacheFile, "utf-8") + const data = JSON.parse(content) as Record }> + + const providerIds = Object.keys(data) + log("[fetchAvailableModels] providers found in models.json", { count: providerIds.length, providers: providerIds.slice(0, 10) }) + + for (const providerId of providerIds) { + if (!connectedSet.has(providerId)) { + continue + } + + const provider = data[providerId] + const models = provider?.models + if (!models || typeof models !== "object") continue + + for (const modelKey of Object.keys(models)) { + modelSet.add(`${providerId}/${modelKey}`) + } + } + + log("[fetchAvailableModels] parsed models from models.json (NO whitelist filtering)", { + count: modelSet.size, + connectedProviders: connectedProvidersList.slice(0, 5) + }) + + if (modelSet.size > 0) { + return modelSet + } + } catch (err) { + log("[fetchAvailableModels] error", { error: String(err) }) + } + } + + if (client?.model?.list) { + try { + const modelsResult = await client.model.list() + const models = modelsResult.data ?? [] + + for (const model of models) { + if (!model?.provider || !model?.id) continue + if (connectedSet.has(model.provider)) { + modelSet.add(`${model.provider}/${model.id}`) + } + } + + log("[fetchAvailableModels] fetched models from client (filtered)", { + count: modelSet.size, + connectedProviders: connectedProvidersList.slice(0, 5), + }) + } catch (err) { + log("[fetchAvailableModels] client.model.list error", { error: String(err) }) + } + } + + return modelSet +} + +export function isAnyFallbackModelAvailable( + fallbackChain: Array<{ providers: string[]; model: string }>, + availableModels: Set, +): boolean { + // If we have models, check them first + if (availableModels.size > 0) { + for (const entry of fallbackChain) { + const hasAvailableProvider = entry.providers.some((provider) => { + return fuzzyMatchModel(entry.model, availableModels, [provider]) !== null + }) + if (hasAvailableProvider) { + return true + } + } + } + + // Fallback: check if any provider in the chain is connected + // This handles race conditions where availableModels is empty or incomplete + // but we know the provider is connected. + const connectedProviders = connectedProvidersCache.readConnectedProvidersCache() + if (connectedProviders) { + const connectedSet = new Set(connectedProviders) + for (const entry of fallbackChain) { + if (entry.providers.some((p) => connectedSet.has(p))) { + log("[isAnyFallbackModelAvailable] model not in available set, but provider is connected", { + model: entry.model, + availableCount: availableModels.size, + }) + return true + } + } + } + + return false +} + +export function isAnyProviderConnected( + providers: string[], + availableModels: Set, +): boolean { + if (availableModels.size > 0) { + const providerSet = new Set(providers) + for (const model of availableModels) { + const [provider] = model.split("/") + if (providerSet.has(provider)) { + log("[isAnyProviderConnected] found model from required provider", { provider, model }) + return true + } + } + } + + const connectedProviders = connectedProvidersCache.readConnectedProvidersCache() + if (connectedProviders) { + const connectedSet = new Set(connectedProviders) + for (const provider of providers) { + if (connectedSet.has(provider)) { + log("[isAnyProviderConnected] provider connected via cache", { provider }) + return true + } + } + } + + return false +} + +export function __resetModelCache(): void {} + +export function isModelCacheAvailable(): boolean { + if (connectedProvidersCache.hasProviderModelsCache()) { + return true + } + const cacheFile = join(getOpenCodeCacheDir(), "models.json") + return existsSync(cacheFile) +} diff --git a/src/shared/model-resolution-pipeline.ts b/src/shared/model-resolution-pipeline.ts index 1d27617e..34d1c13b 100644 --- a/src/shared/model-resolution-pipeline.ts +++ b/src/shared/model-resolution-pipeline.ts @@ -1,16 +1,37 @@ import { log } from "./logger" -import { readConnectedProvidersCache } from "./connected-providers-cache" +import * as connectedProvidersCache from "./connected-providers-cache" import { fuzzyMatchModel } from "./model-availability" -import type { - ModelResolutionRequest, - ModelResolutionResult, -} from "./model-resolution-types" +import type { FallbackEntry } from "./model-requirements" -export type { - ModelResolutionProvenance, - ModelResolutionRequest, - ModelResolutionResult, -} from "./model-resolution-types" +export type ModelResolutionRequest = { + intent?: { + uiSelectedModel?: string + userModel?: string + categoryDefaultModel?: string + } + constraints: { + availableModels: Set + connectedProviders?: string[] | null + } + policy?: { + fallbackChain?: FallbackEntry[] + systemDefaultModel?: string + } +} + +export type ModelResolutionProvenance = + | "override" + | "category-default" + | "provider-fallback" + | "system-default" + +export type ModelResolutionResult = { + model: string + provenance: ModelResolutionProvenance + variant?: string + attempted?: string[] + reason?: string +} function normalizeModel(model?: string): string | undefined { const trimmed = model?.trim() @@ -53,7 +74,7 @@ export function resolveModelPipeline( return { model: match, provenance: "category-default", attempted } } } else { - const connectedProviders = readConnectedProvidersCache() + const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache() if (connectedProviders === null) { log("Model resolved via category default (no cache, first run)", { model: normalizedCategoryDefault, @@ -78,7 +99,7 @@ export function resolveModelPipeline( if (fallbackChain && fallbackChain.length > 0) { if (availableModels.size === 0) { - const connectedProviders = readConnectedProvidersCache() + const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache() const connectedSet = connectedProviders ? new Set(connectedProviders) : null if (connectedSet === null) { diff --git a/src/tools/delegate-task/tools.test.ts b/src/tools/delegate-task/tools.test.ts index 4f32addf..2d91acaa 100644 --- a/src/tools/delegate-task/tools.test.ts +++ b/src/tools/delegate-task/tools.test.ts @@ -10,11 +10,27 @@ import * as connectedProvidersCache from "../../shared/connected-providers-cache const SYSTEM_DEFAULT_MODEL = "anthropic/claude-sonnet-4-5" +const TEST_CONNECTED_PROVIDERS = ["anthropic", "google", "openai"] +const TEST_AVAILABLE_MODELS = new Set([ + "anthropic/claude-opus-4-6", + "anthropic/claude-sonnet-4-5", + "anthropic/claude-haiku-4-5", + "google/gemini-3-pro", + "google/gemini-3-flash", + "openai/gpt-5.2", + "openai/gpt-5.3-codex", +]) + +function createTestAvailableModels(): Set { + return new Set(TEST_AVAILABLE_MODELS) +} + describe("sisyphus-task", () => { let cacheSpy: ReturnType let providerModelsSpy: ReturnType beforeEach(() => { + mock.restore() __resetModelCache() clearSkillCache() __setTimingConfig({ @@ -271,6 +287,8 @@ describe("sisyphus-task", () => { const tool = createDelegateTask({ manager: mockManager, client: mockClient, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -324,6 +342,8 @@ describe("sisyphus-task", () => { const tool = createDelegateTask({ manager: mockManager, client: mockClient, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -436,6 +456,8 @@ describe("sisyphus-task", () => { const tool = createDelegateTask({ manager: mockManager, client: mockClient, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const metadataCalls: Array<{ title?: string; metadata?: Record }> = [] @@ -727,6 +749,8 @@ describe("sisyphus-task", () => { userCategories: { ultrabrain: { model: "openai/gpt-5.2", variant: "xhigh" }, }, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -790,6 +814,8 @@ describe("sisyphus-task", () => { const tool = createDelegateTask({ manager: mockManager, client: mockClient, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -1950,6 +1976,8 @@ describe("sisyphus-task", () => { client: mockClient, // userCategories: undefined - use DEFAULT_CATEGORIES only // sisyphusJuniorModel: undefined + connectedProvidersOverride: null, + availableModelsOverride: new Set(), }) const toolContext = { @@ -2013,6 +2041,8 @@ describe("sisyphus-task", () => { userCategories: { "fallback-test": { model: "anthropic/claude-opus-4-6" }, }, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -2072,6 +2102,8 @@ describe("sisyphus-task", () => { manager: mockManager, client: mockClient, sisyphusJuniorModel: "anthropic/claude-sonnet-4-5", + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -2135,6 +2167,8 @@ describe("sisyphus-task", () => { userCategories: { ultrabrain: { model: "openai/gpt-5.3-codex" }, }, + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -2194,6 +2228,8 @@ describe("sisyphus-task", () => { manager: mockManager, client: mockClient, sisyphusJuniorModel: "anthropic/claude-sonnet-4-5", + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { @@ -3207,6 +3243,8 @@ describe("sisyphus-task", () => { manager: mockManager, client: mockClient, // no agentOverrides + connectedProvidersOverride: TEST_CONNECTED_PROVIDERS, + availableModelsOverride: createTestAvailableModels(), }) const toolContext = { diff --git a/src/tools/delegate-task/types.ts b/src/tools/delegate-task/types.ts index 4327bdce..13d1973a 100644 --- a/src/tools/delegate-task/types.ts +++ b/src/tools/delegate-task/types.ts @@ -50,6 +50,15 @@ export interface DelegateTaskToolOptions { manager: BackgroundManager client: OpencodeClient directory: string + /** + * Test hook: bypass global cache reads (Bun runs tests in parallel). + * If provided, resolveCategoryExecution/resolveSubagentExecution uses this instead of reading from disk cache. + */ + connectedProvidersOverride?: string[] | null + /** + * Test hook: bypass fetchAvailableModels() by providing an explicit available model set. + */ + availableModelsOverride?: Set userCategories?: CategoriesConfig gitMasterConfig?: GitMasterConfig sisyphusJuniorModel?: string