diff --git a/src/shared/model-availability.test.ts b/src/shared/model-availability.test.ts index cdd41b03..cbaed0f5 100644 --- a/src/shared/model-availability.test.ts +++ b/src/shared/model-availability.test.ts @@ -1,38 +1,43 @@ -import { describe, it, expect, beforeEach, afterEach, beforeAll, afterAll, mock } from "bun:test" +declare const require: (name: string) => any +const { describe, it, expect, beforeEach, afterEach, beforeAll } = require("bun:test") import { mkdtempSync, writeFileSync, rmSync } from "fs" import { tmpdir } from "os" import { join } from "path" -import { fuzzyMatchModel, isModelAvailable } from "./model-name-matcher" -let activeCacheHomeDir: string | null = null -const DEFAULT_CACHE_HOME_DIR = join(tmpdir(), "opencode-test-default-cache") +let __resetModelCache: () => void +let fetchAvailableModels: (client?: unknown, options?: { connectedProviders?: string[] | null }) => Promise> +let fuzzyMatchModel: (target: string, available: Set, providers?: string[]) => string | null +let isModelAvailable: (targetModel: string, availableModels: Set) => boolean +let getConnectedProviders: (client: unknown) => Promise -mock.module("./data-path", () => ({ - getDataDir: () => activeCacheHomeDir ?? DEFAULT_CACHE_HOME_DIR, - getOpenCodeStorageDir: () => join(activeCacheHomeDir ?? DEFAULT_CACHE_HOME_DIR, "opencode", "storage"), - getCacheDir: () => activeCacheHomeDir ?? DEFAULT_CACHE_HOME_DIR, - getOmoOpenCodeCacheDir: () => join(activeCacheHomeDir ?? DEFAULT_CACHE_HOME_DIR, "oh-my-opencode"), - getOpenCodeCacheDir: () => join(activeCacheHomeDir ?? DEFAULT_CACHE_HOME_DIR, "opencode"), -})) +beforeAll(async () => { + ;({ + __resetModelCache, + fetchAvailableModels, + fuzzyMatchModel, + isModelAvailable, + getConnectedProviders, + } = await import("./model-availability")) +}) describe("fetchAvailableModels", () => { let tempDir: string - let fetchAvailableModels: (client?: unknown, options?: { connectedProviders?: string[] | null }) => Promise> - let __resetModelCache: () => void + let originalXdgCache: string | undefined - beforeAll(async () => { - ;({ fetchAvailableModels } = await import("./available-models-fetcher")) - ;({ __resetModelCache } = await import("./model-cache-availability")) - }) beforeEach(() => { __resetModelCache() tempDir = mkdtempSync(join(tmpdir(), "opencode-test-")) - activeCacheHomeDir = tempDir + originalXdgCache = process.env.XDG_CACHE_HOME + process.env.XDG_CACHE_HOME = tempDir }) afterEach(() => { - activeCacheHomeDir = null + if (originalXdgCache !== undefined) { + process.env.XDG_CACHE_HOME = originalXdgCache + } else { + delete process.env.XDG_CACHE_HOME + } rmSync(tempDir, { recursive: true, force: true }) }) diff --git a/src/shared/model-availability.ts b/src/shared/model-availability.ts index 6fa2fb17..1ff696ee 100644 --- a/src/shared/model-availability.ts +++ b/src/shared/model-availability.ts @@ -2,7 +2,7 @@ import { existsSync, readFileSync } from "fs" import { join } from "path" import { log } from "./logger" import { getOpenCodeCacheDir } from "./data-path" -import { readProviderModelsCache, hasProviderModelsCache, readConnectedProvidersCache } from "./connected-providers-cache" +import * as connectedProvidersCache from "./connected-providers-cache" /** * Fuzzy match a target model name against available models @@ -181,7 +181,7 @@ export async function fetchAvailableModels( const connectedSet = new Set(connectedProvidersList) const modelSet = new Set() - const providerModelsCache = readProviderModelsCache() + const providerModelsCache = connectedProvidersCache.readProviderModelsCache() if (providerModelsCache) { const providerCount = Object.keys(providerModelsCache.models).length if (providerCount === 0) { @@ -189,7 +189,8 @@ export async function fetchAvailableModels( } else { log("[fetchAvailableModels] using provider-models cache (whitelist-filtered)") - for (const [providerId, modelIds] of Object.entries(providerModelsCache.models)) { + const modelsByProvider = providerModelsCache.models as Record> + for (const [providerId, modelIds] of Object.entries(modelsByProvider)) { if (!connectedSet.has(providerId)) { continue } @@ -300,7 +301,7 @@ export function isAnyFallbackModelAvailable( // Fallback: check if any provider in the chain is connected // This handles race conditions where availableModels is empty or incomplete // but we know the provider is connected. - const connectedProviders = readConnectedProvidersCache() + const connectedProviders = connectedProvidersCache.readConnectedProvidersCache() if (connectedProviders) { const connectedSet = new Set(connectedProviders) for (const entry of fallbackChain) { @@ -332,7 +333,7 @@ export function isAnyProviderConnected( } } - const connectedProviders = readConnectedProvidersCache() + const connectedProviders = connectedProvidersCache.readConnectedProvidersCache() if (connectedProviders) { const connectedSet = new Set(connectedProviders) for (const provider of providers) { @@ -349,7 +350,7 @@ export function isAnyProviderConnected( export function __resetModelCache(): void {} export function isModelCacheAvailable(): boolean { - if (hasProviderModelsCache()) { + if (connectedProvidersCache.hasProviderModelsCache()) { return true } const cacheFile = join(getOpenCodeCacheDir(), "models.json") diff --git a/src/shared/model-resolution-pipeline.ts b/src/shared/model-resolution-pipeline.ts index 552746c8..34d1c13b 100644 --- a/src/shared/model-resolution-pipeline.ts +++ b/src/shared/model-resolution-pipeline.ts @@ -1,5 +1,5 @@ import { log } from "./logger" -import { readConnectedProvidersCache } from "./connected-providers-cache" +import * as connectedProvidersCache from "./connected-providers-cache" import { fuzzyMatchModel } from "./model-availability" import type { FallbackEntry } from "./model-requirements" @@ -11,6 +11,7 @@ export type ModelResolutionRequest = { } constraints: { availableModels: Set + connectedProviders?: string[] | null } policy?: { fallbackChain?: FallbackEntry[] @@ -73,7 +74,7 @@ export function resolveModelPipeline( return { model: match, provenance: "category-default", attempted } } } else { - const connectedProviders = readConnectedProvidersCache() + const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache() if (connectedProviders === null) { log("Model resolved via category default (no cache, first run)", { model: normalizedCategoryDefault, @@ -98,7 +99,7 @@ export function resolveModelPipeline( if (fallbackChain && fallbackChain.length > 0) { if (availableModels.size === 0) { - const connectedProviders = readConnectedProvidersCache() + const connectedProviders = constraints.connectedProviders ?? connectedProvidersCache.readConnectedProvidersCache() const connectedSet = connectedProviders ? new Set(connectedProviders) : null if (connectedSet === null) {