应文浩wenhao.ying@xiaobao100.com bce2612b64 feat: 完善具体实现
2026-04-06 15:25:34 +08:00

307 lines
12 KiB
C#

using System.Runtime.CompilerServices;
using System.Text.Json;
using FluentAssertions;
using FreeCode.Core.Enums;
using FreeCode.Core.Interfaces;
using FreeCode.Core.Models;
using FreeCode.Engine;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
[assembly: CollectionBehavior(DisableTestParallelization = true)]
namespace FreeCode.Tests.Integration;
public sealed class QueryPipelineTests
{
[Fact]
public async Task SubmitMessageAsync_SimpleText_ReturnsUserAndAssistantMessages()
{
var sut = CreateSut(CreateRouter(new SDKMessage.StreamingDelta("Hello")));
var messages = await CollectAsync(sut.SubmitMessageAsync("Hi"));
messages.OfType<SDKMessage.UserMessage>().Should().ContainSingle();
messages.OfType<SDKMessage.StreamingDelta>().Select(message => message.Text).Should().Contain("Hello");
messages.OfType<SDKMessage.AssistantMessage>().Select(message => message.Text).Should().Contain("Hello");
}
[Fact]
public async Task SubmitMessageAsync_AssistantMessage_AppendsToHistory()
{
var sut = CreateSut(CreateRouter(new SDKMessage.AssistantMessage("Ready", "assistant-1")));
_ = await CollectAsync(sut.SubmitMessageAsync("Hello"));
sut.GetMessages().Should().Contain(message =>
message.MessageId == "assistant-1"
&& message.Role == MessageRole.Assistant
&& (string?)message.Content == "Ready");
}
[Fact]
public async Task SubmitMessageAsync_StreamingDelta_AccumulatesText()
{
var sut = CreateSut(CreateRouter(
new SDKMessage.StreamingDelta("Hello"),
new SDKMessage.StreamingDelta(" "),
new SDKMessage.StreamingDelta("world")));
var messages = await CollectAsync(sut.SubmitMessageAsync("Say hello"));
messages.OfType<SDKMessage.StreamingDelta>().Select(delta => delta.Text).Should().Equal("Hello", " ", "world");
sut.GetMessages().Should().Contain(message =>
message.Role == MessageRole.Assistant
&& (string?)message.Content == "Hello world");
}
[Fact]
public async Task SubmitMessageAsync_ToolUse_ExecutesToolAndReturnsResult()
{
var toolCalls = new List<string>();
var sut = CreateSut(
CreateRouter(new SDKMessage.ToolUseStart("tool-1", "read", Json("{\"path\":\"file.txt\"}"))),
toolExecutor: (toolName, _, _, _, _) =>
{
toolCalls.Add(toolName);
return Task.FromResult(("tool output", true, false));
});
var messages = await CollectAsync(sut.SubmitMessageAsync("Use a tool"));
toolCalls.Should().Equal("read");
messages.OfType<SDKMessage.ToolUseStart>()
.Should().ContainSingle(message => message.ToolUseId == "tool-1" && message.ToolName == "read");
messages.OfType<SDKMessage.ToolUseResult>()
.Should().ContainSingle(message => message.ToolUseId == "tool-1" && message.Output == "tool output" && !message.ShouldContinue);
sut.GetMessages().Should().Contain(message =>
message.Role == MessageRole.Tool
&& message.ToolUseId == "tool-1"
&& message.ToolName == "read"
&& (string?)message.Content == "tool output");
}
[Fact]
public async Task SubmitMessageAsync_ToolPermissionDenied_ReturnsPermissionDenial()
{
var sut = CreateSut(
CreateRouter(new SDKMessage.ToolUseStart("tool-2", "bash", Json("{}"))),
toolExecutor: (_, _, _, _, _) => Task.FromResult(("denied", false, false)));
var messages = await CollectAsync(sut.SubmitMessageAsync("Run bash"));
messages.OfType<SDKMessage.PermissionDenial>()
.Should().ContainSingle(message => message.ToolName == "bash" && message.ToolUseId == "tool-2");
messages.OfType<SDKMessage.ToolUseResult>()
.Should().ContainSingle(message => message.ToolUseId == "tool-2" && message.Output == "denied");
}
[Fact]
public async Task SubmitMessageAsync_MultipleToolUses_ExecutesAll()
{
var toolCalls = new List<string>();
var sut = CreateSut(
CreateRouter(
new SDKMessage.ToolUseStart("tool-1", "read", Json("{}")),
new SDKMessage.ToolUseStart("tool-2", "write", Json("{}"))),
toolExecutor: (toolName, _, _, _, _) =>
{
toolCalls.Add(toolName);
return Task.FromResult(($"completed-{toolName}", true, false));
});
var messages = await CollectAsync(sut.SubmitMessageAsync("Use multiple tools"));
toolCalls.Should().Equal("read", "write");
messages.OfType<SDKMessage.ToolUseResult>().Select(result => result.ToolUseId).Should().Equal("tool-1", "tool-2");
}
[Fact]
public async Task SubmitMessageAsync_Cancellation_StopsGracefully()
{
var sut = CreateSut(CreateRouter((_, ct) => StreamUntilCancelledAsync(new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), ct)));
using var cts = new CancellationTokenSource();
cts.Cancel();
var messages = await CollectAsync(sut.SubmitMessageAsync("Cancel me", ct: cts.Token));
messages.Should().ContainSingle(message => message is SDKMessage.UserMessage);
}
[Fact]
public async Task SubmitMessageAsync_EmptyContent_ThrowsArgumentException()
{
var sut = CreateSut(CreateRouter());
var act = async () =>
{
await foreach (var _ in sut.SubmitMessageAsync(" "))
{
}
};
await act.Should().ThrowAsync<ArgumentException>();
}
[Fact]
public async Task CancelAsync_CancelsActiveQuery()
{
var started = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var sut = CreateSut(CreateRouter((_, ct) => StreamUntilCancelledAsync(started, ct)));
var runningQuery = Task.Run(async () =>
{
await foreach (var _ in sut.SubmitMessageAsync("Long running"))
{
}
});
await started.Task;
await sut.CancelAsync();
var exception = await Record.ExceptionAsync(async () => await runningQuery);
exception.Should().NotBeNull();
(exception is OperationCanceledException || exception is ObjectDisposedException).Should().BeTrue();
}
[Fact]
public void GetMessages_InitiallyEmpty()
{
var sut = CreateSut(CreateRouter());
sut.GetMessages().Should().BeEmpty();
}
[Fact]
public async Task GetCurrentUsage_TracksTokens()
{
var sut = CreateSut(CreateRouter(new SDKMessage.StreamingDelta("Hello from assistant")));
_ = await CollectAsync(sut.SubmitMessageAsync("Hello from user"));
var usage = sut.GetCurrentUsage();
usage.InputTokens.Should().BeGreaterThan(0);
usage.OutputTokens.Should().BeGreaterThan(0);
}
[Fact]
public async Task SubmitMessageAsync_ExtractsMemories_WhenFeatureEnabled()
{
var sessionMemoryService = Substitute.For<ISessionMemoryService>();
sessionMemoryService.GetCurrentMemoryAsync().Returns((string?)null);
sessionMemoryService.TryExtractAsync(Arg.Any<IReadOnlyList<Message>>()).Returns(Task.CompletedTask);
var sut = CreateSut(
CreateRouter(new SDKMessage.StreamingDelta("Remember this")),
featureFlagService: new StubFeatureFlagService("EXTRACT_MEMORIES"),
sessionMemoryService: sessionMemoryService);
_ = await CollectAsync(sut.SubmitMessageAsync("Store memory"));
await sessionMemoryService.Received(1).TryExtractAsync(Arg.Is<IReadOnlyList<Message>>(messages => messages.Count >= 2));
}
private static QueryEngine CreateSut(
IApiProviderRouter router,
IToolRegistry? toolRegistry = null,
IPermissionEngine? permissionEngine = null,
IPromptBuilder? promptBuilder = null,
ISessionMemoryService? sessionMemoryService = null,
IFeatureFlagService? featureFlagService = null,
Func<string, JsonElement, IPermissionEngine, ToolPermissionContext?, CancellationToken, Task<(string Output, bool IsAllowed, bool ShouldContinue)>>? toolExecutor = null)
{
var builder = promptBuilder ?? Substitute.For<IPromptBuilder>();
builder.BuildAsync(Arg.Any<IReadOnlyList<Message>>(), Arg.Any<ToolPermissionContext?>(), Arg.Any<SubmitMessageOptions>())
.Returns(Task.FromResult("system prompt"));
var memoryService = sessionMemoryService ?? Substitute.For<ISessionMemoryService>();
memoryService.GetCurrentMemoryAsync().Returns((string?)null);
memoryService.TryExtractAsync(Arg.Any<IReadOnlyList<Message>>()).Returns(Task.CompletedTask);
return new QueryEngine(
router,
toolRegistry ?? new StubToolRegistry(),
permissionEngine ?? Substitute.For<IPermissionEngine>(),
builder,
memoryService,
featureFlagService ?? new StubFeatureFlagService(),
toolExecutor,
Substitute.For<ILogger<QueryEngine>>());
}
private static IApiProviderRouter CreateRouter(params SDKMessage[] responses)
=> CreateRouter([responses]);
private static IApiProviderRouter CreateRouter(IReadOnlyList<SDKMessage[]> responseBatches)
{
var index = 0;
return CreateRouter((_, ct) => YieldMessages(index < responseBatches.Count ? responseBatches[index++] : [], ct));
}
private static IApiProviderRouter CreateRouter(Func<ApiRequest, CancellationToken, IAsyncEnumerable<SDKMessage>> streamFactory)
{
var provider = Substitute.For<IApiProvider>();
provider.StreamAsync(Arg.Any<ApiRequest>(), Arg.Any<CancellationToken>())
.Returns(call => streamFactory(call.Arg<ApiRequest>(), call.Arg<CancellationToken>()));
var router = Substitute.For<IApiProviderRouter>();
router.GetActiveProvider().Returns(provider);
return router;
}
private static async Task<List<SDKMessage>> CollectAsync(IAsyncEnumerable<SDKMessage> messages)
{
var results = new List<SDKMessage>();
await foreach (var message in messages)
{
results.Add(message);
}
return results;
}
private static async IAsyncEnumerable<SDKMessage> YieldMessages(
IReadOnlyList<SDKMessage> messages,
[EnumeratorCancellation] CancellationToken ct = default)
{
foreach (var message in messages)
{
ct.ThrowIfCancellationRequested();
yield return message;
await Task.Yield();
}
}
private static async IAsyncEnumerable<SDKMessage> StreamUntilCancelledAsync(
TaskCompletionSource started,
[EnumeratorCancellation] CancellationToken ct = default)
{
started.TrySetResult();
yield return new SDKMessage.StreamingDelta("working");
await Task.Delay(Timeout.InfiniteTimeSpan, ct);
}
private static JsonElement Json(string json)
{
using var document = JsonDocument.Parse(json);
return document.RootElement.Clone();
}
private sealed class StubToolRegistry : IToolRegistry
{
public Task<IReadOnlyList<ITool>> GetToolsAsync(ToolPermissionContext? permissionContext = null)
=> Task.FromResult<IReadOnlyList<ITool>>([]);
}
private sealed class StubFeatureFlagService(params string[] enabledFlags) : IFeatureFlagService
{
private readonly HashSet<string> _enabledFlags = new(enabledFlags, StringComparer.OrdinalIgnoreCase);
public bool IsEnabled(string featureFlag) => _enabledFlags.Contains(featureFlag);
public IReadOnlySet<string> GetEnabledFlags() => _enabledFlags;
}
}