fix(model-fallback): apply transformModelForProvider in getNextFallback
fix(model-fallback): apply transformModelForProvider in getNextFallback
This commit is contained in:
commit
7a43737cd6
@ -3,12 +3,15 @@ import { beforeEach, describe, expect, test } from "bun:test"
|
||||
import {
|
||||
clearPendingModelFallback,
|
||||
createModelFallbackHook,
|
||||
setSessionFallbackChain,
|
||||
setPendingModelFallback,
|
||||
} from "./hook"
|
||||
|
||||
describe("model fallback hook", () => {
|
||||
beforeEach(() => {
|
||||
clearPendingModelFallback("ses_model_fallback_main")
|
||||
clearPendingModelFallback("ses_model_fallback_ghcp")
|
||||
clearPendingModelFallback("ses_model_fallback_google")
|
||||
})
|
||||
|
||||
test("applies pending fallback on chat.message by overriding model", async () => {
|
||||
@ -138,4 +141,92 @@ describe("model fallback hook", () => {
|
||||
expect(toastCalls.length).toBe(1)
|
||||
expect(toastCalls[0]?.title).toBe("Model fallback")
|
||||
})
|
||||
|
||||
test("transforms model names for github-copilot provider via fallback chain", async () => {
|
||||
//#given
|
||||
const sessionID = "ses_model_fallback_ghcp"
|
||||
clearPendingModelFallback(sessionID)
|
||||
|
||||
const hook = createModelFallbackHook() as unknown as {
|
||||
"chat.message"?: (
|
||||
input: { sessionID: string },
|
||||
output: { message: Record<string, unknown>; parts: Array<{ type: string; text?: string }> },
|
||||
) => Promise<void>
|
||||
}
|
||||
|
||||
// Set a custom fallback chain that routes through github-copilot
|
||||
setSessionFallbackChain(sessionID, [
|
||||
{ providers: ["github-copilot"], model: "claude-sonnet-4-6" },
|
||||
])
|
||||
|
||||
const set = setPendingModelFallback(
|
||||
sessionID,
|
||||
"Atlas (Plan Executor)",
|
||||
"github-copilot",
|
||||
"claude-sonnet-4-6",
|
||||
)
|
||||
expect(set).toBe(true)
|
||||
|
||||
const output = {
|
||||
message: {
|
||||
model: { providerID: "github-copilot", modelID: "claude-sonnet-4-6" },
|
||||
},
|
||||
parts: [{ type: "text", text: "continue" }],
|
||||
}
|
||||
|
||||
//#when
|
||||
await hook["chat.message"]?.({ sessionID }, output)
|
||||
|
||||
//#then — model name should be transformed from hyphen to dot notation
|
||||
expect(output.message["model"]).toEqual({
|
||||
providerID: "github-copilot",
|
||||
modelID: "claude-sonnet-4.6",
|
||||
})
|
||||
|
||||
clearPendingModelFallback(sessionID)
|
||||
})
|
||||
|
||||
test("transforms model names for google provider via fallback chain", async () => {
|
||||
//#given
|
||||
const sessionID = "ses_model_fallback_google"
|
||||
clearPendingModelFallback(sessionID)
|
||||
|
||||
const hook = createModelFallbackHook() as unknown as {
|
||||
"chat.message"?: (
|
||||
input: { sessionID: string },
|
||||
output: { message: Record<string, unknown>; parts: Array<{ type: string; text?: string }> },
|
||||
) => Promise<void>
|
||||
}
|
||||
|
||||
// Set a custom fallback chain that routes through google
|
||||
setSessionFallbackChain(sessionID, [
|
||||
{ providers: ["google"], model: "gemini-3-pro" },
|
||||
])
|
||||
|
||||
const set = setPendingModelFallback(
|
||||
sessionID,
|
||||
"Oracle",
|
||||
"google",
|
||||
"gemini-3-pro",
|
||||
)
|
||||
expect(set).toBe(true)
|
||||
|
||||
const output = {
|
||||
message: {
|
||||
model: { providerID: "google", modelID: "gemini-3-pro" },
|
||||
},
|
||||
parts: [{ type: "text", text: "continue" }],
|
||||
}
|
||||
|
||||
//#when
|
||||
await hook["chat.message"]?.({ sessionID }, output)
|
||||
|
||||
//#then — model name should be transformed from gemini-3-pro to gemini-3-pro-preview
|
||||
expect(output.message["model"]).toEqual({
|
||||
providerID: "google",
|
||||
modelID: "gemini-3-pro-preview",
|
||||
})
|
||||
|
||||
clearPendingModelFallback(sessionID)
|
||||
})
|
||||
})
|
||||
|
||||
@ -3,6 +3,7 @@ import { getAgentConfigKey } from "../../shared/agent-display-names"
|
||||
import { AGENT_MODEL_REQUIREMENTS } from "../../shared/model-requirements"
|
||||
import { readConnectedProvidersCache, readProviderModelsCache } from "../../shared/connected-providers-cache"
|
||||
import { selectFallbackProvider } from "../../shared/model-error-classifier"
|
||||
import { transformModelForProvider } from "../../shared/provider-model-id-transform"
|
||||
import { log } from "../../shared/logger"
|
||||
import { getTaskToastManager } from "../../features/task-toast-manager"
|
||||
import type { ChatMessageInput, ChatMessageHandlerOutput } from "../../plugin/chat-message"
|
||||
@ -145,7 +146,7 @@ export function getNextFallback(
|
||||
|
||||
return {
|
||||
providerID,
|
||||
modelID: fallback.model,
|
||||
modelID: transformModelForProvider(providerID, fallback.model),
|
||||
variant: fallback.variant,
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user