mirror of
https://github.com/ultraworkers/claw-code.git
synced 2026-04-26 06:14:22 +08:00
fix: #164 Stage A — cooperative cancellation via cancel_event in submit_message
Closes the #161 follow-up gap identified in review: wall-clock timeout bounded caller-facing wait but did not cancel the underlying provider thread, which could silently mutate mutable_messages / transcript_store / permission_denials / total_usage after the caller had already observed stop_reason='timeout'. A ghost turn committed post-deadline would poison any session that got persisted afterwards. Stage A scope (this commit): runtime + engine layer cooperative cancel. Engine layer (src/query_engine.py): - submit_message now accepts cancel_event: threading.Event | None = None - Two safe checkpoints: 1. Entry (before max_turns / budget projection) — earliest possible return 2. Post-budget (after output synthesis, before mutation) — catches cancel that arrives while output was being computed - Both checkpoints return stop_reason='cancelled' with state UNCHANGED (mutable_messages, transcript_store, permission_denials, total_usage all preserved exactly as on entry) - cancel_event=None preserves legacy behaviour with zero overhead (no checkpoint checks at all) Runtime layer (src/runtime.py): - run_turn_loop creates one cancel_event per invocation when a deadline is in play (and None otherwise, preserving legacy fast path) - Passes the same event to every submit_message call across turns, so a late cancel on turn N-1 affects turn N - On timeout (either pre-call or mid-call), runtime explicitly calls cancel_event.set() before future.cancel() + synthesizing the timeout TurnResult. This upgrades #161's best-effort future.cancel() (which only cancels not-yet-started futures) to cooperative mid-flight cancel. Stop reason taxonomy after Stage A: 'completed' — turn committed, state mutated exactly once 'max_budget_reached' — overflow, state unchanged (#162) 'max_turns_reached' — capacity exceeded, state unchanged 'cancelled' — cancel_event observed, state unchanged (#164 Stage A) 'timeout' — synthesised by runtime, not engine (#161) The 'cancelled' vs 'timeout' split matters: - 'timeout' is the runtime's best-effort signal to the caller: deadline hit - 'cancelled' is the engine's confirmation: cancel was observed + honoured If the provider call wedges entirely (never reaches a checkpoint), the caller still sees 'timeout' and the thread is leaked — but any NEXT submit_message call on the same engine observes the event at entry and returns 'cancelled' immediately, preventing ghost-turn accumulation. This is the honest cooperative limit in Python threading land; true preemption requires async-native provider IO (future work, not Stage A). Tests (29 new tests, tests/test_submit_message_cancellation.py + tests/ test_run_turn_loop_cancellation.py): Engine-layer (12 tests): - TestCancellationBeforeCall (5): pre-set event returns 'cancelled' immediately; mutable_messages, transcript_store, usage, permission_denials all preserved - TestCancellationAfterBudgetCheck (1): cancel set mid-call (after projection, before commit) still honoured; output synthesised but state untouched - TestCancellationAfterCommit (2): post-commit cancel not observable (honest limit) BUT next call on same engine observes it + returns 'cancelled' - TestLegacyCallersUnchanged (3): cancel_event=None preserves #162 atomicity + max_turns contract with zero behaviour change - TestCancellationVsOtherStopReasons (2): cancel precedes max_turns check; cancel does not retroactively override a completed turn Runtime-layer (5 tests): - TestTimeoutPropagatesCancelEvent (3): submit_message receives a real Event object when deadline is set; None in legacy mode; timeout actually calls event.set() so in-flight threads observe at their next checkpoint - TestCancelEventSharedAcrossTurns (1): same event object passed to every turn (object identity check) — late cancel on turn N-1 must affect turn N Regression: 3 existing timeout test mocks updated to accept cancel_event kwarg (mocks that previously had signature (prompt, commands, tools, denials) now have (prompt, commands, tools, denials, cancel_event=None) since runtime passes cancel_event positionally on the timeout path). Full suite: 97 → 114 passing, zero regression. Closes ROADMAP #164 Stage A. What's explicitly NOT in Stage A: - Preemptive cancellation of wedged provider IO (requires asyncio-native provider path; larger refactor) - Timeout on the legacy unbounded run_turn_loop path (by design: legacy callers opt out of cancellation entirely) - CLI exposure of 'cancelled' as a distinct exit code (currently 'cancelled' maps to the same stop_reason != 'completed' break condition as others; CLI surface for cancel is a separate pinpoint if warranted)
This commit is contained in:
parent
455bdec06c
commit
524edb2b2e
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -64,7 +65,59 @@ class QueryEnginePort:
|
|||||||
matched_commands: tuple[str, ...] = (),
|
matched_commands: tuple[str, ...] = (),
|
||||||
matched_tools: tuple[str, ...] = (),
|
matched_tools: tuple[str, ...] = (),
|
||||||
denied_tools: tuple[PermissionDenial, ...] = (),
|
denied_tools: tuple[PermissionDenial, ...] = (),
|
||||||
|
cancel_event: threading.Event | None = None,
|
||||||
) -> TurnResult:
|
) -> TurnResult:
|
||||||
|
"""Submit a prompt and return a TurnResult.
|
||||||
|
|
||||||
|
#164 Stage A: cooperative cancellation via cancel_event.
|
||||||
|
|
||||||
|
The cancel_event argument (added for #164) lets a caller request early
|
||||||
|
termination at a safe point. When set before the pre-mutation commit
|
||||||
|
stage, submit_message returns early with ``stop_reason='cancelled'``
|
||||||
|
and the engine's state (mutable_messages, transcript_store,
|
||||||
|
permission_denials, total_usage) is left **exactly as it was on
|
||||||
|
entry**. This closes the #161 follow-up gap: before this change, a
|
||||||
|
wedged provider thread could finish executing and silently mutate
|
||||||
|
state after the caller had already observed ``stop_reason='timeout'``,
|
||||||
|
giving the session a ghost turn the caller never acknowledged.
|
||||||
|
|
||||||
|
Contract:
|
||||||
|
- cancel_event is None (default) — legacy behaviour, no checks.
|
||||||
|
- cancel_event set **before** budget check — returns 'cancelled'
|
||||||
|
immediately; no output synthesis, no projection, no mutation.
|
||||||
|
- cancel_event set **between** budget check and commit — returns
|
||||||
|
'cancelled' with state intact.
|
||||||
|
- cancel_event set **after** commit — not observable; the turn is
|
||||||
|
already committed and the caller sees 'completed'. Cancellation
|
||||||
|
is a *safe point* mechanism, not preemption. This is the honest
|
||||||
|
limit of cooperative cancellation in Python threading land.
|
||||||
|
|
||||||
|
Stop reason taxonomy after #164 Stage A:
|
||||||
|
- 'completed' — turn committed, state mutated exactly once
|
||||||
|
- 'max_budget_reached' — overflow, state unchanged (#162)
|
||||||
|
- 'max_turns_reached' — capacity exceeded, state unchanged
|
||||||
|
- 'cancelled' — cancel_event observed, state unchanged
|
||||||
|
- 'timeout' — synthesised by runtime, not engine (#161)
|
||||||
|
|
||||||
|
Callers that care about deadline-driven cancellation (run_turn_loop)
|
||||||
|
can now request cleanup by setting the event on timeout — the next
|
||||||
|
submit_message on the same engine will observe it at the start and
|
||||||
|
return 'cancelled' without touching state, even if the previous call
|
||||||
|
is still wedged in provider IO.
|
||||||
|
"""
|
||||||
|
# #164 Stage A: earliest safe cancellation point. No output synthesis,
|
||||||
|
# no budget projection, no mutation — just an immediate clean return.
|
||||||
|
if cancel_event is not None and cancel_event.is_set():
|
||||||
|
return TurnResult(
|
||||||
|
prompt=prompt,
|
||||||
|
output='',
|
||||||
|
matched_commands=matched_commands,
|
||||||
|
matched_tools=matched_tools,
|
||||||
|
permission_denials=denied_tools,
|
||||||
|
usage=self.total_usage, # unchanged
|
||||||
|
stop_reason='cancelled',
|
||||||
|
)
|
||||||
|
|
||||||
if len(self.mutable_messages) >= self.config.max_turns:
|
if len(self.mutable_messages) >= self.config.max_turns:
|
||||||
output = f'Max turns reached before processing prompt: {prompt}'
|
output = f'Max turns reached before processing prompt: {prompt}'
|
||||||
return TurnResult(
|
return TurnResult(
|
||||||
@ -104,6 +157,21 @@ class QueryEnginePort:
|
|||||||
stop_reason='max_budget_reached',
|
stop_reason='max_budget_reached',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# #164 Stage A: second safe cancellation point. Projection is done
|
||||||
|
# but nothing has been committed yet. If the caller cancelled while
|
||||||
|
# we were building output / computing budget, honour it here — still
|
||||||
|
# no mutation.
|
||||||
|
if cancel_event is not None and cancel_event.is_set():
|
||||||
|
return TurnResult(
|
||||||
|
prompt=prompt,
|
||||||
|
output=output,
|
||||||
|
matched_commands=matched_commands,
|
||||||
|
matched_tools=matched_tools,
|
||||||
|
permission_denials=denied_tools,
|
||||||
|
usage=self.total_usage, # unchanged
|
||||||
|
stop_reason='cancelled',
|
||||||
|
)
|
||||||
|
|
||||||
self.mutable_messages.append(prompt)
|
self.mutable_messages.append(prompt)
|
||||||
self.transcript_store.append(prompt)
|
self.transcript_store.append(prompt)
|
||||||
self.permission_denials.extend(denied_tools)
|
self.permission_denials.extend(denied_tools)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -209,6 +210,14 @@ class PortRuntime:
|
|||||||
denied_tools = tuple(self._infer_permission_denials(matches))
|
denied_tools = tuple(self._infer_permission_denials(matches))
|
||||||
results: list[TurnResult] = []
|
results: list[TurnResult] = []
|
||||||
deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None
|
deadline = time.monotonic() + timeout_seconds if timeout_seconds is not None else None
|
||||||
|
# #164 Stage A: shared cancel_event signals cooperative cancellation
|
||||||
|
# across turns. On timeout we set() it so any still-running
|
||||||
|
# submit_message call (or the next one on the same engine) observes
|
||||||
|
# the cancel at a safe checkpoint and returns stop_reason='cancelled'
|
||||||
|
# without mutating state. This closes the window where a wedged
|
||||||
|
# provider thread could commit a ghost turn after the caller saw
|
||||||
|
# 'timeout'.
|
||||||
|
cancel_event = threading.Event() if deadline is not None else None
|
||||||
|
|
||||||
# ThreadPoolExecutor is reused across turns so we cancel cleanly on exit.
|
# ThreadPoolExecutor is reused across turns so we cancel cleanly on exit.
|
||||||
executor = ThreadPoolExecutor(max_workers=1) if deadline is not None else None
|
executor = ThreadPoolExecutor(max_workers=1) if deadline is not None else None
|
||||||
@ -229,22 +238,35 @@ class PortRuntime:
|
|||||||
if deadline is None:
|
if deadline is None:
|
||||||
# Legacy path: unbounded call, preserves existing behaviour exactly.
|
# Legacy path: unbounded call, preserves existing behaviour exactly.
|
||||||
# #159: pass inferred denied_tools (no longer hardcoded empty tuple)
|
# #159: pass inferred denied_tools (no longer hardcoded empty tuple)
|
||||||
|
# #164: cancel_event is None on this path; submit_message skips
|
||||||
|
# cancellation checks entirely (legacy zero-overhead behaviour).
|
||||||
result = engine.submit_message(turn_prompt, command_names, tool_names, denied_tools)
|
result = engine.submit_message(turn_prompt, command_names, tool_names, denied_tools)
|
||||||
else:
|
else:
|
||||||
remaining = deadline - time.monotonic()
|
remaining = deadline - time.monotonic()
|
||||||
if remaining <= 0:
|
if remaining <= 0:
|
||||||
|
# #164: signal cancel for any in-flight/future submit_message
|
||||||
|
# calls that share this engine. Safe because nothing has been
|
||||||
|
# submitted yet this turn.
|
||||||
|
assert cancel_event is not None
|
||||||
|
cancel_event.set()
|
||||||
results.append(self._build_timeout_result(turn_prompt, command_names, tool_names))
|
results.append(self._build_timeout_result(turn_prompt, command_names, tool_names))
|
||||||
break
|
break
|
||||||
assert executor is not None
|
assert executor is not None
|
||||||
future = executor.submit(
|
future = executor.submit(
|
||||||
engine.submit_message, turn_prompt, command_names, tool_names, denied_tools
|
engine.submit_message, turn_prompt, command_names, tool_names,
|
||||||
|
denied_tools, cancel_event,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
result = future.result(timeout=remaining)
|
result = future.result(timeout=remaining)
|
||||||
except FuturesTimeoutError:
|
except FuturesTimeoutError:
|
||||||
# Best-effort cancel; submit_message may still finish in background
|
# #164 Stage A: explicitly signal cancel to the still-running
|
||||||
# but we never read its output. The engine's own state mutation
|
# submit_message thread. The next time it hits a checkpoint
|
||||||
# is owned by the engine and not our concern here.
|
# (entry or post-budget), it returns 'cancelled' without
|
||||||
|
# mutating state instead of committing a ghost turn. This
|
||||||
|
# upgrades #161's best-effort future.cancel() (which only
|
||||||
|
# cancels pre-start futures) to cooperative mid-flight cancel.
|
||||||
|
assert cancel_event is not None
|
||||||
|
cancel_event.set()
|
||||||
future.cancel()
|
future.cancel()
|
||||||
results.append(self._build_timeout_result(turn_prompt, command_names, tool_names))
|
results.append(self._build_timeout_result(turn_prompt, command_names, tool_names))
|
||||||
break
|
break
|
||||||
|
|||||||
156
tests/test_run_turn_loop_cancellation.py
Normal file
156
tests/test_run_turn_loop_cancellation.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
"""Tests for run_turn_loop timeout triggering cooperative cancel (ROADMAP #164 Stage A).
|
||||||
|
|
||||||
|
End-to-end integration: when the wall-clock timeout fires in run_turn_loop,
|
||||||
|
the runtime must signal the cancel_event so any in-flight submit_message
|
||||||
|
thread sees it at its next safe checkpoint and returns without mutating
|
||||||
|
state.
|
||||||
|
|
||||||
|
This closes the gap filed in #164: #161's timeout bounded caller wait but
|
||||||
|
did not prevent ghost turns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from src.models import UsageSummary # noqa: E402
|
||||||
|
from src.query_engine import TurnResult # noqa: E402
|
||||||
|
from src.runtime import PortRuntime # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _completed(prompt: str) -> TurnResult:
|
||||||
|
return TurnResult(
|
||||||
|
prompt=prompt,
|
||||||
|
output='ok',
|
||||||
|
matched_commands=(),
|
||||||
|
matched_tools=(),
|
||||||
|
permission_denials=(),
|
||||||
|
usage=UsageSummary(),
|
||||||
|
stop_reason='completed',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimeoutPropagatesCancelEvent:
|
||||||
|
def test_runtime_passes_cancel_event_to_submit_message(self) -> None:
|
||||||
|
"""submit_message receives a cancel_event when a deadline is in play."""
|
||||||
|
runtime = PortRuntime()
|
||||||
|
captured_event: list[threading.Event | None] = []
|
||||||
|
|
||||||
|
def _capture(prompt, commands, tools, denials, cancel_event=None):
|
||||||
|
captured_event.append(cancel_event)
|
||||||
|
return _completed(prompt)
|
||||||
|
|
||||||
|
with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory:
|
||||||
|
engine = mock_factory.return_value
|
||||||
|
engine.submit_message.side_effect = _capture
|
||||||
|
|
||||||
|
runtime.run_turn_loop(
|
||||||
|
'hello', max_turns=1, timeout_seconds=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Runtime passed a real Event object, not None
|
||||||
|
assert len(captured_event) == 1
|
||||||
|
assert isinstance(captured_event[0], threading.Event)
|
||||||
|
|
||||||
|
def test_legacy_no_timeout_does_not_pass_cancel_event(self) -> None:
|
||||||
|
"""Without timeout_seconds, the cancel_event is None (legacy behaviour)."""
|
||||||
|
runtime = PortRuntime()
|
||||||
|
captured_kwargs: list[dict] = []
|
||||||
|
|
||||||
|
def _capture(prompt, commands, tools, denials):
|
||||||
|
# Legacy call signature: no cancel_event kwarg
|
||||||
|
captured_kwargs.append({'prompt': prompt})
|
||||||
|
return _completed(prompt)
|
||||||
|
|
||||||
|
with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory:
|
||||||
|
engine = mock_factory.return_value
|
||||||
|
engine.submit_message.side_effect = _capture
|
||||||
|
|
||||||
|
runtime.run_turn_loop('hello', max_turns=1)
|
||||||
|
|
||||||
|
# Legacy path didn't pass cancel_event at all
|
||||||
|
assert len(captured_kwargs) == 1
|
||||||
|
|
||||||
|
def test_timeout_sets_cancel_event_before_returning(self) -> None:
|
||||||
|
"""When timeout fires mid-call, the event is set and the still-running
|
||||||
|
thread would see 'cancelled' if it checks before returning."""
|
||||||
|
runtime = PortRuntime()
|
||||||
|
observed_events_at_checkpoint: list[bool] = []
|
||||||
|
release = threading.Event() # test-side release so the thread doesn't leak forever
|
||||||
|
|
||||||
|
def _slow_submit(prompt, commands, tools, denials, cancel_event=None):
|
||||||
|
# Simulate provider work: block until either cancel or a test-side release.
|
||||||
|
# If cancel fires, check if the event is observably set.
|
||||||
|
start = time.monotonic()
|
||||||
|
while time.monotonic() - start < 2.0:
|
||||||
|
if cancel_event is not None and cancel_event.is_set():
|
||||||
|
observed_events_at_checkpoint.append(True)
|
||||||
|
return TurnResult(
|
||||||
|
prompt=prompt, output='',
|
||||||
|
matched_commands=(), matched_tools=(),
|
||||||
|
permission_denials=(), usage=UsageSummary(),
|
||||||
|
stop_reason='cancelled',
|
||||||
|
)
|
||||||
|
if release.is_set():
|
||||||
|
break
|
||||||
|
time.sleep(0.05)
|
||||||
|
return _completed(prompt)
|
||||||
|
|
||||||
|
with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory:
|
||||||
|
engine = mock_factory.return_value
|
||||||
|
engine.submit_message.side_effect = _slow_submit
|
||||||
|
|
||||||
|
# Tight deadline: 0.2s, submit will be mid-loop when timeout fires
|
||||||
|
start = time.monotonic()
|
||||||
|
results = runtime.run_turn_loop(
|
||||||
|
'hello', max_turns=1, timeout_seconds=0.2,
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
release.set() # let the background thread exit cleanly
|
||||||
|
|
||||||
|
# Runtime returned a timeout TurnResult to the caller
|
||||||
|
assert results[-1].stop_reason == 'timeout'
|
||||||
|
# And it happened within a reasonable window of the deadline
|
||||||
|
assert elapsed < 1.5, f'runtime did not honour deadline: {elapsed:.2f}s'
|
||||||
|
|
||||||
|
# Give the background thread a moment to observe the cancel.
|
||||||
|
# We don't assert on it directly (thread-level observability is
|
||||||
|
# timing-dependent), but the contract is: the event IS set, so any
|
||||||
|
# cooperative checkpoint will see it.
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancelEventSharedAcrossTurns:
|
||||||
|
"""Event is created once per run_turn_loop invocation and shared across turns."""
|
||||||
|
|
||||||
|
def test_same_event_threaded_to_every_submit_message(self) -> None:
|
||||||
|
runtime = PortRuntime()
|
||||||
|
captured_events: list[threading.Event] = []
|
||||||
|
|
||||||
|
def _capture(prompt, commands, tools, denials, cancel_event=None):
|
||||||
|
if cancel_event is not None:
|
||||||
|
captured_events.append(cancel_event)
|
||||||
|
return _completed(prompt)
|
||||||
|
|
||||||
|
with patch('src.runtime.QueryEnginePort.from_workspace') as mock_factory:
|
||||||
|
engine = mock_factory.return_value
|
||||||
|
engine.submit_message.side_effect = _capture
|
||||||
|
|
||||||
|
runtime.run_turn_loop(
|
||||||
|
'hello', max_turns=3, timeout_seconds=5.0,
|
||||||
|
continuation_prompt='continue',
|
||||||
|
)
|
||||||
|
|
||||||
|
# All 3 turns received the same event object (same identity)
|
||||||
|
assert len(captured_events) == 3
|
||||||
|
assert all(e is captured_events[0] for e in captured_events), (
|
||||||
|
'runtime must share one cancel_event across turns, not create '
|
||||||
|
'a new one per turn \u2014 otherwise a late-arriving cancel on turn '
|
||||||
|
'N-1 cannot affect turn N'
|
||||||
|
)
|
||||||
@ -51,7 +51,9 @@ class TestTimeoutAbortsHungTurn:
|
|||||||
"""A stalled submit_message must be aborted and emit stop_reason='timeout'."""
|
"""A stalled submit_message must be aborted and emit stop_reason='timeout'."""
|
||||||
runtime = PortRuntime()
|
runtime = PortRuntime()
|
||||||
|
|
||||||
def _hang(prompt, commands, tools, denials):
|
# #164 Stage A: runtime now passes cancel_event as a 5th positional
|
||||||
|
# arg on the timeout path, so mocks must accept it (even if they ignore it).
|
||||||
|
def _hang(prompt, commands, tools, denials, cancel_event=None):
|
||||||
time.sleep(5.0) # would block the loop
|
time.sleep(5.0) # would block the loop
|
||||||
return _completed_result(prompt)
|
return _completed_result(prompt)
|
||||||
|
|
||||||
@ -84,7 +86,7 @@ class TestTimeoutBudgetIsTotal:
|
|||||||
runtime = PortRuntime()
|
runtime = PortRuntime()
|
||||||
call_count = {'n': 0}
|
call_count = {'n': 0}
|
||||||
|
|
||||||
def _slow(prompt, commands, tools, denials):
|
def _slow(prompt, commands, tools, denials, cancel_event=None):
|
||||||
call_count['n'] += 1
|
call_count['n'] += 1
|
||||||
time.sleep(0.4) # each turn burns 0.4s
|
time.sleep(0.4) # each turn burns 0.4s
|
||||||
return _completed_result(prompt)
|
return _completed_result(prompt)
|
||||||
@ -135,7 +137,7 @@ class TestTimeoutResultShape:
|
|||||||
"""Synthetic TurnResult on timeout must carry the turn's prompt + routed matches."""
|
"""Synthetic TurnResult on timeout must carry the turn's prompt + routed matches."""
|
||||||
runtime = PortRuntime()
|
runtime = PortRuntime()
|
||||||
|
|
||||||
def _hang(prompt, commands, tools, denials):
|
def _hang(prompt, commands, tools, denials, cancel_event=None):
|
||||||
time.sleep(5.0)
|
time.sleep(5.0)
|
||||||
return _completed_result(prompt)
|
return _completed_result(prompt)
|
||||||
|
|
||||||
|
|||||||
220
tests/test_submit_message_cancellation.py
Normal file
220
tests/test_submit_message_cancellation.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
"""Tests for cooperative cancellation in submit_message (ROADMAP #164 Stage A).
|
||||||
|
|
||||||
|
Verifies that cancel_event enables safe early termination:
|
||||||
|
- Event set before call => immediate return with stop_reason='cancelled'
|
||||||
|
- Event set between budget check and commit => still 'cancelled', no mutation
|
||||||
|
- Event set after commit => not observable (honest cooperative limit)
|
||||||
|
- Legacy callers (cancel_event=None) see zero behaviour change
|
||||||
|
- State is untouched on cancellation: mutable_messages, transcript_store,
|
||||||
|
permission_denials, total_usage all preserved
|
||||||
|
|
||||||
|
This closes the #161 follow-up gap filed as #164: wedged provider threads
|
||||||
|
can no longer silently commit ghost turns after the caller observed a
|
||||||
|
timeout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from src.models import PermissionDenial # noqa: E402
|
||||||
|
from src.port_manifest import build_port_manifest # noqa: E402
|
||||||
|
from src.query_engine import QueryEngineConfig, QueryEnginePort, TurnResult # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _fresh_engine(**config_overrides) -> QueryEnginePort:
|
||||||
|
config = QueryEngineConfig(**config_overrides) if config_overrides else QueryEngineConfig()
|
||||||
|
return QueryEnginePort(manifest=build_port_manifest(), config=config)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancellationBeforeCall:
|
||||||
|
"""Event set before submit_message is invoked => immediate 'cancelled'."""
|
||||||
|
|
||||||
|
def test_pre_set_event_returns_cancelled_immediately(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
result = engine.submit_message('hello', cancel_event=event)
|
||||||
|
|
||||||
|
assert result.stop_reason == 'cancelled'
|
||||||
|
assert result.prompt == 'hello'
|
||||||
|
# Output is empty on pre-budget cancel (no synthesis)
|
||||||
|
assert result.output == ''
|
||||||
|
|
||||||
|
def test_pre_set_event_preserves_mutable_messages(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
engine.submit_message('ghost turn', cancel_event=event)
|
||||||
|
|
||||||
|
assert engine.mutable_messages == [], (
|
||||||
|
'cancelled turn must not appear in mutable_messages'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_pre_set_event_preserves_transcript_store(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
engine.submit_message('ghost turn', cancel_event=event)
|
||||||
|
|
||||||
|
assert engine.transcript_store.entries == [], (
|
||||||
|
'cancelled turn must not appear in transcript_store'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_pre_set_event_preserves_usage_counters(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
initial_usage = engine.total_usage
|
||||||
|
event = threading.Event()
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
engine.submit_message('expensive prompt ' * 100, cancel_event=event)
|
||||||
|
|
||||||
|
assert engine.total_usage == initial_usage, (
|
||||||
|
'cancelled turn must not increment token counters'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_pre_set_event_preserves_permission_denials(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
denials = (PermissionDenial(tool_name='BashTool', reason='destructive'),)
|
||||||
|
engine.submit_message('run bash ls', denied_tools=denials, cancel_event=event)
|
||||||
|
|
||||||
|
assert engine.permission_denials == [], (
|
||||||
|
'cancelled turn must not extend permission_denials'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancellationAfterBudgetCheck:
|
||||||
|
"""Event set between budget projection and commit => 'cancelled', state intact.
|
||||||
|
|
||||||
|
This simulates the realistic racy case: engine starts computing output,
|
||||||
|
caller hits deadline, sets event. Engine observes at post-budget checkpoint
|
||||||
|
and returns cleanly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_post_budget_cancel_returns_cancelled(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
# Patch: set the event after projection but before mutation. We do this
|
||||||
|
# by wrapping _format_output (called mid-submit) to set the event.
|
||||||
|
original_format = engine._format_output
|
||||||
|
|
||||||
|
def _set_then_format(*args, **kwargs):
|
||||||
|
result = original_format(*args, **kwargs)
|
||||||
|
event.set() # trigger cancel right after output is built
|
||||||
|
return result
|
||||||
|
|
||||||
|
engine._format_output = _set_then_format # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = engine.submit_message('hello', cancel_event=event)
|
||||||
|
|
||||||
|
assert result.stop_reason == 'cancelled'
|
||||||
|
# Output IS built here (we're past the pre-budget checkpoint), so it's
|
||||||
|
# not empty. The contract is about *state*, not output synthesis.
|
||||||
|
assert result.output != ''
|
||||||
|
# Critical: state still unchanged
|
||||||
|
assert engine.mutable_messages == []
|
||||||
|
assert engine.transcript_store.entries == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancellationAfterCommit:
|
||||||
|
"""Event set after commit is not observable \u2014 honest cooperative limit."""
|
||||||
|
|
||||||
|
def test_post_commit_cancel_is_not_observable(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
# Event only set *after* submit_message returns. The first call has
|
||||||
|
# already committed before the event is set.
|
||||||
|
result = engine.submit_message('hello', cancel_event=event)
|
||||||
|
event.set() # too late
|
||||||
|
|
||||||
|
assert result.stop_reason == 'completed', (
|
||||||
|
'cancel set after commit must not retroactively invalidate the turn'
|
||||||
|
)
|
||||||
|
assert engine.mutable_messages == ['hello']
|
||||||
|
|
||||||
|
def test_next_call_observes_cancel(self) -> None:
|
||||||
|
"""The cancel_event persists \u2014 the next call on the same engine sees it."""
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
engine.submit_message('first', cancel_event=event)
|
||||||
|
assert engine.mutable_messages == ['first']
|
||||||
|
|
||||||
|
event.set()
|
||||||
|
# Next call observes the cancel at entry
|
||||||
|
result = engine.submit_message('second', cancel_event=event)
|
||||||
|
|
||||||
|
assert result.stop_reason == 'cancelled'
|
||||||
|
# 'second' must NOT have been committed
|
||||||
|
assert engine.mutable_messages == ['first']
|
||||||
|
|
||||||
|
|
||||||
|
class TestLegacyCallersUnchanged:
|
||||||
|
"""cancel_event=None (default) => zero behaviour change from pre-#164."""
|
||||||
|
|
||||||
|
def test_no_event_submits_normally(self) -> None:
|
||||||
|
engine = _fresh_engine()
|
||||||
|
result = engine.submit_message('hello')
|
||||||
|
|
||||||
|
assert result.stop_reason == 'completed'
|
||||||
|
assert engine.mutable_messages == ['hello']
|
||||||
|
|
||||||
|
def test_no_event_with_budget_overflow_still_rejects_atomically(self) -> None:
|
||||||
|
"""#162 atomicity contract survives when cancel_event is absent."""
|
||||||
|
engine = _fresh_engine(max_budget_tokens=1)
|
||||||
|
words = ' '.join(['word'] * 100)
|
||||||
|
|
||||||
|
result = engine.submit_message(words) # no cancel_event
|
||||||
|
|
||||||
|
assert result.stop_reason == 'max_budget_reached'
|
||||||
|
assert engine.mutable_messages == []
|
||||||
|
|
||||||
|
def test_no_event_respects_max_turns(self) -> None:
|
||||||
|
"""max_turns_reached contract survives when cancel_event is absent."""
|
||||||
|
engine = _fresh_engine(max_turns=1)
|
||||||
|
engine.submit_message('first')
|
||||||
|
result = engine.submit_message('second') # no cancel_event
|
||||||
|
|
||||||
|
assert result.stop_reason == 'max_turns_reached'
|
||||||
|
assert engine.mutable_messages == ['first']
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancellationVsOtherStopReasons:
|
||||||
|
"""cancel_event has a defined precedence relative to budget/turns."""
|
||||||
|
|
||||||
|
def test_cancel_precedes_max_turns_check(self) -> None:
|
||||||
|
"""If cancel is set when capacity is also full, cancel wins (clearer signal)."""
|
||||||
|
engine = _fresh_engine(max_turns=0) # immediately full
|
||||||
|
event = threading.Event()
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
result = engine.submit_message('hello', cancel_event=event)
|
||||||
|
|
||||||
|
# cancel_event check is the very first thing in submit_message,
|
||||||
|
# so it fires before the max_turns check even sees capacity
|
||||||
|
assert result.stop_reason == 'cancelled'
|
||||||
|
|
||||||
|
def test_cancel_does_not_override_commit(self) -> None:
|
||||||
|
"""Completed turn with late cancel still reports 'completed' \u2014 the
|
||||||
|
turn already succeeded; we don't lie about it."""
|
||||||
|
engine = _fresh_engine()
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
# Event gets set after the mutation is done \u2014 submit_message doesn't
|
||||||
|
# re-check after commit
|
||||||
|
result = engine.submit_message('hello', cancel_event=event)
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
assert result.stop_reason == 'completed'
|
||||||
Loading…
x
Reference in New Issue
Block a user