diff --git a/src/shared/available-models-fetcher.ts b/src/shared/available-models-fetcher.ts new file mode 100644 index 00000000..b19defce --- /dev/null +++ b/src/shared/available-models-fetcher.ts @@ -0,0 +1,114 @@ +import { addModelsFromModelsJsonCache } from "./models-json-cache-reader" +import { getModelListFunction, getProviderListFunction } from "./open-code-client-accessors" +import { addModelsFromProviderModelsCache } from "./provider-models-cache-model-reader" +import { log } from "./logger" + +export async function getConnectedProviders(client: unknown): Promise { + const providerList = getProviderListFunction(client) + if (!providerList) { + log("[getConnectedProviders] client.provider.list not available") + return [] + } + + try { + const result = await providerList() + const connected = result.data?.connected ?? [] + log("[getConnectedProviders] connected providers", { + count: connected.length, + providers: connected, + }) + return connected + } catch (err) { + log("[getConnectedProviders] SDK error", { error: String(err) }) + return [] + } +} + +export async function fetchAvailableModels( + client?: unknown, + options?: { connectedProviders?: string[] | null }, +): Promise> { + let connectedProviders = options?.connectedProviders ?? null + let connectedProvidersUnknown = connectedProviders === null + + log("[fetchAvailableModels] CALLED", { + connectedProvidersUnknown, + connectedProviders: options?.connectedProviders, + }) + + if (connectedProvidersUnknown && client !== undefined) { + const liveConnected = await getConnectedProviders(client) + if (liveConnected.length > 0) { + connectedProviders = liveConnected + connectedProvidersUnknown = false + log("[fetchAvailableModels] connected providers fetched from client", { + count: liveConnected.length, + }) + } + } + + if (connectedProvidersUnknown) { + const modelList = client === undefined ? null : getModelListFunction(client) + if (modelList) { + const modelSet = new Set() + try { + const modelsResult = await modelList() + const models = modelsResult.data ?? [] + for (const model of models) { + if (model.provider && model.id) { + modelSet.add(`${model.provider}/${model.id}`) + } + } + log( + "[fetchAvailableModels] fetched models from client without provider filter", + { count: modelSet.size }, + ) + return modelSet + } catch (err) { + log("[fetchAvailableModels] client.model.list error", { + error: String(err), + }) + } + } + log( + "[fetchAvailableModels] connected providers unknown, returning empty set for fallback resolution", + ) + return new Set() + } + + const connectedProvidersList = connectedProviders ?? [] + const connectedSet = new Set(connectedProvidersList) + const modelSet = new Set() + + if (addModelsFromProviderModelsCache(connectedSet, modelSet)) { + return modelSet + } + log("[fetchAvailableModels] provider-models cache not found, falling back to models.json") + if (addModelsFromModelsJsonCache(connectedSet, modelSet)) { + return modelSet + } + + const modelList = client === undefined ? null : getModelListFunction(client) + if (modelList) { + try { + const modelsResult = await modelList() + const models = modelsResult.data ?? [] + + for (const model of models) { + if (!model.provider || !model.id) continue + if (connectedSet.has(model.provider)) { + modelSet.add(`${model.provider}/${model.id}`) + } + } + + log("[fetchAvailableModels] fetched models from client (filtered)", { + count: modelSet.size, + connectedProviders: connectedProvidersList.slice(0, 5), + }) + } catch (err) { + log("[fetchAvailableModels] client.model.list error", { error: String(err) }) + } + } + + return modelSet +} diff --git a/src/shared/fallback-model-availability.ts b/src/shared/fallback-model-availability.ts new file mode 100644 index 00000000..f6fc30cc --- /dev/null +++ b/src/shared/fallback-model-availability.ts @@ -0,0 +1,67 @@ +import { readConnectedProvidersCache } from "./connected-providers-cache" +import { log } from "./logger" +import { fuzzyMatchModel } from "./model-name-matcher" + +export function isAnyFallbackModelAvailable( + fallbackChain: Array<{ providers: string[]; model: string }>, + availableModels: Set, +): boolean { + if (availableModels.size > 0) { + for (const entry of fallbackChain) { + const hasAvailableProvider = entry.providers.some((provider) => { + return fuzzyMatchModel(entry.model, availableModels, [provider]) !== null + }) + if (hasAvailableProvider) { + return true + } + } + } + + const connectedProviders = readConnectedProvidersCache() + if (connectedProviders) { + const connectedSet = new Set(connectedProviders) + for (const entry of fallbackChain) { + if (entry.providers.some((p) => connectedSet.has(p))) { + log( + "[isAnyFallbackModelAvailable] model not in available set, but provider is connected", + { model: entry.model, availableCount: availableModels.size }, + ) + return true + } + } + } + + return false +} + +export function isAnyProviderConnected( + providers: string[], + availableModels: Set, +): boolean { + if (availableModels.size > 0) { + const providerSet = new Set(providers) + for (const model of availableModels) { + const [provider] = model.split("/") + if (providerSet.has(provider)) { + log("[isAnyProviderConnected] found model from required provider", { + provider, + model, + }) + return true + } + } + } + + const connectedProviders = readConnectedProvidersCache() + if (connectedProviders) { + const connectedSet = new Set(connectedProviders) + for (const provider of providers) { + if (connectedSet.has(provider)) { + log("[isAnyProviderConnected] provider connected via cache", { provider }) + return true + } + } + } + + return false +} diff --git a/src/shared/model-availability.ts b/src/shared/model-availability.ts index 6fa2fb17..ad5df3b4 100644 --- a/src/shared/model-availability.ts +++ b/src/shared/model-availability.ts @@ -1,357 +1,4 @@ -import { existsSync, readFileSync } from "fs" -import { join } from "path" -import { log } from "./logger" -import { getOpenCodeCacheDir } from "./data-path" -import { readProviderModelsCache, hasProviderModelsCache, readConnectedProvidersCache } from "./connected-providers-cache" - -/** - * Fuzzy match a target model name against available models - * - * @param target - The model name or substring to search for (e.g., "gpt-5.2", "claude-opus") - * @param available - Set of available model names in format "provider/model-name" - * @param providers - Optional array of provider names to filter by (e.g., ["openai", "anthropic"]) - * @returns The matched model name or null if no match found - * - * Matching priority: - * 1. Exact match (if exists) - * 2. Shorter model name (more specific) - * - * Matching is case-insensitive substring match. - * If providers array is given, only models starting with "provider/" are considered. - * - * @example - * const available = new Set(["openai/gpt-5.2", "openai/gpt-5.3-codex", "anthropic/claude-opus-4-6"]) - * fuzzyMatchModel("gpt-5.2", available) // → "openai/gpt-5.2" - * fuzzyMatchModel("claude", available, ["openai"]) // → null (provider filter excludes anthropic) - */ -function normalizeModelName(name: string): string { - return name - .toLowerCase() - .replace(/claude-(opus|sonnet|haiku)-4-5/g, "claude-$1-4.5") - .replace(/claude-(opus|sonnet|haiku)-4\.5/g, "claude-$1-4.5") -} - -export function fuzzyMatchModel( - target: string, - available: Set, - providers?: string[], -): string | null { - log("[fuzzyMatchModel] called", { target, availableCount: available.size, providers }) - - if (available.size === 0) { - log("[fuzzyMatchModel] empty available set") - return null - } - - const targetNormalized = normalizeModelName(target) - - // Filter by providers if specified - let candidates = Array.from(available) - if (providers && providers.length > 0) { - const providerSet = new Set(providers) - candidates = candidates.filter((model) => { - const [provider] = model.split("/") - return providerSet.has(provider) - }) - log("[fuzzyMatchModel] filtered by providers", { candidateCount: candidates.length, candidates: candidates.slice(0, 10) }) - } - - if (candidates.length === 0) { - log("[fuzzyMatchModel] no candidates after filter") - return null - } - - // Find all matches (case-insensitive substring match with normalization) - const matches = candidates.filter((model) => - normalizeModelName(model).includes(targetNormalized), - ) - - log("[fuzzyMatchModel] substring matches", { targetNormalized, matchCount: matches.length, matches }) - - if (matches.length === 0) { - return null - } - - // Priority 1: Exact match (normalized full model string) - const exactMatch = matches.find((model) => normalizeModelName(model) === targetNormalized) - if (exactMatch) { - log("[fuzzyMatchModel] exact match found", { exactMatch }) - return exactMatch - } - - // Priority 2: Exact model ID match (part after provider/) - // This ensures "glm-4.7-free" matches "zai-coding-plan/glm-4.7-free" over "zai-coding-plan/glm-4.7" - // Use filter + shortest to handle multi-provider cases (e.g., openai/gpt-5.2 + opencode/gpt-5.2) - const exactModelIdMatches = matches.filter((model) => { - const modelId = model.split("/").slice(1).join("/") - return normalizeModelName(modelId) === targetNormalized - }) - if (exactModelIdMatches.length > 0) { - const result = exactModelIdMatches.reduce((shortest, current) => - current.length < shortest.length ? current : shortest, - ) - log("[fuzzyMatchModel] exact model ID match found", { result, candidateCount: exactModelIdMatches.length }) - return result - } - - // Priority 3: Shorter model name (more specific, fallback for partial matches) - const result = matches.reduce((shortest, current) => - current.length < shortest.length ? current : shortest, - ) - log("[fuzzyMatchModel] shortest match", { result }) - return result -} - -/** - * Check if a target model is available (fuzzy match by model name, no provider filtering) - * - * @param targetModel - Model name to check (e.g., "gpt-5.3-codex") - * @param availableModels - Set of available models in "provider/model" format - * @returns true if model is available, false otherwise - */ -export function isModelAvailable( - targetModel: string, - availableModels: Set, -): boolean { - return fuzzyMatchModel(targetModel, availableModels) !== null -} - -export async function getConnectedProviders(client: any): Promise { - if (!client?.provider?.list) { - log("[getConnectedProviders] client.provider.list not available") - return [] - } - - try { - const result = await client.provider.list() - const connected = result.data?.connected ?? [] - log("[getConnectedProviders] connected providers", { count: connected.length, providers: connected }) - return connected - } catch (err) { - log("[getConnectedProviders] SDK error", { error: String(err) }) - return [] - } -} - -export async function fetchAvailableModels( - client?: any, - options?: { connectedProviders?: string[] | null } -): Promise> { - let connectedProviders = options?.connectedProviders ?? null - let connectedProvidersUnknown = connectedProviders === null - - log("[fetchAvailableModels] CALLED", { - connectedProvidersUnknown, - connectedProviders: options?.connectedProviders - }) - - if (connectedProvidersUnknown && client) { - const liveConnected = await getConnectedProviders(client) - if (liveConnected.length > 0) { - connectedProviders = liveConnected - connectedProvidersUnknown = false - log("[fetchAvailableModels] connected providers fetched from client", { count: liveConnected.length }) - } - } - - if (connectedProvidersUnknown) { - if (client?.model?.list) { - const modelSet = new Set() - try { - const modelsResult = await client.model.list() - const models = modelsResult.data ?? [] - for (const model of models) { - if (model?.provider && model?.id) { - modelSet.add(`${model.provider}/${model.id}`) - } - } - log("[fetchAvailableModels] fetched models from client without provider filter", { - count: modelSet.size, - }) - return modelSet - } catch (err) { - log("[fetchAvailableModels] client.model.list error", { error: String(err) }) - } - } - log("[fetchAvailableModels] connected providers unknown, returning empty set for fallback resolution") - return new Set() - } - - const connectedProvidersList = connectedProviders ?? [] - const connectedSet = new Set(connectedProvidersList) - const modelSet = new Set() - - const providerModelsCache = readProviderModelsCache() - if (providerModelsCache) { - const providerCount = Object.keys(providerModelsCache.models).length - if (providerCount === 0) { - log("[fetchAvailableModels] provider-models cache empty, falling back to models.json") - } else { - log("[fetchAvailableModels] using provider-models cache (whitelist-filtered)") - - for (const [providerId, modelIds] of Object.entries(providerModelsCache.models)) { - if (!connectedSet.has(providerId)) { - continue - } - for (const modelItem of modelIds) { - // Handle both string[] (legacy) and object[] (with metadata) formats - const modelId = typeof modelItem === 'string' - ? modelItem - : (modelItem as any)?.id - - if (modelId) { - modelSet.add(`${providerId}/${modelId}`) - } - } - } - - log("[fetchAvailableModels] parsed from provider-models cache", { - count: modelSet.size, - connectedProviders: connectedProvidersList.slice(0, 5) - }) - - if (modelSet.size > 0) { - return modelSet - } - log("[fetchAvailableModels] provider-models cache produced no models for connected providers, falling back to models.json") - } - } - - log("[fetchAvailableModels] provider-models cache not found, falling back to models.json") - const cacheFile = join(getOpenCodeCacheDir(), "models.json") - - if (!existsSync(cacheFile)) { - log("[fetchAvailableModels] models.json cache file not found, falling back to client") - } else { - try { - const content = readFileSync(cacheFile, "utf-8") - const data = JSON.parse(content) as Record }> - - const providerIds = Object.keys(data) - log("[fetchAvailableModels] providers found in models.json", { count: providerIds.length, providers: providerIds.slice(0, 10) }) - - for (const providerId of providerIds) { - if (!connectedSet.has(providerId)) { - continue - } - - const provider = data[providerId] - const models = provider?.models - if (!models || typeof models !== "object") continue - - for (const modelKey of Object.keys(models)) { - modelSet.add(`${providerId}/${modelKey}`) - } - } - - log("[fetchAvailableModels] parsed models from models.json (NO whitelist filtering)", { - count: modelSet.size, - connectedProviders: connectedProvidersList.slice(0, 5) - }) - - if (modelSet.size > 0) { - return modelSet - } - } catch (err) { - log("[fetchAvailableModels] error", { error: String(err) }) - } - } - - if (client?.model?.list) { - try { - const modelsResult = await client.model.list() - const models = modelsResult.data ?? [] - - for (const model of models) { - if (!model?.provider || !model?.id) continue - if (connectedSet.has(model.provider)) { - modelSet.add(`${model.provider}/${model.id}`) - } - } - - log("[fetchAvailableModels] fetched models from client (filtered)", { - count: modelSet.size, - connectedProviders: connectedProvidersList.slice(0, 5), - }) - } catch (err) { - log("[fetchAvailableModels] client.model.list error", { error: String(err) }) - } - } - - return modelSet -} - -export function isAnyFallbackModelAvailable( - fallbackChain: Array<{ providers: string[]; model: string }>, - availableModels: Set, -): boolean { - // If we have models, check them first - if (availableModels.size > 0) { - for (const entry of fallbackChain) { - const hasAvailableProvider = entry.providers.some((provider) => { - return fuzzyMatchModel(entry.model, availableModels, [provider]) !== null - }) - if (hasAvailableProvider) { - return true - } - } - } - - // Fallback: check if any provider in the chain is connected - // This handles race conditions where availableModels is empty or incomplete - // but we know the provider is connected. - const connectedProviders = readConnectedProvidersCache() - if (connectedProviders) { - const connectedSet = new Set(connectedProviders) - for (const entry of fallbackChain) { - if (entry.providers.some((p) => connectedSet.has(p))) { - log("[isAnyFallbackModelAvailable] model not in available set, but provider is connected", { - model: entry.model, - availableCount: availableModels.size, - }) - return true - } - } - } - - return false -} - -export function isAnyProviderConnected( - providers: string[], - availableModels: Set, -): boolean { - if (availableModels.size > 0) { - const providerSet = new Set(providers) - for (const model of availableModels) { - const [provider] = model.split("/") - if (providerSet.has(provider)) { - log("[isAnyProviderConnected] found model from required provider", { provider, model }) - return true - } - } - } - - const connectedProviders = readConnectedProvidersCache() - if (connectedProviders) { - const connectedSet = new Set(connectedProviders) - for (const provider of providers) { - if (connectedSet.has(provider)) { - log("[isAnyProviderConnected] provider connected via cache", { provider }) - return true - } - } - } - - return false -} - -export function __resetModelCache(): void {} - -export function isModelCacheAvailable(): boolean { - if (hasProviderModelsCache()) { - return true - } - const cacheFile = join(getOpenCodeCacheDir(), "models.json") - return existsSync(cacheFile) -} +export { fetchAvailableModels, getConnectedProviders } from "./available-models-fetcher" +export { isAnyFallbackModelAvailable, isAnyProviderConnected } from "./fallback-model-availability" +export { __resetModelCache, isModelCacheAvailable } from "./model-cache-availability" +export { fuzzyMatchModel, isModelAvailable } from "./model-name-matcher" diff --git a/src/shared/model-cache-availability.ts b/src/shared/model-cache-availability.ts new file mode 100644 index 00000000..d0a2807c --- /dev/null +++ b/src/shared/model-cache-availability.ts @@ -0,0 +1,14 @@ +import { existsSync } from "fs" +import { join } from "path" +import { getOpenCodeCacheDir } from "./data-path" +import { hasProviderModelsCache } from "./connected-providers-cache" + +export function __resetModelCache(): void {} + +export function isModelCacheAvailable(): boolean { + if (hasProviderModelsCache()) { + return true + } + const cacheFile = join(getOpenCodeCacheDir(), "models.json") + return existsSync(cacheFile) +} diff --git a/src/shared/model-name-matcher.ts b/src/shared/model-name-matcher.ts new file mode 100644 index 00000000..4cbc0381 --- /dev/null +++ b/src/shared/model-name-matcher.ts @@ -0,0 +1,91 @@ +import { log } from "./logger" + +function normalizeModelName(name: string): string { + return name + .toLowerCase() + .replace(/claude-(opus|sonnet|haiku)-4-5/g, "claude-$1-4.5") + .replace(/claude-(opus|sonnet|haiku)-4\.5/g, "claude-$1-4.5") +} + +export function fuzzyMatchModel( + target: string, + available: Set, + providers?: string[], +): string | null { + log("[fuzzyMatchModel] called", { target, availableCount: available.size, providers }) + + if (available.size === 0) { + log("[fuzzyMatchModel] empty available set") + return null + } + + const targetNormalized = normalizeModelName(target) + + let candidates = Array.from(available) + if (providers && providers.length > 0) { + const providerSet = new Set(providers) + candidates = candidates.filter((model) => { + const [provider] = model.split("/") + return providerSet.has(provider) + }) + log("[fuzzyMatchModel] filtered by providers", { + candidateCount: candidates.length, + candidates: candidates.slice(0, 10), + }) + } + + if (candidates.length === 0) { + log("[fuzzyMatchModel] no candidates after filter") + return null + } + + const matches = candidates.filter((model) => + normalizeModelName(model).includes(targetNormalized), + ) + + log("[fuzzyMatchModel] substring matches", { + targetNormalized, + matchCount: matches.length, + matches, + }) + + if (matches.length === 0) { + return null + } + + const exactMatch = matches.find( + (model) => normalizeModelName(model) === targetNormalized, + ) + if (exactMatch) { + log("[fuzzyMatchModel] exact match found", { exactMatch }) + return exactMatch + } + + const exactModelIdMatches = matches.filter((model) => { + const modelId = model.split("/").slice(1).join("/") + return normalizeModelName(modelId) === targetNormalized + }) + if (exactModelIdMatches.length > 0) { + const result = exactModelIdMatches.reduce((shortest, current) => + current.length < shortest.length ? current : shortest, + ) + log("[fuzzyMatchModel] exact model ID match found", { + result, + candidateCount: exactModelIdMatches.length, + }) + return result + } + + const result = matches.reduce((shortest, current) => + current.length < shortest.length ? current : shortest, + ) + log("[fuzzyMatchModel] shortest match", { result }) + return result +} + +export function isModelAvailable( + targetModel: string, + availableModels: Set, +): boolean { + return fuzzyMatchModel(targetModel, availableModels) !== null +} diff --git a/src/shared/models-json-cache-reader.ts b/src/shared/models-json-cache-reader.ts new file mode 100644 index 00000000..d2291f28 --- /dev/null +++ b/src/shared/models-json-cache-reader.ts @@ -0,0 +1,52 @@ +import { existsSync, readFileSync } from "fs" +import { join } from "path" +import { getOpenCodeCacheDir } from "./data-path" +import { log } from "./logger" +import { isRecord } from "./record-type-guard" + +export function addModelsFromModelsJsonCache( + connectedProviders: Set, + modelSet: Set, +): boolean { + const cacheFile = join(getOpenCodeCacheDir(), "models.json") + if (!existsSync(cacheFile)) { + log("[fetchAvailableModels] models.json cache file not found, falling back to client") + return false + } + + try { + const content = readFileSync(cacheFile, "utf-8") + const data: unknown = JSON.parse(content) + if (!isRecord(data)) { + return false + } + + const providerIds = Object.keys(data) + log("[fetchAvailableModels] providers found in models.json", { + count: providerIds.length, + providers: providerIds.slice(0, 10), + }) + + const previousSize = modelSet.size + for (const providerId of providerIds) { + if (!connectedProviders.has(providerId)) continue + const providerValue = data[providerId] + if (!isRecord(providerValue)) continue + const modelsValue = providerValue["models"] + if (!isRecord(modelsValue)) continue + for (const modelKey of Object.keys(modelsValue)) { + modelSet.add(`${providerId}/${modelKey}`) + } + } + + log("[fetchAvailableModels] parsed models from models.json (NO whitelist filtering)", { + count: modelSet.size, + connectedProviders: Array.from(connectedProviders).slice(0, 5), + }) + + return modelSet.size > previousSize + } catch (err) { + log("[fetchAvailableModels] error", { error: String(err) }) + return false + } +} diff --git a/src/shared/open-code-client-accessors.ts b/src/shared/open-code-client-accessors.ts new file mode 100644 index 00000000..d20f9290 --- /dev/null +++ b/src/shared/open-code-client-accessors.ts @@ -0,0 +1,20 @@ +import type { ModelListFunction, ProviderListFunction } from "./open-code-client-shapes" +import { isRecord } from "./record-type-guard" + +export function getProviderListFunction(client: unknown): ProviderListFunction | null { + if (!isRecord(client)) return null + const provider = client["provider"] + if (!isRecord(provider)) return null + const list = provider["list"] + if (typeof list !== "function") return null + return list as ProviderListFunction +} + +export function getModelListFunction(client: unknown): ModelListFunction | null { + if (!isRecord(client)) return null + const model = client["model"] + if (!isRecord(model)) return null + const list = model["list"] + if (typeof list !== "function") return null + return list as ModelListFunction +} diff --git a/src/shared/open-code-client-shapes.ts b/src/shared/open-code-client-shapes.ts new file mode 100644 index 00000000..701091a3 --- /dev/null +++ b/src/shared/open-code-client-shapes.ts @@ -0,0 +1,7 @@ +export type ProviderListResponse = { data?: { connected?: string[] } } +export type ModelListResponse = { + data?: Array<{ id?: string; provider?: string }> +} + +export type ProviderListFunction = () => Promise +export type ModelListFunction = () => Promise diff --git a/src/shared/provider-models-cache-model-reader.ts b/src/shared/provider-models-cache-model-reader.ts new file mode 100644 index 00000000..c012b94e --- /dev/null +++ b/src/shared/provider-models-cache-model-reader.ts @@ -0,0 +1,39 @@ +import { readProviderModelsCache } from "./connected-providers-cache" +import { log } from "./logger" + +export function addModelsFromProviderModelsCache( + connectedProviders: Set, + modelSet: Set, +): boolean { + const providerModelsCache = readProviderModelsCache() + if (!providerModelsCache) { + return false + } + + const providerCount = Object.keys(providerModelsCache.models).length + if (providerCount === 0) { + log("[fetchAvailableModels] provider-models cache empty, falling back to models.json") + return false + } + + log("[fetchAvailableModels] using provider-models cache (whitelist-filtered)") + const previousSize = modelSet.size + + for (const [providerId, modelIds] of Object.entries(providerModelsCache.models)) { + if (!connectedProviders.has(providerId)) continue + for (const modelItem of modelIds) { + if (!modelItem) continue + const modelId = typeof modelItem === "string" ? modelItem : modelItem.id + if (modelId) { + modelSet.add(`${providerId}/${modelId}`) + } + } + } + + log("[fetchAvailableModels] parsed from provider-models cache", { + count: modelSet.size, + connectedProviders: Array.from(connectedProviders).slice(0, 5), + }) + + return modelSet.size > previousSize +} diff --git a/src/shared/record-type-guard.ts b/src/shared/record-type-guard.ts new file mode 100644 index 00000000..a901f1a6 --- /dev/null +++ b/src/shared/record-type-guard.ts @@ -0,0 +1,3 @@ +export function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null +}