feat(mcp-oauth): add full OAuth 2.1 authentication for MCP servers (#1169)
* feat(mcp-oauth): add oauth field to ClaudeCodeMcpServer schema Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai> * feat(mcp-oauth): add RFC 7591 Dynamic Client Registration * feat(mcp-oauth): add RFC 9728 PRM + RFC 8414 AS discovery * feat(mcp-oauth): add secure token storage with {host}/{resource} key format * feat(mcp-oauth): add dynamic port OAuth callback server * feat(mcp-oauth): add RFC 8707 Resource Indicators * feat(mcp-oauth): implement full-spec McpOAuthProvider * feat(mcp-oauth): add step-up authorization handler * feat(mcp-oauth): integrate authProvider into SkillMcpManager * feat(doctor): add MCP OAuth token status check * feat(cli): add mcp oauth subcommand structure * feat(cli): implement mcp oauth login command * fix(mcp-oauth): address cubic review — security, correctness, and test issues - Remove @ts-nocheck from provider.ts, storage.ts, provider.test.ts - Fix server resource leak on missing code/state (close + reject) - Fix command injection in openBrowser (spawn array args, cross-platform) - Mock McpOAuthProvider in login.test.ts for deterministic CI - Recreate auth provider with merged scopes in step-up flow - Add listAllTokens() for global status listing - Fix logout to accept --server-url for correct token deletion - Support both quoted and unquoted WWW-Authenticate params (RFC 2617) - Save/restore OPENCODE_CONFIG_DIR in storage.test.ts - Fix index.test.ts: vitest → bun:test * fix(mcp-oauth): use explorer instead of cmd /c start on Windows to prevent shell injection * fix(mcp-oauth): address remaining cubic review issues - Add 5-minute timeout to provider callback server to prevent indefinite hangs - Persist client registration from token storage across process restarts - Require --server-url for logout to match token storage key format - Use listTokensByHost for server-specific status lookups - Fix callback-server test to handle promise rejection ordering - Fix provider test port expectations (8912 → 19877) - Fix cli-guide.md duplicate Section 7 numbering - Fix manager test for login-on-missing-tokens behavior * fix(mcp-oauth): address final review issues - P1: Redact token values in status.ts output to prevent credential leakage - P2: Read OAuth error response body before throwing in token exchange - Test: Fix mcp-oauth doctor test to use epoch seconds (not milliseconds) --------- Co-authored-by: justsisyphus <justsisyphus@users.noreply.github.com> Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
a94fbadd57
commit
dcda8769cc
1
.gitignore
vendored
1
.gitignore
vendored
@ -33,3 +33,4 @@ yarn.lock
|
||||
test-injection/
|
||||
notepad.md
|
||||
oauth-success.html
|
||||
.188e87dbff6e7fd9-00000000.bun-build
|
||||
|
||||
@ -134,7 +134,41 @@ bunx oh-my-opencode run [prompt]
|
||||
|
||||
---
|
||||
|
||||
## 6. `auth` - Authentication Management
|
||||
## 6. `mcp oauth` - MCP OAuth Management
|
||||
|
||||
Manages OAuth 2.1 authentication for remote MCP servers.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Login to an OAuth-protected MCP server
|
||||
bunx oh-my-opencode mcp oauth login <server-name> --server-url https://api.example.com
|
||||
|
||||
# Login with explicit client ID and scopes
|
||||
bunx oh-my-opencode mcp oauth login my-api --server-url https://api.example.com --client-id my-client --scopes "read,write"
|
||||
|
||||
# Remove stored OAuth tokens
|
||||
bunx oh-my-opencode mcp oauth logout <server-name>
|
||||
|
||||
# Check OAuth token status
|
||||
bunx oh-my-opencode mcp oauth status [server-name]
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `--server-url <url>` | MCP server URL (required for login) |
|
||||
| `--client-id <id>` | OAuth client ID (optional if server supports Dynamic Client Registration) |
|
||||
| `--scopes <scopes>` | Comma-separated OAuth scopes |
|
||||
|
||||
### Token Storage
|
||||
|
||||
Tokens are stored in `~/.config/opencode/mcp-oauth.json` with `0600` permissions (owner read/write only). Key format: `{serverHost}/{resource}`.
|
||||
|
||||
---
|
||||
|
||||
## 7. `auth` - Authentication Management
|
||||
|
||||
Manages Google Antigravity OAuth authentication. Required for using Gemini models.
|
||||
|
||||
@ -153,7 +187,7 @@ bunx oh-my-opencode auth status
|
||||
|
||||
---
|
||||
|
||||
## 7. Configuration Files
|
||||
## 8. Configuration Files
|
||||
|
||||
The CLI searches for configuration files in the following locations (in priority order):
|
||||
|
||||
@ -183,7 +217,7 @@ Configuration files support **JSONC (JSON with Comments)** format. You can use c
|
||||
|
||||
---
|
||||
|
||||
## 8. Troubleshooting
|
||||
## 9. Troubleshooting
|
||||
|
||||
### "OpenCode version too old" Error
|
||||
|
||||
@ -213,7 +247,7 @@ bunx oh-my-opencode doctor --category authentication
|
||||
|
||||
---
|
||||
|
||||
## 9. Non-Interactive Mode
|
||||
## 10. Non-Interactive Mode
|
||||
|
||||
Use the `--no-tui` option for CI/CD environments.
|
||||
|
||||
@ -227,7 +261,7 @@ bunx oh-my-opencode doctor --json > doctor-report.json
|
||||
|
||||
---
|
||||
|
||||
## 10. Developer Information
|
||||
## 11. Developer Information
|
||||
|
||||
### CLI Structure
|
||||
|
||||
|
||||
@ -521,6 +521,37 @@ mcp:
|
||||
|
||||
The `skill_mcp` tool invokes these operations with full schema discovery.
|
||||
|
||||
#### OAuth-Enabled MCPs
|
||||
|
||||
Skills can define OAuth-protected remote MCP servers. OAuth 2.1 with full RFC compliance (RFC 9728, 8414, 8707, 7591) is supported:
|
||||
|
||||
```yaml
|
||||
---
|
||||
description: My API skill
|
||||
mcp:
|
||||
my-api:
|
||||
url: https://api.example.com/mcp
|
||||
oauth:
|
||||
clientId: ${CLIENT_ID}
|
||||
scopes: ["read", "write"]
|
||||
---
|
||||
```
|
||||
|
||||
When a skill MCP has `oauth` configured:
|
||||
- **Auto-discovery**: Fetches `/.well-known/oauth-protected-resource` (RFC 9728), falls back to `/.well-known/oauth-authorization-server` (RFC 8414)
|
||||
- **Dynamic Client Registration**: Auto-registers with servers supporting RFC 7591 (clientId becomes optional)
|
||||
- **PKCE**: Mandatory for all flows
|
||||
- **Resource Indicators**: Auto-generated from MCP URL per RFC 8707
|
||||
- **Token Storage**: Persisted in `~/.config/opencode/mcp-oauth.json` (chmod 0600)
|
||||
- **Auto-refresh**: Tokens refresh on 401; step-up authorization on 403 with `WWW-Authenticate`
|
||||
- **Dynamic Port**: OAuth callback server uses an auto-discovered available port
|
||||
|
||||
Pre-authenticate via CLI:
|
||||
|
||||
```bash
|
||||
bunx oh-my-opencode mcp oauth login <server-name> --server-url https://api.example.com
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context Injection
|
||||
|
||||
@ -8,6 +8,7 @@ import { getDependencyCheckDefinitions } from "./dependencies"
|
||||
import { getGhCliCheckDefinition } from "./gh"
|
||||
import { getLspCheckDefinition } from "./lsp"
|
||||
import { getMcpCheckDefinitions } from "./mcp"
|
||||
import { getMcpOAuthCheckDefinition } from "./mcp-oauth"
|
||||
import { getVersionCheckDefinition } from "./version"
|
||||
|
||||
export * from "./opencode"
|
||||
@ -19,6 +20,7 @@ export * from "./dependencies"
|
||||
export * from "./gh"
|
||||
export * from "./lsp"
|
||||
export * from "./mcp"
|
||||
export * from "./mcp-oauth"
|
||||
export * from "./version"
|
||||
|
||||
export function getAllCheckDefinitions(): CheckDefinition[] {
|
||||
@ -32,6 +34,7 @@ export function getAllCheckDefinitions(): CheckDefinition[] {
|
||||
getGhCliCheckDefinition(),
|
||||
getLspCheckDefinition(),
|
||||
...getMcpCheckDefinitions(),
|
||||
getMcpOAuthCheckDefinition(),
|
||||
getVersionCheckDefinition(),
|
||||
]
|
||||
}
|
||||
|
||||
133
src/cli/doctor/checks/mcp-oauth.test.ts
Normal file
133
src/cli/doctor/checks/mcp-oauth.test.ts
Normal file
@ -0,0 +1,133 @@
|
||||
import { describe, it, expect, spyOn, afterEach } from "bun:test"
|
||||
import * as mcpOauth from "./mcp-oauth"
|
||||
|
||||
describe("mcp-oauth check", () => {
|
||||
describe("getMcpOAuthCheckDefinition", () => {
|
||||
it("returns check definition with correct properties", () => {
|
||||
// #given
|
||||
// #when getting definition
|
||||
const def = mcpOauth.getMcpOAuthCheckDefinition()
|
||||
|
||||
// #then should have correct structure
|
||||
expect(def.id).toBe("mcp-oauth-tokens")
|
||||
expect(def.name).toBe("MCP OAuth Tokens")
|
||||
expect(def.category).toBe("tools")
|
||||
expect(def.critical).toBe(false)
|
||||
expect(typeof def.check).toBe("function")
|
||||
})
|
||||
})
|
||||
|
||||
describe("checkMcpOAuthTokens", () => {
|
||||
let readStoreSpy: ReturnType<typeof spyOn>
|
||||
|
||||
afterEach(() => {
|
||||
readStoreSpy?.mockRestore()
|
||||
})
|
||||
|
||||
it("returns skip when no tokens stored", async () => {
|
||||
// #given no OAuth tokens configured
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue(null)
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should skip
|
||||
expect(result.status).toBe("skip")
|
||||
expect(result.message).toContain("No OAuth")
|
||||
})
|
||||
|
||||
it("returns pass when all tokens valid", async () => {
|
||||
// #given valid tokens with future expiry (expiresAt is in epoch seconds)
|
||||
const futureTime = Math.floor(Date.now() / 1000) + 3600
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"example.com/resource1": {
|
||||
accessToken: "token1",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
"example.com/resource2": {
|
||||
accessToken: "token2",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should pass
|
||||
expect(result.status).toBe("pass")
|
||||
expect(result.message).toContain("2")
|
||||
expect(result.message).toContain("valid")
|
||||
})
|
||||
|
||||
it("returns warn when some tokens expired", async () => {
|
||||
// #given mix of valid and expired tokens (expiresAt is in epoch seconds)
|
||||
const futureTime = Math.floor(Date.now() / 1000) + 3600
|
||||
const pastTime = Math.floor(Date.now() / 1000) - 3600
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"example.com/resource1": {
|
||||
accessToken: "token1",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
"example.com/resource2": {
|
||||
accessToken: "token2",
|
||||
expiresAt: pastTime,
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should warn
|
||||
expect(result.status).toBe("warn")
|
||||
expect(result.message).toContain("1")
|
||||
expect(result.message).toContain("expired")
|
||||
expect(result.details?.some((d: string) => d.includes("Expired"))).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it("returns pass when tokens have no expiry", async () => {
|
||||
// #given tokens without expiry info
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"example.com/resource1": {
|
||||
accessToken: "token1",
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should pass (no expiry = assume valid)
|
||||
expect(result.status).toBe("pass")
|
||||
expect(result.message).toContain("1")
|
||||
})
|
||||
|
||||
it("includes token details in output", async () => {
|
||||
// #given multiple tokens
|
||||
const futureTime = Math.floor(Date.now() / 1000) + 3600
|
||||
readStoreSpy = spyOn(mcpOauth, "readTokenStore").mockReturnValue({
|
||||
"api.example.com/v1": {
|
||||
accessToken: "token1",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
"auth.example.com/oauth": {
|
||||
accessToken: "token2",
|
||||
expiresAt: futureTime,
|
||||
},
|
||||
})
|
||||
|
||||
// #when checking OAuth tokens
|
||||
const result = await mcpOauth.checkMcpOAuthTokens()
|
||||
|
||||
// #then should list tokens in details
|
||||
expect(result.details).toBeDefined()
|
||||
expect(result.details?.length).toBeGreaterThan(0)
|
||||
expect(
|
||||
result.details?.some((d: string) => d.includes("api.example.com"))
|
||||
).toBe(true)
|
||||
expect(
|
||||
result.details?.some((d: string) => d.includes("auth.example.com"))
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
80
src/cli/doctor/checks/mcp-oauth.ts
Normal file
80
src/cli/doctor/checks/mcp-oauth.ts
Normal file
@ -0,0 +1,80 @@
|
||||
import type { CheckResult, CheckDefinition } from "../types"
|
||||
import { CHECK_IDS, CHECK_NAMES } from "../constants"
|
||||
import { getMcpOauthStoragePath } from "../../../features/mcp-oauth/storage"
|
||||
import { existsSync, readFileSync } from "node:fs"
|
||||
|
||||
interface OAuthTokenData {
|
||||
accessToken: string
|
||||
refreshToken?: string
|
||||
expiresAt?: number
|
||||
clientInfo?: {
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
}
|
||||
}
|
||||
|
||||
type TokenStore = Record<string, OAuthTokenData>
|
||||
|
||||
export function readTokenStore(): TokenStore | null {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
if (!existsSync(filePath)) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(filePath, "utf-8")
|
||||
return JSON.parse(content) as TokenStore
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkMcpOAuthTokens(): Promise<CheckResult> {
|
||||
const store = readTokenStore()
|
||||
|
||||
if (!store || Object.keys(store).length === 0) {
|
||||
return {
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
status: "skip",
|
||||
message: "No OAuth tokens configured",
|
||||
details: ["Optional: Configure OAuth tokens for MCP servers"],
|
||||
}
|
||||
}
|
||||
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
const tokens = Object.entries(store)
|
||||
const expiredTokens = tokens.filter(
|
||||
([, token]) => token.expiresAt && token.expiresAt < now
|
||||
)
|
||||
|
||||
if (expiredTokens.length > 0) {
|
||||
return {
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
status: "warn",
|
||||
message: `${expiredTokens.length} of ${tokens.length} token(s) expired`,
|
||||
details: [
|
||||
...tokens
|
||||
.filter(([, token]) => !token.expiresAt || token.expiresAt >= now)
|
||||
.map(([key]) => `Valid: ${key}`),
|
||||
...expiredTokens.map(([key]) => `Expired: ${key}`),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
status: "pass",
|
||||
message: `${tokens.length} OAuth token(s) valid`,
|
||||
details: tokens.map(([key]) => `Configured: ${key}`),
|
||||
}
|
||||
}
|
||||
|
||||
export function getMcpOAuthCheckDefinition(): CheckDefinition {
|
||||
return {
|
||||
id: CHECK_IDS.MCP_OAUTH_TOKENS,
|
||||
name: CHECK_NAMES[CHECK_IDS.MCP_OAUTH_TOKENS],
|
||||
category: "tools",
|
||||
check: checkMcpOAuthTokens,
|
||||
critical: false,
|
||||
}
|
||||
}
|
||||
@ -32,6 +32,7 @@ export const CHECK_IDS = {
|
||||
LSP_SERVERS: "lsp-servers",
|
||||
MCP_BUILTIN: "mcp-builtin",
|
||||
MCP_USER: "mcp-user",
|
||||
MCP_OAUTH_TOKENS: "mcp-oauth-tokens",
|
||||
VERSION_STATUS: "version-status",
|
||||
} as const
|
||||
|
||||
@ -50,6 +51,7 @@ export const CHECK_NAMES: Record<string, string> = {
|
||||
[CHECK_IDS.LSP_SERVERS]: "LSP Servers",
|
||||
[CHECK_IDS.MCP_BUILTIN]: "Built-in MCP Servers",
|
||||
[CHECK_IDS.MCP_USER]: "User MCP Configuration",
|
||||
[CHECK_IDS.MCP_OAUTH_TOKENS]: "MCP OAuth Tokens",
|
||||
[CHECK_IDS.VERSION_STATUS]: "Version Status",
|
||||
} as const
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import { install } from "./install"
|
||||
import { run } from "./run"
|
||||
import { getLocalVersion } from "./get-local-version"
|
||||
import { doctor } from "./doctor"
|
||||
import { createMcpOAuthCommand } from "./mcp-oauth"
|
||||
import type { InstallArgs } from "./types"
|
||||
import type { RunOptions } from "./run"
|
||||
import type { GetLocalVersionOptions } from "./get-local-version/types"
|
||||
@ -150,4 +151,6 @@ program
|
||||
console.log(`oh-my-opencode v${VERSION}`)
|
||||
})
|
||||
|
||||
program.addCommand(createMcpOAuthCommand())
|
||||
|
||||
program.parse()
|
||||
|
||||
123
src/cli/mcp-oauth/index.test.ts
Normal file
123
src/cli/mcp-oauth/index.test.ts
Normal file
@ -0,0 +1,123 @@
|
||||
import { describe, it, expect } from "bun:test"
|
||||
import { Command } from "commander"
|
||||
import { createMcpOAuthCommand } from "./index"
|
||||
|
||||
describe("mcp oauth command", () => {
|
||||
|
||||
describe("command structure", () => {
|
||||
it("creates mcp command group with oauth subcommand", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
|
||||
// when
|
||||
const subcommands = mcpCommand.commands.map((cmd: Command) => cmd.name())
|
||||
|
||||
// then
|
||||
expect(subcommands).toContain("oauth")
|
||||
})
|
||||
|
||||
it("oauth subcommand has login, logout, and status subcommands", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
|
||||
// when
|
||||
const subcommands = oauthCommand?.commands.map((cmd: Command) => cmd.name()) ?? []
|
||||
|
||||
// then
|
||||
expect(subcommands).toContain("login")
|
||||
expect(subcommands).toContain("logout")
|
||||
expect(subcommands).toContain("status")
|
||||
})
|
||||
})
|
||||
|
||||
describe("login subcommand", () => {
|
||||
it("exists and has description", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const description = loginCommand?.description() ?? ""
|
||||
|
||||
// then
|
||||
expect(loginCommand).toBeDefined()
|
||||
expect(description).toContain("OAuth")
|
||||
})
|
||||
|
||||
it("accepts --server-url option", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const options = loginCommand?.options ?? []
|
||||
const serverUrlOption = options.find((opt: { long?: string }) => opt.long === "--server-url")
|
||||
|
||||
// then
|
||||
expect(serverUrlOption).toBeDefined()
|
||||
})
|
||||
|
||||
it("accepts --client-id option", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const options = loginCommand?.options ?? []
|
||||
const clientIdOption = options.find((opt: { long?: string }) => opt.long === "--client-id")
|
||||
|
||||
// then
|
||||
expect(clientIdOption).toBeDefined()
|
||||
})
|
||||
|
||||
it("accepts --scopes option", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const loginCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "login")
|
||||
|
||||
// when
|
||||
const options = loginCommand?.options ?? []
|
||||
const scopesOption = options.find((opt: { long?: string }) => opt.long === "--scopes")
|
||||
|
||||
// then
|
||||
expect(scopesOption).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe("logout subcommand", () => {
|
||||
it("exists and has description", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const logoutCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "logout")
|
||||
|
||||
// when
|
||||
const description = logoutCommand?.description() ?? ""
|
||||
|
||||
// then
|
||||
expect(logoutCommand).toBeDefined()
|
||||
expect(description).toContain("tokens")
|
||||
})
|
||||
})
|
||||
|
||||
describe("status subcommand", () => {
|
||||
it("exists and has description", () => {
|
||||
// given
|
||||
const mcpCommand = createMcpOAuthCommand()
|
||||
const oauthCommand = mcpCommand.commands.find((cmd: Command) => cmd.name() === "oauth")
|
||||
const statusCommand = oauthCommand?.commands.find((cmd: Command) => cmd.name() === "status")
|
||||
|
||||
// when
|
||||
const description = statusCommand?.description() ?? ""
|
||||
|
||||
// then
|
||||
expect(statusCommand).toBeDefined()
|
||||
expect(description).toContain("status")
|
||||
})
|
||||
})
|
||||
})
|
||||
43
src/cli/mcp-oauth/index.ts
Normal file
43
src/cli/mcp-oauth/index.ts
Normal file
@ -0,0 +1,43 @@
|
||||
import { Command } from "commander"
|
||||
import { login } from "./login"
|
||||
import { logout } from "./logout"
|
||||
import { status } from "./status"
|
||||
|
||||
export function createMcpOAuthCommand(): Command {
|
||||
const mcp = new Command("mcp").description("MCP server management")
|
||||
|
||||
const oauth = new Command("oauth").description("OAuth token management for MCP servers")
|
||||
|
||||
oauth
|
||||
.command("login <server-name>")
|
||||
.description("Authenticate with an MCP server using OAuth")
|
||||
.option("--server-url <url>", "OAuth server URL (required if not in config)")
|
||||
.option("--client-id <id>", "OAuth client ID (optional, uses DCR if not provided)")
|
||||
.option("--scopes <scopes...>", "OAuth scopes to request")
|
||||
.action(async (serverName: string, options) => {
|
||||
const exitCode = await login(serverName, options)
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
oauth
|
||||
.command("logout <server-name>")
|
||||
.description("Remove stored OAuth tokens for an MCP server")
|
||||
.option("--server-url <url>", "OAuth server URL (use if server name differs from URL)")
|
||||
.action(async (serverName: string, options) => {
|
||||
const exitCode = await logout(serverName, options)
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
oauth
|
||||
.command("status [server-name]")
|
||||
.description("Show OAuth token status for MCP servers")
|
||||
.action(async (serverName: string | undefined) => {
|
||||
const exitCode = await status(serverName)
|
||||
process.exit(exitCode)
|
||||
})
|
||||
|
||||
mcp.addCommand(oauth)
|
||||
return mcp
|
||||
}
|
||||
|
||||
export { login, logout, status }
|
||||
80
src/cli/mcp-oauth/login.test.ts
Normal file
80
src/cli/mcp-oauth/login.test.ts
Normal file
@ -0,0 +1,80 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test"
|
||||
|
||||
const mockLogin = mock(() => Promise.resolve({ accessToken: "test-token", expiresAt: 1710000000 }))
|
||||
|
||||
mock.module("../../features/mcp-oauth/provider", () => ({
|
||||
McpOAuthProvider: class MockMcpOAuthProvider {
|
||||
constructor(public options: { serverUrl: string; clientId?: string; scopes?: string[] }) {}
|
||||
async login() {
|
||||
return mockLogin()
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
const { login } = await import("./login")
|
||||
|
||||
describe("login command", () => {
|
||||
beforeEach(() => {
|
||||
mockLogin.mockClear()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
// cleanup
|
||||
})
|
||||
|
||||
it("returns error code when server-url is not provided", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {}
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
|
||||
it("returns success code when login succeeds", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {
|
||||
serverUrl: "https://oauth.example.com",
|
||||
}
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(0)
|
||||
expect(mockLogin).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it("returns error code when login throws", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {
|
||||
serverUrl: "https://oauth.example.com",
|
||||
}
|
||||
mockLogin.mockRejectedValueOnce(new Error("Network error"))
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
|
||||
it("returns error code when server-url is missing", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
const options = {
|
||||
clientId: "test-client-id",
|
||||
}
|
||||
|
||||
// when
|
||||
const exitCode = await login(serverName, options)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
})
|
||||
38
src/cli/mcp-oauth/login.ts
Normal file
38
src/cli/mcp-oauth/login.ts
Normal file
@ -0,0 +1,38 @@
|
||||
import { McpOAuthProvider } from "../../features/mcp-oauth/provider"
|
||||
|
||||
export interface LoginOptions {
|
||||
serverUrl?: string
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
|
||||
export async function login(serverName: string, options: LoginOptions): Promise<number> {
|
||||
try {
|
||||
const serverUrl = options.serverUrl
|
||||
if (!serverUrl) {
|
||||
console.error(`Error: --server-url is required for server "${serverName}"`)
|
||||
return 1
|
||||
}
|
||||
|
||||
const provider = new McpOAuthProvider({
|
||||
serverUrl,
|
||||
clientId: options.clientId,
|
||||
scopes: options.scopes,
|
||||
})
|
||||
|
||||
console.log(`Authenticating with ${serverName}...`)
|
||||
const tokenData = await provider.login()
|
||||
|
||||
console.log(`✓ Successfully authenticated with ${serverName}`)
|
||||
if (tokenData.expiresAt) {
|
||||
const expiryDate = new Date(tokenData.expiresAt * 1000)
|
||||
console.log(` Token expires at: ${expiryDate.toISOString()}`)
|
||||
}
|
||||
|
||||
return 0
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
console.error(`Error: Failed to authenticate with ${serverName}: ${message}`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
65
src/cli/mcp-oauth/logout.test.ts
Normal file
65
src/cli/mcp-oauth/logout.test.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, mock } from "bun:test"
|
||||
import { existsSync, mkdirSync, rmSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import { tmpdir } from "node:os"
|
||||
import { saveToken } from "../../features/mcp-oauth/storage"
|
||||
|
||||
const { logout } = await import("./logout")
|
||||
|
||||
describe("logout command", () => {
|
||||
const TEST_CONFIG_DIR = join(tmpdir(), "mcp-oauth-logout-test-" + Date.now())
|
||||
let originalConfigDir: string | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
originalConfigDir = process.env.OPENCODE_CONFIG_DIR
|
||||
process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR
|
||||
if (!existsSync(TEST_CONFIG_DIR)) {
|
||||
mkdirSync(TEST_CONFIG_DIR, { recursive: true })
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (originalConfigDir === undefined) {
|
||||
delete process.env.OPENCODE_CONFIG_DIR
|
||||
} else {
|
||||
process.env.OPENCODE_CONFIG_DIR = originalConfigDir
|
||||
}
|
||||
if (existsSync(TEST_CONFIG_DIR)) {
|
||||
rmSync(TEST_CONFIG_DIR, { recursive: true, force: true })
|
||||
}
|
||||
})
|
||||
|
||||
it("returns success code when logout succeeds", async () => {
|
||||
// given
|
||||
const serverUrl = "https://test-server.example.com"
|
||||
saveToken(serverUrl, serverUrl, { accessToken: "test-token" })
|
||||
|
||||
// when
|
||||
const exitCode = await logout("test-server", { serverUrl })
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("handles non-existent server gracefully", async () => {
|
||||
// given
|
||||
const serverName = "non-existent-server"
|
||||
|
||||
// when
|
||||
const exitCode = await logout(serverName, { serverUrl: "https://nonexistent.example.com" })
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("returns error when --server-url is not provided", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
|
||||
// when
|
||||
const exitCode = await logout(serverName)
|
||||
|
||||
// then
|
||||
expect(exitCode).toBe(1)
|
||||
})
|
||||
})
|
||||
30
src/cli/mcp-oauth/logout.ts
Normal file
30
src/cli/mcp-oauth/logout.ts
Normal file
@ -0,0 +1,30 @@
|
||||
import { deleteToken } from "../../features/mcp-oauth/storage"
|
||||
|
||||
export interface LogoutOptions {
|
||||
serverUrl?: string
|
||||
}
|
||||
|
||||
export async function logout(serverName: string, options?: LogoutOptions): Promise<number> {
|
||||
try {
|
||||
const serverUrl = options?.serverUrl
|
||||
if (!serverUrl) {
|
||||
console.error(`Error: --server-url is required for logout. Token storage uses server URLs, not names.`)
|
||||
console.error(` Usage: mcp oauth logout ${serverName} --server-url https://your-server.example.com`)
|
||||
return 1
|
||||
}
|
||||
|
||||
const success = deleteToken(serverUrl, serverUrl)
|
||||
|
||||
if (success) {
|
||||
console.log(`✓ Successfully removed tokens for ${serverName}`)
|
||||
return 0
|
||||
}
|
||||
|
||||
console.error(`Error: Failed to remove tokens for ${serverName}`)
|
||||
return 1
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
console.error(`Error: Failed to remove tokens for ${serverName}: ${message}`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
48
src/cli/mcp-oauth/status.test.ts
Normal file
48
src/cli/mcp-oauth/status.test.ts
Normal file
@ -0,0 +1,48 @@
|
||||
import { describe, it, expect, beforeEach, afterEach } from "bun:test"
|
||||
import { status } from "./status"
|
||||
|
||||
describe("status command", () => {
|
||||
beforeEach(() => {
|
||||
// setup
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
// cleanup
|
||||
})
|
||||
|
||||
it("returns success code when checking status for specific server", async () => {
|
||||
// given
|
||||
const serverName = "test-server"
|
||||
|
||||
// when
|
||||
const exitCode = await status(serverName)
|
||||
|
||||
// then
|
||||
expect(typeof exitCode).toBe("number")
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("returns success code when checking status for all servers", async () => {
|
||||
// given
|
||||
const serverName = undefined
|
||||
|
||||
// when
|
||||
const exitCode = await status(serverName)
|
||||
|
||||
// then
|
||||
expect(typeof exitCode).toBe("number")
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
|
||||
it("handles non-existent server gracefully", async () => {
|
||||
// given
|
||||
const serverName = "non-existent-server"
|
||||
|
||||
// when
|
||||
const exitCode = await status(serverName)
|
||||
|
||||
// then
|
||||
expect(typeof exitCode).toBe("number")
|
||||
expect(exitCode).toBe(0)
|
||||
})
|
||||
})
|
||||
50
src/cli/mcp-oauth/status.ts
Normal file
50
src/cli/mcp-oauth/status.ts
Normal file
@ -0,0 +1,50 @@
|
||||
import { listAllTokens, listTokensByHost } from "../../features/mcp-oauth/storage"
|
||||
|
||||
export async function status(serverName: string | undefined): Promise<number> {
|
||||
try {
|
||||
if (serverName) {
|
||||
const tokens = listTokensByHost(serverName)
|
||||
|
||||
if (Object.keys(tokens).length === 0) {
|
||||
console.log(`No tokens found for ${serverName}`)
|
||||
return 0
|
||||
}
|
||||
|
||||
console.log(`OAuth Status for ${serverName}:`)
|
||||
for (const [key, token] of Object.entries(tokens)) {
|
||||
console.log(` ${key}:`)
|
||||
console.log(` Access Token: [REDACTED]`)
|
||||
if (token.refreshToken) {
|
||||
console.log(` Refresh Token: [REDACTED]`)
|
||||
}
|
||||
if (token.expiresAt) {
|
||||
const expiryDate = new Date(token.expiresAt * 1000)
|
||||
const now = Date.now() / 1000
|
||||
const isExpired = token.expiresAt < now
|
||||
const tokenStatus = isExpired ? "EXPIRED" : "VALID"
|
||||
console.log(` Expiry: ${expiryDate.toISOString()} (${tokenStatus})`)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const tokens = listAllTokens()
|
||||
if (Object.keys(tokens).length === 0) {
|
||||
console.log("No OAuth tokens stored")
|
||||
return 0
|
||||
}
|
||||
|
||||
console.log("Stored OAuth Tokens:")
|
||||
for (const [key, token] of Object.entries(tokens)) {
|
||||
const isExpired = token.expiresAt && token.expiresAt < Date.now() / 1000
|
||||
const tokenStatus = isExpired ? "EXPIRED" : "VALID"
|
||||
console.log(` ${key}: ${tokenStatus}`)
|
||||
}
|
||||
|
||||
return 0
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
console.error(`Error: Failed to get token status: ${message}`)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
@ -7,6 +7,10 @@ export interface ClaudeCodeMcpServer {
|
||||
args?: string[]
|
||||
env?: Record<string, string>
|
||||
headers?: Record<string, string>
|
||||
oauth?: {
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
disabled?: boolean
|
||||
}
|
||||
|
||||
|
||||
129
src/features/mcp-oauth/callback-server.test.ts
Normal file
129
src/features/mcp-oauth/callback-server.test.ts
Normal file
@ -0,0 +1,129 @@
|
||||
import { afterEach, describe, expect, it } from "bun:test"
|
||||
import { findAvailablePort, startCallbackServer, type CallbackServer } from "./callback-server"
|
||||
|
||||
describe("findAvailablePort", () => {
|
||||
it("returns the start port when it is available", async () => {
|
||||
//#given
|
||||
const startPort = 19877
|
||||
|
||||
//#when
|
||||
const port = await findAvailablePort(startPort)
|
||||
|
||||
//#then
|
||||
expect(port).toBeGreaterThanOrEqual(startPort)
|
||||
expect(port).toBeLessThan(startPort + 20)
|
||||
})
|
||||
|
||||
it("skips busy ports and returns next available", async () => {
|
||||
//#given
|
||||
const blocker = Bun.serve({
|
||||
port: 19877,
|
||||
hostname: "127.0.0.1",
|
||||
fetch: () => new Response(),
|
||||
})
|
||||
|
||||
//#when
|
||||
const port = await findAvailablePort(19877)
|
||||
|
||||
//#then
|
||||
expect(port).toBeGreaterThan(19877)
|
||||
blocker.stop(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe("startCallbackServer", () => {
|
||||
let server: CallbackServer | null = null
|
||||
|
||||
afterEach(() => {
|
||||
server?.close()
|
||||
server = null
|
||||
})
|
||||
|
||||
it("starts server and returns port", async () => {
|
||||
//#given - no preconditions
|
||||
|
||||
//#when
|
||||
server = await startCallbackServer()
|
||||
|
||||
//#then
|
||||
expect(server.port).toBeGreaterThanOrEqual(19877)
|
||||
expect(typeof server.waitForCallback).toBe("function")
|
||||
expect(typeof server.close).toBe("function")
|
||||
})
|
||||
|
||||
it("resolves callback with code and state from query params", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const callbackUrl = `http://127.0.0.1:${server.port}/oauth/callback?code=test-code&state=test-state`
|
||||
|
||||
//#when
|
||||
const fetchPromise = fetch(callbackUrl)
|
||||
const result = await server.waitForCallback()
|
||||
const response = await fetchPromise
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({ code: "test-code", state: "test-state" })
|
||||
expect(response.status).toBe(200)
|
||||
const html = await response.text()
|
||||
expect(html).toContain("Authorization successful")
|
||||
})
|
||||
|
||||
it("returns 404 for non-callback routes", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
|
||||
//#when
|
||||
const response = await fetch(`http://127.0.0.1:${server.port}/other`)
|
||||
|
||||
//#then
|
||||
expect(response.status).toBe(404)
|
||||
})
|
||||
|
||||
it("returns 400 and rejects when code is missing", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const callbackRejection = server.waitForCallback().catch((e: Error) => e)
|
||||
|
||||
//#when
|
||||
const response = await fetch(`http://127.0.0.1:${server.port}/oauth/callback?state=s`)
|
||||
|
||||
//#then
|
||||
expect(response.status).toBe(400)
|
||||
const error = await callbackRejection
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect((error as Error).message).toContain("missing code or state")
|
||||
})
|
||||
|
||||
it("returns 400 and rejects when state is missing", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const callbackRejection = server.waitForCallback().catch((e: Error) => e)
|
||||
|
||||
//#when
|
||||
const response = await fetch(`http://127.0.0.1:${server.port}/oauth/callback?code=c`)
|
||||
|
||||
//#then
|
||||
expect(response.status).toBe(400)
|
||||
const error = await callbackRejection
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect((error as Error).message).toContain("missing code or state")
|
||||
})
|
||||
|
||||
it("close stops the server immediately", async () => {
|
||||
//#given
|
||||
server = await startCallbackServer()
|
||||
const port = server.port
|
||||
|
||||
//#when
|
||||
server.close()
|
||||
server = null
|
||||
|
||||
//#then
|
||||
try {
|
||||
await fetch(`http://127.0.0.1:${port}/oauth/callback?code=c&state=s`)
|
||||
expect(true).toBe(false)
|
||||
} catch (error) {
|
||||
expect(error).toBeDefined()
|
||||
}
|
||||
})
|
||||
})
|
||||
124
src/features/mcp-oauth/callback-server.ts
Normal file
124
src/features/mcp-oauth/callback-server.ts
Normal file
@ -0,0 +1,124 @@
|
||||
const DEFAULT_PORT = 19877
|
||||
const MAX_PORT_ATTEMPTS = 20
|
||||
const TIMEOUT_MS = 5 * 60 * 1000
|
||||
|
||||
export type OAuthCallbackResult = {
|
||||
code: string
|
||||
state: string
|
||||
}
|
||||
|
||||
export type CallbackServer = {
|
||||
port: number
|
||||
waitForCallback: () => Promise<OAuthCallbackResult>
|
||||
close: () => void
|
||||
}
|
||||
|
||||
const SUCCESS_HTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>OAuth Authorized</title>
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #0a0a0a; color: #fafafa; }
|
||||
.container { text-align: center; }
|
||||
h1 { font-size: 1.5rem; margin-bottom: 0.5rem; }
|
||||
p { color: #888; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Authorization successful</h1>
|
||||
<p>You can close this window and return to your terminal.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
async function isPortAvailable(port: number): Promise<boolean> {
|
||||
try {
|
||||
const server = Bun.serve({
|
||||
port,
|
||||
hostname: "127.0.0.1",
|
||||
fetch: () => new Response(),
|
||||
})
|
||||
server.stop(true)
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function findAvailablePort(startPort: number = DEFAULT_PORT): Promise<number> {
|
||||
for (let attempt = 0; attempt < MAX_PORT_ATTEMPTS; attempt++) {
|
||||
const port = startPort + attempt
|
||||
if (await isPortAvailable(port)) {
|
||||
return port
|
||||
}
|
||||
}
|
||||
throw new Error(`No available port found in range ${startPort}-${startPort + MAX_PORT_ATTEMPTS - 1}`)
|
||||
}
|
||||
|
||||
export async function startCallbackServer(startPort: number = DEFAULT_PORT): Promise<CallbackServer> {
|
||||
const port = await findAvailablePort(startPort)
|
||||
|
||||
let resolveCallback: ((result: OAuthCallbackResult) => void) | null = null
|
||||
let rejectCallback: ((error: Error) => void) | null = null
|
||||
|
||||
const callbackPromise = new Promise<OAuthCallbackResult>((resolve, reject) => {
|
||||
resolveCallback = resolve
|
||||
rejectCallback = reject
|
||||
})
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
rejectCallback?.(new Error("OAuth callback timed out after 5 minutes"))
|
||||
server.stop(true)
|
||||
}, TIMEOUT_MS)
|
||||
|
||||
const server = Bun.serve({
|
||||
port,
|
||||
hostname: "127.0.0.1",
|
||||
fetch(request: Request): Response {
|
||||
const url = new URL(request.url)
|
||||
|
||||
if (url.pathname !== "/oauth/callback") {
|
||||
return new Response("Not Found", { status: 404 })
|
||||
}
|
||||
|
||||
const oauthError = url.searchParams.get("error")
|
||||
if (oauthError) {
|
||||
const description = url.searchParams.get("error_description") ?? oauthError
|
||||
clearTimeout(timeoutId)
|
||||
rejectCallback?.(new Error(`OAuth authorization failed: ${description}`))
|
||||
setTimeout(() => server.stop(true), 100)
|
||||
return new Response(`Authorization failed: ${description}`, { status: 400 })
|
||||
}
|
||||
|
||||
const code = url.searchParams.get("code")
|
||||
const state = url.searchParams.get("state")
|
||||
|
||||
if (!code || !state) {
|
||||
clearTimeout(timeoutId)
|
||||
rejectCallback?.(new Error("OAuth callback missing code or state parameter"))
|
||||
setTimeout(() => server.stop(true), 100)
|
||||
return new Response("Missing code or state parameter", { status: 400 })
|
||||
}
|
||||
|
||||
resolveCallback?.({ code, state })
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
setTimeout(() => server.stop(true), 100)
|
||||
|
||||
return new Response(SUCCESS_HTML, {
|
||||
headers: { "content-type": "text/html; charset=utf-8" },
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
port,
|
||||
waitForCallback: () => callbackPromise,
|
||||
close: () => {
|
||||
clearTimeout(timeoutId)
|
||||
server.stop(true)
|
||||
},
|
||||
}
|
||||
}
|
||||
164
src/features/mcp-oauth/dcr.test.ts
Normal file
164
src/features/mcp-oauth/dcr.test.ts
Normal file
@ -0,0 +1,164 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import {
|
||||
getOrRegisterClient,
|
||||
type ClientCredentials,
|
||||
type ClientRegistrationStorage,
|
||||
type DcrFetch,
|
||||
} from "./dcr"
|
||||
|
||||
function createStorage(initial: ClientCredentials | null):
|
||||
& ClientRegistrationStorage
|
||||
& { getLastKey: () => string | null; getLastSet: () => ClientCredentials | null } {
|
||||
let stored = initial
|
||||
let lastKey: string | null = null
|
||||
let lastSet: ClientCredentials | null = null
|
||||
|
||||
return {
|
||||
getClientRegistration: () => stored,
|
||||
setClientRegistration: (serverIdentifier: string, credentials: ClientCredentials) => {
|
||||
lastKey = serverIdentifier
|
||||
lastSet = credentials
|
||||
stored = credentials
|
||||
},
|
||||
getLastKey: () => lastKey,
|
||||
getLastSet: () => lastSet,
|
||||
}
|
||||
}
|
||||
|
||||
describe("getOrRegisterClient", () => {
|
||||
it("returns cached registration when available", async () => {
|
||||
// #given
|
||||
const storage = createStorage({
|
||||
clientId: "cached-client",
|
||||
clientSecret: "cached-secret",
|
||||
})
|
||||
const fetchMock: DcrFetch = async () => {
|
||||
throw new Error("fetch should not be called")
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: "https://server.example.com/register",
|
||||
serverIdentifier: "server-1",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
clientId: "cached-client",
|
||||
clientSecret: "cached-secret",
|
||||
})
|
||||
})
|
||||
|
||||
it("registers client and stores credentials when endpoint available", async () => {
|
||||
// #given
|
||||
const storage = createStorage(null)
|
||||
let fetchCalled = false
|
||||
const fetchMock: DcrFetch = async (
|
||||
input: string,
|
||||
init?: { method?: string; headers?: Record<string, string>; body?: string }
|
||||
) => {
|
||||
fetchCalled = true
|
||||
expect(input).toBe("https://server.example.com/register")
|
||||
if (typeof init?.body !== "string") {
|
||||
throw new Error("Expected request body string")
|
||||
}
|
||||
const payload = JSON.parse(init.body)
|
||||
expect(payload).toEqual({
|
||||
redirect_uris: ["https://app.example.com/callback"],
|
||||
client_name: "Test Client",
|
||||
grant_types: ["authorization_code", "refresh_token"],
|
||||
response_types: ["code"],
|
||||
token_endpoint_auth_method: "client_secret_post",
|
||||
})
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
client_id: "registered-client",
|
||||
client_secret: "registered-secret",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: "https://server.example.com/register",
|
||||
serverIdentifier: "server-2",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(fetchCalled).toBe(true)
|
||||
expect(result).toEqual({
|
||||
clientId: "registered-client",
|
||||
clientSecret: "registered-secret",
|
||||
})
|
||||
expect(storage.getLastKey()).toBe("server-2")
|
||||
expect(storage.getLastSet()).toEqual({
|
||||
clientId: "registered-client",
|
||||
clientSecret: "registered-secret",
|
||||
})
|
||||
})
|
||||
|
||||
it("uses config client id when registration endpoint missing", async () => {
|
||||
// #given
|
||||
const storage = createStorage(null)
|
||||
let fetchCalled = false
|
||||
const fetchMock: DcrFetch = async () => {
|
||||
fetchCalled = true
|
||||
return {
|
||||
ok: false,
|
||||
json: async () => ({}),
|
||||
}
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: undefined,
|
||||
serverIdentifier: "server-3",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
clientId: "config-client",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(fetchCalled).toBe(false)
|
||||
expect(result).toEqual({ clientId: "config-client" })
|
||||
})
|
||||
|
||||
it("falls back to config client id when registration fails", async () => {
|
||||
// #given
|
||||
const storage = createStorage(null)
|
||||
const fetchMock: DcrFetch = async () => {
|
||||
throw new Error("network error")
|
||||
}
|
||||
|
||||
// #when
|
||||
const result = await getOrRegisterClient({
|
||||
registrationEndpoint: "https://server.example.com/register",
|
||||
serverIdentifier: "server-4",
|
||||
clientName: "Test Client",
|
||||
redirectUris: ["https://app.example.com/callback"],
|
||||
tokenEndpointAuthMethod: "client_secret_post",
|
||||
clientId: "fallback-client",
|
||||
storage,
|
||||
fetch: fetchMock,
|
||||
})
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ clientId: "fallback-client" })
|
||||
expect(storage.getLastSet()).toBeNull()
|
||||
})
|
||||
})
|
||||
98
src/features/mcp-oauth/dcr.ts
Normal file
98
src/features/mcp-oauth/dcr.ts
Normal file
@ -0,0 +1,98 @@
|
||||
export type ClientRegistrationRequest = {
|
||||
redirect_uris: string[]
|
||||
client_name: string
|
||||
grant_types: ["authorization_code", "refresh_token"]
|
||||
response_types: ["code"]
|
||||
token_endpoint_auth_method: "none" | "client_secret_post"
|
||||
}
|
||||
|
||||
export type ClientCredentials = {
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
}
|
||||
|
||||
export type ClientRegistrationStorage = {
|
||||
getClientRegistration: (serverIdentifier: string) => ClientCredentials | null
|
||||
setClientRegistration: (
|
||||
serverIdentifier: string,
|
||||
credentials: ClientCredentials
|
||||
) => void
|
||||
}
|
||||
|
||||
export type DynamicClientRegistrationOptions = {
|
||||
registrationEndpoint?: string | null
|
||||
serverIdentifier?: string
|
||||
clientName: string
|
||||
redirectUris: string[]
|
||||
tokenEndpointAuthMethod: "none" | "client_secret_post"
|
||||
clientId?: string | null
|
||||
storage: ClientRegistrationStorage
|
||||
fetch?: DcrFetch
|
||||
}
|
||||
|
||||
export type DcrFetch = (
|
||||
input: string,
|
||||
init?: { method?: string; headers?: Record<string, string>; body?: string }
|
||||
) => Promise<{ ok: boolean; json: () => Promise<unknown> }>
|
||||
|
||||
export async function getOrRegisterClient(
|
||||
options: DynamicClientRegistrationOptions
|
||||
): Promise<ClientCredentials | null> {
|
||||
const serverIdentifier =
|
||||
options.serverIdentifier ?? options.registrationEndpoint ?? "default"
|
||||
const existing = options.storage.getClientRegistration(serverIdentifier)
|
||||
if (existing) return existing
|
||||
|
||||
if (!options.registrationEndpoint) {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
|
||||
const fetchImpl = options.fetch ?? globalThis.fetch
|
||||
const request: ClientRegistrationRequest = {
|
||||
redirect_uris: options.redirectUris,
|
||||
client_name: options.clientName,
|
||||
grant_types: ["authorization_code", "refresh_token"],
|
||||
response_types: ["code"],
|
||||
token_endpoint_auth_method: options.tokenEndpointAuthMethod,
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetchImpl(options.registrationEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
|
||||
const data: unknown = await response.json()
|
||||
const parsed = parseRegistrationResponse(data)
|
||||
if (!parsed) {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
|
||||
options.storage.setClientRegistration(serverIdentifier, parsed)
|
||||
return parsed
|
||||
} catch {
|
||||
return options.clientId ? { clientId: options.clientId } : null
|
||||
}
|
||||
}
|
||||
|
||||
function parseRegistrationResponse(data: unknown): ClientCredentials | null {
|
||||
if (!isRecord(data)) return null
|
||||
const clientId = data.client_id
|
||||
if (typeof clientId !== "string" || clientId.length === 0) return null
|
||||
|
||||
const clientSecret = data.client_secret
|
||||
if (typeof clientSecret === "string" && clientSecret.length > 0) {
|
||||
return { clientId, clientSecret }
|
||||
}
|
||||
|
||||
return { clientId }
|
||||
}
|
||||
|
||||
function isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null
|
||||
}
|
||||
175
src/features/mcp-oauth/discovery.test.ts
Normal file
175
src/features/mcp-oauth/discovery.test.ts
Normal file
@ -0,0 +1,175 @@
|
||||
import { describe, test, expect, beforeEach, afterEach } from "bun:test"
|
||||
import { discoverOAuthServerMetadata, resetDiscoveryCache } from "./discovery"
|
||||
|
||||
describe("discoverOAuthServerMetadata", () => {
|
||||
const originalFetch = globalThis.fetch
|
||||
|
||||
beforeEach(() => {
|
||||
resetDiscoveryCache()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
Object.defineProperty(globalThis, "fetch", { value: originalFetch, configurable: true })
|
||||
})
|
||||
|
||||
test("returns endpoints from PRM + AS discovery", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const authServer = "https://auth.example.com"
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString()
|
||||
const calls: string[] = []
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
calls.push(url)
|
||||
if (url === prmUrl) {
|
||||
return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: "https://auth.example.com/authorize",
|
||||
token_endpoint: "https://auth.example.com/token",
|
||||
registration_endpoint: "https://auth.example.com/register",
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
return discoverOAuthServerMetadata(resource).then((result) => {
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
authorizationEndpoint: "https://auth.example.com/authorize",
|
||||
tokenEndpoint: "https://auth.example.com/token",
|
||||
registrationEndpoint: "https://auth.example.com/register",
|
||||
resource,
|
||||
})
|
||||
expect(calls).toEqual([prmUrl, asUrl])
|
||||
})
|
||||
})
|
||||
|
||||
test("falls back to RFC 8414 when PRM returns 404", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", resource).toString()
|
||||
const calls: string[] = []
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
calls.push(url)
|
||||
if (url === prmUrl) {
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: "https://mcp.example.com/authorize",
|
||||
token_endpoint: "https://mcp.example.com/token",
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
return discoverOAuthServerMetadata(resource).then((result) => {
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
authorizationEndpoint: "https://mcp.example.com/authorize",
|
||||
tokenEndpoint: "https://mcp.example.com/token",
|
||||
registrationEndpoint: undefined,
|
||||
resource,
|
||||
})
|
||||
expect(calls).toEqual([prmUrl, asUrl])
|
||||
})
|
||||
})
|
||||
|
||||
test("throws when both PRM and AS discovery return 404", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", resource).toString()
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
if (url === prmUrl || url === asUrl) {
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
const result = discoverOAuthServerMetadata(resource)
|
||||
|
||||
// #then
|
||||
return expect(result).rejects.toThrow("OAuth authorization server metadata not found")
|
||||
})
|
||||
|
||||
test("throws when AS metadata is malformed", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const authServer = "https://auth.example.com"
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString()
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
if (url === prmUrl) {
|
||||
return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(JSON.stringify({ authorization_endpoint: "https://auth.example.com/authorize" }), {
|
||||
status: 200,
|
||||
})
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
const result = discoverOAuthServerMetadata(resource)
|
||||
|
||||
// #then
|
||||
return expect(result).rejects.toThrow("token_endpoint")
|
||||
})
|
||||
|
||||
test("caches discovery results per resource URL", () => {
|
||||
// #given
|
||||
const resource = "https://mcp.example.com"
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resource).toString()
|
||||
const authServer = "https://auth.example.com"
|
||||
const asUrl = new URL("/.well-known/oauth-authorization-server", authServer).toString()
|
||||
const calls: string[] = []
|
||||
const fetchMock = async (input: string | URL) => {
|
||||
const url = typeof input === "string" ? input : input.toString()
|
||||
calls.push(url)
|
||||
if (url === prmUrl) {
|
||||
return new Response(JSON.stringify({ authorization_servers: [authServer] }), { status: 200 })
|
||||
}
|
||||
if (url === asUrl) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
authorization_endpoint: "https://auth.example.com/authorize",
|
||||
token_endpoint: "https://auth.example.com/token",
|
||||
}),
|
||||
{ status: 200 }
|
||||
)
|
||||
}
|
||||
return new Response("not found", { status: 404 })
|
||||
}
|
||||
Object.defineProperty(globalThis, "fetch", { value: fetchMock, configurable: true })
|
||||
|
||||
// #when
|
||||
return discoverOAuthServerMetadata(resource)
|
||||
.then(() => discoverOAuthServerMetadata(resource))
|
||||
.then(() => {
|
||||
// #then
|
||||
expect(calls).toEqual([prmUrl, asUrl])
|
||||
})
|
||||
})
|
||||
})
|
||||
123
src/features/mcp-oauth/discovery.ts
Normal file
123
src/features/mcp-oauth/discovery.ts
Normal file
@ -0,0 +1,123 @@
|
||||
export interface OAuthServerMetadata {
|
||||
authorizationEndpoint: string
|
||||
tokenEndpoint: string
|
||||
registrationEndpoint?: string
|
||||
resource: string
|
||||
}
|
||||
|
||||
const discoveryCache = new Map<string, OAuthServerMetadata>()
|
||||
const pendingDiscovery = new Map<string, Promise<OAuthServerMetadata>>()
|
||||
|
||||
function parseHttpsUrl(value: string, label: string): URL {
|
||||
const parsed = new URL(value)
|
||||
if (parsed.protocol !== "https:") {
|
||||
throw new Error(`${label} must use https`)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
function readStringField(source: Record<string, unknown>, field: string): string {
|
||||
const value = source[field]
|
||||
if (typeof value !== "string" || value.length === 0) {
|
||||
throw new Error(`OAuth metadata missing ${field}`)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
async function fetchMetadata(url: string): Promise<{ ok: true; json: Record<string, unknown> } | { ok: false; status: number }> {
|
||||
const response = await fetch(url, { headers: { accept: "application/json" } })
|
||||
if (!response.ok) {
|
||||
return { ok: false, status: response.status }
|
||||
}
|
||||
const json = (await response.json().catch(() => null)) as Record<string, unknown> | null
|
||||
if (!json || typeof json !== "object") {
|
||||
throw new Error("OAuth metadata response is not valid JSON")
|
||||
}
|
||||
return { ok: true, json }
|
||||
}
|
||||
|
||||
async function fetchAuthorizationServerMetadata(issuer: string, resource: string): Promise<OAuthServerMetadata> {
|
||||
const issuerUrl = parseHttpsUrl(issuer, "Authorization server URL")
|
||||
const issuerPath = issuerUrl.pathname.replace(/\/+$/, "")
|
||||
const metadataUrl = new URL(`/.well-known/oauth-authorization-server${issuerPath}`, issuerUrl).toString()
|
||||
const metadata = await fetchMetadata(metadataUrl)
|
||||
|
||||
if (!metadata.ok) {
|
||||
if (metadata.status === 404) {
|
||||
throw new Error("OAuth authorization server metadata not found")
|
||||
}
|
||||
throw new Error(`OAuth authorization server metadata fetch failed (${metadata.status})`)
|
||||
}
|
||||
|
||||
const authorizationEndpoint = parseHttpsUrl(
|
||||
readStringField(metadata.json, "authorization_endpoint"),
|
||||
"authorization_endpoint"
|
||||
).toString()
|
||||
const tokenEndpoint = parseHttpsUrl(
|
||||
readStringField(metadata.json, "token_endpoint"),
|
||||
"token_endpoint"
|
||||
).toString()
|
||||
const registrationEndpointValue = metadata.json.registration_endpoint
|
||||
const registrationEndpoint =
|
||||
typeof registrationEndpointValue === "string" && registrationEndpointValue.length > 0
|
||||
? parseHttpsUrl(registrationEndpointValue, "registration_endpoint").toString()
|
||||
: undefined
|
||||
|
||||
return {
|
||||
authorizationEndpoint,
|
||||
tokenEndpoint,
|
||||
registrationEndpoint,
|
||||
resource,
|
||||
}
|
||||
}
|
||||
|
||||
function parseAuthorizationServers(metadata: Record<string, unknown>): string[] {
|
||||
const servers = metadata.authorization_servers
|
||||
if (!Array.isArray(servers)) return []
|
||||
return servers.filter((server): server is string => typeof server === "string" && server.length > 0)
|
||||
}
|
||||
|
||||
export async function discoverOAuthServerMetadata(resource: string): Promise<OAuthServerMetadata> {
|
||||
const resourceUrl = parseHttpsUrl(resource, "Resource server URL")
|
||||
const resourceKey = resourceUrl.toString()
|
||||
|
||||
const cached = discoveryCache.get(resourceKey)
|
||||
if (cached) return cached
|
||||
|
||||
const pending = pendingDiscovery.get(resourceKey)
|
||||
if (pending) return pending
|
||||
|
||||
const discoveryPromise = (async () => {
|
||||
const prmUrl = new URL("/.well-known/oauth-protected-resource", resourceUrl).toString()
|
||||
const prmResponse = await fetchMetadata(prmUrl)
|
||||
|
||||
if (prmResponse.ok) {
|
||||
const authServers = parseAuthorizationServers(prmResponse.json)
|
||||
if (authServers.length === 0) {
|
||||
throw new Error("OAuth protected resource metadata missing authorization_servers")
|
||||
}
|
||||
return fetchAuthorizationServerMetadata(authServers[0], resource)
|
||||
}
|
||||
|
||||
if (prmResponse.status !== 404) {
|
||||
throw new Error(`OAuth protected resource metadata fetch failed (${prmResponse.status})`)
|
||||
}
|
||||
|
||||
return fetchAuthorizationServerMetadata(resourceKey, resource)
|
||||
})()
|
||||
|
||||
pendingDiscovery.set(resourceKey, discoveryPromise)
|
||||
|
||||
try {
|
||||
const result = await discoveryPromise
|
||||
discoveryCache.set(resourceKey, result)
|
||||
return result
|
||||
} finally {
|
||||
pendingDiscovery.delete(resourceKey)
|
||||
}
|
||||
}
|
||||
|
||||
export function resetDiscoveryCache(): void {
|
||||
discoveryCache.clear()
|
||||
pendingDiscovery.clear()
|
||||
}
|
||||
1
src/features/mcp-oauth/index.ts
Normal file
1
src/features/mcp-oauth/index.ts
Normal file
@ -0,0 +1 @@
|
||||
export * from "./schema"
|
||||
223
src/features/mcp-oauth/provider.test.ts
Normal file
223
src/features/mcp-oauth/provider.test.ts
Normal file
@ -0,0 +1,223 @@
|
||||
import { describe, expect, it, beforeEach, afterEach, mock } from "bun:test"
|
||||
import { createHash, randomBytes } from "node:crypto"
|
||||
import { McpOAuthProvider, generateCodeVerifier, generateCodeChallenge, buildAuthorizationUrl } from "./provider"
|
||||
import type { OAuthTokenData } from "./storage"
|
||||
|
||||
describe("McpOAuthProvider", () => {
|
||||
describe("generateCodeVerifier", () => {
|
||||
it("returns a base64url-encoded 32-byte random string", () => {
|
||||
//#given
|
||||
const verifier = generateCodeVerifier()
|
||||
|
||||
//#when
|
||||
const decoded = Buffer.from(verifier, "base64url")
|
||||
|
||||
//#then
|
||||
expect(decoded.length).toBe(32)
|
||||
expect(verifier).toMatch(/^[A-Za-z0-9_-]+$/)
|
||||
})
|
||||
|
||||
it("produces unique values on each call", () => {
|
||||
//#given
|
||||
const first = generateCodeVerifier()
|
||||
|
||||
//#when
|
||||
const second = generateCodeVerifier()
|
||||
|
||||
//#then
|
||||
expect(first).not.toBe(second)
|
||||
})
|
||||
})
|
||||
|
||||
describe("generateCodeChallenge", () => {
|
||||
it("returns SHA256 base64url digest of the verifier", () => {
|
||||
//#given
|
||||
const verifier = "test-verifier-value"
|
||||
const expected = createHash("sha256").update(verifier).digest("base64url")
|
||||
|
||||
//#when
|
||||
const challenge = generateCodeChallenge(verifier)
|
||||
|
||||
//#then
|
||||
expect(challenge).toBe(expected)
|
||||
})
|
||||
})
|
||||
|
||||
describe("buildAuthorizationUrl", () => {
|
||||
it("builds URL with all required PKCE parameters", () => {
|
||||
//#given
|
||||
const endpoint = "https://auth.example.com/authorize"
|
||||
|
||||
//#when
|
||||
const url = buildAuthorizationUrl(endpoint, {
|
||||
clientId: "my-client",
|
||||
redirectUri: "http://127.0.0.1:8912/callback",
|
||||
codeChallenge: "challenge-value",
|
||||
state: "state-value",
|
||||
scopes: ["openid", "profile"],
|
||||
resource: "https://mcp.example.com",
|
||||
})
|
||||
|
||||
//#then
|
||||
const parsed = new URL(url)
|
||||
expect(parsed.origin + parsed.pathname).toBe("https://auth.example.com/authorize")
|
||||
expect(parsed.searchParams.get("response_type")).toBe("code")
|
||||
expect(parsed.searchParams.get("client_id")).toBe("my-client")
|
||||
expect(parsed.searchParams.get("redirect_uri")).toBe("http://127.0.0.1:8912/callback")
|
||||
expect(parsed.searchParams.get("code_challenge")).toBe("challenge-value")
|
||||
expect(parsed.searchParams.get("code_challenge_method")).toBe("S256")
|
||||
expect(parsed.searchParams.get("state")).toBe("state-value")
|
||||
expect(parsed.searchParams.get("scope")).toBe("openid profile")
|
||||
expect(parsed.searchParams.get("resource")).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("omits scope when empty", () => {
|
||||
//#given
|
||||
const endpoint = "https://auth.example.com/authorize"
|
||||
|
||||
//#when
|
||||
const url = buildAuthorizationUrl(endpoint, {
|
||||
clientId: "my-client",
|
||||
redirectUri: "http://127.0.0.1:8912/callback",
|
||||
codeChallenge: "challenge-value",
|
||||
state: "state-value",
|
||||
scopes: [],
|
||||
})
|
||||
|
||||
//#then
|
||||
const parsed = new URL(url)
|
||||
expect(parsed.searchParams.has("scope")).toBe(false)
|
||||
})
|
||||
|
||||
it("omits resource when undefined", () => {
|
||||
//#given
|
||||
const endpoint = "https://auth.example.com/authorize"
|
||||
|
||||
//#when
|
||||
const url = buildAuthorizationUrl(endpoint, {
|
||||
clientId: "my-client",
|
||||
redirectUri: "http://127.0.0.1:8912/callback",
|
||||
codeChallenge: "challenge-value",
|
||||
state: "state-value",
|
||||
})
|
||||
|
||||
//#then
|
||||
const parsed = new URL(url)
|
||||
expect(parsed.searchParams.has("resource")).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe("constructor and basic methods", () => {
|
||||
it("stores serverUrl and optional clientId and scopes", () => {
|
||||
//#given
|
||||
const options = {
|
||||
serverUrl: "https://mcp.example.com",
|
||||
clientId: "my-client",
|
||||
scopes: ["openid"],
|
||||
}
|
||||
|
||||
//#when
|
||||
const provider = new McpOAuthProvider(options)
|
||||
|
||||
//#then
|
||||
expect(provider.tokens()).toBeNull()
|
||||
expect(provider.clientInformation()).toBeNull()
|
||||
expect(provider.codeVerifier()).toBeNull()
|
||||
})
|
||||
|
||||
it("defaults scopes to empty array", () => {
|
||||
//#given
|
||||
const options = { serverUrl: "https://mcp.example.com" }
|
||||
|
||||
//#when
|
||||
const provider = new McpOAuthProvider(options)
|
||||
|
||||
//#then
|
||||
expect(provider.redirectUrl()).toBe("http://127.0.0.1:19877/callback")
|
||||
})
|
||||
})
|
||||
|
||||
describe("saveCodeVerifier / codeVerifier", () => {
|
||||
it("stores and retrieves code verifier", () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
|
||||
//#when
|
||||
provider.saveCodeVerifier("my-verifier")
|
||||
|
||||
//#then
|
||||
expect(provider.codeVerifier()).toBe("my-verifier")
|
||||
})
|
||||
})
|
||||
|
||||
describe("saveTokens / tokens", () => {
|
||||
let originalEnv: string | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
originalEnv = process.env.OPENCODE_CONFIG_DIR
|
||||
const { mkdirSync } = require("node:fs")
|
||||
const { tmpdir } = require("node:os")
|
||||
const { join } = require("node:path")
|
||||
const testDir = join(tmpdir(), "mcp-oauth-provider-test-" + Date.now())
|
||||
mkdirSync(testDir, { recursive: true })
|
||||
process.env.OPENCODE_CONFIG_DIR = testDir
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (originalEnv === undefined) {
|
||||
delete process.env.OPENCODE_CONFIG_DIR
|
||||
} else {
|
||||
process.env.OPENCODE_CONFIG_DIR = originalEnv
|
||||
}
|
||||
})
|
||||
|
||||
it("persists and loads token data via storage", () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
const tokenData: OAuthTokenData = {
|
||||
accessToken: "access-token-123",
|
||||
refreshToken: "refresh-token-456",
|
||||
expiresAt: 1710000000,
|
||||
}
|
||||
|
||||
//#when
|
||||
const saved = provider.saveTokens(tokenData)
|
||||
const loaded = provider.tokens()
|
||||
|
||||
//#then
|
||||
expect(saved).toBe(true)
|
||||
expect(loaded).toEqual(tokenData)
|
||||
})
|
||||
})
|
||||
|
||||
describe("redirectToAuthorization", () => {
|
||||
it("throws when no client information is set", async () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
const metadata = {
|
||||
authorizationEndpoint: "https://auth.example.com/authorize",
|
||||
tokenEndpoint: "https://auth.example.com/token",
|
||||
resource: "https://mcp.example.com",
|
||||
}
|
||||
|
||||
//#when
|
||||
const result = provider.redirectToAuthorization(metadata)
|
||||
|
||||
//#then
|
||||
await expect(result).rejects.toThrow("No client information available")
|
||||
})
|
||||
})
|
||||
|
||||
describe("redirectUrl", () => {
|
||||
it("returns localhost callback URL with default port", () => {
|
||||
//#given
|
||||
const provider = new McpOAuthProvider({ serverUrl: "https://mcp.example.com" })
|
||||
|
||||
//#when
|
||||
const url = provider.redirectUrl()
|
||||
|
||||
//#then
|
||||
expect(url).toBe("http://127.0.0.1:19877/callback")
|
||||
})
|
||||
})
|
||||
})
|
||||
295
src/features/mcp-oauth/provider.ts
Normal file
295
src/features/mcp-oauth/provider.ts
Normal file
@ -0,0 +1,295 @@
|
||||
import { createHash, randomBytes } from "node:crypto"
|
||||
import { createServer } from "node:http"
|
||||
import { spawn } from "node:child_process"
|
||||
import type { OAuthTokenData } from "./storage"
|
||||
import { loadToken, saveToken } from "./storage"
|
||||
import { discoverOAuthServerMetadata } from "./discovery"
|
||||
import type { OAuthServerMetadata } from "./discovery"
|
||||
import { getOrRegisterClient } from "./dcr"
|
||||
import type { ClientCredentials, ClientRegistrationStorage } from "./dcr"
|
||||
import { findAvailablePort } from "./callback-server"
|
||||
|
||||
export type McpOAuthProviderOptions = {
|
||||
serverUrl: string
|
||||
clientId?: string
|
||||
scopes?: string[]
|
||||
}
|
||||
|
||||
type CallbackResult = {
|
||||
code: string
|
||||
state: string
|
||||
}
|
||||
|
||||
function generateCodeVerifier(): string {
|
||||
return randomBytes(32).toString("base64url")
|
||||
}
|
||||
|
||||
function generateCodeChallenge(verifier: string): string {
|
||||
return createHash("sha256").update(verifier).digest("base64url")
|
||||
}
|
||||
|
||||
function buildAuthorizationUrl(
|
||||
authorizationEndpoint: string,
|
||||
options: {
|
||||
clientId: string
|
||||
redirectUri: string
|
||||
codeChallenge: string
|
||||
state: string
|
||||
scopes?: string[]
|
||||
resource?: string
|
||||
}
|
||||
): string {
|
||||
const url = new URL(authorizationEndpoint)
|
||||
url.searchParams.set("response_type", "code")
|
||||
url.searchParams.set("client_id", options.clientId)
|
||||
url.searchParams.set("redirect_uri", options.redirectUri)
|
||||
url.searchParams.set("code_challenge", options.codeChallenge)
|
||||
url.searchParams.set("code_challenge_method", "S256")
|
||||
url.searchParams.set("state", options.state)
|
||||
if (options.scopes && options.scopes.length > 0) {
|
||||
url.searchParams.set("scope", options.scopes.join(" "))
|
||||
}
|
||||
if (options.resource) {
|
||||
url.searchParams.set("resource", options.resource)
|
||||
}
|
||||
return url.toString()
|
||||
}
|
||||
|
||||
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000
|
||||
|
||||
function startCallbackServer(port: number): Promise<CallbackResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
let timeoutId: ReturnType<typeof setTimeout>
|
||||
|
||||
const server = createServer((request, response) => {
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
const requestUrl = new URL(request.url ?? "/", `http://localhost:${port}`)
|
||||
const code = requestUrl.searchParams.get("code")
|
||||
const state = requestUrl.searchParams.get("state")
|
||||
const error = requestUrl.searchParams.get("error")
|
||||
|
||||
if (error) {
|
||||
const errorDescription = requestUrl.searchParams.get("error_description") ?? error
|
||||
response.writeHead(400, { "content-type": "text/html" })
|
||||
response.end("<html><body><h1>Authorization failed</h1></body></html>")
|
||||
server.close()
|
||||
reject(new Error(`OAuth authorization error: ${errorDescription}`))
|
||||
return
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
response.writeHead(400, { "content-type": "text/html" })
|
||||
response.end("<html><body><h1>Missing code or state</h1></body></html>")
|
||||
server.close()
|
||||
reject(new Error("OAuth callback missing code or state parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
response.writeHead(200, { "content-type": "text/html" })
|
||||
response.end("<html><body><h1>Authorization successful. You can close this tab.</h1></body></html>")
|
||||
server.close()
|
||||
resolve({ code, state })
|
||||
})
|
||||
|
||||
timeoutId = setTimeout(() => {
|
||||
server.close()
|
||||
reject(new Error("OAuth callback timed out after 5 minutes"))
|
||||
}, CALLBACK_TIMEOUT_MS)
|
||||
|
||||
server.listen(port, "127.0.0.1")
|
||||
server.on("error", (err) => {
|
||||
clearTimeout(timeoutId)
|
||||
reject(err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
function openBrowser(url: string): void {
|
||||
const platform = process.platform
|
||||
let cmd: string
|
||||
let args: string[]
|
||||
|
||||
if (platform === "darwin") {
|
||||
cmd = "open"
|
||||
args = [url]
|
||||
} else if (platform === "win32") {
|
||||
cmd = "explorer"
|
||||
args = [url]
|
||||
} else {
|
||||
cmd = "xdg-open"
|
||||
args = [url]
|
||||
}
|
||||
|
||||
try {
|
||||
const child = spawn(cmd, args, { stdio: "ignore", detached: true })
|
||||
child.on("error", () => {})
|
||||
child.unref()
|
||||
} catch {
|
||||
// Browser open failed — user must navigate manually
|
||||
}
|
||||
}
|
||||
|
||||
export class McpOAuthProvider {
|
||||
private readonly serverUrl: string
|
||||
private readonly configClientId: string | undefined
|
||||
private readonly scopes: string[]
|
||||
private storedCodeVerifier: string | null = null
|
||||
private storedClientInfo: ClientCredentials | null = null
|
||||
private callbackPort: number | null = null
|
||||
|
||||
constructor(options: McpOAuthProviderOptions) {
|
||||
this.serverUrl = options.serverUrl
|
||||
this.configClientId = options.clientId
|
||||
this.scopes = options.scopes ?? []
|
||||
}
|
||||
|
||||
tokens(): OAuthTokenData | null {
|
||||
return loadToken(this.serverUrl, this.serverUrl)
|
||||
}
|
||||
|
||||
saveTokens(tokenData: OAuthTokenData): boolean {
|
||||
return saveToken(this.serverUrl, this.serverUrl, tokenData)
|
||||
}
|
||||
|
||||
clientInformation(): ClientCredentials | null {
|
||||
if (this.storedClientInfo) return this.storedClientInfo
|
||||
const tokenData = this.tokens()
|
||||
if (tokenData?.clientInfo) {
|
||||
this.storedClientInfo = tokenData.clientInfo
|
||||
return this.storedClientInfo
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
redirectUrl(): string {
|
||||
return `http://127.0.0.1:${this.callbackPort ?? 19877}/callback`
|
||||
}
|
||||
|
||||
saveCodeVerifier(verifier: string): void {
|
||||
this.storedCodeVerifier = verifier
|
||||
}
|
||||
|
||||
codeVerifier(): string | null {
|
||||
return this.storedCodeVerifier
|
||||
}
|
||||
|
||||
async redirectToAuthorization(metadata: OAuthServerMetadata): Promise<CallbackResult> {
|
||||
const verifier = generateCodeVerifier()
|
||||
this.saveCodeVerifier(verifier)
|
||||
const challenge = generateCodeChallenge(verifier)
|
||||
const state = randomBytes(16).toString("hex")
|
||||
|
||||
const clientInfo = this.clientInformation()
|
||||
if (!clientInfo) {
|
||||
throw new Error("No client information available. Run login() or register a client first.")
|
||||
}
|
||||
|
||||
if (this.callbackPort === null) {
|
||||
this.callbackPort = await findAvailablePort()
|
||||
}
|
||||
|
||||
const authUrl = buildAuthorizationUrl(metadata.authorizationEndpoint, {
|
||||
clientId: clientInfo.clientId,
|
||||
redirectUri: this.redirectUrl(),
|
||||
codeChallenge: challenge,
|
||||
state,
|
||||
scopes: this.scopes,
|
||||
resource: metadata.resource,
|
||||
})
|
||||
|
||||
const callbackPromise = startCallbackServer(this.callbackPort)
|
||||
openBrowser(authUrl)
|
||||
|
||||
const result = await callbackPromise
|
||||
if (result.state !== state) {
|
||||
throw new Error("OAuth state mismatch")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
async login(): Promise<OAuthTokenData> {
|
||||
const metadata = await discoverOAuthServerMetadata(this.serverUrl)
|
||||
|
||||
const clientRegistrationStorage: ClientRegistrationStorage = {
|
||||
getClientRegistration: () => this.storedClientInfo,
|
||||
setClientRegistration: (_serverIdentifier: string, credentials: ClientCredentials) => {
|
||||
this.storedClientInfo = credentials
|
||||
},
|
||||
}
|
||||
|
||||
const clientInfo = await getOrRegisterClient({
|
||||
registrationEndpoint: metadata.registrationEndpoint,
|
||||
serverIdentifier: this.serverUrl,
|
||||
clientName: "oh-my-opencode",
|
||||
redirectUris: [this.redirectUrl()],
|
||||
tokenEndpointAuthMethod: "none",
|
||||
clientId: this.configClientId,
|
||||
storage: clientRegistrationStorage,
|
||||
})
|
||||
|
||||
if (!clientInfo) {
|
||||
throw new Error("Failed to obtain client credentials. Provide a clientId or ensure the server supports DCR.")
|
||||
}
|
||||
|
||||
this.storedClientInfo = clientInfo
|
||||
|
||||
const { code } = await this.redirectToAuthorization(metadata)
|
||||
const verifier = this.codeVerifier()
|
||||
if (!verifier) {
|
||||
throw new Error("Code verifier not found")
|
||||
}
|
||||
|
||||
const tokenResponse = await fetch(metadata.tokenEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "authorization_code",
|
||||
code,
|
||||
redirect_uri: this.redirectUrl(),
|
||||
client_id: clientInfo.clientId,
|
||||
code_verifier: verifier,
|
||||
...(metadata.resource ? { resource: metadata.resource } : {}),
|
||||
}).toString(),
|
||||
})
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
let errorDetail = `${tokenResponse.status}`
|
||||
try {
|
||||
const body = (await tokenResponse.json()) as Record<string, unknown>
|
||||
if (body.error) {
|
||||
errorDetail = `${tokenResponse.status} ${body.error}`
|
||||
if (body.error_description) {
|
||||
errorDetail += `: ${body.error_description}`
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Response body not JSON
|
||||
}
|
||||
throw new Error(`Token exchange failed: ${errorDetail}`)
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as Record<string, unknown>
|
||||
const accessToken = tokenData.access_token
|
||||
if (typeof accessToken !== "string") {
|
||||
throw new Error("Token response missing access_token")
|
||||
}
|
||||
|
||||
const oauthTokenData: OAuthTokenData = {
|
||||
accessToken,
|
||||
refreshToken: typeof tokenData.refresh_token === "string" ? tokenData.refresh_token : undefined,
|
||||
expiresAt:
|
||||
typeof tokenData.expires_in === "number" ? Math.floor(Date.now() / 1000) + tokenData.expires_in : undefined,
|
||||
clientInfo: {
|
||||
clientId: clientInfo.clientId,
|
||||
clientSecret: clientInfo.clientSecret,
|
||||
},
|
||||
}
|
||||
|
||||
this.saveTokens(oauthTokenData)
|
||||
return oauthTokenData
|
||||
}
|
||||
}
|
||||
|
||||
export { generateCodeVerifier, generateCodeChallenge, buildAuthorizationUrl, startCallbackServer }
|
||||
121
src/features/mcp-oauth/resource-indicator.test.ts
Normal file
121
src/features/mcp-oauth/resource-indicator.test.ts
Normal file
@ -0,0 +1,121 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import { addResourceToParams, getResourceIndicator } from "./resource-indicator"
|
||||
|
||||
describe("getResourceIndicator", () => {
|
||||
it("returns URL unchanged when already normalized", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("strips trailing slash", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("strips query parameters", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/v1?token=abc&debug=true"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/v1")
|
||||
})
|
||||
|
||||
it("strips fragment", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/v1#section"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/v1")
|
||||
})
|
||||
|
||||
it("strips query and trailing slash together", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/api/?key=val"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/api")
|
||||
})
|
||||
|
||||
it("preserves path segments", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com/org/project/v2"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com/org/project/v2")
|
||||
})
|
||||
|
||||
it("preserves port number", () => {
|
||||
// #given
|
||||
const url = "https://mcp.example.com:8443/api/"
|
||||
|
||||
// #when
|
||||
const result = getResourceIndicator(url)
|
||||
|
||||
// #then
|
||||
expect(result).toBe("https://mcp.example.com:8443/api")
|
||||
})
|
||||
})
|
||||
|
||||
describe("addResourceToParams", () => {
|
||||
it("sets resource parameter on empty params", () => {
|
||||
// #given
|
||||
const params = new URLSearchParams()
|
||||
const resource = "https://mcp.example.com"
|
||||
|
||||
// #when
|
||||
addResourceToParams(params, resource)
|
||||
|
||||
// #then
|
||||
expect(params.get("resource")).toBe("https://mcp.example.com")
|
||||
})
|
||||
|
||||
it("adds resource alongside existing parameters", () => {
|
||||
// #given
|
||||
const params = new URLSearchParams({ grant_type: "authorization_code" })
|
||||
const resource = "https://mcp.example.com/v1"
|
||||
|
||||
// #when
|
||||
addResourceToParams(params, resource)
|
||||
|
||||
// #then
|
||||
expect(params.get("grant_type")).toBe("authorization_code")
|
||||
expect(params.get("resource")).toBe("https://mcp.example.com/v1")
|
||||
})
|
||||
|
||||
it("overwrites existing resource parameter", () => {
|
||||
// #given
|
||||
const params = new URLSearchParams({ resource: "https://old.example.com" })
|
||||
const resource = "https://new.example.com"
|
||||
|
||||
// #when
|
||||
addResourceToParams(params, resource)
|
||||
|
||||
// #then
|
||||
expect(params.get("resource")).toBe("https://new.example.com")
|
||||
expect(params.getAll("resource")).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
16
src/features/mcp-oauth/resource-indicator.ts
Normal file
16
src/features/mcp-oauth/resource-indicator.ts
Normal file
@ -0,0 +1,16 @@
|
||||
export function getResourceIndicator(url: string): string {
|
||||
const parsed = new URL(url)
|
||||
parsed.search = ""
|
||||
parsed.hash = ""
|
||||
|
||||
let normalized = parsed.toString()
|
||||
if (normalized.endsWith("/")) {
|
||||
normalized = normalized.slice(0, -1)
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
export function addResourceToParams(params: URLSearchParams, resource: string): void {
|
||||
params.set("resource", resource)
|
||||
}
|
||||
60
src/features/mcp-oauth/schema.test.ts
Normal file
60
src/features/mcp-oauth/schema.test.ts
Normal file
@ -0,0 +1,60 @@
|
||||
/// <reference types="bun-types" />
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import { McpOauthSchema } from "./schema"
|
||||
|
||||
describe("McpOauthSchema", () => {
|
||||
test("parses empty oauth config", () => {
|
||||
//#given
|
||||
const input = {}
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.parse(input)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
test("parses oauth config with clientId", () => {
|
||||
//#given
|
||||
const input = { clientId: "client-123" }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.parse(input)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({ clientId: "client-123" })
|
||||
})
|
||||
|
||||
test("parses oauth config with scopes", () => {
|
||||
//#given
|
||||
const input = { scopes: ["openid", "profile"] }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.parse(input)
|
||||
|
||||
//#then
|
||||
expect(result).toEqual({ scopes: ["openid", "profile"] })
|
||||
})
|
||||
|
||||
test("rejects non-string clientId", () => {
|
||||
//#given
|
||||
const input = { clientId: 123 }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.safeParse(input)
|
||||
|
||||
//#then
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
test("rejects non-string scopes", () => {
|
||||
//#given
|
||||
const input = { scopes: ["openid", 42] }
|
||||
|
||||
//#when
|
||||
const result = McpOauthSchema.safeParse(input)
|
||||
|
||||
//#then
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
})
|
||||
8
src/features/mcp-oauth/schema.ts
Normal file
8
src/features/mcp-oauth/schema.ts
Normal file
@ -0,0 +1,8 @@
|
||||
import { z } from "zod"
|
||||
|
||||
export const McpOauthSchema = z.object({
|
||||
clientId: z.string().optional(),
|
||||
scopes: z.array(z.string()).optional(),
|
||||
})
|
||||
|
||||
export type McpOauth = z.infer<typeof McpOauthSchema>
|
||||
223
src/features/mcp-oauth/step-up.test.ts
Normal file
223
src/features/mcp-oauth/step-up.test.ts
Normal file
@ -0,0 +1,223 @@
|
||||
import { describe, expect, it } from "bun:test"
|
||||
import { isStepUpRequired, mergeScopes, parseWwwAuthenticate } from "./step-up"
|
||||
|
||||
describe("parseWwwAuthenticate", () => {
|
||||
it("parses scope from simple Bearer header", () => {
|
||||
// #given
|
||||
const header = 'Bearer scope="read write"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["read", "write"] })
|
||||
})
|
||||
|
||||
it("parses scope with error fields", () => {
|
||||
// #given
|
||||
const header = 'Bearer error="insufficient_scope", scope="admin"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
requiredScopes: ["admin"],
|
||||
error: "insufficient_scope",
|
||||
})
|
||||
})
|
||||
|
||||
it("parses all fields including error_description", () => {
|
||||
// #given
|
||||
const header =
|
||||
'Bearer realm="example", error="insufficient_scope", error_description="Need admin access", scope="admin write"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({
|
||||
requiredScopes: ["admin", "write"],
|
||||
error: "insufficient_scope",
|
||||
errorDescription: "Need admin access",
|
||||
})
|
||||
})
|
||||
|
||||
it("returns null for non-Bearer scheme", () => {
|
||||
// #given
|
||||
const header = 'Basic realm="example"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null when no scope parameter present", () => {
|
||||
// #given
|
||||
const header = 'Bearer error="invalid_token"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null for empty scope value", () => {
|
||||
// #given
|
||||
const header = 'Bearer scope=""'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null for bare Bearer with no params", () => {
|
||||
// #given
|
||||
const header = "Bearer"
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("handles case-insensitive Bearer prefix", () => {
|
||||
// #given
|
||||
const header = 'bearer scope="read"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["read"] })
|
||||
})
|
||||
|
||||
it("parses single scope value", () => {
|
||||
// #given
|
||||
const header = 'Bearer scope="admin"'
|
||||
|
||||
// #when
|
||||
const result = parseWwwAuthenticate(header)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["admin"] })
|
||||
})
|
||||
})
|
||||
|
||||
describe("mergeScopes", () => {
|
||||
it("merges new scopes into existing", () => {
|
||||
// #given
|
||||
const existing = ["read", "write"]
|
||||
const required = ["admin", "write"]
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read", "write", "admin"])
|
||||
})
|
||||
|
||||
it("returns required when existing is empty", () => {
|
||||
// #given
|
||||
const existing: string[] = []
|
||||
const required = ["read", "write"]
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read", "write"])
|
||||
})
|
||||
|
||||
it("returns existing when required is empty", () => {
|
||||
// #given
|
||||
const existing = ["read"]
|
||||
const required: string[] = []
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read"])
|
||||
})
|
||||
|
||||
it("deduplicates identical scopes", () => {
|
||||
// #given
|
||||
const existing = ["read", "write"]
|
||||
const required = ["read", "write"]
|
||||
|
||||
// #when
|
||||
const result = mergeScopes(existing, required)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual(["read", "write"])
|
||||
})
|
||||
})
|
||||
|
||||
describe("isStepUpRequired", () => {
|
||||
it("returns step-up info for 403 with WWW-Authenticate", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "www-authenticate": 'Bearer scope="admin"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["admin"] })
|
||||
})
|
||||
|
||||
it("returns null for non-403 status", () => {
|
||||
// #given
|
||||
const statusCode = 401
|
||||
const headers = { "www-authenticate": 'Bearer scope="admin"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("returns null when no WWW-Authenticate header", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "content-type": "application/json" }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
|
||||
it("handles capitalized WWW-Authenticate header", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "WWW-Authenticate": 'Bearer scope="read write"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toEqual({ requiredScopes: ["read", "write"] })
|
||||
})
|
||||
|
||||
it("returns null for 403 with unparseable WWW-Authenticate", () => {
|
||||
// #given
|
||||
const statusCode = 403
|
||||
const headers = { "www-authenticate": 'Basic realm="example"' }
|
||||
|
||||
// #when
|
||||
const result = isStepUpRequired(statusCode, headers)
|
||||
|
||||
// #then
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
79
src/features/mcp-oauth/step-up.ts
Normal file
79
src/features/mcp-oauth/step-up.ts
Normal file
@ -0,0 +1,79 @@
|
||||
export interface StepUpInfo {
|
||||
requiredScopes: string[]
|
||||
error?: string
|
||||
errorDescription?: string
|
||||
}
|
||||
|
||||
export function parseWwwAuthenticate(header: string): StepUpInfo | null {
|
||||
const trimmed = header.trim()
|
||||
const lowerHeader = trimmed.toLowerCase()
|
||||
const bearerIndex = lowerHeader.indexOf("bearer")
|
||||
if (bearerIndex === -1) {
|
||||
return null
|
||||
}
|
||||
|
||||
const params = trimmed.slice(bearerIndex + "bearer".length).trim()
|
||||
if (params.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const scope = extractParam(params, "scope")
|
||||
if (scope === null) {
|
||||
return null
|
||||
}
|
||||
|
||||
const requiredScopes = scope
|
||||
.split(/\s+/)
|
||||
.filter((s) => s.length > 0)
|
||||
|
||||
if (requiredScopes.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const info: StepUpInfo = { requiredScopes }
|
||||
|
||||
const error = extractParam(params, "error")
|
||||
if (error !== null) {
|
||||
info.error = error
|
||||
}
|
||||
|
||||
const errorDescription = extractParam(params, "error_description")
|
||||
if (errorDescription !== null) {
|
||||
info.errorDescription = errorDescription
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
function extractParam(params: string, name: string): string | null {
|
||||
const quotedPattern = new RegExp(`${name}="([^"]*)"`)
|
||||
const quotedMatch = quotedPattern.exec(params)
|
||||
if (quotedMatch) {
|
||||
return quotedMatch[1]
|
||||
}
|
||||
|
||||
const unquotedPattern = new RegExp(`${name}=([^\\s,]+)`)
|
||||
const unquotedMatch = unquotedPattern.exec(params)
|
||||
return unquotedMatch?.[1] ?? null
|
||||
}
|
||||
|
||||
export function mergeScopes(existing: string[], required: string[]): string[] {
|
||||
const set = new Set(existing)
|
||||
for (const scope of required) {
|
||||
set.add(scope)
|
||||
}
|
||||
return [...set]
|
||||
}
|
||||
|
||||
export function isStepUpRequired(statusCode: number, headers: Record<string, string>): StepUpInfo | null {
|
||||
if (statusCode !== 403) {
|
||||
return null
|
||||
}
|
||||
|
||||
const wwwAuth = headers["www-authenticate"] ?? headers["WWW-Authenticate"]
|
||||
if (!wwwAuth) {
|
||||
return null
|
||||
}
|
||||
|
||||
return parseWwwAuthenticate(wwwAuth)
|
||||
}
|
||||
136
src/features/mcp-oauth/storage.test.ts
Normal file
136
src/features/mcp-oauth/storage.test.ts
Normal file
@ -0,0 +1,136 @@
|
||||
import { describe, expect, test, beforeEach, afterEach } from "bun:test"
|
||||
import { existsSync, mkdirSync, rmSync, readFileSync, statSync, writeFileSync } from "node:fs"
|
||||
import { join } from "node:path"
|
||||
import { tmpdir } from "node:os"
|
||||
import {
|
||||
deleteToken,
|
||||
getMcpOauthStoragePath,
|
||||
listAllTokens,
|
||||
listTokensByHost,
|
||||
loadToken,
|
||||
saveToken,
|
||||
} from "./storage"
|
||||
import type { OAuthTokenData } from "./storage"
|
||||
|
||||
describe("mcp-oauth storage", () => {
|
||||
const TEST_CONFIG_DIR = join(tmpdir(), "mcp-oauth-test-" + Date.now())
|
||||
let originalConfigDir: string | undefined
|
||||
|
||||
beforeEach(() => {
|
||||
originalConfigDir = process.env.OPENCODE_CONFIG_DIR
|
||||
process.env.OPENCODE_CONFIG_DIR = TEST_CONFIG_DIR
|
||||
if (!existsSync(TEST_CONFIG_DIR)) {
|
||||
mkdirSync(TEST_CONFIG_DIR, { recursive: true })
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
if (originalConfigDir === undefined) {
|
||||
delete process.env.OPENCODE_CONFIG_DIR
|
||||
} else {
|
||||
process.env.OPENCODE_CONFIG_DIR = originalConfigDir
|
||||
}
|
||||
if (existsSync(TEST_CONFIG_DIR)) {
|
||||
rmSync(TEST_CONFIG_DIR, { recursive: true, force: true })
|
||||
}
|
||||
})
|
||||
|
||||
test("should save tokens with {host}/{resource} key and set 0600 permissions", () => {
|
||||
// #given
|
||||
const token: OAuthTokenData = {
|
||||
accessToken: "access-1",
|
||||
refreshToken: "refresh-1",
|
||||
expiresAt: 1710000000,
|
||||
clientInfo: { clientId: "client-1", clientSecret: "secret-1" },
|
||||
}
|
||||
|
||||
// #when
|
||||
const success = saveToken("https://example.com:443", "mcp/v1", token)
|
||||
const storagePath = getMcpOauthStoragePath()
|
||||
const parsed = JSON.parse(readFileSync(storagePath, "utf-8")) as Record<string, OAuthTokenData>
|
||||
const mode = statSync(storagePath).mode & 0o777
|
||||
|
||||
// #then
|
||||
expect(success).toBe(true)
|
||||
expect(Object.keys(parsed)).toEqual(["example.com/mcp/v1"])
|
||||
expect(parsed["example.com/mcp/v1"].accessToken).toBe("access-1")
|
||||
expect(mode).toBe(0o600)
|
||||
})
|
||||
|
||||
test("should load a saved token", () => {
|
||||
// #given
|
||||
const token: OAuthTokenData = { accessToken: "access-2", refreshToken: "refresh-2" }
|
||||
saveToken("api.example.com", "resource-a", token)
|
||||
|
||||
// #when
|
||||
const loaded = loadToken("api.example.com:8443", "resource-a")
|
||||
|
||||
// #then
|
||||
expect(loaded).toEqual(token)
|
||||
})
|
||||
|
||||
test("should delete a token", () => {
|
||||
// #given
|
||||
const token: OAuthTokenData = { accessToken: "access-3" }
|
||||
saveToken("api.example.com", "resource-b", token)
|
||||
|
||||
// #when
|
||||
const success = deleteToken("api.example.com", "resource-b")
|
||||
const loaded = loadToken("api.example.com", "resource-b")
|
||||
|
||||
// #then
|
||||
expect(success).toBe(true)
|
||||
expect(loaded).toBeNull()
|
||||
})
|
||||
|
||||
test("should list tokens by host", () => {
|
||||
// #given
|
||||
saveToken("api.example.com", "resource-a", { accessToken: "access-a" })
|
||||
saveToken("api.example.com", "resource-b", { accessToken: "access-b" })
|
||||
saveToken("other.example.com", "resource-c", { accessToken: "access-c" })
|
||||
|
||||
// #when
|
||||
const entries = listTokensByHost("api.example.com:5555")
|
||||
|
||||
// #then
|
||||
expect(Object.keys(entries).sort()).toEqual([
|
||||
"api.example.com/resource-a",
|
||||
"api.example.com/resource-b",
|
||||
])
|
||||
expect(entries["api.example.com/resource-a"].accessToken).toBe("access-a")
|
||||
})
|
||||
|
||||
test("should handle missing storage file", () => {
|
||||
// #given
|
||||
const storagePath = getMcpOauthStoragePath()
|
||||
if (existsSync(storagePath)) {
|
||||
rmSync(storagePath, { force: true })
|
||||
}
|
||||
|
||||
// #when
|
||||
const loaded = loadToken("api.example.com", "resource-a")
|
||||
const entries = listTokensByHost("api.example.com")
|
||||
|
||||
// #then
|
||||
expect(loaded).toBeNull()
|
||||
expect(entries).toEqual({})
|
||||
})
|
||||
|
||||
test("should handle invalid JSON", () => {
|
||||
// #given
|
||||
const storagePath = getMcpOauthStoragePath()
|
||||
const dir = join(storagePath, "..")
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true })
|
||||
}
|
||||
writeFileSync(storagePath, "{not-valid-json", "utf-8")
|
||||
|
||||
// #when
|
||||
const loaded = loadToken("api.example.com", "resource-a")
|
||||
const entries = listTokensByHost("api.example.com")
|
||||
|
||||
// #then
|
||||
expect(loaded).toBeNull()
|
||||
expect(entries).toEqual({})
|
||||
})
|
||||
})
|
||||
153
src/features/mcp-oauth/storage.ts
Normal file
153
src/features/mcp-oauth/storage.ts
Normal file
@ -0,0 +1,153 @@
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, unlinkSync, writeFileSync } from "node:fs"
|
||||
import { dirname, join } from "node:path"
|
||||
import { getOpenCodeConfigDir } from "../../shared"
|
||||
|
||||
export interface OAuthTokenData {
|
||||
accessToken: string
|
||||
refreshToken?: string
|
||||
expiresAt?: number
|
||||
clientInfo?: {
|
||||
clientId: string
|
||||
clientSecret?: string
|
||||
}
|
||||
}
|
||||
|
||||
type TokenStore = Record<string, OAuthTokenData>
|
||||
|
||||
const STORAGE_FILE_NAME = "mcp-oauth.json"
|
||||
|
||||
export function getMcpOauthStoragePath(): string {
|
||||
return join(getOpenCodeConfigDir({ binary: "opencode" }), STORAGE_FILE_NAME)
|
||||
}
|
||||
|
||||
function normalizeHost(serverHost: string): string {
|
||||
let host = serverHost.trim()
|
||||
if (!host) return host
|
||||
|
||||
if (host.includes("://")) {
|
||||
try {
|
||||
host = new URL(host).hostname
|
||||
} catch {
|
||||
host = host.split("/")[0]
|
||||
}
|
||||
} else {
|
||||
host = host.split("/")[0]
|
||||
}
|
||||
|
||||
if (host.startsWith("[")) {
|
||||
const closing = host.indexOf("]")
|
||||
if (closing !== -1) {
|
||||
host = host.slice(0, closing + 1)
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
if (host.includes(":")) {
|
||||
host = host.split(":")[0]
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
function normalizeResource(resource: string): string {
|
||||
return resource.replace(/^\/+/, "")
|
||||
}
|
||||
|
||||
function buildKey(serverHost: string, resource: string): string {
|
||||
const host = normalizeHost(serverHost)
|
||||
const normalizedResource = normalizeResource(resource)
|
||||
return `${host}/${normalizedResource}`
|
||||
}
|
||||
|
||||
function readStore(): TokenStore | null {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
if (!existsSync(filePath)) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(filePath, "utf-8")
|
||||
return JSON.parse(content) as TokenStore
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function writeStore(store: TokenStore): boolean {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
|
||||
try {
|
||||
const dir = dirname(filePath)
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true })
|
||||
}
|
||||
|
||||
writeFileSync(filePath, JSON.stringify(store, null, 2), { encoding: "utf-8", mode: 0o600 })
|
||||
chmodSync(filePath, 0o600)
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export function loadToken(serverHost: string, resource: string): OAuthTokenData | null {
|
||||
const store = readStore()
|
||||
if (!store) return null
|
||||
|
||||
const key = buildKey(serverHost, resource)
|
||||
return store[key] ?? null
|
||||
}
|
||||
|
||||
export function saveToken(serverHost: string, resource: string, token: OAuthTokenData): boolean {
|
||||
const store = readStore() ?? {}
|
||||
const key = buildKey(serverHost, resource)
|
||||
store[key] = token
|
||||
return writeStore(store)
|
||||
}
|
||||
|
||||
export function deleteToken(serverHost: string, resource: string): boolean {
|
||||
const store = readStore()
|
||||
if (!store) return true
|
||||
|
||||
const key = buildKey(serverHost, resource)
|
||||
if (!(key in store)) {
|
||||
return true
|
||||
}
|
||||
|
||||
delete store[key]
|
||||
|
||||
if (Object.keys(store).length === 0) {
|
||||
try {
|
||||
const filePath = getMcpOauthStoragePath()
|
||||
if (existsSync(filePath)) {
|
||||
unlinkSync(filePath)
|
||||
}
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return writeStore(store)
|
||||
}
|
||||
|
||||
export function listTokensByHost(serverHost: string): TokenStore {
|
||||
const store = readStore()
|
||||
if (!store) return {}
|
||||
|
||||
const host = normalizeHost(serverHost)
|
||||
const prefix = `${host}/`
|
||||
const result: TokenStore = {}
|
||||
|
||||
for (const [key, value] of Object.entries(store)) {
|
||||
if (key.startsWith(prefix)) {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
export function listAllTokens(): TokenStore {
|
||||
return readStore() ?? {}
|
||||
}
|
||||
@ -3,8 +3,6 @@ import { SkillMcpManager } from "./manager"
|
||||
import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types"
|
||||
import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types"
|
||||
|
||||
|
||||
|
||||
// Mock the MCP SDK transports to avoid network calls
|
||||
const mockHttpConnect = mock(() => Promise.reject(new Error("Mocked HTTP connection failure")))
|
||||
const mockHttpClose = mock(() => Promise.resolve())
|
||||
@ -24,6 +22,21 @@ mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
const mockTokens = mock(() => null as { accessToken: string; refreshToken?: string; expiresAt?: number } | null)
|
||||
const mockLogin = mock(() => Promise.resolve({ accessToken: "new-token" }))
|
||||
|
||||
mock.module("../mcp-oauth/provider", () => ({
|
||||
McpOAuthProvider: class MockMcpOAuthProvider {
|
||||
constructor(public options: { serverUrl: string; clientId?: string; scopes?: string[] }) {}
|
||||
tokens() {
|
||||
return mockTokens()
|
||||
}
|
||||
async login() {
|
||||
return mockLogin()
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
|
||||
|
||||
@ -518,7 +531,6 @@ describe("SkillMcpManager", () => {
|
||||
skillName: "retry-skill",
|
||||
}
|
||||
|
||||
// Mock client that fails first time with "Not connected", then succeeds
|
||||
let callCount = 0
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
@ -531,7 +543,6 @@ describe("SkillMcpManager", () => {
|
||||
close: mock(() => Promise.resolve()),
|
||||
}
|
||||
|
||||
// Spy on getOrCreateClientWithRetry to inject mock client
|
||||
const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry")
|
||||
getOrCreateSpy.mockResolvedValue(mockClient)
|
||||
|
||||
@ -539,9 +550,9 @@ describe("SkillMcpManager", () => {
|
||||
const result = await manager.callTool(info, context, "test-tool", {})
|
||||
|
||||
// #then
|
||||
expect(callCount).toBe(2) // First call fails, second succeeds
|
||||
expect(callCount).toBe(2)
|
||||
expect(result).toEqual([{ type: "text", text: "success" }])
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(2) // Called twice due to retry
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it("should fail after 3 retry attempts", async () => {
|
||||
@ -558,7 +569,6 @@ describe("SkillMcpManager", () => {
|
||||
skillName: "fail-skill",
|
||||
}
|
||||
|
||||
// Mock client that always fails with "Not connected"
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
throw new Error("Not connected")
|
||||
@ -573,7 +583,7 @@ describe("SkillMcpManager", () => {
|
||||
await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(
|
||||
/Failed after 3 reconnection attempts/
|
||||
)
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(3) // Initial + 2 retries
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it("should not retry on non-connection errors", async () => {
|
||||
@ -590,7 +600,6 @@ describe("SkillMcpManager", () => {
|
||||
skillName: "error-skill",
|
||||
}
|
||||
|
||||
// Mock client that fails with non-connection error
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
throw new Error("Tool not found")
|
||||
@ -605,7 +614,194 @@ describe("SkillMcpManager", () => {
|
||||
await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(
|
||||
"Tool not found"
|
||||
)
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(1) // No retry
|
||||
expect(getOrCreateSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe("OAuth integration", () => {
|
||||
beforeEach(() => {
|
||||
mockTokens.mockClear()
|
||||
mockLogin.mockClear()
|
||||
})
|
||||
|
||||
it("injects Authorization header when oauth config has stored tokens", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "oauth-server",
|
||||
skillName: "oauth-skill",
|
||||
sessionID: "session-oauth-1",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
scopes: ["read", "write"],
|
||||
},
|
||||
}
|
||||
mockTokens.mockReturnValue({ accessToken: "stored-access-token" })
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.Authorization).toBe("Bearer stored-access-token")
|
||||
})
|
||||
|
||||
it("does not inject Authorization header when no stored tokens exist and login fails", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "oauth-no-token",
|
||||
skillName: "oauth-skill",
|
||||
sessionID: "session-oauth-2",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
},
|
||||
}
|
||||
mockTokens.mockReturnValue(null)
|
||||
mockLogin.mockRejectedValue(new Error("Login failed"))
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.Authorization).toBeUndefined()
|
||||
})
|
||||
|
||||
it("preserves existing static headers alongside OAuth token", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "oauth-with-headers",
|
||||
skillName: "oauth-skill",
|
||||
sessionID: "session-oauth-3",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
headers: {
|
||||
"X-Custom": "custom-value",
|
||||
},
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
},
|
||||
}
|
||||
mockTokens.mockReturnValue({ accessToken: "oauth-token" })
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.["X-Custom"]).toBe("custom-value")
|
||||
expect(headers?.Authorization).toBe("Bearer oauth-token")
|
||||
})
|
||||
|
||||
it("does not create auth provider when oauth config is absent", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "no-oauth-server",
|
||||
skillName: "test-skill",
|
||||
sessionID: "session-no-oauth",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
headers: {
|
||||
Authorization: "Bearer static-token",
|
||||
},
|
||||
}
|
||||
|
||||
// #when
|
||||
try {
|
||||
await manager.getOrCreateClient(info, config)
|
||||
} catch { /* connection fails in test */ }
|
||||
|
||||
// #then
|
||||
const headers = lastTransportInstance.options?.requestInit?.headers as Record<string, string> | undefined
|
||||
expect(headers?.Authorization).toBe("Bearer static-token")
|
||||
expect(mockTokens).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("handles step-up auth by triggering re-login on 403 with scope", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "stepup-server",
|
||||
skillName: "stepup-skill",
|
||||
sessionID: "session-stepup-1",
|
||||
}
|
||||
const config: ClaudeCodeMcpServer = {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
oauth: {
|
||||
clientId: "my-client",
|
||||
scopes: ["read"],
|
||||
},
|
||||
}
|
||||
const context: SkillMcpServerContext = {
|
||||
config,
|
||||
skillName: "stepup-skill",
|
||||
}
|
||||
|
||||
mockTokens.mockReturnValue({ accessToken: "initial-token" })
|
||||
mockLogin.mockResolvedValue({ accessToken: "upgraded-token" })
|
||||
|
||||
let callCount = 0
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
throw new Error('403 WWW-Authenticate: Bearer scope="admin write"')
|
||||
}
|
||||
return { content: [{ type: "text", text: "success" }] }
|
||||
}),
|
||||
close: mock(() => Promise.resolve()),
|
||||
}
|
||||
|
||||
const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry")
|
||||
getOrCreateSpy.mockResolvedValue(mockClient)
|
||||
|
||||
// #when
|
||||
const result = await manager.callTool(info, context, "test-tool", {})
|
||||
|
||||
// #then
|
||||
expect(result).toEqual([{ type: "text", text: "success" }])
|
||||
expect(mockLogin).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it("does not attempt step-up when oauth config is absent", async () => {
|
||||
// #given
|
||||
const info: SkillMcpClientInfo = {
|
||||
serverName: "no-stepup-server",
|
||||
skillName: "no-stepup-skill",
|
||||
sessionID: "session-no-stepup",
|
||||
}
|
||||
const context: SkillMcpServerContext = {
|
||||
config: {
|
||||
url: "https://mcp.example.com/mcp",
|
||||
},
|
||||
skillName: "no-stepup-skill",
|
||||
}
|
||||
|
||||
const mockClient = {
|
||||
callTool: mock(async () => {
|
||||
throw new Error('403 WWW-Authenticate: Bearer scope="admin"')
|
||||
}),
|
||||
close: mock(() => Promise.resolve()),
|
||||
}
|
||||
|
||||
const getOrCreateSpy = spyOn(manager as any, "getOrCreateClientWithRetry")
|
||||
getOrCreateSpy.mockResolvedValue(mockClient)
|
||||
|
||||
// #when / #then
|
||||
await expect(manager.callTool(info, context, "test-tool", {})).rejects.toThrow(/403/)
|
||||
expect(mockLogin).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -4,6 +4,8 @@ import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/
|
||||
import type { Tool, Resource, Prompt } from "@modelcontextprotocol/sdk/types.js"
|
||||
import type { ClaudeCodeMcpServer } from "../claude-code-mcp-loader/types"
|
||||
import { expandEnvVarsInObject } from "../claude-code-mcp-loader/env-expander"
|
||||
import { McpOAuthProvider } from "../mcp-oauth/provider"
|
||||
import { isStepUpRequired, mergeScopes } from "../mcp-oauth/step-up"
|
||||
import { createCleanMcpEnvironment } from "./env-cleaner"
|
||||
import type { SkillMcpClientInfo, SkillMcpServerContext } from "./types"
|
||||
|
||||
@ -60,6 +62,7 @@ function getConnectionType(config: ClaudeCodeMcpServer): ConnectionType | null {
|
||||
export class SkillMcpManager {
|
||||
private clients: Map<string, ManagedClient> = new Map()
|
||||
private pendingConnections: Map<string, Promise<Client>> = new Map()
|
||||
private authProviders: Map<string, McpOAuthProvider> = new Map()
|
||||
private cleanupRegistered = false
|
||||
private cleanupInterval: ReturnType<typeof setInterval> | null = null
|
||||
private readonly IDLE_TIMEOUT = 5 * 60 * 1000
|
||||
@ -68,6 +71,28 @@ export class SkillMcpManager {
|
||||
return `${info.sessionID}:${info.skillName}:${info.serverName}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create an McpOAuthProvider for a given server URL + oauth config.
|
||||
* Providers are cached by server URL to reuse tokens across reconnections.
|
||||
*/
|
||||
private getOrCreateAuthProvider(
|
||||
serverUrl: string,
|
||||
oauth: NonNullable<ClaudeCodeMcpServer["oauth"]>
|
||||
): McpOAuthProvider {
|
||||
const existing = this.authProviders.get(serverUrl)
|
||||
if (existing) {
|
||||
return existing
|
||||
}
|
||||
|
||||
const provider = new McpOAuthProvider({
|
||||
serverUrl,
|
||||
clientId: oauth.clientId,
|
||||
scopes: oauth.scopes,
|
||||
})
|
||||
this.authProviders.set(serverUrl, provider)
|
||||
return provider
|
||||
}
|
||||
|
||||
private registerProcessCleanup(): void {
|
||||
if (this.cleanupRegistered) return
|
||||
this.cleanupRegistered = true
|
||||
@ -204,7 +229,30 @@ export class SkillMcpManager {
|
||||
// Build request init with headers if provided
|
||||
const requestInit: RequestInit = {}
|
||||
if (config.headers && Object.keys(config.headers).length > 0) {
|
||||
requestInit.headers = config.headers
|
||||
requestInit.headers = { ...config.headers }
|
||||
}
|
||||
|
||||
let authProvider: McpOAuthProvider | undefined
|
||||
if (config.oauth) {
|
||||
authProvider = this.getOrCreateAuthProvider(config.url, config.oauth)
|
||||
let tokenData = authProvider.tokens()
|
||||
|
||||
const isExpired = tokenData?.expiresAt != null && tokenData.expiresAt < Math.floor(Date.now() / 1000)
|
||||
if (!tokenData || isExpired) {
|
||||
try {
|
||||
tokenData = await authProvider.login()
|
||||
} catch {
|
||||
// Login failed — proceed without auth header
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenData) {
|
||||
const existingHeaders = (requestInit.headers ?? {}) as Record<string, string>
|
||||
requestInit.headers = {
|
||||
...existingHeaders,
|
||||
Authorization: `Bearer ${tokenData.accessToken}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(url, {
|
||||
@ -460,6 +508,12 @@ export class SkillMcpManager {
|
||||
lastError = error instanceof Error ? error : new Error(String(error))
|
||||
const errorMessage = lastError.message.toLowerCase()
|
||||
|
||||
const stepUpHandled = await this.handleStepUpIfNeeded(lastError, config)
|
||||
if (stepUpHandled) {
|
||||
await this.forceReconnect(info)
|
||||
continue
|
||||
}
|
||||
|
||||
if (!errorMessage.includes("not connected")) {
|
||||
throw lastError
|
||||
}
|
||||
@ -470,23 +524,66 @@ export class SkillMcpManager {
|
||||
)
|
||||
}
|
||||
|
||||
const key = this.getClientKey(info)
|
||||
const existing = this.clients.get(key)
|
||||
if (existing) {
|
||||
this.clients.delete(key)
|
||||
try {
|
||||
await existing.client.close()
|
||||
} catch { /* process may already be terminated */ }
|
||||
try {
|
||||
await existing.transport.close()
|
||||
} catch { /* transport may already be terminated */ }
|
||||
}
|
||||
await this.forceReconnect(info)
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError || new Error("Operation failed with unknown error")
|
||||
}
|
||||
|
||||
private async handleStepUpIfNeeded(
|
||||
error: Error,
|
||||
config: ClaudeCodeMcpServer
|
||||
): Promise<boolean> {
|
||||
if (!config.oauth || !config.url) {
|
||||
return false
|
||||
}
|
||||
|
||||
const statusMatch = /\b403\b/.exec(error.message)
|
||||
if (!statusMatch) {
|
||||
return false
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {}
|
||||
const wwwAuthMatch = /WWW-Authenticate:\s*(.+)/i.exec(error.message)
|
||||
if (wwwAuthMatch?.[1]) {
|
||||
headers["www-authenticate"] = wwwAuthMatch[1]
|
||||
}
|
||||
|
||||
const stepUp = isStepUpRequired(403, headers)
|
||||
if (!stepUp) {
|
||||
return false
|
||||
}
|
||||
|
||||
const currentScopes = config.oauth.scopes ?? []
|
||||
const merged = mergeScopes(currentScopes, stepUp.requiredScopes)
|
||||
config.oauth.scopes = merged
|
||||
|
||||
this.authProviders.delete(config.url)
|
||||
const provider = this.getOrCreateAuthProvider(config.url, config.oauth)
|
||||
|
||||
try {
|
||||
await provider.login()
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private async forceReconnect(info: SkillMcpClientInfo): Promise<void> {
|
||||
const key = this.getClientKey(info)
|
||||
const existing = this.clients.get(key)
|
||||
if (existing) {
|
||||
this.clients.delete(key)
|
||||
try {
|
||||
await existing.client.close()
|
||||
} catch { /* process may already be terminated */ }
|
||||
try {
|
||||
await existing.transport.close()
|
||||
} catch { /* transport may already be terminated */ }
|
||||
}
|
||||
}
|
||||
|
||||
private async getOrCreateClientWithRetry(
|
||||
info: SkillMcpClientInfo,
|
||||
config: ClaudeCodeMcpServer
|
||||
|
||||
5
src/types/request-info.d.ts
vendored
Normal file
5
src/types/request-info.d.ts
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
declare global {
|
||||
type RequestInfo = string | URL
|
||||
}
|
||||
|
||||
export {}
|
||||
Loading…
x
Reference in New Issue
Block a user