Merge pull request #2034 from code-yeongyu/refactor/background-manager-extraction
Extract inline logic from BackgroundManager into focused modules
This commit is contained in:
commit
ed43cd4c85
@ -0,0 +1,190 @@
|
||||
import { describe, test, expect, beforeEach, afterEach } from "bun:test"
|
||||
import { mkdtempSync, writeFileSync, rmSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import { tmpdir } from "node:os"
|
||||
import { isCompactionAgent, findNearestMessageExcludingCompaction } from "./compaction-aware-message-resolver"
|
||||
|
||||
describe("isCompactionAgent", () => {
|
||||
describe("#given agent name variations", () => {
|
||||
test("returns true for 'compaction'", () => {
|
||||
// when
|
||||
const result = isCompactionAgent("compaction")
|
||||
|
||||
// then
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for 'Compaction' (case insensitive)", () => {
|
||||
// when
|
||||
const result = isCompactionAgent("Compaction")
|
||||
|
||||
// then
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for ' compaction ' (with whitespace)", () => {
|
||||
// when
|
||||
const result = isCompactionAgent(" compaction ")
|
||||
|
||||
// then
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
test("returns false for undefined", () => {
|
||||
// when
|
||||
const result = isCompactionAgent(undefined)
|
||||
|
||||
// then
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for null", () => {
|
||||
// when
|
||||
const result = isCompactionAgent(null as unknown as string)
|
||||
|
||||
// then
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for non-compaction agent like 'sisyphus'", () => {
|
||||
// when
|
||||
const result = isCompactionAgent("sisyphus")
|
||||
|
||||
// then
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("findNearestMessageExcludingCompaction", () => {
|
||||
let tempDir: string
|
||||
|
||||
beforeEach(() => {
|
||||
tempDir = mkdtempSync(join(tmpdir(), "compaction-test-"))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
rmSync(tempDir, { force: true, recursive: true })
|
||||
})
|
||||
|
||||
describe("#given directory with messages", () => {
|
||||
test("finds message with full agent and model", () => {
|
||||
// given
|
||||
const message = {
|
||||
agent: "sisyphus",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-6" },
|
||||
}
|
||||
writeFileSync(join(tempDir, "001.json"), JSON.stringify(message))
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(tempDir)
|
||||
|
||||
// then
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.agent).toBe("sisyphus")
|
||||
expect(result?.model?.providerID).toBe("anthropic")
|
||||
expect(result?.model?.modelID).toBe("claude-opus-4-6")
|
||||
})
|
||||
|
||||
test("skips compaction agent messages", () => {
|
||||
// given
|
||||
const compactionMessage = {
|
||||
agent: "compaction",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-6" },
|
||||
}
|
||||
const validMessage = {
|
||||
agent: "sisyphus",
|
||||
model: { providerID: "anthropic", modelID: "claude-opus-4-6" },
|
||||
}
|
||||
writeFileSync(join(tempDir, "001.json"), JSON.stringify(compactionMessage))
|
||||
writeFileSync(join(tempDir, "002.json"), JSON.stringify(validMessage))
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(tempDir)
|
||||
|
||||
// then
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.agent).toBe("sisyphus")
|
||||
})
|
||||
|
||||
test("falls back to partial agent/model match", () => {
|
||||
// given
|
||||
const messageWithAgentOnly = {
|
||||
agent: "hephaestus",
|
||||
}
|
||||
const messageWithModelOnly = {
|
||||
model: { providerID: "openai", modelID: "gpt-5.3" },
|
||||
}
|
||||
writeFileSync(join(tempDir, "001.json"), JSON.stringify(messageWithModelOnly))
|
||||
writeFileSync(join(tempDir, "002.json"), JSON.stringify(messageWithAgentOnly))
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(tempDir)
|
||||
|
||||
// then
|
||||
expect(result).not.toBeNull()
|
||||
// Should find the one with agent first (sorted reverse, so 002 is checked first)
|
||||
expect(result?.agent).toBe("hephaestus")
|
||||
})
|
||||
|
||||
test("returns null for empty directory", () => {
|
||||
// given - empty directory (tempDir is already empty)
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(tempDir)
|
||||
|
||||
// then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
test("returns null for non-existent directory", () => {
|
||||
// given
|
||||
const nonExistentDir = join(tmpdir(), "non-existent-dir-12345")
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(nonExistentDir)
|
||||
|
||||
// then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
test("skips invalid JSON files and finds valid message", () => {
|
||||
// given
|
||||
const invalidJson = "{ invalid json"
|
||||
const validMessage = {
|
||||
agent: "oracle",
|
||||
model: { providerID: "google", modelID: "gemini-2-flash" },
|
||||
}
|
||||
writeFileSync(join(tempDir, "001.json"), invalidJson)
|
||||
writeFileSync(join(tempDir, "002.json"), JSON.stringify(validMessage))
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(tempDir)
|
||||
|
||||
// then
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.agent).toBe("oracle")
|
||||
})
|
||||
|
||||
test("finds newest valid message (sorted by filename reverse)", () => {
|
||||
// given
|
||||
const olderMessage = {
|
||||
agent: "older",
|
||||
model: { providerID: "a", modelID: "b" },
|
||||
}
|
||||
const newerMessage = {
|
||||
agent: "newer",
|
||||
model: { providerID: "c", modelID: "d" },
|
||||
}
|
||||
writeFileSync(join(tempDir, "001.json"), JSON.stringify(olderMessage))
|
||||
writeFileSync(join(tempDir, "010.json"), JSON.stringify(newerMessage))
|
||||
|
||||
// when
|
||||
const result = findNearestMessageExcludingCompaction(tempDir)
|
||||
|
||||
// then
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.agent).toBe("newer")
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,57 @@
|
||||
import { readdirSync, readFileSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import type { StoredMessage } from "../hook-message-injector"
|
||||
|
||||
export function isCompactionAgent(agent: string | undefined): boolean {
|
||||
return agent?.trim().toLowerCase() === "compaction"
|
||||
}
|
||||
|
||||
function hasFullAgentAndModel(message: StoredMessage): boolean {
|
||||
return !!message.agent &&
|
||||
!isCompactionAgent(message.agent) &&
|
||||
!!message.model?.providerID &&
|
||||
!!message.model?.modelID
|
||||
}
|
||||
|
||||
function hasPartialAgentOrModel(message: StoredMessage): boolean {
|
||||
const hasAgent = !!message.agent && !isCompactionAgent(message.agent)
|
||||
const hasModel = !!message.model?.providerID && !!message.model?.modelID
|
||||
return hasAgent || hasModel
|
||||
}
|
||||
|
||||
export function findNearestMessageExcludingCompaction(messageDir: string): StoredMessage | null {
|
||||
try {
|
||||
const files = readdirSync(messageDir)
|
||||
.filter((name) => name.endsWith(".json"))
|
||||
.sort()
|
||||
.reverse()
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
const content = readFileSync(join(messageDir, file), "utf-8")
|
||||
const parsed = JSON.parse(content) as StoredMessage
|
||||
if (hasFullAgentAndModel(parsed)) {
|
||||
return parsed
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
const content = readFileSync(join(messageDir, file), "utf-8")
|
||||
const parsed = JSON.parse(content) as StoredMessage
|
||||
if (hasPartialAgentOrModel(parsed)) {
|
||||
return parsed
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
351
src/features/background-agent/error-classifier.test.ts
Normal file
351
src/features/background-agent/error-classifier.test.ts
Normal file
@ -0,0 +1,351 @@
|
||||
import { describe, test, expect } from "bun:test"
|
||||
import {
|
||||
isRecord,
|
||||
isAbortedSessionError,
|
||||
getErrorText,
|
||||
extractErrorName,
|
||||
extractErrorMessage,
|
||||
getSessionErrorMessage,
|
||||
} from "./error-classifier"
|
||||
|
||||
describe("isRecord", () => {
|
||||
describe("#given null or primitive values", () => {
|
||||
test("returns false for null", () => {
|
||||
expect(isRecord(null)).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for undefined", () => {
|
||||
expect(isRecord(undefined)).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for string", () => {
|
||||
expect(isRecord("hello")).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for number", () => {
|
||||
expect(isRecord(42)).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for boolean", () => {
|
||||
expect(isRecord(true)).toBe(false)
|
||||
})
|
||||
|
||||
test("returns true for array (arrays are objects)", () => {
|
||||
expect(isRecord([1, 2, 3])).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given plain objects", () => {
|
||||
test("returns true for empty object", () => {
|
||||
expect(isRecord({})).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for object with properties", () => {
|
||||
expect(isRecord({ key: "value" })).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for object with nested objects", () => {
|
||||
expect(isRecord({ nested: { deep: true } })).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given Error instances", () => {
|
||||
test("returns true for Error instance", () => {
|
||||
expect(isRecord(new Error("test"))).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for TypeError instance", () => {
|
||||
expect(isRecord(new TypeError("test"))).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("isAbortedSessionError", () => {
|
||||
describe("#given error with aborted message", () => {
|
||||
test("returns true for string containing aborted", () => {
|
||||
expect(isAbortedSessionError("Session aborted")).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for string with ABORTED uppercase", () => {
|
||||
expect(isAbortedSessionError("Session ABORTED")).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for Error with aborted in message", () => {
|
||||
expect(isAbortedSessionError(new Error("Session aborted"))).toBe(true)
|
||||
})
|
||||
|
||||
test("returns true for object with message containing aborted", () => {
|
||||
expect(isAbortedSessionError({ message: "The session was aborted" })).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given error without aborted message", () => {
|
||||
test("returns false for string without aborted", () => {
|
||||
expect(isAbortedSessionError("Session completed")).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for Error without aborted", () => {
|
||||
expect(isAbortedSessionError(new Error("Something went wrong"))).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for empty string", () => {
|
||||
expect(isAbortedSessionError("")).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given invalid inputs", () => {
|
||||
test("returns false for null", () => {
|
||||
expect(isAbortedSessionError(null)).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for undefined", () => {
|
||||
expect(isAbortedSessionError(undefined)).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false for object without message", () => {
|
||||
expect(isAbortedSessionError({ code: "ABORTED" })).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("getErrorText", () => {
|
||||
describe("#given string input", () => {
|
||||
test("returns the string as-is", () => {
|
||||
expect(getErrorText("Something went wrong")).toBe("Something went wrong")
|
||||
})
|
||||
|
||||
test("returns empty string for empty string", () => {
|
||||
expect(getErrorText("")).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given Error instance", () => {
|
||||
test("returns name and message format", () => {
|
||||
expect(getErrorText(new Error("test message"))).toBe("Error: test message")
|
||||
})
|
||||
|
||||
test("returns TypeError format", () => {
|
||||
expect(getErrorText(new TypeError("type error"))).toBe("TypeError: type error")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given object with message property", () => {
|
||||
test("returns message property as string", () => {
|
||||
expect(getErrorText({ message: "custom error" })).toBe("custom error")
|
||||
})
|
||||
|
||||
test("returns name property when message not available", () => {
|
||||
expect(getErrorText({ name: "CustomError" })).toBe("CustomError")
|
||||
})
|
||||
|
||||
test("prefers message over name", () => {
|
||||
expect(getErrorText({ name: "CustomError", message: "error message" })).toBe("error message")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given invalid inputs", () => {
|
||||
test("returns empty string for null", () => {
|
||||
expect(getErrorText(null)).toBe("")
|
||||
})
|
||||
|
||||
test("returns empty string for undefined", () => {
|
||||
expect(getErrorText(undefined)).toBe("")
|
||||
})
|
||||
|
||||
test("returns empty string for object without message or name", () => {
|
||||
expect(getErrorText({ code: 500 })).toBe("")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("extractErrorName", () => {
|
||||
describe("#given Error instance", () => {
|
||||
test("returns Error for generic Error", () => {
|
||||
expect(extractErrorName(new Error("test"))).toBe("Error")
|
||||
})
|
||||
|
||||
test("returns TypeError name", () => {
|
||||
expect(extractErrorName(new TypeError("test"))).toBe("TypeError")
|
||||
})
|
||||
|
||||
test("returns RangeError name", () => {
|
||||
expect(extractErrorName(new RangeError("test"))).toBe("RangeError")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given plain object with name property", () => {
|
||||
test("returns name property when string", () => {
|
||||
expect(extractErrorName({ name: "CustomError" })).toBe("CustomError")
|
||||
})
|
||||
|
||||
test("returns undefined when name is not string", () => {
|
||||
expect(extractErrorName({ name: 123 })).toBe(undefined)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given invalid inputs", () => {
|
||||
test("returns undefined for null", () => {
|
||||
expect(extractErrorName(null)).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined for undefined", () => {
|
||||
expect(extractErrorName(undefined)).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined for string", () => {
|
||||
expect(extractErrorName("Error message")).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined for object without name property", () => {
|
||||
expect(extractErrorName({ message: "test" })).toBe(undefined)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("extractErrorMessage", () => {
|
||||
describe("#given string input", () => {
|
||||
test("returns the string as-is", () => {
|
||||
expect(extractErrorMessage("error message")).toBe("error message")
|
||||
})
|
||||
|
||||
test("returns undefined for empty string", () => {
|
||||
expect(extractErrorMessage("")).toBe(undefined)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given Error instance", () => {
|
||||
test("returns error message", () => {
|
||||
expect(extractErrorMessage(new Error("test error"))).toBe("test error")
|
||||
})
|
||||
|
||||
test("returns empty string for Error with no message", () => {
|
||||
expect(extractErrorMessage(new Error())).toBe("")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given object with message property", () => {
|
||||
test("returns message property", () => {
|
||||
expect(extractErrorMessage({ message: "custom message" })).toBe("custom message")
|
||||
})
|
||||
|
||||
test("falls through to JSON.stringify for empty message value", () => {
|
||||
expect(extractErrorMessage({ message: "" })).toBe('{"message":""}')
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given nested error structure", () => {
|
||||
test("extracts message from nested error object", () => {
|
||||
expect(extractErrorMessage({ error: { message: "nested error" } })).toBe("nested error")
|
||||
})
|
||||
|
||||
test("extracts message from data.error structure", () => {
|
||||
expect(extractErrorMessage({ data: { error: "data error" } })).toBe("data error")
|
||||
})
|
||||
|
||||
test("extracts message from cause property", () => {
|
||||
expect(extractErrorMessage({ cause: "cause error" })).toBe("cause error")
|
||||
})
|
||||
|
||||
test("extracts message from cause object with message", () => {
|
||||
expect(extractErrorMessage({ cause: { message: "cause message" } })).toBe("cause message")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given complex error with data wrapper", () => {
|
||||
test("extracts from error.data.message", () => {
|
||||
const error = {
|
||||
data: {
|
||||
message: "data message",
|
||||
},
|
||||
}
|
||||
expect(extractErrorMessage(error)).toBe("data message")
|
||||
})
|
||||
|
||||
test("prefers top over nested-level message", () => {
|
||||
const error = {
|
||||
message: "top level",
|
||||
data: { message: "nested" },
|
||||
}
|
||||
expect(extractErrorMessage(error)).toBe("top level")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given invalid inputs", () => {
|
||||
test("returns undefined for null", () => {
|
||||
expect(extractErrorMessage(null)).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined for undefined", () => {
|
||||
expect(extractErrorMessage(undefined)).toBe(undefined)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given object without extractable message", () => {
|
||||
test("falls back to JSON.stringify for object", () => {
|
||||
const obj = { code: 500, details: "error" }
|
||||
const result = extractErrorMessage(obj)
|
||||
expect(result).toContain('"code":500')
|
||||
})
|
||||
|
||||
test("falls back to String() for non-serializable object", () => {
|
||||
const circular: Record<string, unknown> = { a: 1 }
|
||||
circular.self = circular
|
||||
const result = extractErrorMessage(circular)
|
||||
expect(result).toBe("[object Object]")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe("getSessionErrorMessage", () => {
|
||||
describe("#given valid error properties", () => {
|
||||
test("extracts message from error.message", () => {
|
||||
const properties = { error: { message: "session error" } }
|
||||
expect(getSessionErrorMessage(properties)).toBe("session error")
|
||||
})
|
||||
|
||||
test("extracts message from error.data.message", () => {
|
||||
const properties = {
|
||||
error: {
|
||||
data: { message: "data error message" },
|
||||
},
|
||||
}
|
||||
expect(getSessionErrorMessage(properties)).toBe("data error message")
|
||||
})
|
||||
|
||||
test("prefers error.data.message over error.message", () => {
|
||||
const properties = {
|
||||
error: {
|
||||
message: "top level",
|
||||
data: { message: "nested" },
|
||||
},
|
||||
}
|
||||
expect(getSessionErrorMessage(properties)).toBe("nested")
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given missing or invalid properties", () => {
|
||||
test("returns undefined when error is missing", () => {
|
||||
expect(getSessionErrorMessage({})).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined when error is null", () => {
|
||||
expect(getSessionErrorMessage({ error: null })).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined when error is string", () => {
|
||||
expect(getSessionErrorMessage({ error: "error string" })).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined when data is not an object", () => {
|
||||
expect(getSessionErrorMessage({ error: { data: "not an object" } })).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined when message is not string", () => {
|
||||
expect(getSessionErrorMessage({ error: { message: 123 } })).toBe(undefined)
|
||||
})
|
||||
|
||||
test("returns undefined when data.message is not string", () => {
|
||||
expect(getSessionErrorMessage({ error: { data: { message: null } } })).toBe(undefined)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,3 +1,7 @@
|
||||
export function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
|
||||
export function isAbortedSessionError(error: unknown): boolean {
|
||||
const message = getErrorText(error)
|
||||
return message.toLowerCase().includes("aborted")
|
||||
@ -19,3 +23,61 @@ export function getErrorText(error: unknown): string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
export function extractErrorName(error: unknown): string | undefined {
|
||||
if (isRecord(error) && typeof error["name"] === "string") return error["name"]
|
||||
if (error instanceof Error) return error.name
|
||||
return undefined
|
||||
}
|
||||
|
||||
export function extractErrorMessage(error: unknown): string | undefined {
|
||||
if (!error) return undefined
|
||||
if (typeof error === "string") return error
|
||||
if (error instanceof Error) return error.message
|
||||
|
||||
if (isRecord(error)) {
|
||||
const dataRaw = error["data"]
|
||||
const candidates: unknown[] = [
|
||||
error,
|
||||
dataRaw,
|
||||
error["error"],
|
||||
isRecord(dataRaw) ? (dataRaw as Record<string, unknown>)["error"] : undefined,
|
||||
error["cause"],
|
||||
]
|
||||
|
||||
for (const candidate of candidates) {
|
||||
if (typeof candidate === "string" && candidate.length > 0) return candidate
|
||||
if (
|
||||
isRecord(candidate) &&
|
||||
typeof candidate["message"] === "string" &&
|
||||
candidate["message"].length > 0
|
||||
) {
|
||||
return candidate["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.stringify(error)
|
||||
} catch {
|
||||
return String(error)
|
||||
}
|
||||
}
|
||||
|
||||
interface EventPropertiesLike {
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export function getSessionErrorMessage(properties: EventPropertiesLike): string | undefined {
|
||||
const errorRaw = properties["error"]
|
||||
if (!isRecord(errorRaw)) return undefined
|
||||
|
||||
const dataRaw = errorRaw["data"]
|
||||
if (isRecord(dataRaw)) {
|
||||
const message = dataRaw["message"]
|
||||
if (typeof message === "string") return message
|
||||
}
|
||||
|
||||
const message = errorRaw["message"]
|
||||
return typeof message === "string" ? message : undefined
|
||||
}
|
||||
|
||||
270
src/features/background-agent/fallback-retry-handler.test.ts
Normal file
270
src/features/background-agent/fallback-retry-handler.test.ts
Normal file
@ -0,0 +1,270 @@
|
||||
import { describe, test, expect, mock, beforeEach } from "bun:test"
|
||||
|
||||
mock.module("../../shared", () => ({
|
||||
log: mock(() => {}),
|
||||
readConnectedProvidersCache: mock(() => null),
|
||||
readProviderModelsCache: mock(() => null),
|
||||
}))
|
||||
|
||||
mock.module("../../shared/model-error-classifier", () => ({
|
||||
shouldRetryError: mock(() => true),
|
||||
getNextFallback: mock((chain: Array<{ model: string }>, attempt: number) => chain[attempt]),
|
||||
hasMoreFallbacks: mock((chain: Array<{ model: string }>, attempt: number) => attempt < chain.length),
|
||||
selectFallbackProvider: mock((providers: string[]) => providers[0]),
|
||||
}))
|
||||
|
||||
mock.module("../../shared/provider-model-id-transform", () => ({
|
||||
transformModelForProvider: mock((_provider: string, model: string) => model),
|
||||
}))
|
||||
|
||||
import { tryFallbackRetry } from "./fallback-retry-handler"
|
||||
import { shouldRetryError } from "../../shared/model-error-classifier"
|
||||
import type { BackgroundTask } from "./types"
|
||||
import type { ConcurrencyManager } from "./concurrency"
|
||||
|
||||
function createMockTask(overrides: Partial<BackgroundTask> = {}): BackgroundTask {
|
||||
return {
|
||||
id: "test-task-1",
|
||||
description: "test task",
|
||||
prompt: "test prompt",
|
||||
agent: "sisyphus-junior",
|
||||
status: "error",
|
||||
parentSessionID: "parent-session-1",
|
||||
parentMessageID: "parent-message-1",
|
||||
fallbackChain: [
|
||||
{ model: "fallback-model-1", providers: ["provider-a"], variant: undefined },
|
||||
{ model: "fallback-model-2", providers: ["provider-b"], variant: undefined },
|
||||
],
|
||||
attemptCount: 0,
|
||||
concurrencyKey: "provider-a/original-model",
|
||||
model: { providerID: "provider-a", modelID: "original-model" },
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
function createMockConcurrencyManager(): ConcurrencyManager {
|
||||
return {
|
||||
release: mock(() => {}),
|
||||
acquire: mock(async () => {}),
|
||||
getQueueLength: mock(() => 0),
|
||||
getActiveCount: mock(() => 0),
|
||||
} as unknown as ConcurrencyManager
|
||||
}
|
||||
|
||||
function createMockClient() {
|
||||
return {
|
||||
session: {
|
||||
abort: mock(async () => ({})),
|
||||
},
|
||||
} as any
|
||||
}
|
||||
|
||||
function createDefaultArgs(taskOverrides: Partial<BackgroundTask> = {}) {
|
||||
const processKeyFn = mock(() => {})
|
||||
const queuesByKey = new Map<string, Array<{ task: BackgroundTask; input: any }>>()
|
||||
const idleDeferralTimers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||
const concurrencyManager = createMockConcurrencyManager()
|
||||
const client = createMockClient()
|
||||
const task = createMockTask(taskOverrides)
|
||||
|
||||
return {
|
||||
task,
|
||||
errorInfo: { name: "OverloadedError", message: "model overloaded" },
|
||||
source: "polling",
|
||||
concurrencyManager,
|
||||
client,
|
||||
idleDeferralTimers,
|
||||
queuesByKey,
|
||||
processKey: processKeyFn,
|
||||
}
|
||||
}
|
||||
|
||||
describe("tryFallbackRetry", () => {
|
||||
beforeEach(() => {
|
||||
;(shouldRetryError as any).mockImplementation(() => true)
|
||||
})
|
||||
|
||||
describe("#given retryable error with fallback chain", () => {
|
||||
test("returns true and enqueues retry", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
const result = tryFallbackRetry(args)
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
test("resets task status to pending", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.status).toBe("pending")
|
||||
})
|
||||
|
||||
test("increments attemptCount", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.attemptCount).toBe(1)
|
||||
})
|
||||
|
||||
test("updates task model to fallback", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.model?.modelID).toBe("fallback-model-1")
|
||||
expect(args.task.model?.providerID).toBe("provider-a")
|
||||
})
|
||||
|
||||
test("clears sessionID and startedAt", () => {
|
||||
const args = createDefaultArgs({
|
||||
sessionID: "old-session",
|
||||
startedAt: new Date(),
|
||||
})
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.sessionID).toBeUndefined()
|
||||
expect(args.task.startedAt).toBeUndefined()
|
||||
})
|
||||
|
||||
test("clears error field", () => {
|
||||
const args = createDefaultArgs({ error: "previous error" })
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.error).toBeUndefined()
|
||||
})
|
||||
|
||||
test("sets new queuedAt", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.queuedAt).toBeInstanceOf(Date)
|
||||
})
|
||||
|
||||
test("releases concurrency slot", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.concurrencyManager.release).toHaveBeenCalledWith("provider-a/original-model")
|
||||
})
|
||||
|
||||
test("clears concurrencyKey after release", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.concurrencyKey).toBeUndefined()
|
||||
})
|
||||
|
||||
test("aborts existing session", () => {
|
||||
const args = createDefaultArgs({ sessionID: "session-to-abort" })
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.client.session.abort).toHaveBeenCalledWith({
|
||||
path: { id: "session-to-abort" },
|
||||
})
|
||||
})
|
||||
|
||||
test("adds retry input to queue and calls processKey", () => {
|
||||
const args = createDefaultArgs()
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
const key = `${args.task.model!.providerID}/${args.task.model!.modelID}`
|
||||
const queue = args.queuesByKey.get(key)
|
||||
expect(queue).toBeDefined()
|
||||
expect(queue!.length).toBe(1)
|
||||
expect(queue![0].task).toBe(args.task)
|
||||
expect(args.processKey).toHaveBeenCalledWith(key)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given non-retryable error", () => {
|
||||
test("returns false when shouldRetryError returns false", () => {
|
||||
;(shouldRetryError as any).mockImplementation(() => false)
|
||||
const args = createDefaultArgs()
|
||||
|
||||
const result = tryFallbackRetry(args)
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given no fallback chain", () => {
|
||||
test("returns false when fallbackChain is undefined", () => {
|
||||
const args = createDefaultArgs({ fallbackChain: undefined })
|
||||
|
||||
const result = tryFallbackRetry(args)
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
test("returns false when fallbackChain is empty", () => {
|
||||
const args = createDefaultArgs({ fallbackChain: [] })
|
||||
|
||||
const result = tryFallbackRetry(args)
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given exhausted fallbacks", () => {
|
||||
test("returns false when attemptCount exceeds chain length", () => {
|
||||
const args = createDefaultArgs({ attemptCount: 5 })
|
||||
|
||||
const result = tryFallbackRetry(args)
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given task without concurrency key", () => {
|
||||
test("skips concurrency release", () => {
|
||||
const args = createDefaultArgs({ concurrencyKey: undefined })
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.concurrencyManager.release).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given task without session", () => {
|
||||
test("skips session abort", () => {
|
||||
const args = createDefaultArgs({ sessionID: undefined })
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.client.session.abort).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given active idle deferral timer", () => {
|
||||
test("clears the timer and removes from map", () => {
|
||||
const args = createDefaultArgs()
|
||||
const timerId = setTimeout(() => {}, 10000)
|
||||
args.idleDeferralTimers.set("test-task-1", timerId)
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.idleDeferralTimers.has("test-task-1")).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("#given second attempt", () => {
|
||||
test("uses second fallback in chain", () => {
|
||||
const args = createDefaultArgs({ attemptCount: 1 })
|
||||
|
||||
tryFallbackRetry(args)
|
||||
|
||||
expect(args.task.model?.modelID).toBe("fallback-model-2")
|
||||
expect(args.task.attemptCount).toBe(2)
|
||||
})
|
||||
})
|
||||
})
|
||||
125
src/features/background-agent/fallback-retry-handler.ts
Normal file
125
src/features/background-agent/fallback-retry-handler.ts
Normal file
@ -0,0 +1,125 @@
|
||||
import type { BackgroundTask, LaunchInput } from "./types"
|
||||
import type { FallbackEntry } from "../../shared/model-requirements"
|
||||
import type { ConcurrencyManager } from "./concurrency"
|
||||
import type { OpencodeClient, QueueItem } from "./constants"
|
||||
import { log, readConnectedProvidersCache, readProviderModelsCache } from "../../shared"
|
||||
import {
|
||||
shouldRetryError,
|
||||
getNextFallback,
|
||||
hasMoreFallbacks,
|
||||
selectFallbackProvider,
|
||||
} from "../../shared/model-error-classifier"
|
||||
import { transformModelForProvider } from "../../shared/provider-model-id-transform"
|
||||
|
||||
export function tryFallbackRetry(args: {
|
||||
task: BackgroundTask
|
||||
errorInfo: { name?: string; message?: string }
|
||||
source: string
|
||||
concurrencyManager: ConcurrencyManager
|
||||
client: OpencodeClient
|
||||
idleDeferralTimers: Map<string, ReturnType<typeof setTimeout>>
|
||||
queuesByKey: Map<string, QueueItem[]>
|
||||
processKey: (key: string) => void
|
||||
}): boolean {
|
||||
const { task, errorInfo, source, concurrencyManager, client, idleDeferralTimers, queuesByKey, processKey } = args
|
||||
const fallbackChain = task.fallbackChain
|
||||
const canRetry =
|
||||
shouldRetryError(errorInfo) &&
|
||||
fallbackChain &&
|
||||
fallbackChain.length > 0 &&
|
||||
hasMoreFallbacks(fallbackChain, task.attemptCount ?? 0)
|
||||
|
||||
if (!canRetry) return false
|
||||
|
||||
const attemptCount = task.attemptCount ?? 0
|
||||
const providerModelsCache = readProviderModelsCache()
|
||||
const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache()
|
||||
const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null
|
||||
|
||||
const isReachable = (entry: FallbackEntry): boolean => {
|
||||
if (!connectedSet) return true
|
||||
return entry.providers.some((p) => connectedSet.has(p.toLowerCase()))
|
||||
}
|
||||
|
||||
let selectedAttemptCount = attemptCount
|
||||
let nextFallback: FallbackEntry | undefined
|
||||
while (fallbackChain && selectedAttemptCount < fallbackChain.length) {
|
||||
const candidate = getNextFallback(fallbackChain, selectedAttemptCount)
|
||||
if (!candidate) break
|
||||
selectedAttemptCount++
|
||||
if (!isReachable(candidate)) {
|
||||
log("[background-agent] Skipping unreachable fallback:", {
|
||||
taskId: task.id,
|
||||
source,
|
||||
model: candidate.model,
|
||||
providers: candidate.providers,
|
||||
})
|
||||
continue
|
||||
}
|
||||
nextFallback = candidate
|
||||
break
|
||||
}
|
||||
if (!nextFallback) return false
|
||||
|
||||
const providerID = selectFallbackProvider(
|
||||
nextFallback.providers,
|
||||
task.model?.providerID,
|
||||
)
|
||||
|
||||
log("[background-agent] Retryable error, attempting fallback:", {
|
||||
taskId: task.id,
|
||||
source,
|
||||
errorName: errorInfo.name,
|
||||
errorMessage: errorInfo.message?.slice(0, 100),
|
||||
attemptCount: selectedAttemptCount,
|
||||
nextModel: `${providerID}/${nextFallback.model}`,
|
||||
})
|
||||
|
||||
if (task.concurrencyKey) {
|
||||
concurrencyManager.release(task.concurrencyKey)
|
||||
task.concurrencyKey = undefined
|
||||
}
|
||||
|
||||
if (task.sessionID) {
|
||||
client.session.abort({ path: { id: task.sessionID } }).catch(() => {})
|
||||
}
|
||||
|
||||
const idleTimer = idleDeferralTimers.get(task.id)
|
||||
if (idleTimer) {
|
||||
clearTimeout(idleTimer)
|
||||
idleDeferralTimers.delete(task.id)
|
||||
}
|
||||
|
||||
task.attemptCount = selectedAttemptCount
|
||||
const transformedModelId = transformModelForProvider(providerID, nextFallback.model)
|
||||
task.model = {
|
||||
providerID,
|
||||
modelID: transformedModelId,
|
||||
variant: nextFallback.variant,
|
||||
}
|
||||
task.status = "pending"
|
||||
task.sessionID = undefined
|
||||
task.startedAt = undefined
|
||||
task.queuedAt = new Date()
|
||||
task.error = undefined
|
||||
|
||||
const key = task.model ? `${task.model.providerID}/${task.model.modelID}` : task.agent
|
||||
const queue = queuesByKey.get(key) ?? []
|
||||
const retryInput: LaunchInput = {
|
||||
description: task.description,
|
||||
prompt: task.prompt,
|
||||
agent: task.agent,
|
||||
parentSessionID: task.parentSessionID,
|
||||
parentMessageID: task.parentMessageID,
|
||||
parentModel: task.parentModel,
|
||||
parentAgent: task.parentAgent,
|
||||
parentTools: task.parentTools,
|
||||
model: task.model,
|
||||
fallbackChain: task.fallbackChain,
|
||||
category: task.category,
|
||||
}
|
||||
queue.push({ task, input: retryInput })
|
||||
queuesByKey.set(key, queue)
|
||||
processKey(key)
|
||||
return true
|
||||
}
|
||||
@ -5,16 +5,14 @@ import type {
|
||||
LaunchInput,
|
||||
ResumeInput,
|
||||
} from "./types"
|
||||
import type { FallbackEntry } from "../../shared/model-requirements"
|
||||
import { TaskHistory } from "./task-history"
|
||||
import {
|
||||
log,
|
||||
getAgentToolRestrictions,
|
||||
getMessageDir,
|
||||
normalizePromptTools,
|
||||
normalizeSDKResponse,
|
||||
promptWithModelSuggestionRetry,
|
||||
readConnectedProvidersCache,
|
||||
readProviderModelsCache,
|
||||
resolveInheritedPromptTools,
|
||||
createInternalAgentTextPart,
|
||||
} from "../../shared"
|
||||
@ -25,28 +23,29 @@ import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema"
|
||||
import { isInsideTmux } from "../../shared/tmux"
|
||||
import {
|
||||
shouldRetryError,
|
||||
getNextFallback,
|
||||
hasMoreFallbacks,
|
||||
selectFallbackProvider,
|
||||
} from "../../shared/model-error-classifier"
|
||||
import { transformModelForProvider } from "../../shared/provider-model-id-transform"
|
||||
import {
|
||||
DEFAULT_MESSAGE_STALENESS_TIMEOUT_MS,
|
||||
DEFAULT_STALE_TIMEOUT_MS,
|
||||
MIN_IDLE_TIME_MS,
|
||||
MIN_RUNTIME_BEFORE_STALE_MS,
|
||||
POLLING_INTERVAL_MS,
|
||||
TASK_CLEANUP_DELAY_MS,
|
||||
TASK_TTL_MS,
|
||||
} from "./constants"
|
||||
|
||||
import { subagentSessions } from "../claude-code-session-state"
|
||||
import { getTaskToastManager } from "../task-toast-manager"
|
||||
import { MESSAGE_STORAGE, type StoredMessage } from "../hook-message-injector"
|
||||
import { existsSync, readFileSync, readdirSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
|
||||
type ProcessCleanupEvent = NodeJS.Signals | "beforeExit" | "exit"
|
||||
import { formatDuration } from "./duration-formatter"
|
||||
import {
|
||||
isAbortedSessionError,
|
||||
extractErrorName,
|
||||
extractErrorMessage,
|
||||
getSessionErrorMessage,
|
||||
isRecord,
|
||||
} from "./error-classifier"
|
||||
import { tryFallbackRetry } from "./fallback-retry-handler"
|
||||
import { registerManagerForCleanup, unregisterManagerForCleanup } from "./process-cleanup"
|
||||
import { isCompactionAgent, findNearestMessageExcludingCompaction } from "./compaction-aware-message-resolver"
|
||||
import { pruneStaleTasksAndNotifications } from "./task-poller"
|
||||
import { checkAndInterruptStaleTasks } from "./task-poller"
|
||||
|
||||
type OpencodeClient = PluginInput["client"]
|
||||
|
||||
@ -89,9 +88,7 @@ export interface SubagentSessionCreatedEvent {
|
||||
export type OnSubagentSessionCreated = (event: SubagentSessionCreatedEvent) => Promise<void>
|
||||
|
||||
export class BackgroundManager {
|
||||
private static cleanupManagers = new Set<BackgroundManager>()
|
||||
private static cleanupRegistered = false
|
||||
private static cleanupHandlers = new Map<ProcessCleanupEvent, () => void>()
|
||||
|
||||
|
||||
private tasks: Map<string, BackgroundTask>
|
||||
private notifications: Map<string, BackgroundTask[]>
|
||||
@ -705,8 +702,8 @@ export class BackgroundManager {
|
||||
if (!assistantError) return
|
||||
|
||||
const errorInfo = {
|
||||
name: this.extractErrorName(assistantError),
|
||||
message: this.extractErrorMessage(assistantError),
|
||||
name: extractErrorName(assistantError),
|
||||
message: extractErrorMessage(assistantError),
|
||||
}
|
||||
this.tryFallbackRetry(task, errorInfo, "message.updated")
|
||||
}
|
||||
@ -809,7 +806,7 @@ export class BackgroundManager {
|
||||
|
||||
const errorObj = props?.error as { name?: string; message?: string } | undefined
|
||||
const errorName = errorObj?.name
|
||||
const errorMessage = props ? this.getSessionErrorMessage(props) : undefined
|
||||
const errorMessage = props ? getSessionErrorMessage(props) : undefined
|
||||
|
||||
const errorInfo = { name: errorName, message: errorMessage }
|
||||
if (this.tryFallbackRetry(task, errorInfo, "session.error")) return
|
||||
@ -934,110 +931,20 @@ export class BackgroundManager {
|
||||
errorInfo: { name?: string; message?: string },
|
||||
source: string,
|
||||
): boolean {
|
||||
const fallbackChain = task.fallbackChain
|
||||
const canRetry =
|
||||
shouldRetryError(errorInfo) &&
|
||||
fallbackChain &&
|
||||
fallbackChain.length > 0 &&
|
||||
hasMoreFallbacks(fallbackChain, task.attemptCount ?? 0)
|
||||
|
||||
if (!canRetry) return false
|
||||
|
||||
const attemptCount = task.attemptCount ?? 0
|
||||
const providerModelsCache = readProviderModelsCache()
|
||||
const connectedProviders = providerModelsCache?.connected ?? readConnectedProvidersCache()
|
||||
const connectedSet = connectedProviders ? new Set(connectedProviders.map(p => p.toLowerCase())) : null
|
||||
|
||||
const isReachable = (entry: FallbackEntry): boolean => {
|
||||
if (!connectedSet) return true
|
||||
|
||||
// Gate only on provider connectivity. Provider model lists can be stale/incomplete,
|
||||
// especially after users manually add models to opencode.json.
|
||||
return entry.providers.some((p) => connectedSet.has(p.toLowerCase()))
|
||||
}
|
||||
|
||||
let selectedAttemptCount = attemptCount
|
||||
let nextFallback: FallbackEntry | undefined
|
||||
while (fallbackChain && selectedAttemptCount < fallbackChain.length) {
|
||||
const candidate = getNextFallback(fallbackChain, selectedAttemptCount)
|
||||
if (!candidate) break
|
||||
selectedAttemptCount++
|
||||
if (!isReachable(candidate)) {
|
||||
log("[background-agent] Skipping unreachable fallback:", {
|
||||
taskId: task.id,
|
||||
source,
|
||||
model: candidate.model,
|
||||
providers: candidate.providers,
|
||||
})
|
||||
continue
|
||||
}
|
||||
nextFallback = candidate
|
||||
break
|
||||
}
|
||||
if (!nextFallback) return false
|
||||
|
||||
const providerID = selectFallbackProvider(
|
||||
nextFallback.providers,
|
||||
task.model?.providerID,
|
||||
)
|
||||
|
||||
log("[background-agent] Retryable error, attempting fallback:", {
|
||||
taskId: task.id,
|
||||
const result = tryFallbackRetry({
|
||||
task,
|
||||
errorInfo,
|
||||
source,
|
||||
errorName: errorInfo.name,
|
||||
errorMessage: errorInfo.message?.slice(0, 100),
|
||||
attemptCount: selectedAttemptCount,
|
||||
nextModel: `${providerID}/${nextFallback.model}`,
|
||||
concurrencyManager: this.concurrencyManager,
|
||||
client: this.client,
|
||||
idleDeferralTimers: this.idleDeferralTimers,
|
||||
queuesByKey: this.queuesByKey,
|
||||
processKey: (key: string) => this.processKey(key),
|
||||
})
|
||||
|
||||
if (task.concurrencyKey) {
|
||||
this.concurrencyManager.release(task.concurrencyKey)
|
||||
task.concurrencyKey = undefined
|
||||
}
|
||||
|
||||
if (task.sessionID) {
|
||||
this.client.session.abort({ path: { id: task.sessionID } }).catch(() => {})
|
||||
if (result && task.sessionID) {
|
||||
subagentSessions.delete(task.sessionID)
|
||||
}
|
||||
|
||||
const idleTimer = this.idleDeferralTimers.get(task.id)
|
||||
if (idleTimer) {
|
||||
clearTimeout(idleTimer)
|
||||
this.idleDeferralTimers.delete(task.id)
|
||||
}
|
||||
|
||||
task.attemptCount = selectedAttemptCount
|
||||
const transformedModelId = transformModelForProvider(providerID, nextFallback.model)
|
||||
task.model = {
|
||||
providerID,
|
||||
modelID: transformedModelId,
|
||||
variant: nextFallback.variant,
|
||||
}
|
||||
task.status = "pending"
|
||||
task.sessionID = undefined
|
||||
task.startedAt = undefined
|
||||
task.queuedAt = new Date()
|
||||
task.error = undefined
|
||||
|
||||
const key = task.model ? `${task.model.providerID}/${task.model.modelID}` : task.agent
|
||||
const queue = this.queuesByKey.get(key) ?? []
|
||||
const retryInput: LaunchInput = {
|
||||
description: task.description,
|
||||
prompt: task.prompt,
|
||||
agent: task.agent,
|
||||
parentSessionID: task.parentSessionID,
|
||||
parentMessageID: task.parentMessageID,
|
||||
parentModel: task.parentModel,
|
||||
parentAgent: task.parentAgent,
|
||||
parentTools: task.parentTools,
|
||||
model: task.model,
|
||||
fallbackChain: task.fallbackChain,
|
||||
category: task.category,
|
||||
}
|
||||
queue.push({ task, input: retryInput })
|
||||
this.queuesByKey.set(key, queue)
|
||||
this.processKey(key)
|
||||
return true
|
||||
return result
|
||||
}
|
||||
|
||||
markForNotification(task: BackgroundTask): void {
|
||||
@ -1256,45 +1163,11 @@ export class BackgroundManager {
|
||||
}
|
||||
|
||||
private registerProcessCleanup(): void {
|
||||
BackgroundManager.cleanupManagers.add(this)
|
||||
|
||||
if (BackgroundManager.cleanupRegistered) return
|
||||
BackgroundManager.cleanupRegistered = true
|
||||
|
||||
const cleanupAll = () => {
|
||||
for (const manager of BackgroundManager.cleanupManagers) {
|
||||
try {
|
||||
manager.shutdown()
|
||||
} catch (error) {
|
||||
log("[background-agent] Error during shutdown cleanup:", error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const registerSignal = (signal: ProcessCleanupEvent, exitAfter: boolean): void => {
|
||||
const listener = registerProcessSignal(signal, cleanupAll, exitAfter)
|
||||
BackgroundManager.cleanupHandlers.set(signal, listener)
|
||||
}
|
||||
|
||||
registerSignal("SIGINT", true)
|
||||
registerSignal("SIGTERM", true)
|
||||
if (process.platform === "win32") {
|
||||
registerSignal("SIGBREAK", true)
|
||||
}
|
||||
registerSignal("beforeExit", false)
|
||||
registerSignal("exit", false)
|
||||
registerManagerForCleanup(this)
|
||||
}
|
||||
|
||||
private unregisterProcessCleanup(): void {
|
||||
BackgroundManager.cleanupManagers.delete(this)
|
||||
|
||||
if (BackgroundManager.cleanupManagers.size > 0) return
|
||||
|
||||
for (const [signal, listener] of BackgroundManager.cleanupHandlers.entries()) {
|
||||
process.off(signal, listener)
|
||||
}
|
||||
BackgroundManager.cleanupHandlers.clear()
|
||||
BackgroundManager.cleanupRegistered = false
|
||||
unregisterManagerForCleanup(this)
|
||||
}
|
||||
|
||||
|
||||
@ -1368,7 +1241,7 @@ export class BackgroundManager {
|
||||
// Note: Callers must release concurrency before calling this method
|
||||
// to ensure slots are freed even if notification fails
|
||||
|
||||
const duration = this.formatDuration(task.startedAt ?? new Date(), task.completedAt)
|
||||
const duration = formatDuration(task.startedAt ?? new Date(), task.completedAt)
|
||||
|
||||
log("[background-agent] notifyParentSession called for task:", task.id)
|
||||
|
||||
@ -1455,7 +1328,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
if (isCompactionAgent(info?.agent)) {
|
||||
continue
|
||||
}
|
||||
const normalizedTools = this.isRecord(info?.tools)
|
||||
const normalizedTools = isRecord(info?.tools)
|
||||
? normalizePromptTools(info.tools as Record<string, boolean | "allow" | "deny" | "ask">)
|
||||
: undefined
|
||||
if (info?.agent || info?.model || (info?.modelID && info?.providerID) || normalizedTools) {
|
||||
@ -1466,7 +1339,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.isAbortedSessionError(error)) {
|
||||
if (isAbortedSessionError(error)) {
|
||||
log("[background-agent] Parent session aborted while loading messages; using messageDir fallback:", {
|
||||
taskId: task.id,
|
||||
parentSessionID: task.parentSessionID,
|
||||
@ -1506,7 +1379,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
noReply: !allComplete,
|
||||
})
|
||||
} catch (error) {
|
||||
if (this.isAbortedSessionError(error)) {
|
||||
if (isAbortedSessionError(error)) {
|
||||
log("[background-agent] Parent session aborted while sending notification; continuing cleanup:", {
|
||||
taskId: task.id,
|
||||
parentSessionID: task.parentSessionID,
|
||||
@ -1544,97 +1417,11 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
}
|
||||
|
||||
private formatDuration(start: Date, end?: Date): string {
|
||||
const duration = (end ?? new Date()).getTime() - start.getTime()
|
||||
const seconds = Math.floor(duration / 1000)
|
||||
const minutes = Math.floor(seconds / 60)
|
||||
const hours = Math.floor(minutes / 60)
|
||||
|
||||
if (hours > 0) {
|
||||
return `${hours}h ${minutes % 60}m ${seconds % 60}s`
|
||||
} else if (minutes > 0) {
|
||||
return `${minutes}m ${seconds % 60}s`
|
||||
}
|
||||
return `${seconds}s`
|
||||
return formatDuration(start, end)
|
||||
}
|
||||
|
||||
private isAbortedSessionError(error: unknown): boolean {
|
||||
const message = this.getErrorText(error)
|
||||
return message.toLowerCase().includes("aborted")
|
||||
}
|
||||
|
||||
private getErrorText(error: unknown): string {
|
||||
if (!error) return ""
|
||||
if (typeof error === "string") return error
|
||||
if (error instanceof Error) {
|
||||
return `${error.name}: ${error.message}`
|
||||
}
|
||||
if (typeof error === "object" && error !== null) {
|
||||
if ("message" in error && typeof error.message === "string") {
|
||||
return error.message
|
||||
}
|
||||
if ("name" in error && typeof error.name === "string") {
|
||||
return error.name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
private extractErrorName(error: unknown): string | undefined {
|
||||
if (this.isRecord(error) && typeof error["name"] === "string") return error["name"]
|
||||
if (error instanceof Error) return error.name
|
||||
return undefined
|
||||
}
|
||||
|
||||
private extractErrorMessage(error: unknown): string | undefined {
|
||||
if (!error) return undefined
|
||||
if (typeof error === "string") return error
|
||||
if (error instanceof Error) return error.message
|
||||
|
||||
if (this.isRecord(error)) {
|
||||
const dataRaw = error["data"]
|
||||
const candidates: unknown[] = [
|
||||
error,
|
||||
dataRaw,
|
||||
error["error"],
|
||||
this.isRecord(dataRaw) ? (dataRaw as Record<string, unknown>)["error"] : undefined,
|
||||
error["cause"],
|
||||
]
|
||||
|
||||
for (const candidate of candidates) {
|
||||
if (typeof candidate === "string" && candidate.length > 0) return candidate
|
||||
if (
|
||||
this.isRecord(candidate) &&
|
||||
typeof candidate["message"] === "string" &&
|
||||
candidate["message"].length > 0
|
||||
) {
|
||||
return candidate["message"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.stringify(error)
|
||||
} catch {
|
||||
return String(error)
|
||||
}
|
||||
}
|
||||
|
||||
private isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
|
||||
private getSessionErrorMessage(properties: EventProperties): string | undefined {
|
||||
const errorRaw = properties["error"]
|
||||
if (!this.isRecord(errorRaw)) return undefined
|
||||
|
||||
const dataRaw = errorRaw["data"]
|
||||
if (this.isRecord(dataRaw)) {
|
||||
const message = dataRaw["message"]
|
||||
if (typeof message === "string") return message
|
||||
}
|
||||
|
||||
const message = errorRaw["message"]
|
||||
return typeof message === "string" ? message : undefined
|
||||
return isAbortedSessionError(error)
|
||||
}
|
||||
|
||||
private hasRunningTasks(): boolean {
|
||||
@ -1645,25 +1432,12 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
}
|
||||
|
||||
private pruneStaleTasksAndNotifications(): void {
|
||||
const now = Date.now()
|
||||
|
||||
for (const [taskId, task] of this.tasks.entries()) {
|
||||
const wasPending = task.status === "pending"
|
||||
const timestamp = task.status === "pending"
|
||||
? task.queuedAt?.getTime()
|
||||
: task.startedAt?.getTime()
|
||||
|
||||
if (!timestamp) {
|
||||
continue
|
||||
}
|
||||
|
||||
const age = now - timestamp
|
||||
if (age > TASK_TTL_MS) {
|
||||
const errorMessage = task.status === "pending"
|
||||
? "Task timed out while queued (30 minutes)"
|
||||
: "Task timed out after 30 minutes"
|
||||
|
||||
log("[background-agent] Pruning stale task:", { taskId, status: task.status, age: Math.round(age / 1000) + "s" })
|
||||
pruneStaleTasksAndNotifications({
|
||||
tasks: this.tasks,
|
||||
notifications: this.notifications,
|
||||
onTaskPruned: (taskId, task, errorMessage) => {
|
||||
const wasPending = task.status === "pending"
|
||||
log("[background-agent] Pruning stale task:", { taskId, status: task.status, age: Math.round(((wasPending ? task.queuedAt?.getTime() : task.startedAt?.getTime()) ? (Date.now() - (wasPending ? task.queuedAt!.getTime() : task.startedAt!.getTime())) : 0) / 1000) + "s" })
|
||||
task.status = "error"
|
||||
task.error = errorMessage
|
||||
task.completedAt = new Date()
|
||||
@ -1671,7 +1445,6 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
this.concurrencyManager.release(task.concurrencyKey)
|
||||
task.concurrencyKey = undefined
|
||||
}
|
||||
// Clean up pendingByParent to prevent stale entries
|
||||
this.cleanupPendingByParent(task)
|
||||
if (wasPending) {
|
||||
const key = task.model
|
||||
@ -1698,97 +1471,21 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
subagentSessions.delete(task.sessionID)
|
||||
SessionCategoryRegistry.remove(task.sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const [sessionID, notifications] of this.notifications.entries()) {
|
||||
if (notifications.length === 0) {
|
||||
this.notifications.delete(sessionID)
|
||||
continue
|
||||
}
|
||||
const validNotifications = notifications.filter((task) => {
|
||||
if (!task.startedAt) return false
|
||||
const age = now - task.startedAt.getTime()
|
||||
return age <= TASK_TTL_MS
|
||||
})
|
||||
if (validNotifications.length === 0) {
|
||||
this.notifications.delete(sessionID)
|
||||
} else if (validNotifications.length !== notifications.length) {
|
||||
this.notifications.set(sessionID, validNotifications)
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
private async checkAndInterruptStaleTasks(
|
||||
allStatuses: Record<string, { type: string }> = {},
|
||||
): Promise<void> {
|
||||
const staleTimeoutMs = this.config?.staleTimeoutMs ?? DEFAULT_STALE_TIMEOUT_MS
|
||||
const messageStalenessMs = this.config?.messageStalenessTimeoutMs ?? DEFAULT_MESSAGE_STALENESS_TIMEOUT_MS
|
||||
const now = Date.now()
|
||||
|
||||
for (const task of this.tasks.values()) {
|
||||
if (task.status !== "running") continue
|
||||
|
||||
const startedAt = task.startedAt
|
||||
const sessionID = task.sessionID
|
||||
if (!startedAt || !sessionID) continue
|
||||
|
||||
const sessionStatus = allStatuses[sessionID]?.type
|
||||
const sessionIsRunning = sessionStatus !== undefined && sessionStatus !== "idle"
|
||||
const runtime = now - startedAt.getTime()
|
||||
|
||||
if (!task.progress?.lastUpdate) {
|
||||
if (sessionIsRunning) continue
|
||||
if (runtime <= messageStalenessMs) continue
|
||||
|
||||
const staleMinutes = Math.round(runtime / 60000)
|
||||
task.status = "cancelled"
|
||||
task.error = `Stale timeout (no activity for ${staleMinutes}min since start)`
|
||||
task.completedAt = new Date()
|
||||
|
||||
if (task.concurrencyKey) {
|
||||
this.concurrencyManager.release(task.concurrencyKey)
|
||||
task.concurrencyKey = undefined
|
||||
}
|
||||
|
||||
this.client.session.abort({ path: { id: sessionID } }).catch(() => {})
|
||||
log(`[background-agent] Task ${task.id} interrupted: no progress since start`)
|
||||
|
||||
try {
|
||||
await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task))
|
||||
} catch (err) {
|
||||
log("[background-agent] Error in notifyParentSession for stale task:", { taskId: task.id, error: err })
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if (sessionIsRunning) continue
|
||||
|
||||
if (runtime < MIN_RUNTIME_BEFORE_STALE_MS) continue
|
||||
|
||||
const timeSinceLastUpdate = now - task.progress.lastUpdate.getTime()
|
||||
if (timeSinceLastUpdate <= staleTimeoutMs) continue
|
||||
if (task.status !== "running") continue
|
||||
|
||||
const staleMinutes = Math.round(timeSinceLastUpdate / 60000)
|
||||
task.status = "cancelled"
|
||||
task.error = `Stale timeout (no activity for ${staleMinutes}min)`
|
||||
task.completedAt = new Date()
|
||||
|
||||
if (task.concurrencyKey) {
|
||||
this.concurrencyManager.release(task.concurrencyKey)
|
||||
task.concurrencyKey = undefined
|
||||
}
|
||||
|
||||
this.client.session.abort({ path: { id: sessionID } }).catch(() => {})
|
||||
log(`[background-agent] Task ${task.id} interrupted: stale timeout`)
|
||||
|
||||
try {
|
||||
await this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task))
|
||||
} catch (err) {
|
||||
log("[background-agent] Error in notifyParentSession for stale task:", { taskId: task.id, error: err })
|
||||
}
|
||||
}
|
||||
await checkAndInterruptStaleTasks({
|
||||
tasks: this.tasks.values(),
|
||||
client: this.client,
|
||||
config: this.config,
|
||||
concurrencyManager: this.concurrencyManager,
|
||||
notifyParentSession: (task) => this.enqueueNotificationForParent(task.parentSessionID, () => this.notifyParentSession(task)),
|
||||
sessionStatuses: allStatuses,
|
||||
})
|
||||
}
|
||||
|
||||
private async pollRunningTasks(): Promise<void> {
|
||||
@ -1948,89 +1645,3 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
||||
return current
|
||||
}
|
||||
}
|
||||
|
||||
function registerProcessSignal(
|
||||
signal: ProcessCleanupEvent,
|
||||
handler: () => void,
|
||||
exitAfter: boolean
|
||||
): () => void {
|
||||
const listener = () => {
|
||||
handler()
|
||||
if (exitAfter) {
|
||||
// Set exitCode and schedule exit after delay to allow other handlers to complete async cleanup
|
||||
// Use 6s delay to accommodate LSP cleanup (5s timeout + 1s SIGKILL wait)
|
||||
process.exitCode = 0
|
||||
setTimeout(() => process.exit(), 6000)
|
||||
}
|
||||
}
|
||||
process.on(signal, listener)
|
||||
return listener
|
||||
}
|
||||
|
||||
|
||||
function getMessageDir(sessionID: string): string | null {
|
||||
if (!existsSync(MESSAGE_STORAGE)) return null
|
||||
|
||||
const directPath = join(MESSAGE_STORAGE, sessionID)
|
||||
if (existsSync(directPath)) return directPath
|
||||
|
||||
for (const dir of readdirSync(MESSAGE_STORAGE)) {
|
||||
const sessionPath = join(MESSAGE_STORAGE, dir, sessionID)
|
||||
if (existsSync(sessionPath)) return sessionPath
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function isCompactionAgent(agent: string | undefined): boolean {
|
||||
return agent?.trim().toLowerCase() === "compaction"
|
||||
}
|
||||
|
||||
function hasFullAgentAndModel(message: StoredMessage): boolean {
|
||||
return !!message.agent &&
|
||||
!isCompactionAgent(message.agent) &&
|
||||
!!message.model?.providerID &&
|
||||
!!message.model?.modelID
|
||||
}
|
||||
|
||||
function hasPartialAgentOrModel(message: StoredMessage): boolean {
|
||||
const hasAgent = !!message.agent && !isCompactionAgent(message.agent)
|
||||
const hasModel = !!message.model?.providerID && !!message.model?.modelID
|
||||
return hasAgent || hasModel
|
||||
}
|
||||
|
||||
function findNearestMessageExcludingCompaction(messageDir: string): StoredMessage | null {
|
||||
try {
|
||||
const files = readdirSync(messageDir)
|
||||
.filter((name) => name.endsWith(".json"))
|
||||
.sort()
|
||||
.reverse()
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
const content = readFileSync(join(messageDir, file), "utf-8")
|
||||
const parsed = JSON.parse(content) as StoredMessage
|
||||
if (hasFullAgentAndModel(parsed)) {
|
||||
return parsed
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
const content = readFileSync(join(messageDir, file), "utf-8")
|
||||
const parsed = JSON.parse(content) as StoredMessage
|
||||
if (hasPartialAgentOrModel(parsed)) {
|
||||
return parsed
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
import type { BackgroundTask } from "./types"
|
||||
|
||||
export function buildBackgroundTaskNotificationText(args: {
|
||||
task: BackgroundTask
|
||||
duration: string
|
||||
allComplete: boolean
|
||||
remainingCount: number
|
||||
completedTasks: BackgroundTask[]
|
||||
}): string {
|
||||
const { task, duration, allComplete, remainingCount, completedTasks } = args
|
||||
const statusText =
|
||||
task.status === "completed" ? "COMPLETED" : task.status === "interrupt" ? "INTERRUPTED" : task.status === "error" ? "ERROR" : "CANCELLED"
|
||||
const errorInfo = task.error ? `\n**Error:** ${task.error}` : ""
|
||||
|
||||
if (allComplete) {
|
||||
const completedTasksText = completedTasks
|
||||
.map((t) => `- \`${t.id}\`: ${t.description}`)
|
||||
.join("\n")
|
||||
|
||||
return `<system-reminder>
|
||||
[ALL BACKGROUND TASKS COMPLETE]
|
||||
|
||||
**Completed:**
|
||||
${completedTasksText || `- \`${task.id}\`: ${task.description}`}
|
||||
|
||||
Use \`background_output(task_id="<id>")\` to retrieve each result.
|
||||
</system-reminder>`
|
||||
}
|
||||
|
||||
return `<system-reminder>
|
||||
[BACKGROUND TASK ${statusText}]
|
||||
**ID:** \`${task.id}\`
|
||||
**Description:** ${task.description}
|
||||
**Duration:** ${duration}${errorInfo}
|
||||
|
||||
**${remainingCount} task${remainingCount === 1 ? "" : "s"} still in progress.** You WILL be notified when ALL complete.
|
||||
Do NOT poll - continue productive work.
|
||||
|
||||
Use \`background_output(task_id="${task.id}")\` to retrieve this result when ready.
|
||||
</system-reminder>`
|
||||
}
|
||||
156
src/features/background-agent/process-cleanup.test.ts
Normal file
156
src/features/background-agent/process-cleanup.test.ts
Normal file
@ -0,0 +1,156 @@
|
||||
import { describe, test, expect, beforeEach, afterEach, mock } from "bun:test"
|
||||
import {
|
||||
registerManagerForCleanup,
|
||||
unregisterManagerForCleanup,
|
||||
_resetForTesting,
|
||||
} from "./process-cleanup"
|
||||
|
||||
describe("process-cleanup", () => {
|
||||
const registeredManagers: Array<{ shutdown: () => void }> = []
|
||||
const mockShutdown = mock(() => {})
|
||||
|
||||
const processOnCalls: Array<[string, Function]> = []
|
||||
const processOffCalls: Array<[string, Function]> = []
|
||||
const originalProcessOn = process.on.bind(process)
|
||||
const originalProcessOff = process.off.bind(process)
|
||||
|
||||
beforeEach(() => {
|
||||
mockShutdown.mockClear()
|
||||
processOnCalls.length = 0
|
||||
processOffCalls.length = 0
|
||||
registeredManagers.length = 0
|
||||
|
||||
process.on = originalProcessOn as any
|
||||
process.off = originalProcessOff as any
|
||||
_resetForTesting()
|
||||
|
||||
process.on = ((event: string, listener: Function) => {
|
||||
processOnCalls.push([event, listener])
|
||||
return process
|
||||
}) as any
|
||||
|
||||
process.off = ((event: string, listener: Function) => {
|
||||
processOffCalls.push([event, listener])
|
||||
return process
|
||||
}) as any
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
process.on = originalProcessOn as any
|
||||
process.off = originalProcessOff as any
|
||||
|
||||
for (const manager of [...registeredManagers]) {
|
||||
unregisterManagerForCleanup(manager)
|
||||
}
|
||||
})
|
||||
|
||||
describe("registerManagerForCleanup", () => {
|
||||
test("registers signal handlers on first manager", () => {
|
||||
const manager = { shutdown: mockShutdown }
|
||||
registeredManagers.push(manager)
|
||||
|
||||
registerManagerForCleanup(manager)
|
||||
|
||||
const signals = processOnCalls.map(([signal]) => signal)
|
||||
expect(signals).toContain("SIGINT")
|
||||
expect(signals).toContain("SIGTERM")
|
||||
expect(signals).toContain("beforeExit")
|
||||
expect(signals).toContain("exit")
|
||||
})
|
||||
|
||||
test("signal listener calls shutdown on registered manager", () => {
|
||||
const manager = { shutdown: mockShutdown }
|
||||
registeredManagers.push(manager)
|
||||
|
||||
registerManagerForCleanup(manager)
|
||||
|
||||
const [, listener] = processOnCalls[0]
|
||||
listener()
|
||||
|
||||
expect(mockShutdown).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test("multiple managers all get shutdown when signal fires", () => {
|
||||
const shutdown1 = mock(() => {})
|
||||
const shutdown2 = mock(() => {})
|
||||
const shutdown3 = mock(() => {})
|
||||
const manager1 = { shutdown: shutdown1 }
|
||||
const manager2 = { shutdown: shutdown2 }
|
||||
const manager3 = { shutdown: shutdown3 }
|
||||
registeredManagers.push(manager1, manager2, manager3)
|
||||
|
||||
registerManagerForCleanup(manager1)
|
||||
registerManagerForCleanup(manager2)
|
||||
registerManagerForCleanup(manager3)
|
||||
|
||||
const [, listener] = processOnCalls[0]
|
||||
listener()
|
||||
|
||||
expect(shutdown1).toHaveBeenCalledTimes(1)
|
||||
expect(shutdown2).toHaveBeenCalledTimes(1)
|
||||
expect(shutdown3).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
test("does not re-register signal handlers for subsequent managers", () => {
|
||||
const manager1 = { shutdown: mockShutdown }
|
||||
const manager2 = { shutdown: mockShutdown }
|
||||
registeredManagers.push(manager1, manager2)
|
||||
|
||||
registerManagerForCleanup(manager1)
|
||||
const callsAfterFirst = processOnCalls.length
|
||||
|
||||
registerManagerForCleanup(manager2)
|
||||
|
||||
expect(processOnCalls.length).toBe(callsAfterFirst)
|
||||
})
|
||||
})
|
||||
|
||||
describe("unregisterManagerForCleanup", () => {
|
||||
test("removes signal handlers when last manager unregisters", () => {
|
||||
const manager = { shutdown: mockShutdown }
|
||||
registeredManagers.push(manager)
|
||||
|
||||
registerManagerForCleanup(manager)
|
||||
unregisterManagerForCleanup(manager)
|
||||
registeredManagers.length = 0
|
||||
|
||||
const offSignals = processOffCalls.map(([signal]) => signal)
|
||||
expect(offSignals).toContain("SIGINT")
|
||||
expect(offSignals).toContain("SIGTERM")
|
||||
expect(offSignals).toContain("beforeExit")
|
||||
expect(offSignals).toContain("exit")
|
||||
})
|
||||
|
||||
test("keeps signal handlers when other managers remain", () => {
|
||||
const manager1 = { shutdown: mockShutdown }
|
||||
const manager2 = { shutdown: mockShutdown }
|
||||
registeredManagers.push(manager1, manager2)
|
||||
|
||||
registerManagerForCleanup(manager1)
|
||||
registerManagerForCleanup(manager2)
|
||||
|
||||
unregisterManagerForCleanup(manager2)
|
||||
|
||||
expect(processOffCalls.length).toBe(0)
|
||||
})
|
||||
|
||||
test("remaining managers still get shutdown after partial unregister", () => {
|
||||
const shutdown1 = mock(() => {})
|
||||
const shutdown2 = mock(() => {})
|
||||
const manager1 = { shutdown: shutdown1 }
|
||||
const manager2 = { shutdown: shutdown2 }
|
||||
registeredManagers.push(manager1, manager2)
|
||||
|
||||
registerManagerForCleanup(manager1)
|
||||
registerManagerForCleanup(manager2)
|
||||
|
||||
const [, listener] = processOnCalls[0]
|
||||
unregisterManagerForCleanup(manager2)
|
||||
|
||||
listener()
|
||||
|
||||
expect(shutdown1).toHaveBeenCalledTimes(1)
|
||||
expect(shutdown2).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
81
src/features/background-agent/process-cleanup.ts
Normal file
81
src/features/background-agent/process-cleanup.ts
Normal file
@ -0,0 +1,81 @@
|
||||
import { log } from "../../shared"
|
||||
|
||||
type ProcessCleanupEvent = NodeJS.Signals | "beforeExit" | "exit"
|
||||
|
||||
function registerProcessSignal(
|
||||
signal: ProcessCleanupEvent,
|
||||
handler: () => void,
|
||||
exitAfter: boolean
|
||||
): () => void {
|
||||
const listener = () => {
|
||||
handler()
|
||||
if (exitAfter) {
|
||||
process.exitCode = 0
|
||||
setTimeout(() => process.exit(), 6000)
|
||||
}
|
||||
}
|
||||
process.on(signal, listener)
|
||||
return listener
|
||||
}
|
||||
|
||||
interface CleanupTarget {
|
||||
shutdown(): void
|
||||
}
|
||||
|
||||
const cleanupManagers = new Set<CleanupTarget>()
|
||||
let cleanupRegistered = false
|
||||
const cleanupHandlers = new Map<ProcessCleanupEvent, () => void>()
|
||||
|
||||
export function registerManagerForCleanup(manager: CleanupTarget): void {
|
||||
cleanupManagers.add(manager)
|
||||
|
||||
if (cleanupRegistered) return
|
||||
cleanupRegistered = true
|
||||
|
||||
const cleanupAll = () => {
|
||||
for (const m of cleanupManagers) {
|
||||
try {
|
||||
m.shutdown()
|
||||
} catch (error) {
|
||||
log("[background-agent] Error during shutdown cleanup:", error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const registerSignal = (signal: ProcessCleanupEvent, exitAfter: boolean): void => {
|
||||
const listener = registerProcessSignal(signal, cleanupAll, exitAfter)
|
||||
cleanupHandlers.set(signal, listener)
|
||||
}
|
||||
|
||||
registerSignal("SIGINT", true)
|
||||
registerSignal("SIGTERM", true)
|
||||
if (process.platform === "win32") {
|
||||
registerSignal("SIGBREAK", true)
|
||||
}
|
||||
registerSignal("beforeExit", false)
|
||||
registerSignal("exit", false)
|
||||
}
|
||||
|
||||
export function unregisterManagerForCleanup(manager: CleanupTarget): void {
|
||||
cleanupManagers.delete(manager)
|
||||
|
||||
if (cleanupManagers.size > 0) return
|
||||
|
||||
for (const [signal, listener] of cleanupHandlers.entries()) {
|
||||
process.off(signal, listener)
|
||||
}
|
||||
cleanupHandlers.clear()
|
||||
cleanupRegistered = false
|
||||
}
|
||||
|
||||
/** @internal — test-only reset for module-level singleton state */
|
||||
export function _resetForTesting(): void {
|
||||
for (const manager of [...cleanupManagers]) {
|
||||
cleanupManagers.delete(manager)
|
||||
}
|
||||
for (const [signal, listener] of cleanupHandlers.entries()) {
|
||||
process.off(signal, listener)
|
||||
}
|
||||
cleanupHandlers.clear()
|
||||
cleanupRegistered = false
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user