diff --git a/src/features/background-agent/compaction-aware-message-resolver.test.ts b/src/features/background-agent/compaction-aware-message-resolver.test.ts new file mode 100644 index 00000000..b8a659b0 --- /dev/null +++ b/src/features/background-agent/compaction-aware-message-resolver.test.ts @@ -0,0 +1,190 @@ +import { describe, test, expect, beforeEach, afterEach } from "bun:test" +import { mkdtempSync, writeFileSync, rmSync } from "node:fs" +import { join } from "node:path" +import { tmpdir } from "node:os" +import { isCompactionAgent, findNearestMessageExcludingCompaction } from "./compaction-aware-message-resolver" + +describe("isCompactionAgent", () => { + describe("#given agent name variations", () => { + test("returns true for 'compaction'", () => { + // when + const result = isCompactionAgent("compaction") + + // then + expect(result).toBe(true) + }) + + test("returns true for 'Compaction' (case insensitive)", () => { + // when + const result = isCompactionAgent("Compaction") + + // then + expect(result).toBe(true) + }) + + test("returns true for ' compaction ' (with whitespace)", () => { + // when + const result = isCompactionAgent(" compaction ") + + // then + expect(result).toBe(true) + }) + + test("returns false for undefined", () => { + // when + const result = isCompactionAgent(undefined) + + // then + expect(result).toBe(false) + }) + + test("returns false for null", () => { + // when + const result = isCompactionAgent(null as unknown as string) + + // then + expect(result).toBe(false) + }) + + test("returns false for non-compaction agent like 'sisyphus'", () => { + // when + const result = isCompactionAgent("sisyphus") + + // then + expect(result).toBe(false) + }) + }) +}) + +describe("findNearestMessageExcludingCompaction", () => { + let tempDir: string + + beforeEach(() => { + tempDir = mkdtempSync(join(tmpdir(), "compaction-test-")) + }) + + afterEach(() => { + rmSync(tempDir, { force: true, recursive: true }) + }) + + describe("#given directory with messages", () => { + test("finds message with full agent and model", () => { + // given + const message = { + agent: "sisyphus", + model: { providerID: "anthropic", modelID: "claude-opus-4-6" }, + } + writeFileSync(join(tempDir, "001.json"), JSON.stringify(message)) + + // when + const result = findNearestMessageExcludingCompaction(tempDir) + + // then + expect(result).not.toBeNull() + expect(result?.agent).toBe("sisyphus") + expect(result?.model?.providerID).toBe("anthropic") + expect(result?.model?.modelID).toBe("claude-opus-4-6") + }) + + test("skips compaction agent messages", () => { + // given + const compactionMessage = { + agent: "compaction", + model: { providerID: "anthropic", modelID: "claude-opus-4-6" }, + } + const validMessage = { + agent: "sisyphus", + model: { providerID: "anthropic", modelID: "claude-opus-4-6" }, + } + writeFileSync(join(tempDir, "001.json"), JSON.stringify(compactionMessage)) + writeFileSync(join(tempDir, "002.json"), JSON.stringify(validMessage)) + + // when + const result = findNearestMessageExcludingCompaction(tempDir) + + // then + expect(result).not.toBeNull() + expect(result?.agent).toBe("sisyphus") + }) + + test("falls back to partial agent/model match", () => { + // given + const messageWithAgentOnly = { + agent: "hephaestus", + } + const messageWithModelOnly = { + model: { providerID: "openai", modelID: "gpt-5.3" }, + } + writeFileSync(join(tempDir, "001.json"), JSON.stringify(messageWithModelOnly)) + writeFileSync(join(tempDir, "002.json"), JSON.stringify(messageWithAgentOnly)) + + // when + const result = findNearestMessageExcludingCompaction(tempDir) + + // then + expect(result).not.toBeNull() + // Should find the one with agent first (sorted reverse, so 002 is checked first) + expect(result?.agent).toBe("hephaestus") + }) + + test("returns null for empty directory", () => { + // given - empty directory (tempDir is already empty) + + // when + const result = findNearestMessageExcludingCompaction(tempDir) + + // then + expect(result).toBeNull() + }) + + test("returns null for non-existent directory", () => { + // given + const nonExistentDir = join(tmpdir(), "non-existent-dir-12345") + + // when + const result = findNearestMessageExcludingCompaction(nonExistentDir) + + // then + expect(result).toBeNull() + }) + + test("skips invalid JSON files and finds valid message", () => { + // given + const invalidJson = "{ invalid json" + const validMessage = { + agent: "oracle", + model: { providerID: "google", modelID: "gemini-2-flash" }, + } + writeFileSync(join(tempDir, "001.json"), invalidJson) + writeFileSync(join(tempDir, "002.json"), JSON.stringify(validMessage)) + + // when + const result = findNearestMessageExcludingCompaction(tempDir) + + // then + expect(result).not.toBeNull() + expect(result?.agent).toBe("oracle") + }) + + test("finds newest valid message (sorted by filename reverse)", () => { + // given + const olderMessage = { + agent: "older", + model: { providerID: "a", modelID: "b" }, + } + const newerMessage = { + agent: "newer", + model: { providerID: "c", modelID: "d" }, + } + writeFileSync(join(tempDir, "001.json"), JSON.stringify(olderMessage)) + writeFileSync(join(tempDir, "010.json"), JSON.stringify(newerMessage)) + + // when + const result = findNearestMessageExcludingCompaction(tempDir) + + // then + expect(result).not.toBeNull() + expect(result?.agent).toBe("newer") + }) + }) +}) diff --git a/src/features/background-agent/compaction-aware-message-resolver.ts b/src/features/background-agent/compaction-aware-message-resolver.ts new file mode 100644 index 00000000..1bf94bfd --- /dev/null +++ b/src/features/background-agent/compaction-aware-message-resolver.ts @@ -0,0 +1,57 @@ +import { readdirSync, readFileSync } from "node:fs" +import { join } from "node:path" +import type { StoredMessage } from "../hook-message-injector" + +export function isCompactionAgent(agent: string | undefined): boolean { + return agent?.trim().toLowerCase() === "compaction" +} + +function hasFullAgentAndModel(message: StoredMessage): boolean { + return !!message.agent && + !isCompactionAgent(message.agent) && + !!message.model?.providerID && + !!message.model?.modelID +} + +function hasPartialAgentOrModel(message: StoredMessage): boolean { + const hasAgent = !!message.agent && !isCompactionAgent(message.agent) + const hasModel = !!message.model?.providerID && !!message.model?.modelID + return hasAgent || hasModel +} + +export function findNearestMessageExcludingCompaction(messageDir: string): StoredMessage | null { + try { + const files = readdirSync(messageDir) + .filter((name) => name.endsWith(".json")) + .sort() + .reverse() + + for (const file of files) { + try { + const content = readFileSync(join(messageDir, file), "utf-8") + const parsed = JSON.parse(content) as StoredMessage + if (hasFullAgentAndModel(parsed)) { + return parsed + } + } catch { + continue + } + } + + for (const file of files) { + try { + const content = readFileSync(join(messageDir, file), "utf-8") + const parsed = JSON.parse(content) as StoredMessage + if (hasPartialAgentOrModel(parsed)) { + return parsed + } + } catch { + continue + } + } + } catch { + return null + } + + return null +} diff --git a/src/features/background-agent/error-classifier.test.ts b/src/features/background-agent/error-classifier.test.ts new file mode 100644 index 00000000..1fe24e93 --- /dev/null +++ b/src/features/background-agent/error-classifier.test.ts @@ -0,0 +1,351 @@ +import { describe, test, expect } from "bun:test" +import { + isRecord, + isAbortedSessionError, + getErrorText, + extractErrorName, + extractErrorMessage, + getSessionErrorMessage, +} from "./error-classifier" + +describe("isRecord", () => { + describe("#given null or primitive values", () => { + test("returns false for null", () => { + expect(isRecord(null)).toBe(false) + }) + + test("returns false for undefined", () => { + expect(isRecord(undefined)).toBe(false) + }) + + test("returns false for string", () => { + expect(isRecord("hello")).toBe(false) + }) + + test("returns false for number", () => { + expect(isRecord(42)).toBe(false) + }) + + test("returns false for boolean", () => { + expect(isRecord(true)).toBe(false) + }) + + test("returns true for array (arrays are objects)", () => { + expect(isRecord([1, 2, 3])).toBe(true) + }) + }) + + describe("#given plain objects", () => { + test("returns true for empty object", () => { + expect(isRecord({})).toBe(true) + }) + + test("returns true for object with properties", () => { + expect(isRecord({ key: "value" })).toBe(true) + }) + + test("returns true for object with nested objects", () => { + expect(isRecord({ nested: { deep: true } })).toBe(true) + }) + }) + + describe("#given Error instances", () => { + test("returns true for Error instance", () => { + expect(isRecord(new Error("test"))).toBe(true) + }) + + test("returns true for TypeError instance", () => { + expect(isRecord(new TypeError("test"))).toBe(true) + }) + }) +}) + +describe("isAbortedSessionError", () => { + describe("#given error with aborted message", () => { + test("returns true for string containing aborted", () => { + expect(isAbortedSessionError("Session aborted")).toBe(true) + }) + + test("returns true for string with ABORTED uppercase", () => { + expect(isAbortedSessionError("Session ABORTED")).toBe(true) + }) + + test("returns true for Error with aborted in message", () => { + expect(isAbortedSessionError(new Error("Session aborted"))).toBe(true) + }) + + test("returns true for object with message containing aborted", () => { + expect(isAbortedSessionError({ message: "The session was aborted" })).toBe(true) + }) + }) + + describe("#given error without aborted message", () => { + test("returns false for string without aborted", () => { + expect(isAbortedSessionError("Session completed")).toBe(false) + }) + + test("returns false for Error without aborted", () => { + expect(isAbortedSessionError(new Error("Something went wrong"))).toBe(false) + }) + + test("returns false for empty string", () => { + expect(isAbortedSessionError("")).toBe(false) + }) + }) + + describe("#given invalid inputs", () => { + test("returns false for null", () => { + expect(isAbortedSessionError(null)).toBe(false) + }) + + test("returns false for undefined", () => { + expect(isAbortedSessionError(undefined)).toBe(false) + }) + + test("returns false for object without message", () => { + expect(isAbortedSessionError({ code: "ABORTED" })).toBe(false) + }) + }) +}) + +describe("getErrorText", () => { + describe("#given string input", () => { + test("returns the string as-is", () => { + expect(getErrorText("Something went wrong")).toBe("Something went wrong") + }) + + test("returns empty string for empty string", () => { + expect(getErrorText("")).toBe("") + }) + }) + + describe("#given Error instance", () => { + test("returns name and message format", () => { + expect(getErrorText(new Error("test message"))).toBe("Error: test message") + }) + + test("returns TypeError format", () => { + expect(getErrorText(new TypeError("type error"))).toBe("TypeError: type error") + }) + }) + + describe("#given object with message property", () => { + test("returns message property as string", () => { + expect(getErrorText({ message: "custom error" })).toBe("custom error") + }) + + test("returns name property when message not available", () => { + expect(getErrorText({ name: "CustomError" })).toBe("CustomError") + }) + + test("prefers message over name", () => { + expect(getErrorText({ name: "CustomError", message: "error message" })).toBe("error message") + }) + }) + + describe("#given invalid inputs", () => { + test("returns empty string for null", () => { + expect(getErrorText(null)).toBe("") + }) + + test("returns empty string for undefined", () => { + expect(getErrorText(undefined)).toBe("") + }) + + test("returns empty string for object without message or name", () => { + expect(getErrorText({ code: 500 })).toBe("") + }) + }) +}) + +describe("extractErrorName", () => { + describe("#given Error instance", () => { + test("returns Error for generic Error", () => { + expect(extractErrorName(new Error("test"))).toBe("Error") + }) + + test("returns TypeError name", () => { + expect(extractErrorName(new TypeError("test"))).toBe("TypeError") + }) + + test("returns RangeError name", () => { + expect(extractErrorName(new RangeError("test"))).toBe("RangeError") + }) + }) + + describe("#given plain object with name property", () => { + test("returns name property when string", () => { + expect(extractErrorName({ name: "CustomError" })).toBe("CustomError") + }) + + test("returns undefined when name is not string", () => { + expect(extractErrorName({ name: 123 })).toBe(undefined) + }) + }) + + describe("#given invalid inputs", () => { + test("returns undefined for null", () => { + expect(extractErrorName(null)).toBe(undefined) + }) + + test("returns undefined for undefined", () => { + expect(extractErrorName(undefined)).toBe(undefined) + }) + + test("returns undefined for string", () => { + expect(extractErrorName("Error message")).toBe(undefined) + }) + + test("returns undefined for object without name property", () => { + expect(extractErrorName({ message: "test" })).toBe(undefined) + }) + }) +}) + +describe("extractErrorMessage", () => { + describe("#given string input", () => { + test("returns the string as-is", () => { + expect(extractErrorMessage("error message")).toBe("error message") + }) + + test("returns undefined for empty string", () => { + expect(extractErrorMessage("")).toBe(undefined) + }) + }) + + describe("#given Error instance", () => { + test("returns error message", () => { + expect(extractErrorMessage(new Error("test error"))).toBe("test error") + }) + + test("returns empty string for Error with no message", () => { + expect(extractErrorMessage(new Error())).toBe("") + }) + }) + + describe("#given object with message property", () => { + test("returns message property", () => { + expect(extractErrorMessage({ message: "custom message" })).toBe("custom message") + }) + + test("falls through to JSON.stringify for empty message value", () => { + expect(extractErrorMessage({ message: "" })).toBe('{"message":""}') + }) + }) + + describe("#given nested error structure", () => { + test("extracts message from nested error object", () => { + expect(extractErrorMessage({ error: { message: "nested error" } })).toBe("nested error") + }) + + test("extracts message from data.error structure", () => { + expect(extractErrorMessage({ data: { error: "data error" } })).toBe("data error") + }) + + test("extracts message from cause property", () => { + expect(extractErrorMessage({ cause: "cause error" })).toBe("cause error") + }) + + test("extracts message from cause object with message", () => { + expect(extractErrorMessage({ cause: { message: "cause message" } })).toBe("cause message") + }) + }) + + describe("#given complex error with data wrapper", () => { + test("extracts from error.data.message", () => { + const error = { + data: { + message: "data message", + }, + } + expect(extractErrorMessage(error)).toBe("data message") + }) + + test("prefers top over nested-level message", () => { + const error = { + message: "top level", + data: { message: "nested" }, + } + expect(extractErrorMessage(error)).toBe("top level") + }) + }) + + describe("#given invalid inputs", () => { + test("returns undefined for null", () => { + expect(extractErrorMessage(null)).toBe(undefined) + }) + + test("returns undefined for undefined", () => { + expect(extractErrorMessage(undefined)).toBe(undefined) + }) + }) + + describe("#given object without extractable message", () => { + test("falls back to JSON.stringify for object", () => { + const obj = { code: 500, details: "error" } + const result = extractErrorMessage(obj) + expect(result).toContain('"code":500') + }) + + test("falls back to String() for non-serializable object", () => { + const circular: Record = { a: 1 } + circular.self = circular + const result = extractErrorMessage(circular) + expect(result).toBe("[object Object]") + }) + }) +}) + +describe("getSessionErrorMessage", () => { + describe("#given valid error properties", () => { + test("extracts message from error.message", () => { + const properties = { error: { message: "session error" } } + expect(getSessionErrorMessage(properties)).toBe("session error") + }) + + test("extracts message from error.data.message", () => { + const properties = { + error: { + data: { message: "data error message" }, + }, + } + expect(getSessionErrorMessage(properties)).toBe("data error message") + }) + + test("prefers error.data.message over error.message", () => { + const properties = { + error: { + message: "top level", + data: { message: "nested" }, + }, + } + expect(getSessionErrorMessage(properties)).toBe("nested") + }) + }) + + describe("#given missing or invalid properties", () => { + test("returns undefined when error is missing", () => { + expect(getSessionErrorMessage({})).toBe(undefined) + }) + + test("returns undefined when error is null", () => { + expect(getSessionErrorMessage({ error: null })).toBe(undefined) + }) + + test("returns undefined when error is string", () => { + expect(getSessionErrorMessage({ error: "error string" })).toBe(undefined) + }) + + test("returns undefined when data is not an object", () => { + expect(getSessionErrorMessage({ error: { data: "not an object" } })).toBe(undefined) + }) + + test("returns undefined when message is not string", () => { + expect(getSessionErrorMessage({ error: { message: 123 } })).toBe(undefined) + }) + + test("returns undefined when data.message is not string", () => { + expect(getSessionErrorMessage({ error: { data: { message: null } } })).toBe(undefined) + }) + }) +}) diff --git a/src/features/background-agent/error-classifier.ts b/src/features/background-agent/error-classifier.ts index 8be1dd3d..5c7e90b4 100644 --- a/src/features/background-agent/error-classifier.ts +++ b/src/features/background-agent/error-classifier.ts @@ -1,3 +1,7 @@ +export function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null +} + export function isAbortedSessionError(error: unknown): boolean { const message = getErrorText(error) return message.toLowerCase().includes("aborted") @@ -19,3 +23,61 @@ export function getErrorText(error: unknown): string { } return "" } + +export function extractErrorName(error: unknown): string | undefined { + if (isRecord(error) && typeof error["name"] === "string") return error["name"] + if (error instanceof Error) return error.name + return undefined +} + +export function extractErrorMessage(error: unknown): string | undefined { + if (!error) return undefined + if (typeof error === "string") return error + if (error instanceof Error) return error.message + + if (isRecord(error)) { + const dataRaw = error["data"] + const candidates: unknown[] = [ + error, + dataRaw, + error["error"], + isRecord(dataRaw) ? (dataRaw as Record)["error"] : undefined, + error["cause"], + ] + + for (const candidate of candidates) { + if (typeof candidate === "string" && candidate.length > 0) return candidate + if ( + isRecord(candidate) && + typeof candidate["message"] === "string" && + candidate["message"].length > 0 + ) { + return candidate["message"] + } + } + } + + try { + return JSON.stringify(error) + } catch { + return String(error) + } +} + +interface EventPropertiesLike { + [key: string]: unknown +} + +export function getSessionErrorMessage(properties: EventPropertiesLike): string | undefined { + const errorRaw = properties["error"] + if (!isRecord(errorRaw)) return undefined + + const dataRaw = errorRaw["data"] + if (isRecord(dataRaw)) { + const message = dataRaw["message"] + if (typeof message === "string") return message + } + + const message = errorRaw["message"] + return typeof message === "string" ? message : undefined +} diff --git a/src/features/background-agent/fallback-retry-handler.test.ts b/src/features/background-agent/fallback-retry-handler.test.ts new file mode 100644 index 00000000..1a99503f --- /dev/null +++ b/src/features/background-agent/fallback-retry-handler.test.ts @@ -0,0 +1,270 @@ +import { describe, test, expect, mock, beforeEach } from "bun:test" + +mock.module("../../shared", () => ({ + log: mock(() => {}), + readConnectedProvidersCache: mock(() => null), + readProviderModelsCache: mock(() => null), +})) + +mock.module("../../shared/model-error-classifier", () => ({ + shouldRetryError: mock(() => true), + getNextFallback: mock((chain: Array<{ model: string }>, attempt: number) => chain[attempt]), + hasMoreFallbacks: mock((chain: Array<{ model: string }>, attempt: number) => attempt < chain.length), + selectFallbackProvider: mock((providers: string[]) => providers[0]), +})) + +mock.module("../../shared/provider-model-id-transform", () => ({ + transformModelForProvider: mock((_provider: string, model: string) => model), +})) + +import { tryFallbackRetry } from "./fallback-retry-handler" +import { shouldRetryError } from "../../shared/model-error-classifier" +import type { BackgroundTask } from "./types" +import type { ConcurrencyManager } from "./concurrency" + +function createMockTask(overrides: Partial = {}): BackgroundTask { + return { + id: "test-task-1", + description: "test task", + prompt: "test prompt", + agent: "sisyphus-junior", + status: "error", + parentSessionID: "parent-session-1", + parentMessageID: "parent-message-1", + fallbackChain: [ + { model: "fallback-model-1", providers: ["provider-a"], variant: undefined }, + { model: "fallback-model-2", providers: ["provider-b"], variant: undefined }, + ], + attemptCount: 0, + concurrencyKey: "provider-a/original-model", + model: { providerID: "provider-a", modelID: "original-model" }, + ...overrides, + } +} + +function createMockConcurrencyManager(): ConcurrencyManager { + return { + release: mock(() => {}), + acquire: mock(async () => {}), + getQueueLength: mock(() => 0), + getActiveCount: mock(() => 0), + } as unknown as ConcurrencyManager +} + +function createMockClient() { + return { + session: { + abort: mock(async () => ({})), + }, + } as any +} + +function createDefaultArgs(taskOverrides: Partial = {}) { + const processKeyFn = mock(() => {}) + const queuesByKey = new Map>() + const idleDeferralTimers = new Map>() + const concurrencyManager = createMockConcurrencyManager() + const client = createMockClient() + const task = createMockTask(taskOverrides) + + return { + task, + errorInfo: { name: "OverloadedError", message: "model overloaded" }, + source: "polling", + concurrencyManager, + client, + idleDeferralTimers, + queuesByKey, + processKey: processKeyFn, + } +} + +describe("tryFallbackRetry", () => { + beforeEach(() => { + ;(shouldRetryError as any).mockImplementation(() => true) + }) + + describe("#given retryable error with fallback chain", () => { + test("returns true and enqueues retry", () => { + const args = createDefaultArgs() + + const result = tryFallbackRetry(args) + + expect(result).toBe(true) + }) + + test("resets task status to pending", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + expect(args.task.status).toBe("pending") + }) + + test("increments attemptCount", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + expect(args.task.attemptCount).toBe(1) + }) + + test("updates task model to fallback", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + expect(args.task.model?.modelID).toBe("fallback-model-1") + expect(args.task.model?.providerID).toBe("provider-a") + }) + + test("clears sessionID and startedAt", () => { + const args = createDefaultArgs({ + sessionID: "old-session", + startedAt: new Date(), + }) + + tryFallbackRetry(args) + + expect(args.task.sessionID).toBeUndefined() + expect(args.task.startedAt).toBeUndefined() + }) + + test("clears error field", () => { + const args = createDefaultArgs({ error: "previous error" }) + + tryFallbackRetry(args) + + expect(args.task.error).toBeUndefined() + }) + + test("sets new queuedAt", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + expect(args.task.queuedAt).toBeInstanceOf(Date) + }) + + test("releases concurrency slot", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + expect(args.concurrencyManager.release).toHaveBeenCalledWith("provider-a/original-model") + }) + + test("clears concurrencyKey after release", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + expect(args.task.concurrencyKey).toBeUndefined() + }) + + test("aborts existing session", () => { + const args = createDefaultArgs({ sessionID: "session-to-abort" }) + + tryFallbackRetry(args) + + expect(args.client.session.abort).toHaveBeenCalledWith({ + path: { id: "session-to-abort" }, + }) + }) + + test("adds retry input to queue and calls processKey", () => { + const args = createDefaultArgs() + + tryFallbackRetry(args) + + const key = `${args.task.model!.providerID}/${args.task.model!.modelID}` + const queue = args.queuesByKey.get(key) + expect(queue).toBeDefined() + expect(queue!.length).toBe(1) + expect(queue![0].task).toBe(args.task) + expect(args.processKey).toHaveBeenCalledWith(key) + }) + }) + + describe("#given non-retryable error", () => { + test("returns false when shouldRetryError returns false", () => { + ;(shouldRetryError as any).mockImplementation(() => false) + const args = createDefaultArgs() + + const result = tryFallbackRetry(args) + + expect(result).toBe(false) + }) + }) + + describe("#given no fallback chain", () => { + test("returns false when fallbackChain is undefined", () => { + const args = createDefaultArgs({ fallbackChain: undefined }) + + const result = tryFallbackRetry(args) + + expect(result).toBe(false) + }) + + test("returns false when fallbackChain is empty", () => { + const args = createDefaultArgs({ fallbackChain: [] }) + + const result = tryFallbackRetry(args) + + expect(result).toBe(false) + }) + }) + + describe("#given exhausted fallbacks", () => { + test("returns false when attemptCount exceeds chain length", () => { + const args = createDefaultArgs({ attemptCount: 5 }) + + const result = tryFallbackRetry(args) + + expect(result).toBe(false) + }) + }) + + describe("#given task without concurrency key", () => { + test("skips concurrency release", () => { + const args = createDefaultArgs({ concurrencyKey: undefined }) + + tryFallbackRetry(args) + + expect(args.concurrencyManager.release).not.toHaveBeenCalled() + }) + }) + + describe("#given task without session", () => { + test("skips session abort", () => { + const args = createDefaultArgs({ sessionID: undefined }) + + tryFallbackRetry(args) + + expect(args.client.session.abort).not.toHaveBeenCalled() + }) + }) + + describe("#given active idle deferral timer", () => { + test("clears the timer and removes from map", () => { + const args = createDefaultArgs() + const timerId = setTimeout(() => {}, 10000) + args.idleDeferralTimers.set("test-task-1", timerId) + + tryFallbackRetry(args) + + expect(args.idleDeferralTimers.has("test-task-1")).toBe(false) + }) + }) + + describe("#given second attempt", () => { + test("uses second fallback in chain", () => { + const args = createDefaultArgs({ attemptCount: 1 }) + + tryFallbackRetry(args) + + expect(args.task.model?.modelID).toBe("fallback-model-2") + expect(args.task.attemptCount).toBe(2) + }) + }) +}) diff --git a/src/features/background-agent/fallback-retry-handler.ts b/src/features/background-agent/fallback-retry-handler.ts new file mode 100644 index 00000000..f21b9228 --- /dev/null +++ b/src/features/background-agent/fallback-retry-handler.ts @@ -0,0 +1,125 @@ +import type { BackgroundTask, LaunchInput } from "./types" +import type { FallbackEntry } from "../../shared/model-requirements" +import type { ConcurrencyManager } from "./concurrency" +import type { OpencodeClient, QueueItem } from "./constants" +import { log, readConnectedProvidersCache, readProviderModelsCache } from "../../shared" +import { + shouldRetryError, + getNextFallback, + hasMoreFallbacks, + selectFallbackProvider, +} from "../../shared/model-error-classifier" +import { transformModelForProvider } from "../../shared/provider-model-id-transform" + +export function tryFallbackRetry(args: { + task: BackgroundTask + errorInfo: { name?: string; message?: string } + source: string + concurrencyManager: ConcurrencyManager + client: OpencodeClient + idleDeferralTimers: Map> + queuesByKey: Map + processKey: (key: string) => void +}): boolean { + const { task, errorInfo, source, concurrencyManager, client, idleDeferralTimers, queuesByKey, processKey } = args + const fallbackChain = task.fallbackChain + const canRetry = + shouldRetryError(errorInfo) && + fallbackChain && + fallbackChain.length > 0 && + hasMoreFallbacks(fallbackChain, task.attemptCount ?? 0) + + if (!canRetry) return false + + const attemptCount = task.attemptCount ?? 0 + const providerModelsCache = readProviderModelsCache() + const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache() + const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null + + const isReachable = (entry: FallbackEntry): boolean => { + if (!connectedSet) return true + return entry.providers.some((p) => connectedSet.has(p.toLowerCase())) + } + + let selectedAttemptCount = attemptCount + let nextFallback: FallbackEntry | undefined + while (fallbackChain && selectedAttemptCount < fallbackChain.length) { + const candidate = getNextFallback(fallbackChain, selectedAttemptCount) + if (!candidate) break + selectedAttemptCount++ + if (!isReachable(candidate)) { + log("[background-agent] Skipping unreachable fallback:", { + taskId: task.id, + source, + model: candidate.model, + providers: candidate.providers, + }) + continue + } + nextFallback = candidate + break + } + if (!nextFallback) return false + + const providerID = selectFallbackProvider( + nextFallback.providers, + task.model?.providerID, + ) + + log("[background-agent] Retryable error, attempting fallback:", { + taskId: task.id, + source, + errorName: errorInfo.name, + errorMessage: errorInfo.message?.slice(0, 100), + attemptCount: selectedAttemptCount, + nextModel: `${providerID}/${nextFallback.model}`, + }) + + if (task.concurrencyKey) { + concurrencyManager.release(task.concurrencyKey) + task.concurrencyKey = undefined + } + + if (task.sessionID) { + client.session.abort({ path: { id: task.sessionID } }).catch(() => {}) + } + + const idleTimer = idleDeferralTimers.get(task.id) + if (idleTimer) { + clearTimeout(idleTimer) + idleDeferralTimers.delete(task.id) + } + + task.attemptCount = selectedAttemptCount + const transformedModelId = transformModelForProvider(providerID, nextFallback.model) + task.model = { + providerID, + modelID: transformedModelId, + variant: nextFallback.variant, + } + task.status = "pending" + task.sessionID = undefined + task.startedAt = undefined + task.queuedAt = new Date() + task.error = undefined + + const key = task.model ? `${task.model.providerID}/${task.model.modelID}` : task.agent + const queue = queuesByKey.get(key) ?? [] + const retryInput: LaunchInput = { + description: task.description, + prompt: task.prompt, + agent: task.agent, + parentSessionID: task.parentSessionID, + parentMessageID: task.parentMessageID, + parentModel: task.parentModel, + parentAgent: task.parentAgent, + parentTools: task.parentTools, + model: task.model, + fallbackChain: task.fallbackChain, + category: task.category, + } + queue.push({ task, input: retryInput }) + queuesByKey.set(key, queue) + processKey(key) + return true +} diff --git a/src/features/background-agent/manager.ts b/src/features/background-agent/manager.ts index ee49b20b..a77e063a 100644 --- a/src/features/background-agent/manager.ts +++ b/src/features/background-agent/manager.ts @@ -5,16 +5,14 @@ import type { LaunchInput, ResumeInput, } from "./types" -import type { FallbackEntry } from "../../shared/model-requirements" import { TaskHistory } from "./task-history" import { log, getAgentToolRestrictions, + getMessageDir, normalizePromptTools, normalizeSDKResponse, promptWithModelSuggestionRetry, - readConnectedProvidersCache, - readProviderModelsCache, resolveInheritedPromptTools, createInternalAgentTextPart, } from "../../shared" @@ -25,28 +23,29 @@ import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema" import { isInsideTmux } from "../../shared/tmux" import { shouldRetryError, - getNextFallback, hasMoreFallbacks, - selectFallbackProvider, } from "../../shared/model-error-classifier" -import { transformModelForProvider } from "../../shared/provider-model-id-transform" import { - DEFAULT_MESSAGE_STALENESS_TIMEOUT_MS, - DEFAULT_STALE_TIMEOUT_MS, MIN_IDLE_TIME_MS, - MIN_RUNTIME_BEFORE_STALE_MS, POLLING_INTERVAL_MS, TASK_CLEANUP_DELAY_MS, - TASK_TTL_MS, } from "./constants" import { subagentSessions } from "../claude-code-session-state" import { getTaskToastManager } from "../task-toast-manager" -import { MESSAGE_STORAGE, type StoredMessage } from "../hook-message-injector" -import { existsSync, readFileSync, readdirSync } from "node:fs" -import { join } from "node:path" - -type ProcessCleanupEvent = NodeJS.Signals | "beforeExit" | "exit" +import { formatDuration } from "./duration-formatter" +import { + isAbortedSessionError, + extractErrorName, + extractErrorMessage, + getSessionErrorMessage, + isRecord, +} from "./error-classifier" +import { tryFallbackRetry } from "./fallback-retry-handler" +import { registerManagerForCleanup, unregisterManagerForCleanup } from "./process-cleanup" +import { isCompactionAgent, findNearestMessageExcludingCompaction } from "./compaction-aware-message-resolver" +import { pruneStaleTasksAndNotifications } from "./task-poller" +import { checkAndInterruptStaleTasks } from "./task-poller" type OpencodeClient = PluginInput["client"] @@ -89,9 +88,7 @@ export interface SubagentSessionCreatedEvent { export type OnSubagentSessionCreated = (event: SubagentSessionCreatedEvent) => Promise export class BackgroundManager { - private static cleanupManagers = new Set() - private static cleanupRegistered = false - private static cleanupHandlers = new Map void>() + private tasks: Map private notifications: Map @@ -705,8 +702,8 @@ export class BackgroundManager { if (!assistantError) return const errorInfo = { - name: this.extractErrorName(assistantError), - message: this.extractErrorMessage(assistantError), + name: extractErrorName(assistantError), + message: extractErrorMessage(assistantError), } this.tryFallbackRetry(task, errorInfo, "message.updated") } @@ -809,7 +806,7 @@ export class BackgroundManager { const errorObj = props?.error as { name?: string; message?: string } | undefined const errorName = errorObj?.name - const errorMessage = props ? this.getSessionErrorMessage(props) : undefined + const errorMessage = props ? getSessionErrorMessage(props) : undefined const errorInfo = { name: errorName, message: errorMessage } if (this.tryFallbackRetry(task, errorInfo, "session.error")) return @@ -934,110 +931,20 @@ export class BackgroundManager { errorInfo: { name?: string; message?: string }, source: string, ): boolean { - const fallbackChain = task.fallbackChain - const canRetry = - shouldRetryError(errorInfo) && - fallbackChain && - fallbackChain.length > 0 && - hasMoreFallbacks(fallbackChain, task.attemptCount ?? 0) - - if (!canRetry) return false - - const attemptCount = task.attemptCount ?? 0 - const providerModelsCache = readProviderModelsCache() - const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache() - const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null - - const isReachable = (entry: FallbackEntry): boolean => { - if (!connectedSet) return true - - // Gate only on provider connectivity. Provider model lists can be stale/incomplete, - // especially after users manually add models to opencode.json. - return entry.providers.some((p) => connectedSet.has(p.toLowerCase())) - } - - let selectedAttemptCount = attemptCount - let nextFallback: FallbackEntry | undefined - while (fallbackChain && selectedAttemptCount < fallbackChain.length) { - const candidate = getNextFallback(fallbackChain, selectedAttemptCount) - if (!candidate) break - selectedAttemptCount++ - if (!isReachable(candidate)) { - log("[background-agent] Skipping unreachable fallback:", { - taskId: task.id, - source, - model: candidate.model, - providers: candidate.providers, - }) - continue - } - nextFallback = candidate - break - } - if (!nextFallback) return false - - const providerID = selectFallbackProvider( - nextFallback.providers, - task.model?.providerID, - ) - - log("[background-agent] Retryable error, attempting fallback:", { - taskId: task.id, + const result = tryFallbackRetry({ + task, + errorInfo, source, - errorName: errorInfo.name, - errorMessage: errorInfo.message?.slice(0, 100), - attemptCount: selectedAttemptCount, - nextModel: `${providerID}/${nextFallback.model}`, + concurrencyManager: this.concurrencyManager, + client: this.client, + idleDeferralTimers: this.idleDeferralTimers, + queuesByKey: this.queuesByKey, + processKey: (key: string) => this.processKey(key), }) - - if (task.concurrencyKey) { - this.concurrencyManager.release(task.concurrencyKey) - task.concurrencyKey = undefined - } - - if (task.sessionID) { - this.client.session.abort({ path: { id: task.sessionID } }).catch(() => {}) + if (result && task.sessionID) { subagentSessions.delete(task.sessionID) } - - const idleTimer = this.idleDeferralTimers.get(task.id) - if (idleTimer) { - clearTimeout(idleTimer) - this.idleDeferralTimers.delete(task.id) - } - - task.attemptCount = selectedAttemptCount - const transformedModelId = transformModelForProvider(providerID, nextFallback.model) - task.model = { - providerID, - modelID: transformedModelId, - variant: nextFallback.variant, - } - task.status = "pending" - task.sessionID = undefined - task.startedAt = undefined - task.queuedAt = new Date() - task.error = undefined - - const key = task.model ? `${task.model.providerID}/${task.model.modelID}` : task.agent - const queue = this.queuesByKey.get(key) ?? [] - const retryInput: LaunchInput = { - description: task.description, - prompt: task.prompt, - agent: task.agent, - parentSessionID: task.parentSessionID, - parentMessageID: task.parentMessageID, - parentModel: task.parentModel, - parentAgent: task.parentAgent, - parentTools: task.parentTools, - model: task.model, - fallbackChain: task.fallbackChain, - category: task.category, - } - queue.push({ task, input: retryInput }) - this.queuesByKey.set(key, queue) - this.processKey(key) - return true + return result } markForNotification(task: BackgroundTask): void { @@ -1256,45 +1163,11 @@ export class BackgroundManager { } private registerProcessCleanup(): void { - BackgroundManager.cleanupManagers.add(this) - - if (BackgroundManager.cleanupRegistered) return - BackgroundManager.cleanupRegistered = true - - const cleanupAll = () => { - for (const manager of BackgroundManager.cleanupManagers) { - try { - manager.shutdown() - } catch (error) { - log("[background-agent] Error during shutdown cleanup:", error) - } - } - } - - const registerSignal = (signal: ProcessCleanupEvent, exitAfter: boolean): void => { - const listener = registerProcessSignal(signal, cleanupAll, exitAfter) - BackgroundManager.cleanupHandlers.set(signal, listener) - } - - registerSignal("SIGINT", true) - registerSignal("SIGTERM", true) - if (process.platform === "win32") { - registerSignal("SIGBREAK", true) - } - registerSignal("beforeExit", false) - registerSignal("exit", false) + registerManagerForCleanup(this) } private unregisterProcessCleanup(): void { - BackgroundManager.cleanupManagers.delete(this) - - if (BackgroundManager.cleanupManagers.size > 0) return - - for (const [signal, listener] of BackgroundManager.cleanupHandlers.entries()) { - process.off(signal, listener) - } - BackgroundManager.cleanupHandlers.clear() - BackgroundManager.cleanupRegistered = false + unregisterManagerForCleanup(this) } @@ -1368,7 +1241,7 @@ export class BackgroundManager { // Note: Callers must release concurrency before calling this method // to ensure slots are freed even if notification fails - const duration = this.formatDuration(task.startedAt ?? new Date(), task.completedAt) + const duration = formatDuration(task.startedAt ?? new Date(), task.completedAt) log("[background-agent] notifyParentSession called for task:", task.id) @@ -1455,7 +1328,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea if (isCompactionAgent(info?.agent)) { continue } - const normalizedTools = this.isRecord(info?.tools) + const normalizedTools = isRecord(info?.tools) ? normalizePromptTools(info.tools as Record) : undefined if (info?.agent || info?.model || (info?.modelID && info?.providerID) || normalizedTools) { @@ -1466,7 +1339,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea } } } catch (error) { - if (this.isAbortedSessionError(error)) { + if (isAbortedSessionError(error)) { log("[background-agent] Parent session aborted while loading messages; using messageDir fallback:", { taskId: task.id, parentSessionID: task.parentSessionID, @@ -1506,7 +1379,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea noReply: !allComplete, }) } catch (error) { - if (this.isAbortedSessionError(error)) { + if (isAbortedSessionError(error)) { log("[background-agent] Parent session aborted while sending notification; continuing cleanup:", { taskId: task.id, parentSessionID: task.parentSessionID, @@ -1544,97 +1417,11 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea } private formatDuration(start: Date, end?: Date): string { - const duration = (end ?? new Date()).getTime() - start.getTime() - const seconds = Math.floor(duration / 1000) - const minutes = Math.floor(seconds / 60) - const hours = Math.floor(minutes / 60) - - if (hours > 0) { - return `${hours}h ${minutes % 60}m ${seconds % 60}s` - } else if (minutes > 0) { - return `${minutes}m ${seconds % 60}s` - } - return `${seconds}s` + return formatDuration(start, end) } private isAbortedSessionError(error: unknown): boolean { - const message = this.getErrorText(error) - return message.toLowerCase().includes("aborted") - } - - private getErrorText(error: unknown): string { - if (!error) return "" - if (typeof error === "string") return error - if (error instanceof Error) { - return `${error.name}: ${error.message}` - } - if (typeof error === "object" && error !== null) { - if ("message" in error && typeof error.message === "string") { - return error.message - } - if ("name" in error && typeof error.name === "string") { - return error.name - } - } - return "" - } - - private extractErrorName(error: unknown): string | undefined { - if (this.isRecord(error) && typeof error["name"] === "string") return error["name"] - if (error instanceof Error) return error.name - return undefined - } - - private extractErrorMessage(error: unknown): string | undefined { - if (!error) return undefined - if (typeof error === "string") return error - if (error instanceof Error) return error.message - - if (this.isRecord(error)) { - const dataRaw = error["data"] - const candidates: unknown[] = [ - error, - dataRaw, - error["error"], - this.isRecord(dataRaw) ? (dataRaw as Record)["error"] : undefined, - error["cause"], - ] - - for (const candidate of candidates) { - if (typeof candidate === "string" && candidate.length > 0) return candidate - if ( - this.isRecord(candidate) && - typeof candidate["message"] === "string" && - candidate["message"].length > 0 - ) { - return candidate["message"] - } - } - } - - try { - return JSON.stringify(error) - } catch { - return String(error) - } - } - - private isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null - } - - private getSessionErrorMessage(properties: EventProperties): string | undefined { - const errorRaw = properties["error"] - if (!this.isRecord(errorRaw)) return undefined - - const dataRaw = errorRaw["data"] - if (this.isRecord(dataRaw)) { - const message = dataRaw["message"] - if (typeof message === "string") return message - } - - const message = errorRaw["message"] - return typeof message === "string" ? message : undefined + return isAbortedSessionError(error) } private hasRunningTasks(): boolean { @@ -1645,25 +1432,12 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea } private pruneStaleTasksAndNotifications(): void { - const now = Date.now() - - for (const [taskId, task] of this.tasks.entries()) { - const wasPending = task.status === "pending" - const timestamp = task.status === "pending" - ? task.queuedAt?.getTime() - : task.startedAt?.getTime() - - if (!timestamp) { - continue - } - - const age = now - timestamp - if (age > TASK_TTL_MS) { - const errorMessage = task.status === "pending" - ? "Task timed out while queued (30 minutes)" - : "Task timed out after 30 minutes" - - log("[background-agent] Pruning stale task:", { taskId, status: task.status, age: Math.round(age / 1000) + "s" }) + pruneStaleTasksAndNotifications({ + tasks: this.tasks, + notifications: this.notifications, + onTaskPruned: (taskId, task, errorMessage) => { + const wasPending = task.status === "pending" + log("[background-agent] Pruning stale task:", { taskId, status: task.status, age: Math.round(((wasPending ? task.queuedAt?.getTime() : task.startedAt?.getTime()) ? (Date.now() - (wasPending ? task.queuedAt!.getTime() : task.startedAt!.getTime())) : 0) / 1000) + "s" }) task.status = "error" task.error = errorMessage task.completedAt = new Date() @@ -1671,7 +1445,6 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea this.concurrencyManager.release(task.concurrencyKey) task.concurrencyKey = undefined } - // Clean up pendingByParent to prevent stale entries this.cleanupPendingByParent(task) if (wasPending) { const key = task.model @@ -1698,97 +1471,21 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea subagentSessions.delete(task.sessionID) SessionCategoryRegistry.remove(task.sessionID) } - } - } - - for (const [sessionID, notifications] of this.notifications.entries()) { - if (notifications.length === 0) { - this.notifications.delete(sessionID) - continue - } - const validNotifications = notifications.filter((task) => { - if (!task.startedAt) return false - const age = now - task.startedAt.getTime() - return age <= TASK_TTL_MS - }) - if (validNotifications.length === 0) { - this.notifications.delete(sessionID) - } else if (validNotifications.length !== notifications.length) { - this.notifications.set(sessionID, validNotifications) - } - } + }, + }) } private async checkAndInterruptStaleTasks( allStatuses: Record = {}, ): Promise { - const staleTimeoutMs = this.config?.staleTimeoutMs ?? DEFAULT_STALE_TIMEOUT_MS - const messageStalenessMs = this.config?.messageStalenessTimeoutMs ?? DEFAULT_MESSAGE_STALENESS_TIMEOUT_MS - const now = Date.now() - - for (const task of this.tasks.values()) { - if (task.status !== "running") continue - - const startedAt = task.startedAt - const sessionID = task.sessionID - if (!startedAt || !sessionID) continue - - const sessionStatus = allStatuses[sessionID]?.type - const sessionIsRunning = sessionStatus !== undefined && sessionStatus !== "idle" - const runtime = now - startedAt.getTime() - - if (!task.progress?.lastUpdate) { - if (sessionIsRunning) continue - if (runtime <= messageStalenessMs) continue - - const staleMinutes = Math.round(runtime / 60000) - task.status = "cancelled" - task.error = `Stale timeout (no activity for ${staleMinutes}min since start)` - task.completedAt = new Date() - - if (task.concurrencyKey) { - this.concurrencyManager.release(task.concurrencyKey) - task.concurrencyKey = undefined - } - - this.client.session.abort({ path: { id: sessionID } }).catch(() => {}) - log(`[background-agent] Task ${task.id} interrupted: no progress since start`) - - try { - await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) - } catch (err) { - log("[background-agent] Error in notifyParentSession for stale task:", { taskId: task.id, error: err }) - } - continue - } - - if (sessionIsRunning) continue - - if (runtime < MIN_RUNTIME_BEFORE_STALE_MS) continue - - const timeSinceLastUpdate = now - task.progress.lastUpdate.getTime() - if (timeSinceLastUpdate <= staleTimeoutMs) continue - if (task.status !== "running") continue - - const staleMinutes = Math.round(timeSinceLastUpdate / 60000) - task.status = "cancelled" - task.error = `Stale timeout (no activity for ${staleMinutes}min)` - task.completedAt = new Date() - - if (task.concurrencyKey) { - this.concurrencyManager.release(task.concurrencyKey) - task.concurrencyKey = undefined - } - - this.client.session.abort({ path: { id: sessionID } }).catch(() => {}) - log(`[background-agent] Task ${task.id} interrupted: stale timeout`) - - try { - await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)) - } catch (err) { - log("[background-agent] Error in notifyParentSession for stale task:", { taskId: task.id, error: err }) - } - } + await checkAndInterruptStaleTasks({ + tasks: this.tasks.values(), + client: this.client, + config: this.config, + concurrencyManager: this.concurrencyManager, + notifyParentSession: (task) => this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)), + sessionStatuses: allStatuses, + }) } private async pollRunningTasks(): Promise { @@ -1948,89 +1645,3 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea return current } } - -function registerProcessSignal( - signal: ProcessCleanupEvent, - handler: () => void, - exitAfter: boolean -): () => void { - const listener = () => { - handler() - if (exitAfter) { - // Set exitCode and schedule exit after delay to allow other handlers to complete async cleanup - // Use 6s delay to accommodate LSP cleanup (5s timeout + 1s SIGKILL wait) - process.exitCode = 0 - setTimeout(() => process.exit(), 6000) - } - } - process.on(signal, listener) - return listener -} - - -function getMessageDir(sessionID: string): string | null { - if (!existsSync(MESSAGE_STORAGE)) return null - - const directPath = join(MESSAGE_STORAGE, sessionID) - if (existsSync(directPath)) return directPath - - for (const dir of readdirSync(MESSAGE_STORAGE)) { - const sessionPath = join(MESSAGE_STORAGE, dir, sessionID) - if (existsSync(sessionPath)) return sessionPath - } - return null -} - -function isCompactionAgent(agent: string | undefined): boolean { - return agent?.trim().toLowerCase() === "compaction" -} - -function hasFullAgentAndModel(message: StoredMessage): boolean { - return !!message.agent && - !isCompactionAgent(message.agent) && - !!message.model?.providerID && - !!message.model?.modelID -} - -function hasPartialAgentOrModel(message: StoredMessage): boolean { - const hasAgent = !!message.agent && !isCompactionAgent(message.agent) - const hasModel = !!message.model?.providerID && !!message.model?.modelID - return hasAgent || hasModel -} - -function findNearestMessageExcludingCompaction(messageDir: string): StoredMessage | null { - try { - const files = readdirSync(messageDir) - .filter((name) => name.endsWith(".json")) - .sort() - .reverse() - - for (const file of files) { - try { - const content = readFileSync(join(messageDir, file), "utf-8") - const parsed = JSON.parse(content) as StoredMessage - if (hasFullAgentAndModel(parsed)) { - return parsed - } - } catch { - continue - } - } - - for (const file of files) { - try { - const content = readFileSync(join(messageDir, file), "utf-8") - const parsed = JSON.parse(content) as StoredMessage - if (hasPartialAgentOrModel(parsed)) { - return parsed - } - } catch { - continue - } - } - } catch { - return null - } - - return null -} diff --git a/src/features/background-agent/notification-builder.ts b/src/features/background-agent/notification-builder.ts deleted file mode 100644 index f5bffa65..00000000 --- a/src/features/background-agent/notification-builder.ts +++ /dev/null @@ -1,41 +0,0 @@ -import type { BackgroundTask } from "./types" - -export function buildBackgroundTaskNotificationText(args: { - task: BackgroundTask - duration: string - allComplete: boolean - remainingCount: number - completedTasks: BackgroundTask[] -}): string { - const { task, duration, allComplete, remainingCount, completedTasks } = args - const statusText = - task.status === "completed" ? "COMPLETED" : task.status === "interrupt" ? "INTERRUPTED" : task.status === "error" ? "ERROR" : "CANCELLED" - const errorInfo = task.error ? `\n**Error:** ${task.error}` : "" - - if (allComplete) { - const completedTasksText = completedTasks - .map((t) => `- \`${t.id}\`: ${t.description}`) - .join("\n") - - return ` -[ALL BACKGROUND TASKS COMPLETE] - -**Completed:** -${completedTasksText || `- \`${task.id}\`: ${task.description}`} - -Use \`background_output(task_id="")\` to retrieve each result. -` - } - - return ` -[BACKGROUND TASK ${statusText}] -**ID:** \`${task.id}\` -**Description:** ${task.description} -**Duration:** ${duration}${errorInfo} - -**${remainingCount} task${remainingCount === 1 ? "" : "s"} still in progress.** You WILL be notified when ALL complete. -Do NOT poll - continue productive work. - -Use \`background_output(task_id="${task.id}")\` to retrieve this result when ready. -` -} diff --git a/src/features/background-agent/process-cleanup.test.ts b/src/features/background-agent/process-cleanup.test.ts new file mode 100644 index 00000000..9560852d --- /dev/null +++ b/src/features/background-agent/process-cleanup.test.ts @@ -0,0 +1,156 @@ +import { describe, test, expect, beforeEach, afterEach, mock } from "bun:test" +import { + registerManagerForCleanup, + unregisterManagerForCleanup, + _resetForTesting, +} from "./process-cleanup" + +describe("process-cleanup", () => { + const registeredManagers: Array<{ shutdown: () => void }> = [] + const mockShutdown = mock(() => {}) + + const processOnCalls: Array<[string, Function]> = [] + const processOffCalls: Array<[string, Function]> = [] + const originalProcessOn = process.on.bind(process) + const originalProcessOff = process.off.bind(process) + + beforeEach(() => { + mockShutdown.mockClear() + processOnCalls.length = 0 + processOffCalls.length = 0 + registeredManagers.length = 0 + + process.on = originalProcessOn as any + process.off = originalProcessOff as any + _resetForTesting() + + process.on = ((event: string, listener: Function) => { + processOnCalls.push([event, listener]) + return process + }) as any + + process.off = ((event: string, listener: Function) => { + processOffCalls.push([event, listener]) + return process + }) as any + }) + + afterEach(() => { + process.on = originalProcessOn as any + process.off = originalProcessOff as any + + for (const manager of [...registeredManagers]) { + unregisterManagerForCleanup(manager) + } + }) + + describe("registerManagerForCleanup", () => { + test("registers signal handlers on first manager", () => { + const manager = { shutdown: mockShutdown } + registeredManagers.push(manager) + + registerManagerForCleanup(manager) + + const signals = processOnCalls.map(([signal]) => signal) + expect(signals).toContain("SIGINT") + expect(signals).toContain("SIGTERM") + expect(signals).toContain("beforeExit") + expect(signals).toContain("exit") + }) + + test("signal listener calls shutdown on registered manager", () => { + const manager = { shutdown: mockShutdown } + registeredManagers.push(manager) + + registerManagerForCleanup(manager) + + const [, listener] = processOnCalls[0] + listener() + + expect(mockShutdown).toHaveBeenCalled() + }) + + test("multiple managers all get shutdown when signal fires", () => { + const shutdown1 = mock(() => {}) + const shutdown2 = mock(() => {}) + const shutdown3 = mock(() => {}) + const manager1 = { shutdown: shutdown1 } + const manager2 = { shutdown: shutdown2 } + const manager3 = { shutdown: shutdown3 } + registeredManagers.push(manager1, manager2, manager3) + + registerManagerForCleanup(manager1) + registerManagerForCleanup(manager2) + registerManagerForCleanup(manager3) + + const [, listener] = processOnCalls[0] + listener() + + expect(shutdown1).toHaveBeenCalledTimes(1) + expect(shutdown2).toHaveBeenCalledTimes(1) + expect(shutdown3).toHaveBeenCalledTimes(1) + }) + + test("does not re-register signal handlers for subsequent managers", () => { + const manager1 = { shutdown: mockShutdown } + const manager2 = { shutdown: mockShutdown } + registeredManagers.push(manager1, manager2) + + registerManagerForCleanup(manager1) + const callsAfterFirst = processOnCalls.length + + registerManagerForCleanup(manager2) + + expect(processOnCalls.length).toBe(callsAfterFirst) + }) + }) + + describe("unregisterManagerForCleanup", () => { + test("removes signal handlers when last manager unregisters", () => { + const manager = { shutdown: mockShutdown } + registeredManagers.push(manager) + + registerManagerForCleanup(manager) + unregisterManagerForCleanup(manager) + registeredManagers.length = 0 + + const offSignals = processOffCalls.map(([signal]) => signal) + expect(offSignals).toContain("SIGINT") + expect(offSignals).toContain("SIGTERM") + expect(offSignals).toContain("beforeExit") + expect(offSignals).toContain("exit") + }) + + test("keeps signal handlers when other managers remain", () => { + const manager1 = { shutdown: mockShutdown } + const manager2 = { shutdown: mockShutdown } + registeredManagers.push(manager1, manager2) + + registerManagerForCleanup(manager1) + registerManagerForCleanup(manager2) + + unregisterManagerForCleanup(manager2) + + expect(processOffCalls.length).toBe(0) + }) + + test("remaining managers still get shutdown after partial unregister", () => { + const shutdown1 = mock(() => {}) + const shutdown2 = mock(() => {}) + const manager1 = { shutdown: shutdown1 } + const manager2 = { shutdown: shutdown2 } + registeredManagers.push(manager1, manager2) + + registerManagerForCleanup(manager1) + registerManagerForCleanup(manager2) + + const [, listener] = processOnCalls[0] + unregisterManagerForCleanup(manager2) + + listener() + + expect(shutdown1).toHaveBeenCalledTimes(1) + expect(shutdown2).not.toHaveBeenCalled() + }) + }) +}) diff --git a/src/features/background-agent/process-cleanup.ts b/src/features/background-agent/process-cleanup.ts new file mode 100644 index 00000000..e5d43214 --- /dev/null +++ b/src/features/background-agent/process-cleanup.ts @@ -0,0 +1,81 @@ +import { log } from "../../shared" + +type ProcessCleanupEvent = NodeJS.Signals | "beforeExit" | "exit" + +function registerProcessSignal( + signal: ProcessCleanupEvent, + handler: () => void, + exitAfter: boolean +): () => void { + const listener = () => { + handler() + if (exitAfter) { + process.exitCode = 0 + setTimeout(() => process.exit(), 6000) + } + } + process.on(signal, listener) + return listener +} + +interface CleanupTarget { + shutdown(): void +} + +const cleanupManagers = new Set() +let cleanupRegistered = false +const cleanupHandlers = new Map void>() + +export function registerManagerForCleanup(manager: CleanupTarget): void { + cleanupManagers.add(manager) + + if (cleanupRegistered) return + cleanupRegistered = true + + const cleanupAll = () => { + for (const m of cleanupManagers) { + try { + m.shutdown() + } catch (error) { + log("[background-agent] Error during shutdown cleanup:", error) + } + } + } + + const registerSignal = (signal: ProcessCleanupEvent, exitAfter: boolean): void => { + const listener = registerProcessSignal(signal, cleanupAll, exitAfter) + cleanupHandlers.set(signal, listener) + } + + registerSignal("SIGINT", true) + registerSignal("SIGTERM", true) + if (process.platform === "win32") { + registerSignal("SIGBREAK", true) + } + registerSignal("beforeExit", false) + registerSignal("exit", false) +} + +export function unregisterManagerForCleanup(manager: CleanupTarget): void { + cleanupManagers.delete(manager) + + if (cleanupManagers.size > 0) return + + for (const [signal, listener] of cleanupHandlers.entries()) { + process.off(signal, listener) + } + cleanupHandlers.clear() + cleanupRegistered = false +} + +/** @internal — test-only reset for module-level singleton state */ +export function _resetForTesting(): void { + for (const manager of [...cleanupManagers]) { + cleanupManagers.delete(manager) + } + for (const [signal, listener] of cleanupHandlers.entries()) { + process.off(signal, listener) + } + cleanupHandlers.clear() + cleanupRegistered = false +}