feat: add Astraflow provider support

This commit is contained in:
Affaan Mustafa 2026-05-11 23:04:03 -04:00 committed by Affaan Mustafa
parent 03108bea62
commit e9c8845833
9 changed files with 337 additions and 2 deletions

View File

@ -20,6 +20,16 @@ GITHUB_TOKEN=
# ─── Optional: Package manager override ────────────────────────────────────── # ─── Optional: Package manager override ──────────────────────────────────────
# CLAUDE_CODE_PACKAGE_MANAGER=npm # npm | pnpm | yarn | bun # CLAUDE_CODE_PACKAGE_MANAGER=npm # npm | pnpm | yarn | bun
# --- Optional: Astraflow / UModelVerse (OpenAI-compatible) -------------------
# Global endpoint: https://api.umodelverse.ai/v1
ASTRAFLOW_API_KEY=
# ASTRAFLOW_MODEL=gpt-4o-mini
# ASTRAFLOW_BASE_URL=https://api.umodelverse.ai/v1
# China endpoint: https://api.modelverse.cn/v1
ASTRAFLOW_CN_API_KEY=
# ASTRAFLOW_CN_MODEL=gpt-4o-mini
# ASTRAFLOW_CN_BASE_URL=https://api.modelverse.cn/v1
# ─── Session & Security ───────────────────────────────────────────────────── # ─── Session & Security ─────────────────────────────────────────────────────
# GitHub username (used by CI scripts for credential context) # GitHub username (used by CI scripts for credential context)
GITHUB_USER="your-github-username" GITHUB_USER="your-github-username"

View File

@ -18,6 +18,8 @@ class ProviderType(str, Enum):
CLAUDE = "claude" CLAUDE = "claude"
OPENAI = "openai" OPENAI = "openai"
OLLAMA = "ollama" OLLAMA = "ollama"
ASTRAFLOW = "astraflow"
ASTRAFLOW_CN = "astraflow_cn"
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@ -1,11 +1,14 @@
"""Provider adapters for multiple LLM backends.""" """Provider adapters for multiple LLM backends."""
from llm.providers.astraflow import AstraflowCNProvider, AstraflowProvider
from llm.providers.claude import ClaudeProvider from llm.providers.claude import ClaudeProvider
from llm.providers.openai import OpenAIProvider from llm.providers.openai import OpenAIProvider
from llm.providers.ollama import OllamaProvider from llm.providers.ollama import OllamaProvider
from llm.providers.resolver import get_provider, register_provider from llm.providers.resolver import get_provider, register_provider
__all__ = ( __all__ = (
"AstraflowCNProvider",
"AstraflowProvider",
"ClaudeProvider", "ClaudeProvider",
"OpenAIProvider", "OpenAIProvider",
"OllamaProvider", "OllamaProvider",

View File

@ -0,0 +1,148 @@
"""Astraflow/UModelVerse OpenAI-compatible provider adapters."""
from __future__ import annotations
import json
import os
from typing import Any
from openai import OpenAI
from llm.core.interface import (
AuthenticationError,
ContextLengthError,
LLMProvider,
RateLimitError,
)
from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall
ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1"
ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1"
DEFAULT_ASTRAFLOW_MODEL = "gpt-4o-mini"
def _parse_tool_arguments(raw_arguments: str | None) -> dict[str, Any]:
if not raw_arguments:
return {}
try:
arguments = json.loads(raw_arguments)
except json.JSONDecodeError:
return {"raw": raw_arguments}
if isinstance(arguments, dict):
return arguments
return {"value": arguments}
class _AstraflowBaseProvider(LLMProvider):
provider_type: ProviderType
api_key_env: str
base_url_env: str
model_env: str
fallback_model_env: str | None = None
default_base_url: str
default_model = DEFAULT_ASTRAFLOW_MODEL
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
default_model: str | None = None,
) -> None:
self.api_key = api_key or os.environ.get(self.api_key_env) or ""
self.base_url = base_url or os.environ.get(self.base_url_env, self.default_base_url)
env_model = os.environ.get(self.model_env)
fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None
self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
self._models = [
ModelInfo(
name=self.default_model,
provider=self.provider_type,
supports_tools=True,
supports_vision=False,
)
]
def generate(self, llm_input: LLMInput) -> LLMOutput:
try:
params: dict[str, Any] = {
"model": llm_input.model or self.default_model,
"messages": [msg.to_dict() for msg in llm_input.messages],
}
if llm_input.temperature != 1.0:
params["temperature"] = llm_input.temperature
if llm_input.max_tokens is not None:
params["max_tokens"] = llm_input.max_tokens
if llm_input.tools:
params["tools"] = [tool.to_openai_tool() for tool in llm_input.tools]
response = self.client.chat.completions.create(**params)
choice = response.choices[0]
tool_calls = None
if choice.message.tool_calls:
tool_calls = [
ToolCall(
id=tc.id or "",
name=tc.function.name,
arguments=_parse_tool_arguments(tc.function.arguments),
)
for tc in choice.message.tool_calls
]
usage = None
if response.usage:
usage = {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
return LLMOutput(
content=choice.message.content or "",
tool_calls=tool_calls,
model=response.model,
usage=usage,
stop_reason=choice.finish_reason,
)
except Exception as e:
msg = str(e)
if "401" in msg or "authentication" in msg.lower():
raise AuthenticationError(msg, provider=self.provider_type) from e
if "429" in msg or "rate_limit" in msg.lower():
raise RateLimitError(msg, provider=self.provider_type) from e
if "context" in msg.lower() and "length" in msg.lower():
raise ContextLengthError(msg, provider=self.provider_type) from e
raise
def list_models(self) -> list[ModelInfo]:
return self._models.copy()
def validate_config(self) -> bool:
return bool(self.api_key)
def get_default_model(self) -> str:
return self.default_model
class AstraflowProvider(_AstraflowBaseProvider):
"""UModelVerse global endpoint using OpenAI-compatible chat completions."""
provider_type = ProviderType.ASTRAFLOW
api_key_env = "ASTRAFLOW_API_KEY"
base_url_env = "ASTRAFLOW_BASE_URL"
model_env = "ASTRAFLOW_MODEL"
default_base_url = ASTRAFLOW_BASE_URL
class AstraflowCNProvider(_AstraflowBaseProvider):
"""UModelVerse China endpoint using OpenAI-compatible chat completions."""
provider_type = ProviderType.ASTRAFLOW_CN
api_key_env = "ASTRAFLOW_CN_API_KEY"
base_url_env = "ASTRAFLOW_CN_BASE_URL"
model_env = "ASTRAFLOW_CN_MODEL"
fallback_model_env = "ASTRAFLOW_MODEL"
default_base_url = ASTRAFLOW_CN_BASE_URL

View File

@ -7,12 +7,15 @@ from pathlib import Path
from llm.core.interface import LLMProvider from llm.core.interface import LLMProvider
from llm.core.types import ProviderType from llm.core.types import ProviderType
from llm.providers.astraflow import AstraflowCNProvider, AstraflowProvider
from llm.providers.claude import ClaudeProvider from llm.providers.claude import ClaudeProvider
from llm.providers.openai import OpenAIProvider from llm.providers.openai import OpenAIProvider
from llm.providers.ollama import OllamaProvider from llm.providers.ollama import OllamaProvider
_PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = { _PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = {
ProviderType.ASTRAFLOW: AstraflowProvider,
ProviderType.ASTRAFLOW_CN: AstraflowCNProvider,
ProviderType.CLAUDE: ClaudeProvider, ProviderType.CLAUDE: ClaudeProvider,
ProviderType.OPENAI: OpenAIProvider, ProviderType.OPENAI: OpenAIProvider,
ProviderType.OLLAMA: OllamaProvider, ProviderType.OLLAMA: OllamaProvider,

View File

@ -568,7 +568,7 @@ async function runTests() {
CLAUDE_HOOK_EVENT_NAME: 'PreToolUse', CLAUDE_HOOK_EVENT_NAME: 'PreToolUse',
ECC_MCP_CONFIG_PATH: configPath, ECC_MCP_CONFIG_PATH: configPath,
ECC_MCP_HEALTH_STATE_PATH: statePath, ECC_MCP_HEALTH_STATE_PATH: statePath,
ECC_MCP_HEALTH_TIMEOUT_MS: '100' ECC_MCP_HEALTH_TIMEOUT_MS: process.platform === 'win32' ? '1000' : '100'
} }
); );

View File

@ -0,0 +1,141 @@
from types import SimpleNamespace
from llm.core.types import LLMInput, Message, ProviderType, Role, ToolDefinition, ToolCall
from llm.providers.astraflow import ASTRAFLOW_BASE_URL, ASTRAFLOW_CN_BASE_URL, AstraflowCNProvider, AstraflowProvider
def _tool() -> ToolDefinition:
return ToolDefinition(
name="search",
description="Search",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
)
class _Completions:
def __init__(self, response: SimpleNamespace) -> None:
self.params = None
self.response = response
def create(self, **params):
self.params = params
return self.response
class _Client:
def __init__(self, response: SimpleNamespace) -> None:
self.completions = _Completions(response)
self.chat = SimpleNamespace(completions=self.completions)
def _response(**overrides) -> SimpleNamespace:
message = SimpleNamespace(content="ok", tool_calls=None)
choice = SimpleNamespace(message=message, finish_reason="stop")
defaults = {
"choices": [choice],
"model": "gpt-4o-mini",
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=2, total_tokens=3),
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
def test_astraflow_provider_defaults_to_global_umodelverse_endpoint(monkeypatch):
monkeypatch.delenv("ASTRAFLOW_API_KEY", raising=False)
monkeypatch.delenv("ASTRAFLOW_BASE_URL", raising=False)
monkeypatch.delenv("ASTRAFLOW_MODEL", raising=False)
provider = AstraflowProvider()
assert provider.provider_type == ProviderType.ASTRAFLOW
assert provider.base_url == ASTRAFLOW_BASE_URL
assert provider.get_default_model() == "gpt-4o-mini"
assert provider.validate_config() is False
def test_astraflow_cn_provider_uses_cn_endpoint_and_model_fallback(monkeypatch):
monkeypatch.setenv("ASTRAFLOW_API_KEY", "global-key")
monkeypatch.setenv("ASTRAFLOW_MODEL", "deepseek-ai/DeepSeek-V3-0324")
monkeypatch.setenv("ASTRAFLOW_CN_API_KEY", "cn-key")
monkeypatch.delenv("ASTRAFLOW_CN_MODEL", raising=False)
monkeypatch.delenv("ASTRAFLOW_CN_BASE_URL", raising=False)
provider = AstraflowCNProvider()
assert provider.provider_type == ProviderType.ASTRAFLOW_CN
assert provider.base_url == ASTRAFLOW_CN_BASE_URL
assert provider.get_default_model() == "deepseek-ai/DeepSeek-V3-0324"
assert provider.validate_config() is True
def test_astraflow_provider_generates_openai_compatible_chat_completion():
provider = AstraflowProvider(api_key="test", default_model="deepseek-ai/DeepSeek-V3-0324")
client = _Client(_response(model="deepseek-ai/DeepSeek-V3-0324"))
provider.client = client
output = provider.generate(
LLMInput(
messages=[Message(role=Role.USER, content="hi")],
max_tokens=128,
tools=[_tool()],
)
)
assert output.content == "ok"
assert output.model == "deepseek-ai/DeepSeek-V3-0324"
assert output.usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}
assert client.completions.params["model"] == "deepseek-ai/DeepSeek-V3-0324"
assert client.completions.params["max_tokens"] == 128
assert "temperature" not in client.completions.params
assert client.completions.params["tools"] == [
{
"type": "function",
"function": {
"name": "search",
"description": "Search",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}},
"strict": True,
},
}
]
def test_astraflow_provider_forwards_non_default_temperature():
provider = AstraflowProvider(api_key="test")
client = _Client(_response())
provider.client = client
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], temperature=0.2))
assert client.completions.params["temperature"] == 0.2
def test_astraflow_provider_parses_tool_calls():
provider = AstraflowProvider(api_key="test")
tool_call = SimpleNamespace(
id="call_1",
function=SimpleNamespace(name="search", arguments='{"query":"ucloud"}'),
)
message = SimpleNamespace(content="", tool_calls=[tool_call])
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")], usage=None))
provider.client = client
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"query": "ucloud"})]
assert output.usage is None
def test_astraflow_provider_preserves_malformed_tool_arguments():
provider = AstraflowProvider(api_key="test")
tool_call = SimpleNamespace(
id="call_1",
function=SimpleNamespace(name="search", arguments="{not-json"),
)
message = SimpleNamespace(content="", tool_calls=[tool_call])
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")]))
provider.client = client
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"raw": "{not-json"})]

View File

@ -1,6 +1,6 @@
import pytest import pytest
from llm.core.types import ProviderType from llm.core.types import ProviderType
from llm.providers import ClaudeProvider, OpenAIProvider, OllamaProvider, get_provider from llm.providers import AstraflowCNProvider, AstraflowProvider, ClaudeProvider, OpenAIProvider, OllamaProvider, get_provider
class TestGetProvider: class TestGetProvider:
@ -19,6 +19,16 @@ class TestGetProvider:
assert isinstance(provider, OllamaProvider) assert isinstance(provider, OllamaProvider)
assert provider.provider_type == ProviderType.OLLAMA assert provider.provider_type == ProviderType.OLLAMA
def test_get_astraflow_provider(self):
provider = get_provider("astraflow")
assert isinstance(provider, AstraflowProvider)
assert provider.provider_type == ProviderType.ASTRAFLOW
def test_get_astraflow_cn_provider(self):
provider = get_provider("astraflow_cn")
assert isinstance(provider, AstraflowCNProvider)
assert provider.provider_type == ProviderType.ASTRAFLOW_CN
def test_get_provider_by_enum(self): def test_get_provider_by_enum(self):
provider = get_provider(ProviderType.CLAUDE) provider = get_provider(ProviderType.CLAUDE)
assert isinstance(provider, ClaudeProvider) assert isinstance(provider, ClaudeProvider)
@ -52,6 +62,13 @@ class TestGetProvider:
assert isinstance(provider, OllamaProvider) assert isinstance(provider, OllamaProvider)
def test_astraflow_env_provider_is_normalized(self, monkeypatch):
monkeypatch.setenv("LLM_PROVIDER", "ASTRAFLOW")
provider = get_provider()
assert isinstance(provider, AstraflowProvider)
def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path): def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path):
monkeypatch.delenv("LLM_PROVIDER", raising=False) monkeypatch.delenv("LLM_PROVIDER", raising=False)
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)
@ -60,3 +77,12 @@ class TestGetProvider:
provider = get_provider("ollama") provider = get_provider("ollama")
assert isinstance(provider, OllamaProvider) assert isinstance(provider, OllamaProvider)
def test_saved_llm_env_selects_astraflow_cn_provider(self, monkeypatch, tmp_path):
monkeypatch.delenv("LLM_PROVIDER", raising=False)
monkeypatch.chdir(tmp_path)
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=astraflow_cn\n")
provider = get_provider()
assert isinstance(provider, AstraflowCNProvider)

View File

@ -25,6 +25,8 @@ class TestProviderType:
assert ProviderType.CLAUDE.value == "claude" assert ProviderType.CLAUDE.value == "claude"
assert ProviderType.OPENAI.value == "openai" assert ProviderType.OPENAI.value == "openai"
assert ProviderType.OLLAMA.value == "ollama" assert ProviderType.OLLAMA.value == "ollama"
assert ProviderType.ASTRAFLOW.value == "astraflow"
assert ProviderType.ASTRAFLOW_CN.value == "astraflow_cn"
class TestMessage: class TestMessage: