fix(runtime-fallback): 9 critical bug fixes for auto-retry, agent preservation, and model override
Bug fixes: 1. extractStatusCode: handle nested data.statusCode (Anthropic error structure) 2. Error regex: relax credit.*balance.*too.*low pattern for multi-char gaps 3. Zod schema: bump max_fallback_attempts from 10 to 20 (config rejected silently) 4. getFallbackModelsForSession: fallback to sisyphus/any agent when session.error lacks agent 5. Model detection: derive model from agent config when session.error lacks model info 6. Auto-retry: resend last user message with fallback model via promptAsync 7. Persistent fallback: override model on every chat.message (not just pendingFallbackModel) 8. Manual model change: detect UI model changes and reset fallback state 9. Agent preservation: include agent in promptAsync body to prevent defaulting to sisyphus Additional: - Add sessionRetryInFlight guard to prevent double-retries - Add resolveAgentForSession with 3-tier resolution (event → session memory → session ID) - Add normalizeAgentName for display names like "Prometheus (Planner)" → "prometheus" - Add resolveAgentForSessionFromContext to fetch agent from session messages - Move AGENT_NAMES and agentPattern to module scope for reuse - Register runtime-fallback hooks in event.ts and chat-message.ts - Remove diagnostic debug logging from isRetryableError - Add 400 to default retry_on_errors and credit/balance patterns to RETRYABLE_ERROR_PATTERNS
This commit is contained in:
parent
708b9ce9ff
commit
fbafb8cf67
@ -1,11 +1,16 @@
|
||||
import { z } from "zod"
|
||||
|
||||
export const RuntimeFallbackConfigSchema = z.object({
|
||||
enabled: z.boolean().default(true),
|
||||
retry_on_errors: z.array(z.number()).default([429, 503, 529]),
|
||||
max_fallback_attempts: z.number().min(1).max(10).default(3),
|
||||
cooldown_seconds: z.number().min(0).default(60),
|
||||
notify_on_fallback: z.boolean().default(true),
|
||||
/** Enable runtime fallback (default: true) */
|
||||
enabled: z.boolean().optional(),
|
||||
/** HTTP status codes that trigger fallback (default: [429, 503, 529]) */
|
||||
retry_on_errors: z.array(z.number()).optional(),
|
||||
/** Maximum fallback attempts per session (default: 3) */
|
||||
max_fallback_attempts: z.number().min(1).max(20).optional(),
|
||||
/** Cooldown in seconds before retrying a failed model (default: 60) */
|
||||
cooldown_seconds: z.number().min(0).optional(),
|
||||
/** Show toast notification when switching to fallback model (default: true) */
|
||||
notify_on_fallback: z.boolean().optional(),
|
||||
})
|
||||
|
||||
export type RuntimeFallbackConfig = z.infer<typeof RuntimeFallbackConfigSchema>
|
||||
|
||||
@ -11,7 +11,7 @@ import type { RuntimeFallbackConfig } from "../../config"
|
||||
*/
|
||||
export const DEFAULT_CONFIG: Required<RuntimeFallbackConfig> = {
|
||||
enabled: true,
|
||||
retry_on_errors: [429, 503, 529],
|
||||
retry_on_errors: [400, 429, 503, 529],
|
||||
max_fallback_attempts: 3,
|
||||
cooldown_seconds: 60,
|
||||
notify_on_fallback: true,
|
||||
@ -29,6 +29,8 @@ export const RETRYABLE_ERROR_PATTERNS = [
|
||||
/overloaded/i,
|
||||
/temporarily.?unavailable/i,
|
||||
/try.?again/i,
|
||||
/credit.*balance.*too.*low/i,
|
||||
/insufficient.?(?:credits?|funds?|balance)/i,
|
||||
/(?:^|\s)429(?:\s|$)/,
|
||||
/(?:^|\s)503(?:\s|$)/,
|
||||
/(?:^|\s)529(?:\s|$)/,
|
||||
|
||||
@ -23,7 +23,12 @@ describe("runtime-fallback", () => {
|
||||
logSpy?.mockRestore()
|
||||
})
|
||||
|
||||
function createMockPluginInput() {
|
||||
function createMockPluginInput(overrides?: {
|
||||
session?: {
|
||||
messages?: (args: unknown) => Promise<unknown>
|
||||
promptAsync?: (args: unknown) => Promise<unknown>
|
||||
}
|
||||
}) {
|
||||
return {
|
||||
client: {
|
||||
tui: {
|
||||
@ -35,6 +40,10 @@ describe("runtime-fallback", () => {
|
||||
})
|
||||
},
|
||||
},
|
||||
session: {
|
||||
messages: overrides?.session?.messages ?? (async () => ({ data: [] })),
|
||||
promptAsync: overrides?.session?.promptAsync ?? (async () => ({})),
|
||||
},
|
||||
},
|
||||
directory: "/test/dir",
|
||||
} as any
|
||||
@ -174,7 +183,10 @@ describe("runtime-fallback", () => {
|
||||
})
|
||||
|
||||
test("should log when no fallback models configured", async () => {
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), { config: createMockConfig() })
|
||||
const hook = createRuntimeFallbackHook(createMockPluginInput(), {
|
||||
config: createMockConfig(),
|
||||
pluginConfig: {},
|
||||
})
|
||||
const sessionID = "test-session-no-fallbacks"
|
||||
|
||||
await hook.event({
|
||||
@ -487,7 +499,7 @@ describe("runtime-fallback", () => {
|
||||
|
||||
const output = { message: {}, parts: [] }
|
||||
await hook["chat.message"]?.(
|
||||
{ sessionID, model: { providerID: "anthropic", modelID: "claude-opus-4-5" } },
|
||||
{ sessionID },
|
||||
output
|
||||
)
|
||||
|
||||
@ -588,6 +600,50 @@ describe("runtime-fallback", () => {
|
||||
expect(fallbackLog).toBeDefined()
|
||||
expect(fallbackLog?.data).toMatchObject({ to: "openai/gpt-5.2" })
|
||||
})
|
||||
|
||||
test("should preserve resolved agent during auto-retry", async () => {
|
||||
const promptCalls: Array<Record<string, unknown>> = []
|
||||
const hook = createRuntimeFallbackHook(
|
||||
createMockPluginInput({
|
||||
session: {
|
||||
messages: async () => ({
|
||||
data: [
|
||||
{
|
||||
info: { role: "user" },
|
||||
parts: [{ type: "text", text: "test" }],
|
||||
},
|
||||
],
|
||||
}),
|
||||
promptAsync: async (args: unknown) => {
|
||||
promptCalls.push(args as Record<string, unknown>)
|
||||
return {}
|
||||
},
|
||||
},
|
||||
}),
|
||||
{
|
||||
config: createMockConfig({ notify_on_fallback: false }),
|
||||
pluginConfig: createMockPluginConfigWithAgentFallback("prometheus", ["github-copilot/claude-opus-4.6"]),
|
||||
},
|
||||
)
|
||||
const sessionID = "test-preserve-agent-on-retry"
|
||||
|
||||
await hook.event({
|
||||
event: {
|
||||
type: "session.error",
|
||||
properties: {
|
||||
sessionID,
|
||||
model: "anthropic/claude-opus-4-6",
|
||||
error: { statusCode: 503, message: "Service unavailable" },
|
||||
agent: "prometheus",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(promptCalls.length).toBe(1)
|
||||
const callBody = promptCalls[0]?.body as Record<string, unknown>
|
||||
expect(callBody?.agent).toBe("prometheus")
|
||||
expect(callBody?.model).toEqual({ providerID: "github-copilot", modelID: "claude-opus-4.6" })
|
||||
})
|
||||
})
|
||||
|
||||
describe("cooldown mechanism", () => {
|
||||
|
||||
@ -5,6 +5,7 @@ import { DEFAULT_CONFIG, RETRYABLE_ERROR_PATTERNS, HOOK_NAME } from "./constants
|
||||
import { log } from "../../shared/logger"
|
||||
import { SessionCategoryRegistry } from "../../shared/session-category-registry"
|
||||
import { normalizeFallbackModels } from "../../shared/model-resolver"
|
||||
import { getSessionAgent } from "../../features/claude-code-session-state"
|
||||
|
||||
function createFallbackState(originalModel: string): FallbackState {
|
||||
return {
|
||||
@ -56,7 +57,7 @@ function extractStatusCode(error: unknown): number | undefined {
|
||||
}
|
||||
|
||||
const message = getErrorMessage(error)
|
||||
const statusMatch = message.match(/\b(429|503|529)\b/)
|
||||
const statusMatch = message.match(/\b(400|402|429|503|529)\b/)
|
||||
if (statusMatch) {
|
||||
return parseInt(statusMatch[1], 10)
|
||||
}
|
||||
@ -66,15 +67,68 @@ function extractStatusCode(error: unknown): number | undefined {
|
||||
|
||||
function isRetryableError(error: unknown, retryOnErrors: number[]): boolean {
|
||||
const statusCode = extractStatusCode(error)
|
||||
const message = getErrorMessage(error)
|
||||
|
||||
if (statusCode && retryOnErrors.includes(statusCode)) {
|
||||
return true
|
||||
}
|
||||
|
||||
const message = getErrorMessage(error)
|
||||
return RETRYABLE_ERROR_PATTERNS.some((pattern) => pattern.test(message))
|
||||
}
|
||||
|
||||
const AGENT_NAMES = [
|
||||
"sisyphus",
|
||||
"oracle",
|
||||
"librarian",
|
||||
"explore",
|
||||
"prometheus",
|
||||
"atlas",
|
||||
"metis",
|
||||
"momus",
|
||||
"hephaestus",
|
||||
"sisyphus-junior",
|
||||
"build",
|
||||
"plan",
|
||||
"multimodal-looker",
|
||||
]
|
||||
|
||||
const agentPattern = new RegExp(
|
||||
`\\b(${AGENT_NAMES
|
||||
.sort((a, b) => b.length - a.length)
|
||||
.map((a) => a.replace(/-/g, "\\-"))
|
||||
.join("|")})\\b`,
|
||||
"i",
|
||||
)
|
||||
|
||||
function detectAgentFromSession(sessionID: string): string | undefined {
|
||||
const match = sessionID.match(agentPattern)
|
||||
if (match) {
|
||||
return match[1].toLowerCase()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function normalizeAgentName(agent: string | undefined): string | undefined {
|
||||
if (!agent) return undefined
|
||||
const normalized = agent.toLowerCase().trim()
|
||||
if (AGENT_NAMES.includes(normalized)) {
|
||||
return normalized
|
||||
}
|
||||
const match = normalized.match(agentPattern)
|
||||
if (match) {
|
||||
return match[1].toLowerCase()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function resolveAgentForSession(sessionID: string, eventAgent?: string): string | undefined {
|
||||
return (
|
||||
normalizeAgentName(eventAgent) ??
|
||||
normalizeAgentName(getSessionAgent(sessionID)) ??
|
||||
detectAgentFromSession(sessionID)
|
||||
)
|
||||
}
|
||||
|
||||
function getFallbackModelsForSession(
|
||||
sessionID: string,
|
||||
agent: string | undefined,
|
||||
@ -115,28 +169,6 @@ function getFallbackModelsForSession(
|
||||
if (result) return result
|
||||
}
|
||||
|
||||
const AGENT_NAMES = [
|
||||
"sisyphus",
|
||||
"oracle",
|
||||
"librarian",
|
||||
"explore",
|
||||
"prometheus",
|
||||
"atlas",
|
||||
"metis",
|
||||
"momus",
|
||||
"hephaestus",
|
||||
"sisyphus-junior",
|
||||
"build",
|
||||
"plan",
|
||||
"multimodal-looker",
|
||||
]
|
||||
const agentPattern = new RegExp(
|
||||
`(?:^|[^a-zA-Z0-9_-])(${AGENT_NAMES
|
||||
.sort((a, b) => b.length - a.length)
|
||||
.map((a) => a.replace(/-/g, "\\-"))
|
||||
.join("|")})(?:$|[^a-zA-Z0-9_-])`,
|
||||
"i",
|
||||
)
|
||||
const sessionAgentMatch = sessionID.match(agentPattern)
|
||||
if (sessionAgentMatch) {
|
||||
const detectedAgent = sessionAgentMatch[1].toLowerCase()
|
||||
@ -144,6 +176,22 @@ function getFallbackModelsForSession(
|
||||
if (result) return result
|
||||
}
|
||||
|
||||
// Fallback: if no agent detected, try main agent "sisyphus" then any agent with fallback_models
|
||||
const sisyphusFallback = tryGetFallbackFromAgent("sisyphus")
|
||||
if (sisyphusFallback) {
|
||||
log(`[${HOOK_NAME}] Using sisyphus fallback models (no agent detected)`, { sessionID })
|
||||
return sisyphusFallback
|
||||
}
|
||||
|
||||
// Last resort: try all known agents until we find one with fallback_models
|
||||
for (const agentName of AGENT_NAMES) {
|
||||
const result = tryGetFallbackFromAgent(agentName)
|
||||
if (result) {
|
||||
log(`[${HOOK_NAME}] Using ${agentName} fallback models (no agent detected)`, { sessionID })
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
@ -221,6 +269,30 @@ export function createRuntimeFallbackHook(
|
||||
}
|
||||
|
||||
const sessionStates = new Map<string, FallbackState>()
|
||||
const sessionLastAccess = new Map<string, number>()
|
||||
const sessionRetryInFlight = new Set<string>()
|
||||
const SESSION_TTL_MS = 30 * 60 * 1000 // 30 minutes TTL for stale sessions
|
||||
|
||||
// Periodic cleanup of stale session states to prevent memory leaks
|
||||
const cleanupStaleSessions = () => {
|
||||
const now = Date.now()
|
||||
let cleanedCount = 0
|
||||
for (const [sessionID, lastAccess] of sessionLastAccess.entries()) {
|
||||
if (now - lastAccess > SESSION_TTL_MS) {
|
||||
sessionStates.delete(sessionID)
|
||||
sessionLastAccess.delete(sessionID)
|
||||
sessionRetryInFlight.delete(sessionID)
|
||||
SessionCategoryRegistry.remove(sessionID)
|
||||
cleanedCount++
|
||||
}
|
||||
}
|
||||
if (cleanedCount > 0) {
|
||||
log(`[${HOOK_NAME}] Cleaned up ${cleanedCount} stale session states`)
|
||||
}
|
||||
}
|
||||
|
||||
// Run cleanup every 5 minutes
|
||||
const cleanupInterval = setInterval(cleanupStaleSessions, 5 * 60 * 1000)
|
||||
|
||||
let pluginConfig: OhMyOpenCodeConfig | undefined
|
||||
if (options?.pluginConfig) {
|
||||
@ -234,6 +306,36 @@ export function createRuntimeFallbackHook(
|
||||
}
|
||||
}
|
||||
|
||||
const resolveAgentForSessionFromContext = async (
|
||||
sessionID: string,
|
||||
eventAgent?: string,
|
||||
): Promise<string | undefined> => {
|
||||
const resolved = resolveAgentForSession(sessionID, eventAgent)
|
||||
if (resolved) return resolved
|
||||
|
||||
try {
|
||||
const messagesResp = await ctx.client.session.messages({
|
||||
path: { id: sessionID },
|
||||
query: { directory: ctx.directory },
|
||||
})
|
||||
const msgs = (messagesResp as { data?: Array<{ info?: Record<string, unknown> }> }).data
|
||||
if (!msgs || msgs.length === 0) return undefined
|
||||
|
||||
for (let i = msgs.length - 1; i >= 0; i--) {
|
||||
const info = msgs[i]?.info
|
||||
const infoAgent = typeof info?.agent === "string" ? info.agent : undefined
|
||||
const normalized = normalizeAgentName(infoAgent)
|
||||
if (normalized) {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
const eventHandler = async ({ event }: { event: { type: string; properties?: unknown } }) => {
|
||||
if (!config.enabled) return
|
||||
|
||||
@ -247,6 +349,7 @@ export function createRuntimeFallbackHook(
|
||||
if (sessionID && model) {
|
||||
log(`[${HOOK_NAME}] Session created with model`, { sessionID, model })
|
||||
sessionStates.set(sessionID, createFallbackState(model))
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -258,6 +361,8 @@ export function createRuntimeFallbackHook(
|
||||
if (sessionID) {
|
||||
log(`[${HOOK_NAME}] Cleaning up session state`, { sessionID })
|
||||
sessionStates.delete(sessionID)
|
||||
sessionLastAccess.delete(sessionID)
|
||||
sessionRetryInFlight.delete(sessionID)
|
||||
SessionCategoryRegistry.remove(sessionID)
|
||||
}
|
||||
return
|
||||
@ -273,7 +378,14 @@ export function createRuntimeFallbackHook(
|
||||
return
|
||||
}
|
||||
|
||||
log(`[${HOOK_NAME}] session.error received`, { sessionID, agent, statusCode: extractStatusCode(error) })
|
||||
const resolvedAgent = await resolveAgentForSessionFromContext(sessionID, agent)
|
||||
|
||||
log(`[${HOOK_NAME}] session.error received`, {
|
||||
sessionID,
|
||||
agent,
|
||||
resolvedAgent,
|
||||
statusCode: extractStatusCode(error),
|
||||
})
|
||||
|
||||
if (!isRetryableError(error, config.retry_on_errors)) {
|
||||
log(`[${HOOK_NAME}] Error not retryable, skipping fallback`, { sessionID })
|
||||
@ -281,7 +393,7 @@ export function createRuntimeFallbackHook(
|
||||
}
|
||||
|
||||
let state = sessionStates.get(sessionID)
|
||||
const fallbackModels = getFallbackModelsForSession(sessionID, agent, pluginConfig)
|
||||
const fallbackModels = getFallbackModelsForSession(sessionID, resolvedAgent, pluginConfig)
|
||||
|
||||
if (fallbackModels.length === 0) {
|
||||
log(`[${HOOK_NAME}] No fallback models configured`, { sessionID, agent })
|
||||
@ -293,10 +405,26 @@ export function createRuntimeFallbackHook(
|
||||
if (currentModel) {
|
||||
state = createFallbackState(currentModel)
|
||||
sessionStates.set(sessionID, state)
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
} else {
|
||||
log(`[${HOOK_NAME}] No model info available, cannot fallback`, { sessionID })
|
||||
return
|
||||
// session.error doesn't include model — derive from agent config
|
||||
const detectedAgent = resolvedAgent
|
||||
const agentConfig = detectedAgent
|
||||
? pluginConfig?.agents?.[detectedAgent as keyof typeof pluginConfig.agents]
|
||||
: undefined
|
||||
const agentModel = agentConfig?.model as string | undefined
|
||||
if (agentModel) {
|
||||
log(`[${HOOK_NAME}] Derived model from agent config`, { sessionID, agent: detectedAgent, model: agentModel })
|
||||
state = createFallbackState(agentModel)
|
||||
sessionStates.set(sessionID, state)
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
} else {
|
||||
log(`[${HOOK_NAME}] No model info available, cannot fallback`, { sessionID })
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
}
|
||||
|
||||
const result = prepareFallback(sessionID, state, fallbackModels, config)
|
||||
@ -314,6 +442,68 @@ export function createRuntimeFallbackHook(
|
||||
.catch(() => {})
|
||||
}
|
||||
|
||||
if (result.success && result.newModel) {
|
||||
if (sessionRetryInFlight.has(sessionID)) {
|
||||
log(`[${HOOK_NAME}] Retry already in flight, skipping`, { sessionID })
|
||||
} else {
|
||||
const modelParts = result.newModel.split("/")
|
||||
if (modelParts.length >= 2) {
|
||||
const fallbackModelObj = {
|
||||
providerID: modelParts[0],
|
||||
modelID: modelParts.slice(1).join("/"),
|
||||
}
|
||||
|
||||
sessionRetryInFlight.add(sessionID)
|
||||
try {
|
||||
const messagesResp = await ctx.client.session.messages({
|
||||
path: { id: sessionID },
|
||||
query: { directory: ctx.directory },
|
||||
})
|
||||
const msgs = (messagesResp as {
|
||||
data?: Array<{
|
||||
info?: Record<string, unknown>
|
||||
parts?: Array<{ type?: string; text?: string }>
|
||||
}>
|
||||
}).data
|
||||
const lastUserMsg = msgs?.filter((m) => m.info?.role === "user").pop()
|
||||
const lastUserPartsRaw =
|
||||
lastUserMsg?.parts ??
|
||||
(lastUserMsg?.info?.parts as Array<{ type?: string; text?: string }> | undefined)
|
||||
|
||||
if (lastUserPartsRaw && lastUserPartsRaw.length > 0) {
|
||||
log(`[${HOOK_NAME}] Auto-retrying with fallback model`, {
|
||||
sessionID,
|
||||
model: result.newModel,
|
||||
})
|
||||
|
||||
const retryParts = lastUserPartsRaw
|
||||
.filter((p) => p.type === "text" && typeof p.text === "string" && p.text.length > 0)
|
||||
.map((p) => ({ type: "text" as const, text: p.text! }))
|
||||
|
||||
if (retryParts.length > 0) {
|
||||
const retryAgent = resolvedAgent ?? getSessionAgent(sessionID)
|
||||
await ctx.client.session.promptAsync({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
...(retryAgent ? { agent: retryAgent } : {}),
|
||||
model: fallbackModelObj,
|
||||
parts: retryParts,
|
||||
},
|
||||
query: { directory: ctx.directory },
|
||||
})
|
||||
}
|
||||
} else {
|
||||
log(`[${HOOK_NAME}] No user message found for auto-retry`, { sessionID })
|
||||
}
|
||||
} catch (retryError) {
|
||||
log(`[${HOOK_NAME}] Auto-retry failed`, { sessionID, error: String(retryError) })
|
||||
} finally {
|
||||
sessionRetryInFlight.delete(sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.success) {
|
||||
log(`[${HOOK_NAME}] Fallback preparation failed`, { sessionID, error: result.error })
|
||||
}
|
||||
@ -337,7 +527,8 @@ export function createRuntimeFallbackHook(
|
||||
|
||||
let state = sessionStates.get(sessionID)
|
||||
const agent = info?.agent as string | undefined
|
||||
const fallbackModels = getFallbackModelsForSession(sessionID, agent, pluginConfig)
|
||||
const resolvedAgent = await resolveAgentForSessionFromContext(sessionID, agent)
|
||||
const fallbackModels = getFallbackModelsForSession(sessionID, resolvedAgent, pluginConfig)
|
||||
|
||||
if (fallbackModels.length === 0) {
|
||||
return
|
||||
@ -346,6 +537,9 @@ export function createRuntimeFallbackHook(
|
||||
if (!state) {
|
||||
state = createFallbackState(model)
|
||||
sessionStates.set(sessionID, state)
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
} else {
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
}
|
||||
|
||||
const result = prepareFallback(sessionID, state, fallbackModels, config)
|
||||
@ -362,6 +556,66 @@ export function createRuntimeFallbackHook(
|
||||
})
|
||||
.catch(() => {})
|
||||
}
|
||||
|
||||
if (result.success && result.newModel) {
|
||||
if (sessionRetryInFlight.has(sessionID)) {
|
||||
log(`[${HOOK_NAME}] Retry already in flight, skipping (message.updated)`, { sessionID })
|
||||
} else {
|
||||
const modelParts = result.newModel.split("/")
|
||||
if (modelParts.length >= 2) {
|
||||
const fallbackModelObj = {
|
||||
providerID: modelParts[0],
|
||||
modelID: modelParts.slice(1).join("/"),
|
||||
}
|
||||
|
||||
sessionRetryInFlight.add(sessionID)
|
||||
try {
|
||||
const messagesResp = await ctx.client.session.messages({
|
||||
path: { id: sessionID },
|
||||
query: { directory: ctx.directory },
|
||||
})
|
||||
const msgs = (messagesResp as {
|
||||
data?: Array<{
|
||||
info?: Record<string, unknown>
|
||||
parts?: Array<{ type?: string; text?: string }>
|
||||
}>
|
||||
}).data
|
||||
const lastUserMsg = msgs?.filter((m) => m.info?.role === "user").pop()
|
||||
const lastUserPartsRaw =
|
||||
lastUserMsg?.parts ??
|
||||
(lastUserMsg?.info?.parts as Array<{ type?: string; text?: string }> | undefined)
|
||||
|
||||
if (lastUserPartsRaw && lastUserPartsRaw.length > 0) {
|
||||
log(`[${HOOK_NAME}] Auto-retrying with fallback model (message.updated)`, {
|
||||
sessionID,
|
||||
model: result.newModel,
|
||||
})
|
||||
|
||||
const retryParts = lastUserPartsRaw
|
||||
.filter((p) => p.type === "text" && typeof p.text === "string" && p.text.length > 0)
|
||||
.map((p) => ({ type: "text" as const, text: p.text! }))
|
||||
|
||||
if (retryParts.length > 0) {
|
||||
const retryAgent = resolvedAgent ?? getSessionAgent(sessionID)
|
||||
await ctx.client.session.promptAsync({
|
||||
path: { id: sessionID },
|
||||
body: {
|
||||
...(retryAgent ? { agent: retryAgent } : {}),
|
||||
model: fallbackModelObj,
|
||||
parts: retryParts,
|
||||
},
|
||||
query: { directory: ctx.directory },
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (retryError) {
|
||||
log(`[${HOOK_NAME}] Auto-retry failed (message.updated)`, { sessionID, error: String(retryError) })
|
||||
} finally {
|
||||
sessionRetryInFlight.delete(sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -374,21 +628,38 @@ export function createRuntimeFallbackHook(
|
||||
if (!config.enabled) return
|
||||
|
||||
const { sessionID } = input
|
||||
const state = sessionStates.get(sessionID)
|
||||
let state = sessionStates.get(sessionID)
|
||||
|
||||
if (!state?.pendingFallbackModel) return
|
||||
if (!state) return
|
||||
|
||||
const fallbackModel = state.pendingFallbackModel
|
||||
state.pendingFallbackModel = undefined
|
||||
const requestedModel = input.model
|
||||
? `${input.model.providerID}/${input.model.modelID}`
|
||||
: undefined
|
||||
|
||||
log(`[${HOOK_NAME}] Applying fallback model for next request`, {
|
||||
if (requestedModel && requestedModel !== state.currentModel) {
|
||||
log(`[${HOOK_NAME}] Detected manual model change, resetting fallback state`, {
|
||||
sessionID,
|
||||
from: state.currentModel,
|
||||
to: requestedModel,
|
||||
})
|
||||
state = createFallbackState(requestedModel)
|
||||
sessionStates.set(sessionID, state)
|
||||
sessionLastAccess.set(sessionID, Date.now())
|
||||
return
|
||||
}
|
||||
|
||||
if (state.currentModel === state.originalModel) return
|
||||
|
||||
const activeModel = state.currentModel
|
||||
|
||||
log(`[${HOOK_NAME}] Applying fallback model override`, {
|
||||
sessionID,
|
||||
from: input.model,
|
||||
to: fallbackModel,
|
||||
to: activeModel,
|
||||
})
|
||||
|
||||
if (output.message && fallbackModel) {
|
||||
const parts = fallbackModel.split("/")
|
||||
if (output.message && activeModel) {
|
||||
const parts = activeModel.split("/")
|
||||
if (parts.length >= 2) {
|
||||
output.message.model = {
|
||||
providerID: parts[0],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user