diff --git a/pyproject.toml b/pyproject.toml index 34778aef..ee13baa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,84 +1,78 @@ -[project] -name = "llm-abstraction" -version = "0.1.0" -description = "Provider-agnostic LLM abstraction layer" -readme = "README.md" -requires-python = ">=3.11" -license = {text = "MIT"} -authors = [ - {name = "Affaan Mustafa", email = "affaan@example.com"} -] -keywords = ["llm", "openai", "anthropic", "ollama", "ai"] -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", -] - -dependencies = [ - "anthropic>=0.25.0", - "openai>=1.30.0", -] - -[project.optional-dependencies] -dev = [ - "pytest>=8.0", - "pytest-asyncio>=0.23", - "pytest-cov>=4.1", - "ruff>=0.4", - "mypy>=1.10", - "ruff>=0.4", -] -test = [ - "pytest>=8.0", - "pytest-asyncio>=0.23", - "pytest-cov>=4.1", - "pytest-mock>=3.12", -] - -[project.urls] -Homepage = "https://github.com/affaan-m/everything-claude-code" -Repository = "https://github.com/affaan-m/everything-claude-code" - -[project.scripts] -llm-select = "llm.cli.selector:main" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["src/llm"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -asyncio_mode = "auto" -filterwarnings = ["ignore::DeprecationWarning"] - -[tool.coverage.run] -source = ["src/llm"] -branch = true - -[tool.coverage.report] -exclude_lines = [ - "pragma: no cover", - "if TYPE_CHECKING:", - "raise NotImplementedError", -] - -[tool.ruff] -src-path = ["src"] -target-version = "py311" - -[tool.ruff.lint] -select = ["E", "F", "I", "N", "W", "UP"] -ignore = ["E501"] - -[tool.mypy] -python_version = "3.11" -src_paths = ["src"] -warn_return_any = true -warn_unused_ignores = true +[project] +name = "llm-abstraction" +version = "0.1.0" +description = "Provider-agnostic LLM abstraction layer" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} +authors = [ + {name = "Affaan Mustafa", email = "affaan@example.com"} +] +keywords = ["llm", "openai", "anthropic", "ollama", "ai"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "anthropic>=0.25.0", + "openai>=1.30.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "pytest-cov>=4.1", + "pytest-mock>=3.12", + "ruff>=0.4", + "mypy>=1.10", +] + +[project.urls] +Homepage = "https://github.com/affaan-m/everything-claude-code" +Repository = "https://github.com/affaan-m/everything-claude-code" + +[project.scripts] +llm-select = "llm.cli.selector:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/llm"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +filterwarnings = ["ignore::DeprecationWarning"] + +[tool.coverage.run] +source = ["src/llm"] +branch = true + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "raise NotImplementedError", +] + +[tool.ruff] +src-path = ["src"] +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.11" +src_paths = ["src"] +warn_return_any = true +warn_unused_ignores = true diff --git a/src/llm/__init__.py b/src/llm/__init__.py index 3b25554b..e7c28925 100644 --- a/src/llm/__init__.py +++ b/src/llm/__init__.py @@ -1,33 +1,33 @@ -""" -LLM Abstraction Layer - -Provider-agnostic interface for multiple LLM backends. -""" - -from llm.core.interface import LLMProvider -from llm.core.types import LLMInput, LLMOutput, Message, ToolCall, ToolDefinition, ToolResult -from llm.providers import get_provider -from llm.tools import ToolExecutor, ToolRegistry -from llm.cli.selector import interactive_select - -__version__ = "0.1.0" - -__all__ = [ - "LLMProvider", - "LLMInput", - "LLMOutput", - "Message", - "get_provider", - "ToolCall", - "ToolDefinition", - "ToolResult", - "ToolExecutor", - "ToolRegistry", - "interactive_select", -] - - -def gui() -> None: - from llm.gui.selector import main - main() - +""" +LLM Abstraction Layer + +Provider-agnostic interface for multiple LLM backends. +""" + +from llm.core.interface import LLMProvider +from llm.core.types import LLMInput, LLMOutput, Message, ToolCall, ToolDefinition, ToolResult +from llm.providers import get_provider +from llm.tools import ToolExecutor, ToolRegistry +from llm.cli.selector import interactive_select + +__version__ = "0.1.0" + +__all__ = [ + "LLMProvider", + "LLMInput", + "LLMOutput", + "Message", + "get_provider", + "ToolCall", + "ToolDefinition", + "ToolResult", + "ToolExecutor", + "ToolRegistry", + "interactive_select", +] + + +def gui() -> None: + from llm.cli.selector import main + main() + diff --git a/src/llm/prompt/builder.py b/src/llm/prompt/builder.py index 4d588e27..2eb719c3 100644 --- a/src/llm/prompt/builder.py +++ b/src/llm/prompt/builder.py @@ -1,101 +1,102 @@ -"""Prompt builder for normalizing prompts across providers.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -from llm.core.types import LLMInput, Message, Role, ToolDefinition -from llm.providers.claude import ClaudeProvider -from llm.providers.openai import OpenAIProvider -from llm.providers.ollama import OllamaProvider - - -@dataclass -class PromptConfig: - system_template: str | None = None - user_template: str | None = None - include_tools_in_system: bool = True - tool_format: str = "native" - - -class PromptBuilder: - def __init__(self, config: PromptConfig | None = None) -> None: - self.config = config or PromptConfig() - - def build(self, messages: list[Message], tools: list[ToolDefinition] | None = None) -> list[Message]: - if not messages: - return [] - - result: list[Message] = [] - system_parts: list[str] = [] - - if self.config.system_template: - system_parts.append(self.config.system_template) - - if tools and self.config.include_tools_in_system: - tools_desc = self._format_tools(tools) - system_parts.append(f"\n\n## Available Tools\n{tools_desc}") - - if messages[0].role == Role.SYSTEM: - system_parts.insert(0, messages[0].content) - result.extend(messages[1:]) - else: - if system_parts: - result.insert(0, Message(role=Role.SYSTEM, content="\n\n".join(system_parts))) - result.extend(messages) - - return result - - def _format_tools(self, tools: list[ToolDefinition]) -> str: - lines = [] - for tool in tools: - lines.append(f"### {tool.name}") - lines.append(tool.description) - if tool.parameters: - lines.append("Parameters:") - lines.append(self._format_parameters(tool.parameters)) - return "\n".join(lines) - - def _format_parameters(self, params: dict[str, Any]) -> str: - if "properties" not in params: - return str(params) - lines = [] - required = params.get("required", []) - for name, spec in params["properties"].items(): - prop_type = spec.get("type", "any") - desc = spec.get("description", "") - required_mark = "(required)" if name in required else "(optional)" - lines.append(f" - {name}: {prop_type} {required_mark} - {desc}") - return "\n".join(lines) if lines else str(params) - - -_PROVIDER_TEMPLATE_MAP: dict[str, dict[str, Any]] = { - "claude": { - "include_tools_in_system": False, - "tool_format": "anthropic", - }, - "openai": { - "include_tools_in_system": False, - "tool_format": "openai", - }, - "ollama": { - "include_tools_in_system": True, - "tool_format": "text", - }, -} - - -def get_provider_builder(provider_name: str) -> PromptBuilder: - config_dict = _PROVIDER_TEMPLATE_MAP.get(provider_name.lower(), {}) - config = PromptConfig(**config_dict) - return PromptBuilder(config) - - -def adapt_messages_for_provider( - messages: list[Message], - provider: str, - tools: list[ToolDefinition] | None = None, -) -> list[Message]: - builder = get_provider_builder(provider) - return builder.build(messages, tools) +"""Prompt builder for normalizing prompts across providers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from llm.core.types import LLMInput, Message, Role, ToolDefinition +from llm.providers.claude import ClaudeProvider +from llm.providers.openai import OpenAIProvider +from llm.providers.ollama import OllamaProvider + + +@dataclass +class PromptConfig: + system_template: str | None = None + user_template: str | None = None + include_tools_in_system: bool = True + tool_format: str = "native" + + +class PromptBuilder: + def __init__(self, config: PromptConfig | None = None) -> None: + self.config = config or PromptConfig() + + def build(self, messages: list[Message], tools: list[ToolDefinition] | None = None) -> list[Message]: + if not messages: + return [] + + result: list[Message] = [] + system_parts: list[str] = [] + + if self.config.system_template: + system_parts.append(self.config.system_template) + + if tools and self.config.include_tools_in_system: + tools_desc = self._format_tools(tools) + system_parts.append(f"\n\n## Available Tools\n{tools_desc}") + + if messages[0].role == Role.SYSTEM: + system_parts.insert(0, messages[0].content) + result.insert(0, Message(role=Role.SYSTEM, content="\n\n".join(system_parts))) + result.extend(messages[1:]) + else: + if system_parts: + result.insert(0, Message(role=Role.SYSTEM, content="\n\n".join(system_parts))) + result.extend(messages) + + return result + + def _format_tools(self, tools: list[ToolDefinition]) -> str: + lines = [] + for tool in tools: + lines.append(f"### {tool.name}") + lines.append(tool.description) + if tool.parameters: + lines.append("Parameters:") + lines.append(self._format_parameters(tool.parameters)) + return "\n".join(lines) + + def _format_parameters(self, params: dict[str, Any]) -> str: + if "properties" not in params: + return str(params) + lines = [] + required = params.get("required", []) + for name, spec in params["properties"].items(): + prop_type = spec.get("type", "any") + desc = spec.get("description", "") + required_mark = "(required)" if name in required else "(optional)" + lines.append(f" - {name}: {prop_type} {required_mark} - {desc}") + return "\n".join(lines) if lines else str(params) + + +_PROVIDER_TEMPLATE_MAP: dict[str, dict[str, Any]] = { + "claude": { + "include_tools_in_system": False, + "tool_format": "anthropic", + }, + "openai": { + "include_tools_in_system": False, + "tool_format": "openai", + }, + "ollama": { + "include_tools_in_system": True, + "tool_format": "text", + }, +} + + +def get_provider_builder(provider_name: str) -> PromptBuilder: + config_dict = _PROVIDER_TEMPLATE_MAP.get(provider_name.lower(), {}) + config = PromptConfig(**config_dict) + return PromptBuilder(config) + + +def adapt_messages_for_provider( + messages: list[Message], + provider: str, + tools: list[ToolDefinition] | None = None, +) -> list[Message]: + builder = get_provider_builder(provider) + return builder.build(messages, tools) diff --git a/src/llm/providers/claude.py b/src/llm/providers/claude.py index cb41ed4b..975f036b 100644 --- a/src/llm/providers/claude.py +++ b/src/llm/providers/claude.py @@ -1,103 +1,105 @@ -"""Claude provider adapter.""" - -from __future__ import annotations - -import os -from typing import Any - -from anthropic import Anthropic - -from llm.core.interface import ( - AuthenticationError, - ContextLengthError, - LLMProvider, - RateLimitError, -) -from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall - - -class ClaudeProvider(LLMProvider): - provider_type = ProviderType.CLAUDE - - def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: - self.client = Anthropic(api_key=api_key or os.environ.get("ANTHROPIC_API_KEY"), base_url=base_url) - self._models = [ - ModelInfo( - name="claude-opus-4-5", - provider=ProviderType.CLAUDE, - supports_tools=True, - supports_vision=True, - max_tokens=8192, - context_window=200000, - ), - ModelInfo( - name="claude-sonnet-4-7", - provider=ProviderType.CLAUDE, - supports_tools=True, - supports_vision=True, - max_tokens=8192, - context_window=200000, - ), - ModelInfo( - name="claude-haiku-4-7", - provider=ProviderType.CLAUDE, - supports_tools=True, - supports_vision=False, - max_tokens=4096, - context_window=200000, - ), - ] - - def generate(self, input: LLMInput) -> LLMOutput: - try: - params: dict[str, Any] = { - "model": input.model or "claude-sonnet-4-7", - "messages": [msg.to_dict() for msg in input.messages], - "temperature": input.temperature, - } - if input.max_tokens: - params["max_tokens"] = input.max_tokens - if input.tools: - params["tools"] = [tool.to_dict() for tool in input.tools] - - response = self.client.messages.create(**params) - - tool_calls = None - if response.content and hasattr(response.content[0], "type"): - if response.content[0].type == "tool_use": - tool_calls = [ - ToolCall( - id=getattr(response.content[0], "id", ""), - name=getattr(response.content[0], "name", ""), - arguments=getattr(response.content[0].input, "__dict__", {}), - ) - ] - - return LLMOutput( - content=response.content[0].text if response.content else "", - tool_calls=tool_calls, - model=response.model, - usage={ - "input_tokens": response.usage.input_tokens, - "output_tokens": response.usage.output_tokens, - }, - stop_reason=response.stop_reason, - ) - except Exception as e: - msg = str(e) - if "401" in msg or "authentication" in msg.lower(): - raise AuthenticationError(msg, provider=ProviderType.CLAUDE) from e - if "429" in msg or "rate_limit" in msg.lower(): - raise RateLimitError(msg, provider=ProviderType.CLAUDE) from e - if "context" in msg.lower() and "length" in msg.lower(): - raise ContextLengthError(msg, provider=ProviderType.CLAUDE) from e - raise - - def list_models(self) -> list[ModelInfo]: - return self._models.copy() - - def validate_config(self) -> bool: - return bool(self.client.api_key) - - def get_default_model(self) -> str: - return "claude-sonnet-4-7" +"""Claude provider adapter.""" + +from __future__ import annotations + +import os +from typing import Any + +from anthropic import Anthropic + +from llm.core.interface import ( + AuthenticationError, + ContextLengthError, + LLMProvider, + RateLimitError, +) +from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall + + +class ClaudeProvider(LLMProvider): + provider_type = ProviderType.CLAUDE + + def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: + self.client = Anthropic(api_key=api_key or os.environ.get("ANTHROPIC_API_KEY"), base_url=base_url) + self._models = [ + ModelInfo( + name="claude-opus-4-5", + provider=ProviderType.CLAUDE, + supports_tools=True, + supports_vision=True, + max_tokens=8192, + context_window=200000, + ), + ModelInfo( + name="claude-sonnet-4-7", + provider=ProviderType.CLAUDE, + supports_tools=True, + supports_vision=True, + max_tokens=8192, + context_window=200000, + ), + ModelInfo( + name="claude-haiku-4-7", + provider=ProviderType.CLAUDE, + supports_tools=True, + supports_vision=False, + max_tokens=4096, + context_window=200000, + ), + ] + + def generate(self, input: LLMInput) -> LLMOutput: + try: + params: dict[str, Any] = { + "model": input.model or "claude-sonnet-4-7", + "messages": [msg.to_dict() for msg in input.messages], + "temperature": input.temperature, + } + if input.max_tokens: + params["max_tokens"] = input.max_tokens + else: + params["max_tokens"] = 8192 # required by Anthropic API + if input.tools: + params["tools"] = [tool.to_dict() for tool in input.tools] + + response = self.client.messages.create(**params) + + tool_calls = None + if response.content and hasattr(response.content[0], "type"): + if response.content[0].type == "tool_use": + tool_calls = [ + ToolCall( + id=getattr(response.content[0], "id", ""), + name=getattr(response.content[0], "name", ""), + arguments=getattr(response.content[0].input, "__dict__", {}), + ) + ] + + return LLMOutput( + content=response.content[0].text if response.content else "", + tool_calls=tool_calls, + model=response.model, + usage={ + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens, + }, + stop_reason=response.stop_reason, + ) + except Exception as e: + msg = str(e) + if "401" in msg or "authentication" in msg.lower(): + raise AuthenticationError(msg, provider=ProviderType.CLAUDE) from e + if "429" in msg or "rate_limit" in msg.lower(): + raise RateLimitError(msg, provider=ProviderType.CLAUDE) from e + if "context" in msg.lower() and "length" in msg.lower(): + raise ContextLengthError(msg, provider=ProviderType.CLAUDE) from e + raise + + def list_models(self) -> list[ModelInfo]: + return self._models.copy() + + def validate_config(self) -> bool: + return bool(self.client.api_key) + + def get_default_model(self) -> str: + return "claude-sonnet-4-7" diff --git a/src/llm/providers/openai.py b/src/llm/providers/openai.py index ebbcf78d..019696cf 100644 --- a/src/llm/providers/openai.py +++ b/src/llm/providers/openai.py @@ -1,113 +1,114 @@ -"""OpenAI provider adapter.""" - -from __future__ import annotations - -import os -from typing import Any - -from openai import OpenAI - -from llm.core.interface import ( - AuthenticationError, - ContextLengthError, - LLMProvider, - RateLimitError, -) -from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall - - -class OpenAIProvider(LLMProvider): - provider_type = ProviderType.OPENAI - - def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: - self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"), base_url=base_url) - self._models = [ - ModelInfo( - name="gpt-4o", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=True, - max_tokens=4096, - context_window=128000, - ), - ModelInfo( - name="gpt-4o-mini", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=True, - max_tokens=4096, - context_window=128000, - ), - ModelInfo( - name="gpt-4-turbo", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=True, - max_tokens=4096, - context_window=128000, - ), - ModelInfo( - name="gpt-3.5-turbo", - provider=ProviderType.OPENAI, - supports_tools=True, - supports_vision=False, - max_tokens=4096, - context_window=16385, - ), - ] - - def generate(self, input: LLMInput) -> LLMOutput: - try: - params: dict[str, Any] = { - "model": input.model or "gpt-4o-mini", - "messages": [msg.to_dict() for msg in input.messages], - "temperature": input.temperature, - } - if input.max_tokens: - params["max_tokens"] = input.max_tokens - if input.tools: - params["tools"] = [tool.to_dict() for tool in input.tools] - - response = self.client.chat.completions.create(**params) - choice = response.choices[0] - - tool_calls = None - if choice.message.tool_calls: - tool_calls = [ - ToolCall( - id=tc.id or "", - name=tc.function.name, - arguments={} if tc.function.arguments == "" else tc.function.arguments, - ) - for tc in choice.message.tool_calls - ] - - return LLMOutput( - content=choice.message.content or "", - tool_calls=tool_calls, - model=response.model, - usage={ - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - }, - stop_reason=choice.finish_reason, - ) - except Exception as e: - msg = str(e) - if "401" in msg or "authentication" in msg.lower(): - raise AuthenticationError(msg, provider=ProviderType.OPENAI) from e - if "429" in msg or "rate_limit" in msg.lower(): - raise RateLimitError(msg, provider=ProviderType.OPENAI) from e - if "context" in msg.lower() and "length" in msg.lower(): - raise ContextLengthError(msg, provider=ProviderType.OPENAI) from e - raise - - def list_models(self) -> list[ModelInfo]: - return self._models.copy() - - def validate_config(self) -> bool: - return bool(self.client.api_key) - - def get_default_model(self) -> str: - return "gpt-4o-mini" +"""OpenAI provider adapter.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +from openai import OpenAI + +from llm.core.interface import ( + AuthenticationError, + ContextLengthError, + LLMProvider, + RateLimitError, +) +from llm.core.types import LLMInput, LLMOutput, Message, ModelInfo, ProviderType, ToolCall + + +class OpenAIProvider(LLMProvider): + provider_type = ProviderType.OPENAI + + def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: + self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"), base_url=base_url) + self._models = [ + ModelInfo( + name="gpt-4o", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=True, + max_tokens=4096, + context_window=128000, + ), + ModelInfo( + name="gpt-4o-mini", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=True, + max_tokens=4096, + context_window=128000, + ), + ModelInfo( + name="gpt-4-turbo", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=True, + max_tokens=4096, + context_window=128000, + ), + ModelInfo( + name="gpt-3.5-turbo", + provider=ProviderType.OPENAI, + supports_tools=True, + supports_vision=False, + max_tokens=4096, + context_window=16385, + ), + ] + + def generate(self, input: LLMInput) -> LLMOutput: + try: + params: dict[str, Any] = { + "model": input.model or "gpt-4o-mini", + "messages": [msg.to_dict() for msg in input.messages], + "temperature": input.temperature, + } + if input.max_tokens: + params["max_tokens"] = input.max_tokens + if input.tools: + params["tools"] = [tool.to_dict() for tool in input.tools] + + response = self.client.chat.completions.create(**params) + choice = response.choices[0] + + tool_calls = None + if choice.message.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id or "", + name=tc.function.name, + arguments={} if not tc.function.arguments else json.loads(tc.function.arguments), + ) + for tc in choice.message.tool_calls + ] + + return LLMOutput( + content=choice.message.content or "", + tool_calls=tool_calls, + model=response.model, + usage={ + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + }, + stop_reason=choice.finish_reason, + ) + except Exception as e: + msg = str(e) + if "401" in msg or "authentication" in msg.lower(): + raise AuthenticationError(msg, provider=ProviderType.OPENAI) from e + if "429" in msg or "rate_limit" in msg.lower(): + raise RateLimitError(msg, provider=ProviderType.OPENAI) from e + if "context" in msg.lower() and "length" in msg.lower(): + raise ContextLengthError(msg, provider=ProviderType.OPENAI) from e + raise + + def list_models(self) -> list[ModelInfo]: + return self._models.copy() + + def validate_config(self) -> bool: + return bool(self.client.api_key) + + def get_default_model(self) -> str: + return "gpt-4o-mini" diff --git a/tests/test_builder.py b/tests/test_builder.py index 4ba4c0e1..ae022dbc 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -1,61 +1,69 @@ -import pytest -from llm.core.types import LLMInput, Message, Role, ToolDefinition -from llm.prompt import PromptBuilder, adapt_messages_for_provider - - -class TestPromptBuilder: - def test_build_without_system(self): - messages = [Message(role=Role.USER, content="Hello")] - builder = PromptBuilder() - result = builder.build(messages) - - assert len(result) == 1 - assert result[0].role == Role.USER - - def test_build_with_system(self): - messages = [ - Message(role=Role.SYSTEM, content="You are helpful."), - Message(role=Role.USER, content="Hello"), - ] - builder = PromptBuilder() - result = builder.build(messages) - - assert len(result) == 2 - assert result[0].role == Role.SYSTEM - - def test_build_adds_system_from_config(self): - messages = [Message(role=Role.USER, content="Hello")] - builder = PromptBuilder(system_template="You are a pirate.") - result = builder.build(messages) - - assert len(result) == 2 - assert "pirate" in result[0].content - - def test_build_with_tools(self): - messages = [Message(role=Role.USER, content="Search for something")] - tools = [ - ToolDefinition(name="search", description="Search the web", parameters={}), - ] - builder = PromptBuilder(include_tools_in_system=True) - result = builder.build(messages, tools) - - assert len(result) == 2 - assert "search" in result[0].content - assert "Available Tools" in result[0].content - - -class TestAdaptMessagesForProvider: - def test_adapt_for_claude(self): - messages = [Message(role=Role.USER, content="Hello")] - result = adapt_messages_for_provider(messages, "claude") - assert len(result) == 1 - - def test_adapt_for_openai(self): - messages = [Message(role=Role.USER, content="Hello")] - result = adapt_messages_for_provider(messages, "openai") - assert len(result) == 1 - - def test_adapt_for_ollama(self): - messages = [Message(role=Role.USER, content="Hello")] - result = adapt_messages_for_provider(messages, "ollama") - assert len(result) == 1 +import pytest +from llm.core.types import LLMInput, Message, Role, ToolDefinition +from llm.prompt import PromptBuilder, adapt_messages_for_provider +from llm.prompt.builder import PromptConfig + + +class TestPromptBuilder: + def test_build_without_system(self): + messages = [Message(role=Role.USER, content="Hello")] + builder = PromptBuilder() + result = builder.build(messages) + + assert len(result) == 1 + assert result[0].role == Role.USER + + def test_build_with_system(self): + messages = [ + Message(role=Role.SYSTEM, content="You are helpful."), + Message(role=Role.USER, content="Hello"), + ] + builder = PromptBuilder() + result = builder.build(messages) + + assert len(result) == 2 + assert result[0].role == Role.SYSTEM + + def test_build_adds_system_from_config(self): + messages = [Message(role=Role.USER, content="Hello")] + builder = PromptBuilder(system_template="You are a pirate.") + result = builder.build(messages) + + assert len(result) == 2 + assert "pirate" in result[0].content + + def test_build_adds_system_from_config(self): + messages = [Message(role=Role.USER, content="Hello")] + builder = PromptBuilder(config=PromptConfig(system_template="You are a pirate.")) + result = builder.build(messages) + + assert len(result) == 2 + assert "pirate" in result[0].content + def test_build_with_tools(self): + messages = [Message(role=Role.USER, content="Search for something")] + tools = [ + ToolDefinition(name="search", description="Search the web", parameters={}), + ] + builder = PromptBuilder(include_tools_in_system=True) + result = builder.build(messages, tools) + + assert len(result) == 2 + assert "search" in result[0].content + assert "Available Tools" in result[0].content + + +class TestAdaptMessagesForProvider: + def test_adapt_for_claude(self): + messages = [Message(role=Role.USER, content="Hello")] + result = adapt_messages_for_provider(messages, "claude") + assert len(result) == 1 + + def test_adapt_for_openai(self): + messages = [Message(role=Role.USER, content="Hello")] + result = adapt_messages_for_provider(messages, "openai") + assert len(result) == 1 + + def test_adapt_for_ollama(self): + messages = [Message(role=Role.USER, content="Hello")] + result = adapt_messages_for_provider(messages, "ollama") + assert len(result) == 1