mirror of
https://github.com/affaan-m/everything-claude-code.git
synced 2026-05-13 18:00:35 +08:00
feat: add Astraflow provider support
This commit is contained in:
parent
03108bea62
commit
e9c8845833
10
.env.example
10
.env.example
@ -20,6 +20,16 @@ GITHUB_TOKEN=
|
||||
# ─── Optional: Package manager override ──────────────────────────────────────
|
||||
# CLAUDE_CODE_PACKAGE_MANAGER=npm # npm | pnpm | yarn | bun
|
||||
|
||||
# --- Optional: Astraflow / UModelVerse (OpenAI-compatible) -------------------
|
||||
# Global endpoint: https://api.umodelverse.ai/v1
|
||||
ASTRAFLOW_API_KEY=
|
||||
# ASTRAFLOW_MODEL=gpt-4o-mini
|
||||
# ASTRAFLOW_BASE_URL=https://api.umodelverse.ai/v1
|
||||
# China endpoint: https://api.modelverse.cn/v1
|
||||
ASTRAFLOW_CN_API_KEY=
|
||||
# ASTRAFLOW_CN_MODEL=gpt-4o-mini
|
||||
# ASTRAFLOW_CN_BASE_URL=https://api.modelverse.cn/v1
|
||||
|
||||
# ─── Session & Security ─────────────────────────────────────────────────────
|
||||
# GitHub username (used by CI scripts for credential context)
|
||||
GITHUB_USER="your-github-username"
|
||||
|
||||
@ -18,6 +18,8 @@ class ProviderType(str, Enum):
|
||||
CLAUDE = "claude"
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
ASTRAFLOW = "astraflow"
|
||||
ASTRAFLOW_CN = "astraflow_cn"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
"""Provider adapters for multiple LLM backends."""
|
||||
|
||||
from llm.providers.astraflow import AstraflowCNProvider, AstraflowProvider
|
||||
from llm.providers.claude import ClaudeProvider
|
||||
from llm.providers.openai import OpenAIProvider
|
||||
from llm.providers.ollama import OllamaProvider
|
||||
from llm.providers.resolver import get_provider, register_provider
|
||||
|
||||
__all__ = (
|
||||
"AstraflowCNProvider",
|
||||
"AstraflowProvider",
|
||||
"ClaudeProvider",
|
||||
"OpenAIProvider",
|
||||
"OllamaProvider",
|
||||
|
||||
148
src/llm/providers/astraflow.py
Normal file
148
src/llm/providers/astraflow.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Astraflow/UModelVerse OpenAI-compatible provider adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llm.core.interface import (
|
||||
AuthenticationError,
|
||||
ContextLengthError,
|
||||
LLMProvider,
|
||||
RateLimitError,
|
||||
)
|
||||
from llm.core.types import LLMInput, LLMOutput, ModelInfo, ProviderType, ToolCall
|
||||
|
||||
ASTRAFLOW_BASE_URL = "https://api.umodelverse.ai/v1"
|
||||
ASTRAFLOW_CN_BASE_URL = "https://api.modelverse.cn/v1"
|
||||
DEFAULT_ASTRAFLOW_MODEL = "gpt-4o-mini"
|
||||
|
||||
|
||||
def _parse_tool_arguments(raw_arguments: str | None) -> dict[str, Any]:
|
||||
if not raw_arguments:
|
||||
return {}
|
||||
|
||||
try:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw": raw_arguments}
|
||||
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
return {"value": arguments}
|
||||
|
||||
|
||||
class _AstraflowBaseProvider(LLMProvider):
|
||||
provider_type: ProviderType
|
||||
api_key_env: str
|
||||
base_url_env: str
|
||||
model_env: str
|
||||
fallback_model_env: str | None = None
|
||||
default_base_url: str
|
||||
default_model = DEFAULT_ASTRAFLOW_MODEL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_model: str | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key or os.environ.get(self.api_key_env) or ""
|
||||
self.base_url = base_url or os.environ.get(self.base_url_env, self.default_base_url)
|
||||
env_model = os.environ.get(self.model_env)
|
||||
fallback_model = os.environ.get(self.fallback_model_env) if self.fallback_model_env else None
|
||||
self.default_model = default_model or env_model or fallback_model or DEFAULT_ASTRAFLOW_MODEL
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
self._models = [
|
||||
ModelInfo(
|
||||
name=self.default_model,
|
||||
provider=self.provider_type,
|
||||
supports_tools=True,
|
||||
supports_vision=False,
|
||||
)
|
||||
]
|
||||
|
||||
def generate(self, llm_input: LLMInput) -> LLMOutput:
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"model": llm_input.model or self.default_model,
|
||||
"messages": [msg.to_dict() for msg in llm_input.messages],
|
||||
}
|
||||
if llm_input.temperature != 1.0:
|
||||
params["temperature"] = llm_input.temperature
|
||||
if llm_input.max_tokens is not None:
|
||||
params["max_tokens"] = llm_input.max_tokens
|
||||
if llm_input.tools:
|
||||
params["tools"] = [tool.to_openai_tool() for tool in llm_input.tools]
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
choice = response.choices[0]
|
||||
|
||||
tool_calls = None
|
||||
if choice.message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=tc.id or "",
|
||||
name=tc.function.name,
|
||||
arguments=_parse_tool_arguments(tc.function.arguments),
|
||||
)
|
||||
for tc in choice.message.tool_calls
|
||||
]
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
return LLMOutput(
|
||||
content=choice.message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
model=response.model,
|
||||
usage=usage,
|
||||
stop_reason=choice.finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
if "401" in msg or "authentication" in msg.lower():
|
||||
raise AuthenticationError(msg, provider=self.provider_type) from e
|
||||
if "429" in msg or "rate_limit" in msg.lower():
|
||||
raise RateLimitError(msg, provider=self.provider_type) from e
|
||||
if "context" in msg.lower() and "length" in msg.lower():
|
||||
raise ContextLengthError(msg, provider=self.provider_type) from e
|
||||
raise
|
||||
|
||||
def list_models(self) -> list[ModelInfo]:
|
||||
return self._models.copy()
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
return bool(self.api_key)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
class AstraflowProvider(_AstraflowBaseProvider):
|
||||
"""UModelVerse global endpoint using OpenAI-compatible chat completions."""
|
||||
|
||||
provider_type = ProviderType.ASTRAFLOW
|
||||
api_key_env = "ASTRAFLOW_API_KEY"
|
||||
base_url_env = "ASTRAFLOW_BASE_URL"
|
||||
model_env = "ASTRAFLOW_MODEL"
|
||||
default_base_url = ASTRAFLOW_BASE_URL
|
||||
|
||||
|
||||
class AstraflowCNProvider(_AstraflowBaseProvider):
|
||||
"""UModelVerse China endpoint using OpenAI-compatible chat completions."""
|
||||
|
||||
provider_type = ProviderType.ASTRAFLOW_CN
|
||||
api_key_env = "ASTRAFLOW_CN_API_KEY"
|
||||
base_url_env = "ASTRAFLOW_CN_BASE_URL"
|
||||
model_env = "ASTRAFLOW_CN_MODEL"
|
||||
fallback_model_env = "ASTRAFLOW_MODEL"
|
||||
default_base_url = ASTRAFLOW_CN_BASE_URL
|
||||
@ -7,12 +7,15 @@ from pathlib import Path
|
||||
|
||||
from llm.core.interface import LLMProvider
|
||||
from llm.core.types import ProviderType
|
||||
from llm.providers.astraflow import AstraflowCNProvider, AstraflowProvider
|
||||
from llm.providers.claude import ClaudeProvider
|
||||
from llm.providers.openai import OpenAIProvider
|
||||
from llm.providers.ollama import OllamaProvider
|
||||
|
||||
|
||||
_PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = {
|
||||
ProviderType.ASTRAFLOW: AstraflowProvider,
|
||||
ProviderType.ASTRAFLOW_CN: AstraflowCNProvider,
|
||||
ProviderType.CLAUDE: ClaudeProvider,
|
||||
ProviderType.OPENAI: OpenAIProvider,
|
||||
ProviderType.OLLAMA: OllamaProvider,
|
||||
|
||||
@ -568,7 +568,7 @@ async function runTests() {
|
||||
CLAUDE_HOOK_EVENT_NAME: 'PreToolUse',
|
||||
ECC_MCP_CONFIG_PATH: configPath,
|
||||
ECC_MCP_HEALTH_STATE_PATH: statePath,
|
||||
ECC_MCP_HEALTH_TIMEOUT_MS: '100'
|
||||
ECC_MCP_HEALTH_TIMEOUT_MS: process.platform === 'win32' ? '1000' : '100'
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
141
tests/test_astraflow_provider.py
Normal file
141
tests/test_astraflow_provider.py
Normal file
@ -0,0 +1,141 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from llm.core.types import LLMInput, Message, ProviderType, Role, ToolDefinition, ToolCall
|
||||
from llm.providers.astraflow import ASTRAFLOW_BASE_URL, ASTRAFLOW_CN_BASE_URL, AstraflowCNProvider, AstraflowProvider
|
||||
|
||||
|
||||
def _tool() -> ToolDefinition:
|
||||
return ToolDefinition(
|
||||
name="search",
|
||||
description="Search",
|
||||
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
)
|
||||
|
||||
|
||||
class _Completions:
|
||||
def __init__(self, response: SimpleNamespace) -> None:
|
||||
self.params = None
|
||||
self.response = response
|
||||
|
||||
def create(self, **params):
|
||||
self.params = params
|
||||
return self.response
|
||||
|
||||
|
||||
class _Client:
|
||||
def __init__(self, response: SimpleNamespace) -> None:
|
||||
self.completions = _Completions(response)
|
||||
self.chat = SimpleNamespace(completions=self.completions)
|
||||
|
||||
|
||||
def _response(**overrides) -> SimpleNamespace:
|
||||
message = SimpleNamespace(content="ok", tool_calls=None)
|
||||
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||
defaults = {
|
||||
"choices": [choice],
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": SimpleNamespace(prompt_tokens=1, completion_tokens=2, total_tokens=3),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
def test_astraflow_provider_defaults_to_global_umodelverse_endpoint(monkeypatch):
|
||||
monkeypatch.delenv("ASTRAFLOW_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ASTRAFLOW_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("ASTRAFLOW_MODEL", raising=False)
|
||||
|
||||
provider = AstraflowProvider()
|
||||
|
||||
assert provider.provider_type == ProviderType.ASTRAFLOW
|
||||
assert provider.base_url == ASTRAFLOW_BASE_URL
|
||||
assert provider.get_default_model() == "gpt-4o-mini"
|
||||
assert provider.validate_config() is False
|
||||
|
||||
|
||||
def test_astraflow_cn_provider_uses_cn_endpoint_and_model_fallback(monkeypatch):
|
||||
monkeypatch.setenv("ASTRAFLOW_API_KEY", "global-key")
|
||||
monkeypatch.setenv("ASTRAFLOW_MODEL", "deepseek-ai/DeepSeek-V3-0324")
|
||||
monkeypatch.setenv("ASTRAFLOW_CN_API_KEY", "cn-key")
|
||||
monkeypatch.delenv("ASTRAFLOW_CN_MODEL", raising=False)
|
||||
monkeypatch.delenv("ASTRAFLOW_CN_BASE_URL", raising=False)
|
||||
|
||||
provider = AstraflowCNProvider()
|
||||
|
||||
assert provider.provider_type == ProviderType.ASTRAFLOW_CN
|
||||
assert provider.base_url == ASTRAFLOW_CN_BASE_URL
|
||||
assert provider.get_default_model() == "deepseek-ai/DeepSeek-V3-0324"
|
||||
assert provider.validate_config() is True
|
||||
|
||||
|
||||
def test_astraflow_provider_generates_openai_compatible_chat_completion():
|
||||
provider = AstraflowProvider(api_key="test", default_model="deepseek-ai/DeepSeek-V3-0324")
|
||||
client = _Client(_response(model="deepseek-ai/DeepSeek-V3-0324"))
|
||||
provider.client = client
|
||||
|
||||
output = provider.generate(
|
||||
LLMInput(
|
||||
messages=[Message(role=Role.USER, content="hi")],
|
||||
max_tokens=128,
|
||||
tools=[_tool()],
|
||||
)
|
||||
)
|
||||
|
||||
assert output.content == "ok"
|
||||
assert output.model == "deepseek-ai/DeepSeek-V3-0324"
|
||||
assert output.usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}
|
||||
assert client.completions.params["model"] == "deepseek-ai/DeepSeek-V3-0324"
|
||||
assert client.completions.params["max_tokens"] == 128
|
||||
assert "temperature" not in client.completions.params
|
||||
assert client.completions.params["tools"] == [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search",
|
||||
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_astraflow_provider_forwards_non_default_temperature():
|
||||
provider = AstraflowProvider(api_key="test")
|
||||
client = _Client(_response())
|
||||
provider.client = client
|
||||
|
||||
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], temperature=0.2))
|
||||
|
||||
assert client.completions.params["temperature"] == 0.2
|
||||
|
||||
|
||||
def test_astraflow_provider_parses_tool_calls():
|
||||
provider = AstraflowProvider(api_key="test")
|
||||
tool_call = SimpleNamespace(
|
||||
id="call_1",
|
||||
function=SimpleNamespace(name="search", arguments='{"query":"ucloud"}'),
|
||||
)
|
||||
message = SimpleNamespace(content="", tool_calls=[tool_call])
|
||||
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")], usage=None))
|
||||
provider.client = client
|
||||
|
||||
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
|
||||
|
||||
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"query": "ucloud"})]
|
||||
assert output.usage is None
|
||||
|
||||
|
||||
def test_astraflow_provider_preserves_malformed_tool_arguments():
|
||||
provider = AstraflowProvider(api_key="test")
|
||||
tool_call = SimpleNamespace(
|
||||
id="call_1",
|
||||
function=SimpleNamespace(name="search", arguments="{not-json"),
|
||||
)
|
||||
message = SimpleNamespace(content="", tool_calls=[tool_call])
|
||||
client = _Client(_response(choices=[SimpleNamespace(message=message, finish_reason="tool_calls")]))
|
||||
provider.client = client
|
||||
|
||||
output = provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")]))
|
||||
|
||||
assert output.tool_calls == [ToolCall(id="call_1", name="search", arguments={"raw": "{not-json"})]
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
from llm.core.types import ProviderType
|
||||
from llm.providers import ClaudeProvider, OpenAIProvider, OllamaProvider, get_provider
|
||||
from llm.providers import AstraflowCNProvider, AstraflowProvider, ClaudeProvider, OpenAIProvider, OllamaProvider, get_provider
|
||||
|
||||
|
||||
class TestGetProvider:
|
||||
@ -19,6 +19,16 @@ class TestGetProvider:
|
||||
assert isinstance(provider, OllamaProvider)
|
||||
assert provider.provider_type == ProviderType.OLLAMA
|
||||
|
||||
def test_get_astraflow_provider(self):
|
||||
provider = get_provider("astraflow")
|
||||
assert isinstance(provider, AstraflowProvider)
|
||||
assert provider.provider_type == ProviderType.ASTRAFLOW
|
||||
|
||||
def test_get_astraflow_cn_provider(self):
|
||||
provider = get_provider("astraflow_cn")
|
||||
assert isinstance(provider, AstraflowCNProvider)
|
||||
assert provider.provider_type == ProviderType.ASTRAFLOW_CN
|
||||
|
||||
def test_get_provider_by_enum(self):
|
||||
provider = get_provider(ProviderType.CLAUDE)
|
||||
assert isinstance(provider, ClaudeProvider)
|
||||
@ -52,6 +62,13 @@ class TestGetProvider:
|
||||
|
||||
assert isinstance(provider, OllamaProvider)
|
||||
|
||||
def test_astraflow_env_provider_is_normalized(self, monkeypatch):
|
||||
monkeypatch.setenv("LLM_PROVIDER", "ASTRAFLOW")
|
||||
|
||||
provider = get_provider()
|
||||
|
||||
assert isinstance(provider, AstraflowProvider)
|
||||
|
||||
def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("LLM_PROVIDER", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
@ -60,3 +77,12 @@ class TestGetProvider:
|
||||
provider = get_provider("ollama")
|
||||
|
||||
assert isinstance(provider, OllamaProvider)
|
||||
|
||||
def test_saved_llm_env_selects_astraflow_cn_provider(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("LLM_PROVIDER", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=astraflow_cn\n")
|
||||
|
||||
provider = get_provider()
|
||||
|
||||
assert isinstance(provider, AstraflowCNProvider)
|
||||
|
||||
@ -25,6 +25,8 @@ class TestProviderType:
|
||||
assert ProviderType.CLAUDE.value == "claude"
|
||||
assert ProviderType.OPENAI.value == "openai"
|
||||
assert ProviderType.OLLAMA.value == "ollama"
|
||||
assert ProviderType.ASTRAFLOW.value == "astraflow"
|
||||
assert ProviderType.ASTRAFLOW_CN.value == "astraflow_cn"
|
||||
|
||||
|
||||
class TestMessage:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user