diff --git a/src/hooks/model-fallback/hook.test.ts b/src/hooks/model-fallback/hook.test.ts index 4d30d5b0..348f163a 100644 --- a/src/hooks/model-fallback/hook.test.ts +++ b/src/hooks/model-fallback/hook.test.ts @@ -3,12 +3,15 @@ import { beforeEach, describe, expect, test } from "bun:test" import { clearPendingModelFallback, createModelFallbackHook, + setSessionFallbackChain, setPendingModelFallback, } from "./hook" describe("model fallback hook", () => { beforeEach(() => { clearPendingModelFallback("ses_model_fallback_main") + clearPendingModelFallback("ses_model_fallback_ghcp") + clearPendingModelFallback("ses_model_fallback_google") }) test("applies pending fallback on chat.message by overriding model", async () => { @@ -138,4 +141,92 @@ describe("model fallback hook", () => { expect(toastCalls.length).toBe(1) expect(toastCalls[0]?.title).toBe("Model fallback") }) + + test("transforms model names for github-copilot provider via fallback chain", async () => { + //#given + const sessionID = "ses_model_fallback_ghcp" + clearPendingModelFallback(sessionID) + + const hook = createModelFallbackHook() as unknown as { + "chat.message"?: ( + input: { sessionID: string }, + output: { message: Record; parts: Array<{ type: string; text?: string }> }, + ) => Promise + } + + // Set a custom fallback chain that routes through github-copilot + setSessionFallbackChain(sessionID, [ + { providers: ["github-copilot"], model: "claude-sonnet-4-6" }, + ]) + + const set = setPendingModelFallback( + sessionID, + "Atlas (Plan Executor)", + "github-copilot", + "claude-sonnet-4-6", + ) + expect(set).toBe(true) + + const output = { + message: { + model: { providerID: "github-copilot", modelID: "claude-sonnet-4-6" }, + }, + parts: [{ type: "text", text: "continue" }], + } + + //#when + await hook["chat.message"]?.({ sessionID }, output) + + //#then — model name should be transformed from hyphen to dot notation + expect(output.message["model"]).toEqual({ + providerID: "github-copilot", + modelID: "claude-sonnet-4.6", + }) + + clearPendingModelFallback(sessionID) + }) + + test("transforms model names for google provider via fallback chain", async () => { + //#given + const sessionID = "ses_model_fallback_google" + clearPendingModelFallback(sessionID) + + const hook = createModelFallbackHook() as unknown as { + "chat.message"?: ( + input: { sessionID: string }, + output: { message: Record; parts: Array<{ type: string; text?: string }> }, + ) => Promise + } + + // Set a custom fallback chain that routes through google + setSessionFallbackChain(sessionID, [ + { providers: ["google"], model: "gemini-3-pro" }, + ]) + + const set = setPendingModelFallback( + sessionID, + "Oracle", + "google", + "gemini-3-pro", + ) + expect(set).toBe(true) + + const output = { + message: { + model: { providerID: "google", modelID: "gemini-3-pro" }, + }, + parts: [{ type: "text", text: "continue" }], + } + + //#when + await hook["chat.message"]?.({ sessionID }, output) + + //#then — model name should be transformed from gemini-3-pro to gemini-3-pro-preview + expect(output.message["model"]).toEqual({ + providerID: "google", + modelID: "gemini-3-pro-preview", + }) + + clearPendingModelFallback(sessionID) + }) }) diff --git a/src/hooks/model-fallback/hook.ts b/src/hooks/model-fallback/hook.ts index fbe9deab..bbb01825 100644 --- a/src/hooks/model-fallback/hook.ts +++ b/src/hooks/model-fallback/hook.ts @@ -3,6 +3,7 @@ import { getAgentConfigKey } from "../../shared/agent-display-names" import { AGENT_MODEL_REQUIREMENTS } from "../../shared/model-requirements" import { readConnectedProvidersCache, readProviderModelsCache } from "../../shared/connected-providers-cache" import { selectFallbackProvider } from "../../shared/model-error-classifier" +import { transformModelForProvider } from "../../shared/provider-model-id-transform" import { log } from "../../shared/logger" import { getTaskToastManager } from "../../features/task-toast-manager" import type { ChatMessageInput, ChatMessageHandlerOutput } from "../../plugin/chat-message" @@ -145,7 +146,7 @@ export function getNextFallback( return { providerID, - modelID: fallback.model, + modelID: transformModelForProvider(providerID, fallback.model), variant: fallback.variant, } }