Merge PR #1976 provider response guards

This commit is contained in:
Affaan Mustafa 2026-05-18 01:05:37 -04:00
commit 80f6c27957
4 changed files with 181 additions and 122 deletions

View File

@ -15,6 +15,7 @@ from llm.core.interface import (
RateLimitError, RateLimitError,
) )
from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1" ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1"
ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1" ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1"
@ -55,7 +56,7 @@ class _AstraflowBaseProvider(LLMProvider):
env_model = os.environ.get(self.model_env) env_model = os.environ.get(self.model_env)
fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None
self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) self.client = OpenAI(api_key=self.api_key, base_url=self.base_url, _enforce_credentials=False)
self._models = [ self._models = [
ModelInfo( ModelInfo(
name=self.default_model, name=self.default_model,
@ -79,6 +80,8 @@ class _AstraflowBaseProvider(LLMProvider):
params["tools"] = [tool.to_openai_tool() for tool in llm_input.tools] params["tools"] = [tool.to_openai_tool() for tool in llm_input.tools]
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
if not response.choices or response.choices[0].message is None:
raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR)
choice = response.choices[0] choice = response.choices[0]
tool_calls = None tool_calls = None

View File

@ -0,0 +1,3 @@
"""Shared provider constants."""
EMPTY_FILTERED_RESPONSE_ERROR = "LLM returned empty or filtered response"

View File

@ -15,13 +15,18 @@ from llm.core.interface import (
RateLimitError, RateLimitError,
) )
from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
class OpenAIProvider(LLMProvider): class OpenAIProvider(LLMProvider):
provider_type = ProviderType.OPENAI provider_type = ProviderType.OPENAI
def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None:
self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"), base_url=base_url) self.client = OpenAI(
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
base_url=base_url,
_enforce_credentials=False,
)
self._models = [ self._models = [
ModelInfo( ModelInfo(
name="gpt-4o", name="gpt-4o",
@ -70,6 +75,8 @@ class OpenAIProvider(LLMProvider):
params["tools"] = [tool.to_openai_tool() for tool in input.tools] params["tools"] = [tool.to_openai_tool() for tool in input.tools]
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
if not response.choices or response.choices[0].message is None:
raise ValueError(EMPTY_FILTERED_RESPONSE_ERROR)
choice = response.choices[0] choice = response.choices[0]
tool_calls = None tool_calls = None
@ -83,15 +90,19 @@ class OpenAIProvider(LLMProvider):
for tc in choice.message.tool_calls for tc in choice.message.tool_calls
] ]
usage = None
if response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
return LLMOutput( return LLMOutput(
content=choice.message.content or "", content=choice.message.content or "",
tool_calls=tool_calls, tool_calls=tool_calls,
model=response.model, model=response.model,
usage={ usage=usage,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
},
stop_reason=choice.finish_reason, stop_reason=choice.finish_reason,
) )
except Exception as e: except Exception as e:

View File

@ -1,7 +1,10 @@
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from llm.core.types import LLMInput, Message, Role, ToolDefinition from llm.core.types import LLMInput, Message, Role, ToolDefinition
from llm.providers.claude import ClaudeProvider from llm.providers.claude import ClaudeProvider
from llm.providers.constants import EMPTY_FILTERED_RESPONSE_ERROR
from llm.providers.openai import OpenAIProvider from llm.providers.openai import OpenAIProvider
@ -14,21 +17,20 @@ def _tool() -> ToolDefinition:
class _OpenAICompletions: class _OpenAICompletions:
def __init__(self) -> None: def __init__(self, response: SimpleNamespace | None = None) -> None:
self.params = None self.params = None
self.response = response
def create(self, **params): def create(self, **params):
self.params = params self.params = params
return SimpleNamespace( if self.response:
choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")], return self.response
model=params["model"], return _openai_response(model=params["model"])
usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
)
class _OpenAIClient: class _OpenAIClient:
def __init__(self) -> None: def __init__(self, response: SimpleNamespace | None = None) -> None:
self.completions = _OpenAICompletions() self.completions = _OpenAICompletions(response=response)
self.chat = SimpleNamespace(completions=self.completions) self.chat = SimpleNamespace(completions=self.completions)
@ -52,6 +54,16 @@ class _AnthropicClient:
self.api_key = "test" self.api_key = "test"
def _openai_response(**overrides) -> SimpleNamespace:
defaults = {
"choices": [SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")],
"model": "gpt-4o-mini",
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_openai_provider_serializes_tools_for_chat_completions(): def test_openai_provider_serializes_tools_for_chat_completions():
provider = OpenAIProvider(api_key="test") provider = OpenAIProvider(api_key="test")
client = _OpenAIClient() client = _OpenAIClient()
@ -72,6 +84,36 @@ def test_openai_provider_serializes_tools_for_chat_completions():
] ]
def test_openai_provider_can_be_constructed_without_credentials(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
provider = OpenAIProvider()
assert provider.validate_config() is False
def test_openai_provider_rejects_empty_or_filtered_responses():
provider = OpenAIProvider(api_key="test")
for response in [
_openai_response(choices=[]),
_openai_response(choices=[SimpleNamespace(message=None, finish_reason="content_filter")]),
]:
provider.client = _OpenAIClient(response=response)
with pytest.raises(ValueError, match=EMPTY_FILTERED_RESPONSE_ERROR):
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
def test_openai_provider_allows_missing_usage():
provider = OpenAIProvider(api_key="test")
provider.client = _OpenAIClient(response=_openai_response(usage=None))
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.content == "ok"
assert output.usage is None
def test_claude_provider_serializes_tools_for_messages_api(): def test_claude_provider_serializes_tools_for_messages_api():
provider = ClaudeProvider(api_key="test") provider = ClaudeProvider(api_key="test")
client = _AnthropicClient() client = _AnthropicClient()