fix: honor agent variant overrides (#1394)
* fix(shared): honor agent variant overrides * test(shared): use model in fallback chain to verify override precedence Address PR review: test now uses claude-opus-4-5 (which has default variant 'max' in sisyphus chain) to properly verify that agent override 'high' takes precedence over the fallback chain's default variant.
This commit is contained in:
parent
d165a6821d
commit
5c68ae3bee
@ -83,6 +83,23 @@ describe("applyAgentVariant", () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
describe("resolveVariantForModel", () => {
|
describe("resolveVariantForModel", () => {
|
||||||
|
test("returns agent override variant when configured", () => {
|
||||||
|
// given - use a model in sisyphus chain (claude-opus-4-5 has default variant "max")
|
||||||
|
// to verify override takes precedence over fallback chain
|
||||||
|
const config = {
|
||||||
|
agents: {
|
||||||
|
sisyphus: { variant: "high" },
|
||||||
|
},
|
||||||
|
} as OhMyOpenCodeConfig
|
||||||
|
const model = { providerID: "anthropic", modelID: "claude-opus-4-5" }
|
||||||
|
|
||||||
|
// when
|
||||||
|
const variant = resolveVariantForModel(config, "sisyphus", model)
|
||||||
|
|
||||||
|
// then
|
||||||
|
expect(variant).toBe("high")
|
||||||
|
})
|
||||||
|
|
||||||
test("returns correct variant for anthropic provider", () => {
|
test("returns correct variant for anthropic provider", () => {
|
||||||
// given
|
// given
|
||||||
const config = {} as OhMyOpenCodeConfig
|
const config = {} as OhMyOpenCodeConfig
|
||||||
|
|||||||
@ -37,23 +37,26 @@ export function resolveVariantForModel(
|
|||||||
agentName: string,
|
agentName: string,
|
||||||
currentModel: { providerID: string; modelID: string },
|
currentModel: { providerID: string; modelID: string },
|
||||||
): string | undefined {
|
): string | undefined {
|
||||||
const agentRequirement = AGENT_MODEL_REQUIREMENTS[agentName]
|
|
||||||
if (agentRequirement) {
|
|
||||||
return findVariantInChain(agentRequirement.fallbackChain, currentModel.providerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
const agentOverrides = config.agents as
|
const agentOverrides = config.agents as
|
||||||
| Record<string, { category?: string }>
|
| Record<string, { variant?: string; category?: string }>
|
||||||
| undefined
|
| undefined
|
||||||
const agentOverride = agentOverrides
|
const agentOverride = agentOverrides
|
||||||
? agentOverrides[agentName]
|
? agentOverrides[agentName]
|
||||||
?? Object.entries(agentOverrides).find(([key]) => key.toLowerCase() === agentName.toLowerCase())?.[1]
|
?? Object.entries(agentOverrides).find(([key]) => key.toLowerCase() === agentName.toLowerCase())?.[1]
|
||||||
: undefined
|
: undefined
|
||||||
|
if (agentOverride?.variant) {
|
||||||
|
return agentOverride.variant
|
||||||
|
}
|
||||||
|
|
||||||
|
const agentRequirement = AGENT_MODEL_REQUIREMENTS[agentName]
|
||||||
|
if (agentRequirement) {
|
||||||
|
return findVariantInChain(agentRequirement.fallbackChain, currentModel)
|
||||||
|
}
|
||||||
const categoryName = agentOverride?.category
|
const categoryName = agentOverride?.category
|
||||||
if (categoryName) {
|
if (categoryName) {
|
||||||
const categoryRequirement = CATEGORY_MODEL_REQUIREMENTS[categoryName]
|
const categoryRequirement = CATEGORY_MODEL_REQUIREMENTS[categoryName]
|
||||||
if (categoryRequirement) {
|
if (categoryRequirement) {
|
||||||
return findVariantInChain(categoryRequirement.fallbackChain, currentModel.providerID)
|
return findVariantInChain(categoryRequirement.fallbackChain, currentModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,10 +65,13 @@ export function resolveVariantForModel(
|
|||||||
|
|
||||||
function findVariantInChain(
|
function findVariantInChain(
|
||||||
fallbackChain: { providers: string[]; model: string; variant?: string }[],
|
fallbackChain: { providers: string[]; model: string; variant?: string }[],
|
||||||
providerID: string,
|
currentModel: { providerID: string; modelID: string },
|
||||||
): string | undefined {
|
): string | undefined {
|
||||||
for (const entry of fallbackChain) {
|
for (const entry of fallbackChain) {
|
||||||
if (entry.providers.includes(providerID)) {
|
if (
|
||||||
|
entry.providers.includes(currentModel.providerID)
|
||||||
|
&& entry.model === currentModel.modelID
|
||||||
|
) {
|
||||||
return entry.variant
|
return entry.variant
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user