307 lines
12 KiB
C#
307 lines
12 KiB
C#
using System.Runtime.CompilerServices;
|
|
using System.Text.Json;
|
|
using FluentAssertions;
|
|
using FreeCode.Core.Enums;
|
|
using FreeCode.Core.Interfaces;
|
|
using FreeCode.Core.Models;
|
|
using FreeCode.Engine;
|
|
using Microsoft.Extensions.Logging;
|
|
using NSubstitute;
|
|
using Xunit;
|
|
|
|
[assembly: CollectionBehavior(DisableTestParallelization = true)]
|
|
|
|
namespace FreeCode.Tests.Integration;
|
|
|
|
public sealed class QueryPipelineTests
|
|
{
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_SimpleText_ReturnsUserAndAssistantMessages()
|
|
{
|
|
var sut = CreateSut(CreateRouter(new SDKMessage.StreamingDelta("Hello")));
|
|
|
|
var messages = await CollectAsync(sut.SubmitMessageAsync("Hi"));
|
|
|
|
messages.OfType<SDKMessage.UserMessage>().Should().ContainSingle();
|
|
messages.OfType<SDKMessage.StreamingDelta>().Select(message => message.Text).Should().Contain("Hello");
|
|
messages.OfType<SDKMessage.AssistantMessage>().Select(message => message.Text).Should().Contain("Hello");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_AssistantMessage_AppendsToHistory()
|
|
{
|
|
var sut = CreateSut(CreateRouter(new SDKMessage.AssistantMessage("Ready", "assistant-1")));
|
|
|
|
_ = await CollectAsync(sut.SubmitMessageAsync("Hello"));
|
|
|
|
sut.GetMessages().Should().Contain(message =>
|
|
message.MessageId == "assistant-1"
|
|
&& message.Role == MessageRole.Assistant
|
|
&& (string?)message.Content == "Ready");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_StreamingDelta_AccumulatesText()
|
|
{
|
|
var sut = CreateSut(CreateRouter(
|
|
new SDKMessage.StreamingDelta("Hello"),
|
|
new SDKMessage.StreamingDelta(" "),
|
|
new SDKMessage.StreamingDelta("world")));
|
|
|
|
var messages = await CollectAsync(sut.SubmitMessageAsync("Say hello"));
|
|
|
|
messages.OfType<SDKMessage.StreamingDelta>().Select(delta => delta.Text).Should().Equal("Hello", " ", "world");
|
|
sut.GetMessages().Should().Contain(message =>
|
|
message.Role == MessageRole.Assistant
|
|
&& (string?)message.Content == "Hello world");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_ToolUse_ExecutesToolAndReturnsResult()
|
|
{
|
|
var toolCalls = new List<string>();
|
|
var sut = CreateSut(
|
|
CreateRouter(new SDKMessage.ToolUseStart("tool-1", "read", Json("{\"path\":\"file.txt\"}"))),
|
|
toolExecutor: (toolName, _, _, _, _) =>
|
|
{
|
|
toolCalls.Add(toolName);
|
|
return Task.FromResult(("tool output", true, false));
|
|
});
|
|
|
|
var messages = await CollectAsync(sut.SubmitMessageAsync("Use a tool"));
|
|
|
|
toolCalls.Should().Equal("read");
|
|
messages.OfType<SDKMessage.ToolUseStart>()
|
|
.Should().ContainSingle(message => message.ToolUseId == "tool-1" && message.ToolName == "read");
|
|
messages.OfType<SDKMessage.ToolUseResult>()
|
|
.Should().ContainSingle(message => message.ToolUseId == "tool-1" && message.Output == "tool output" && !message.ShouldContinue);
|
|
sut.GetMessages().Should().Contain(message =>
|
|
message.Role == MessageRole.Tool
|
|
&& message.ToolUseId == "tool-1"
|
|
&& message.ToolName == "read"
|
|
&& (string?)message.Content == "tool output");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_ToolPermissionDenied_ReturnsPermissionDenial()
|
|
{
|
|
var sut = CreateSut(
|
|
CreateRouter(new SDKMessage.ToolUseStart("tool-2", "bash", Json("{}"))),
|
|
toolExecutor: (_, _, _, _, _) => Task.FromResult(("denied", false, false)));
|
|
|
|
var messages = await CollectAsync(sut.SubmitMessageAsync("Run bash"));
|
|
|
|
messages.OfType<SDKMessage.PermissionDenial>()
|
|
.Should().ContainSingle(message => message.ToolName == "bash" && message.ToolUseId == "tool-2");
|
|
messages.OfType<SDKMessage.ToolUseResult>()
|
|
.Should().ContainSingle(message => message.ToolUseId == "tool-2" && message.Output == "denied");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_MultipleToolUses_ExecutesAll()
|
|
{
|
|
var toolCalls = new List<string>();
|
|
var sut = CreateSut(
|
|
CreateRouter(
|
|
new SDKMessage.ToolUseStart("tool-1", "read", Json("{}")),
|
|
new SDKMessage.ToolUseStart("tool-2", "write", Json("{}"))),
|
|
toolExecutor: (toolName, _, _, _, _) =>
|
|
{
|
|
toolCalls.Add(toolName);
|
|
return Task.FromResult(($"completed-{toolName}", true, false));
|
|
});
|
|
|
|
var messages = await CollectAsync(sut.SubmitMessageAsync("Use multiple tools"));
|
|
|
|
toolCalls.Should().Equal("read", "write");
|
|
messages.OfType<SDKMessage.ToolUseResult>().Select(result => result.ToolUseId).Should().Equal("tool-1", "tool-2");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_Cancellation_StopsGracefully()
|
|
{
|
|
var sut = CreateSut(CreateRouter((_, ct) => StreamUntilCancelledAsync(new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), ct)));
|
|
using var cts = new CancellationTokenSource();
|
|
cts.Cancel();
|
|
|
|
var messages = await CollectAsync(sut.SubmitMessageAsync("Cancel me", ct: cts.Token));
|
|
|
|
messages.Should().ContainSingle(message => message is SDKMessage.UserMessage);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_EmptyContent_ThrowsArgumentException()
|
|
{
|
|
var sut = CreateSut(CreateRouter());
|
|
|
|
var act = async () =>
|
|
{
|
|
await foreach (var _ in sut.SubmitMessageAsync(" "))
|
|
{
|
|
}
|
|
};
|
|
|
|
await act.Should().ThrowAsync<ArgumentException>();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task CancelAsync_CancelsActiveQuery()
|
|
{
|
|
var started = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
|
|
var sut = CreateSut(CreateRouter((_, ct) => StreamUntilCancelledAsync(started, ct)));
|
|
|
|
var runningQuery = Task.Run(async () =>
|
|
{
|
|
await foreach (var _ in sut.SubmitMessageAsync("Long running"))
|
|
{
|
|
}
|
|
});
|
|
|
|
await started.Task;
|
|
await sut.CancelAsync();
|
|
|
|
var exception = await Record.ExceptionAsync(async () => await runningQuery);
|
|
|
|
exception.Should().NotBeNull();
|
|
(exception is OperationCanceledException || exception is ObjectDisposedException).Should().BeTrue();
|
|
}
|
|
|
|
[Fact]
|
|
public void GetMessages_InitiallyEmpty()
|
|
{
|
|
var sut = CreateSut(CreateRouter());
|
|
|
|
sut.GetMessages().Should().BeEmpty();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task GetCurrentUsage_TracksTokens()
|
|
{
|
|
var sut = CreateSut(CreateRouter(new SDKMessage.StreamingDelta("Hello from assistant")));
|
|
|
|
_ = await CollectAsync(sut.SubmitMessageAsync("Hello from user"));
|
|
var usage = sut.GetCurrentUsage();
|
|
|
|
usage.InputTokens.Should().BeGreaterThan(0);
|
|
usage.OutputTokens.Should().BeGreaterThan(0);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SubmitMessageAsync_ExtractsMemories_WhenFeatureEnabled()
|
|
{
|
|
var sessionMemoryService = Substitute.For<ISessionMemoryService>();
|
|
sessionMemoryService.GetCurrentMemoryAsync().Returns((string?)null);
|
|
sessionMemoryService.TryExtractAsync(Arg.Any<IReadOnlyList<Message>>()).Returns(Task.CompletedTask);
|
|
|
|
var sut = CreateSut(
|
|
CreateRouter(new SDKMessage.StreamingDelta("Remember this")),
|
|
featureFlagService: new StubFeatureFlagService("EXTRACT_MEMORIES"),
|
|
sessionMemoryService: sessionMemoryService);
|
|
|
|
_ = await CollectAsync(sut.SubmitMessageAsync("Store memory"));
|
|
|
|
await sessionMemoryService.Received(1).TryExtractAsync(Arg.Is<IReadOnlyList<Message>>(messages => messages.Count >= 2));
|
|
}
|
|
|
|
private static QueryEngine CreateSut(
|
|
IApiProviderRouter router,
|
|
IToolRegistry? toolRegistry = null,
|
|
IPermissionEngine? permissionEngine = null,
|
|
IPromptBuilder? promptBuilder = null,
|
|
ISessionMemoryService? sessionMemoryService = null,
|
|
IFeatureFlagService? featureFlagService = null,
|
|
Func<string, JsonElement, IPermissionEngine, ToolPermissionContext?, CancellationToken, Task<(string Output, bool IsAllowed, bool ShouldContinue)>>? toolExecutor = null)
|
|
{
|
|
var builder = promptBuilder ?? Substitute.For<IPromptBuilder>();
|
|
builder.BuildAsync(Arg.Any<IReadOnlyList<Message>>(), Arg.Any<ToolPermissionContext?>(), Arg.Any<SubmitMessageOptions>())
|
|
.Returns(Task.FromResult("system prompt"));
|
|
|
|
var memoryService = sessionMemoryService ?? Substitute.For<ISessionMemoryService>();
|
|
memoryService.GetCurrentMemoryAsync().Returns((string?)null);
|
|
memoryService.TryExtractAsync(Arg.Any<IReadOnlyList<Message>>()).Returns(Task.CompletedTask);
|
|
|
|
return new QueryEngine(
|
|
router,
|
|
toolRegistry ?? new StubToolRegistry(),
|
|
permissionEngine ?? Substitute.For<IPermissionEngine>(),
|
|
builder,
|
|
memoryService,
|
|
featureFlagService ?? new StubFeatureFlagService(),
|
|
toolExecutor,
|
|
Substitute.For<ILogger<QueryEngine>>());
|
|
}
|
|
|
|
private static IApiProviderRouter CreateRouter(params SDKMessage[] responses)
|
|
=> CreateRouter([responses]);
|
|
|
|
private static IApiProviderRouter CreateRouter(IReadOnlyList<SDKMessage[]> responseBatches)
|
|
{
|
|
var index = 0;
|
|
return CreateRouter((_, ct) => YieldMessages(index < responseBatches.Count ? responseBatches[index++] : [], ct));
|
|
}
|
|
|
|
private static IApiProviderRouter CreateRouter(Func<ApiRequest, CancellationToken, IAsyncEnumerable<SDKMessage>> streamFactory)
|
|
{
|
|
var provider = Substitute.For<IApiProvider>();
|
|
provider.StreamAsync(Arg.Any<ApiRequest>(), Arg.Any<CancellationToken>())
|
|
.Returns(call => streamFactory(call.Arg<ApiRequest>(), call.Arg<CancellationToken>()));
|
|
|
|
var router = Substitute.For<IApiProviderRouter>();
|
|
router.GetActiveProvider().Returns(provider);
|
|
return router;
|
|
}
|
|
|
|
private static async Task<List<SDKMessage>> CollectAsync(IAsyncEnumerable<SDKMessage> messages)
|
|
{
|
|
var results = new List<SDKMessage>();
|
|
await foreach (var message in messages)
|
|
{
|
|
results.Add(message);
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
private static async IAsyncEnumerable<SDKMessage> YieldMessages(
|
|
IReadOnlyList<SDKMessage> messages,
|
|
[EnumeratorCancellation] CancellationToken ct = default)
|
|
{
|
|
foreach (var message in messages)
|
|
{
|
|
ct.ThrowIfCancellationRequested();
|
|
yield return message;
|
|
await Task.Yield();
|
|
}
|
|
}
|
|
|
|
private static async IAsyncEnumerable<SDKMessage> StreamUntilCancelledAsync(
|
|
TaskCompletionSource started,
|
|
[EnumeratorCancellation] CancellationToken ct = default)
|
|
{
|
|
started.TrySetResult();
|
|
yield return new SDKMessage.StreamingDelta("working");
|
|
await Task.Delay(Timeout.InfiniteTimeSpan, ct);
|
|
}
|
|
|
|
private static JsonElement Json(string json)
|
|
{
|
|
using var document = JsonDocument.Parse(json);
|
|
return document.RootElement.Clone();
|
|
}
|
|
|
|
private sealed class StubToolRegistry : IToolRegistry
|
|
{
|
|
public Task<IReadOnlyList<ITool>> GetToolsAsync(ToolPermissionContext? permissionContext = null)
|
|
=> Task.FromResult<IReadOnlyList<ITool>>([]);
|
|
}
|
|
|
|
private sealed class StubFeatureFlagService(params string[] enabledFlags) : IFeatureFlagService
|
|
{
|
|
private readonly HashSet<string> _enabledFlags = new(enabledFlags, StringComparer.OrdinalIgnoreCase);
|
|
|
|
public bool IsEnabled(string featureFlag) => _enabledFlags.Contains(featureFlag);
|
|
|
|
public IReadOnlySet<string> GetEnabledFlags() => _enabledFlags;
|
|
}
|
|
}
|