feat(fallback_models): complete init-time and runtime integration
Implement full fallback_models support across all integration points: 1. Model Resolution Pipeline (src/shared/model-resolution-pipeline.ts) - Add userFallbackModels to ModelResolutionRequest - Process user fallback_models before hardcoded fallback chain - Support both connected provider and availability checking modes 2. Agent Utils (src/agents/utils.ts) - Update applyModelResolution to accept userFallbackModels - Inject fallback_models for all builtin agents (sisyphus, oracle, etc.) - Support both single string and array formats 3. Model Resolver (src/shared/model-resolver.ts) - Add userFallbackModels to ExtendedModelResolutionInput type - Pass through to resolveModelPipeline 4. Delegate Task Executor (src/tools/delegate-task/executor.ts) - Extract category fallback_models configuration - Pass to model resolution pipeline - Register session category for runtime-fallback hook 5. Session Category Registry (src/shared/session-category-registry.ts) - New module: maps sessionID -> category - Used by runtime-fallback to lookup category fallback_models - Auto-cleanup support 6. Runtime Fallback Hook (src/hooks/runtime-fallback/index.ts) - Check SessionCategoryRegistry first for category fallback_models - Fallback to agent-level configuration - Import and use SessionCategoryRegistry Test Results: - runtime-fallback: 24/24 tests passing - model-resolver: 46/46 tests passing Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
6dc1aff698
commit
7aafa13b21
@ -3,6 +3,7 @@ import type { RuntimeFallbackConfig, OhMyOpenCodeConfig } from "../../config"
|
|||||||
import type { FallbackState, FallbackResult, RuntimeFallbackHook } from "./types"
|
import type { FallbackState, FallbackResult, RuntimeFallbackHook } from "./types"
|
||||||
import { DEFAULT_CONFIG, RETRYABLE_ERROR_PATTERNS, HOOK_NAME } from "./constants"
|
import { DEFAULT_CONFIG, RETRYABLE_ERROR_PATTERNS, HOOK_NAME } from "./constants"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
|
import { SessionCategoryRegistry } from "../../shared/session-category-registry"
|
||||||
|
|
||||||
function createFallbackState(originalModel: string): FallbackState {
|
function createFallbackState(originalModel: string): FallbackState {
|
||||||
return {
|
return {
|
||||||
@ -87,6 +88,15 @@ function getFallbackModelsForSession(
|
|||||||
): string[] {
|
): string[] {
|
||||||
if (!pluginConfig) return []
|
if (!pluginConfig) return []
|
||||||
|
|
||||||
|
//#when - session has category from delegate_task, try category fallback_models first
|
||||||
|
const sessionCategory = SessionCategoryRegistry.get(sessionID)
|
||||||
|
if (sessionCategory && pluginConfig.categories?.[sessionCategory]) {
|
||||||
|
const categoryConfig = pluginConfig.categories[sessionCategory]
|
||||||
|
if (categoryConfig?.fallback_models) {
|
||||||
|
return normalizeFallbackModels(categoryConfig.fallback_models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const tryGetFallbackFromAgent = (agentName: string): string[] | undefined => {
|
const tryGetFallbackFromAgent = (agentName: string): string[] | undefined => {
|
||||||
const agentConfig = pluginConfig.agents?.[agentName as keyof typeof pluginConfig.agents]
|
const agentConfig = pluginConfig.agents?.[agentName as keyof typeof pluginConfig.agents]
|
||||||
if (!agentConfig) return undefined
|
if (!agentConfig) return undefined
|
||||||
|
|||||||
@ -58,3 +58,4 @@ export * from "./normalize-sdk-response"
|
|||||||
export * from "./session-directory-resolver"
|
export * from "./session-directory-resolver"
|
||||||
export * from "./prompt-tools"
|
export * from "./prompt-tools"
|
||||||
export * from "./internal-initiator-marker"
|
export * from "./internal-initiator-marker"
|
||||||
|
export { SessionCategoryRegistry } from "./session-category-registry"
|
||||||
|
|||||||
@ -7,6 +7,7 @@ export type ModelResolutionRequest = {
|
|||||||
intent?: {
|
intent?: {
|
||||||
uiSelectedModel?: string
|
uiSelectedModel?: string
|
||||||
userModel?: string
|
userModel?: string
|
||||||
|
userFallbackModels?: string[]
|
||||||
categoryDefaultModel?: string
|
categoryDefaultModel?: string
|
||||||
}
|
}
|
||||||
constraints: {
|
constraints: {
|
||||||
@ -97,6 +98,42 @@ export function resolveModelPipeline(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//#when - user configured fallback_models, try them before hardcoded fallback chain
|
||||||
|
const userFallbackModels = intent?.userFallbackModels
|
||||||
|
if (userFallbackModels && userFallbackModels.length > 0) {
|
||||||
|
if (availableModels.size === 0) {
|
||||||
|
const connectedProviders = readConnectedProvidersCache()
|
||||||
|
const connectedSet = connectedProviders ? new Set(connectedProviders) : null
|
||||||
|
|
||||||
|
if (connectedSet !== null) {
|
||||||
|
for (const model of userFallbackModels) {
|
||||||
|
attempted.push(model)
|
||||||
|
const parts = model.split("/")
|
||||||
|
if (parts.length >= 2) {
|
||||||
|
const provider = parts[0]
|
||||||
|
if (connectedSet.has(provider)) {
|
||||||
|
log("Model resolved via user fallback_models (connected provider)", { model })
|
||||||
|
return { model, provenance: "provider-fallback", attempted }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log("No connected provider found in user fallback_models, falling through to hardcoded chain")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const model of userFallbackModels) {
|
||||||
|
attempted.push(model)
|
||||||
|
const parts = model.split("/")
|
||||||
|
const providerHint = parts.length >= 2 ? [parts[0]] : undefined
|
||||||
|
const match = fuzzyMatchModel(model, availableModels, providerHint)
|
||||||
|
if (match) {
|
||||||
|
log("Model resolved via user fallback_models (availability confirmed)", { model: model, match })
|
||||||
|
return { model: match, provenance: "provider-fallback", attempted }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log("No available model found in user fallback_models, falling through to hardcoded chain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (fallbackChain && fallbackChain.length > 0) {
|
if (fallbackChain && fallbackChain.length > 0) {
|
||||||
if (availableModels.size === 0) {
|
if (availableModels.size === 0) {
|
||||||
const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache()
|
const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache()
|
||||||
|
|||||||
@ -22,6 +22,7 @@ export type ModelResolutionResult = {
|
|||||||
export type ExtendedModelResolutionInput = {
|
export type ExtendedModelResolutionInput = {
|
||||||
uiSelectedModel?: string
|
uiSelectedModel?: string
|
||||||
userModel?: string
|
userModel?: string
|
||||||
|
userFallbackModels?: string[]
|
||||||
categoryDefaultModel?: string
|
categoryDefaultModel?: string
|
||||||
fallbackChain?: FallbackEntry[]
|
fallbackChain?: FallbackEntry[]
|
||||||
availableModels: Set<string>
|
availableModels: Set<string>
|
||||||
@ -44,9 +45,9 @@ export function resolveModel(input: ModelResolutionInput): string | undefined {
|
|||||||
export function resolveModelWithFallback(
|
export function resolveModelWithFallback(
|
||||||
input: ExtendedModelResolutionInput,
|
input: ExtendedModelResolutionInput,
|
||||||
): ModelResolutionResult | undefined {
|
): ModelResolutionResult | undefined {
|
||||||
const { uiSelectedModel, userModel, categoryDefaultModel, fallbackChain, availableModels, systemDefaultModel } = input
|
const { uiSelectedModel, userModel, userFallbackModels, categoryDefaultModel, fallbackChain, availableModels, systemDefaultModel } = input
|
||||||
const resolved = resolveModelPipeline({
|
const resolved = resolveModelPipeline({
|
||||||
intent: { uiSelectedModel, userModel, categoryDefaultModel },
|
intent: { uiSelectedModel, userModel, userFallbackModels, categoryDefaultModel },
|
||||||
constraints: { availableModels },
|
constraints: { availableModels },
|
||||||
policy: { fallbackChain, systemDefaultModel },
|
policy: { fallbackChain, systemDefaultModel },
|
||||||
})
|
})
|
||||||
|
|||||||
53
src/shared/session-category-registry.ts
Normal file
53
src/shared/session-category-registry.ts
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/**
|
||||||
|
* Session Category Registry
|
||||||
|
*
|
||||||
|
* Maintains a mapping of session IDs to their assigned categories.
|
||||||
|
* Used by runtime-fallback hook to lookup category-specific fallback_models.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Map of sessionID -> category name
|
||||||
|
const sessionCategoryMap = new Map<string, string>()
|
||||||
|
|
||||||
|
export const SessionCategoryRegistry = {
|
||||||
|
/**
|
||||||
|
* Register a session with its category
|
||||||
|
*/
|
||||||
|
register: (sessionID: string, category: string): void => {
|
||||||
|
sessionCategoryMap.set(sessionID, category)
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the category for a session
|
||||||
|
*/
|
||||||
|
get: (sessionID: string): string | undefined => {
|
||||||
|
return sessionCategoryMap.get(sessionID)
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a session from the registry (cleanup)
|
||||||
|
*/
|
||||||
|
remove: (sessionID: string): void => {
|
||||||
|
sessionCategoryMap.delete(sessionID)
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a session is registered
|
||||||
|
*/
|
||||||
|
has: (sessionID: string): boolean => {
|
||||||
|
return sessionCategoryMap.has(sessionID)
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the size of the registry (for debugging)
|
||||||
|
*/
|
||||||
|
size: (): number => {
|
||||||
|
return sessionCategoryMap.size
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all entries (use with caution, mainly for testing)
|
||||||
|
*/
|
||||||
|
clear: (): void => {
|
||||||
|
sessionCategoryMap.clear()
|
||||||
|
},
|
||||||
|
}
|
||||||
@ -4,6 +4,7 @@ import { getTimingConfig } from "./timing"
|
|||||||
import { storeToolMetadata } from "../../features/tool-metadata-store"
|
import { storeToolMetadata } from "../../features/tool-metadata-store"
|
||||||
import { formatDetailedError } from "./error-formatting"
|
import { formatDetailedError } from "./error-formatting"
|
||||||
import { getSessionTools } from "../../shared/session-tools-store"
|
import { getSessionTools } from "../../shared/session-tools-store"
|
||||||
|
import { SessionCategoryRegistry } from "../../shared/session-category-registry"
|
||||||
|
|
||||||
export async function executeBackgroundTask(
|
export async function executeBackgroundTask(
|
||||||
args: DelegateTaskArgs,
|
args: DelegateTaskArgs,
|
||||||
@ -48,6 +49,10 @@ export async function executeBackgroundTask(
|
|||||||
sessionId = updated?.sessionID
|
sessionId = updated?.sessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (args.category && sessionId) {
|
||||||
|
SessionCategoryRegistry.register(sessionId, args.category)
|
||||||
|
}
|
||||||
|
|
||||||
const unstableMeta = {
|
const unstableMeta = {
|
||||||
title: args.description,
|
title: args.description,
|
||||||
metadata: {
|
metadata: {
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import { getTaskToastManager } from "../../features/task-toast-manager"
|
|||||||
import { storeToolMetadata } from "../../features/tool-metadata-store"
|
import { storeToolMetadata } from "../../features/tool-metadata-store"
|
||||||
import { subagentSessions } from "../../features/claude-code-session-state"
|
import { subagentSessions } from "../../features/claude-code-session-state"
|
||||||
import { log } from "../../shared/logger"
|
import { log } from "../../shared/logger"
|
||||||
|
import { SessionCategoryRegistry } from "../../shared/session-category-registry"
|
||||||
import { formatDuration } from "./time-formatter"
|
import { formatDuration } from "./time-formatter"
|
||||||
import { formatDetailedError } from "./error-formatting"
|
import { formatDetailedError } from "./error-formatting"
|
||||||
import { syncTaskDeps, type SyncTaskDeps } from "./sync-task-deps"
|
import { syncTaskDeps, type SyncTaskDeps } from "./sync-task-deps"
|
||||||
@ -41,6 +42,10 @@ export async function executeSyncTask(
|
|||||||
syncSessionID = sessionID
|
syncSessionID = sessionID
|
||||||
subagentSessions.add(sessionID)
|
subagentSessions.add(sessionID)
|
||||||
|
|
||||||
|
if (args.category) {
|
||||||
|
SessionCategoryRegistry.register(sessionID, args.category)
|
||||||
|
}
|
||||||
|
|
||||||
if (onSyncSessionCreated) {
|
if (onSyncSessionCreated) {
|
||||||
log("[task] Invoking onSyncSessionCreated callback", { sessionID, parentID: parentContext.sessionID })
|
log("[task] Invoking onSyncSessionCreated callback", { sessionID, parentID: parentContext.sessionID })
|
||||||
await onSyncSessionCreated({
|
await onSyncSessionCreated({
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user