191 lines
7.5 KiB
C#
191 lines
7.5 KiB
C#
using Xunit;
|
|
using FluentAssertions;
|
|
using NSubstitute;
|
|
using System.Text.Json;
|
|
using FreeCode.Core.Enums;
|
|
using FreeCode.Core.Interfaces;
|
|
using FreeCode.Core.Models;
|
|
using FreeCode.Engine;
|
|
using FreeCode.Tests.Unit.Helpers;
|
|
|
|
namespace FreeCode.Tests.Unit.Engine;
|
|
|
|
public sealed class QueryEngineTests
|
|
{
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_WithBlankContent_ThrowsArgumentException()
|
|
{
|
|
var engine = CreateEngine(new CapturingApiProvider([]));
|
|
|
|
var act = async () => await engine.SubmitMessageAsync(" ").ToListAsync();
|
|
|
|
await act.Should().ThrowAsync<ArgumentException>();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_YieldsUserAndAssistantMessages_AndBuildsRequest()
|
|
{
|
|
var provider = new CapturingApiProvider([new SDKMessage.AssistantMessage("Hello back", "assistant-1")]);
|
|
var toolRegistry = new StubToolRegistry
|
|
{
|
|
Tools =
|
|
[
|
|
new StubTool
|
|
{
|
|
Name = "search",
|
|
Category = ToolCategory.FileSystem,
|
|
DescriptionFactory = _ => Task.FromResult("Search files"),
|
|
InputSchema = ParseJson("{\"type\":\"object\"}")
|
|
}
|
|
]
|
|
};
|
|
var engine = CreateEngine(provider, toolRegistry: toolRegistry);
|
|
|
|
var messages = await engine.SubmitMessageAsync("Hello there", new SubmitMessageOptions(Model: "gpt-test")).ToListAsync();
|
|
|
|
messages.Should().HaveCount(2);
|
|
messages[0].Should().BeOfType<SDKMessage.UserMessage>();
|
|
messages[1].Should().BeOfType<SDKMessage.AssistantMessage>().Which.Text.Should().Be("Hello back");
|
|
|
|
engine.GetMessages().Should().SatisfyRespectively(
|
|
user =>
|
|
{
|
|
user.Role.Should().Be(MessageRole.User);
|
|
user.Content.Should().Be("Hello there");
|
|
},
|
|
assistant =>
|
|
{
|
|
assistant.Role.Should().Be(MessageRole.Assistant);
|
|
assistant.Content.Should().Be("Hello back");
|
|
});
|
|
|
|
provider.Requests.Should().ContainSingle();
|
|
provider.Requests[0].Model.Should().Be("gpt-test");
|
|
provider.Requests[0].Messages[0].GetProperty("content").GetString().Should().Be("Hello there");
|
|
provider.Requests[0].Tools[0].GetProperty("name").GetString().Should().Be("search");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_WithStreamingDeltas_SynthesizesAssistantMessage()
|
|
{
|
|
var provider = new CapturingApiProvider([
|
|
new SDKMessage.StreamingDelta("Hello"),
|
|
new SDKMessage.StreamingDelta(" world")
|
|
]);
|
|
var engine = CreateEngine(provider);
|
|
|
|
var messages = await engine.SubmitMessageAsync("Hi").ToListAsync();
|
|
|
|
messages.OfType<SDKMessage.StreamingDelta>().Select(x => x.Text).Should().Equal("Hello", " world");
|
|
messages.OfType<SDKMessage.AssistantMessage>().Should().ContainSingle(x => x.Text == "Hello world");
|
|
engine.GetMessages().Should().ContainSingle(x => x.Role == MessageRole.Assistant && Equals(x.Content, "Hello world"));
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_WithToolUse_ExecutesToolAndStoresToolMessage()
|
|
{
|
|
var provider = new CapturingApiProvider([
|
|
new SDKMessage.ToolUseStart("tool-1", "Read", ParseJson("{\"path\":\"sample.txt\"}"))
|
|
]);
|
|
var engine = CreateEngine(
|
|
provider,
|
|
toolExecutor: static (_, _, _, _, _) => Task.FromResult(("file content", true, false)));
|
|
|
|
var messages = await engine.SubmitMessageAsync("show file").ToListAsync();
|
|
|
|
messages.OfType<SDKMessage.ToolUseStart>().Should().ContainSingle();
|
|
messages.OfType<SDKMessage.ToolUseResult>().Should().ContainSingle(x => x.Output == "file content");
|
|
engine.GetMessages().Should().Contain(x => x.Role == MessageRole.Tool && Equals(x.Content, "file content") && x.ToolName == "Read");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_WhenToolExecutorThrows_ReturnsFailureToolResult()
|
|
{
|
|
var provider = new CapturingApiProvider([
|
|
new SDKMessage.ToolUseStart("tool-2", "Bash", ParseJson("{}"))
|
|
]);
|
|
var engine = CreateEngine(
|
|
provider,
|
|
toolExecutor: static (_, _, _, _, _) => throw new InvalidOperationException("boom"));
|
|
|
|
var messages = await engine.SubmitMessageAsync("run command").ToListAsync();
|
|
|
|
messages.OfType<SDKMessage.ToolUseResult>().Should().ContainSingle(x => x.Output == "Tool 'Bash' failed: boom");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_WhenExtractMemoriesEnabled_TriggersPostProcessing()
|
|
{
|
|
var provider = new CapturingApiProvider([new SDKMessage.AssistantMessage("done", "assistant-2")]);
|
|
var sessionMemory = new StubSessionMemoryService();
|
|
var featureFlags = new StubFeatureFlagService();
|
|
featureFlags.EnabledFlags.Add("EXTRACT_MEMORIES");
|
|
var engine = CreateEngine(provider, sessionMemoryService: sessionMemory, featureFlagService: featureFlags);
|
|
|
|
_ = await engine.SubmitMessageAsync("remember this").ToListAsync();
|
|
await Task.Delay(50);
|
|
|
|
sessionMemory.ExtractedMessages.Should().ContainSingle();
|
|
sessionMemory.ExtractedMessages[0].Select(message => message.Content?.ToString()).Should().Contain(["remember this", "done"]);
|
|
}
|
|
|
|
private static QueryEngine CreateEngine(
|
|
IApiProvider provider,
|
|
IToolRegistry? toolRegistry = null,
|
|
ISessionMemoryService? sessionMemoryService = null,
|
|
IFeatureFlagService? featureFlagService = null,
|
|
Func<string, JsonElement, IPermissionEngine, ToolPermissionContext?, CancellationToken, Task<(string Output, bool IsAllowed, bool ShouldContinue)>>? toolExecutor = null)
|
|
{
|
|
var router = new StubApiProviderRouter(provider);
|
|
var promptBuilder = new StubPromptBuilder();
|
|
|
|
return new QueryEngine(
|
|
router,
|
|
toolRegistry ?? new StubToolRegistry(),
|
|
new StubPermissionEngine(),
|
|
promptBuilder,
|
|
sessionMemoryService ?? new StubSessionMemoryService(),
|
|
featureFlagService ?? new StubFeatureFlagService(),
|
|
toolExecutor,
|
|
new TestLogger<QueryEngine>());
|
|
}
|
|
|
|
private static JsonElement ParseJson(string json)
|
|
{
|
|
using var document = JsonDocument.Parse(json);
|
|
return document.RootElement.Clone();
|
|
}
|
|
|
|
private sealed class CapturingApiProvider(IEnumerable<SDKMessage> messages) : IApiProvider
|
|
{
|
|
public List<ApiRequest> Requests { get; } = [];
|
|
|
|
public IAsyncEnumerable<SDKMessage> StreamAsync(ApiRequest request, CancellationToken ct = default)
|
|
{
|
|
Requests.Add(request);
|
|
return ToAsync(messages, ct);
|
|
}
|
|
|
|
private static async IAsyncEnumerable<SDKMessage> ToAsync(IEnumerable<SDKMessage> items, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct)
|
|
{
|
|
foreach (var item in items)
|
|
{
|
|
ct.ThrowIfCancellationRequested();
|
|
yield return item;
|
|
await Task.Yield();
|
|
}
|
|
}
|
|
}
|
|
|
|
private sealed class StubApiProviderRouter(IApiProvider provider) : IApiProviderRouter
|
|
{
|
|
public IApiProvider GetActiveProvider() => provider;
|
|
}
|
|
|
|
private sealed class StubPromptBuilder : IPromptBuilder
|
|
{
|
|
public Task<string> BuildAsync(IReadOnlyList<Message> messages, ToolPermissionContext? permissionContext, SubmitMessageOptions options)
|
|
=> Task.FromResult("system");
|
|
}
|
|
}
|