mirror of
https://github.com/affaan-m/everything-claude-code.git
synced 2026-05-14 10:43:20 +08:00
fix: port LLM provider config and tool schemas
This commit is contained in:
parent
f442bac8c9
commit
7fa1e5b6db
@ -57,6 +57,24 @@ class ToolDefinition:
|
|||||||
"strict": self.strict,
|
"strict": self.strict,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def to_openai_tool(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
"strict": self.strict,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_anthropic_tool(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"input_schema": self.parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ToolCall:
|
class ToolCall:
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class ClaudeProvider(LLMProvider):
|
|||||||
else:
|
else:
|
||||||
params["max_tokens"] = 8192 # required by Anthropic API
|
params["max_tokens"] = 8192 # required by Anthropic API
|
||||||
if input.tools:
|
if input.tools:
|
||||||
params["tools"] = [tool.to_dict() for tool in input.tools]
|
params["tools"] = [tool.to_anthropic_tool() for tool in input.tools]
|
||||||
|
|
||||||
response = self.client.messages.create(**params)
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class OpenAIProvider(LLMProvider):
|
|||||||
if input.max_tokens:
|
if input.max_tokens:
|
||||||
params["max_tokens"] = input.max_tokens
|
params["max_tokens"] = input.max_tokens
|
||||||
if input.tools:
|
if input.tools:
|
||||||
params["tools"] = [tool.to_dict() for tool in input.tools]
|
params["tools"] = [tool.to_openai_tool() for tool in input.tools]
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from llm.core.interface import LLMProvider
|
from llm.core.interface import LLMProvider
|
||||||
from llm.core.types import ProviderType
|
from llm.core.types import ProviderType
|
||||||
@ -17,10 +18,45 @@ _PROVIDER_MAP: dict[ProviderType, type[LLMProvider]] = {
|
|||||||
ProviderType.OLLAMA: OllamaProvider,
|
ProviderType.OLLAMA: OllamaProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LLM_ENV_FILE = ".llm.env"
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_env_value(value: str) -> str:
|
||||||
|
value = value.strip()
|
||||||
|
if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
|
||||||
|
return value[1:-1]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _read_saved_llm_config(env_path: str | Path = LLM_ENV_FILE) -> dict[str, str]:
|
||||||
|
path = Path(env_path)
|
||||||
|
if not path.is_file():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
config: dict[str, str] = {}
|
||||||
|
for line in path.read_text().splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped or stripped.startswith("#") or "=" not in stripped:
|
||||||
|
continue
|
||||||
|
key, value = stripped.split("=", 1)
|
||||||
|
config[key.strip()] = _strip_env_value(value)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_type(provider_type: ProviderType | str | None) -> ProviderType | str:
|
||||||
|
if provider_type is not None:
|
||||||
|
return provider_type
|
||||||
|
|
||||||
|
env_provider = os.environ.get("LLM_PROVIDER")
|
||||||
|
if env_provider:
|
||||||
|
return _strip_env_value(env_provider).lower()
|
||||||
|
|
||||||
|
saved_config = _read_saved_llm_config()
|
||||||
|
return saved_config.get("LLM_PROVIDER", "claude").lower()
|
||||||
|
|
||||||
|
|
||||||
def get_provider(provider_type: ProviderType | str | None = None, **kwargs: str) -> LLMProvider:
|
def get_provider(provider_type: ProviderType | str | None = None, **kwargs: str) -> LLMProvider:
|
||||||
if provider_type is None:
|
provider_type = _resolve_provider_type(provider_type)
|
||||||
provider_type = os.environ.get("LLM_PROVIDER", "claude").lower()
|
|
||||||
|
|
||||||
if isinstance(provider_type, str):
|
if isinstance(provider_type, str):
|
||||||
try:
|
try:
|
||||||
|
|||||||
88
tests/test_provider_tools.py
Normal file
88
tests/test_provider_tools.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from llm.core.types import LLMInput, Message, Role, ToolDefinition
|
||||||
|
from llm.providers.claude import ClaudeProvider
|
||||||
|
from llm.providers.openai import OpenAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _tool() -> ToolDefinition:
|
||||||
|
return ToolDefinition(
|
||||||
|
name="search",
|
||||||
|
description="Search",
|
||||||
|
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _OpenAICompletions:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.params = None
|
||||||
|
|
||||||
|
def create(self, **params):
|
||||||
|
self.params = params
|
||||||
|
return SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=SimpleNamespace(content="ok", tool_calls=None), finish_reason="stop")],
|
||||||
|
model=params["model"],
|
||||||
|
usage=SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _OpenAIClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.completions = _OpenAICompletions()
|
||||||
|
self.chat = SimpleNamespace(completions=self.completions)
|
||||||
|
|
||||||
|
|
||||||
|
class _AnthropicMessages:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.params = None
|
||||||
|
|
||||||
|
def create(self, **params):
|
||||||
|
self.params = params
|
||||||
|
return SimpleNamespace(
|
||||||
|
content=[SimpleNamespace(text="ok", type="text")],
|
||||||
|
model=params["model"],
|
||||||
|
usage=SimpleNamespace(input_tokens=1, output_tokens=1),
|
||||||
|
stop_reason="end_turn",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _AnthropicClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.messages = _AnthropicMessages()
|
||||||
|
self.api_key = "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_provider_serializes_tools_for_chat_completions():
|
||||||
|
provider = OpenAIProvider(api_key="test")
|
||||||
|
client = _OpenAIClient()
|
||||||
|
provider.client = client
|
||||||
|
|
||||||
|
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], tools=[_tool()]))
|
||||||
|
|
||||||
|
assert client.completions.params["tools"] == [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search",
|
||||||
|
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}},
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_provider_serializes_tools_for_messages_api():
|
||||||
|
provider = ClaudeProvider(api_key="test")
|
||||||
|
client = _AnthropicClient()
|
||||||
|
provider.client = client
|
||||||
|
|
||||||
|
provider.generate(LLMInput(messages=[Message(role=Role.USER, content="hi")], tools=[_tool()]))
|
||||||
|
|
||||||
|
assert client.messages.params["tools"] == [
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search",
|
||||||
|
"input_schema": {"type": "object", "properties": {"query": {"type": "string"}}},
|
||||||
|
}
|
||||||
|
]
|
||||||
@ -26,3 +26,37 @@ class TestGetProvider:
|
|||||||
def test_invalid_provider_raises(self):
|
def test_invalid_provider_raises(self):
|
||||||
with pytest.raises(ValueError, match="Unknown provider type"):
|
with pytest.raises(ValueError, match="Unknown provider type"):
|
||||||
get_provider("invalid")
|
get_provider("invalid")
|
||||||
|
|
||||||
|
def test_saved_llm_env_selects_provider(self, monkeypatch, tmp_path):
|
||||||
|
monkeypatch.delenv("LLM_PROVIDER", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=ollama\nLLM_MODEL=llama3.2\n")
|
||||||
|
|
||||||
|
provider = get_provider()
|
||||||
|
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
|
||||||
|
def test_env_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setenv("LLM_PROVIDER", "ollama")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=openai\n")
|
||||||
|
|
||||||
|
provider = get_provider()
|
||||||
|
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
|
||||||
|
def test_env_provider_is_normalized(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("LLM_PROVIDER", "OLLAMA")
|
||||||
|
|
||||||
|
provider = get_provider()
|
||||||
|
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
|
||||||
|
def test_explicit_provider_overrides_saved_llm_env(self, monkeypatch, tmp_path):
|
||||||
|
monkeypatch.delenv("LLM_PROVIDER", raising=False)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
tmp_path.joinpath(".llm.env").write_text("LLM_PROVIDER=openai\n")
|
||||||
|
|
||||||
|
provider = get_provider("ollama")
|
||||||
|
|
||||||
|
assert isinstance(provider, OllamaProvider)
|
||||||
|
|||||||
@ -63,6 +63,37 @@ class TestToolDefinition:
|
|||||||
assert result["name"] == "search"
|
assert result["name"] == "search"
|
||||||
assert result["strict"] is True
|
assert result["strict"] is True
|
||||||
|
|
||||||
|
def test_tool_to_openai_tool(self):
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name="search",
|
||||||
|
description="Search",
|
||||||
|
parameters={"type": "object"},
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.to_openai_tool() == {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search",
|
||||||
|
"parameters": {"type": "object"},
|
||||||
|
"strict": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_tool_to_anthropic_tool(self):
|
||||||
|
tool = ToolDefinition(
|
||||||
|
name="search",
|
||||||
|
description="Search",
|
||||||
|
parameters={"type": "object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool.to_anthropic_tool() == {
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search",
|
||||||
|
"input_schema": {"type": "object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestToolCall:
|
class TestToolCall:
|
||||||
def test_create_tool_call(self):
|
def test_create_tool_call(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user