Mirrors todo-continuation-enforcer but reads from file-based task storage instead of OpenCode's todo API. Includes 19 tests covering all skip conditions, abort detection, countdown, and recovery scenarios.
531 lines
17 KiB
TypeScript
531 lines
17 KiB
TypeScript
import type { PluginInput } from "@opencode-ai/plugin"
|
|
import { existsSync, readdirSync } from "node:fs"
|
|
import { join } from "node:path"
|
|
|
|
import type { BackgroundManager } from "../features/background-agent"
|
|
import { getMainSessionID, subagentSessions } from "../features/claude-code-session-state"
|
|
import {
|
|
findNearestMessageWithFields,
|
|
MESSAGE_STORAGE,
|
|
type ToolPermission,
|
|
} from "../features/hook-message-injector"
|
|
import { listTaskFiles, readJsonSafe, getTaskDir } from "../features/claude-tasks/storage"
|
|
import type { OhMyOpenCodeConfig } from "../config/schema"
|
|
import { TaskObjectSchema } from "../tools/task/types"
|
|
import type { TaskObject } from "../tools/task/types"
|
|
import { log } from "../shared/logger"
|
|
import { createSystemDirective, SystemDirectiveTypes } from "../shared/system-directive"
|
|
|
|
const HOOK_NAME = "task-continuation-enforcer"
|
|
|
|
const DEFAULT_SKIP_AGENTS = ["prometheus", "compaction"]
|
|
|
|
export interface TaskContinuationEnforcerOptions {
|
|
backgroundManager?: BackgroundManager
|
|
skipAgents?: string[]
|
|
isContinuationStopped?: (sessionID: string) => boolean
|
|
}
|
|
|
|
export interface TaskContinuationEnforcer {
|
|
handler: (input: { event: { type: string; properties?: unknown } }) => Promise<void>
|
|
markRecovering: (sessionID: string) => void
|
|
markRecoveryComplete: (sessionID: string) => void
|
|
cancelAllCountdowns: () => void
|
|
}
|
|
|
|
interface SessionState {
|
|
countdownTimer?: ReturnType<typeof setTimeout>
|
|
countdownInterval?: ReturnType<typeof setInterval>
|
|
isRecovering?: boolean
|
|
countdownStartedAt?: number
|
|
abortDetectedAt?: number
|
|
}
|
|
|
|
const CONTINUATION_PROMPT = `${createSystemDirective(SystemDirectiveTypes.TASK_CONTINUATION)}
|
|
|
|
Incomplete tasks remain in your task list. Continue working on the next pending task.
|
|
|
|
- Proceed without asking for permission
|
|
- Mark each task complete when finished
|
|
- Do not stop until all tasks are done`
|
|
|
|
const COUNTDOWN_SECONDS = 2
|
|
const TOAST_DURATION_MS = 900
|
|
const COUNTDOWN_GRACE_PERIOD_MS = 500
|
|
|
|
function getMessageDir(sessionID: string): string | null {
|
|
if (!existsSync(MESSAGE_STORAGE)) return null
|
|
|
|
const directPath = join(MESSAGE_STORAGE, sessionID)
|
|
if (existsSync(directPath)) return directPath
|
|
|
|
for (const dir of readdirSync(MESSAGE_STORAGE)) {
|
|
const sessionPath = join(MESSAGE_STORAGE, dir, sessionID)
|
|
if (existsSync(sessionPath)) return sessionPath
|
|
}
|
|
|
|
return null
|
|
}
|
|
|
|
function getIncompleteCount(tasks: TaskObject[]): number {
|
|
return tasks.filter(t => t.status !== "completed" && t.status !== "deleted").length
|
|
}
|
|
|
|
interface MessageInfo {
|
|
id?: string
|
|
role?: string
|
|
error?: { name?: string; data?: unknown }
|
|
}
|
|
|
|
function isLastAssistantMessageAborted(messages: Array<{ info?: MessageInfo }>): boolean {
|
|
if (!messages || messages.length === 0) return false
|
|
|
|
const assistantMessages = messages.filter(m => m.info?.role === "assistant")
|
|
if (assistantMessages.length === 0) return false
|
|
|
|
const lastAssistant = assistantMessages[assistantMessages.length - 1]
|
|
const errorName = lastAssistant.info?.error?.name
|
|
|
|
if (!errorName) return false
|
|
|
|
return errorName === "MessageAbortedError" || errorName === "AbortError"
|
|
}
|
|
|
|
function loadTasksFromDisk(config: Partial<OhMyOpenCodeConfig>): TaskObject[] {
|
|
const taskIds = listTaskFiles(config)
|
|
const taskDirectory = getTaskDir(config)
|
|
const tasks: TaskObject[] = []
|
|
|
|
for (const id of taskIds) {
|
|
const task = readJsonSafe<TaskObject>(join(taskDirectory, `${id}.json`), TaskObjectSchema)
|
|
if (task) tasks.push(task)
|
|
}
|
|
|
|
return tasks
|
|
}
|
|
|
|
export function createTaskContinuationEnforcer(
|
|
ctx: PluginInput,
|
|
config: Partial<OhMyOpenCodeConfig>,
|
|
options: TaskContinuationEnforcerOptions = {}
|
|
): TaskContinuationEnforcer {
|
|
const { backgroundManager, skipAgents = DEFAULT_SKIP_AGENTS, isContinuationStopped } = options
|
|
const sessions = new Map<string, SessionState>()
|
|
|
|
function getState(sessionID: string): SessionState {
|
|
let state = sessions.get(sessionID)
|
|
if (!state) {
|
|
state = {}
|
|
sessions.set(sessionID, state)
|
|
}
|
|
return state
|
|
}
|
|
|
|
function cancelCountdown(sessionID: string): void {
|
|
const state = sessions.get(sessionID)
|
|
if (!state) return
|
|
|
|
if (state.countdownTimer) {
|
|
clearTimeout(state.countdownTimer)
|
|
state.countdownTimer = undefined
|
|
}
|
|
if (state.countdownInterval) {
|
|
clearInterval(state.countdownInterval)
|
|
state.countdownInterval = undefined
|
|
}
|
|
state.countdownStartedAt = undefined
|
|
}
|
|
|
|
function cleanup(sessionID: string): void {
|
|
cancelCountdown(sessionID)
|
|
sessions.delete(sessionID)
|
|
}
|
|
|
|
const markRecovering = (sessionID: string): void => {
|
|
const state = getState(sessionID)
|
|
state.isRecovering = true
|
|
cancelCountdown(sessionID)
|
|
log(`[${HOOK_NAME}] Session marked as recovering`, { sessionID })
|
|
}
|
|
|
|
const markRecoveryComplete = (sessionID: string): void => {
|
|
const state = sessions.get(sessionID)
|
|
if (state) {
|
|
state.isRecovering = false
|
|
log(`[${HOOK_NAME}] Session recovery complete`, { sessionID })
|
|
}
|
|
}
|
|
|
|
async function showCountdownToast(seconds: number, incompleteCount: number): Promise<void> {
|
|
await ctx.client.tui
|
|
.showToast({
|
|
body: {
|
|
title: "Task Continuation",
|
|
message: `Resuming in ${seconds}s... (${incompleteCount} tasks remaining)`,
|
|
variant: "warning" as const,
|
|
duration: TOAST_DURATION_MS,
|
|
},
|
|
})
|
|
.catch(() => {})
|
|
}
|
|
|
|
interface ResolvedMessageInfo {
|
|
agent?: string
|
|
model?: { providerID: string; modelID: string }
|
|
tools?: Record<string, ToolPermission>
|
|
}
|
|
|
|
async function injectContinuation(
|
|
sessionID: string,
|
|
incompleteCount: number,
|
|
total: number,
|
|
resolvedInfo?: ResolvedMessageInfo
|
|
): Promise<void> {
|
|
const state = sessions.get(sessionID)
|
|
|
|
if (state?.isRecovering) {
|
|
log(`[${HOOK_NAME}] Skipped injection: in recovery`, { sessionID })
|
|
return
|
|
}
|
|
|
|
const hasRunningBgTasks = backgroundManager
|
|
? backgroundManager.getTasksByParentSession(sessionID).some(t => t.status === "running")
|
|
: false
|
|
|
|
if (hasRunningBgTasks) {
|
|
log(`[${HOOK_NAME}] Skipped injection: background tasks running`, { sessionID })
|
|
return
|
|
}
|
|
|
|
const tasks = loadTasksFromDisk(config)
|
|
const freshIncompleteCount = getIncompleteCount(tasks)
|
|
if (freshIncompleteCount === 0) {
|
|
log(`[${HOOK_NAME}] Skipped injection: no incomplete tasks`, { sessionID })
|
|
return
|
|
}
|
|
|
|
let agentName = resolvedInfo?.agent
|
|
let model = resolvedInfo?.model
|
|
let tools = resolvedInfo?.tools
|
|
|
|
if (!agentName || !model) {
|
|
const messageDir = getMessageDir(sessionID)
|
|
const prevMessage = messageDir ? findNearestMessageWithFields(messageDir) : null
|
|
agentName = agentName ?? prevMessage?.agent
|
|
model =
|
|
model ??
|
|
(prevMessage?.model?.providerID && prevMessage?.model?.modelID
|
|
? {
|
|
providerID: prevMessage.model.providerID,
|
|
modelID: prevMessage.model.modelID,
|
|
...(prevMessage.model.variant ? { variant: prevMessage.model.variant } : {}),
|
|
}
|
|
: undefined)
|
|
tools = tools ?? prevMessage?.tools
|
|
}
|
|
|
|
if (agentName && skipAgents.includes(agentName)) {
|
|
log(`[${HOOK_NAME}] Skipped: agent in skipAgents list`, { sessionID, agent: agentName })
|
|
return
|
|
}
|
|
|
|
const editPermission = tools?.edit
|
|
const writePermission = tools?.write
|
|
const hasWritePermission =
|
|
!tools ||
|
|
(editPermission !== false && editPermission !== "deny" && writePermission !== false && writePermission !== "deny")
|
|
if (!hasWritePermission) {
|
|
log(`[${HOOK_NAME}] Skipped: agent lacks write permission`, { sessionID, agent: agentName })
|
|
return
|
|
}
|
|
|
|
const incompleteTasks = tasks.filter(t => t.status !== "completed" && t.status !== "deleted")
|
|
const taskList = incompleteTasks.map(t => `- [${t.status}] ${t.subject}`).join("\n")
|
|
const prompt = `${CONTINUATION_PROMPT}
|
|
|
|
[Status: ${tasks.length - freshIncompleteCount}/${tasks.length} completed, ${freshIncompleteCount} remaining]
|
|
|
|
Remaining tasks:
|
|
${taskList}`
|
|
|
|
try {
|
|
log(`[${HOOK_NAME}] Injecting continuation`, {
|
|
sessionID,
|
|
agent: agentName,
|
|
model,
|
|
incompleteCount: freshIncompleteCount,
|
|
})
|
|
|
|
await ctx.client.session.prompt({
|
|
path: { id: sessionID },
|
|
body: {
|
|
agent: agentName,
|
|
...(model !== undefined ? { model } : {}),
|
|
parts: [{ type: "text", text: prompt }],
|
|
},
|
|
query: { directory: ctx.directory },
|
|
})
|
|
|
|
log(`[${HOOK_NAME}] Injection successful`, { sessionID })
|
|
} catch (err) {
|
|
log(`[${HOOK_NAME}] Injection failed`, { sessionID, error: String(err) })
|
|
}
|
|
}
|
|
|
|
function startCountdown(
|
|
sessionID: string,
|
|
incompleteCount: number,
|
|
total: number,
|
|
resolvedInfo?: ResolvedMessageInfo
|
|
): void {
|
|
const state = getState(sessionID)
|
|
cancelCountdown(sessionID)
|
|
|
|
let secondsRemaining = COUNTDOWN_SECONDS
|
|
showCountdownToast(secondsRemaining, incompleteCount)
|
|
state.countdownStartedAt = Date.now()
|
|
|
|
state.countdownInterval = setInterval(() => {
|
|
secondsRemaining--
|
|
if (secondsRemaining > 0) {
|
|
showCountdownToast(secondsRemaining, incompleteCount)
|
|
}
|
|
}, 1000)
|
|
|
|
state.countdownTimer = setTimeout(() => {
|
|
cancelCountdown(sessionID)
|
|
injectContinuation(sessionID, incompleteCount, total, resolvedInfo)
|
|
}, COUNTDOWN_SECONDS * 1000)
|
|
|
|
log(`[${HOOK_NAME}] Countdown started`, { sessionID, seconds: COUNTDOWN_SECONDS, incompleteCount })
|
|
}
|
|
|
|
const handler = async ({ event }: { event: { type: string; properties?: unknown } }): Promise<void> => {
|
|
const props = event.properties as Record<string, unknown> | undefined
|
|
|
|
if (event.type === "session.error") {
|
|
const sessionID = props?.sessionID as string | undefined
|
|
if (!sessionID) return
|
|
|
|
const error = props?.error as { name?: string } | undefined
|
|
if (error?.name === "MessageAbortedError" || error?.name === "AbortError") {
|
|
const state = getState(sessionID)
|
|
state.abortDetectedAt = Date.now()
|
|
log(`[${HOOK_NAME}] Abort detected via session.error`, { sessionID, errorName: error.name })
|
|
}
|
|
|
|
cancelCountdown(sessionID)
|
|
log(`[${HOOK_NAME}] session.error`, { sessionID })
|
|
return
|
|
}
|
|
|
|
if (event.type === "session.idle") {
|
|
const sessionID = props?.sessionID as string | undefined
|
|
if (!sessionID) return
|
|
|
|
log(`[${HOOK_NAME}] session.idle`, { sessionID })
|
|
|
|
const mainSessionID = getMainSessionID()
|
|
const isMainSession = sessionID === mainSessionID
|
|
const isBackgroundTaskSession = subagentSessions.has(sessionID)
|
|
|
|
if (mainSessionID && !isMainSession && !isBackgroundTaskSession) {
|
|
log(`[${HOOK_NAME}] Skipped: not main or background task session`, { sessionID })
|
|
return
|
|
}
|
|
|
|
const state = getState(sessionID)
|
|
|
|
if (state.isRecovering) {
|
|
log(`[${HOOK_NAME}] Skipped: in recovery`, { sessionID })
|
|
return
|
|
}
|
|
|
|
// Check 1: Event-based abort detection (primary, most reliable)
|
|
if (state.abortDetectedAt) {
|
|
const timeSinceAbort = Date.now() - state.abortDetectedAt
|
|
const ABORT_WINDOW_MS = 3000
|
|
if (timeSinceAbort < ABORT_WINDOW_MS) {
|
|
log(`[${HOOK_NAME}] Skipped: abort detected via event ${timeSinceAbort}ms ago`, { sessionID })
|
|
state.abortDetectedAt = undefined
|
|
return
|
|
}
|
|
state.abortDetectedAt = undefined
|
|
}
|
|
|
|
const hasRunningBgTasks = backgroundManager
|
|
? backgroundManager.getTasksByParentSession(sessionID).some(t => t.status === "running")
|
|
: false
|
|
|
|
if (hasRunningBgTasks) {
|
|
log(`[${HOOK_NAME}] Skipped: background tasks running`, { sessionID })
|
|
return
|
|
}
|
|
|
|
// Check 2: API-based abort detection (fallback, for cases where event was missed)
|
|
try {
|
|
const messagesResp = await ctx.client.session.messages({
|
|
path: { id: sessionID },
|
|
query: { directory: ctx.directory },
|
|
})
|
|
const messages = (messagesResp as { data?: Array<{ info?: MessageInfo }> }).data ?? []
|
|
|
|
if (isLastAssistantMessageAborted(messages)) {
|
|
log(`[${HOOK_NAME}] Skipped: last assistant message was aborted (API fallback)`, { sessionID })
|
|
return
|
|
}
|
|
} catch (err) {
|
|
log(`[${HOOK_NAME}] Messages fetch failed, continuing`, { sessionID, error: String(err) })
|
|
}
|
|
|
|
const tasks = loadTasksFromDisk(config)
|
|
|
|
if (!tasks || tasks.length === 0) {
|
|
log(`[${HOOK_NAME}] No tasks`, { sessionID })
|
|
return
|
|
}
|
|
|
|
const incompleteCount = getIncompleteCount(tasks)
|
|
if (incompleteCount === 0) {
|
|
log(`[${HOOK_NAME}] All tasks complete`, { sessionID, total: tasks.length })
|
|
return
|
|
}
|
|
|
|
let resolvedInfo: ResolvedMessageInfo | undefined
|
|
let hasCompactionMessage = false
|
|
try {
|
|
const messagesResp = await ctx.client.session.messages({
|
|
path: { id: sessionID },
|
|
})
|
|
const messages = (messagesResp.data ?? []) as Array<{
|
|
info?: {
|
|
agent?: string
|
|
model?: { providerID: string; modelID: string }
|
|
modelID?: string
|
|
providerID?: string
|
|
tools?: Record<string, ToolPermission>
|
|
}
|
|
}>
|
|
for (let i = messages.length - 1; i >= 0; i--) {
|
|
const info = messages[i].info
|
|
if (info?.agent === "compaction") {
|
|
hasCompactionMessage = true
|
|
continue
|
|
}
|
|
if (info?.agent || info?.model || (info?.modelID && info?.providerID)) {
|
|
resolvedInfo = {
|
|
agent: info.agent,
|
|
model:
|
|
info.model ??
|
|
(info.providerID && info.modelID
|
|
? { providerID: info.providerID, modelID: info.modelID }
|
|
: undefined),
|
|
tools: info.tools,
|
|
}
|
|
break
|
|
}
|
|
}
|
|
} catch (err) {
|
|
log(`[${HOOK_NAME}] Failed to fetch messages for agent check`, { sessionID, error: String(err) })
|
|
}
|
|
|
|
log(`[${HOOK_NAME}] Agent check`, {
|
|
sessionID,
|
|
agentName: resolvedInfo?.agent,
|
|
skipAgents,
|
|
hasCompactionMessage,
|
|
})
|
|
if (resolvedInfo?.agent && skipAgents.includes(resolvedInfo.agent)) {
|
|
log(`[${HOOK_NAME}] Skipped: agent in skipAgents list`, { sessionID, agent: resolvedInfo.agent })
|
|
return
|
|
}
|
|
if (hasCompactionMessage && !resolvedInfo?.agent) {
|
|
log(`[${HOOK_NAME}] Skipped: compaction occurred but no agent info resolved`, { sessionID })
|
|
return
|
|
}
|
|
|
|
if (isContinuationStopped?.(sessionID)) {
|
|
log(`[${HOOK_NAME}] Skipped: continuation stopped for session`, { sessionID })
|
|
return
|
|
}
|
|
|
|
startCountdown(sessionID, incompleteCount, tasks.length, resolvedInfo)
|
|
return
|
|
}
|
|
|
|
if (event.type === "message.updated") {
|
|
const info = props?.info as Record<string, unknown> | undefined
|
|
const sessionID = info?.sessionID as string | undefined
|
|
const role = info?.role as string | undefined
|
|
|
|
if (!sessionID) return
|
|
|
|
if (role === "user") {
|
|
const state = sessions.get(sessionID)
|
|
if (state?.countdownStartedAt) {
|
|
const elapsed = Date.now() - state.countdownStartedAt
|
|
if (elapsed < COUNTDOWN_GRACE_PERIOD_MS) {
|
|
log(`[${HOOK_NAME}] Ignoring user message in grace period`, { sessionID, elapsed })
|
|
return
|
|
}
|
|
}
|
|
if (state) state.abortDetectedAt = undefined
|
|
cancelCountdown(sessionID)
|
|
}
|
|
|
|
if (role === "assistant") {
|
|
const state = sessions.get(sessionID)
|
|
if (state) state.abortDetectedAt = undefined
|
|
cancelCountdown(sessionID)
|
|
}
|
|
return
|
|
}
|
|
|
|
if (event.type === "message.part.updated") {
|
|
const info = props?.info as Record<string, unknown> | undefined
|
|
const sessionID = info?.sessionID as string | undefined
|
|
const role = info?.role as string | undefined
|
|
|
|
if (sessionID && role === "assistant") {
|
|
const state = sessions.get(sessionID)
|
|
if (state) state.abortDetectedAt = undefined
|
|
cancelCountdown(sessionID)
|
|
}
|
|
return
|
|
}
|
|
|
|
if (event.type === "tool.execute.before" || event.type === "tool.execute.after") {
|
|
const sessionID = props?.sessionID as string | undefined
|
|
if (sessionID) {
|
|
const state = sessions.get(sessionID)
|
|
if (state) state.abortDetectedAt = undefined
|
|
cancelCountdown(sessionID)
|
|
}
|
|
return
|
|
}
|
|
|
|
if (event.type === "session.deleted") {
|
|
const sessionInfo = props?.info as { id?: string } | undefined
|
|
if (sessionInfo?.id) {
|
|
cleanup(sessionInfo.id)
|
|
log(`[${HOOK_NAME}] Session deleted: cleaned up`, { sessionID: sessionInfo.id })
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
const cancelAllCountdowns = (): void => {
|
|
for (const sessionID of sessions.keys()) {
|
|
cancelCountdown(sessionID)
|
|
}
|
|
log(`[${HOOK_NAME}] All countdowns cancelled`)
|
|
}
|
|
|
|
return {
|
|
handler,
|
|
markRecovering,
|
|
markRecoveryComplete,
|
|
cancelAllCountdowns,
|
|
}
|
|
}
|