应文浩wenhao.ying@xiaobao100.com bce2612b64 feat: 完善具体实现
2026-04-06 15:25:34 +08:00

191 lines
7.5 KiB
C#

using Xunit;
using FluentAssertions;
using NSubstitute;
using System.Text.Json;
using FreeCode.Core.Enums;
using FreeCode.Core.Interfaces;
using FreeCode.Core.Models;
using FreeCode.Engine;
using FreeCode.Tests.Unit.Helpers;
namespace FreeCode.Tests.Unit.Engine;
public sealed class QueryEngineTests
{
[Fact]
public async Task SubmitMessageAsync_WithBlankContent_ThrowsArgumentException()
{
var engine = CreateEngine(new CapturingApiProvider([]));
var act = async () => await engine.SubmitMessageAsync(" ").ToListAsync();
await act.Should().ThrowAsync<ArgumentException>();
}
[Fact]
public async Task SubmitMessageAsync_YieldsUserAndAssistantMessages_AndBuildsRequest()
{
var provider = new CapturingApiProvider([new SDKMessage.AssistantMessage("Hello back", "assistant-1")]);
var toolRegistry = new StubToolRegistry
{
Tools =
[
new StubTool
{
Name = "search",
Category = ToolCategory.FileSystem,
DescriptionFactory = _ => Task.FromResult("Search files"),
InputSchema = ParseJson("{\"type\":\"object\"}")
}
]
};
var engine = CreateEngine(provider, toolRegistry: toolRegistry);
var messages = await engine.SubmitMessageAsync("Hello there", new SubmitMessageOptions(Model: "gpt-test")).ToListAsync();
messages.Should().HaveCount(2);
messages[0].Should().BeOfType<SDKMessage.UserMessage>();
messages[1].Should().BeOfType<SDKMessage.AssistantMessage>().Which.Text.Should().Be("Hello back");
engine.GetMessages().Should().SatisfyRespectively(
user =>
{
user.Role.Should().Be(MessageRole.User);
user.Content.Should().Be("Hello there");
},
assistant =>
{
assistant.Role.Should().Be(MessageRole.Assistant);
assistant.Content.Should().Be("Hello back");
});
provider.Requests.Should().ContainSingle();
provider.Requests[0].Model.Should().Be("gpt-test");
provider.Requests[0].Messages[0].GetProperty("content").GetString().Should().Be("Hello there");
provider.Requests[0].Tools[0].GetProperty("name").GetString().Should().Be("search");
}
[Fact]
public async Task SubmitMessageAsync_WithStreamingDeltas_SynthesizesAssistantMessage()
{
var provider = new CapturingApiProvider([
new SDKMessage.StreamingDelta("Hello"),
new SDKMessage.StreamingDelta(" world")
]);
var engine = CreateEngine(provider);
var messages = await engine.SubmitMessageAsync("Hi").ToListAsync();
messages.OfType<SDKMessage.StreamingDelta>().Select(x => x.Text).Should().Equal("Hello", " world");
messages.OfType<SDKMessage.AssistantMessage>().Should().ContainSingle(x => x.Text == "Hello world");
engine.GetMessages().Should().ContainSingle(x => x.Role == MessageRole.Assistant && Equals(x.Content, "Hello world"));
}
[Fact]
public async Task SubmitMessageAsync_WithToolUse_ExecutesToolAndStoresToolMessage()
{
var provider = new CapturingApiProvider([
new SDKMessage.ToolUseStart("tool-1", "Read", ParseJson("{\"path\":\"sample.txt\"}"))
]);
var engine = CreateEngine(
provider,
toolExecutor: static (_, _, _, _, _) => Task.FromResult(("file content", true, false)));
var messages = await engine.SubmitMessageAsync("show file").ToListAsync();
messages.OfType<SDKMessage.ToolUseStart>().Should().ContainSingle();
messages.OfType<SDKMessage.ToolUseResult>().Should().ContainSingle(x => x.Output == "file content");
engine.GetMessages().Should().Contain(x => x.Role == MessageRole.Tool && Equals(x.Content, "file content") && x.ToolName == "Read");
}
[Fact]
public async Task SubmitMessageAsync_WhenToolExecutorThrows_ReturnsFailureToolResult()
{
var provider = new CapturingApiProvider([
new SDKMessage.ToolUseStart("tool-2", "Bash", ParseJson("{}"))
]);
var engine = CreateEngine(
provider,
toolExecutor: static (_, _, _, _, _) => throw new InvalidOperationException("boom"));
var messages = await engine.SubmitMessageAsync("run command").ToListAsync();
messages.OfType<SDKMessage.ToolUseResult>().Should().ContainSingle(x => x.Output == "Tool 'Bash' failed: boom");
}
[Fact]
public async Task SubmitMessageAsync_WhenExtractMemoriesEnabled_TriggersPostProcessing()
{
var provider = new CapturingApiProvider([new SDKMessage.AssistantMessage("done", "assistant-2")]);
var sessionMemory = new StubSessionMemoryService();
var featureFlags = new StubFeatureFlagService();
featureFlags.EnabledFlags.Add("EXTRACT_MEMORIES");
var engine = CreateEngine(provider, sessionMemoryService: sessionMemory, featureFlagService: featureFlags);
_ = await engine.SubmitMessageAsync("remember this").ToListAsync();
await Task.Delay(50);
sessionMemory.ExtractedMessages.Should().ContainSingle();
sessionMemory.ExtractedMessages[0].Select(message => message.Content?.ToString()).Should().Contain(["remember this", "done"]);
}
private static QueryEngine CreateEngine(
IApiProvider provider,
IToolRegistry? toolRegistry = null,
ISessionMemoryService? sessionMemoryService = null,
IFeatureFlagService? featureFlagService = null,
Func<string, JsonElement, IPermissionEngine, ToolPermissionContext?, CancellationToken, Task<(string Output, bool IsAllowed, bool ShouldContinue)>>? toolExecutor = null)
{
var router = new StubApiProviderRouter(provider);
var promptBuilder = new StubPromptBuilder();
return new QueryEngine(
router,
toolRegistry ?? new StubToolRegistry(),
new StubPermissionEngine(),
promptBuilder,
sessionMemoryService ?? new StubSessionMemoryService(),
featureFlagService ?? new StubFeatureFlagService(),
toolExecutor,
new TestLogger<QueryEngine>());
}
private static JsonElement ParseJson(string json)
{
using var document = JsonDocument.Parse(json);
return document.RootElement.Clone();
}
private sealed class CapturingApiProvider(IEnumerable<SDKMessage> messages) : IApiProvider
{
public List<ApiRequest> Requests { get; } = [];
public IAsyncEnumerable<SDKMessage> StreamAsync(ApiRequest request, CancellationToken ct = default)
{
Requests.Add(request);
return ToAsync(messages, ct);
}
private static async IAsyncEnumerable<SDKMessage> ToAsync(IEnumerable<SDKMessage> items, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct)
{
foreach (var item in items)
{
ct.ThrowIfCancellationRequested();
yield return item;
await Task.Yield();
}
}
}
private sealed class StubApiProviderRouter(IApiProvider provider) : IApiProviderRouter
{
public IApiProvider GetActiveProvider() => provider;
}
private sealed class StubPromptBuilder : IPromptBuilder
{
public Task<string> BuildAsync(IReadOnlyList<Message> messages, ToolPermissionContext? permissionContext, SubmitMessageOptions options)
=> Task.FromResult("system");
}
}