diff --git a/src/features/skill-mcp-manager/cleanup.ts b/src/features/skill-mcp-manager/cleanup.ts new file mode 100644 index 00000000..805c8c50 --- /dev/null +++ b/src/features/skill-mcp-manager/cleanup.ts @@ -0,0 +1,129 @@ +import type { ManagedClient, SkillMcpManagerState } from "./types" + +async function closeManagedClient(managed: ManagedClient): Promise { + try { + await managed.client.close() + } catch { + // Ignore close errors - process may already be terminated + } + + try { + await managed.transport.close() + } catch { + // Transport may already be terminated + } +} + +export function registerProcessCleanup(state: SkillMcpManagerState): void { + if (state.cleanupRegistered) return + state.cleanupRegistered = true + + const cleanup = async (): Promise => { + for (const managed of state.clients.values()) { + await closeManagedClient(managed) + } + state.clients.clear() + state.pendingConnections.clear() + } + + // Note: Node's 'exit' event is synchronous-only, so we rely on signal handlers for async cleanup. + // Signal handlers invoke the async cleanup function and ignore errors so they don't block or throw. + // Don't call process.exit() here - let the background-agent manager handle the final process exit. + // Use void + catch to trigger async cleanup without awaiting it in the signal handler. + + const register = (signal: NodeJS.Signals) => { + const listener = () => void cleanup().catch(() => {}) + state.cleanupHandlers.push({ signal, listener }) + process.on(signal, listener) + } + + register("SIGINT") + register("SIGTERM") + if (process.platform === "win32") { + register("SIGBREAK") + } +} + +export function unregisterProcessCleanup(state: SkillMcpManagerState): void { + if (!state.cleanupRegistered) return + for (const { signal, listener } of state.cleanupHandlers) { + process.off(signal, listener) + } + state.cleanupHandlers = [] + state.cleanupRegistered = false +} + +export function startCleanupTimer(state: SkillMcpManagerState): void { + if (state.cleanupInterval) return + + state.cleanupInterval = setInterval(() => { + void cleanupIdleClients(state).catch(() => {}) + }, 60_000) + + state.cleanupInterval.unref() +} + +export function stopCleanupTimer(state: SkillMcpManagerState): void { + if (!state.cleanupInterval) return + clearInterval(state.cleanupInterval) + state.cleanupInterval = null +} + +async function cleanupIdleClients(state: SkillMcpManagerState): Promise { + const now = Date.now() + + for (const [key, managed] of state.clients) { + if (now - managed.lastUsedAt > state.idleTimeoutMs) { + state.clients.delete(key) + await closeManagedClient(managed) + } + } + + if (state.clients.size === 0) { + stopCleanupTimer(state) + } +} + +export async function disconnectSession(state: SkillMcpManagerState, sessionID: string): Promise { + const keysToRemove: string[] = [] + + for (const [key, managed] of state.clients.entries()) { + if (key.startsWith(`${sessionID}:`)) { + keysToRemove.push(key) + // Delete from map first to prevent re-entrancy during async close + state.clients.delete(key) + await closeManagedClient(managed) + } + } + + for (const key of keysToRemove) { + state.pendingConnections.delete(key) + } + + if (state.clients.size === 0) { + stopCleanupTimer(state) + } +} + +export async function disconnectAll(state: SkillMcpManagerState): Promise { + stopCleanupTimer(state) + unregisterProcessCleanup(state) + + const clients = Array.from(state.clients.values()) + state.clients.clear() + state.pendingConnections.clear() + state.authProviders.clear() + + for (const managed of clients) { + await closeManagedClient(managed) + } +} + +export async function forceReconnect(state: SkillMcpManagerState, clientKey: string): Promise { + const existing = state.clients.get(clientKey) + if (!existing) return false + + state.clients.delete(clientKey) + await closeManagedClient(existing) + return true +} diff --git a/src/features/skill-mcp-manager/connection-type.ts b/src/features/skill-mcp-manager/connection-type.ts new file mode 100644 index 00000000..64e1b59d --- /dev/null +++ b/src/features/skill-mcp-manager/connection-type.ts @@ -0,0 +1,26 @@ +import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" +import type { ConnectionType } from "./types" + +/** + * Determines connection type from MCP server configuration. + * Priority: explicit type field > url presence > command presence + */ +export function getConnectionType(config: ClaudeCodeMcpServer): ConnectionType | null { + // Explicit type takes priority + if (config.type === "http" || config.type === "sse") { + return "http" + } + if (config.type === "stdio") { + return "stdio" + } + + // Infer from available fields + if (config.url) { + return "http" + } + if (config.command) { + return "stdio" + } + + return null +} diff --git a/src/features/skill-mcp-manager/connection.ts b/src/features/skill-mcp-manager/connection.ts new file mode 100644 index 00000000..1a3559d0 --- /dev/null +++ b/src/features/skill-mcp-manager/connection.ts @@ -0,0 +1,95 @@ +import type { Client } from "@modelcontextprotocol/sdk/client/index.js" +import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" +import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander" +import { forceReconnect } from "./cleanup" +import { getConnectionType } from "./connection-type" +import { createHttpClient } from "./http-client" +import { createStdioClient } from "./stdio-client" +import type { SkillMcpClientConnectionParams, SkillMcpClientInfo, SkillMcpManagerState } from "./types" + +export async function getOrCreateClient(params: { + state: SkillMcpManagerState + clientKey: string + info: SkillMcpClientInfo + config: ClaudeCodeMcpServer +}): Promise { + const { state, clientKey, info, config } = params + + const existing = state.clients.get(clientKey) + if (existing) { + existing.lastUsedAt = Date.now() + return existing.client + } + + // Prevent race condition: if a connection is already in progress, wait for it + const pending = state.pendingConnections.get(clientKey) + if (pending) { + return pending + } + + const expandedConfig = expandEnvVarsInObject(config) + const connectionPromise = createClient({ state, clientKey, info, config: expandedConfig }) + state.pendingConnections.set(clientKey, connectionPromise) + + try { + const client = await connectionPromise + return client + } finally { + state.pendingConnections.delete(clientKey) + } +} + +export async function getOrCreateClientWithRetryImpl(params: { + state: SkillMcpManagerState + clientKey: string + info: SkillMcpClientInfo + config: ClaudeCodeMcpServer +}): Promise { + const { state, clientKey } = params + + try { + return await getOrCreateClient(params) + } catch (error) { + const didReconnect = await forceReconnect(state, clientKey) + if (!didReconnect) { + throw error + } + return await getOrCreateClient(params) + } +} + +async function createClient(params: { + state: SkillMcpManagerState + clientKey: string + info: SkillMcpClientInfo + config: ClaudeCodeMcpServer +}): Promise { + const { info, config } = params + const connectionType = getConnectionType(config) + + if (!connectionType) { + throw new Error( + `MCP server "${info.serverName}" has no valid connection configuration.\n\n` + + `The MCP configuration in skill "${info.skillName}" must specify either:\n` + + ` - A URL for HTTP connection (remote MCP server)\n` + + ` - A command for stdio connection (local MCP process)\n\n` + + `Examples:\n` + + ` HTTP:\n` + + ` mcp:\n` + + ` ${info.serverName}:\n` + + ` url: https://mcp.example.com/mcp\n` + + ` headers:\n` + + " Authorization: Bearer ${API_KEY}\n\n" + + ` Stdio:\n` + + ` mcp:\n` + + ` ${info.serverName}:\n` + + ` command: npx\n` + + ` args: [-y, @some/mcp-server]` + ) + } + + if (connectionType === "http") { + return await createHttpClient(params satisfies SkillMcpClientConnectionParams) + } + return await createStdioClient(params satisfies SkillMcpClientConnectionParams) +} diff --git a/src/features/skill-mcp-manager/http-client.ts b/src/features/skill-mcp-manager/http-client.ts new file mode 100644 index 00000000..308e9a6b --- /dev/null +++ b/src/features/skill-mcp-manager/http-client.ts @@ -0,0 +1,68 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js" +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js" +import { registerProcessCleanup, startCleanupTimer } from "./cleanup" +import { buildHttpRequestInit } from "./oauth-handler" +import type { ManagedClient, SkillMcpClientConnectionParams } from "./types" + +export async function createHttpClient(params: SkillMcpClientConnectionParams): Promise { + const { state, clientKey, info, config } = params + + if (!config.url) { + throw new Error(`MCP server "${info.serverName}" is configured for HTTP but missing 'url' field.`) + } + + let url: URL + try { + url = new URL(config.url) + } catch { + throw new Error( + `MCP server "${info.serverName}" has invalid URL: ${config.url}\n\n` + + `Expected a valid URL like: https://mcp.example.com/mcp` + ) + } + + registerProcessCleanup(state) + + const requestInit = await buildHttpRequestInit(config, state.authProviders) + const transport = new StreamableHTTPClientTransport(url, { + requestInit, + }) + + const client = new Client( + { name: `skill-mcp-${info.skillName}-${info.serverName}`, version: "1.0.0" }, + { capabilities: {} } + ) + + try { + await client.connect(transport) + } catch (error) { + try { + await transport.close() + } catch { + // Transport may already be closed + } + + const errorMessage = error instanceof Error ? error.message : String(error) + throw new Error( + `Failed to connect to MCP server "${info.serverName}".\n\n` + + `URL: ${config.url}\n` + + `Reason: ${errorMessage}\n\n` + + `Hints:\n` + + ` - Verify the URL is correct and the server is running\n` + + ` - Check if authentication headers are required\n` + + ` - Ensure the server supports MCP over HTTP` + ) + } + + const managedClient = { + client, + transport, + skillName: info.skillName, + lastUsedAt: Date.now(), + connectionType: "http", + } satisfies ManagedClient + + state.clients.set(clientKey, managedClient) + startCleanupTimer(state) + return client +} diff --git a/src/features/skill-mcp-manager/manager.ts b/src/features/skill-mcp-manager/manager.ts index 43cb3dd8..71d3cf78 100644 --- a/src/features/skill-mcp-manager/manager.ts +++ b/src/features/skill-mcp-manager/manager.ts @@ -1,480 +1,57 @@ -import { Client } from "@modelcontextprotocol/sdk/client/index.js" -import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js" -import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js" -import type { Tool, Resource, Prompt } from "@modelcontextprotocol/sdk/types.js" +import type { Client } from "@modelcontextprotocol/sdk/client/index.js" +import type { Prompt, Resource, Tool } from "@modelcontextprotocol/sdk/types.js" import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" -import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander" -import { McpOAuthProvider } from "../mcp-oauth/provider" -import { isStepUpRequired, mergeScopes } from "../mcp-oauth/step-up" -import { createCleanMcpEnvironment } from "./env-cleaner" -import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types" - -/** - * Connection type for a managed MCP client. - * - "stdio": Local process via stdin/stdout - * - "http": Remote server via HTTP (Streamable HTTP transport) - */ -type ConnectionType = "stdio" | "http" - -interface ManagedClientBase { - client: Client - skillName: string - lastUsedAt: number - connectionType: ConnectionType -} - -interface ManagedStdioClient extends ManagedClientBase { - connectionType: "stdio" - transport: StdioClientTransport -} - -interface ManagedHttpClient extends ManagedClientBase { - connectionType: "http" - transport: StreamableHTTPClientTransport -} - -type ManagedClient = ManagedStdioClient | ManagedHttpClient - -/** - * Determines connection type from MCP server configuration. - * Priority: explicit type field > url presence > command presence - */ -function getConnectionType(config: ClaudeCodeMcpServer): ConnectionType | null { - // Explicit type takes priority - if (config.type === "http" || config.type === "sse") { - return "http" - } - if (config.type === "stdio") { - return "stdio" - } - - // Infer from available fields - if (config.url) { - return "http" - } - if (config.command) { - return "stdio" - } - - return null -} +import { disconnectAll, disconnectSession, forceReconnect } from "./cleanup" +import { getOrCreateClient, getOrCreateClientWithRetryImpl } from "./connection" +import { handleStepUpIfNeeded } from "./oauth-handler" +import type { SkillMcpClientInfo, SkillMcpManagerState, SkillMcpServerContext } from "./types" export class SkillMcpManager { - private clients: Map = new Map() - private pendingConnections: Map> = new Map() - private authProviders: Map = new Map() - private cleanupRegistered = false - private cleanupInterval: ReturnType | null = null - private cleanupHandlers: Array<{ signal: NodeJS.Signals; listener: () => void }> = [] - private readonly IDLE_TIMEOUT = 5 * 60 * 1000 + private readonly state: SkillMcpManagerState = { + clients: new Map(), + pendingConnections: new Map(), + authProviders: new Map(), + cleanupRegistered: false, + cleanupInterval: null, + cleanupHandlers: [], + idleTimeoutMs: 5 * 60 * 1000, + } private getClientKey(info: SkillMcpClientInfo): string { return `${info.sessionID}:${info.skillName}:${info.serverName}` } - /** - * Get or create an McpOAuthProvider for a given server URL + oauth config. - * Providers are cached by server URL to reuse tokens across reconnections. - */ - private getOrCreateAuthProvider( - serverUrl: string, - oauth: NonNullable - ): McpOAuthProvider { - const existing = this.authProviders.get(serverUrl) - if (existing) { - return existing - } - - const provider = new McpOAuthProvider({ - serverUrl, - clientId: oauth.clientId, - scopes: oauth.scopes, + async getOrCreateClient(info: SkillMcpClientInfo, config: ClaudeCodeMcpServer): Promise { + const clientKey = this.getClientKey(info) + return await getOrCreateClient({ + state: this.state, + clientKey, + info, + config, }) - this.authProviders.set(serverUrl, provider) - return provider - } - - private registerProcessCleanup(): void { - if (this.cleanupRegistered) return - this.cleanupRegistered = true - - const cleanup = async () => { - for (const [, managed] of this.clients) { - try { - await managed.client.close() - } catch { - // Ignore errors during cleanup - } - try { - await managed.transport.close() - } catch { - // Transport may already be terminated - } - } - this.clients.clear() - this.pendingConnections.clear() - } - - // Note: Node's 'exit' event is synchronous-only, so we rely on signal handlers for async cleanup. - // Signal handlers invoke the async cleanup function and ignore errors so they don't block or throw. - // Don't call process.exit() here - let the background-agent manager handle the final process exit. - // Use void + catch to trigger async cleanup without awaiting it in the signal handler. - - const register = (signal: NodeJS.Signals) => { - const listener = () => void cleanup().catch(() => {}) - this.cleanupHandlers.push({ signal, listener }) - process.on(signal, listener) - } - - register("SIGINT") - register("SIGTERM") - if (process.platform === "win32") { - register("SIGBREAK") - } - } - - private unregisterProcessCleanup(): void { - if (!this.cleanupRegistered) return - for (const { signal, listener } of this.cleanupHandlers) { - process.off(signal, listener) - } - this.cleanupHandlers = [] - this.cleanupRegistered = false - } - - async getOrCreateClient( - info: SkillMcpClientInfo, - config: ClaudeCodeMcpServer - ): Promise { - const key = this.getClientKey(info) - const existing = this.clients.get(key) - - if (existing) { - existing.lastUsedAt = Date.now() - return existing.client - } - - // Prevent race condition: if a connection is already in progress, wait for it - const pending = this.pendingConnections.get(key) - if (pending) { - return pending - } - - const expandedConfig = expandEnvVarsInObject(config) - const connectionPromise = this.createClient(info, expandedConfig) - this.pendingConnections.set(key, connectionPromise) - - try { - const client = await connectionPromise - return client - } finally { - this.pendingConnections.delete(key) - } - } - - private async createClient( - info: SkillMcpClientInfo, - config: ClaudeCodeMcpServer - ): Promise { - const connectionType = getConnectionType(config) - - if (!connectionType) { - throw new Error( - `MCP server "${info.serverName}" has no valid connection configuration.\n\n` + - `The MCP configuration in skill "${info.skillName}" must specify either:\n` + - ` - A URL for HTTP connection (remote MCP server)\n` + - ` - A command for stdio connection (local MCP process)\n\n` + - `Examples:\n` + - ` HTTP:\n` + - ` mcp:\n` + - ` ${info.serverName}:\n` + - ` url: https://mcp.example.com/mcp\n` + - ` headers:\n` + - ` Authorization: Bearer \${API_KEY}\n\n` + - ` Stdio:\n` + - ` mcp:\n` + - ` ${info.serverName}:\n` + - ` command: npx\n` + - ` args: [-y, @some/mcp-server]` - ) - } - - if (connectionType === "http") { - return this.createHttpClient(info, config) - } else { - return this.createStdioClient(info, config) - } - } - - /** - * Create an HTTP-based MCP client using StreamableHTTPClientTransport. - * Supports remote MCP servers with optional authentication headers. - */ - private async createHttpClient( - info: SkillMcpClientInfo, - config: ClaudeCodeMcpServer - ): Promise { - const key = this.getClientKey(info) - - if (!config.url) { - throw new Error( - `MCP server "${info.serverName}" is configured for HTTP but missing 'url' field.` - ) - } - - let url: URL - try { - url = new URL(config.url) - } catch { - throw new Error( - `MCP server "${info.serverName}" has invalid URL: ${config.url}\n\n` + - `Expected a valid URL like: https://mcp.example.com/mcp` - ) - } - - this.registerProcessCleanup() - - // Build request init with headers if provided - const requestInit: RequestInit = {} - if (config.headers && Object.keys(config.headers).length > 0) { - requestInit.headers = { ...config.headers } - } - - let authProvider: McpOAuthProvider | undefined - if (config.oauth) { - authProvider = this.getOrCreateAuthProvider(config.url, config.oauth) - let tokenData = authProvider.tokens() - - const isExpired = tokenData?.expiresAt != null && tokenData.expiresAt < Math.floor(Date.now() / 1000) - if (!tokenData || isExpired) { - try { - tokenData = await authProvider.login() - } catch { - // Login failed — proceed without auth header - } - } - - if (tokenData) { - const existingHeaders = (requestInit.headers ?? {}) as Record - requestInit.headers = { - ...existingHeaders, - Authorization: `Bearer ${tokenData.accessToken}`, - } - } - } - - const transport = new StreamableHTTPClientTransport(url, { - requestInit: Object.keys(requestInit).length > 0 ? requestInit : undefined, - }) - - const client = new Client( - { name: `skill-mcp-${info.skillName}-${info.serverName}`, version: "1.0.0" }, - { capabilities: {} } - ) - - try { - await client.connect(transport) - } catch (error) { - try { - await transport.close() - } catch { - // Transport may already be closed - } - const errorMessage = error instanceof Error ? error.message : String(error) - throw new Error( - `Failed to connect to MCP server "${info.serverName}".\n\n` + - `URL: ${config.url}\n` + - `Reason: ${errorMessage}\n\n` + - `Hints:\n` + - ` - Verify the URL is correct and the server is running\n` + - ` - Check if authentication headers are required\n` + - ` - Ensure the server supports MCP over HTTP` - ) - } - - const managedClient: ManagedHttpClient = { - client, - transport, - skillName: info.skillName, - lastUsedAt: Date.now(), - connectionType: "http", - } - this.clients.set(key, managedClient) - this.startCleanupTimer() - return client - } - - /** - * Create a stdio-based MCP client using StdioClientTransport. - * Spawns a local process and communicates via stdin/stdout. - */ - private async createStdioClient( - info: SkillMcpClientInfo, - config: ClaudeCodeMcpServer - ): Promise { - const key = this.getClientKey(info) - - if (!config.command) { - throw new Error( - `MCP server "${info.serverName}" is configured for stdio but missing 'command' field.` - ) - } - - const command = config.command - const args = config.args || [] - - const mergedEnv = createCleanMcpEnvironment(config.env) - - this.registerProcessCleanup() - - const transport = new StdioClientTransport({ - command, - args, - env: mergedEnv, - stderr: "ignore", - }) - - const client = new Client( - { name: `skill-mcp-${info.skillName}-${info.serverName}`, version: "1.0.0" }, - { capabilities: {} } - ) - - try { - await client.connect(transport) - } catch (error) { - // Close transport to prevent orphaned MCP process on connection failure - try { - await transport.close() - } catch { - // Process may already be terminated - } - const errorMessage = error instanceof Error ? error.message : String(error) - throw new Error( - `Failed to connect to MCP server "${info.serverName}".\n\n` + - `Command: ${command} ${args.join(" ")}\n` + - `Reason: ${errorMessage}\n\n` + - `Hints:\n` + - ` - Ensure the command is installed and available in PATH\n` + - ` - Check if the MCP server package exists\n` + - ` - Verify the args are correct for this server` - ) - } - - const managedClient: ManagedStdioClient = { - client, - transport, - skillName: info.skillName, - lastUsedAt: Date.now(), - connectionType: "stdio", - } - this.clients.set(key, managedClient) - this.startCleanupTimer() - return client } async disconnectSession(sessionID: string): Promise { - const keysToRemove: string[] = [] - - for (const [key, managed] of this.clients.entries()) { - if (key.startsWith(`${sessionID}:`)) { - keysToRemove.push(key) - // Delete from map first to prevent re-entrancy during async close - this.clients.delete(key) - try { - await managed.client.close() - } catch { - // Ignore close errors - process may already be terminated - } - try { - await managed.transport.close() - } catch { - // Transport may already be terminated - } - } - } - - for (const key of keysToRemove) { - this.pendingConnections.delete(key) - } - - if (this.clients.size === 0) { - this.stopCleanupTimer() - } + await disconnectSession(this.state, sessionID) } async disconnectAll(): Promise { - this.stopCleanupTimer() - this.unregisterProcessCleanup() - const clients = Array.from(this.clients.values()) - this.clients.clear() - this.pendingConnections.clear() - this.authProviders.clear() - for (const managed of clients) { - try { - await managed.client.close() - } catch { /* process may already be terminated */ } - try { - await managed.transport.close() - } catch { /* transport may already be terminated */ } - } + await disconnectAll(this.state) } - private startCleanupTimer(): void { - if (this.cleanupInterval) return - this.cleanupInterval = setInterval(() => { - this.cleanupIdleClients() - }, 60_000) - this.cleanupInterval.unref() - } - - private stopCleanupTimer(): void { - if (this.cleanupInterval) { - clearInterval(this.cleanupInterval) - this.cleanupInterval = null - } - } - - private async cleanupIdleClients(): Promise { - const now = Date.now() - for (const [key, managed] of this.clients) { - if (now - managed.lastUsedAt > this.IDLE_TIMEOUT) { - this.clients.delete(key) - try { - await managed.client.close() - } catch { /* process may already be terminated */ } - try { - await managed.transport.close() - } catch { /* transport may already be terminated */ } - } - } - - if (this.clients.size === 0) { - this.stopCleanupTimer() - } - } - - async listTools( - info: SkillMcpClientInfo, - context: SkillMcpServerContext - ): Promise { + async listTools(info: SkillMcpClientInfo, context: SkillMcpServerContext): Promise { const client = await this.getOrCreateClientWithRetry(info, context.config) const result = await client.listTools() return result.tools } - async listResources( - info: SkillMcpClientInfo, - context: SkillMcpServerContext - ): Promise { + async listResources(info: SkillMcpClientInfo, context: SkillMcpServerContext): Promise { const client = await this.getOrCreateClientWithRetry(info, context.config) const result = await client.listResources() return result.resources } - async listPrompts( - info: SkillMcpClientInfo, - context: SkillMcpServerContext - ): Promise { + async listPrompts(info: SkillMcpClientInfo, context: SkillMcpServerContext): Promise { const client = await this.getOrCreateClientWithRetry(info, context.config) const result = await client.listPrompts() return result.prompts @@ -486,18 +63,14 @@ export class SkillMcpManager { name: string, args: Record ): Promise { - return this.withOperationRetry(info, context.config, async (client) => { + return await this.withOperationRetry(info, context.config, async (client) => { const result = await client.callTool({ name, arguments: args }) return result.content }) } - async readResource( - info: SkillMcpClientInfo, - context: SkillMcpServerContext, - uri: string - ): Promise { - return this.withOperationRetry(info, context.config, async (client) => { + async readResource(info: SkillMcpClientInfo, context: SkillMcpServerContext, uri: string): Promise { + return await this.withOperationRetry(info, context.config, async (client) => { const result = await client.readResource({ uri }) return result.contents }) @@ -509,7 +82,7 @@ export class SkillMcpManager { name: string, args: Record ): Promise { - return this.withOperationRetry(info, context.config, async (client) => { + return await this.withOperationRetry(info, context.config, async (client) => { const result = await client.getPrompt({ name, arguments: args }) return result.messages }) @@ -531,9 +104,13 @@ export class SkillMcpManager { lastError = error instanceof Error ? error : new Error(String(error)) const errorMessage = lastError.message.toLowerCase() - const stepUpHandled = await this.handleStepUpIfNeeded(lastError, config) + const stepUpHandled = await handleStepUpIfNeeded({ + error: lastError, + config, + authProviders: this.state.authProviders, + }) if (stepUpHandled) { - await this.forceReconnect(info) + await forceReconnect(this.state, this.getClientKey(info)) continue } @@ -542,99 +119,32 @@ export class SkillMcpManager { } if (attempt === maxRetries) { - throw new Error( - `Failed after ${maxRetries} reconnection attempts: ${lastError.message}` - ) + throw new Error(`Failed after ${maxRetries} reconnection attempts: ${lastError.message}`) } - await this.forceReconnect(info) + await forceReconnect(this.state, this.getClientKey(info)) } } - throw lastError || new Error("Operation failed with unknown error") + throw lastError ?? new Error("Operation failed with unknown error") } - private async handleStepUpIfNeeded( - error: Error, - config: ClaudeCodeMcpServer - ): Promise { - if (!config.oauth || !config.url) { - return false - } - - const statusMatch = /\b403\b/.exec(error.message) - if (!statusMatch) { - return false - } - - const headers: Record = {} - const wwwAuthMatch = /WWW-Authenticate:\s*(.+)/i.exec(error.message) - if (wwwAuthMatch?.[1]) { - headers["www-authenticate"] = wwwAuthMatch[1] - } - - const stepUp = isStepUpRequired(403, headers) - if (!stepUp) { - return false - } - - const currentScopes = config.oauth.scopes ?? [] - const merged = mergeScopes(currentScopes, stepUp.requiredScopes) - config.oauth.scopes = merged - - this.authProviders.delete(config.url) - const provider = this.getOrCreateAuthProvider(config.url, config.oauth) - - try { - await provider.login() - return true - } catch { - return false - } - } - - private async forceReconnect(info: SkillMcpClientInfo): Promise { - const key = this.getClientKey(info) - const existing = this.clients.get(key) - if (existing) { - this.clients.delete(key) - try { - await existing.client.close() - } catch { /* process may already be terminated */ } - try { - await existing.transport.close() - } catch { /* transport may already be terminated */ } - } - } - - private async getOrCreateClientWithRetry( - info: SkillMcpClientInfo, - config: ClaudeCodeMcpServer - ): Promise { - try { - return await this.getOrCreateClient(info, config) - } catch (error) { - const key = this.getClientKey(info) - const existing = this.clients.get(key) - if (existing) { - this.clients.delete(key) - try { - await existing.client.close() - } catch { /* process may already be terminated */ } - try { - await existing.transport.close() - } catch { /* transport may already be terminated */ } - return await this.getOrCreateClient(info, config) - } - throw error - } + // NOTE: tests spy on this exact method name via `spyOn(manager as any, 'getOrCreateClientWithRetry')`. + private async getOrCreateClientWithRetry(info: SkillMcpClientInfo, config: ClaudeCodeMcpServer): Promise { + const clientKey = this.getClientKey(info) + return await getOrCreateClientWithRetryImpl({ + state: this.state, + clientKey, + info, + config, + }) } getConnectedServers(): string[] { - return Array.from(this.clients.keys()) + return Array.from(this.state.clients.keys()) } isConnected(info: SkillMcpClientInfo): boolean { - return this.clients.has(this.getClientKey(info)) + return this.state.clients.has(this.getClientKey(info)) } } diff --git a/src/features/skill-mcp-manager/oauth-handler.ts b/src/features/skill-mcp-manager/oauth-handler.ts new file mode 100644 index 00000000..66e12b3e --- /dev/null +++ b/src/features/skill-mcp-manager/oauth-handler.ts @@ -0,0 +1,100 @@ +import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" +import { McpOAuthProvider } from "../mcp-oauth/provider" +import type { OAuthTokenData } from "../mcp-oauth/storage" +import { isStepUpRequired, mergeScopes } from "../mcp-oauth/step-up" + +export function getOrCreateAuthProvider( + authProviders: Map, + serverUrl: string, + oauth: NonNullable +): McpOAuthProvider { + const existing = authProviders.get(serverUrl) + if (existing) return existing + + const provider = new McpOAuthProvider({ + serverUrl, + clientId: oauth.clientId, + scopes: oauth.scopes, + }) + authProviders.set(serverUrl, provider) + return provider +} + +function isTokenExpired(tokenData: OAuthTokenData): boolean { + if (tokenData.expiresAt == null) return false + return tokenData.expiresAt < Math.floor(Date.now() / 1000) +} + +export async function buildHttpRequestInit( + config: ClaudeCodeMcpServer, + authProviders: Map +): Promise { + const headers: Record = {} + + if (config.headers) { + for (const [key, value] of Object.entries(config.headers)) { + headers[key] = value + } + } + + if (config.oauth && config.url) { + const provider = getOrCreateAuthProvider(authProviders, config.url, config.oauth) + let tokenData = provider.tokens() + + if (!tokenData || isTokenExpired(tokenData)) { + try { + tokenData = await provider.login() + } catch { + tokenData = null + } + } + + if (tokenData) { + headers.Authorization = `Bearer ${tokenData.accessToken}` + } + } + + return Object.keys(headers).length > 0 ? { headers } : undefined +} + +export async function handleStepUpIfNeeded(params: { + error: Error + config: ClaudeCodeMcpServer + authProviders: Map +}): Promise { + const { error, config, authProviders } = params + + if (!config.oauth || !config.url) { + return false + } + + const statusMatch = /\b403\b/.exec(error.message) + if (!statusMatch) { + return false + } + + const headers: Record = {} + const wwwAuthMatch = /WWW-Authenticate:\s*(.+)/i.exec(error.message) + if (wwwAuthMatch?.[1]) { + headers["www-authenticate"] = wwwAuthMatch[1] + } + + const stepUp = isStepUpRequired(403, headers) + if (!stepUp) { + return false + } + + const currentScopes = config.oauth.scopes ?? [] + const mergedScopes = mergeScopes(currentScopes, stepUp.requiredScopes) + config.oauth.scopes = mergedScopes + + authProviders.delete(config.url) + const provider = getOrCreateAuthProvider(authProviders, config.url, config.oauth) + + try { + await provider.login() + return true + } catch { + return false + } +} diff --git a/src/features/skill-mcp-manager/stdio-client.ts b/src/features/skill-mcp-manager/stdio-client.ts new file mode 100644 index 00000000..56b8bb51 --- /dev/null +++ b/src/features/skill-mcp-manager/stdio-client.ts @@ -0,0 +1,69 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js" +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js" +import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" +import { createCleanMcpEnvironment } from "./env-cleaner" +import { registerProcessCleanup, startCleanupTimer } from "./cleanup" +import type { ManagedClient, SkillMcpClientConnectionParams } from "./types" + +function getStdioCommand(config: ClaudeCodeMcpServer, serverName: string): string { + if (!config.command) { + throw new Error(`MCP server "${serverName}" is configured for stdio but missing 'command' field.`) + } + return config.command +} + +export async function createStdioClient(params: SkillMcpClientConnectionParams): Promise { + const { state, clientKey, info, config } = params + + const command = getStdioCommand(config, info.serverName) + const args = config.args ?? [] + const mergedEnv = createCleanMcpEnvironment(config.env) + + registerProcessCleanup(state) + + const transport = new StdioClientTransport({ + command, + args, + env: mergedEnv, + stderr: "ignore", + }) + + const client = new Client( + { name: `skill-mcp-${info.skillName}-${info.serverName}`, version: "1.0.0" }, + { capabilities: {} } + ) + + try { + await client.connect(transport) + } catch (error) { + // Close transport to prevent orphaned MCP process on connection failure + try { + await transport.close() + } catch { + // Process may already be terminated + } + + const errorMessage = error instanceof Error ? error.message : String(error) + throw new Error( + `Failed to connect to MCP server "${info.serverName}".\n\n` + + `Command: ${command} ${args.join(" ")}\n` + + `Reason: ${errorMessage}\n\n` + + `Hints:\n` + + ` - Ensure the command is installed and available in PATH\n` + + ` - Check if the MCP server package exists\n` + + ` - Verify the args are correct for this server` + ) + } + + const managedClient = { + client, + transport, + skillName: info.skillName, + lastUsedAt: Date.now(), + connectionType: "stdio", + } satisfies ManagedClient + + state.clients.set(clientKey, managedClient) + startCleanupTimer(state) + return client +} diff --git a/src/features/skill-mcp-manager/types.ts b/src/features/skill-mcp-manager/types.ts index bed9dbcb..b7a9f46c 100644 --- a/src/features/skill-mcp-manager/types.ts +++ b/src/features/skill-mcp-manager/types.ts @@ -1,4 +1,8 @@ +import type { Client } from "@modelcontextprotocol/sdk/client/index.js" +import type { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js" +import type { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js" import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" +import type { McpOAuthProvider } from "../mcp-oauth/provider" export type SkillMcpConfig = Record @@ -12,3 +16,51 @@ export interface SkillMcpServerContext { config: ClaudeCodeMcpServer skillName: string } + +/** + * Connection type for a managed MCP client. + * - "stdio": Local process via stdin/stdout + * - "http": Remote server via HTTP (Streamable HTTP transport) + */ +export type ConnectionType = "stdio" | "http" + +export interface ManagedClientBase { + client: Client + skillName: string + lastUsedAt: number + connectionType: ConnectionType +} + +export interface ManagedStdioClient extends ManagedClientBase { + connectionType: "stdio" + transport: StdioClientTransport +} + +export interface ManagedHttpClient extends ManagedClientBase { + connectionType: "http" + transport: StreamableHTTPClientTransport +} + +export type ManagedClient = ManagedStdioClient | ManagedHttpClient + +export interface ProcessCleanupHandler { + signal: NodeJS.Signals + listener: () => void +} + +export interface SkillMcpManagerState { + clients: Map + pendingConnections: Map> + authProviders: Map + cleanupRegistered: boolean + cleanupInterval: ReturnType | null + cleanupHandlers: ProcessCleanupHandler[] + idleTimeoutMs: number +} + +export interface SkillMcpClientConnectionParams { + state: SkillMcpManagerState + clientKey: string + info: SkillMcpClientInfo + config: ClaudeCodeMcpServer +}