using System.Runtime.CompilerServices; using System.Text.Json; using FluentAssertions; using FreeCode.Core.Enums; using FreeCode.Core.Interfaces; using FreeCode.Core.Models; using FreeCode.Engine; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; [assembly: CollectionBehavior(DisableTestParallelization = true)] namespace FreeCode.Tests.Integration; public sealed class QueryPipelineTests { [Fact] public async Task SubmitMessageAsync_SimpleText_ReturnsUserAndAssistantMessages() { var sut = CreateSut(CreateRouter(new SDKMessage.StreamingDelta("Hello"))); var messages = await CollectAsync(sut.SubmitMessageAsync("Hi")); messages.OfType().Should().ContainSingle(); messages.OfType().Select(message => message.Text).Should().Contain("Hello"); messages.OfType().Select(message => message.Text).Should().Contain("Hello"); } [Fact] public async Task SubmitMessageAsync_AssistantMessage_AppendsToHistory() { var sut = CreateSut(CreateRouter(new SDKMessage.AssistantMessage("Ready", "assistant-1"))); _ = await CollectAsync(sut.SubmitMessageAsync("Hello")); sut.GetMessages().Should().Contain(message => message.MessageId == "assistant-1" && message.Role == MessageRole.Assistant && (string?)message.Content == "Ready"); } [Fact] public async Task SubmitMessageAsync_StreamingDelta_AccumulatesText() { var sut = CreateSut(CreateRouter( new SDKMessage.StreamingDelta("Hello"), new SDKMessage.StreamingDelta(" "), new SDKMessage.StreamingDelta("world"))); var messages = await CollectAsync(sut.SubmitMessageAsync("Say hello")); messages.OfType().Select(delta => delta.Text).Should().Equal("Hello", " ", "world"); sut.GetMessages().Should().Contain(message => message.Role == MessageRole.Assistant && (string?)message.Content == "Hello world"); } [Fact] public async Task SubmitMessageAsync_ToolUse_ExecutesToolAndReturnsResult() { var toolCalls = new List(); var sut = CreateSut( CreateRouter(new SDKMessage.ToolUseStart("tool-1", "read", Json("{\"path\":\"file.txt\"}"))), toolExecutor: (toolName, _, _, _, _) => { toolCalls.Add(toolName); return Task.FromResult(("tool output", true, false)); }); var messages = await CollectAsync(sut.SubmitMessageAsync("Use a tool")); toolCalls.Should().Equal("read"); messages.OfType() .Should().ContainSingle(message => message.ToolUseId == "tool-1" && message.ToolName == "read"); messages.OfType() .Should().ContainSingle(message => message.ToolUseId == "tool-1" && message.Output == "tool output" && !message.ShouldContinue); sut.GetMessages().Should().Contain(message => message.Role == MessageRole.Tool && message.ToolUseId == "tool-1" && message.ToolName == "read" && (string?)message.Content == "tool output"); } [Fact] public async Task SubmitMessageAsync_ToolPermissionDenied_ReturnsPermissionDenial() { var sut = CreateSut( CreateRouter(new SDKMessage.ToolUseStart("tool-2", "bash", Json("{}"))), toolExecutor: (_, _, _, _, _) => Task.FromResult(("denied", false, false))); var messages = await CollectAsync(sut.SubmitMessageAsync("Run bash")); messages.OfType() .Should().ContainSingle(message => message.ToolName == "bash" && message.ToolUseId == "tool-2"); messages.OfType() .Should().ContainSingle(message => message.ToolUseId == "tool-2" && message.Output == "denied"); } [Fact] public async Task SubmitMessageAsync_MultipleToolUses_ExecutesAll() { var toolCalls = new List(); var sut = CreateSut( CreateRouter( new SDKMessage.ToolUseStart("tool-1", "read", Json("{}")), new SDKMessage.ToolUseStart("tool-2", "write", Json("{}"))), toolExecutor: (toolName, _, _, _, _) => { toolCalls.Add(toolName); return Task.FromResult(($"completed-{toolName}", true, false)); }); var messages = await CollectAsync(sut.SubmitMessageAsync("Use multiple tools")); toolCalls.Should().Equal("read", "write"); messages.OfType().Select(result => result.ToolUseId).Should().Equal("tool-1", "tool-2"); } [Fact] public async Task SubmitMessageAsync_Cancellation_StopsGracefully() { var sut = CreateSut(CreateRouter((_, ct) => StreamUntilCancelledAsync(new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), ct))); using var cts = new CancellationTokenSource(); cts.Cancel(); var messages = await CollectAsync(sut.SubmitMessageAsync("Cancel me", ct: cts.Token)); messages.Should().ContainSingle(message => message is SDKMessage.UserMessage); } [Fact] public async Task SubmitMessageAsync_EmptyContent_ThrowsArgumentException() { var sut = CreateSut(CreateRouter()); var act = async () => { await foreach (var _ in sut.SubmitMessageAsync(" ")) { } }; await act.Should().ThrowAsync(); } [Fact] public async Task CancelAsync_CancelsActiveQuery() { var started = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var sut = CreateSut(CreateRouter((_, ct) => StreamUntilCancelledAsync(started, ct))); var runningQuery = Task.Run(async () => { await foreach (var _ in sut.SubmitMessageAsync("Long running")) { } }); await started.Task; await sut.CancelAsync(); var exception = await Record.ExceptionAsync(async () => await runningQuery); exception.Should().NotBeNull(); (exception is OperationCanceledException || exception is ObjectDisposedException).Should().BeTrue(); } [Fact] public void GetMessages_InitiallyEmpty() { var sut = CreateSut(CreateRouter()); sut.GetMessages().Should().BeEmpty(); } [Fact] public async Task GetCurrentUsage_TracksTokens() { var sut = CreateSut(CreateRouter(new SDKMessage.StreamingDelta("Hello from assistant"))); _ = await CollectAsync(sut.SubmitMessageAsync("Hello from user")); var usage = sut.GetCurrentUsage(); usage.InputTokens.Should().BeGreaterThan(0); usage.OutputTokens.Should().BeGreaterThan(0); } [Fact] public async Task SubmitMessageAsync_ExtractsMemories_WhenFeatureEnabled() { var sessionMemoryService = Substitute.For(); sessionMemoryService.GetCurrentMemoryAsync().Returns((string?)null); sessionMemoryService.TryExtractAsync(Arg.Any>()).Returns(Task.CompletedTask); var sut = CreateSut( CreateRouter(new SDKMessage.StreamingDelta("Remember this")), featureFlagService: new StubFeatureFlagService("EXTRACT_MEMORIES"), sessionMemoryService: sessionMemoryService); _ = await CollectAsync(sut.SubmitMessageAsync("Store memory")); await sessionMemoryService.Received(1).TryExtractAsync(Arg.Is>(messages => messages.Count >= 2)); } private static QueryEngine CreateSut( IApiProviderRouter router, IToolRegistry? toolRegistry = null, IPermissionEngine? permissionEngine = null, IPromptBuilder? promptBuilder = null, ISessionMemoryService? sessionMemoryService = null, IFeatureFlagService? featureFlagService = null, Func>? toolExecutor = null) { var builder = promptBuilder ?? Substitute.For(); builder.BuildAsync(Arg.Any>(), Arg.Any(), Arg.Any()) .Returns(Task.FromResult("system prompt")); var memoryService = sessionMemoryService ?? Substitute.For(); memoryService.GetCurrentMemoryAsync().Returns((string?)null); memoryService.TryExtractAsync(Arg.Any>()).Returns(Task.CompletedTask); return new QueryEngine( router, toolRegistry ?? new StubToolRegistry(), permissionEngine ?? Substitute.For(), builder, memoryService, featureFlagService ?? new StubFeatureFlagService(), toolExecutor, Substitute.For>()); } private static IApiProviderRouter CreateRouter(params SDKMessage[] responses) => CreateRouter([responses]); private static IApiProviderRouter CreateRouter(IReadOnlyList responseBatches) { var index = 0; return CreateRouter((_, ct) => YieldMessages(index < responseBatches.Count ? responseBatches[index++] : [], ct)); } private static IApiProviderRouter CreateRouter(Func> streamFactory) { var provider = Substitute.For(); provider.StreamAsync(Arg.Any(), Arg.Any()) .Returns(call => streamFactory(call.Arg(), call.Arg())); var router = Substitute.For(); router.GetActiveProvider().Returns(provider); return router; } private static async Task> CollectAsync(IAsyncEnumerable messages) { var results = new List(); await foreach (var message in messages) { results.Add(message); } return results; } private static async IAsyncEnumerable YieldMessages( IReadOnlyList messages, [EnumeratorCancellation] CancellationToken ct = default) { foreach (var message in messages) { ct.ThrowIfCancellationRequested(); yield return message; await Task.Yield(); } } private static async IAsyncEnumerable StreamUntilCancelledAsync( TaskCompletionSource started, [EnumeratorCancellation] CancellationToken ct = default) { started.TrySetResult(); yield return new SDKMessage.StreamingDelta("working"); await Task.Delay(Timeout.InfiniteTimeSpan, ct); } private static JsonElement Json(string json) { using var document = JsonDocument.Parse(json); return document.RootElement.Clone(); } private sealed class StubToolRegistry : IToolRegistry { public Task> GetToolsAsync(ToolPermissionContext? permissionContext = null) => Task.FromResult>([]); } private sealed class StubFeatureFlagService(params string[] enabledFlags) : IFeatureFlagService { private readonly HashSet _enabledFlags = new(enabledFlags, StringComparer.OrdinalIgnoreCase); public bool IsEnabled(string featureFlag) => _enabledFlags.Contains(featureFlag); public IReadOnlySet GetEnabledFlags() => _enabledFlags; } }