diff --git a/src/hooks/runtime-fallback/index.ts b/src/hooks/runtime-fallback/index.ts index 3a4fe8da..06f7f9ec 100644 --- a/src/hooks/runtime-fallback/index.ts +++ b/src/hooks/runtime-fallback/index.ts @@ -3,6 +3,7 @@ import type { RuntimeFallbackConfig, OhMyOpenCodeConfig } from "../../config" import type { FallbackState, FallbackResult, RuntimeFallbackHook } from "./types" import { DEFAULT_CONFIG, RETRYABLE_ERROR_PATTERNS, HOOK_NAME } from "./constants" import { log } from "../../shared/logger" +import { SessionCategoryRegistry } from "../../shared/session-category-registry" function createFallbackState(originalModel: string): FallbackState { return { @@ -87,6 +88,15 @@ function getFallbackModelsForSession( ): string[] { if (!pluginConfig) return [] + //#when - session has category from delegate_task, try category fallback_models first + const sessionCategory = SessionCategoryRegistry.get(sessionID) + if (sessionCategory && pluginConfig.categories?.[sessionCategory]) { + const categoryConfig = pluginConfig.categories[sessionCategory] + if (categoryConfig?.fallback_models) { + return normalizeFallbackModels(categoryConfig.fallback_models) + } + } + const tryGetFallbackFromAgent = (agentName: string): string[] | undefined => { const agentConfig = pluginConfig.agents?.[agentName as keyof typeof pluginConfig.agents] if (!agentConfig) return undefined diff --git a/src/shared/index.ts b/src/shared/index.ts index ce8e69be..263d50ba 100644 --- a/src/shared/index.ts +++ b/src/shared/index.ts @@ -58,3 +58,4 @@ export * from "./normalize-sdk-response" export * from "./session-directory-resolver" export * from "./prompt-tools" export * from "./internal-initiator-marker" +export { SessionCategoryRegistry } from "./session-category-registry" diff --git a/src/shared/model-resolution-pipeline.ts b/src/shared/model-resolution-pipeline.ts index 34d1c13b..12b337cf 100644 --- a/src/shared/model-resolution-pipeline.ts +++ b/src/shared/model-resolution-pipeline.ts @@ -7,6 +7,7 @@ export type ModelResolutionRequest = { intent?: { uiSelectedModel?: string userModel?: string + userFallbackModels?: string[] categoryDefaultModel?: string } constraints: { @@ -97,6 +98,42 @@ export function resolveModelPipeline( }) } + //#when - user configured fallback_models, try them before hardcoded fallback chain + const userFallbackModels = intent?.userFallbackModels + if (userFallbackModels && userFallbackModels.length > 0) { + if (availableModels.size === 0) { + const connectedProviders = readConnectedProvidersCache() + const connectedSet = connectedProviders ? new Set(connectedProviders) : null + + if (connectedSet !== null) { + for (const model of userFallbackModels) { + attempted.push(model) + const parts = model.split("/") + if (parts.length >= 2) { + const provider = parts[0] + if (connectedSet.has(provider)) { + log("Model resolved via user fallback_models (connected provider)", { model }) + return { model, provenance: "provider-fallback", attempted } + } + } + } + log("No connected provider found in user fallback_models, falling through to hardcoded chain") + } + } else { + for (const model of userFallbackModels) { + attempted.push(model) + const parts = model.split("/") + const providerHint = parts.length >= 2 ? [parts[0]] : undefined + const match = fuzzyMatchModel(model, availableModels, providerHint) + if (match) { + log("Model resolved via user fallback_models (availability confirmed)", { model: model, match }) + return { model: match, provenance: "provider-fallback", attempted } + } + } + log("No available model found in user fallback_models, falling through to hardcoded chain") + } + } + if (fallbackChain && fallbackChain.length > 0) { if (availableModels.size === 0) { const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache() diff --git a/src/shared/model-resolver.ts b/src/shared/model-resolver.ts index cbaa8c48..84cbcbe2 100644 --- a/src/shared/model-resolver.ts +++ b/src/shared/model-resolver.ts @@ -22,6 +22,7 @@ export type ModelResolutionResult = { export type ExtendedModelResolutionInput = { uiSelectedModel?: string userModel?: string + userFallbackModels?: string[] categoryDefaultModel?: string fallbackChain?: FallbackEntry[] availableModels: Set @@ -44,9 +45,9 @@ export function resolveModel(input: ModelResolutionInput): string | undefined { export function resolveModelWithFallback( input: ExtendedModelResolutionInput, ): ModelResolutionResult | undefined { - const { uiSelectedModel, userModel, categoryDefaultModel, fallbackChain, availableModels, systemDefaultModel } = input + const { uiSelectedModel, userModel, userFallbackModels, categoryDefaultModel, fallbackChain, availableModels, systemDefaultModel } = input const resolved = resolveModelPipeline({ - intent: { uiSelectedModel, userModel, categoryDefaultModel }, + intent: { uiSelectedModel, userModel, userFallbackModels, categoryDefaultModel }, constraints: { availableModels }, policy: { fallbackChain, systemDefaultModel }, }) diff --git a/src/shared/session-category-registry.ts b/src/shared/session-category-registry.ts new file mode 100644 index 00000000..ce19e1c0 --- /dev/null +++ b/src/shared/session-category-registry.ts @@ -0,0 +1,53 @@ +/** + * Session Category Registry + * + * Maintains a mapping of session IDs to their assigned categories. + * Used by runtime-fallback hook to lookup category-specific fallback_models. + */ + +// Map of sessionID -> category name +const sessionCategoryMap = new Map() + +export const SessionCategoryRegistry = { + /** + * Register a session with its category + */ + register: (sessionID: string, category: string): void => { + sessionCategoryMap.set(sessionID, category) + }, + + /** + * Get the category for a session + */ + get: (sessionID: string): string | undefined => { + return sessionCategoryMap.get(sessionID) + }, + + /** + * Remove a session from the registry (cleanup) + */ + remove: (sessionID: string): void => { + sessionCategoryMap.delete(sessionID) + }, + + /** + * Check if a session is registered + */ + has: (sessionID: string): boolean => { + return sessionCategoryMap.has(sessionID) + }, + + /** + * Get the size of the registry (for debugging) + */ + size: (): number => { + return sessionCategoryMap.size + }, + + /** + * Clear all entries (use with caution, mainly for testing) + */ + clear: (): void => { + sessionCategoryMap.clear() + }, +} diff --git a/src/tools/delegate-task/background-task.ts b/src/tools/delegate-task/background-task.ts index e724695b..62580541 100644 --- a/src/tools/delegate-task/background-task.ts +++ b/src/tools/delegate-task/background-task.ts @@ -4,6 +4,7 @@ import { getTimingConfig } from "./timing" import { storeToolMetadata } from "../../features/tool-metadata-store" import { formatDetailedError } from "./error-formatting" import { getSessionTools } from "../../shared/session-tools-store" +import { SessionCategoryRegistry } from "../../shared/session-category-registry" export async function executeBackgroundTask( args: DelegateTaskArgs, @@ -48,6 +49,10 @@ export async function executeBackgroundTask( sessionId = updated?.sessionID } + if (args.category && sessionId) { + SessionCategoryRegistry.register(sessionId, args.category) + } + const unstableMeta = { title: args.description, metadata: { diff --git a/src/tools/delegate-task/sync-task.ts b/src/tools/delegate-task/sync-task.ts index d9543786..13b701c2 100644 --- a/src/tools/delegate-task/sync-task.ts +++ b/src/tools/delegate-task/sync-task.ts @@ -5,6 +5,7 @@ import { getTaskToastManager } from "../../features/task-toast-manager" import { storeToolMetadata } from "../../features/tool-metadata-store" import { subagentSessions } from "../../features/claude-code-session-state" import { log } from "../../shared/logger" +import { SessionCategoryRegistry } from "../../shared/session-category-registry" import { formatDuration } from "./time-formatter" import { formatDetailedError } from "./error-formatting" import { syncTaskDeps, type SyncTaskDeps } from "./sync-task-deps" @@ -41,6 +42,10 @@ export async function executeSyncTask( syncSessionID = sessionID subagentSessions.add(sessionID) + if (args.category) { + SessionCategoryRegistry.register(sessionID, args.category) + } + if (onSyncSessionCreated) { log("[task] Invoking onSyncSessionCreated callback", { sessionID, parentID: parentContext.sessionID }) await onSyncSessionCreated({