144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
"""Client utilities for interacting with the Ollama API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Iterable, Iterator
|
|
|
|
try: # pragma: no cover - optional dependency may not be installed in CI
|
|
import ollama
|
|
except ImportError: # pragma: no cover - fallback path for environments without Ollama
|
|
ollama = None # type: ignore[assignment]
|
|
|
|
|
|
class OllamaClientError(RuntimeError):
|
|
"""Base exception raised when Ollama operations fail."""
|
|
|
|
|
|
class OllamaUnavailableError(OllamaClientError):
|
|
"""Raised when the Ollama Python SDK is not available."""
|
|
|
|
|
|
class OllamaClient:
|
|
"""Thin wrapper around the Ollama Python SDK with graceful degradation."""
|
|
|
|
def __init__(self, host: str | None = None) -> None:
|
|
self._host = host
|
|
self._client = None
|
|
self._cached_models: list[str] | None = None
|
|
|
|
if ollama is None:
|
|
return
|
|
|
|
if host and hasattr(ollama, "Client"):
|
|
self._client = ollama.Client(host=host) # type: ignore[call-arg]
|
|
|
|
# ------------------------------------------------------------------ helpers
|
|
@property
|
|
def is_available(self) -> bool:
|
|
return ollama is not None
|
|
|
|
@property
|
|
def default_model(self) -> str | None:
|
|
models = self.list_models()
|
|
return models[0] if models else None
|
|
|
|
def list_models(self, force_refresh: bool = False) -> list[str]:
|
|
"""Return the available model names, caching the result for quick reuse."""
|
|
if not self.is_available:
|
|
return []
|
|
|
|
if self._cached_models is not None and not force_refresh:
|
|
return list(self._cached_models)
|
|
|
|
try:
|
|
response = self._call_sdk("list") # type: ignore[arg-type]
|
|
except OllamaClientError:
|
|
return []
|
|
|
|
models: list[str] = []
|
|
for item in response.get("models", []): # type: ignore[assignment]
|
|
name = item.get("name") or item.get("model")
|
|
if name:
|
|
models.append(name)
|
|
|
|
self._cached_models = models
|
|
return list(models)
|
|
|
|
# ------------------------------------------------------------------ chat APIs
|
|
def chat(
|
|
self,
|
|
*,
|
|
model: str,
|
|
messages: Iterable[Dict[str, str]],
|
|
) -> dict[str, str] | None:
|
|
"""Execute a blocking chat call against Ollama."""
|
|
if not self.is_available:
|
|
return {
|
|
"role": "assistant",
|
|
"content": "Ollama SDK is not installed; install `ollama` to enable responses.",
|
|
}
|
|
|
|
try:
|
|
result = self._call_sdk(
|
|
"chat",
|
|
model=model,
|
|
messages=list(messages),
|
|
stream=False,
|
|
)
|
|
except OllamaClientError as exc:
|
|
return {
|
|
"role": "assistant",
|
|
"content": f"Unable to reach Ollama: {exc}",
|
|
}
|
|
|
|
message = result.get("message") if isinstance(result, dict) else None
|
|
if not message:
|
|
return {"role": "assistant", "content": ""}
|
|
|
|
role = message.get("role") or "assistant"
|
|
content = message.get("content") or ""
|
|
return {"role": role, "content": content}
|
|
|
|
def stream_chat(
|
|
self, *, model: str, messages: Iterable[Dict[str, str]]
|
|
) -> Iterator[dict[str, Any]]:
|
|
"""Placeholder that exposes the streaming API for future UI hooks."""
|
|
if not self.is_available:
|
|
raise OllamaUnavailableError(
|
|
"Streaming requires the Ollama Python SDK to be installed."
|
|
)
|
|
|
|
try:
|
|
stream = self._call_sdk(
|
|
"chat",
|
|
model=model,
|
|
messages=list(messages),
|
|
stream=True,
|
|
)
|
|
except OllamaClientError as exc:
|
|
raise OllamaClientError(f"Failed to start streaming chat: {exc}") from exc
|
|
|
|
if not hasattr(stream, "__iter__"):
|
|
raise OllamaClientError("Ollama returned a non-iterable stream response.")
|
|
return iter(stream)
|
|
|
|
# ------------------------------------------------------------------ internals
|
|
def _call_sdk(self, method: str, *args: Any, **kwargs: Any) -> Any:
|
|
if not self.is_available:
|
|
raise OllamaUnavailableError(
|
|
"Ollama Python SDK is not available in the environment."
|
|
)
|
|
|
|
target = self._client if self._client is not None else ollama
|
|
if target is None or not hasattr(target, method):
|
|
raise OllamaClientError(
|
|
f"Ollama SDK does not expose method '{method}'. Install or update the SDK."
|
|
)
|
|
|
|
func = getattr(target, method)
|
|
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except Exception as exc: # pragma: no cover - network errors depend on runtime
|
|
raise OllamaClientError(str(exc)) from exc
|