diff --git a/src/query_engine.py b/src/query_engine.py index 5f3f3ed..987a98a 100644 --- a/src/query_engine.py +++ b/src/query_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import threading from dataclasses import dataclass, field from uuid import uuid4 @@ -64,7 +65,59 @@ class QueryEnginePort: matched_commands: tuple[str, ...] = (), matched_tools: tuple[str, ...] = (), denied_tools: tuple[PermissionDenial, ...] = (), + cancel_event: threading.Event | None = None, ) -> TurnResult: + """Submit a prompt and return a TurnResult. + + #164 Stage A: cooperative cancellation via cancel_event. + + The cancel_event argument (added for #164) lets a caller request early + termination at a safe point. When set before the pre-mutation commit + stage, submit_message returns early with ``stop_reason='cancelled'`` + and the engine's state (mutable_messages, transcript_store, + permission_denials, total_usage) is left **exactly as it was on + entry**. This closes the #161 follow-up gap: before this change, a + wedged provider thread could finish executing and silently mutate + state after the caller had already observed ``stop_reason='timeout'``, + giving the session a ghost turn the caller never acknowledged. + + Contract: + - cancel_event is None (default) — legacy behaviour, no checks. + - cancel_event set **before** budget check — returns 'cancelled' + immediately; no output synthesis, no projection, no mutation. + - cancel_event set **between** budget check and commit — returns + 'cancelled' with state intact. + - cancel_event set **after** commit — not observable; the turn is + already committed and the caller sees 'completed'. Cancellation + is a *safe point* mechanism, not preemption. This is the honest + limit of cooperative cancellation in Python threading land. + + Stop reason taxonomy after #164 Stage A: + - 'completed' — turn committed, state mutated exactly once + - 'max_budget_reached' — overflow, state unchanged (#162) + - 'max_turns_reached' — capacity exceeded, state unchanged + - 'cancelled' — cancel_event observed, state unchanged + - 'timeout' — synthesised by runtime, not engine (#161) + + Callers that care about deadline-driven cancellation (run_turn_loop) + can now request cleanup by setting the event on timeout — the next + submit_message on the same engine will observe it at the start and + return 'cancelled' without touching state, even if the previous call + is still wedged in provider IO. + """ + # #164 Stage A: earliest safe cancellation point. No output synthesis, + # no budget projection, no mutation — just an immediate clean return. + if cancel_event is not None and cancel_event.is_set(): + return TurnResult( + prompt=prompt, + output='', + matched_commands=matched_commands, + matched_tools=matched_tools, + permission_denials=denied_tools, + usage=self.total_usage, # unchanged + stop_reason='cancelled', + ) + if len(self.mutable_messages) >= self.config.max_turns: output = f'Max turns reached before processing prompt: {prompt}' return TurnResult( @@ -104,6 +157,21 @@ class QueryEnginePort: stop_reason='max_budget_reached', ) + # #164 Stage A: second safe cancellation point. Projection is done + # but nothing has been committed yet. If the caller cancelled while + # we were building output / computing budget, honour it here — still + # no mutation. + if cancel_event is not None and cancel_event.is_set(): + return TurnResult( + prompt=prompt, + output=output, + matched_commands=matched_commands, + matched_tools=matched_tools, + permission_denials=denied_tools, + usage=self.total_usage, # unchanged + stop_reason='cancelled', + ) + self.mutable_messages.append(prompt) self.transcript_store.append(prompt) self.permission_denials.extend(denied_tools) diff --git a/src/runtime.py b/src/runtime.py index 52dd4be..ebeb011 100644 --- a/src/runtime.py +++ b/src/runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading import time from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError from dataclasses import dataclass @@ -209,6 +210,14 @@ class PortRuntime: denied_tools = tuple(self._infer_permission_denials(matches)) results: list[TurnResult] = [] deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None + # #164 Stage A: shared cancel_event signals cooperative cancellation + # across turns. On timeout we set() it so any still-running + # submit_message call (or the next one on the same engine) observes + # the cancel at a safe checkpoint and returns stop_reason='cancelled' + # without mutating state. This closes the window where a wedged + # provider thread could commit a ghost turn after the caller saw + # 'timeout'. + cancel_event = threading.Event() if deadline is not None else None # ThreadPoolExecutor is reused across turns so we cancel cleanly on exit. executor = ThreadPoolExecutor(max_workers=1) if deadline is not None else None @@ -229,22 +238,35 @@ class PortRuntime: if deadline is None: # Legacy path: unbounded call, preserves existing behaviour exactly. # #159: pass inferred denied_tools (no longer hardcoded empty tuple) + # #164: cancel_event is None on this path; submit_message skips + # cancellation checks entirely (legacy zero-overhead behaviour). result = engine.submit_message(turn_prompt, command_names, tool_names, denied_tools) else: remaining = deadline - time.monotonic() if remaining <= 0: + # #164: signal cancel for any in-flight/future submit_message + # calls that share this engine. Safe because nothing has been + # submitted yet this turn. + assert cancel_event is not None + cancel_event.set() results.append(self._build_timeout_result(turn_prompt, command_names, tool_names)) break assert executor is not None future = executor.submit( - engine.submit_message, turn_prompt, command_names, tool_names, denied_tools + engine.submit_message, turn_prompt, command_names, tool_names, + denied_tools, cancel_event, ) try: result = future.result(timeout=remaining) except FuturesTimeoutError: - # Best-effort cancel; submit_message may still finish in background - # but we never read its output. The engine's own state mutation - # is owned by the engine and not our concern here. + # #164 Stage A: explicitly signal cancel to the still-running + # submit_message thread. The next time it hits a checkpoint + # (entry or post-budget), it returns 'cancelled' without + # mutating state instead of committing a ghost turn. This + # upgrades #161's best-effort future.cancel() (which only + # cancels pre-start futures) to cooperative mid-flight cancel. + assert cancel_event is not None + cancel_event.set() future.cancel() results.append(self._build_timeout_result(turn_prompt, command_names, tool_names)) break diff --git a/tests/test_run_turn_loop_cancellation.py b/tests/test_run_turn_loop_cancellation.py new file mode 100644 index 0000000..1450efd --- /dev/null +++ b/tests/test_run_turn_loop_cancellation.py @@ -0,0 +1,156 @@ +"""Tests for run_turn_loop timeout triggering cooperative cancel (ROADMAP #164 Stage A). + +End-to-end integration: when the wall-clock timeout fires in run_turn_loop, +the runtime must signal the cancel_event so any in-flight submit_message +thread sees it at its next safe checkpoint and returns without mutating +state. + +This closes the gap filed in #164: #161's timeout bounded caller wait but +did not prevent ghost turns. +""" + +from __future__ import annotations + +import sys +import threading +import time +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.models import UsageSummary # noqa: E402 +from src.query_engine import TurnResult # noqa: E402 +from src.runtime import PortRuntime # noqa: E402 + + +def _completed(prompt: str) -> TurnResult: + return TurnResult( + prompt=prompt, + output='ok', + matched_commands=(), + matched_tools=(), + permission_denials=(), + usage=UsageSummary(), + stop_reason='completed', + ) + + +class TestTimeoutPropagatesCancelEvent: + def test_runtime_passes_cancel_event_to_submit_message(self) -> None: + """submit_message receives a cancel_event when a deadline is in play.""" + runtime = PortRuntime() + captured_event: list[threading.Event | None] = [] + + def _capture(prompt, commands, tools, denials, cancel_event=None): + captured_event.append(cancel_event) + return _completed(prompt) + + with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory: + engine = mock_factory.return_value + engine.submit_message.side_effect = _capture + + runtime.run_turn_loop( + 'hello', max_turns=1, timeout_seconds=5.0, + ) + + # Runtime passed a real Event object, not None + assert len(captured_event) == 1 + assert isinstance(captured_event[0], threading.Event) + + def test_legacy_no_timeout_does_not_pass_cancel_event(self) -> None: + """Without timeout_seconds, the cancel_event is None (legacy behaviour).""" + runtime = PortRuntime() + captured_kwargs: list[dict] = [] + + def _capture(prompt, commands, tools, denials): + # Legacy call signature: no cancel_event kwarg + captured_kwargs.append({'prompt': prompt}) + return _completed(prompt) + + with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory: + engine = mock_factory.return_value + engine.submit_message.side_effect = _capture + + runtime.run_turn_loop('hello', max_turns=1) + + # Legacy path didn't pass cancel_event at all + assert len(captured_kwargs) == 1 + + def test_timeout_sets_cancel_event_before_returning(self) -> None: + """When timeout fires mid-call, the event is set and the still-running + thread would see 'cancelled' if it checks before returning.""" + runtime = PortRuntime() + observed_events_at_checkpoint: list[bool] = [] + release = threading.Event() # test-side release so the thread doesn't leak forever + + def _slow_submit(prompt, commands, tools, denials, cancel_event=None): + # Simulate provider work: block until either cancel or a test-side release. + # If cancel fires, check if the event is observably set. + start = time.monotonic() + while time.monotonic() - start < 2.0: + if cancel_event is not None and cancel_event.is_set(): + observed_events_at_checkpoint.append(True) + return TurnResult( + prompt=prompt, output='', + matched_commands=(), matched_tools=(), + permission_denials=(), usage=UsageSummary(), + stop_reason='cancelled', + ) + if release.is_set(): + break + time.sleep(0.05) + return _completed(prompt) + + with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory: + engine = mock_factory.return_value + engine.submit_message.side_effect = _slow_submit + + # Tight deadline: 0.2s, submit will be mid-loop when timeout fires + start = time.monotonic() + results = runtime.run_turn_loop( + 'hello', max_turns=1, timeout_seconds=0.2, + ) + elapsed = time.monotonic() - start + release.set() # let the background thread exit cleanly + + # Runtime returned a timeout TurnResult to the caller + assert results[-1].stop_reason == 'timeout' + # And it happened within a reasonable window of the deadline + assert elapsed < 1.5, f'runtime did not honour deadline: {elapsed:.2f}s' + + # Give the background thread a moment to observe the cancel. + # We don't assert on it directly (thread-level observability is + # timing-dependent), but the contract is: the event IS set, so any + # cooperative checkpoint will see it. + time.sleep(0.3) + + +class TestCancelEventSharedAcrossTurns: + """Event is created once per run_turn_loop invocation and shared across turns.""" + + def test_same_event_threaded_to_every_submit_message(self) -> None: + runtime = PortRuntime() + captured_events: list[threading.Event] = [] + + def _capture(prompt, commands, tools, denials, cancel_event=None): + if cancel_event is not None: + captured_events.append(cancel_event) + return _completed(prompt) + + with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory: + engine = mock_factory.return_value + engine.submit_message.side_effect = _capture + + runtime.run_turn_loop( + 'hello', max_turns=3, timeout_seconds=5.0, + continuation_prompt='continue', + ) + + # All 3 turns received the same event object (same identity) + assert len(captured_events) == 3 + assert all(e is captured_events[0] for e in captured_events), ( + 'runtime must share one cancel_event across turns, not create ' + 'a new one per turn \u2014 otherwise a late-arriving cancel on turn ' + 'N-1 cannot affect turn N' + ) diff --git a/tests/test_run_turn_loop_timeout.py b/tests/test_run_turn_loop_timeout.py index 8a24dae..3d9a9c4 100644 --- a/tests/test_run_turn_loop_timeout.py +++ b/tests/test_run_turn_loop_timeout.py @@ -51,7 +51,9 @@ class TestTimeoutAbortsHungTurn: """A stalled submit_message must be aborted and emit stop_reason='timeout'.""" runtime = PortRuntime() - def _hang(prompt, commands, tools, denials): + # #164 Stage A: runtime now passes cancel_event as a 5th positional + # arg on the timeout path, so mocks must accept it (even if they ignore it). + def _hang(prompt, commands, tools, denials, cancel_event=None): time.sleep(5.0) # would block the loop return _completed_result(prompt) @@ -84,7 +86,7 @@ class TestTimeoutBudgetIsTotal: runtime = PortRuntime() call_count = {'n': 0} - def _slow(prompt, commands, tools, denials): + def _slow(prompt, commands, tools, denials, cancel_event=None): call_count['n'] += 1 time.sleep(0.4) # each turn burns 0.4s return _completed_result(prompt) @@ -135,7 +137,7 @@ class TestTimeoutResultShape: """Synthetic TurnResult on timeout must carry the turn's prompt + routed matches.""" runtime = PortRuntime() - def _hang(prompt, commands, tools, denials): + def _hang(prompt, commands, tools, denials, cancel_event=None): time.sleep(5.0) return _completed_result(prompt) diff --git a/tests/test_submit_message_cancellation.py b/tests/test_submit_message_cancellation.py new file mode 100644 index 0000000..eb04bf8 --- /dev/null +++ b/tests/test_submit_message_cancellation.py @@ -0,0 +1,220 @@ +"""Tests for cooperative cancellation in submit_message (ROADMAP #164 Stage A). + +Verifies that cancel_event enables safe early termination: +- Event set before call => immediate return with stop_reason='cancelled' +- Event set between budget check and commit => still 'cancelled', no mutation +- Event set after commit => not observable (honest cooperative limit) +- Legacy callers (cancel_event=None) see zero behaviour change +- State is untouched on cancellation: mutable_messages, transcript_store, + permission_denials, total_usage all preserved + +This closes the #161 follow-up gap filed as #164: wedged provider threads +can no longer silently commit ghost turns after the caller observed a +timeout. +""" + +from __future__ import annotations + +import sys +import threading +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src.models import PermissionDenial # noqa: E402 +from src.port_manifest import build_port_manifest # noqa: E402 +from src.query_engine import QueryEngineConfig, QueryEnginePort, TurnResult # noqa: E402 + + +def _fresh_engine(**config_overrides) -> QueryEnginePort: + config = QueryEngineConfig(**config_overrides) if config_overrides else QueryEngineConfig() + return QueryEnginePort(manifest=build_port_manifest(), config=config) + + +class TestCancellationBeforeCall: + """Event set before submit_message is invoked => immediate 'cancelled'.""" + + def test_pre_set_event_returns_cancelled_immediately(self) -> None: + engine = _fresh_engine() + event = threading.Event() + event.set() + + result = engine.submit_message('hello', cancel_event=event) + + assert result.stop_reason == 'cancelled' + assert result.prompt == 'hello' + # Output is empty on pre-budget cancel (no synthesis) + assert result.output == '' + + def test_pre_set_event_preserves_mutable_messages(self) -> None: + engine = _fresh_engine() + event = threading.Event() + event.set() + + engine.submit_message('ghost turn', cancel_event=event) + + assert engine.mutable_messages == [], ( + 'cancelled turn must not appear in mutable_messages' + ) + + def test_pre_set_event_preserves_transcript_store(self) -> None: + engine = _fresh_engine() + event = threading.Event() + event.set() + + engine.submit_message('ghost turn', cancel_event=event) + + assert engine.transcript_store.entries == [], ( + 'cancelled turn must not appear in transcript_store' + ) + + def test_pre_set_event_preserves_usage_counters(self) -> None: + engine = _fresh_engine() + initial_usage = engine.total_usage + event = threading.Event() + event.set() + + engine.submit_message('expensive prompt ' * 100, cancel_event=event) + + assert engine.total_usage == initial_usage, ( + 'cancelled turn must not increment token counters' + ) + + def test_pre_set_event_preserves_permission_denials(self) -> None: + engine = _fresh_engine() + event = threading.Event() + event.set() + + denials = (PermissionDenial(tool_name='BashTool', reason='destructive'),) + engine.submit_message('run bash ls', denied_tools=denials, cancel_event=event) + + assert engine.permission_denials == [], ( + 'cancelled turn must not extend permission_denials' + ) + + +class TestCancellationAfterBudgetCheck: + """Event set between budget projection and commit => 'cancelled', state intact. + + This simulates the realistic racy case: engine starts computing output, + caller hits deadline, sets event. Engine observes at post-budget checkpoint + and returns cleanly. + """ + + def test_post_budget_cancel_returns_cancelled(self) -> None: + engine = _fresh_engine() + event = threading.Event() + + # Patch: set the event after projection but before mutation. We do this + # by wrapping _format_output (called mid-submit) to set the event. + original_format = engine._format_output + + def _set_then_format(*args, **kwargs): + result = original_format(*args, **kwargs) + event.set() # trigger cancel right after output is built + return result + + engine._format_output = _set_then_format # type: ignore[method-assign] + + result = engine.submit_message('hello', cancel_event=event) + + assert result.stop_reason == 'cancelled' + # Output IS built here (we're past the pre-budget checkpoint), so it's + # not empty. The contract is about *state*, not output synthesis. + assert result.output != '' + # Critical: state still unchanged + assert engine.mutable_messages == [] + assert engine.transcript_store.entries == [] + + +class TestCancellationAfterCommit: + """Event set after commit is not observable \u2014 honest cooperative limit.""" + + def test_post_commit_cancel_is_not_observable(self) -> None: + engine = _fresh_engine() + event = threading.Event() + + # Event only set *after* submit_message returns. The first call has + # already committed before the event is set. + result = engine.submit_message('hello', cancel_event=event) + event.set() # too late + + assert result.stop_reason == 'completed', ( + 'cancel set after commit must not retroactively invalidate the turn' + ) + assert engine.mutable_messages == ['hello'] + + def test_next_call_observes_cancel(self) -> None: + """The cancel_event persists \u2014 the next call on the same engine sees it.""" + engine = _fresh_engine() + event = threading.Event() + + engine.submit_message('first', cancel_event=event) + assert engine.mutable_messages == ['first'] + + event.set() + # Next call observes the cancel at entry + result = engine.submit_message('second', cancel_event=event) + + assert result.stop_reason == 'cancelled' + # 'second' must NOT have been committed + assert engine.mutable_messages == ['first'] + + +class TestLegacyCallersUnchanged: + """cancel_event=None (default) => zero behaviour change from pre-#164.""" + + def test_no_event_submits_normally(self) -> None: + engine = _fresh_engine() + result = engine.submit_message('hello') + + assert result.stop_reason == 'completed' + assert engine.mutable_messages == ['hello'] + + def test_no_event_with_budget_overflow_still_rejects_atomically(self) -> None: + """#162 atomicity contract survives when cancel_event is absent.""" + engine = _fresh_engine(max_budget_tokens=1) + words = ' '.join(['word'] * 100) + + result = engine.submit_message(words) # no cancel_event + + assert result.stop_reason == 'max_budget_reached' + assert engine.mutable_messages == [] + + def test_no_event_respects_max_turns(self) -> None: + """max_turns_reached contract survives when cancel_event is absent.""" + engine = _fresh_engine(max_turns=1) + engine.submit_message('first') + result = engine.submit_message('second') # no cancel_event + + assert result.stop_reason == 'max_turns_reached' + assert engine.mutable_messages == ['first'] + + +class TestCancellationVsOtherStopReasons: + """cancel_event has a defined precedence relative to budget/turns.""" + + def test_cancel_precedes_max_turns_check(self) -> None: + """If cancel is set when capacity is also full, cancel wins (clearer signal).""" + engine = _fresh_engine(max_turns=0) # immediately full + event = threading.Event() + event.set() + + result = engine.submit_message('hello', cancel_event=event) + + # cancel_event check is the very first thing in submit_message, + # so it fires before the max_turns check even sees capacity + assert result.stop_reason == 'cancelled' + + def test_cancel_does_not_override_commit(self) -> None: + """Completed turn with late cancel still reports 'completed' \u2014 the + turn already succeeded; we don't lie about it.""" + engine = _fresh_engine() + event = threading.Event() + + # Event gets set after the mutation is done \u2014 submit_message doesn't + # re-check after commit + result = engine.submit_message('hello', cancel_event=event) + event.set() + + assert result.stop_reason == 'completed'