diff --git a/src/hooks/agent-switch/hook.test.ts b/src/hooks/agent-switch/hook.test.ts index cb2e7c64..30cd6c1e 100644 --- a/src/hooks/agent-switch/hook.test.ts +++ b/src/hooks/agent-switch/hook.test.ts @@ -139,6 +139,54 @@ describe("agent-switch hook", () => { expect(getPendingSwitch("ses-4")).toBeUndefined() }) + test("clears pending switch on session.error with info.id", async () => { + const ctx = { + client: { + session: { + promptAsync: async () => {}, + messages: async () => ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-10", "atlas", "fix this") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.error", + properties: { info: { id: "ses-10" } }, + }, + }) + + expect(getPendingSwitch("ses-10")).toBeUndefined() + }) + + test("clears pending switch on session.error with sessionID property", async () => { + const ctx = { + client: { + session: { + promptAsync: async () => {}, + messages: async () => ({ data: [] }), + message: async () => ({ data: { parts: [] } }), + }, + }, + } as any + + setPendingSwitch("ses-11", "atlas", "fix this") + const hook = createAgentSwitchHook(ctx) + + await hook.event({ + event: { + type: "session.error", + properties: { sessionID: "ses-11" }, + }, + }) + + expect(getPendingSwitch("ses-11")).toBeUndefined() + }) + test("recovers missing switch_agent tool call from Athena handoff text", async () => { const promptAsyncCalls: Array> = [] let switched = false diff --git a/src/hooks/agent-switch/hook.ts b/src/hooks/agent-switch/hook.ts index bc200ce8..2084eee4 100644 --- a/src/hooks/agent-switch/hook.ts +++ b/src/hooks/agent-switch/hook.ts @@ -56,6 +56,21 @@ export function createAgentSwitchHook(ctx: PluginInput) { return } + if (input.event.type === "session.error") { + const props = input.event.properties as Record | undefined + const info = props?.info as Record | undefined + const erroredSessionID = info?.id ?? props?.sessionID + if (typeof erroredSessionID === "string") { + clearPendingSwitchRuntime(erroredSessionID) + for (const key of Array.from(processedFallbackMessages)) { + if (key.startsWith(`${erroredSessionID}:`)) { + processedFallbackMessages.delete(key) + } + } + } + return + } + if (input.event.type === "message.updated") { const props = input.event.properties as Record | undefined const info = props?.info as Record | undefined