From e9c8845833415204db993a3b0d0bf337fded23da Mon Sep 17 00:00:00 2001 From: Affaan Mustafa Date: Mon, 11 May 2026 23:04:03 -0400 Subject: [PATCH] feat: add Astraflow provider support --- .env.example | 10 ++ src/llm/core/types.py | 2 + src/llm/providers/__init__.py | 3 + src/llm/providers/astraflow.py | 148 +++++++++++++++++++++++++++ src/llm/providers/resolver.py | 3 + tests/hooks/mcp-health-check.test.js | 2 +- tests/test_astraflow_provider.py | 141 +++++++++++++++++++++++++ tests/test_resolver.py | 28 ++++- tests/test_types.py | 2 + 9 files changed, 337 insertions(+), 2 deletions(-) create mode 100644 src/llm/providers/astraflow.py create mode 100644 tests/test_astraflow_provider.py diff --git a/.env.example b/.env.example index c37740c3..ec0ddba5 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,16 @@ GITHUB_TOKEN= # ─── Optional: Package manager override ────────────────────────────────────── # CLAUDE_CODE_PACKAGE_MANAGER=npm # npm | pnpm | yarn | bun +# --- Optional: Astraflow / UModelVerse (OpenAI-compatible) ------------------- +# Global endpoint: https://api.umodelverse.ai/v1 +ASTRAFLOW_API_KEY= +# ASTRAFLOW_MODEL=gpt-4o-mini +# ASTRAFLOW_BASE_URL=https://api.umodelverse.ai/v1 +# China endpoint: https://api.modelverse.cn/v1 +ASTRAFLOW_CN_API_KEY= +# ASTRAFLOW_CN_MODEL=gpt-4o-mini +# ASTRAFLOW_CN_BASE_URL=https://api.modelverse.cn/v1 + # ─── Session & Security ───────────────────────────────────────────────────── # GitHub username (used by CI scripts for credential context) GITHUB_USER="your-github-username" diff --git a/src/llm/core/types.py b/src/llm/core/types.py index 6b06adce..07e788bf 100644 --- a/src/llm/core/types.py +++ b/src/llm/core/types.py @@ -18,6 +18,8 @@ class ProviderType(str, Enum): CLAUDE = "claude" OPENAI = "openai" OLLAMA = "ollama" + ASTRAFLOW = "astraflow" + ASTRAFLOW_CN = "astraflow_cn" @dataclass(frozen=True) diff --git a/src/llm/providers/__init__.py b/src/llm/providers/__init__.py index c495fa1c..6775f8ca 100644 --- a/src/llm/providers/__init__.py +++ b/src/llm/providers/__init__.py @@ -1,11 +1,14 @@ """Provider adapters for multiple LLM backends.""" +from llm.providers.astraflow import AstraflowCNProvider, AstraflowProvider from llm.providers.claude import ClaudeProvider from llm.providers.openai import OpenAIProvider from llm.providers.ollama import OllamaProvider from llm.providers.resolver import get_provider, register_provider __all__ = ( + "AstraflowCNProvider", + "AstraflowProvider", "ClaudeProvider", "OpenAIProvider", "OllamaProvider", diff --git a/src/llm/providers/astraflow.py b/src/llm/providers/astraflow.py new file mode 100644 index 00000000..ba23517f --- /dev/null +++ b/src/llm/providers/astraflow.py @@ -0,0 +1,148 @@ +"""Astraflow/UModelVerse OpenAI-compatible provider adapters.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +from openai import OpenAI + +from llm.core.interface import ( + AuthenticationError, + ContextLengthError, + LLMProvider, + RateLimitError, +) +from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall + +ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1" +ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1" +DEFAULT_ASTRAFLOW_MODEL = "gpt-4o-mini" + + +def _parse_tool_arguments(raw_arguments: str | None) -> dict[str, Any]: + if not raw_arguments: + return {} + + try: + arguments = json.loads(raw_arguments) + except json.JSONDecodeError: + return {"raw": raw_arguments} + + if isinstance(arguments, dict): + return arguments + return {"value": arguments} + + +class _AstraflowBaseProvider(LLMProvider): + provider_type: ProviderType + api_key_env: str + base_url_env: str + model_env: str + fallback_model_env: str | None = None + default_base_url: str + default_model = DEFAULT_ASTRAFLOW_MODEL + + def __init__( + self, + api_key: str | None = None, + base_url: str | None = None, + default_model: str | None = None, + ) -> None: + self.api_key = api_key or os.environ.get(self.api_key_env) or "" + self.base_url = base_url or os.environ.get(self.base_url_env, self.default_base_url) + env_model = os.environ.get(self.model_env) + fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None + self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + self._models = [ + ModelInfo( + name=self.default_model, + provider=self.provider_type, + supports_tools=True, + supports_vision=False, + ) + ] + + def generate(self, llm_input: LLMInput) -> LLMOutput: + try: + params: dict[str, Any] = { + "model": llm_input.model or self.default_model, + "messages": [msg.to_dict() for msg in llm_input.messages], + } + if llm_input.temperature != 1.0: + params["temperature"] = llm_input.temperature + if llm_input.max_tokens is not None: + params["max_tokens"] = llm_input.max_tokens + if llm_input.tools: + params["tools"] = [tool.to_openai_tool() for tool in llm_input.tools] + + response = self.client.chat.completions.create(**params) + choice = response.choices[0] + + tool_calls = None + if choice.message.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id or "", + name=tc.function.name, + arguments=_parse_tool_arguments(tc.function.arguments), + ) + for tc in choice.message.tool_calls + ] + + usage = None + if response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + + return LLMOutput( + content=choice.message.content or "", + tool_calls=tool_calls, + model=response.model, + usage=usage, + stop_reason=choice.finish_reason, + ) + except Exception as e: + msg = str(e) + if "401" in msg or "authentication" in msg.lower(): + raise AuthenticationError(msg, provider=self.provider_type) from e + if "429" in msg or "rate_limit" in msg.lower(): + raise RateLimitError(msg, provider=self.provider_type) from e + if "context" in msg.lower() and "length" in msg.lower(): + raise ContextLengthError(msg, provider=self.provider_type) from e + raise + + def list_models(self) -> list[ModelInfo]: + return self._models.copy() + + def validate_config(self) -> bool: + return bool(self.api_key) + + def get_default_model(self) -> str: + return self.default_model + + +class AstraflowProvider(_AstraflowBaseProvider): + """UModelVerse global endpoint using OpenAI-compatible chat completions.""" + + provider_type = ProviderType.ASTRAFLOW + api_key_env = "ASTRAFLOW_API_KEY" + base_url_env = "ASTRAFLOW_BASE_URL" + model_env = "ASTRAFLOW_MODEL" + default_base_url = ASTRAFLOW_BASE_URL + + +class AstraflowCNProvider(_AstraflowBaseProvider): + """UModelVerse China endpoint using OpenAI-compatible chat completions.""" + + provider_type = ProviderType.ASTRAFLOW_CN + api_key_env = "ASTRAFLOW_CN_API_KEY" + base_url_env = "ASTRAFLOW_CN_BASE_URL" + model_env = "ASTRAFLOW_CN_MODEL" + fallback_model_env = "ASTRAFLOW_MODEL" + default_base_url = ASTRAFLOW_CN_BASE_URL diff --git a/src/llm/providers/resolver.py b/src/llm/providers/resolver.py index 5967523e..0e3d1b23 100644 --- a/src/llm/providers/resolver.py +++ b/src/llm/providers/resolver.py @@ -7,12 +7,15 @@ from pathlib import Path from llm.core.interface import LLMProvider from llm.core.types import ProviderType +from llm.providers.astraflow import AstraflowCNProvider, AstraflowProvider from llm.providers.claude import ClaudeProvider from llm.providers.openai import OpenAIProvider from llm.providers.ollama import OllamaProvider _PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = { + ProviderType.ASTRAFLOW: AstraflowProvider, + ProviderType.ASTRAFLOW_CN: AstraflowCNProvider, ProviderType.CLAUDE: ClaudeProvider, ProviderType.OPENAI: OpenAIProvider, ProviderType.OLLAMA: OllamaProvider, diff --git a/tests/hooks/mcp-health-check.test.js b/tests/hooks/mcp-health-check.test.js index b637353f..2f19f29f 100644 --- a/tests/hooks/mcp-health-check.test.js +++ b/tests/hooks/mcp-health-check.test.js @@ -568,7 +568,7 @@ async function runTests() { CLAUDE_HOOK_EVENT_NAME: 'PreToolUse', ECC_MCP_CONFIG_PATH: configPath, ECC_MCP_HEALTH_STATE_PATH: statePath, - ECC_MCP_HEALTH_TIMEOUT_MS: '100' + ECC_MCP_HEALTH_TIMEOUT_MS: process.platform === 'win32' ? '1000' : '100' } ); diff --git a/tests/test_astraflow_provider.py b/tests/test_astraflow_provider.py new file mode 100644 index 00000000..b70c9bd5 --- /dev/null +++ b/tests/test_astraflow_provider.py @@ -0,0 +1,141 @@ +from types import SimpleNamespace + +from llm.core.types import LLMInput, Message, ProviderType, Role, ToolDefinition, ToolCall +from llm.providers.astraflow import ASTRAFLOW_BASE_URL, ASTRAFLOW_CN_BASE_URL, AstraflowCNProvider, AstraflowProvider + + +def _tool() -> ToolDefinition: + return ToolDefinition( + name="search", + description="Search", + parameters={"type": "object", "properties": {"query": {"type": "string"}}}, + ) + + +class _Completions: + def __init__(self, response: SimpleNamespace) -> None: + self.params = None + self.response = response + + def create(self, **params): + self.params = params + return self.response + + +class _Client: + def __init__(self, response: SimpleNamespace) -> None: + self.completions = _Completions(response) + self.chat = SimpleNamespace(completions=self.completions) + + +def _response(**overrides) -> SimpleNamespace: + message = SimpleNamespace(content="ok", tool_calls=None) + choice = SimpleNamespace(message=message, finish_reason="stop") + defaults = { + "choices": [choice], + "model": "gpt-4o-mini", + "usage": SimpleNamespace(prompt_tokens=1, completion_tokens=2, total_tokens=3), + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def test_astraflow_provider_defaults_to_global_umodelverse_endpoint(monkeypatch): + monkeypatch.delenv("ASTRAFLOW_API_KEY", raising=False) + monkeypatch.delenv("ASTRAFLOW_BASE_URL", raising=False) + monkeypatch.delenv("ASTRAFLOW_MODEL", raising=False) + + provider = AstraflowProvider() + + assert provider.provider_type == ProviderType.ASTRAFLOW + assert provider.base_url == ASTRAFLOW_BASE_URL + assert provider.get_default_model() == "gpt-4o-mini" + assert provider.validate_config() is False + + +def test_astraflow_cn_provider_uses_cn_endpoint_and_model_fallback(monkeypatch): + monkeypatch.setenv("ASTRAFLOW_API_KEY", "global-key") + monkeypatch.setenv("ASTRAFLOW_MODEL", "deepseek-ai/DeepSeek-V3-0324") + monkeypatch.setenv("ASTRAFLOW_CN_API_KEY", "cn-key") + monkeypatch.delenv("ASTRAFLOW_CN_MODEL", raising=False) + monkeypatch.delenv("ASTRAFLOW_CN_BASE_URL", raising=False) + + provider = AstraflowCNProvider() + + assert provider.provider_type == ProviderType.ASTRAFLOW_CN + assert provider.base_url == ASTRAFLOW_CN_BASE_URL + assert provider.get_default_model() == "deepseek-ai/DeepSeek-V3-0324" + assert provider.validate_config() is True + + +def test_astraflow_provider_generates_openai_compatible_chat_completion(): + provider = AstraflowProvider(api_key="test", default_model="deepseek-ai/DeepSeek-V3-0324") + client = _Client(_response(model="deepseek-ai/DeepSeek-V3-0324")) + provider.client = client + + output = provider.generate( + LLMInput( + messages=[Message(role=Role.USER, content="hi")], + max_tokens=128, + tools=[_tool()], + ) + ) + + assert output.content == "ok" + assert output.model == "deepseek-ai/DeepSeek-V3-0324" + assert output.usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3} + assert client.completions.params["model"] == "deepseek-ai/DeepSeek-V3-0324" + assert client.completions.params["max_tokens"] == 128 + assert "temperature" not in client.completions.params + assert client.completions.params["tools"] == [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": {"type": "object", "properties": {"query": {"type": "string"}}}, + "strict": True, + }, + } + ] + + +def test_astraflow_provider_forwards_non_default_temperature(): + provider = AstraflowProvider(api_key="test") + client = _Client(_response()) + provider.client = client + + provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], temperature=0.2)) + + assert client.completions.params["temperature"] == 0.2 + + +def test_astraflow_provider_parses_tool_calls(): + provider = AstraflowProvider(api_key="test") + tool_call = SimpleNamespace( + id="call_1", + function=SimpleNamespace(name="search", arguments='{"query":"ucloud"}'), + ) + message = SimpleNamespace(content="", tool_calls=[tool_call]) + client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")], usage=None)) + provider.client = client + + output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")])) + + assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"query": "ucloud"})] + assert output.usage is None + + +def test_astraflow_provider_preserves_malformed_tool_arguments(): + provider = AstraflowProvider(api_key="test") + tool_call = SimpleNamespace( + id="call_1", + function=SimpleNamespace(name="search", arguments="{not-json"), + ) + message = SimpleNamespace(content="", tool_calls=[tool_call]) + client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")])) + provider.client = client + + output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")])) + + assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"raw": "{not-json"})] diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 7a8b9b63..29f743e4 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,6 +1,6 @@ import pytest from llm.core.types import ProviderType -from llm.providers import ClaudeProvider, OpenAIProvider, OllamaProvider, get_provider +from llm.providers import AstraflowCNProvider, AstraflowProvider, ClaudeProvider, OpenAIProvider, OllamaProvider, get_provider class TestGetProvider: @@ -19,6 +19,16 @@ class TestGetProvider: assert isinstance(provider, OllamaProvider) assert provider.provider_type == ProviderType.OLLAMA + def test_get_astraflow_provider(self): + provider = get_provider("astraflow") + assert isinstance(provider, AstraflowProvider) + assert provider.provider_type == ProviderType.ASTRAFLOW + + def test_get_astraflow_cn_provider(self): + provider = get_provider("astraflow_cn") + assert isinstance(provider, AstraflowCNProvider) + assert provider.provider_type == ProviderType.ASTRAFLOW_CN + def test_get_provider_by_enum(self): provider = get_provider(ProviderType.CLAUDE) assert isinstance(provider, ClaudeProvider) @@ -52,6 +62,13 @@ class TestGetProvider: assert isinstance(provider, OllamaProvider) + def test_astraflow_env_provider_is_normalized(self, monkeypatch): + monkeypatch.setenv("LLM_PROVIDER", "ASTRAFLOW") + + provider = get_provider() + + assert isinstance(provider, AstraflowProvider) + def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path): monkeypatch.delenv("LLM_PROVIDER", raising=False) monkeypatch.chdir(tmp_path) @@ -60,3 +77,12 @@ class TestGetProvider: provider = get_provider("ollama") assert isinstance(provider, OllamaProvider) + + def test_saved_llm_env_selects_astraflow_cn_provider(self, monkeypatch, tmp_path): + monkeypatch.delenv("LLM_PROVIDER", raising=False) + monkeypatch.chdir(tmp_path) + tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=astraflow_cn\n") + + provider = get_provider() + + assert isinstance(provider, AstraflowCNProvider) diff --git a/tests/test_types.py b/tests/test_types.py index a065cfc8..8399a0ba 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -25,6 +25,8 @@ class TestProviderType: assert ProviderType.CLAUDE.value == "claude" assert ProviderType.OPENAI.value == "openai" assert ProviderType.OLLAMA.value == "ollama" + assert ProviderType.ASTRAFLOW.value == "astraflow" + assert ProviderType.ASTRAFLOW_CN.value == "astraflow_cn" class TestMessage: