fix(run): inherit main-session tool permissions for continuation prompts
This commit is contained in:
parent
7622eddb0d
commit
096db59399
@ -6,7 +6,14 @@ import type {
|
|||||||
ResumeInput,
|
ResumeInput,
|
||||||
} from "./types"
|
} from "./types"
|
||||||
import { TaskHistory } from "./task-history"
|
import { TaskHistory } from "./task-history"
|
||||||
import { log, getAgentToolRestrictions, normalizeSDKResponse, promptWithModelSuggestionRetry } from "../../shared"
|
import {
|
||||||
|
log,
|
||||||
|
getAgentToolRestrictions,
|
||||||
|
normalizePromptTools,
|
||||||
|
normalizeSDKResponse,
|
||||||
|
promptWithModelSuggestionRetry,
|
||||||
|
resolveInheritedPromptTools,
|
||||||
|
} from "../../shared"
|
||||||
import { setSessionTools } from "../../shared/session-tools-store"
|
import { setSessionTools } from "../../shared/session-tools-store"
|
||||||
import { ConcurrencyManager } from "./concurrency"
|
import { ConcurrencyManager } from "./concurrency"
|
||||||
import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema"
|
import type { BackgroundTaskConfig, TmuxConfig } from "../../config/schema"
|
||||||
@ -1246,12 +1253,19 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
|||||||
|
|
||||||
let agent: string | undefined = task.parentAgent
|
let agent: string | undefined = task.parentAgent
|
||||||
let model: { providerID: string; modelID: string } | undefined
|
let model: { providerID: string; modelID: string } | undefined
|
||||||
|
let tools: Record<string, boolean> | undefined = task.parentTools
|
||||||
|
|
||||||
if (this.enableParentSessionNotifications) {
|
if (this.enableParentSessionNotifications) {
|
||||||
try {
|
try {
|
||||||
const messagesResp = await this.client.session.messages({ path: { id: task.parentSessionID } })
|
const messagesResp = await this.client.session.messages({ path: { id: task.parentSessionID } })
|
||||||
const messages = normalizeSDKResponse(messagesResp, [] as Array<{
|
const messages = normalizeSDKResponse(messagesResp, [] as Array<{
|
||||||
info?: { agent?: string; model?: { providerID: string; modelID: string }; modelID?: string; providerID?: string }
|
info?: {
|
||||||
|
agent?: string
|
||||||
|
model?: { providerID: string; modelID: string }
|
||||||
|
modelID?: string
|
||||||
|
providerID?: string
|
||||||
|
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||||
|
}
|
||||||
}>)
|
}>)
|
||||||
for (let i = messages.length - 1; i >= 0; i--) {
|
for (let i = messages.length - 1; i >= 0; i--) {
|
||||||
const info = messages[i].info
|
const info = messages[i].info
|
||||||
@ -1261,6 +1275,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
|||||||
if (info?.agent || info?.model || (info?.modelID && info?.providerID)) {
|
if (info?.agent || info?.model || (info?.modelID && info?.providerID)) {
|
||||||
agent = info.agent ?? task.parentAgent
|
agent = info.agent ?? task.parentAgent
|
||||||
model = info.model ?? (info.providerID && info.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined)
|
model = info.model ?? (info.providerID && info.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined)
|
||||||
|
tools = normalizePromptTools(info.tools) ?? tools
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1277,8 +1292,11 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
|||||||
model = currentMessage?.model?.providerID && currentMessage?.model?.modelID
|
model = currentMessage?.model?.providerID && currentMessage?.model?.modelID
|
||||||
? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID }
|
? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID }
|
||||||
: undefined
|
: undefined
|
||||||
|
tools = normalizePromptTools(currentMessage?.tools) ?? tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tools = resolveInheritedPromptTools(task.parentSessionID, tools)
|
||||||
|
|
||||||
log("[background-agent] notifyParentSession context:", {
|
log("[background-agent] notifyParentSession context:", {
|
||||||
taskId: task.id,
|
taskId: task.id,
|
||||||
resolvedAgent: agent,
|
resolvedAgent: agent,
|
||||||
@ -1292,7 +1310,7 @@ Use \`background_output(task_id="${task.id}")\` to retrieve this result when rea
|
|||||||
noReply: !allComplete,
|
noReply: !allComplete,
|
||||||
...(agent !== undefined ? { agent } : {}),
|
...(agent !== undefined ? { agent } : {}),
|
||||||
...(model !== undefined ? { model } : {}),
|
...(model !== undefined ? { model } : {}),
|
||||||
...(task.parentTools ? { tools: task.parentTools } : {}),
|
...(tools ? { tools } : {}),
|
||||||
parts: [{ type: "text", text: notification }],
|
parts: [{ type: "text", text: notification }],
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import type { OpencodeClient } from "./constants"
|
|||||||
import type { BackgroundTask } from "./types"
|
import type { BackgroundTask } from "./types"
|
||||||
import { findNearestMessageWithFields } from "../hook-message-injector"
|
import { findNearestMessageWithFields } from "../hook-message-injector"
|
||||||
import { getMessageDir } from "../../shared"
|
import { getMessageDir } from "../../shared"
|
||||||
|
import { normalizePromptTools, resolveInheritedPromptTools } from "../../shared"
|
||||||
|
|
||||||
type AgentModel = { providerID: string; modelID: string }
|
type AgentModel = { providerID: string; modelID: string }
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ function isObject(value: unknown): value is Record<string, unknown> {
|
|||||||
function extractAgentAndModelFromMessage(message: unknown): {
|
function extractAgentAndModelFromMessage(message: unknown): {
|
||||||
agent?: string
|
agent?: string
|
||||||
model?: AgentModel
|
model?: AgentModel
|
||||||
|
tools?: Record<string, boolean>
|
||||||
} {
|
} {
|
||||||
if (!isObject(message)) return {}
|
if (!isObject(message)) return {}
|
||||||
const info = message["info"]
|
const info = message["info"]
|
||||||
@ -19,31 +21,33 @@ function extractAgentAndModelFromMessage(message: unknown): {
|
|||||||
|
|
||||||
const agent = typeof info["agent"] === "string" ? info["agent"] : undefined
|
const agent = typeof info["agent"] === "string" ? info["agent"] : undefined
|
||||||
const modelObj = info["model"]
|
const modelObj = info["model"]
|
||||||
|
const tools = normalizePromptTools(isObject(info["tools"]) ? info["tools"] as Record<string, unknown> as Record<string, boolean | "allow" | "deny" | "ask"> : undefined)
|
||||||
if (isObject(modelObj)) {
|
if (isObject(modelObj)) {
|
||||||
const providerID = modelObj["providerID"]
|
const providerID = modelObj["providerID"]
|
||||||
const modelID = modelObj["modelID"]
|
const modelID = modelObj["modelID"]
|
||||||
if (typeof providerID === "string" && typeof modelID === "string") {
|
if (typeof providerID === "string" && typeof modelID === "string") {
|
||||||
return { agent, model: { providerID, modelID } }
|
return { agent, model: { providerID, modelID }, tools }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const providerID = info["providerID"]
|
const providerID = info["providerID"]
|
||||||
const modelID = info["modelID"]
|
const modelID = info["modelID"]
|
||||||
if (typeof providerID === "string" && typeof modelID === "string") {
|
if (typeof providerID === "string" && typeof modelID === "string") {
|
||||||
return { agent, model: { providerID, modelID } }
|
return { agent, model: { providerID, modelID }, tools }
|
||||||
}
|
}
|
||||||
|
|
||||||
return { agent }
|
return { agent, tools }
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function resolveParentSessionAgentAndModel(input: {
|
export async function resolveParentSessionAgentAndModel(input: {
|
||||||
client: OpencodeClient
|
client: OpencodeClient
|
||||||
task: BackgroundTask
|
task: BackgroundTask
|
||||||
}): Promise<{ agent?: string; model?: AgentModel }> {
|
}): Promise<{ agent?: string; model?: AgentModel; tools?: Record<string, boolean> }> {
|
||||||
const { client, task } = input
|
const { client, task } = input
|
||||||
|
|
||||||
let agent: string | undefined = task.parentAgent
|
let agent: string | undefined = task.parentAgent
|
||||||
let model: AgentModel | undefined
|
let model: AgentModel | undefined
|
||||||
|
let tools: Record<string, boolean> | undefined = task.parentTools
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const messagesResp = await client.session.messages({
|
const messagesResp = await client.session.messages({
|
||||||
@ -55,9 +59,10 @@ export async function resolveParentSessionAgentAndModel(input: {
|
|||||||
|
|
||||||
for (let i = messages.length - 1; i >= 0; i--) {
|
for (let i = messages.length - 1; i >= 0; i--) {
|
||||||
const extracted = extractAgentAndModelFromMessage(messages[i])
|
const extracted = extractAgentAndModelFromMessage(messages[i])
|
||||||
if (extracted.agent || extracted.model) {
|
if (extracted.agent || extracted.model || extracted.tools) {
|
||||||
agent = extracted.agent ?? task.parentAgent
|
agent = extracted.agent ?? task.parentAgent
|
||||||
model = extracted.model
|
model = extracted.model
|
||||||
|
tools = extracted.tools ?? tools
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -69,7 +74,8 @@ export async function resolveParentSessionAgentAndModel(input: {
|
|||||||
currentMessage?.model?.providerID && currentMessage?.model?.modelID
|
currentMessage?.model?.providerID && currentMessage?.model?.modelID
|
||||||
? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID }
|
? { providerID: currentMessage.model.providerID, modelID: currentMessage.model.modelID }
|
||||||
: undefined
|
: undefined
|
||||||
|
tools = normalizePromptTools(currentMessage?.tools) ?? tools
|
||||||
}
|
}
|
||||||
|
|
||||||
return { agent, model }
|
return { agent, model, tools: resolveInheritedPromptTools(task.parentSessionID, tools) }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -56,7 +56,7 @@ export async function notifyParentSession(
|
|||||||
completedTasks,
|
completedTasks,
|
||||||
})
|
})
|
||||||
|
|
||||||
const { agent, model } = await resolveParentSessionAgentAndModel({ client, task })
|
const { agent, model, tools } = await resolveParentSessionAgentAndModel({ client, task })
|
||||||
|
|
||||||
log("[background-agent] notifyParentSession context:", {
|
log("[background-agent] notifyParentSession context:", {
|
||||||
taskId: task.id,
|
taskId: task.id,
|
||||||
@ -71,7 +71,7 @@ export async function notifyParentSession(
|
|||||||
noReply: !allComplete,
|
noReply: !allComplete,
|
||||||
...(agent !== undefined ? { agent } : {}),
|
...(agent !== undefined ? { agent } : {}),
|
||||||
...(model !== undefined ? { model } : {}),
|
...(model !== undefined ? { model } : {}),
|
||||||
...(task.parentTools ? { tools: task.parentTools } : {}),
|
...(tools ? { tools } : {}),
|
||||||
parts: [{ type: "text", text: notification }],
|
parts: [{ type: "text", text: notification }],
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import type { Client } from "./client"
|
|||||||
import { clearSessionState } from "./state"
|
import { clearSessionState } from "./state"
|
||||||
import { formatBytes } from "./message-builder"
|
import { formatBytes } from "./message-builder"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
|
import { resolveInheritedPromptTools } from "../../shared"
|
||||||
|
|
||||||
export async function runAggressiveTruncationStrategy(params: {
|
export async function runAggressiveTruncationStrategy(params: {
|
||||||
sessionID: string
|
sessionID: string
|
||||||
@ -61,9 +62,13 @@ export async function runAggressiveTruncationStrategy(params: {
|
|||||||
clearSessionState(params.autoCompactState, params.sessionID)
|
clearSessionState(params.autoCompactState, params.sessionID)
|
||||||
setTimeout(async () => {
|
setTimeout(async () => {
|
||||||
try {
|
try {
|
||||||
|
const inheritedTools = resolveInheritedPromptTools(params.sessionID)
|
||||||
await params.client.session.promptAsync({
|
await params.client.session.promptAsync({
|
||||||
path: { id: params.sessionID },
|
path: { id: params.sessionID },
|
||||||
body: { auto: true } as never,
|
body: {
|
||||||
|
auto: true,
|
||||||
|
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||||
|
} as never,
|
||||||
query: { directory: params.directory },
|
query: { directory: params.directory },
|
||||||
})
|
})
|
||||||
} catch {}
|
} catch {}
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import type { PluginInput } from "@opencode-ai/plugin"
|
import type { PluginInput } from "@opencode-ai/plugin"
|
||||||
import type { BackgroundManager } from "../../features/background-agent"
|
import type { BackgroundManager } from "../../features/background-agent"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
|
import { resolveInheritedPromptTools } from "../../shared"
|
||||||
import { HOOK_NAME } from "./hook-name"
|
import { HOOK_NAME } from "./hook-name"
|
||||||
import { BOULDER_CONTINUATION_PROMPT } from "./system-reminder-templates"
|
import { BOULDER_CONTINUATION_PROMPT } from "./system-reminder-templates"
|
||||||
import { resolveRecentModelForSession } from "./recent-model-resolver"
|
import { resolveRecentPromptContextForSession } from "./recent-model-resolver"
|
||||||
import type { SessionState } from "./types"
|
import type { SessionState } from "./types"
|
||||||
|
|
||||||
export async function injectBoulderContinuation(input: {
|
export async function injectBoulderContinuation(input: {
|
||||||
@ -43,13 +44,15 @@ export async function injectBoulderContinuation(input: {
|
|||||||
try {
|
try {
|
||||||
log(`[${HOOK_NAME}] Injecting boulder continuation`, { sessionID, planName, remaining })
|
log(`[${HOOK_NAME}] Injecting boulder continuation`, { sessionID, planName, remaining })
|
||||||
|
|
||||||
const model = await resolveRecentModelForSession(ctx, sessionID)
|
const promptContext = await resolveRecentPromptContextForSession(ctx, sessionID)
|
||||||
|
const inheritedTools = resolveInheritedPromptTools(sessionID, promptContext.tools)
|
||||||
|
|
||||||
await ctx.client.session.promptAsync({
|
await ctx.client.session.promptAsync({
|
||||||
path: { id: sessionID },
|
path: { id: sessionID },
|
||||||
body: {
|
body: {
|
||||||
agent: agent ?? "atlas",
|
agent: agent ?? "atlas",
|
||||||
...(model !== undefined ? { model } : {}),
|
...(promptContext.model !== undefined ? { model: promptContext.model } : {}),
|
||||||
|
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||||
parts: [{ type: "text", text: prompt }],
|
parts: [{ type: "text", text: prompt }],
|
||||||
},
|
},
|
||||||
query: { directory: ctx.directory },
|
query: { directory: ctx.directory },
|
||||||
|
|||||||
@ -3,28 +3,39 @@ import {
|
|||||||
findNearestMessageWithFields,
|
findNearestMessageWithFields,
|
||||||
findNearestMessageWithFieldsFromSDK,
|
findNearestMessageWithFieldsFromSDK,
|
||||||
} from "../../features/hook-message-injector"
|
} from "../../features/hook-message-injector"
|
||||||
import { getMessageDir, isSqliteBackend, normalizeSDKResponse } from "../../shared"
|
import { getMessageDir, isSqliteBackend, normalizePromptTools, normalizeSDKResponse } from "../../shared"
|
||||||
import type { ModelInfo } from "./types"
|
import type { ModelInfo } from "./types"
|
||||||
|
|
||||||
export async function resolveRecentModelForSession(
|
type PromptContext = {
|
||||||
|
model?: ModelInfo
|
||||||
|
tools?: Record<string, boolean>
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function resolveRecentPromptContextForSession(
|
||||||
ctx: PluginInput,
|
ctx: PluginInput,
|
||||||
sessionID: string
|
sessionID: string
|
||||||
): Promise<ModelInfo | undefined> {
|
): Promise<PromptContext> {
|
||||||
try {
|
try {
|
||||||
const messagesResp = await ctx.client.session.messages({ path: { id: sessionID } })
|
const messagesResp = await ctx.client.session.messages({ path: { id: sessionID } })
|
||||||
const messages = normalizeSDKResponse(messagesResp, [] as Array<{
|
const messages = normalizeSDKResponse(messagesResp, [] as Array<{
|
||||||
info?: { model?: ModelInfo; modelID?: string; providerID?: string }
|
info?: {
|
||||||
|
model?: ModelInfo
|
||||||
|
modelID?: string
|
||||||
|
providerID?: string
|
||||||
|
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||||
|
}
|
||||||
}>)
|
}>)
|
||||||
|
|
||||||
for (let i = messages.length - 1; i >= 0; i--) {
|
for (let i = messages.length - 1; i >= 0; i--) {
|
||||||
const info = messages[i].info
|
const info = messages[i].info
|
||||||
const model = info?.model
|
const model = info?.model
|
||||||
|
const tools = normalizePromptTools(info?.tools)
|
||||||
if (model?.providerID && model?.modelID) {
|
if (model?.providerID && model?.modelID) {
|
||||||
return { providerID: model.providerID, modelID: model.modelID }
|
return { model: { providerID: model.providerID, modelID: model.modelID }, tools }
|
||||||
}
|
}
|
||||||
|
|
||||||
if (info?.providerID && info?.modelID) {
|
if (info?.providerID && info?.modelID) {
|
||||||
return { providerID: info.providerID, modelID: info.modelID }
|
return { model: { providerID: info.providerID, modelID: info.modelID }, tools }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
@ -39,8 +50,17 @@ export async function resolveRecentModelForSession(
|
|||||||
currentMessage = messageDir ? findNearestMessageWithFields(messageDir) : null
|
currentMessage = messageDir ? findNearestMessageWithFields(messageDir) : null
|
||||||
}
|
}
|
||||||
const model = currentMessage?.model
|
const model = currentMessage?.model
|
||||||
|
const tools = normalizePromptTools(currentMessage?.tools)
|
||||||
if (!model?.providerID || !model?.modelID) {
|
if (!model?.providerID || !model?.modelID) {
|
||||||
return undefined
|
return { tools }
|
||||||
}
|
}
|
||||||
return { providerID: model.providerID, modelID: model.modelID }
|
return { model: { providerID: model.providerID, modelID: model.modelID }, tools }
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function resolveRecentModelForSession(
|
||||||
|
ctx: PluginInput,
|
||||||
|
sessionID: string
|
||||||
|
): Promise<ModelInfo | undefined> {
|
||||||
|
const context = await resolveRecentPromptContextForSession(ctx, sessionID)
|
||||||
|
return context.model
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,13 +3,14 @@ import { log } from "../../shared/logger"
|
|||||||
import { findNearestMessageWithFields } from "../../features/hook-message-injector"
|
import { findNearestMessageWithFields } from "../../features/hook-message-injector"
|
||||||
import { getMessageDir } from "./message-storage-directory"
|
import { getMessageDir } from "./message-storage-directory"
|
||||||
import { withTimeout } from "./with-timeout"
|
import { withTimeout } from "./with-timeout"
|
||||||
import { normalizeSDKResponse } from "../../shared"
|
import { normalizeSDKResponse, resolveInheritedPromptTools } from "../../shared"
|
||||||
|
|
||||||
type MessageInfo = {
|
type MessageInfo = {
|
||||||
agent?: string
|
agent?: string
|
||||||
model?: { providerID: string; modelID: string }
|
model?: { providerID: string; modelID: string }
|
||||||
modelID?: string
|
modelID?: string
|
||||||
providerID?: string
|
providerID?: string
|
||||||
|
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function injectContinuationPrompt(
|
export async function injectContinuationPrompt(
|
||||||
@ -18,6 +19,7 @@ export async function injectContinuationPrompt(
|
|||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
let agent: string | undefined
|
let agent: string | undefined
|
||||||
let model: { providerID: string; modelID: string } | undefined
|
let model: { providerID: string; modelID: string } | undefined
|
||||||
|
let tools: Record<string, boolean | "allow" | "deny" | "ask"> | undefined
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const messagesResp = await withTimeout(
|
const messagesResp = await withTimeout(
|
||||||
@ -36,6 +38,7 @@ export async function injectContinuationPrompt(
|
|||||||
(info.providerID && info.modelID
|
(info.providerID && info.modelID
|
||||||
? { providerID: info.providerID, modelID: info.modelID }
|
? { providerID: info.providerID, modelID: info.modelID }
|
||||||
: undefined)
|
: undefined)
|
||||||
|
tools = info.tools
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -50,13 +53,17 @@ export async function injectContinuationPrompt(
|
|||||||
modelID: currentMessage.model.modelID,
|
modelID: currentMessage.model.modelID,
|
||||||
}
|
}
|
||||||
: undefined
|
: undefined
|
||||||
|
tools = currentMessage?.tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const inheritedTools = resolveInheritedPromptTools(options.sessionID, tools)
|
||||||
|
|
||||||
await ctx.client.session.promptAsync({
|
await ctx.client.session.promptAsync({
|
||||||
path: { id: options.sessionID },
|
path: { id: options.sessionID },
|
||||||
body: {
|
body: {
|
||||||
...(agent !== undefined ? { agent } : {}),
|
...(agent !== undefined ? { agent } : {}),
|
||||||
...(model !== undefined ? { model } : {}),
|
...(model !== undefined ? { model } : {}),
|
||||||
|
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||||
parts: [{ type: "text", text: options.prompt }],
|
parts: [{ type: "text", text: options.prompt }],
|
||||||
},
|
},
|
||||||
query: { directory: options.directory },
|
query: { directory: options.directory },
|
||||||
|
|||||||
48
src/hooks/session-recovery/resume.test.ts
Normal file
48
src/hooks/session-recovery/resume.test.ts
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
declare const require: (name: string) => any
|
||||||
|
const { describe, expect, test } = require("bun:test")
|
||||||
|
import { extractResumeConfig, resumeSession } from "./resume"
|
||||||
|
import type { MessageData } from "./types"
|
||||||
|
|
||||||
|
describe("session-recovery resume", () => {
|
||||||
|
test("extractResumeConfig carries tools from last user message", () => {
|
||||||
|
// given
|
||||||
|
const userMessage: MessageData = {
|
||||||
|
info: {
|
||||||
|
agent: "Hephaestus",
|
||||||
|
model: { providerID: "openai", modelID: "gpt-5.3-codex" },
|
||||||
|
tools: { question: false, bash: true },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// when
|
||||||
|
const config = extractResumeConfig(userMessage, "ses_resume_tools")
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(config.tools).toEqual({ question: false, bash: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
test("resumeSession sends inherited tools with continuation prompt", async () => {
|
||||||
|
// given
|
||||||
|
let promptBody: Record<string, unknown> | undefined
|
||||||
|
const client = {
|
||||||
|
session: {
|
||||||
|
promptAsync: async (input: { body: Record<string, unknown> }) => {
|
||||||
|
promptBody = input.body
|
||||||
|
return {}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// when
|
||||||
|
const ok = await resumeSession(client as never, {
|
||||||
|
sessionID: "ses_resume_prompt",
|
||||||
|
agent: "Hephaestus",
|
||||||
|
model: { providerID: "openai", modelID: "gpt-5.3-codex" },
|
||||||
|
tools: { question: false, bash: true },
|
||||||
|
})
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(ok).toBe(true)
|
||||||
|
expect(promptBody?.tools).toEqual({ question: false, bash: true })
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -1,5 +1,6 @@
|
|||||||
import type { createOpencodeClient } from "@opencode-ai/sdk"
|
import type { createOpencodeClient } from "@opencode-ai/sdk"
|
||||||
import type { MessageData, ResumeConfig } from "./types"
|
import type { MessageData, ResumeConfig } from "./types"
|
||||||
|
import { resolveInheritedPromptTools } from "../../shared"
|
||||||
|
|
||||||
const RECOVERY_RESUME_TEXT = "[session recovered - continuing previous task]"
|
const RECOVERY_RESUME_TEXT = "[session recovered - continuing previous task]"
|
||||||
|
|
||||||
@ -19,17 +20,20 @@ export function extractResumeConfig(userMessage: MessageData | undefined, sessio
|
|||||||
sessionID,
|
sessionID,
|
||||||
agent: userMessage?.info?.agent,
|
agent: userMessage?.info?.agent,
|
||||||
model: userMessage?.info?.model,
|
model: userMessage?.info?.model,
|
||||||
|
tools: userMessage?.info?.tools,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function resumeSession(client: Client, config: ResumeConfig): Promise<boolean> {
|
export async function resumeSession(client: Client, config: ResumeConfig): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
|
const inheritedTools = resolveInheritedPromptTools(config.sessionID, config.tools)
|
||||||
await client.session.promptAsync({
|
await client.session.promptAsync({
|
||||||
path: { id: config.sessionID },
|
path: { id: config.sessionID },
|
||||||
body: {
|
body: {
|
||||||
parts: [{ type: "text", text: RECOVERY_RESUME_TEXT }],
|
parts: [{ type: "text", text: RECOVERY_RESUME_TEXT }],
|
||||||
agent: config.agent,
|
agent: config.agent,
|
||||||
model: config.model,
|
model: config.model,
|
||||||
|
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
return true
|
return true
|
||||||
|
|||||||
@ -95,4 +95,5 @@ export interface ResumeConfig {
|
|||||||
providerID: string
|
providerID: string
|
||||||
modelID: string
|
modelID: string
|
||||||
}
|
}
|
||||||
|
tools?: Record<string, boolean>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,41 @@
|
|||||||
|
declare const require: (name: string) => any
|
||||||
|
const { describe, expect, test } = require("bun:test")
|
||||||
|
|
||||||
|
import { injectContinuation } from "./continuation-injection"
|
||||||
|
|
||||||
|
describe("injectContinuation", () => {
|
||||||
|
test("inherits tools from resolved message info when reinjecting", async () => {
|
||||||
|
// given
|
||||||
|
let capturedTools: Record<string, boolean> | undefined
|
||||||
|
const ctx = {
|
||||||
|
directory: "/tmp/test",
|
||||||
|
client: {
|
||||||
|
session: {
|
||||||
|
todo: async () => ({ data: [{ id: "1", content: "todo", status: "pending", priority: "high" }] }),
|
||||||
|
promptAsync: async (input: { body: { tools?: Record<string, boolean> } }) => {
|
||||||
|
capturedTools = input.body.tools
|
||||||
|
return {}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const sessionStateStore = {
|
||||||
|
getExistingState: () => ({ inFlight: false, lastInjectedAt: 0, consecutiveFailures: 0 }),
|
||||||
|
}
|
||||||
|
|
||||||
|
// when
|
||||||
|
await injectContinuation({
|
||||||
|
ctx: ctx as never,
|
||||||
|
sessionID: "ses_continuation_tools",
|
||||||
|
resolvedInfo: {
|
||||||
|
agent: "Hephaestus",
|
||||||
|
model: { providerID: "openai", modelID: "gpt-5.3-codex" },
|
||||||
|
tools: { question: "deny", bash: "allow" },
|
||||||
|
},
|
||||||
|
sessionStateStore: sessionStateStore as never,
|
||||||
|
})
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(capturedTools).toEqual({ question: false, bash: true })
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -1,7 +1,7 @@
|
|||||||
import type { PluginInput } from "@opencode-ai/plugin"
|
import type { PluginInput } from "@opencode-ai/plugin"
|
||||||
|
|
||||||
import type { BackgroundManager } from "../../features/background-agent"
|
import type { BackgroundManager } from "../../features/background-agent"
|
||||||
import { normalizeSDKResponse } from "../../shared"
|
import { normalizeSDKResponse, resolveInheritedPromptTools } from "../../shared"
|
||||||
import {
|
import {
|
||||||
findNearestMessageWithFields,
|
findNearestMessageWithFields,
|
||||||
findNearestMessageWithFieldsFromSDK,
|
findNearestMessageWithFieldsFromSDK,
|
||||||
@ -136,11 +136,14 @@ ${todoList}`
|
|||||||
incompleteCount: freshIncompleteCount,
|
incompleteCount: freshIncompleteCount,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const inheritedTools = resolveInheritedPromptTools(sessionID, tools)
|
||||||
|
|
||||||
await ctx.client.session.promptAsync({
|
await ctx.client.session.promptAsync({
|
||||||
path: { id: sessionID },
|
path: { id: sessionID },
|
||||||
body: {
|
body: {
|
||||||
agent: agentName,
|
agent: agentName,
|
||||||
...(model !== undefined ? { model } : {}),
|
...(model !== undefined ? { model } : {}),
|
||||||
|
...(inheritedTools ? { tools: inheritedTools } : {}),
|
||||||
parts: [{ type: "text", text: prompt }],
|
parts: [{ type: "text", text: prompt }],
|
||||||
},
|
},
|
||||||
query: { directory: ctx.directory },
|
query: { directory: ctx.directory },
|
||||||
|
|||||||
@ -8,6 +8,7 @@ type MessageInfo = {
|
|||||||
model?: { providerID: string; modelID: string }
|
model?: { providerID: string; modelID: string }
|
||||||
providerID?: string
|
providerID?: string
|
||||||
modelID?: string
|
modelID?: string
|
||||||
|
tools?: Record<string, boolean | "allow" | "deny" | "ask">
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessagePart = {
|
type MessagePart = {
|
||||||
@ -40,6 +41,20 @@ export function getMessageInfo(value: unknown): MessageInfo | undefined {
|
|||||||
model,
|
model,
|
||||||
providerID: typeof info.providerID === "string" ? info.providerID : undefined,
|
providerID: typeof info.providerID === "string" ? info.providerID : undefined,
|
||||||
modelID: typeof info.modelID === "string" ? info.modelID : undefined,
|
modelID: typeof info.modelID === "string" ? info.modelID : undefined,
|
||||||
|
tools: isRecord(info.tools)
|
||||||
|
? Object.entries(info.tools).reduce<Record<string, boolean | "allow" | "deny" | "ask">>((acc, [key, value]) => {
|
||||||
|
if (
|
||||||
|
value === true ||
|
||||||
|
value === false ||
|
||||||
|
value === "allow" ||
|
||||||
|
value === "deny" ||
|
||||||
|
value === "ask"
|
||||||
|
) {
|
||||||
|
acc[key] = value
|
||||||
|
}
|
||||||
|
return acc
|
||||||
|
}, {})
|
||||||
|
: undefined,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import type { BackgroundManager } from "../../features/background-agent"
|
import type { BackgroundManager } from "../../features/background-agent"
|
||||||
import { getMainSessionID, getSessionAgent } from "../../features/claude-code-session-state"
|
import { getMainSessionID, getSessionAgent } from "../../features/claude-code-session-state"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
|
import { resolveInheritedPromptTools } from "../../shared"
|
||||||
import {
|
import {
|
||||||
buildReminder,
|
buildReminder,
|
||||||
extractMessages,
|
extractMessages,
|
||||||
@ -29,6 +30,7 @@ type BabysitterContext = {
|
|||||||
parts: Array<{ type: "text"; text: string }>
|
parts: Array<{ type: "text"; text: string }>
|
||||||
agent?: string
|
agent?: string
|
||||||
model?: { providerID: string; modelID: string }
|
model?: { providerID: string; modelID: string }
|
||||||
|
tools?: Record<string, boolean>
|
||||||
}
|
}
|
||||||
query?: { directory?: string }
|
query?: { directory?: string }
|
||||||
}) => Promise<unknown>
|
}) => Promise<unknown>
|
||||||
@ -38,6 +40,7 @@ type BabysitterContext = {
|
|||||||
parts: Array<{ type: "text"; text: string }>
|
parts: Array<{ type: "text"; text: string }>
|
||||||
agent?: string
|
agent?: string
|
||||||
model?: { providerID: string; modelID: string }
|
model?: { providerID: string; modelID: string }
|
||||||
|
tools?: Record<string, boolean>
|
||||||
}
|
}
|
||||||
query?: { directory?: string }
|
query?: { directory?: string }
|
||||||
}) => Promise<unknown>
|
}) => Promise<unknown>
|
||||||
@ -54,9 +57,10 @@ type BabysitterOptions = {
|
|||||||
async function resolveMainSessionTarget(
|
async function resolveMainSessionTarget(
|
||||||
ctx: BabysitterContext,
|
ctx: BabysitterContext,
|
||||||
sessionID: string
|
sessionID: string
|
||||||
): Promise<{ agent?: string; model?: { providerID: string; modelID: string } }> {
|
): Promise<{ agent?: string; model?: { providerID: string; modelID: string }; tools?: Record<string, boolean> }> {
|
||||||
let agent = getSessionAgent(sessionID)
|
let agent = getSessionAgent(sessionID)
|
||||||
let model: { providerID: string; modelID: string } | undefined
|
let model: { providerID: string; modelID: string } | undefined
|
||||||
|
let tools: Record<string, boolean> | undefined
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const messagesResp = await ctx.client.session.messages({
|
const messagesResp = await ctx.client.session.messages({
|
||||||
@ -68,6 +72,7 @@ async function resolveMainSessionTarget(
|
|||||||
if (info?.agent || info?.model || (info?.providerID && info?.modelID)) {
|
if (info?.agent || info?.model || (info?.providerID && info?.modelID)) {
|
||||||
agent = agent ?? info?.agent
|
agent = agent ?? info?.agent
|
||||||
model = info?.model ?? (info?.providerID && info?.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined)
|
model = info?.model ?? (info?.providerID && info?.modelID ? { providerID: info.providerID, modelID: info.modelID } : undefined)
|
||||||
|
tools = resolveInheritedPromptTools(sessionID, info?.tools) ?? tools
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,7 +80,7 @@ async function resolveMainSessionTarget(
|
|||||||
log(`[${HOOK_NAME}] Failed to resolve main session agent`, { sessionID, error: String(error) })
|
log(`[${HOOK_NAME}] Failed to resolve main session agent`, { sessionID, error: String(error) })
|
||||||
}
|
}
|
||||||
|
|
||||||
return { agent, model }
|
return { agent, model, tools: resolveInheritedPromptTools(sessionID, tools) }
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getThinkingSummary(ctx: BabysitterContext, sessionID: string): Promise<string | null> {
|
async function getThinkingSummary(ctx: BabysitterContext, sessionID: string): Promise<string | null> {
|
||||||
@ -144,7 +149,7 @@ export function createUnstableAgentBabysitterHook(ctx: BabysitterContext, option
|
|||||||
|
|
||||||
const summary = task.sessionID ? await getThinkingSummary(ctx, task.sessionID) : null
|
const summary = task.sessionID ? await getThinkingSummary(ctx, task.sessionID) : null
|
||||||
const reminder = buildReminder(task, summary, idleMs)
|
const reminder = buildReminder(task, summary, idleMs)
|
||||||
const { agent, model } = await resolveMainSessionTarget(ctx, mainSessionID)
|
const { agent, model, tools } = await resolveMainSessionTarget(ctx, mainSessionID)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await ctx.client.session.promptAsync({
|
await ctx.client.session.promptAsync({
|
||||||
@ -152,6 +157,7 @@ export function createUnstableAgentBabysitterHook(ctx: BabysitterContext, option
|
|||||||
body: {
|
body: {
|
||||||
...(agent ? { agent } : {}),
|
...(agent ? { agent } : {}),
|
||||||
...(model ? { model } : {}),
|
...(model ? { model } : {}),
|
||||||
|
...(tools ? { tools } : {}),
|
||||||
parts: [{ type: "text", text: reminder }],
|
parts: [{ type: "text", text: reminder }],
|
||||||
},
|
},
|
||||||
query: { directory: ctx.directory },
|
query: { directory: ctx.directory },
|
||||||
|
|||||||
@ -56,3 +56,4 @@ export * from "./opencode-storage-paths"
|
|||||||
export * from "./opencode-message-dir"
|
export * from "./opencode-message-dir"
|
||||||
export * from "./normalize-sdk-response"
|
export * from "./normalize-sdk-response"
|
||||||
export * from "./session-directory-resolver"
|
export * from "./session-directory-resolver"
|
||||||
|
export * from "./prompt-tools"
|
||||||
|
|||||||
56
src/shared/prompt-tools.test.ts
Normal file
56
src/shared/prompt-tools.test.ts
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
declare const require: (name: string) => any
|
||||||
|
const { afterEach, describe, expect, test } = require("bun:test")
|
||||||
|
import { clearSessionTools, setSessionTools } from "./session-tools-store"
|
||||||
|
import { normalizePromptTools, resolveInheritedPromptTools } from "./prompt-tools"
|
||||||
|
|
||||||
|
describe("prompt-tools", () => {
|
||||||
|
afterEach(() => {
|
||||||
|
clearSessionTools()
|
||||||
|
})
|
||||||
|
|
||||||
|
test("normalizes allow/deny style permissions to boolean tools", () => {
|
||||||
|
// given
|
||||||
|
const tools = {
|
||||||
|
question: "deny",
|
||||||
|
bash: "allow",
|
||||||
|
task: "ask",
|
||||||
|
read: true,
|
||||||
|
edit: false,
|
||||||
|
} as const
|
||||||
|
|
||||||
|
// when
|
||||||
|
const normalized = normalizePromptTools(tools)
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(normalized).toEqual({
|
||||||
|
question: false,
|
||||||
|
bash: true,
|
||||||
|
task: true,
|
||||||
|
read: true,
|
||||||
|
edit: false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test("prefers per-session stored tools over fallback tools", () => {
|
||||||
|
// given
|
||||||
|
const sessionID = "ses_prompt_tools"
|
||||||
|
setSessionTools(sessionID, { question: false, bash: true })
|
||||||
|
|
||||||
|
// when
|
||||||
|
const resolved = resolveInheritedPromptTools(sessionID, { question: true, bash: false })
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(resolved).toEqual({ question: false, bash: true })
|
||||||
|
})
|
||||||
|
|
||||||
|
test("uses fallback tools when no per-session tools exist", () => {
|
||||||
|
// given
|
||||||
|
const sessionID = "ses_fallback_only"
|
||||||
|
|
||||||
|
// when
|
||||||
|
const resolved = resolveInheritedPromptTools(sessionID, { question: "deny", write: "allow" })
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(resolved).toEqual({ question: false, write: true })
|
||||||
|
})
|
||||||
|
})
|
||||||
35
src/shared/prompt-tools.ts
Normal file
35
src/shared/prompt-tools.ts
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import { getSessionTools } from "./session-tools-store"
|
||||||
|
|
||||||
|
export type PromptToolPermission = boolean | "allow" | "deny" | "ask"
|
||||||
|
|
||||||
|
export function normalizePromptTools(
|
||||||
|
tools: Record<string, PromptToolPermission> | undefined
|
||||||
|
): Record<string, boolean> | undefined {
|
||||||
|
if (!tools) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalized: Record<string, boolean> = {}
|
||||||
|
for (const [toolName, permission] of Object.entries(tools)) {
|
||||||
|
if (permission === false || permission === "deny") {
|
||||||
|
normalized[toolName] = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (permission === true || permission === "allow" || permission === "ask") {
|
||||||
|
normalized[toolName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Object.keys(normalized).length > 0 ? normalized : undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export function resolveInheritedPromptTools(
|
||||||
|
sessionID: string,
|
||||||
|
fallbackTools?: Record<string, PromptToolPermission>
|
||||||
|
): Record<string, boolean> | undefined {
|
||||||
|
const sessionTools = getSessionTools(sessionID)
|
||||||
|
if (sessionTools && Object.keys(sessionTools).length > 0) {
|
||||||
|
return { ...sessionTools }
|
||||||
|
}
|
||||||
|
return normalizePromptTools(fallbackTools)
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user