diff --git a/src/hooks/todo-continuation-enforcer.test.ts b/src/hooks/todo-continuation-enforcer.test.ts index 32e28bf2..e680cfd6 100644 --- a/src/hooks/todo-continuation-enforcer.test.ts +++ b/src/hooks/todo-continuation-enforcer.test.ts @@ -548,4 +548,263 @@ describe("todo-continuation-enforcer", () => { // #then - no continuation (abort error detected) expect(promptCalls).toHaveLength(0) }) + + test("should skip injection when abort detected via session.error event (event-based, primary)", async () => { + // #given - session with incomplete todos + const sessionID = "main-event-abort" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - abort error event fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // #when - session goes idle immediately after + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - no continuation (abort detected via event) + expect(promptCalls).toHaveLength(0) + }) + + test("should skip injection when AbortError detected via session.error event", async () => { + // #given - session with incomplete todos + const sessionID = "main-event-abort-dom" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - AbortError event fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "AbortError" } }, + }, + }) + + // #when - session goes idle + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - no continuation (abort detected via event) + expect(promptCalls).toHaveLength(0) + }) + + test("should inject when abort flag is stale (>3s old)", async () => { + // #given - session with incomplete todos and old abort timestamp + const sessionID = "main-stale-abort" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - abort error fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // #when - wait >3s then idle fires + await new Promise(r => setTimeout(r, 3100)) + + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - continuation injected (abort flag is stale) + expect(promptCalls.length).toBeGreaterThan(0) + }, 10000) + + test("should clear abort flag on user message activity", async () => { + // #given - session with abort detected + const sessionID = "main-clear-on-user" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - abort error fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // #when - user sends new message (clears abort flag) + await new Promise(r => setTimeout(r, 600)) + await hook.handler({ + event: { + type: "message.updated", + properties: { info: { sessionID, role: "user" } }, + }, + }) + + // #when - session goes idle + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - continuation injected (abort flag was cleared by user activity) + expect(promptCalls.length).toBeGreaterThan(0) + }) + + test("should clear abort flag on assistant message activity", async () => { + // #given - session with abort detected + const sessionID = "main-clear-on-assistant" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - abort error fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // #when - assistant starts responding (clears abort flag) + await hook.handler({ + event: { + type: "message.updated", + properties: { info: { sessionID, role: "assistant" } }, + }, + }) + + // #when - session goes idle + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - continuation injected (abort flag was cleared by assistant activity) + expect(promptCalls.length).toBeGreaterThan(0) + }) + + test("should clear abort flag on tool execution", async () => { + // #given - session with abort detected + const sessionID = "main-clear-on-tool" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - abort error fires + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // #when - tool executes (clears abort flag) + await hook.handler({ + event: { + type: "tool.execute.before", + properties: { sessionID }, + }, + }) + + // #when - session goes idle + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - continuation injected (abort flag was cleared by tool execution) + expect(promptCalls.length).toBeGreaterThan(0) + }) + + test("should use event-based detection even when API indicates no abort (event wins)", async () => { + // #given - session with abort event but API shows no error + const sessionID = "main-event-wins" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant" } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - abort error event fires (but API doesn't have it yet) + await hook.handler({ + event: { + type: "session.error", + properties: { sessionID, error: { name: "MessageAbortedError" } }, + }, + }) + + // #when - session goes idle + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - no continuation (event-based detection wins over API) + expect(promptCalls).toHaveLength(0) + }) + + test("should use API fallback when event is missed but API shows abort", async () => { + // #given - session where event was missed but API shows abort + const sessionID = "main-api-fallback" + setMainSession(sessionID) + mockMessages = [ + { info: { id: "msg-1", role: "user" } }, + { info: { id: "msg-2", role: "assistant", error: { name: "MessageAbortedError" } } }, + ] + + const hook = createTodoContinuationEnforcer(createMockPluginInput(), {}) + + // #when - session goes idle without prior session.error event + await hook.handler({ + event: { type: "session.idle", properties: { sessionID } }, + }) + + await new Promise(r => setTimeout(r, 3000)) + + // #then - no continuation (API fallback detected the abort) + expect(promptCalls).toHaveLength(0) + }) }) diff --git a/src/hooks/todo-continuation-enforcer.ts b/src/hooks/todo-continuation-enforcer.ts index 4c5fa694..b551a7ca 100644 --- a/src/hooks/todo-continuation-enforcer.ts +++ b/src/hooks/todo-continuation-enforcer.ts @@ -36,6 +36,7 @@ interface SessionState { countdownInterval?: ReturnType isRecovering?: boolean countdownStartedAt?: number + abortDetectedAt?: number } const CONTINUATION_PROMPT = `[SYSTEM REMINDER - TODO CONTINUATION] @@ -254,6 +255,13 @@ export function createTodoContinuationEnforcer( const sessionID = props?.sessionID as string | undefined if (!sessionID) return + const error = props?.error as { name?: string } | undefined + if (error?.name === "MessageAbortedError" || error?.name === "AbortError") { + const state = getState(sessionID) + state.abortDetectedAt = Date.now() + log(`[${HOOK_NAME}] Abort detected via session.error`, { sessionID, errorName: error.name }) + } + cancelCountdown(sessionID) log(`[${HOOK_NAME}] session.error`, { sessionID }) return @@ -281,6 +289,18 @@ export function createTodoContinuationEnforcer( return } + // Check 1: Event-based abort detection (primary, most reliable) + if (state.abortDetectedAt) { + const timeSinceAbort = Date.now() - state.abortDetectedAt + const ABORT_WINDOW_MS = 3000 + if (timeSinceAbort < ABORT_WINDOW_MS) { + log(`[${HOOK_NAME}] Skipped: abort detected via event ${timeSinceAbort}ms ago`, { sessionID }) + state.abortDetectedAt = undefined + return + } + state.abortDetectedAt = undefined + } + const hasRunningBgTasks = backgroundManager ? backgroundManager.getTasksByParentSession(sessionID).some(t => t.status === "running") : false @@ -290,6 +310,7 @@ export function createTodoContinuationEnforcer( return } + // Check 2: API-based abort detection (fallback, for cases where event was missed) try { const messagesResp = await ctx.client.session.messages({ path: { id: sessionID }, @@ -298,7 +319,7 @@ export function createTodoContinuationEnforcer( const messages = (messagesResp as { data?: Array<{ info?: MessageInfo }> }).data ?? [] if (isLastAssistantMessageAborted(messages)) { - log(`[${HOOK_NAME}] Skipped: last assistant message was aborted`, { sessionID }) + log(`[${HOOK_NAME}] Skipped: last assistant message was aborted (API fallback)`, { sessionID }) return } } catch (err) { @@ -367,10 +388,13 @@ export function createTodoContinuationEnforcer( return } } + if (state) state.abortDetectedAt = undefined cancelCountdown(sessionID) } if (role === "assistant") { + const state = sessions.get(sessionID) + if (state) state.abortDetectedAt = undefined cancelCountdown(sessionID) } return @@ -382,6 +406,8 @@ export function createTodoContinuationEnforcer( const role = info?.role as string | undefined if (sessionID && role === "assistant") { + const state = sessions.get(sessionID) + if (state) state.abortDetectedAt = undefined cancelCountdown(sessionID) } return @@ -390,6 +416,8 @@ export function createTodoContinuationEnforcer( if (event.type === "tool.execute.before" || event.type === "tool.execute.after") { const sessionID = props?.sessionID as string | undefined if (sessionID) { + const state = sessions.get(sessionID) + if (state) state.abortDetectedAt = undefined cancelCountdown(sessionID) } return