应文浩wenhao.ying@xiaobao100.com bce2612b64 feat: 完善具体实现
2026-04-06 15:25:34 +08:00

289 lines
10 KiB
C#

using System.Collections.Concurrent;
using System.Text.Json;
using System.Threading.Channels;
namespace FreeCode.Mcp;
public sealed class McpClient : IAsyncDisposable
{
private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web);
private readonly IMcpTransport _transport;
private readonly ConcurrentDictionary<string, TaskCompletionSource<JsonElement?>> _pendingRequests = new();
private readonly Channel<JsonRpcMessage> _incoming = Channel.CreateUnbounded<JsonRpcMessage>();
private readonly CancellationTokenSource _cts = new();
private readonly Task _dispatchLoop;
private int _requestCounter;
private readonly SemaphoreSlim _sendLock = new(1, 1);
public McpClient(IMcpTransport transport)
{
_transport = transport;
_dispatchLoop = Task.Run(DispatchLoopAsync);
}
public bool IsConnected { get; private set; }
public ServerCapabilities Capabilities { get; private set; } = new();
public ServerInfo? ServerInfo { get; private set; }
public async Task ConnectAsync(CancellationToken ct = default)
{
await _transport.StartAsync(ct).ConfigureAwait(false);
var initResult = await SendRequestAsync<InitializeResult>("initialize", new
{
protocolVersion = "2024-11-05",
clientInfo = new { name = "free-code", version = "10.0.0" },
capabilities = new
{
roots = new { },
sampling = new { },
elicitation = new { }
}
}, ct).ConfigureAwait(false);
Capabilities = initResult?.Capabilities ?? new ServerCapabilities();
ServerInfo = initResult?.ServerInfo ?? new ServerInfo("unknown", "unknown");
await SendNotificationAsync("initialized", new { }, ct).ConfigureAwait(false);
IsConnected = true;
}
public async Task<ListToolsResult> ListToolsAsync(CancellationToken ct = default)
{
return await SendRequestAsync<ListToolsResult>("tools/list", new { }, ct).ConfigureAwait(false) ?? new ListToolsResult([]);
}
public async Task<CallToolResult> CallToolAsync(string toolName, object? parameters, CancellationToken ct = default)
{
return await SendRequestAsync<CallToolResult>("tools/call", new { name = toolName, arguments = parameters }, ct).ConfigureAwait(false)
?? new CallToolResult(JsonDocument.Parse("null").RootElement.Clone(), false);
}
public async Task<ListResourcesResult> ListResourcesAsync(CancellationToken ct = default)
{
return await SendRequestAsync<ListResourcesResult>("resources/list", new { }, ct).ConfigureAwait(false) ?? new ListResourcesResult([]);
}
public async Task<ListPromptsResult> ListPromptsAsync(CancellationToken ct = default)
{
return await SendRequestAsync<ListPromptsResult>("prompts/list", new { }, ct).ConfigureAwait(false) ?? new ListPromptsResult([]);
}
public async Task<ReadResourceResult> ReadResourceAsync(string resourceUri, CancellationToken ct = default)
{
return await SendRequestAsync<ReadResourceResult>("resources/read", new { uri = resourceUri }, ct).ConfigureAwait(false) ?? new ReadResourceResult([]);
}
public async Task DisconnectAsync()
{
IsConnected = false;
try
{
if (_transport.IncomingLines.Completion.IsCompletedSuccessfully is false)
{
await SendRequestAsync<JsonElement?>("shutdown", null, CancellationToken.None).ConfigureAwait(false);
await SendNotificationAsync("exit", null, CancellationToken.None).ConfigureAwait(false);
}
}
catch (Exception)
{
/* best-effort shutdown notification */
}
_cts.Cancel();
await _transport.DisposeAsync().ConfigureAwait(false);
}
public async ValueTask DisposeAsync()
{
await DisconnectAsync().ConfigureAwait(false);
_cts.Dispose();
}
private async Task<T?> SendRequestAsync<T>(string method, object? parameters, CancellationToken ct)
{
var id = Interlocked.Increment(ref _requestCounter).ToString(System.Globalization.CultureInfo.InvariantCulture);
var tcs = new TaskCompletionSource<JsonElement?>(TaskCreationOptions.RunContinuationsAsynchronously);
_pendingRequests[id] = tcs;
await _sendLock.WaitAsync(ct).ConfigureAwait(false);
try
{
await _transport.SendLineAsync(SerializeRequest(id, method, parameters), ct).ConfigureAwait(false);
}
finally
{
_sendLock.Release();
}
using var registration = ct.Register(() => tcs.TrySetCanceled(ct));
var result = await tcs.Task.ConfigureAwait(false);
if (result is null)
{
return default;
}
return result.Value.Deserialize<T>(JsonOptions);
}
private async Task SendNotificationAsync(string method, object? parameters, CancellationToken ct)
{
await _sendLock.WaitAsync(ct).ConfigureAwait(false);
try
{
await _transport.SendLineAsync(SerializeNotification(method, parameters), ct).ConfigureAwait(false);
}
finally
{
_sendLock.Release();
}
}
private async Task DispatchLoopAsync()
{
await foreach (var line in _transport.IncomingLines.ReadAllAsync(_cts.Token).ConfigureAwait(false))
{
JsonRpcMessage? message = ParseMessage(line);
if (message is null)
{
continue;
}
if (message is JsonRpcResponse response)
{
if (_pendingRequests.TryRemove(response.Id, out var tcs))
{
if (response.Error is not null)
{
tcs.TrySetException(new InvalidOperationException($"MCP error {response.Error.Code}: {response.Error.Message}"));
}
else
{
tcs.TrySetResult(response.Result);
}
}
}
}
}
private static JsonRpcMessage? ParseMessage(string line)
{
using var document = JsonDocument.Parse(line);
var root = document.RootElement;
if (root.TryGetProperty("method", out var methodElement))
{
var method = methodElement.GetString() ?? string.Empty;
var hasId = root.TryGetProperty("id", out var idElement);
if (hasId)
{
object? parameters = root.TryGetProperty("params", out var paramsElement) ? paramsElement.Clone() : null;
return new JsonRpcRequest(idElement.ToString(), method, parameters);
}
object? notificationParams = root.TryGetProperty("params", out var notificationParamsElement) ? notificationParamsElement.Clone() : null;
return new JsonRpcNotification(method, notificationParams);
}
if (root.TryGetProperty("id", out var responseId))
{
JsonElement? result = root.TryGetProperty("result", out var resultElement) ? resultElement.Clone() : null;
JsonRpcError? error = null;
if (root.TryGetProperty("error", out var errorElement))
{
error = new JsonRpcError(
errorElement.TryGetProperty("code", out var codeElement) ? codeElement.GetInt32() : -1,
errorElement.TryGetProperty("message", out var messageElement) ? messageElement.GetString() ?? string.Empty : string.Empty);
}
return new JsonRpcResponse(responseId.ToString(), result, error);
}
return null;
}
private static string SerializeRequest(string id, string method, object? parameters)
{
using var stream = new MemoryStream();
using (var writer = new Utf8JsonWriter(stream))
{
writer.WriteStartObject();
writer.WriteString("id", id);
writer.WriteString("method", method);
if (parameters is not null)
{
writer.WritePropertyName("params");
WriteValue(writer, parameters);
}
writer.WriteEndObject();
}
return System.Text.Encoding.UTF8.GetString(stream.ToArray());
}
private static string SerializeNotification(string method, object? parameters)
{
using var stream = new MemoryStream();
using (var writer = new Utf8JsonWriter(stream))
{
writer.WriteStartObject();
writer.WriteString("method", method);
if (parameters is not null)
{
writer.WritePropertyName("params");
WriteValue(writer, parameters);
}
writer.WriteEndObject();
}
return System.Text.Encoding.UTF8.GetString(stream.ToArray());
}
private static void WriteValue(Utf8JsonWriter writer, object value)
{
switch (value)
{
case JsonElement element:
element.WriteTo(writer);
break;
case IDictionary<string, object?> dictionary:
writer.WriteStartObject();
foreach (var (key, itemValue) in dictionary)
{
writer.WritePropertyName(key);
WriteNullableValue(writer, itemValue);
}
writer.WriteEndObject();
break;
case IEnumerable<KeyValuePair<string, object?>> pairs:
writer.WriteStartObject();
foreach (var pair in pairs)
{
writer.WritePropertyName(pair.Key);
WriteNullableValue(writer, pair.Value);
}
writer.WriteEndObject();
break;
default:
JsonSerializer.Serialize(writer, value, JsonOptions);
break;
}
}
private static void WriteNullableValue(Utf8JsonWriter writer, object? value)
{
if (value is null)
{
writer.WriteNullValue();
return;
}
WriteValue(writer, value);
}
private sealed record InitializeResult(ServerCapabilities? Capabilities, ServerInfo? ServerInfo);
}