233 lines
9.4 KiB
C#
233 lines
9.4 KiB
C#
using System.Net.Http.Headers;
|
|
using System.Net.Http.Json;
|
|
using System.Text;
|
|
using System.Text.Json;
|
|
using FreeCode.Core.Interfaces;
|
|
using FreeCode.Core.Models;
|
|
using Microsoft.Extensions.Configuration;
|
|
|
|
namespace FreeCode.ApiProviders;
|
|
|
|
public sealed class BedrockProvider : FreeCode.Core.Interfaces.IApiProvider
|
|
{
|
|
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web);
|
|
private readonly HttpClient _httpClient;
|
|
private readonly string _baseUrl;
|
|
private readonly string _region;
|
|
private readonly string _modelPrefix;
|
|
private readonly string? _bearerToken;
|
|
private readonly string? _apiKey;
|
|
private readonly bool _skipAuth;
|
|
private readonly IRateLimitService? _rateLimitService;
|
|
|
|
public BedrockProvider(HttpClient? httpClient = null, IConfiguration? configuration = null, IRateLimitService? rateLimitService = null)
|
|
{
|
|
_httpClient = httpClient ?? new HttpClient();
|
|
_rateLimitService = rateLimitService;
|
|
_region = GetSetting(configuration, "AWS_REGION", "Bedrock:Region")
|
|
?? GetSetting(configuration, "AWS_DEFAULT_REGION", "Bedrock:DefaultRegion")
|
|
?? "us-east-1";
|
|
_baseUrl = GetSetting(configuration, "ANTHROPIC_BEDROCK_BASE_URL", "Bedrock:BaseUrl")
|
|
?? $"https://bedrock-runtime.{_region}.amazonaws.com";
|
|
_modelPrefix = GetSetting(configuration, "AWS_BEDROCK_MODEL_PREFIX", "Bedrock:ModelPrefix") ?? "us.anthropic";
|
|
_bearerToken = ResolveBearerToken(configuration);
|
|
_apiKey = GetSetting(configuration, "ANTHROPIC_BEDROCK_API_KEY", "Bedrock:ApiKey");
|
|
_skipAuth = string.Equals(GetSetting(configuration, "CLAUDE_CODE_SKIP_BEDROCK_AUTH", "Bedrock:SkipAuth"), "1", StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
public async IAsyncEnumerable<SDKMessage> StreamAsync(ApiRequest request, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
|
|
{
|
|
var model = request.Model ?? "claude-sonnet-4-6";
|
|
var modelId = model.Contains(':', StringComparison.Ordinal) ? model : $"{_modelPrefix}.{model}";
|
|
using var httpRequest = new HttpRequestMessage(HttpMethod.Post, new Uri(new Uri(_baseUrl.TrimEnd('/')), $"/model/{Uri.EscapeDataString(modelId)}/invoke-with-response-stream"));
|
|
httpRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
|
|
if (!_skipAuth && !string.IsNullOrWhiteSpace(_bearerToken))
|
|
{
|
|
httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _bearerToken);
|
|
}
|
|
else if (!_skipAuth && !string.IsNullOrWhiteSpace(_apiKey))
|
|
{
|
|
httpRequest.Headers.Add("x-api-key", _apiKey);
|
|
}
|
|
|
|
var payload = new
|
|
{
|
|
anthropic_version = "2023-06-01",
|
|
system = request.SystemPrompt,
|
|
messages = request.Messages,
|
|
tools = request.Tools,
|
|
max_tokens = 4096,
|
|
stream = true
|
|
};
|
|
|
|
httpRequest.Content = JsonContent.Create(payload, options: SerializerOptions);
|
|
|
|
using var response = await _httpClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead, ct).ConfigureAwait(false);
|
|
var responseHeaders = ToHeaderDictionary(response.Headers, response.Content.Headers);
|
|
if (_rateLimitService?.CanProceed(responseHeaders) == false)
|
|
{
|
|
throw CreateRateLimitException(responseHeaders);
|
|
}
|
|
|
|
response.EnsureSuccessStatusCode();
|
|
|
|
await foreach (var data in ReadSseDataAsync(await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false), ct).ConfigureAwait(false))
|
|
{
|
|
using var document = JsonDocument.Parse(data);
|
|
var root = document.RootElement;
|
|
if (!root.TryGetProperty("type", out var typeProperty))
|
|
{
|
|
continue;
|
|
}
|
|
|
|
switch (typeProperty.GetString())
|
|
{
|
|
case "content_block_delta":
|
|
if (root.TryGetProperty("delta", out var delta))
|
|
{
|
|
if (delta.TryGetProperty("text", out var text))
|
|
{
|
|
yield return new SDKMessage.StreamingDelta(text.GetString() ?? string.Empty);
|
|
}
|
|
else if (delta.TryGetProperty("partial_json", out var partialJson))
|
|
{
|
|
yield return new SDKMessage.StreamingDelta(partialJson.GetString() ?? string.Empty);
|
|
}
|
|
}
|
|
|
|
break;
|
|
case "content_block_start":
|
|
if (root.TryGetProperty("content_block", out var contentBlock)
|
|
&& contentBlock.TryGetProperty("type", out var blockType)
|
|
&& string.Equals(blockType.GetString(), "tool_use", StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
var input = contentBlock.TryGetProperty("input", out var inputProperty)
|
|
? inputProperty.Clone()
|
|
: JsonDocument.Parse("{}").RootElement.Clone();
|
|
|
|
yield return new SDKMessage.ToolUseStart(
|
|
contentBlock.TryGetProperty("id", out var idProperty) ? idProperty.GetString() ?? string.Empty : string.Empty,
|
|
contentBlock.TryGetProperty("name", out var nameProperty) ? nameProperty.GetString() ?? string.Empty : string.Empty,
|
|
input);
|
|
}
|
|
|
|
break;
|
|
case "message_stop":
|
|
case "response.completed":
|
|
yield break;
|
|
}
|
|
}
|
|
}
|
|
|
|
private static string? ResolveBearerToken(IConfiguration? configuration)
|
|
{
|
|
var credentialsFromEnv = GetSetting(configuration,
|
|
"AWS_BEARER_TOKEN_BEDROCK",
|
|
"AWS_SESSION_TOKEN",
|
|
"AWS_ACCESS_TOKEN",
|
|
"Bedrock:BearerToken");
|
|
if (!string.IsNullOrWhiteSpace(credentialsFromEnv))
|
|
{
|
|
return credentialsFromEnv;
|
|
}
|
|
|
|
var accessKey = GetSetting(configuration, "AWS_ACCESS_KEY_ID", "Bedrock:AccessKeyId");
|
|
var secretKey = GetSetting(configuration, "AWS_SECRET_ACCESS_KEY", "Bedrock:SecretAccessKey");
|
|
if (!string.IsNullOrWhiteSpace(accessKey) && !string.IsNullOrWhiteSpace(secretKey))
|
|
{
|
|
var sessionToken = GetSetting(configuration, "AWS_SESSION_TOKEN", "Bedrock:SessionToken");
|
|
return string.IsNullOrWhiteSpace(sessionToken)
|
|
? $"{accessKey}:{secretKey}"
|
|
: $"{accessKey}:{secretKey}:{sessionToken}";
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
private static string? GetSetting(IConfiguration? configuration, params string[] keys)
|
|
{
|
|
foreach (var key in keys)
|
|
{
|
|
var value = Environment.GetEnvironmentVariable(key) ?? configuration?[key];
|
|
if (!string.IsNullOrWhiteSpace(value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
private static Dictionary<string, string> ToHeaderDictionary(HttpResponseHeaders responseHeaders, HttpContentHeaders contentHeaders)
|
|
{
|
|
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
|
|
|
foreach (var header in responseHeaders)
|
|
{
|
|
headers[header.Key] = string.Join(",", header.Value);
|
|
}
|
|
|
|
foreach (var header in contentHeaders)
|
|
{
|
|
headers[header.Key] = string.Join(",", header.Value);
|
|
}
|
|
|
|
return headers;
|
|
}
|
|
|
|
private Exception CreateRateLimitException(IReadOnlyDictionary<string, string> headers)
|
|
{
|
|
var retryAfter = _rateLimitService?.GetRetryAfter(headers as IDictionary<string, string> ?? new Dictionary<string, string>(headers, StringComparer.OrdinalIgnoreCase));
|
|
return retryAfter is { } delay && delay > TimeSpan.Zero
|
|
? new HttpRequestException($"Bedrock rate limit exceeded. Retry after {delay.TotalSeconds:F0} seconds.")
|
|
: new HttpRequestException("Bedrock rate limit exceeded.");
|
|
}
|
|
|
|
private static async IAsyncEnumerable<string> ReadSseDataAsync(Stream stream, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct)
|
|
{
|
|
using var reader = new StreamReader(stream, Encoding.UTF8, detectEncodingFromByteOrderMarks: false, bufferSize: 4096, leaveOpen: false);
|
|
var buffer = new StringBuilder();
|
|
|
|
while (!ct.IsCancellationRequested)
|
|
{
|
|
var line = await reader.ReadLineAsync(ct).ConfigureAwait(false);
|
|
if (line is null)
|
|
{
|
|
if (buffer.Length > 0)
|
|
{
|
|
yield return buffer.ToString();
|
|
}
|
|
|
|
yield break;
|
|
}
|
|
|
|
if (line.Length == 0)
|
|
{
|
|
if (buffer.Length > 0)
|
|
{
|
|
yield return buffer.ToString();
|
|
buffer.Clear();
|
|
}
|
|
|
|
continue;
|
|
}
|
|
|
|
if (line.StartsWith("data:", StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
var data = line.AsSpan(5).TrimStart();
|
|
if (data.Length == 0)
|
|
{
|
|
continue;
|
|
}
|
|
|
|
if (buffer.Length > 0)
|
|
{
|
|
buffer.Append('\n');
|
|
}
|
|
|
|
buffer.Append(data);
|
|
}
|
|
}
|
|
}
|
|
}
|