应文浩wenhao.ying@xiaobao100.com e25ac591a7 init easy-code
2026-04-06 07:24:24 +08:00

233 lines
9.4 KiB
C#

using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text;
using System.Text.Json;
using FreeCode.Core.Interfaces;
using FreeCode.Core.Models;
using Microsoft.Extensions.Configuration;
namespace FreeCode.ApiProviders;
public sealed class BedrockProvider : FreeCode.Core.Interfaces.IApiProvider
{
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web);
private readonly HttpClient _httpClient;
private readonly string _baseUrl;
private readonly string _region;
private readonly string _modelPrefix;
private readonly string? _bearerToken;
private readonly string? _apiKey;
private readonly bool _skipAuth;
private readonly IRateLimitService? _rateLimitService;
public BedrockProvider(HttpClient? httpClient = null, IConfiguration? configuration = null, IRateLimitService? rateLimitService = null)
{
_httpClient = httpClient ?? new HttpClient();
_rateLimitService = rateLimitService;
_region = GetSetting(configuration, "AWS_REGION", "Bedrock:Region")
?? GetSetting(configuration, "AWS_DEFAULT_REGION", "Bedrock:DefaultRegion")
?? "us-east-1";
_baseUrl = GetSetting(configuration, "ANTHROPIC_BEDROCK_BASE_URL", "Bedrock:BaseUrl")
?? $"https://bedrock-runtime.{_region}.amazonaws.com";
_modelPrefix = GetSetting(configuration, "AWS_BEDROCK_MODEL_PREFIX", "Bedrock:ModelPrefix") ?? "us.anthropic";
_bearerToken = ResolveBearerToken(configuration);
_apiKey = GetSetting(configuration, "ANTHROPIC_BEDROCK_API_KEY", "Bedrock:ApiKey");
_skipAuth = string.Equals(GetSetting(configuration, "CLAUDE_CODE_SKIP_BEDROCK_AUTH", "Bedrock:SkipAuth"), "1", StringComparison.OrdinalIgnoreCase);
}
public async IAsyncEnumerable<SDKMessage> StreamAsync(ApiRequest request, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
{
var model = request.Model ?? "claude-sonnet-4-6";
var modelId = model.Contains(':', StringComparison.Ordinal) ? model : $"{_modelPrefix}.{model}";
using var httpRequest = new HttpRequestMessage(HttpMethod.Post, new Uri(new Uri(_baseUrl.TrimEnd('/')), $"/model/{Uri.EscapeDataString(modelId)}/invoke-with-response-stream"));
httpRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
if (!_skipAuth && !string.IsNullOrWhiteSpace(_bearerToken))
{
httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _bearerToken);
}
else if (!_skipAuth && !string.IsNullOrWhiteSpace(_apiKey))
{
httpRequest.Headers.Add("x-api-key", _apiKey);
}
var payload = new
{
anthropic_version = "2023-06-01",
system = request.SystemPrompt,
messages = request.Messages,
tools = request.Tools,
max_tokens = 4096,
stream = true
};
httpRequest.Content = JsonContent.Create(payload, options: SerializerOptions);
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, ct).ConfigureAwait(false);
var responseHeaders = ToHeaderDictionary(response.Headers, response.Content.Headers);
if (_rateLimitService?.CanProceed(responseHeaders) == false)
{
throw CreateRateLimitException(responseHeaders);
}
response.EnsureSuccessStatusCode();
await foreach (var data in ReadSseDataAsync(await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false), ct).ConfigureAwait(false))
{
using var document = JsonDocument.Parse(data);
var root = document.RootElement;
if (!root.TryGetProperty("type", out var typeProperty))
{
continue;
}
switch (typeProperty.GetString())
{
case "content_block_delta":
if (root.TryGetProperty("delta", out var delta))
{
if (delta.TryGetProperty("text", out var text))
{
yield return new SDKMessage.StreamingDelta(text.GetString() ?? string.Empty);
}
else if (delta.TryGetProperty("partial_json", out var partialJson))
{
yield return new SDKMessage.StreamingDelta(partialJson.GetString() ?? string.Empty);
}
}
break;
case "content_block_start":
if (root.TryGetProperty("content_block", out var contentBlock)
&& contentBlock.TryGetProperty("type", out var blockType)
&& string.Equals(blockType.GetString(), "tool_use", StringComparison.OrdinalIgnoreCase))
{
var input = contentBlock.TryGetProperty("input", out var inputProperty)
? inputProperty.Clone()
: JsonDocument.Parse("{}").RootElement.Clone();
yield return new SDKMessage.ToolUseStart(
contentBlock.TryGetProperty("id", out var idProperty) ? idProperty.GetString() ?? string.Empty : string.Empty,
contentBlock.TryGetProperty("name", out var nameProperty) ? nameProperty.GetString() ?? string.Empty : string.Empty,
input);
}
break;
case "message_stop":
case "response.completed":
yield break;
}
}
}
private static string? ResolveBearerToken(IConfiguration? configuration)
{
var credentialsFromEnv = GetSetting(configuration,
"AWS_BEARER_TOKEN_BEDROCK",
"AWS_SESSION_TOKEN",
"AWS_ACCESS_TOKEN",
"Bedrock:BearerToken");
if (!string.IsNullOrWhiteSpace(credentialsFromEnv))
{
return credentialsFromEnv;
}
var accessKey = GetSetting(configuration, "AWS_ACCESS_KEY_ID", "Bedrock:AccessKeyId");
var secretKey = GetSetting(configuration, "AWS_SECRET_ACCESS_KEY", "Bedrock:SecretAccessKey");
if (!string.IsNullOrWhiteSpace(accessKey) && !string.IsNullOrWhiteSpace(secretKey))
{
var sessionToken = GetSetting(configuration, "AWS_SESSION_TOKEN", "Bedrock:SessionToken");
return string.IsNullOrWhiteSpace(sessionToken)
? $"{accessKey}:{secretKey}"
: $"{accessKey}:{secretKey}:{sessionToken}";
}
return null;
}
private static string? GetSetting(IConfiguration? configuration, params string[] keys)
{
foreach (var key in keys)
{
var value = Environment.GetEnvironmentVariable(key) ?? configuration?[key];
if (!string.IsNullOrWhiteSpace(value))
{
return value;
}
}
return null;
}
private static Dictionary<string, string> ToHeaderDictionary(HttpResponseHeaders responseHeaders, HttpContentHeaders contentHeaders)
{
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var header in responseHeaders)
{
headers[header.Key] = string.Join(",", header.Value);
}
foreach (var header in contentHeaders)
{
headers[header.Key] = string.Join(",", header.Value);
}
return headers;
}
private Exception CreateRateLimitException(IReadOnlyDictionary<string, string> headers)
{
var retryAfter = _rateLimitService?.GetRetryAfter(headers as IDictionary<string, string> ?? new Dictionary<string, string>(headers, StringComparer.OrdinalIgnoreCase));
return retryAfter is { } delay && delay > TimeSpan.Zero
? new HttpRequestException($"Bedrock rate limit exceeded. Retry after {delay.TotalSeconds:F0} seconds.")
: new HttpRequestException("Bedrock rate limit exceeded.");
}
private static async IAsyncEnumerable<string> ReadSseDataAsync(Stream stream, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct)
{
using var reader = new StreamReader(stream, Encoding.UTF8, detectEncodingFromByteOrderMarks: false, bufferSize: 4096, leaveOpen: false);
var buffer = new StringBuilder();
while (!ct.IsCancellationRequested)
{
var line = await reader.ReadLineAsync(ct).ConfigureAwait(false);
if (line is null)
{
if (buffer.Length > 0)
{
yield return buffer.ToString();
}
yield break;
}
if (line.Length == 0)
{
if (buffer.Length > 0)
{
yield return buffer.ToString();
buffer.Clear();
}
continue;
}
if (line.StartsWith("data:", StringComparison.OrdinalIgnoreCase))
{
var data = line.AsSpan(5).TrimStart();
if (data.Length == 0)
{
continue;
}
if (buffer.Length > 0)
{
buffer.Append('\n');
}
buffer.Append(data);
}
}
}
}