diff --git a/.gitignore b/.gitignore index e913cc4b..5c4708d6 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ yarn.lock test-injection/ notepad.md oauth-success.html +.188e87dbff6e7fd9-00000000.bun-build diff --git a/docs/cli-guide.md b/docs/cli-guide.md index 9cda5ac7..97d368ff 100644 --- a/docs/cli-guide.md +++ b/docs/cli-guide.md @@ -134,7 +134,41 @@ bunx oh-my-opencode run [prompt] --- -## 6. `auth` - Authentication Management +## 6. `mcp oauth` - MCP OAuth Management + +Manages OAuth 2.1 authentication for remote MCP servers. + +### Usage + +```bash +# Login to an OAuth-protected MCP server +bunx oh-my-opencode mcp oauth login --server-url https://api.example.com + +# Login with explicit client ID and scopes +bunx oh-my-opencode mcp oauth login my-api --server-url https://api.example.com --client-id my-client --scopes "read,write" + +# Remove stored OAuth tokens +bunx oh-my-opencode mcp oauth logout + +# Check OAuth token status +bunx oh-my-opencode mcp oauth status [server-name] +``` + +### Options + +| Option | Description | +|--------|-------------| +| `--server-url ` | MCP server URL (required for login) | +| `--client-id ` | OAuth client ID (optional if server supports Dynamic Client Registration) | +| `--scopes ` | Comma-separated OAuth scopes | + +### Token Storage + +Tokens are stored in `~/.config/opencode/mcp-oauth.json` with `0600` permissions (owner read/write only). Key format: `{serverHost}/{resource}`. + +--- + +## 7. `auth` - Authentication Management Manages Google Antigravity OAuth authentication. Required for using Gemini models. @@ -153,7 +187,7 @@ bunx oh-my-opencode auth status --- -## 7. Configuration Files +## 8. Configuration Files The CLI searches for configuration files in the following locations (in priority order): @@ -183,7 +217,7 @@ Configuration files support **JSONC (JSON with Comments)** format. You can use c --- -## 8. Troubleshooting +## 9. Troubleshooting ### "OpenCode version too old" Error @@ -213,7 +247,7 @@ bunx oh-my-opencode doctor --category authentication --- -## 9. Non-Interactive Mode +## 10. Non-Interactive Mode Use the `--no-tui` option for CI/CD environments. @@ -227,7 +261,7 @@ bunx oh-my-opencode doctor --json > doctor-report.json --- -## 10. Developer Information +## 11. Developer Information ### CLI Structure diff --git a/docs/features.md b/docs/features.md index c2ae0984..6b60bcad 100644 --- a/docs/features.md +++ b/docs/features.md @@ -521,6 +521,37 @@ mcp: The `skill_mcp` tool invokes these operations with full schema discovery. +#### OAuth-Enabled MCPs + +Skills can define OAuth-protected remote MCP servers. OAuth 2.1 with full RFC compliance (RFC 9728, 8414, 8707, 7591) is supported: + +```yaml +--- +description: My API skill +mcp: + my-api: + url: https://api.example.com/mcp + oauth: + clientId: ${CLIENT_ID} + scopes: ["read", "write"] +--- +``` + +When a skill MCP has `oauth` configured: +- **Auto-discovery**: Fetches `/.well-known/oauth-protected-resource` (RFC 9728), falls back to `/.well-known/oauth-authorization-server` (RFC 8414) +- **Dynamic Client Registration**: Auto-registers with servers supporting RFC 7591 (clientId becomes optional) +- **PKCE**: Mandatory for all flows +- **Resource Indicators**: Auto-generated from MCP URL per RFC 8707 +- **Token Storage**: Persisted in `~/.config/opencode/mcp-oauth.json` (chmod 0600) +- **Auto-refresh**: Tokens refresh on 401; step-up authorization on 403 with `WWW-Authenticate` +- **Dynamic Port**: OAuth callback server uses an auto-discovered available port + +Pre-authenticate via CLI: + +```bash +bunx oh-my-opencode mcp oauth login --server-url https://api.example.com +``` + --- ## Context Injection diff --git a/src/cli/doctor/checks/index.ts b/src/cli/doctor/checks/index.ts index d8d4b7e7..08927105 100644 --- a/src/cli/doctor/checks/index.ts +++ b/src/cli/doctor/checks/index.ts @@ -8,6 +8,7 @@ import { getDependencyCheckDefinitions } from "./dependencies" import { getGhCliCheckDefinition } from "./gh" import { getLspCheckDefinition } from "./lsp" import { getMcpCheckDefinitions } from "./mcp" +import { getMcpOAuthCheckDefinition } from "./mcp-oauth" import { getVersionCheckDefinition } from "./version" export * from "./opencode" @@ -19,6 +20,7 @@ export * from "./dependencies" export * from "./gh" export * from "./lsp" export * from "./mcp" +export * from "./mcp-oauth" export * from "./version" export function getAllCheckDefinitions(): CheckDefinition[] { @@ -32,6 +34,7 @@ export function getAllCheckDefinitions(): CheckDefinition[] { getGhCliCheckDefinition(), getLspCheckDefinition(), ...getMcpCheckDefinitions(), + getMcpOAuthCheckDefinition(), getVersionCheckDefinition(), ] } diff --git a/src/cli/doctor/checks/mcp-oauth.test.ts b/src/cli/doctor/checks/mcp-oauth.test.ts new file mode 100644 index 00000000..e564989c --- /dev/null +++ b/src/cli/doctor/checks/mcp-oauth.test.ts @@ -0,0 +1,133 @@ +import { describe, it, expect, spyOn, afterEach } from "bun:test" +import * as mcpOauth from "./mcp-oauth" + +describe("mcp-oauth check", () => { + describe("getMcpOAuthCheckDefinition", () => { + it("returns check definition with correct properties", () => { + // #given + // #when getting definition + const def = mcpOauth.getMcpOAuthCheckDefinition() + + // #then should have correct structure + expect(def.id).toBe("mcp-oauth-tokens") + expect(def.name).toBe("MCP OAuth Tokens") + expect(def.category).toBe("tools") + expect(def.critical).toBe(false) + expect(typeof def.check).toBe("function") + }) + }) + + describe("checkMcpOAuthTokens", () => { + let readStoreSpy: ReturnType + + afterEach(() => { + readStoreSpy?.mockRestore() + }) + + it("returns skip when no tokens stored", async () => { + // #given no OAuth tokens configured + readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue(null) + + // #when checking OAuth tokens + const result = await mcpOauth.checkMcpOAuthTokens() + + // #then should skip + expect(result.status).toBe("skip") + expect(result.message).toContain("No OAuth") + }) + + it("returns pass when all tokens valid", async () => { + // #given valid tokens with future expiry (expiresAt is in epoch seconds) + const futureTime = Math.floor(Date.now() / 1000) + 3600 + readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({ + "example.com/resource1": { + accessToken: "token1", + expiresAt: futureTime, + }, + "example.com/resource2": { + accessToken: "token2", + expiresAt: futureTime, + }, + }) + + // #when checking OAuth tokens + const result = await mcpOauth.checkMcpOAuthTokens() + + // #then should pass + expect(result.status).toBe("pass") + expect(result.message).toContain("2") + expect(result.message).toContain("valid") + }) + + it("returns warn when some tokens expired", async () => { + // #given mix of valid and expired tokens (expiresAt is in epoch seconds) + const futureTime = Math.floor(Date.now() / 1000) + 3600 + const pastTime = Math.floor(Date.now() / 1000) - 3600 + readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({ + "example.com/resource1": { + accessToken: "token1", + expiresAt: futureTime, + }, + "example.com/resource2": { + accessToken: "token2", + expiresAt: pastTime, + }, + }) + + // #when checking OAuth tokens + const result = await mcpOauth.checkMcpOAuthTokens() + + // #then should warn + expect(result.status).toBe("warn") + expect(result.message).toContain("1") + expect(result.message).toContain("expired") + expect(result.details?.some((d: string) => d.includes("Expired"))).toBe( + true + ) + }) + + it("returns pass when tokens have no expiry", async () => { + // #given tokens without expiry info + readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({ + "example.com/resource1": { + accessToken: "token1", + }, + }) + + // #when checking OAuth tokens + const result = await mcpOauth.checkMcpOAuthTokens() + + // #then should pass (no expiry = assume valid) + expect(result.status).toBe("pass") + expect(result.message).toContain("1") + }) + + it("includes token details in output", async () => { + // #given multiple tokens + const futureTime = Math.floor(Date.now() / 1000) + 3600 + readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({ + "api.example.com/v1": { + accessToken: "token1", + expiresAt: futureTime, + }, + "auth.example.com/oauth": { + accessToken: "token2", + expiresAt: futureTime, + }, + }) + + // #when checking OAuth tokens + const result = await mcpOauth.checkMcpOAuthTokens() + + // #then should list tokens in details + expect(result.details).toBeDefined() + expect(result.details?.length).toBeGreaterThan(0) + expect( + result.details?.some((d: string) => d.includes("api.example.com")) + ).toBe(true) + expect( + result.details?.some((d: string) => d.includes("auth.example.com")) + ).toBe(true) + }) + }) +}) diff --git a/src/cli/doctor/checks/mcp-oauth.ts b/src/cli/doctor/checks/mcp-oauth.ts new file mode 100644 index 00000000..9c1dd62a --- /dev/null +++ b/src/cli/doctor/checks/mcp-oauth.ts @@ -0,0 +1,80 @@ +import type { CheckResult, CheckDefinition } from "../types" +import { CHECK_IDS, CHECK_NAMES } from "../constants" +import { getMcpOauthStoragePath } from "../../../features/mcp-oauth/storage" +import { existsSync, readFileSync } from "node:fs" + +interface OAuthTokenData { + accessToken: string + refreshToken?: string + expiresAt?: number + clientInfo?: { + clientId: string + clientSecret?: string + } +} + +type TokenStore = Record + +export function readTokenStore(): TokenStore | null { + const filePath = getMcpOauthStoragePath() + if (!existsSync(filePath)) { + return null + } + + try { + const content = readFileSync(filePath, "utf-8") + return JSON.parse(content) as TokenStore + } catch { + return null + } +} + +export async function checkMcpOAuthTokens(): Promise { + const store = readTokenStore() + + if (!store || Object.keys(store).length === 0) { + return { + name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS], + status: "skip", + message: "No OAuth tokens configured", + details: ["Optional: Configure OAuth tokens for MCP servers"], + } + } + + const now = Math.floor(Date.now() / 1000) + const tokens = Object.entries(store) + const expiredTokens = tokens.filter( + ([, token]) => token.expiresAt && token.expiresAt < now + ) + + if (expiredTokens.length > 0) { + return { + name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS], + status: "warn", + message: `${expiredTokens.length} of ${tokens.length} token(s) expired`, + details: [ + ...tokens + .filter(([, token]) => !token.expiresAt || token.expiresAt >= now) + .map(([key]) => `Valid: ${key}`), + ...expiredTokens.map(([key]) => `Expired: ${key}`), + ], + } + } + + return { + name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS], + status: "pass", + message: `${tokens.length} OAuth token(s) valid`, + details: tokens.map(([key]) => `Configured: ${key}`), + } +} + +export function getMcpOAuthCheckDefinition(): CheckDefinition { + return { + id: CHECK_IDS.MCP_OAUTH_TOKENS, + name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS], + category: "tools", + check: checkMcpOAuthTokens, + critical: false, + } +} diff --git a/src/cli/doctor/constants.ts b/src/cli/doctor/constants.ts index 26dbcc01..df6f8800 100644 --- a/src/cli/doctor/constants.ts +++ b/src/cli/doctor/constants.ts @@ -32,6 +32,7 @@ export const CHECK_IDS = { LSP_SERVERS: "lsp-servers", MCP_BUILTIN: "mcp-builtin", MCP_USER: "mcp-user", + MCP_OAUTH_TOKENS: "mcp-oauth-tokens", VERSION_STATUS: "version-status", } as const @@ -50,6 +51,7 @@ export const CHECK_NAMES: Record = { [CHECK_IDS.LSP_SERVERS]: "LSP Servers", [CHECK_IDS.MCP_BUILTIN]: "Built-in MCP Servers", [CHECK_IDS.MCP_USER]: "User MCP Configuration", + [CHECK_IDS.MCP_OAUTH_TOKENS]: "MCP OAuth Tokens", [CHECK_IDS.VERSION_STATUS]: "Version Status", } as const diff --git a/src/cli/index.ts b/src/cli/index.ts index 7b415d89..ddf0dfb5 100644 --- a/src/cli/index.ts +++ b/src/cli/index.ts @@ -4,6 +4,7 @@ import { install } from "./install" import { run } from "./run" import { getLocalVersion } from "./get-local-version" import { doctor } from "./doctor" +import { createMcpOAuthCommand } from "./mcp-oauth" import type { InstallArgs } from "./types" import type { RunOptions } from "./run" import type { GetLocalVersionOptions } from "./get-local-version/types" @@ -150,4 +151,6 @@ program console.log(`oh-my-opencode v${VERSION}`) }) +program.addCommand(createMcpOAuthCommand()) + program.parse() diff --git a/src/cli/mcp-oauth/index.test.ts b/src/cli/mcp-oauth/index.test.ts new file mode 100644 index 00000000..ea88a019 --- /dev/null +++ b/src/cli/mcp-oauth/index.test.ts @@ -0,0 +1,123 @@ +import { describe, it, expect } from "bun:test" +import { Command } from "commander" +import { createMcpOAuthCommand } from "./index" + +describe("mcp oauth command", () => { + + describe("command structure", () => { + it("creates mcp command group with oauth subcommand", () => { + // given + const mcpCommand = createMcpOAuthCommand() + + // when + const subcommands = mcpCommand.commands.map((cmd: Command) => cmd.name()) + + // then + expect(subcommands).toContain("oauth") + }) + + it("oauth subcommand has login, logout, and status subcommands", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + + // when + const subcommands = oauthCommand?.commands.map((cmd: Command) => cmd.name()) ?? [] + + // then + expect(subcommands).toContain("login") + expect(subcommands).toContain("logout") + expect(subcommands).toContain("status") + }) + }) + + describe("login subcommand", () => { + it("exists and has description", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login") + + // when + const description = loginCommand?.description() ?? "" + + // then + expect(loginCommand).toBeDefined() + expect(description).toContain("OAuth") + }) + + it("accepts --server-url option", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login") + + // when + const options = loginCommand?.options ?? [] + const serverUrlOption = options.find((opt: { long?: string }) => opt.long === "--server-url") + + // then + expect(serverUrlOption).toBeDefined() + }) + + it("accepts --client-id option", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login") + + // when + const options = loginCommand?.options ?? [] + const clientIdOption = options.find((opt: { long?: string }) => opt.long === "--client-id") + + // then + expect(clientIdOption).toBeDefined() + }) + + it("accepts --scopes option", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login") + + // when + const options = loginCommand?.options ?? [] + const scopesOption = options.find((opt: { long?: string }) => opt.long === "--scopes") + + // then + expect(scopesOption).toBeDefined() + }) + }) + + describe("logout subcommand", () => { + it("exists and has description", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + const logoutCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "logout") + + // when + const description = logoutCommand?.description() ?? "" + + // then + expect(logoutCommand).toBeDefined() + expect(description).toContain("tokens") + }) + }) + + describe("status subcommand", () => { + it("exists and has description", () => { + // given + const mcpCommand = createMcpOAuthCommand() + const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth") + const statusCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "status") + + // when + const description = statusCommand?.description() ?? "" + + // then + expect(statusCommand).toBeDefined() + expect(description).toContain("status") + }) + }) +}) diff --git a/src/cli/mcp-oauth/index.ts b/src/cli/mcp-oauth/index.ts new file mode 100644 index 00000000..821037ee --- /dev/null +++ b/src/cli/mcp-oauth/index.ts @@ -0,0 +1,43 @@ +import { Command } from "commander" +import { login } from "./login" +import { logout } from "./logout" +import { status } from "./status" + +export function createMcpOAuthCommand(): Command { + const mcp = new Command("mcp").description("MCP server management") + + const oauth = new Command("oauth").description("OAuth token management for MCP servers") + + oauth + .command("login ") + .description("Authenticate with an MCP server using OAuth") + .option("--server-url ", "OAuth server URL (required if not in config)") + .option("--client-id ", "OAuth client ID (optional, uses DCR if not provided)") + .option("--scopes ", "OAuth scopes to request") + .action(async (serverName: string, options) => { + const exitCode = await login(serverName, options) + process.exit(exitCode) + }) + + oauth + .command("logout ") + .description("Remove stored OAuth tokens for an MCP server") + .option("--server-url ", "OAuth server URL (use if server name differs from URL)") + .action(async (serverName: string, options) => { + const exitCode = await logout(serverName, options) + process.exit(exitCode) + }) + + oauth + .command("status [server-name]") + .description("Show OAuth token status for MCP servers") + .action(async (serverName: string | undefined) => { + const exitCode = await status(serverName) + process.exit(exitCode) + }) + + mcp.addCommand(oauth) + return mcp +} + +export { login, logout, status } diff --git a/src/cli/mcp-oauth/login.test.ts b/src/cli/mcp-oauth/login.test.ts new file mode 100644 index 00000000..917652f7 --- /dev/null +++ b/src/cli/mcp-oauth/login.test.ts @@ -0,0 +1,80 @@ +import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test" + +const mockLogin = mock(() => Promise.resolve({ accessToken: "test-token", expiresAt: 1710000000 })) + +mock.module("../../features/mcp-oauth/provider", () => ({ + McpOAuthProvider: class MockMcpOAuthProvider { + constructor(public options: { serverUrl: string; clientId?: string; scopes?: string[] }) {} + async login() { + return mockLogin() + } + }, +})) + +const { login } = await import("./login") + +describe("login command", () => { + beforeEach(() => { + mockLogin.mockClear() + }) + + afterEach(() => { + // cleanup + }) + + it("returns error code when server-url is not provided", async () => { + // given + const serverName = "test-server" + const options = {} + + // when + const exitCode = await login(serverName, options) + + // then + expect(exitCode).toBe(1) + }) + + it("returns success code when login succeeds", async () => { + // given + const serverName = "test-server" + const options = { + serverUrl: "https://oauth.example.com", + } + + // when + const exitCode = await login(serverName, options) + + // then + expect(exitCode).toBe(0) + expect(mockLogin).toHaveBeenCalledTimes(1) + }) + + it("returns error code when login throws", async () => { + // given + const serverName = "test-server" + const options = { + serverUrl: "https://oauth.example.com", + } + mockLogin.mockRejectedValueOnce(new Error("Network error")) + + // when + const exitCode = await login(serverName, options) + + // then + expect(exitCode).toBe(1) + }) + + it("returns error code when server-url is missing", async () => { + // given + const serverName = "test-server" + const options = { + clientId: "test-client-id", + } + + // when + const exitCode = await login(serverName, options) + + // then + expect(exitCode).toBe(1) + }) +}) diff --git a/src/cli/mcp-oauth/login.ts b/src/cli/mcp-oauth/login.ts new file mode 100644 index 00000000..1397900c --- /dev/null +++ b/src/cli/mcp-oauth/login.ts @@ -0,0 +1,38 @@ +import { McpOAuthProvider } from "../../features/mcp-oauth/provider" + +export interface LoginOptions { + serverUrl?: string + clientId?: string + scopes?: string[] +} + +export async function login(serverName: string, options: LoginOptions): Promise { + try { + const serverUrl = options.serverUrl + if (!serverUrl) { + console.error(`Error: --server-url is required for server "${serverName}"`) + return 1 + } + + const provider = new McpOAuthProvider({ + serverUrl, + clientId: options.clientId, + scopes: options.scopes, + }) + + console.log(`Authenticating with ${serverName}...`) + const tokenData = await provider.login() + + console.log(`✓ Successfully authenticated with ${serverName}`) + if (tokenData.expiresAt) { + const expiryDate = new Date(tokenData.expiresAt * 1000) + console.log(` Token expires at: ${expiryDate.toISOString()}`) + } + + return 0 + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + console.error(`Error: Failed to authenticate with ${serverName}: ${message}`) + return 1 + } +} diff --git a/src/cli/mcp-oauth/logout.test.ts b/src/cli/mcp-oauth/logout.test.ts new file mode 100644 index 00000000..b3d042a6 --- /dev/null +++ b/src/cli/mcp-oauth/logout.test.ts @@ -0,0 +1,65 @@ +import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test" +import { existsSync, mkdirSync, rmSync } from "node:fs" +import { join } from "node:path" +import { tmpdir } from "node:os" +import { saveToken } from "../../features/mcp-oauth/storage" + +const { logout } = await import("./logout") + +describe("logout command", () => { + const TEST_CONFIG_DIR = join(tmpdir(), "mcp-oauth-logout-test-" + Date.now()) + let originalConfigDir: string | undefined + + beforeEach(() => { + originalConfigDir = process.env.OPENCODE_CONFIG_DIR + process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR + if (!existsSync(TEST_CONFIG_DIR)) { + mkdirSync(TEST_CONFIG_DIR, { recursive: true }) + } + }) + + afterEach(() => { + if (originalConfigDir === undefined) { + delete process.env.OPENCODE_CONFIG_DIR + } else { + process.env.OPENCODE_CONFIG_DIR = originalConfigDir + } + if (existsSync(TEST_CONFIG_DIR)) { + rmSync(TEST_CONFIG_DIR, { recursive: true, force: true }) + } + }) + + it("returns success code when logout succeeds", async () => { + // given + const serverUrl = "https://test-server.example.com" + saveToken(serverUrl, serverUrl, { accessToken: "test-token" }) + + // when + const exitCode = await logout("test-server", { serverUrl }) + + // then + expect(exitCode).toBe(0) + }) + + it("handles non-existent server gracefully", async () => { + // given + const serverName = "non-existent-server" + + // when + const exitCode = await logout(serverName, { serverUrl: "https://nonexistent.example.com" }) + + // then + expect(exitCode).toBe(0) + }) + + it("returns error when --server-url is not provided", async () => { + // given + const serverName = "test-server" + + // when + const exitCode = await logout(serverName) + + // then + expect(exitCode).toBe(1) + }) +}) diff --git a/src/cli/mcp-oauth/logout.ts b/src/cli/mcp-oauth/logout.ts new file mode 100644 index 00000000..69398f35 --- /dev/null +++ b/src/cli/mcp-oauth/logout.ts @@ -0,0 +1,30 @@ +import { deleteToken } from "../../features/mcp-oauth/storage" + +export interface LogoutOptions { + serverUrl?: string +} + +export async function logout(serverName: string, options?: LogoutOptions): Promise { + try { + const serverUrl = options?.serverUrl + if (!serverUrl) { + console.error(`Error: --server-url is required for logout. Token storage uses server URLs, not names.`) + console.error(` Usage: mcp oauth logout ${serverName} --server-url https://your-server.example.com`) + return 1 + } + + const success = deleteToken(serverUrl, serverUrl) + + if (success) { + console.log(`✓ Successfully removed tokens for ${serverName}`) + return 0 + } + + console.error(`Error: Failed to remove tokens for ${serverName}`) + return 1 + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + console.error(`Error: Failed to remove tokens for ${serverName}: ${message}`) + return 1 + } +} diff --git a/src/cli/mcp-oauth/status.test.ts b/src/cli/mcp-oauth/status.test.ts new file mode 100644 index 00000000..36cc8bb4 --- /dev/null +++ b/src/cli/mcp-oauth/status.test.ts @@ -0,0 +1,48 @@ +import { describe, it, expect, beforeEach, afterEach } from "bun:test" +import { status } from "./status" + +describe("status command", () => { + beforeEach(() => { + // setup + }) + + afterEach(() => { + // cleanup + }) + + it("returns success code when checking status for specific server", async () => { + // given + const serverName = "test-server" + + // when + const exitCode = await status(serverName) + + // then + expect(typeof exitCode).toBe("number") + expect(exitCode).toBe(0) + }) + + it("returns success code when checking status for all servers", async () => { + // given + const serverName = undefined + + // when + const exitCode = await status(serverName) + + // then + expect(typeof exitCode).toBe("number") + expect(exitCode).toBe(0) + }) + + it("handles non-existent server gracefully", async () => { + // given + const serverName = "non-existent-server" + + // when + const exitCode = await status(serverName) + + // then + expect(typeof exitCode).toBe("number") + expect(exitCode).toBe(0) + }) +}) diff --git a/src/cli/mcp-oauth/status.ts b/src/cli/mcp-oauth/status.ts new file mode 100644 index 00000000..876507c1 --- /dev/null +++ b/src/cli/mcp-oauth/status.ts @@ -0,0 +1,50 @@ +import { listAllTokens, listTokensByHost } from "../../features/mcp-oauth/storage" + +export async function status(serverName: string | undefined): Promise { + try { + if (serverName) { + const tokens = listTokensByHost(serverName) + + if (Object.keys(tokens).length === 0) { + console.log(`No tokens found for ${serverName}`) + return 0 + } + + console.log(`OAuth Status for ${serverName}:`) + for (const [key, token] of Object.entries(tokens)) { + console.log(` ${key}:`) + console.log(` Access Token: [REDACTED]`) + if (token.refreshToken) { + console.log(` Refresh Token: [REDACTED]`) + } + if (token.expiresAt) { + const expiryDate = new Date(token.expiresAt * 1000) + const now = Date.now() / 1000 + const isExpired = token.expiresAt < now + const tokenStatus = isExpired ? "EXPIRED" : "VALID" + console.log(` Expiry: ${expiryDate.toISOString()} (${tokenStatus})`) + } + } + return 0 + } + + const tokens = listAllTokens() + if (Object.keys(tokens).length === 0) { + console.log("No OAuth tokens stored") + return 0 + } + + console.log("Stored OAuth Tokens:") + for (const [key, token] of Object.entries(tokens)) { + const isExpired = token.expiresAt && token.expiresAt < Date.now() / 1000 + const tokenStatus = isExpired ? "EXPIRED" : "VALID" + console.log(` ${key}: ${tokenStatus}`) + } + + return 0 + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + console.error(`Error: Failed to get token status: ${message}`) + return 1 + } +} diff --git a/src/features/claude-code-mcp-loader/types.ts b/src/features/claude-code-mcp-loader/types.ts index 838ff61f..66822e8d 100644 --- a/src/features/claude-code-mcp-loader/types.ts +++ b/src/features/claude-code-mcp-loader/types.ts @@ -7,6 +7,10 @@ export interface ClaudeCodeMcpServer { args?: string[] env?: Record headers?: Record + oauth?: { + clientId?: string + scopes?: string[] + } disabled?: boolean } diff --git a/src/features/mcp-oauth/callback-server.test.ts b/src/features/mcp-oauth/callback-server.test.ts new file mode 100644 index 00000000..3275430a --- /dev/null +++ b/src/features/mcp-oauth/callback-server.test.ts @@ -0,0 +1,129 @@ +import { afterEach, describe, expect, it } from "bun:test" +import { findAvailablePort, startCallbackServer, type CallbackServer } from "./callback-server" + +describe("findAvailablePort", () => { + it("returns the start port when it is available", async () => { + //#given + const startPort = 19877 + + //#when + const port = await findAvailablePort(startPort) + + //#then + expect(port).toBeGreaterThanOrEqual(startPort) + expect(port).toBeLessThan(startPort + 20) + }) + + it("skips busy ports and returns next available", async () => { + //#given + const blocker = Bun.serve({ + port: 19877, + hostname: "127.0.0.1", + fetch: () => new Response(), + }) + + //#when + const port = await findAvailablePort(19877) + + //#then + expect(port).toBeGreaterThan(19877) + blocker.stop(true) + }) +}) + +describe("startCallbackServer", () => { + let server: CallbackServer | null = null + + afterEach(() => { + server?.close() + server = null + }) + + it("starts server and returns port", async () => { + //#given - no preconditions + + //#when + server = await startCallbackServer() + + //#then + expect(server.port).toBeGreaterThanOrEqual(19877) + expect(typeof server.waitForCallback).toBe("function") + expect(typeof server.close).toBe("function") + }) + + it("resolves callback with code and state from query params", async () => { + //#given + server = await startCallbackServer() + const callbackUrl = `http://127.0.0.1:${server.port}/oauth/callback?code=test-code&state=test-state` + + //#when + const fetchPromise = fetch(callbackUrl) + const result = await server.waitForCallback() + const response = await fetchPromise + + //#then + expect(result).toEqual({ code: "test-code", state: "test-state" }) + expect(response.status).toBe(200) + const html = await response.text() + expect(html).toContain("Authorization successful") + }) + + it("returns 404 for non-callback routes", async () => { + //#given + server = await startCallbackServer() + + //#when + const response = await fetch(`http://127.0.0.1:${server.port}/other`) + + //#then + expect(response.status).toBe(404) + }) + + it("returns 400 and rejects when code is missing", async () => { + //#given + server = await startCallbackServer() + const callbackRejection = server.waitForCallback().catch((e: Error) => e) + + //#when + const response = await fetch(`http://127.0.0.1:${server.port}/oauth/callback?state=s`) + + //#then + expect(response.status).toBe(400) + const error = await callbackRejection + expect(error).toBeInstanceOf(Error) + expect((error as Error).message).toContain("missing code or state") + }) + + it("returns 400 and rejects when state is missing", async () => { + //#given + server = await startCallbackServer() + const callbackRejection = server.waitForCallback().catch((e: Error) => e) + + //#when + const response = await fetch(`http://127.0.0.1:${server.port}/oauth/callback?code=c`) + + //#then + expect(response.status).toBe(400) + const error = await callbackRejection + expect(error).toBeInstanceOf(Error) + expect((error as Error).message).toContain("missing code or state") + }) + + it("close stops the server immediately", async () => { + //#given + server = await startCallbackServer() + const port = server.port + + //#when + server.close() + server = null + + //#then + try { + await fetch(`http://127.0.0.1:${port}/oauth/callback?code=c&state=s`) + expect(true).toBe(false) + } catch (error) { + expect(error).toBeDefined() + } + }) +}) diff --git a/src/features/mcp-oauth/callback-server.ts b/src/features/mcp-oauth/callback-server.ts new file mode 100644 index 00000000..3f201202 --- /dev/null +++ b/src/features/mcp-oauth/callback-server.ts @@ -0,0 +1,124 @@ +const DEFAULT_PORT = 19877 +const MAX_PORT_ATTEMPTS = 20 +const TIMEOUT_MS = 5 * 60 * 1000 + +export type OAuthCallbackResult = { + code: string + state: string +} + +export type CallbackServer = { + port: number + waitForCallback: () => Promise + close: () => void +} + +const SUCCESS_HTML = ` + + + + OAuth Authorized + + + +
+

Authorization successful

+

You can close this window and return to your terminal.

+
+ +` + +async function isPortAvailable(port: number): Promise { + try { + const server = Bun.serve({ + port, + hostname: "127.0.0.1", + fetch: () => new Response(), + }) + server.stop(true) + return true + } catch { + return false + } +} + +export async function findAvailablePort(startPort: number = DEFAULT_PORT): Promise { + for (let attempt = 0; attempt < MAX_PORT_ATTEMPTS; attempt++) { + const port = startPort + attempt + if (await isPortAvailable(port)) { + return port + } + } + throw new Error(`No available port found in range ${startPort}-${startPort + MAX_PORT_ATTEMPTS - 1}`) +} + +export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise { + const port = await findAvailablePort(startPort) + + let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null + let rejectCallback: ((error: Error) => void) | null = null + + const callbackPromise = new Promise((resolve, reject) => { + resolveCallback = resolve + rejectCallback = reject + }) + + const timeoutId = setTimeout(() => { + rejectCallback?.(new Error("OAuth callback timed out after 5 minutes")) + server.stop(true) + }, TIMEOUT_MS) + + const server = Bun.serve({ + port, + hostname: "127.0.0.1", + fetch(request: Request): Response { + const url = new URL(request.url) + + if (url.pathname !== "/oauth/callback") { + return new Response("Not Found", { status: 404 }) + } + + const oauthError = url.searchParams.get("error") + if (oauthError) { + const description = url.searchParams.get("error_description") ?? oauthError + clearTimeout(timeoutId) + rejectCallback?.(new Error(`OAuth authorization failed: ${description}`)) + setTimeout(() => server.stop(true), 100) + return new Response(`Authorization failed: ${description}`, { status: 400 }) + } + + const code = url.searchParams.get("code") + const state = url.searchParams.get("state") + + if (!code || !state) { + clearTimeout(timeoutId) + rejectCallback?.(new Error("OAuth callback missing code or state parameter")) + setTimeout(() => server.stop(true), 100) + return new Response("Missing code or state parameter", { status: 400 }) + } + + resolveCallback?.({ code, state }) + clearTimeout(timeoutId) + + setTimeout(() => server.stop(true), 100) + + return new Response(SUCCESS_HTML, { + headers: { "content-type": "text/html; charset=utf-8" }, + }) + }, + }) + + return { + port, + waitForCallback: () => callbackPromise, + close: () => { + clearTimeout(timeoutId) + server.stop(true) + }, + } +} diff --git a/src/features/mcp-oauth/dcr.test.ts b/src/features/mcp-oauth/dcr.test.ts new file mode 100644 index 00000000..28c3ec2c --- /dev/null +++ b/src/features/mcp-oauth/dcr.test.ts @@ -0,0 +1,164 @@ +import { describe, expect, it } from "bun:test" +import { + getOrRegisterClient, + type ClientCredentials, + type ClientRegistrationStorage, + type DcrFetch, +} from "./dcr" + +function createStorage(initial: ClientCredentials | null): + & ClientRegistrationStorage + & { getLastKey: () => string | null; getLastSet: () => ClientCredentials | null } { + let stored = initial + let lastKey: string | null = null + let lastSet: ClientCredentials | null = null + + return { + getClientRegistration: () => stored, + setClientRegistration: (serverIdentifier: string, credentials: ClientCredentials) => { + lastKey = serverIdentifier + lastSet = credentials + stored = credentials + }, + getLastKey: () => lastKey, + getLastSet: () => lastSet, + } +} + +describe("getOrRegisterClient", () => { + it("returns cached registration when available", async () => { + // #given + const storage = createStorage({ + clientId: "cached-client", + clientSecret: "cached-secret", + }) + const fetchMock: DcrFetch = async () => { + throw new Error("fetch should not be called") + } + + // #when + const result = await getOrRegisterClient({ + registrationEndpoint: "https://server.example.com/register", + serverIdentifier: "server-1", + clientName: "Test Client", + redirectUris: ["https://app.example.com/callback"], + tokenEndpointAuthMethod: "client_secret_post", + storage, + fetch: fetchMock, + }) + + // #then + expect(result).toEqual({ + clientId: "cached-client", + clientSecret: "cached-secret", + }) + }) + + it("registers client and stores credentials when endpoint available", async () => { + // #given + const storage = createStorage(null) + let fetchCalled = false + const fetchMock: DcrFetch = async ( + input: string, + init?: { method?: string; headers?: Record; body?: string } + ) => { + fetchCalled = true + expect(input).toBe("https://server.example.com/register") + if (typeof init?.body !== "string") { + throw new Error("Expected request body string") + } + const payload = JSON.parse(init.body) + expect(payload).toEqual({ + redirect_uris: ["https://app.example.com/callback"], + client_name: "Test Client", + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + token_endpoint_auth_method: "client_secret_post", + }) + + return { + ok: true, + json: async () => ({ + client_id: "registered-client", + client_secret: "registered-secret", + }), + } + } + + // #when + const result = await getOrRegisterClient({ + registrationEndpoint: "https://server.example.com/register", + serverIdentifier: "server-2", + clientName: "Test Client", + redirectUris: ["https://app.example.com/callback"], + tokenEndpointAuthMethod: "client_secret_post", + storage, + fetch: fetchMock, + }) + + // #then + expect(fetchCalled).toBe(true) + expect(result).toEqual({ + clientId: "registered-client", + clientSecret: "registered-secret", + }) + expect(storage.getLastKey()).toBe("server-2") + expect(storage.getLastSet()).toEqual({ + clientId: "registered-client", + clientSecret: "registered-secret", + }) + }) + + it("uses config client id when registration endpoint missing", async () => { + // #given + const storage = createStorage(null) + let fetchCalled = false + const fetchMock: DcrFetch = async () => { + fetchCalled = true + return { + ok: false, + json: async () => ({}), + } + } + + // #when + const result = await getOrRegisterClient({ + registrationEndpoint: undefined, + serverIdentifier: "server-3", + clientName: "Test Client", + redirectUris: ["https://app.example.com/callback"], + tokenEndpointAuthMethod: "client_secret_post", + clientId: "config-client", + storage, + fetch: fetchMock, + }) + + // #then + expect(fetchCalled).toBe(false) + expect(result).toEqual({ clientId: "config-client" }) + }) + + it("falls back to config client id when registration fails", async () => { + // #given + const storage = createStorage(null) + const fetchMock: DcrFetch = async () => { + throw new Error("network error") + } + + // #when + const result = await getOrRegisterClient({ + registrationEndpoint: "https://server.example.com/register", + serverIdentifier: "server-4", + clientName: "Test Client", + redirectUris: ["https://app.example.com/callback"], + tokenEndpointAuthMethod: "client_secret_post", + clientId: "fallback-client", + storage, + fetch: fetchMock, + }) + + // #then + expect(result).toEqual({ clientId: "fallback-client" }) + expect(storage.getLastSet()).toBeNull() + }) +}) diff --git a/src/features/mcp-oauth/dcr.ts b/src/features/mcp-oauth/dcr.ts new file mode 100644 index 00000000..b8281860 --- /dev/null +++ b/src/features/mcp-oauth/dcr.ts @@ -0,0 +1,98 @@ +export type ClientRegistrationRequest = { + redirect_uris: string[] + client_name: string + grant_types: ["authorization_code", "refresh_token"] + response_types: ["code"] + token_endpoint_auth_method: "none" | "client_secret_post" +} + +export type ClientCredentials = { + clientId: string + clientSecret?: string +} + +export type ClientRegistrationStorage = { + getClientRegistration: (serverIdentifier: string) => ClientCredentials | null + setClientRegistration: ( + serverIdentifier: string, + credentials: ClientCredentials + ) => void +} + +export type DynamicClientRegistrationOptions = { + registrationEndpoint?: string | null + serverIdentifier?: string + clientName: string + redirectUris: string[] + tokenEndpointAuthMethod: "none" | "client_secret_post" + clientId?: string | null + storage: ClientRegistrationStorage + fetch?: DcrFetch +} + +export type DcrFetch = ( + input: string, + init?: { method?: string; headers?: Record; body?: string } +) => Promise<{ ok: boolean; json: () => Promise }> + +export async function getOrRegisterClient( + options: DynamicClientRegistrationOptions +): Promise { + const serverIdentifier = + options.serverIdentifier ?? options.registrationEndpoint ?? "default" + const existing = options.storage.getClientRegistration(serverIdentifier) + if (existing) return existing + + if (!options.registrationEndpoint) { + return options.clientId ? { clientId: options.clientId } : null + } + + const fetchImpl = options.fetch ?? globalThis.fetch + const request: ClientRegistrationRequest = { + redirect_uris: options.redirectUris, + client_name: options.clientName, + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + token_endpoint_auth_method: options.tokenEndpointAuthMethod, + } + + try { + const response = await fetchImpl(options.registrationEndpoint, { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify(request), + }) + + if (!response.ok) { + return options.clientId ? { clientId: options.clientId } : null + } + + const data: unknown = await response.json() + const parsed = parseRegistrationResponse(data) + if (!parsed) { + return options.clientId ? { clientId: options.clientId } : null + } + + options.storage.setClientRegistration(serverIdentifier, parsed) + return parsed + } catch { + return options.clientId ? { clientId: options.clientId } : null + } +} + +function parseRegistrationResponse(data: unknown): ClientCredentials | null { + if (!isRecord(data)) return null + const clientId = data.client_id + if (typeof clientId !== "string" || clientId.length === 0) return null + + const clientSecret = data.client_secret + if (typeof clientSecret === "string" && clientSecret.length > 0) { + return { clientId, clientSecret } + } + + return { clientId } +} + +function isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null +} diff --git a/src/features/mcp-oauth/discovery.test.ts b/src/features/mcp-oauth/discovery.test.ts new file mode 100644 index 00000000..3edf93ef --- /dev/null +++ b/src/features/mcp-oauth/discovery.test.ts @@ -0,0 +1,175 @@ +import { describe, test, expect, beforeEach, afterEach } from "bun:test" +import { discoverOAuthServerMetadata, resetDiscoveryCache } from "./discovery" + +describe("discoverOAuthServerMetadata", () => { + const originalFetch = globalThis.fetch + + beforeEach(() => { + resetDiscoveryCache() + }) + + afterEach(() => { + Object.defineProperty(globalThis, "fetch", { value: originalFetch, configurable: true }) + }) + + test("returns endpoints from PRM + AS discovery", () => { + // #given + const resource = "https://mcp.example.com" + const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString() + const authServer = "https://auth.example.com" + const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString() + const calls: string[] = [] + const fetchMock = async (input: string | URL) => { + const url = typeof input === "string" ? input : input.toString() + calls.push(url) + if (url === prmUrl) { + return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 }) + } + if (url === asUrl) { + return new Response( + JSON.stringify({ + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + }), + { status: 200 } + ) + } + return new Response("not found", { status: 404 }) + } + Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true }) + + // #when + return discoverOAuthServerMetadata(resource).then((result) => { + // #then + expect(result).toEqual({ + authorizationEndpoint: "https://auth.example.com/authorize", + tokenEndpoint: "https://auth.example.com/token", + registrationEndpoint: "https://auth.example.com/register", + resource, + }) + expect(calls).toEqual([prmUrl, asUrl]) + }) + }) + + test("falls back to RFC 8414 when PRM returns 404", () => { + // #given + const resource = "https://mcp.example.com" + const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString() + const asUrl = new URL("/.well-known/oauth-authorization-server", resource).toString() + const calls: string[] = [] + const fetchMock = async (input: string | URL) => { + const url = typeof input === "string" ? input : input.toString() + calls.push(url) + if (url === prmUrl) { + return new Response("not found", { status: 404 }) + } + if (url === asUrl) { + return new Response( + JSON.stringify({ + authorization_endpoint: "https://mcp.example.com/authorize", + token_endpoint: "https://mcp.example.com/token", + }), + { status: 200 } + ) + } + return new Response("not found", { status: 404 }) + } + Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true }) + + // #when + return discoverOAuthServerMetadata(resource).then((result) => { + // #then + expect(result).toEqual({ + authorizationEndpoint: "https://mcp.example.com/authorize", + tokenEndpoint: "https://mcp.example.com/token", + registrationEndpoint: undefined, + resource, + }) + expect(calls).toEqual([prmUrl, asUrl]) + }) + }) + + test("throws when both PRM and AS discovery return 404", () => { + // #given + const resource = "https://mcp.example.com" + const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString() + const asUrl = new URL("/.well-known/oauth-authorization-server", resource).toString() + const fetchMock = async (input: string | URL) => { + const url = typeof input === "string" ? input : input.toString() + if (url === prmUrl || url === asUrl) { + return new Response("not found", { status: 404 }) + } + return new Response("not found", { status: 404 }) + } + Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true }) + + // #when + const result = discoverOAuthServerMetadata(resource) + + // #then + return expect(result).rejects.toThrow("OAuth authorization server metadata not found") + }) + + test("throws when AS metadata is malformed", () => { + // #given + const resource = "https://mcp.example.com" + const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString() + const authServer = "https://auth.example.com" + const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString() + const fetchMock = async (input: string | URL) => { + const url = typeof input === "string" ? input : input.toString() + if (url === prmUrl) { + return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 }) + } + if (url === asUrl) { + return new Response(JSON.stringify({ authorization_endpoint: "https://auth.example.com/authorize" }), { + status: 200, + }) + } + return new Response("not found", { status: 404 }) + } + Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true }) + + // #when + const result = discoverOAuthServerMetadata(resource) + + // #then + return expect(result).rejects.toThrow("token_endpoint") + }) + + test("caches discovery results per resource URL", () => { + // #given + const resource = "https://mcp.example.com" + const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString() + const authServer = "https://auth.example.com" + const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString() + const calls: string[] = [] + const fetchMock = async (input: string | URL) => { + const url = typeof input === "string" ? input : input.toString() + calls.push(url) + if (url === prmUrl) { + return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 }) + } + if (url === asUrl) { + return new Response( + JSON.stringify({ + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + }), + { status: 200 } + ) + } + return new Response("not found", { status: 404 }) + } + Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true }) + + // #when + return discoverOAuthServerMetadata(resource) + .then(() => discoverOAuthServerMetadata(resource)) + .then(() => { + // #then + expect(calls).toEqual([prmUrl, asUrl]) + }) + }) +}) diff --git a/src/features/mcp-oauth/discovery.ts b/src/features/mcp-oauth/discovery.ts new file mode 100644 index 00000000..619520d4 --- /dev/null +++ b/src/features/mcp-oauth/discovery.ts @@ -0,0 +1,123 @@ +export interface OAuthServerMetadata { + authorizationEndpoint: string + tokenEndpoint: string + registrationEndpoint?: string + resource: string +} + +const discoveryCache = new Map() +const pendingDiscovery = new Map>() + +function parseHttpsUrl(value: string, label: string): URL { + const parsed = new URL(value) + if (parsed.protocol !== "https:") { + throw new Error(`${label} must use https`) + } + return parsed +} + +function readStringField(source: Record, field: string): string { + const value = source[field] + if (typeof value !== "string" || value.length === 0) { + throw new Error(`OAuth metadata missing ${field}`) + } + return value +} + +async function fetchMetadata(url: string): Promise<{ ok: true; json: Record } | { ok: false; status: number }> { + const response = await fetch(url, { headers: { accept: "application/json" } }) + if (!response.ok) { + return { ok: false, status: response.status } + } + const json = (await response.json().catch(() => null)) as Record | null + if (!json || typeof json !== "object") { + throw new Error("OAuth metadata response is not valid JSON") + } + return { ok: true, json } +} + +async function fetchAuthorizationServerMetadata(issuer: string, resource: string): Promise { + const issuerUrl = parseHttpsUrl(issuer, "Authorization server URL") + const issuerPath = issuerUrl.pathname.replace(/\/+$/, "") + const metadataUrl = new URL(`/.well-known/oauth-authorization-server${issuerPath}`, issuerUrl).toString() + const metadata = await fetchMetadata(metadataUrl) + + if (!metadata.ok) { + if (metadata.status === 404) { + throw new Error("OAuth authorization server metadata not found") + } + throw new Error(`OAuth authorization server metadata fetch failed (${metadata.status})`) + } + + const authorizationEndpoint = parseHttpsUrl( + readStringField(metadata.json, "authorization_endpoint"), + "authorization_endpoint" + ).toString() + const tokenEndpoint = parseHttpsUrl( + readStringField(metadata.json, "token_endpoint"), + "token_endpoint" + ).toString() + const registrationEndpointValue = metadata.json.registration_endpoint + const registrationEndpoint = + typeof registrationEndpointValue === "string" && registrationEndpointValue.length > 0 + ? parseHttpsUrl(registrationEndpointValue, "registration_endpoint").toString() + : undefined + + return { + authorizationEndpoint, + tokenEndpoint, + registrationEndpoint, + resource, + } +} + +function parseAuthorizationServers(metadata: Record): string[] { + const servers = metadata.authorization_servers + if (!Array.isArray(servers)) return [] + return servers.filter((server): server is string => typeof server === "string" && server.length > 0) +} + +export async function discoverOAuthServerMetadata(resource: string): Promise { + const resourceUrl = parseHttpsUrl(resource, "Resource server URL") + const resourceKey = resourceUrl.toString() + + const cached = discoveryCache.get(resourceKey) + if (cached) return cached + + const pending = pendingDiscovery.get(resourceKey) + if (pending) return pending + + const discoveryPromise = (async () => { + const prmUrl = new URL("/.well-known/oauth-protected-resource", resourceUrl).toString() + const prmResponse = await fetchMetadata(prmUrl) + + if (prmResponse.ok) { + const authServers = parseAuthorizationServers(prmResponse.json) + if (authServers.length === 0) { + throw new Error("OAuth protected resource metadata missing authorization_servers") + } + return fetchAuthorizationServerMetadata(authServers[0], resource) + } + + if (prmResponse.status !== 404) { + throw new Error(`OAuth protected resource metadata fetch failed (${prmResponse.status})`) + } + + return fetchAuthorizationServerMetadata(resourceKey, resource) + })() + + pendingDiscovery.set(resourceKey, discoveryPromise) + + try { + const result = await discoveryPromise + discoveryCache.set(resourceKey, result) + return result + } finally { + pendingDiscovery.delete(resourceKey) + } +} + +export function resetDiscoveryCache(): void { + discoveryCache.clear() + pendingDiscovery.clear() +} diff --git a/src/features/mcp-oauth/index.ts b/src/features/mcp-oauth/index.ts new file mode 100644 index 00000000..06861aae --- /dev/null +++ b/src/features/mcp-oauth/index.ts @@ -0,0 +1 @@ +export * from "./schema" diff --git a/src/features/mcp-oauth/provider.test.ts b/src/features/mcp-oauth/provider.test.ts new file mode 100644 index 00000000..5f42c4e5 --- /dev/null +++ b/src/features/mcp-oauth/provider.test.ts @@ -0,0 +1,223 @@ +import { describe, expect, it, beforeEach, afterEach, mock } from "bun:test" +import { createHash, randomBytes } from "node:crypto" +import { McpOAuthProvider, generateCodeVerifier, generateCodeChallenge, buildAuthorizationUrl } from "./provider" +import type { OAuthTokenData } from "./storage" + +describe("McpOAuthProvider", () => { + describe("generateCodeVerifier", () => { + it("returns a base64url-encoded 32-byte random string", () => { + //#given + const verifier = generateCodeVerifier() + + //#when + const decoded = Buffer.from(verifier, "base64url") + + //#then + expect(decoded.length).toBe(32) + expect(verifier).toMatch(/^[A-Za-z0-9_-]+$/) + }) + + it("produces unique values on each call", () => { + //#given + const first = generateCodeVerifier() + + //#when + const second = generateCodeVerifier() + + //#then + expect(first).not.toBe(second) + }) + }) + + describe("generateCodeChallenge", () => { + it("returns SHA256 base64url digest of the verifier", () => { + //#given + const verifier = "test-verifier-value" + const expected = createHash("sha256").update(verifier).digest("base64url") + + //#when + const challenge = generateCodeChallenge(verifier) + + //#then + expect(challenge).toBe(expected) + }) + }) + + describe("buildAuthorizationUrl", () => { + it("builds URL with all required PKCE parameters", () => { + //#given + const endpoint = "https://auth.example.com/authorize" + + //#when + const url = buildAuthorizationUrl(endpoint, { + clientId: "my-client", + redirectUri: "http://127.0.0.1:8912/callback", + codeChallenge: "challenge-value", + state: "state-value", + scopes: ["openid", "profile"], + resource: "https://mcp.example.com", + }) + + //#then + const parsed = new URL(url) + expect(parsed.origin + parsed.pathname).toBe("https://auth.example.com/authorize") + expect(parsed.searchParams.get("response_type")).toBe("code") + expect(parsed.searchParams.get("client_id")).toBe("my-client") + expect(parsed.searchParams.get("redirect_uri")).toBe("http://127.0.0.1:8912/callback") + expect(parsed.searchParams.get("code_challenge")).toBe("challenge-value") + expect(parsed.searchParams.get("code_challenge_method")).toBe("S256") + expect(parsed.searchParams.get("state")).toBe("state-value") + expect(parsed.searchParams.get("scope")).toBe("openid profile") + expect(parsed.searchParams.get("resource")).toBe("https://mcp.example.com") + }) + + it("omits scope when empty", () => { + //#given + const endpoint = "https://auth.example.com/authorize" + + //#when + const url = buildAuthorizationUrl(endpoint, { + clientId: "my-client", + redirectUri: "http://127.0.0.1:8912/callback", + codeChallenge: "challenge-value", + state: "state-value", + scopes: [], + }) + + //#then + const parsed = new URL(url) + expect(parsed.searchParams.has("scope")).toBe(false) + }) + + it("omits resource when undefined", () => { + //#given + const endpoint = "https://auth.example.com/authorize" + + //#when + const url = buildAuthorizationUrl(endpoint, { + clientId: "my-client", + redirectUri: "http://127.0.0.1:8912/callback", + codeChallenge: "challenge-value", + state: "state-value", + }) + + //#then + const parsed = new URL(url) + expect(parsed.searchParams.has("resource")).toBe(false) + }) + }) + + describe("constructor and basic methods", () => { + it("stores serverUrl and optional clientId and scopes", () => { + //#given + const options = { + serverUrl: "https://mcp.example.com", + clientId: "my-client", + scopes: ["openid"], + } + + //#when + const provider = new McpOAuthProvider(options) + + //#then + expect(provider.tokens()).toBeNull() + expect(provider.clientInformation()).toBeNull() + expect(provider.codeVerifier()).toBeNull() + }) + + it("defaults scopes to empty array", () => { + //#given + const options = { serverUrl: "https://mcp.example.com" } + + //#when + const provider = new McpOAuthProvider(options) + + //#then + expect(provider.redirectUrl()).toBe("http://127.0.0.1:19877/callback") + }) + }) + + describe("saveCodeVerifier / codeVerifier", () => { + it("stores and retrieves code verifier", () => { + //#given + const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" }) + + //#when + provider.saveCodeVerifier("my-verifier") + + //#then + expect(provider.codeVerifier()).toBe("my-verifier") + }) + }) + + describe("saveTokens / tokens", () => { + let originalEnv: string | undefined + + beforeEach(() => { + originalEnv = process.env.OPENCODE_CONFIG_DIR + const { mkdirSync } = require("node:fs") + const { tmpdir } = require("node:os") + const { join } = require("node:path") + const testDir = join(tmpdir(), "mcp-oauth-provider-test-" + Date.now()) + mkdirSync(testDir, { recursive: true }) + process.env.OPENCODE_CONFIG_DIR = testDir + }) + + afterEach(() => { + if (originalEnv === undefined) { + delete process.env.OPENCODE_CONFIG_DIR + } else { + process.env.OPENCODE_CONFIG_DIR = originalEnv + } + }) + + it("persists and loads token data via storage", () => { + //#given + const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" }) + const tokenData: OAuthTokenData = { + accessToken: "access-token-123", + refreshToken: "refresh-token-456", + expiresAt: 1710000000, + } + + //#when + const saved = provider.saveTokens(tokenData) + const loaded = provider.tokens() + + //#then + expect(saved).toBe(true) + expect(loaded).toEqual(tokenData) + }) + }) + + describe("redirectToAuthorization", () => { + it("throws when no client information is set", async () => { + //#given + const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" }) + const metadata = { + authorizationEndpoint: "https://auth.example.com/authorize", + tokenEndpoint: "https://auth.example.com/token", + resource: "https://mcp.example.com", + } + + //#when + const result = provider.redirectToAuthorization(metadata) + + //#then + await expect(result).rejects.toThrow("No client information available") + }) + }) + + describe("redirectUrl", () => { + it("returns localhost callback URL with default port", () => { + //#given + const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" }) + + //#when + const url = provider.redirectUrl() + + //#then + expect(url).toBe("http://127.0.0.1:19877/callback") + }) + }) +}) diff --git a/src/features/mcp-oauth/provider.ts b/src/features/mcp-oauth/provider.ts new file mode 100644 index 00000000..6b4a69b3 --- /dev/null +++ b/src/features/mcp-oauth/provider.ts @@ -0,0 +1,295 @@ +import { createHash, randomBytes } from "node:crypto" +import { createServer } from "node:http" +import { spawn } from "node:child_process" +import type { OAuthTokenData } from "./storage" +import { loadToken, saveToken } from "./storage" +import { discoverOAuthServerMetadata } from "./discovery" +import type { OAuthServerMetadata } from "./discovery" +import { getOrRegisterClient } from "./dcr" +import type { ClientCredentials, ClientRegistrationStorage } from "./dcr" +import { findAvailablePort } from "./callback-server" + +export type McpOAuthProviderOptions = { + serverUrl: string + clientId?: string + scopes?: string[] +} + +type CallbackResult = { + code: string + state: string +} + +function generateCodeVerifier(): string { + return randomBytes(32).toString("base64url") +} + +function generateCodeChallenge(verifier: string): string { + return createHash("sha256").update(verifier).digest("base64url") +} + +function buildAuthorizationUrl( + authorizationEndpoint: string, + options: { + clientId: string + redirectUri: string + codeChallenge: string + state: string + scopes?: string[] + resource?: string + } +): string { + const url = new URL(authorizationEndpoint) + url.searchParams.set("response_type", "code") + url.searchParams.set("client_id", options.clientId) + url.searchParams.set("redirect_uri", options.redirectUri) + url.searchParams.set("code_challenge", options.codeChallenge) + url.searchParams.set("code_challenge_method", "S256") + url.searchParams.set("state", options.state) + if (options.scopes && options.scopes.length > 0) { + url.searchParams.set("scope", options.scopes.join(" ")) + } + if (options.resource) { + url.searchParams.set("resource", options.resource) + } + return url.toString() +} + +const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 + +function startCallbackServer(port: number): Promise { + return new Promise((resolve, reject) => { + let timeoutId: ReturnType + + const server = createServer((request, response) => { + clearTimeout(timeoutId) + + const requestUrl = new URL(request.url ?? "/", `http://localhost:${port}`) + const code = requestUrl.searchParams.get("code") + const state = requestUrl.searchParams.get("state") + const error = requestUrl.searchParams.get("error") + + if (error) { + const errorDescription = requestUrl.searchParams.get("error_description") ?? error + response.writeHead(400, { "content-type": "text/html" }) + response.end("

Authorization failed

") + server.close() + reject(new Error(`OAuth authorization error: ${errorDescription}`)) + return + } + + if (!code || !state) { + response.writeHead(400, { "content-type": "text/html" }) + response.end("

Missing code or state

") + server.close() + reject(new Error("OAuth callback missing code or state parameter")) + return + } + + response.writeHead(200, { "content-type": "text/html" }) + response.end("

Authorization successful. You can close this tab.

") + server.close() + resolve({ code, state }) + }) + + timeoutId = setTimeout(() => { + server.close() + reject(new Error("OAuth callback timed out after 5 minutes")) + }, CALLBACK_TIMEOUT_MS) + + server.listen(port, "127.0.0.1") + server.on("error", (err) => { + clearTimeout(timeoutId) + reject(err) + }) + }) +} + +function openBrowser(url: string): void { + const platform = process.platform + let cmd: string + let args: string[] + + if (platform === "darwin") { + cmd = "open" + args = [url] + } else if (platform === "win32") { + cmd = "explorer" + args = [url] + } else { + cmd = "xdg-open" + args = [url] + } + + try { + const child = spawn(cmd, args, { stdio: "ignore", detached: true }) + child.on("error", () => {}) + child.unref() + } catch { + // Browser open failed — user must navigate manually + } +} + +export class McpOAuthProvider { + private readonly serverUrl: string + private readonly configClientId: string | undefined + private readonly scopes: string[] + private storedCodeVerifier: string | null = null + private storedClientInfo: ClientCredentials | null = null + private callbackPort: number | null = null + + constructor(options: McpOAuthProviderOptions) { + this.serverUrl = options.serverUrl + this.configClientId = options.clientId + this.scopes = options.scopes ?? [] + } + + tokens(): OAuthTokenData | null { + return loadToken(this.serverUrl, this.serverUrl) + } + + saveTokens(tokenData: OAuthTokenData): boolean { + return saveToken(this.serverUrl, this.serverUrl, tokenData) + } + + clientInformation(): ClientCredentials | null { + if (this.storedClientInfo) return this.storedClientInfo + const tokenData = this.tokens() + if (tokenData?.clientInfo) { + this.storedClientInfo = tokenData.clientInfo + return this.storedClientInfo + } + return null + } + + redirectUrl(): string { + return `http://127.0.0.1:${this.callbackPort ?? 19877}/callback` + } + + saveCodeVerifier(verifier: string): void { + this.storedCodeVerifier = verifier + } + + codeVerifier(): string | null { + return this.storedCodeVerifier + } + + async redirectToAuthorization(metadata: OAuthServerMetadata): Promise { + const verifier = generateCodeVerifier() + this.saveCodeVerifier(verifier) + const challenge = generateCodeChallenge(verifier) + const state = randomBytes(16).toString("hex") + + const clientInfo = this.clientInformation() + if (!clientInfo) { + throw new Error("No client information available. Run login() or register a client first.") + } + + if (this.callbackPort === null) { + this.callbackPort = await findAvailablePort() + } + + const authUrl = buildAuthorizationUrl(metadata.authorizationEndpoint, { + clientId: clientInfo.clientId, + redirectUri: this.redirectUrl(), + codeChallenge: challenge, + state, + scopes: this.scopes, + resource: metadata.resource, + }) + + const callbackPromise = startCallbackServer(this.callbackPort) + openBrowser(authUrl) + + const result = await callbackPromise + if (result.state !== state) { + throw new Error("OAuth state mismatch") + } + + return result + } + + async login(): Promise { + const metadata = await discoverOAuthServerMetadata(this.serverUrl) + + const clientRegistrationStorage: ClientRegistrationStorage = { + getClientRegistration: () => this.storedClientInfo, + setClientRegistration: (_serverIdentifier: string, credentials: ClientCredentials) => { + this.storedClientInfo = credentials + }, + } + + const clientInfo = await getOrRegisterClient({ + registrationEndpoint: metadata.registrationEndpoint, + serverIdentifier: this.serverUrl, + clientName: "oh-my-opencode", + redirectUris: [this.redirectUrl()], + tokenEndpointAuthMethod: "none", + clientId: this.configClientId, + storage: clientRegistrationStorage, + }) + + if (!clientInfo) { + throw new Error("Failed to obtain client credentials. Provide a clientId or ensure the server supports DCR.") + } + + this.storedClientInfo = clientInfo + + const { code } = await this.redirectToAuthorization(metadata) + const verifier = this.codeVerifier() + if (!verifier) { + throw new Error("Code verifier not found") + } + + const tokenResponse = await fetch(metadata.tokenEndpoint, { + method: "POST", + headers: { "content-type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ + grant_type: "authorization_code", + code, + redirect_uri: this.redirectUrl(), + client_id: clientInfo.clientId, + code_verifier: verifier, + ...(metadata.resource ? { resource: metadata.resource } : {}), + }).toString(), + }) + + if (!tokenResponse.ok) { + let errorDetail = `${tokenResponse.status}` + try { + const body = (await tokenResponse.json()) as Record + if (body.error) { + errorDetail = `${tokenResponse.status} ${body.error}` + if (body.error_description) { + errorDetail += `: ${body.error_description}` + } + } + } catch { + // Response body not JSON + } + throw new Error(`Token exchange failed: ${errorDetail}`) + } + + const tokenData = (await tokenResponse.json()) as Record + const accessToken = tokenData.access_token + if (typeof accessToken !== "string") { + throw new Error("Token response missing access_token") + } + + const oauthTokenData: OAuthTokenData = { + accessToken, + refreshToken: typeof tokenData.refresh_token === "string" ? tokenData.refresh_token : undefined, + expiresAt: + typeof tokenData.expires_in === "number" ? Math.floor(Date.now() / 1000) + tokenData.expires_in : undefined, + clientInfo: { + clientId: clientInfo.clientId, + clientSecret: clientInfo.clientSecret, + }, + } + + this.saveTokens(oauthTokenData) + return oauthTokenData + } +} + +export { generateCodeVerifier, generateCodeChallenge, buildAuthorizationUrl, startCallbackServer } diff --git a/src/features/mcp-oauth/resource-indicator.test.ts b/src/features/mcp-oauth/resource-indicator.test.ts new file mode 100644 index 00000000..1378e15c --- /dev/null +++ b/src/features/mcp-oauth/resource-indicator.test.ts @@ -0,0 +1,121 @@ +import { describe, expect, it } from "bun:test" +import { addResourceToParams, getResourceIndicator } from "./resource-indicator" + +describe("getResourceIndicator", () => { + it("returns URL unchanged when already normalized", () => { + // #given + const url = "https://mcp.example.com" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com") + }) + + it("strips trailing slash", () => { + // #given + const url = "https://mcp.example.com/" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com") + }) + + it("strips query parameters", () => { + // #given + const url = "https://mcp.example.com/v1?token=abc&debug=true" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com/v1") + }) + + it("strips fragment", () => { + // #given + const url = "https://mcp.example.com/v1#section" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com/v1") + }) + + it("strips query and trailing slash together", () => { + // #given + const url = "https://mcp.example.com/api/?key=val" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com/api") + }) + + it("preserves path segments", () => { + // #given + const url = "https://mcp.example.com/org/project/v2" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com/org/project/v2") + }) + + it("preserves port number", () => { + // #given + const url = "https://mcp.example.com:8443/api/" + + // #when + const result = getResourceIndicator(url) + + // #then + expect(result).toBe("https://mcp.example.com:8443/api") + }) +}) + +describe("addResourceToParams", () => { + it("sets resource parameter on empty params", () => { + // #given + const params = new URLSearchParams() + const resource = "https://mcp.example.com" + + // #when + addResourceToParams(params, resource) + + // #then + expect(params.get("resource")).toBe("https://mcp.example.com") + }) + + it("adds resource alongside existing parameters", () => { + // #given + const params = new URLSearchParams({ grant_type: "authorization_code" }) + const resource = "https://mcp.example.com/v1" + + // #when + addResourceToParams(params, resource) + + // #then + expect(params.get("grant_type")).toBe("authorization_code") + expect(params.get("resource")).toBe("https://mcp.example.com/v1") + }) + + it("overwrites existing resource parameter", () => { + // #given + const params = new URLSearchParams({ resource: "https://old.example.com" }) + const resource = "https://new.example.com" + + // #when + addResourceToParams(params, resource) + + // #then + expect(params.get("resource")).toBe("https://new.example.com") + expect(params.getAll("resource")).toHaveLength(1) + }) +}) diff --git a/src/features/mcp-oauth/resource-indicator.ts b/src/features/mcp-oauth/resource-indicator.ts new file mode 100644 index 00000000..bd73a1aa --- /dev/null +++ b/src/features/mcp-oauth/resource-indicator.ts @@ -0,0 +1,16 @@ +export function getResourceIndicator(url: string): string { + const parsed = new URL(url) + parsed.search = "" + parsed.hash = "" + + let normalized = parsed.toString() + if (normalized.endsWith("/")) { + normalized = normalized.slice(0, -1) + } + + return normalized +} + +export function addResourceToParams(params: URLSearchParams, resource: string): void { + params.set("resource", resource) +} diff --git a/src/features/mcp-oauth/schema.test.ts b/src/features/mcp-oauth/schema.test.ts new file mode 100644 index 00000000..2703aee3 --- /dev/null +++ b/src/features/mcp-oauth/schema.test.ts @@ -0,0 +1,60 @@ +/// +import { describe, expect, test } from "bun:test" +import { McpOauthSchema } from "./schema" + +describe("McpOauthSchema", () => { + test("parses empty oauth config", () => { + //#given + const input = {} + + //#when + const result = McpOauthSchema.parse(input) + + //#then + expect(result).toEqual({}) + }) + + test("parses oauth config with clientId", () => { + //#given + const input = { clientId: "client-123" } + + //#when + const result = McpOauthSchema.parse(input) + + //#then + expect(result).toEqual({ clientId: "client-123" }) + }) + + test("parses oauth config with scopes", () => { + //#given + const input = { scopes: ["openid", "profile"] } + + //#when + const result = McpOauthSchema.parse(input) + + //#then + expect(result).toEqual({ scopes: ["openid", "profile"] }) + }) + + test("rejects non-string clientId", () => { + //#given + const input = { clientId: 123 } + + //#when + const result = McpOauthSchema.safeParse(input) + + //#then + expect(result.success).toBe(false) + }) + + test("rejects non-string scopes", () => { + //#given + const input = { scopes: ["openid", 42] } + + //#when + const result = McpOauthSchema.safeParse(input) + + //#then + expect(result.success).toBe(false) + }) +}) diff --git a/src/features/mcp-oauth/schema.ts b/src/features/mcp-oauth/schema.ts new file mode 100644 index 00000000..c9db14a2 --- /dev/null +++ b/src/features/mcp-oauth/schema.ts @@ -0,0 +1,8 @@ +import { z } from "zod" + +export const McpOauthSchema = z.object({ + clientId: z.string().optional(), + scopes: z.array(z.string()).optional(), +}) + +export type McpOauth = z.infer diff --git a/src/features/mcp-oauth/step-up.test.ts b/src/features/mcp-oauth/step-up.test.ts new file mode 100644 index 00000000..550e2f81 --- /dev/null +++ b/src/features/mcp-oauth/step-up.test.ts @@ -0,0 +1,223 @@ +import { describe, expect, it } from "bun:test" +import { isStepUpRequired, mergeScopes, parseWwwAuthenticate } from "./step-up" + +describe("parseWwwAuthenticate", () => { + it("parses scope from simple Bearer header", () => { + // #given + const header = 'Bearer scope="read write"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toEqual({ requiredScopes: ["read", "write"] }) + }) + + it("parses scope with error fields", () => { + // #given + const header = 'Bearer error="insufficient_scope", scope="admin"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toEqual({ + requiredScopes: ["admin"], + error: "insufficient_scope", + }) + }) + + it("parses all fields including error_description", () => { + // #given + const header = + 'Bearer realm="example", error="insufficient_scope", error_description="Need admin access", scope="admin write"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toEqual({ + requiredScopes: ["admin", "write"], + error: "insufficient_scope", + errorDescription: "Need admin access", + }) + }) + + it("returns null for non-Bearer scheme", () => { + // #given + const header = 'Basic realm="example"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toBeNull() + }) + + it("returns null when no scope parameter present", () => { + // #given + const header = 'Bearer error="invalid_token"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toBeNull() + }) + + it("returns null for empty scope value", () => { + // #given + const header = 'Bearer scope=""' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toBeNull() + }) + + it("returns null for bare Bearer with no params", () => { + // #given + const header = "Bearer" + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toBeNull() + }) + + it("handles case-insensitive Bearer prefix", () => { + // #given + const header = 'bearer scope="read"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toEqual({ requiredScopes: ["read"] }) + }) + + it("parses single scope value", () => { + // #given + const header = 'Bearer scope="admin"' + + // #when + const result = parseWwwAuthenticate(header) + + // #then + expect(result).toEqual({ requiredScopes: ["admin"] }) + }) +}) + +describe("mergeScopes", () => { + it("merges new scopes into existing", () => { + // #given + const existing = ["read", "write"] + const required = ["admin", "write"] + + // #when + const result = mergeScopes(existing, required) + + // #then + expect(result).toEqual(["read", "write", "admin"]) + }) + + it("returns required when existing is empty", () => { + // #given + const existing: string[] = [] + const required = ["read", "write"] + + // #when + const result = mergeScopes(existing, required) + + // #then + expect(result).toEqual(["read", "write"]) + }) + + it("returns existing when required is empty", () => { + // #given + const existing = ["read"] + const required: string[] = [] + + // #when + const result = mergeScopes(existing, required) + + // #then + expect(result).toEqual(["read"]) + }) + + it("deduplicates identical scopes", () => { + // #given + const existing = ["read", "write"] + const required = ["read", "write"] + + // #when + const result = mergeScopes(existing, required) + + // #then + expect(result).toEqual(["read", "write"]) + }) +}) + +describe("isStepUpRequired", () => { + it("returns step-up info for 403 with WWW-Authenticate", () => { + // #given + const statusCode = 403 + const headers = { "www-authenticate": 'Bearer scope="admin"' } + + // #when + const result = isStepUpRequired(statusCode, headers) + + // #then + expect(result).toEqual({ requiredScopes: ["admin"] }) + }) + + it("returns null for non-403 status", () => { + // #given + const statusCode = 401 + const headers = { "www-authenticate": 'Bearer scope="admin"' } + + // #when + const result = isStepUpRequired(statusCode, headers) + + // #then + expect(result).toBeNull() + }) + + it("returns null when no WWW-Authenticate header", () => { + // #given + const statusCode = 403 + const headers = { "content-type": "application/json" } + + // #when + const result = isStepUpRequired(statusCode, headers) + + // #then + expect(result).toBeNull() + }) + + it("handles capitalized WWW-Authenticate header", () => { + // #given + const statusCode = 403 + const headers = { "WWW-Authenticate": 'Bearer scope="read write"' } + + // #when + const result = isStepUpRequired(statusCode, headers) + + // #then + expect(result).toEqual({ requiredScopes: ["read", "write"] }) + }) + + it("returns null for 403 with unparseable WWW-Authenticate", () => { + // #given + const statusCode = 403 + const headers = { "www-authenticate": 'Basic realm="example"' } + + // #when + const result = isStepUpRequired(statusCode, headers) + + // #then + expect(result).toBeNull() + }) +}) diff --git a/src/features/mcp-oauth/step-up.ts b/src/features/mcp-oauth/step-up.ts new file mode 100644 index 00000000..093846ad --- /dev/null +++ b/src/features/mcp-oauth/step-up.ts @@ -0,0 +1,79 @@ +export interface StepUpInfo { + requiredScopes: string[] + error?: string + errorDescription?: string +} + +export function parseWwwAuthenticate(header: string): StepUpInfo | null { + const trimmed = header.trim() + const lowerHeader = trimmed.toLowerCase() + const bearerIndex = lowerHeader.indexOf("bearer") + if (bearerIndex === -1) { + return null + } + + const params = trimmed.slice(bearerIndex + "bearer".length).trim() + if (params.length === 0) { + return null + } + + const scope = extractParam(params, "scope") + if (scope === null) { + return null + } + + const requiredScopes = scope + .split(/\s+/) + .filter((s) => s.length > 0) + + if (requiredScopes.length === 0) { + return null + } + + const info: StepUpInfo = { requiredScopes } + + const error = extractParam(params, "error") + if (error !== null) { + info.error = error + } + + const errorDescription = extractParam(params, "error_description") + if (errorDescription !== null) { + info.errorDescription = errorDescription + } + + return info +} + +function extractParam(params: string, name: string): string | null { + const quotedPattern = new RegExp(`${name}="([^"]*)"`) + const quotedMatch = quotedPattern.exec(params) + if (quotedMatch) { + return quotedMatch[1] + } + + const unquotedPattern = new RegExp(`${name}=([^\\s,]+)`) + const unquotedMatch = unquotedPattern.exec(params) + return unquotedMatch?.[1] ?? null +} + +export function mergeScopes(existing: string[], required: string[]): string[] { + const set = new Set(existing) + for (const scope of required) { + set.add(scope) + } + return [...set] +} + +export function isStepUpRequired(statusCode: number, headers: Record): StepUpInfo | null { + if (statusCode !== 403) { + return null + } + + const wwwAuth = headers["www-authenticate"] ?? headers["WWW-Authenticate"] + if (!wwwAuth) { + return null + } + + return parseWwwAuthenticate(wwwAuth) +} diff --git a/src/features/mcp-oauth/storage.test.ts b/src/features/mcp-oauth/storage.test.ts new file mode 100644 index 00000000..e5570709 --- /dev/null +++ b/src/features/mcp-oauth/storage.test.ts @@ -0,0 +1,136 @@ +import { describe, expect, test, beforeEach, afterEach } from "bun:test" +import { existsSync, mkdirSync, rmSync, readFileSync, statSync, writeFileSync } from "node:fs" +import { join } from "node:path" +import { tmpdir } from "node:os" +import { + deleteToken, + getMcpOauthStoragePath, + listAllTokens, + listTokensByHost, + loadToken, + saveToken, +} from "./storage" +import type { OAuthTokenData } from "./storage" + +describe("mcp-oauth storage", () => { + const TEST_CONFIG_DIR = join(tmpdir(), "mcp-oauth-test-" + Date.now()) + let originalConfigDir: string | undefined + + beforeEach(() => { + originalConfigDir = process.env.OPENCODE_CONFIG_DIR + process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR + if (!existsSync(TEST_CONFIG_DIR)) { + mkdirSync(TEST_CONFIG_DIR, { recursive: true }) + } + }) + + afterEach(() => { + if (originalConfigDir === undefined) { + delete process.env.OPENCODE_CONFIG_DIR + } else { + process.env.OPENCODE_CONFIG_DIR = originalConfigDir + } + if (existsSync(TEST_CONFIG_DIR)) { + rmSync(TEST_CONFIG_DIR, { recursive: true, force: true }) + } + }) + + test("should save tokens with {host}/{resource} key and set 0600 permissions", () => { + // #given + const token: OAuthTokenData = { + accessToken: "access-1", + refreshToken: "refresh-1", + expiresAt: 1710000000, + clientInfo: { clientId: "client-1", clientSecret: "secret-1" }, + } + + // #when + const success = saveToken("https://example.com:443", "mcp/v1", token) + const storagePath = getMcpOauthStoragePath() + const parsed = JSON.parse(readFileSync(storagePath, "utf-8")) as Record + const mode = statSync(storagePath).mode & 0o777 + + // #then + expect(success).toBe(true) + expect(Object.keys(parsed)).toEqual(["example.com/mcp/v1"]) + expect(parsed["example.com/mcp/v1"].accessToken).toBe("access-1") + expect(mode).toBe(0o600) + }) + + test("should load a saved token", () => { + // #given + const token: OAuthTokenData = { accessToken: "access-2", refreshToken: "refresh-2" } + saveToken("api.example.com", "resource-a", token) + + // #when + const loaded = loadToken("api.example.com:8443", "resource-a") + + // #then + expect(loaded).toEqual(token) + }) + + test("should delete a token", () => { + // #given + const token: OAuthTokenData = { accessToken: "access-3" } + saveToken("api.example.com", "resource-b", token) + + // #when + const success = deleteToken("api.example.com", "resource-b") + const loaded = loadToken("api.example.com", "resource-b") + + // #then + expect(success).toBe(true) + expect(loaded).toBeNull() + }) + + test("should list tokens by host", () => { + // #given + saveToken("api.example.com", "resource-a", { accessToken: "access-a" }) + saveToken("api.example.com", "resource-b", { accessToken: "access-b" }) + saveToken("other.example.com", "resource-c", { accessToken: "access-c" }) + + // #when + const entries = listTokensByHost("api.example.com:5555") + + // #then + expect(Object.keys(entries).sort()).toEqual([ + "api.example.com/resource-a", + "api.example.com/resource-b", + ]) + expect(entries["api.example.com/resource-a"].accessToken).toBe("access-a") + }) + + test("should handle missing storage file", () => { + // #given + const storagePath = getMcpOauthStoragePath() + if (existsSync(storagePath)) { + rmSync(storagePath, { force: true }) + } + + // #when + const loaded = loadToken("api.example.com", "resource-a") + const entries = listTokensByHost("api.example.com") + + // #then + expect(loaded).toBeNull() + expect(entries).toEqual({}) + }) + + test("should handle invalid JSON", () => { + // #given + const storagePath = getMcpOauthStoragePath() + const dir = join(storagePath, "..") + if (!existsSync(dir)) { + mkdirSync(dir, { recursive: true }) + } + writeFileSync(storagePath, "{not-valid-json", "utf-8") + + // #when + const loaded = loadToken("api.example.com", "resource-a") + const entries = listTokensByHost("api.example.com") + + // #then + expect(loaded).toBeNull() + expect(entries).toEqual({}) + }) +}) diff --git a/src/features/mcp-oauth/storage.ts b/src/features/mcp-oauth/storage.ts new file mode 100644 index 00000000..d041bdfd --- /dev/null +++ b/src/features/mcp-oauth/storage.ts @@ -0,0 +1,153 @@ +import { chmodSync, existsSync, mkdirSync, readFileSync, unlinkSync, writeFileSync } from "node:fs" +import { dirname, join } from "node:path" +import { getOpenCodeConfigDir } from "../../shared" + +export interface OAuthTokenData { + accessToken: string + refreshToken?: string + expiresAt?: number + clientInfo?: { + clientId: string + clientSecret?: string + } +} + +type TokenStore = Record + +const STORAGE_FILE_NAME = "mcp-oauth.json" + +export function getMcpOauthStoragePath(): string { + return join(getOpenCodeConfigDir({ binary: "opencode" }), STORAGE_FILE_NAME) +} + +function normalizeHost(serverHost: string): string { + let host = serverHost.trim() + if (!host) return host + + if (host.includes("://")) { + try { + host = new URL(host).hostname + } catch { + host = host.split("/")[0] + } + } else { + host = host.split("/")[0] + } + + if (host.startsWith("[")) { + const closing = host.indexOf("]") + if (closing !== -1) { + host = host.slice(0, closing + 1) + } + return host + } + + if (host.includes(":")) { + host = host.split(":")[0] + } + + return host +} + +function normalizeResource(resource: string): string { + return resource.replace(/^\/+/, "") +} + +function buildKey(serverHost: string, resource: string): string { + const host = normalizeHost(serverHost) + const normalizedResource = normalizeResource(resource) + return `${host}/${normalizedResource}` +} + +function readStore(): TokenStore | null { + const filePath = getMcpOauthStoragePath() + if (!existsSync(filePath)) { + return null + } + + try { + const content = readFileSync(filePath, "utf-8") + return JSON.parse(content) as TokenStore + } catch { + return null + } +} + +function writeStore(store: TokenStore): boolean { + const filePath = getMcpOauthStoragePath() + + try { + const dir = dirname(filePath) + if (!existsSync(dir)) { + mkdirSync(dir, { recursive: true }) + } + + writeFileSync(filePath, JSON.stringify(store, null, 2), { encoding: "utf-8", mode: 0o600 }) + chmodSync(filePath, 0o600) + return true + } catch { + return false + } +} + +export function loadToken(serverHost: string, resource: string): OAuthTokenData | null { + const store = readStore() + if (!store) return null + + const key = buildKey(serverHost, resource) + return store[key] ?? null +} + +export function saveToken(serverHost: string, resource: string, token: OAuthTokenData): boolean { + const store = readStore() ?? {} + const key = buildKey(serverHost, resource) + store[key] = token + return writeStore(store) +} + +export function deleteToken(serverHost: string, resource: string): boolean { + const store = readStore() + if (!store) return true + + const key = buildKey(serverHost, resource) + if (!(key in store)) { + return true + } + + delete store[key] + + if (Object.keys(store).length === 0) { + try { + const filePath = getMcpOauthStoragePath() + if (existsSync(filePath)) { + unlinkSync(filePath) + } + return true + } catch { + return false + } + } + + return writeStore(store) +} + +export function listTokensByHost(serverHost: string): TokenStore { + const store = readStore() + if (!store) return {} + + const host = normalizeHost(serverHost) + const prefix = `${host}/` + const result: TokenStore = {} + + for (const [key, value] of Object.entries(store)) { + if (key.startsWith(prefix)) { + result[key] = value + } + } + + return result +} + +export function listAllTokens(): TokenStore { + return readStore() ?? {} +} diff --git a/src/features/skill-mcp-manager/manager.test.ts b/src/features/skill-mcp-manager/manager.test.ts index 5c9120d4..4170b2eb 100644 --- a/src/features/skill-mcp-manager/manager.test.ts +++ b/src/features/skill-mcp-manager/manager.test.ts @@ -3,8 +3,6 @@ import { SkillMcpManager } from "./manager" import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types" import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" - - // Mock the MCP SDK transports to avoid network calls const mockHttpConnect = mock(() => Promise.reject(new Error("Mocked HTTP connection failure"))) const mockHttpClose = mock(() => Promise.resolve()) @@ -24,6 +22,21 @@ mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ }, })) +const mockTokens = mock(() => null as { accessToken: string; refreshToken?: string; expiresAt?: number } | null) +const mockLogin = mock(() => Promise.resolve({ accessToken: "new-token" })) + +mock.module("../mcp-oauth/provider", () => ({ + McpOAuthProvider: class MockMcpOAuthProvider { + constructor(public options: { serverUrl: string; clientId?: string; scopes?: string[] }) {} + tokens() { + return mockTokens() + } + async login() { + return mockLogin() + } + }, +})) + @@ -518,7 +531,6 @@ describe("SkillMcpManager", () => { skillName: "retry-skill", } - // Mock client that fails first time with "Not connected", then succeeds let callCount = 0 const mockClient = { callTool: mock(async () => { @@ -531,7 +543,6 @@ describe("SkillMcpManager", () => { close: mock(() => Promise.resolve()), } - // Spy on getOrCreateClientWithRetry to inject mock client const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry") getOrCreateSpy.mockResolvedValue(mockClient) @@ -539,9 +550,9 @@ describe("SkillMcpManager", () => { const result = await manager.callTool(info, context, "test-tool", {}) // #then - expect(callCount).toBe(2) // First call fails, second succeeds + expect(callCount).toBe(2) expect(result).toEqual([{ type: "text", text: "success" }]) - expect(getOrCreateSpy).toHaveBeenCalledTimes(2) // Called twice due to retry + expect(getOrCreateSpy).toHaveBeenCalledTimes(2) }) it("should fail after 3 retry attempts", async () => { @@ -558,7 +569,6 @@ describe("SkillMcpManager", () => { skillName: "fail-skill", } - // Mock client that always fails with "Not connected" const mockClient = { callTool: mock(async () => { throw new Error("Not connected") @@ -573,7 +583,7 @@ describe("SkillMcpManager", () => { await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow( /Failed after 3 reconnection attempts/ ) - expect(getOrCreateSpy).toHaveBeenCalledTimes(3) // Initial + 2 retries + expect(getOrCreateSpy).toHaveBeenCalledTimes(3) }) it("should not retry on non-connection errors", async () => { @@ -590,7 +600,6 @@ describe("SkillMcpManager", () => { skillName: "error-skill", } - // Mock client that fails with non-connection error const mockClient = { callTool: mock(async () => { throw new Error("Tool not found") @@ -605,7 +614,194 @@ describe("SkillMcpManager", () => { await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow( "Tool not found" ) - expect(getOrCreateSpy).toHaveBeenCalledTimes(1) // No retry + expect(getOrCreateSpy).toHaveBeenCalledTimes(1) + }) + }) + + describe("OAuth integration", () => { + beforeEach(() => { + mockTokens.mockClear() + mockLogin.mockClear() + }) + + it("injects Authorization header when oauth config has stored tokens", async () => { + // #given + const info: SkillMcpClientInfo = { + serverName: "oauth-server", + skillName: "oauth-skill", + sessionID: "session-oauth-1", + } + const config: ClaudeCodeMcpServer = { + url: "https://mcp.example.com/mcp", + oauth: { + clientId: "my-client", + scopes: ["read", "write"], + }, + } + mockTokens.mockReturnValue({ accessToken: "stored-access-token" }) + + // #when + try { + await manager.getOrCreateClient(info, config) + } catch { /* connection fails in test */ } + + // #then + const headers = lastTransportInstance.options?.requestInit?.headers as Record | undefined + expect(headers?.Authorization).toBe("Bearer stored-access-token") + }) + + it("does not inject Authorization header when no stored tokens exist and login fails", async () => { + // #given + const info: SkillMcpClientInfo = { + serverName: "oauth-no-token", + skillName: "oauth-skill", + sessionID: "session-oauth-2", + } + const config: ClaudeCodeMcpServer = { + url: "https://mcp.example.com/mcp", + oauth: { + clientId: "my-client", + }, + } + mockTokens.mockReturnValue(null) + mockLogin.mockRejectedValue(new Error("Login failed")) + + // #when + try { + await manager.getOrCreateClient(info, config) + } catch { /* connection fails in test */ } + + // #then + const headers = lastTransportInstance.options?.requestInit?.headers as Record | undefined + expect(headers?.Authorization).toBeUndefined() + }) + + it("preserves existing static headers alongside OAuth token", async () => { + // #given + const info: SkillMcpClientInfo = { + serverName: "oauth-with-headers", + skillName: "oauth-skill", + sessionID: "session-oauth-3", + } + const config: ClaudeCodeMcpServer = { + url: "https://mcp.example.com/mcp", + headers: { + "X-Custom": "custom-value", + }, + oauth: { + clientId: "my-client", + }, + } + mockTokens.mockReturnValue({ accessToken: "oauth-token" }) + + // #when + try { + await manager.getOrCreateClient(info, config) + } catch { /* connection fails in test */ } + + // #then + const headers = lastTransportInstance.options?.requestInit?.headers as Record | undefined + expect(headers?.["X-Custom"]).toBe("custom-value") + expect(headers?.Authorization).toBe("Bearer oauth-token") + }) + + it("does not create auth provider when oauth config is absent", async () => { + // #given + const info: SkillMcpClientInfo = { + serverName: "no-oauth-server", + skillName: "test-skill", + sessionID: "session-no-oauth", + } + const config: ClaudeCodeMcpServer = { + url: "https://mcp.example.com/mcp", + headers: { + Authorization: "Bearer static-token", + }, + } + + // #when + try { + await manager.getOrCreateClient(info, config) + } catch { /* connection fails in test */ } + + // #then + const headers = lastTransportInstance.options?.requestInit?.headers as Record | undefined + expect(headers?.Authorization).toBe("Bearer static-token") + expect(mockTokens).not.toHaveBeenCalled() + }) + + it("handles step-up auth by triggering re-login on 403 with scope", async () => { + // #given + const info: SkillMcpClientInfo = { + serverName: "stepup-server", + skillName: "stepup-skill", + sessionID: "session-stepup-1", + } + const config: ClaudeCodeMcpServer = { + url: "https://mcp.example.com/mcp", + oauth: { + clientId: "my-client", + scopes: ["read"], + }, + } + const context: SkillMcpServerContext = { + config, + skillName: "stepup-skill", + } + + mockTokens.mockReturnValue({ accessToken: "initial-token" }) + mockLogin.mockResolvedValue({ accessToken: "upgraded-token" }) + + let callCount = 0 + const mockClient = { + callTool: mock(async () => { + callCount++ + if (callCount === 1) { + throw new Error('403 WWW-Authenticate: Bearer scope="admin write"') + } + return { content: [{ type: "text", text: "success" }] } + }), + close: mock(() => Promise.resolve()), + } + + const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry") + getOrCreateSpy.mockResolvedValue(mockClient) + + // #when + const result = await manager.callTool(info, context, "test-tool", {}) + + // #then + expect(result).toEqual([{ type: "text", text: "success" }]) + expect(mockLogin).toHaveBeenCalled() + }) + + it("does not attempt step-up when oauth config is absent", async () => { + // #given + const info: SkillMcpClientInfo = { + serverName: "no-stepup-server", + skillName: "no-stepup-skill", + sessionID: "session-no-stepup", + } + const context: SkillMcpServerContext = { + config: { + url: "https://mcp.example.com/mcp", + }, + skillName: "no-stepup-skill", + } + + const mockClient = { + callTool: mock(async () => { + throw new Error('403 WWW-Authenticate: Bearer scope="admin"') + }), + close: mock(() => Promise.resolve()), + } + + const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry") + getOrCreateSpy.mockResolvedValue(mockClient) + + // #when / #then + await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(/403/) + expect(mockLogin).not.toHaveBeenCalled() }) }) }) diff --git a/src/features/skill-mcp-manager/manager.ts b/src/features/skill-mcp-manager/manager.ts index b56cda8e..0b43ca0c 100644 --- a/src/features/skill-mcp-manager/manager.ts +++ b/src/features/skill-mcp-manager/manager.ts @@ -4,6 +4,8 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/ import type { Tool, Resource, Prompt } from "@modelcontextprotocol/sdk/types.js" import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types" import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander" +import { McpOAuthProvider } from "../mcp-oauth/provider" +import { isStepUpRequired, mergeScopes } from "../mcp-oauth/step-up" import { createCleanMcpEnvironment } from "./env-cleaner" import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types" @@ -60,6 +62,7 @@ function getConnectionType(config: ClaudeCodeMcpServer): ConnectionType | null { export class SkillMcpManager { private clients: Map = new Map() private pendingConnections: Map> = new Map() + private authProviders: Map = new Map() private cleanupRegistered = false private cleanupInterval: ReturnType | null = null private readonly IDLE_TIMEOUT = 5 * 60 * 1000 @@ -68,6 +71,28 @@ export class SkillMcpManager { return `${info.sessionID}:${info.skillName}:${info.serverName}` } + /** + * Get or create an McpOAuthProvider for a given server URL + oauth config. + * Providers are cached by server URL to reuse tokens across reconnections. + */ + private getOrCreateAuthProvider( + serverUrl: string, + oauth: NonNullable + ): McpOAuthProvider { + const existing = this.authProviders.get(serverUrl) + if (existing) { + return existing + } + + const provider = new McpOAuthProvider({ + serverUrl, + clientId: oauth.clientId, + scopes: oauth.scopes, + }) + this.authProviders.set(serverUrl, provider) + return provider + } + private registerProcessCleanup(): void { if (this.cleanupRegistered) return this.cleanupRegistered = true @@ -204,7 +229,30 @@ export class SkillMcpManager { // Build request init with headers if provided const requestInit: RequestInit = {} if (config.headers && Object.keys(config.headers).length > 0) { - requestInit.headers = config.headers + requestInit.headers = { ...config.headers } + } + + let authProvider: McpOAuthProvider | undefined + if (config.oauth) { + authProvider = this.getOrCreateAuthProvider(config.url, config.oauth) + let tokenData = authProvider.tokens() + + const isExpired = tokenData?.expiresAt != null && tokenData.expiresAt < Math.floor(Date.now() / 1000) + if (!tokenData || isExpired) { + try { + tokenData = await authProvider.login() + } catch { + // Login failed — proceed without auth header + } + } + + if (tokenData) { + const existingHeaders = (requestInit.headers ?? {}) as Record + requestInit.headers = { + ...existingHeaders, + Authorization: `Bearer ${tokenData.accessToken}`, + } + } } const transport = new StreamableHTTPClientTransport(url, { @@ -460,6 +508,12 @@ export class SkillMcpManager { lastError = error instanceof Error ? error : new Error(String(error)) const errorMessage = lastError.message.toLowerCase() + const stepUpHandled = await this.handleStepUpIfNeeded(lastError, config) + if (stepUpHandled) { + await this.forceReconnect(info) + continue + } + if (!errorMessage.includes("not connected")) { throw lastError } @@ -470,23 +524,66 @@ export class SkillMcpManager { ) } - const key = this.getClientKey(info) - const existing = this.clients.get(key) - if (existing) { - this.clients.delete(key) - try { - await existing.client.close() - } catch { /* process may already be terminated */ } - try { - await existing.transport.close() - } catch { /* transport may already be terminated */ } - } + await this.forceReconnect(info) } } throw lastError || new Error("Operation failed with unknown error") } + private async handleStepUpIfNeeded( + error: Error, + config: ClaudeCodeMcpServer + ): Promise { + if (!config.oauth || !config.url) { + return false + } + + const statusMatch = /\b403\b/.exec(error.message) + if (!statusMatch) { + return false + } + + const headers: Record = {} + const wwwAuthMatch = /WWW-Authenticate:\s*(.+)/i.exec(error.message) + if (wwwAuthMatch?.[1]) { + headers["www-authenticate"] = wwwAuthMatch[1] + } + + const stepUp = isStepUpRequired(403, headers) + if (!stepUp) { + return false + } + + const currentScopes = config.oauth.scopes ?? [] + const merged = mergeScopes(currentScopes, stepUp.requiredScopes) + config.oauth.scopes = merged + + this.authProviders.delete(config.url) + const provider = this.getOrCreateAuthProvider(config.url, config.oauth) + + try { + await provider.login() + return true + } catch { + return false + } + } + + private async forceReconnect(info: SkillMcpClientInfo): Promise { + const key = this.getClientKey(info) + const existing = this.clients.get(key) + if (existing) { + this.clients.delete(key) + try { + await existing.client.close() + } catch { /* process may already be terminated */ } + try { + await existing.transport.close() + } catch { /* transport may already be terminated */ } + } + } + private async getOrCreateClientWithRetry( info: SkillMcpClientInfo, config: ClaudeCodeMcpServer diff --git a/src/types/request-info.d.ts b/src/types/request-info.d.ts new file mode 100644 index 00000000..69e4d481 --- /dev/null +++ b/src/types/request-info.d.ts @@ -0,0 +1,5 @@ +declare global { + type RequestInfo = string | URL +} + +export {}