Files
niri-ai-sidebar/provider_client.py
Melvin Ragusa 6cc11fc9e4 feat: add multi-provider support to chat widget
- Added support for multiple AI providers (Ollama, Gemini, OpenRouter, Copilot) with provider abstraction layer
- Created settings view with provider configuration and API key management
- Updated UI to show current provider status and handle provider-specific availability
- Modified reasoning mode to work exclusively with Ollama provider
- Added provider switching functionality with persistent settings
- Updated error messages and placeholders to be
2025-10-31 00:08:04 +01:00

757 lines
25 KiB
Python

"""Provider abstraction layer for multiple AI providers."""
from __future__ import annotations
import json
import webbrowser
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Iterator
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
from .ollama_client import OllamaClient
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@property
@abstractmethod
def name(self) -> str:
"""Return the provider name."""
pass
@property
@abstractmethod
def is_available(self) -> bool:
"""Check if the provider is available."""
pass
@property
@abstractmethod
def default_model(self) -> str | None:
"""Get the default model for this provider."""
pass
@abstractmethod
def list_models(self, force_refresh: bool = False) -> list[str]:
"""List available models for this provider."""
pass
@abstractmethod
def chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> dict[str, str] | None:
"""Execute a blocking chat call.
Returns:
Dictionary with 'role' and 'content' keys, or None on error
"""
pass
@abstractmethod
def stream_chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> Iterator[dict[str, Any]]:
"""Execute a streaming chat call.
Yields dictionaries containing token data from the streaming response.
Each yielded dict should follow the format:
- 'message' dict with 'content' (and optionally 'thinking')
- 'done' boolean flag
- 'error' boolean flag (optional)
"""
pass
def test_connection(self) -> tuple[bool, str]:
"""Test the connection to the provider.
Returns:
Tuple of (success: bool, message: str)
"""
try:
models = self.list_models(force_refresh=True)
if models:
return True, f"Connected successfully. Found {len(models)} model(s)."
else:
return False, "Connected but no models available."
except Exception as e:
return False, f"Connection failed: {str(e)}"
class OllamaProvider(AIProvider):
"""Ollama provider wrapper."""
def __init__(self, host: str | None = None):
"""Initialize Ollama provider.
Args:
host: Ollama server host (default: http://localhost:11434)
"""
self._client = OllamaClient(host)
@property
def name(self) -> str:
return "ollama"
@property
def is_available(self) -> bool:
return self._client.is_available
@property
def default_model(self) -> str | None:
return self._client.default_model
def list_models(self, force_refresh: bool = False) -> list[str]:
return self._client.list_models(force_refresh=force_refresh)
def chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> dict[str, str] | None:
return self._client.chat(model=model, messages=messages, options=options)
def stream_chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> Iterator[dict[str, Any]]:
return self._client.stream_chat(model=model, messages=messages, options=options)
class GeminiProvider(AIProvider):
"""Google Gemini provider."""
def __init__(self, api_key: str | None = None):
"""Initialize Gemini provider.
Args:
api_key: Google Gemini API key
"""
self._api_key = api_key
self._is_available = False
self._cached_models: list[str] | None = None
self._base_url = "https://generativelanguage.googleapis.com/v1beta"
if api_key:
self._check_connection()
def _check_connection(self) -> None:
"""Check if API key is valid."""
if not self._api_key:
self._is_available = False
return
try:
# Try to list models
req = Request(
f"{self._base_url}/models?key={self._api_key}",
method="GET"
)
with urlopen(req, timeout=5) as response:
if response.status == 200:
self._is_available = True
else:
self._is_available = False
except Exception:
self._is_available = False
@property
def name(self) -> str:
return "gemini"
@property
def is_available(self) -> bool:
return self._is_available and self._api_key is not None
@property
def default_model(self) -> str | None:
models = self.list_models()
# Prefer gemini-pro, fallback to first available
if "gemini-pro" in models:
return "gemini-pro"
return models[0] if models else None
def list_models(self, force_refresh: bool = False) -> list[str]:
"""List available Gemini models."""
if self._cached_models is not None and not force_refresh:
return list(self._cached_models)
if not self._api_key:
return []
try:
req = Request(
f"{self._base_url}/models?key={self._api_key}",
method="GET"
)
with urlopen(req, timeout=10) as response:
data = json.loads(response.read().decode())
models = []
for model in data.get("models", []):
name = model.get("name", "")
# Extract model name (e.g., "models/gemini-pro" -> "gemini-pro")
if name.startswith("models/"):
models.append(name[7:])
elif name:
models.append(name)
self._cached_models = models
self._is_available = True
return models
except Exception:
self._is_available = False
return []
def chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> dict[str, str] | None:
"""Execute a blocking chat call."""
if not self._api_key:
return {
"role": "assistant",
"content": "Gemini API key not configured. Please set your API key in settings.",
}
# Convert messages to Gemini format
gemini_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Gemini uses "user" and "model" instead of "assistant"
gemini_role = "user" if role == "user" else "model"
gemini_messages.append({"role": gemini_role, "parts": [{"text": content}]})
payload = {
"contents": gemini_messages,
}
try:
req = Request(
f"{self._base_url}/models/{model}:generateContent?key={self._api_key}",
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
with urlopen(req, timeout=120) as response:
result = json.loads(response.read().decode())
candidates = result.get("candidates", [])
if candidates:
content_parts = candidates[0].get("content", {}).get("parts", [])
if content_parts:
text = content_parts[0].get("text", "")
return {"role": "assistant", "content": text}
return {"role": "assistant", "content": ""}
except Exception as exc:
return {
"role": "assistant",
"content": f"Gemini API error: {exc}",
}
def stream_chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> Iterator[dict[str, Any]]:
"""Execute a streaming chat call."""
if not self._api_key:
yield {
"role": "assistant",
"content": "Gemini API key not configured. Please set your API key in settings.",
"done": True,
"error": True,
}
return
# Convert messages to Gemini format
gemini_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
gemini_role = "user" if role == "user" else "model"
gemini_messages.append({"role": gemini_role, "parts": [{"text": content}]})
payload = {
"contents": gemini_messages,
}
try:
req = Request(
f"{self._base_url}/models/{model}:streamGenerateContent?key={self._api_key}",
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
with urlopen(req, timeout=120) as response:
for line in response:
if not line:
continue
try:
chunk_data = json.loads(line.decode("utf-8"))
candidates = chunk_data.get("candidates", [])
if candidates:
content_parts = candidates[0].get("content", {}).get("parts", [])
if content_parts:
text = content_parts[0].get("text", "")
if text:
yield {
"message": {"content": text},
"done": False,
}
# Check if this is the final chunk
finish_reason = candidates[0].get("finishReason") if candidates else None
if finish_reason:
yield {"done": True}
break
except json.JSONDecodeError:
continue
except Exception as exc:
yield {
"role": "assistant",
"content": f"Gemini API error: {exc}",
"done": True,
"error": True,
}
class OpenRouterProvider(AIProvider):
"""OpenRouter provider."""
def __init__(self, api_key: str | None = None):
"""Initialize OpenRouter provider.
Args:
api_key: OpenRouter API key
"""
self._api_key = api_key
self._is_available = False
self._cached_models: list[str] | None = None
self._base_url = "https://openrouter.ai/api/v1"
if api_key:
self._check_connection()
def _check_connection(self) -> None:
"""Check if API key is valid."""
if not self._api_key:
self._is_available = False
return
try:
req = Request(
f"{self._base_url}/models",
headers={"Authorization": f"Bearer {self._api_key}"},
method="GET"
)
with urlopen(req, timeout=5) as response:
if response.status == 200:
self._is_available = True
else:
self._is_available = False
except Exception:
self._is_available = False
@property
def name(self) -> str:
return "openrouter"
@property
def is_available(self) -> bool:
return self._is_available and self._api_key is not None
@property
def default_model(self) -> str | None:
models = self.list_models()
# Prefer common models
preferred = ["openai/gpt-4", "anthropic/claude-3-opus", "meta-llama/llama-3-70b-instruct"]
for pref in preferred:
if pref in models:
return pref
return models[0] if models else None
def list_models(self, force_refresh: bool = False) -> list[str]:
"""List available OpenRouter models."""
if self._cached_models is not None and not force_refresh:
return list(self._cached_models)
if not self._api_key:
return []
try:
req = Request(
f"{self._base_url}/models",
headers={"Authorization": f"Bearer {self._api_key}"},
method="GET"
)
with urlopen(req, timeout=10) as response:
data = json.loads(response.read().decode())
models = []
for model in data.get("data", []):
model_id = model.get("id")
if model_id:
models.append(model_id)
self._cached_models = models
self._is_available = True
return models
except Exception:
self._is_available = False
return []
def chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> dict[str, str] | None:
"""Execute a blocking chat call."""
if not self._api_key:
return {
"role": "assistant",
"content": "OpenRouter API key not configured. Please set your API key in settings.",
}
payload = {
"model": model,
"messages": list(messages),
}
try:
req = Request(
f"{self._base_url}/chat/completions",
data=json.dumps(payload).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
},
method="POST",
)
with urlopen(req, timeout=120) as response:
result = json.loads(response.read().decode())
choices = result.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content", "")
return {"role": "assistant", "content": content}
return {"role": "assistant", "content": ""}
except Exception as exc:
return {
"role": "assistant",
"content": f"OpenRouter API error: {exc}",
}
def stream_chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> Iterator[dict[str, Any]]:
"""Execute a streaming chat call."""
if not self._api_key:
yield {
"role": "assistant",
"content": "OpenRouter API key not configured. Please set your API key in settings.",
"done": True,
"error": True,
}
return
payload = {
"model": model,
"messages": list(messages),
"stream": True,
}
try:
req = Request(
f"{self._base_url}/chat/completions",
data=json.dumps(payload).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
},
method="POST",
)
with urlopen(req, timeout=120) as response:
for line in response:
if not line:
continue
line_str = line.decode("utf-8").strip()
if not line_str or line_str == "data: [DONE]":
continue
if line_str.startswith("data: "):
line_str = line_str[6:] # Remove "data: " prefix
try:
chunk_data = json.loads(line_str)
choices = chunk_data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield {
"message": {"content": content},
"done": False,
}
# Check if finished
finish_reason = choices[0].get("finish_reason")
if finish_reason:
yield {"done": True}
break
except json.JSONDecodeError:
continue
except Exception as exc:
yield {
"role": "assistant",
"content": f"OpenRouter API error: {exc}",
"done": True,
"error": True,
}
class CopilotProvider(AIProvider):
"""GitHub Copilot provider with OAuth."""
def __init__(self, oauth_token: str | None = None):
"""Initialize Copilot provider.
Args:
oauth_token: GitHub OAuth access token
"""
self._oauth_token = oauth_token
self._is_available = False
self._cached_models: list[str] | None = None
self._base_url = "https://api.githubcopilot.com"
if oauth_token:
self._check_connection()
def _check_connection(self) -> None:
"""Check if OAuth token is valid."""
if not self._oauth_token:
self._is_available = False
return
# Basic availability check - could be enhanced with actual API call
self._is_available = True
@property
def name(self) -> str:
return "copilot"
@property
def is_available(self) -> bool:
return self._is_available and self._oauth_token is not None
@property
def default_model(self) -> str:
# GitHub Copilot uses a fixed model
return "github-copilot"
def list_models(self, force_refresh: bool = False) -> list[str]:
"""List available models (Copilot uses a fixed model)."""
if self._cached_models is None or force_refresh:
self._cached_models = ["github-copilot"]
return list(self._cached_models)
def chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> dict[str, str] | None:
"""Execute a blocking chat call."""
if not self._oauth_token:
return {
"role": "assistant",
"content": "GitHub Copilot not authenticated. Please authenticate via OAuth in settings.",
}
# Note: GitHub Copilot Chat API endpoint may vary
# This is a placeholder implementation
payload = {
"model": model,
"messages": list(messages),
}
try:
req = Request(
f"{self._base_url}/chat/completions",
data=json.dumps(payload).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self._oauth_token}",
},
method="POST",
)
with urlopen(req, timeout=120) as response:
result = json.loads(response.read().decode())
choices = result.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content", "")
return {"role": "assistant", "content": content}
return {"role": "assistant", "content": ""}
except Exception as exc:
return {
"role": "assistant",
"content": f"GitHub Copilot API error: {exc}. Note: GitHub Copilot Chat API access requires a Copilot subscription.",
}
def stream_chat(
self,
*,
model: str,
messages: Iterable[Dict[str, str]],
options: Dict[str, Any] | None = None,
) -> Iterator[dict[str, Any]]:
"""Execute a streaming chat call."""
if not self._oauth_token:
yield {
"role": "assistant",
"content": "GitHub Copilot not authenticated. Please authenticate via OAuth in settings.",
"done": True,
"error": True,
}
return
payload = {
"model": model,
"messages": list(messages),
"stream": True,
}
try:
req = Request(
f"{self._base_url}/chat/completions",
data=json.dumps(payload).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self._oauth_token}",
},
method="POST",
)
with urlopen(req, timeout=120) as response:
for line in response:
if not line:
continue
line_str = line.decode("utf-8").strip()
if not line_str or line_str == "data: [DONE]":
continue
if line_str.startswith("data: "):
line_str = line_str[6:]
try:
chunk_data = json.loads(line_str)
choices = chunk_data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield {
"message": {"content": content},
"done": False,
}
finish_reason = choices[0].get("finish_reason")
if finish_reason:
yield {"done": True}
break
except json.JSONDecodeError:
continue
except Exception as exc:
yield {
"role": "assistant",
"content": f"GitHub Copilot API error: {exc}. Note: GitHub Copilot Chat API access requires a Copilot subscription.",
"done": True,
"error": True,
}
@staticmethod
def get_oauth_url(client_id: str, redirect_uri: str, scopes: list[str] = None) -> str:
"""Generate OAuth authorization URL.
Args:
client_id: GitHub OAuth app client ID
redirect_uri: OAuth redirect URI
scopes: List of OAuth scopes (default: ['copilot'])
Returns:
OAuth authorization URL
"""
if scopes is None:
scopes = ["copilot"]
scope_str = " ".join(scopes)
return (
f"https://github.com/login/oauth/authorize"
f"?client_id={client_id}"
f"&redirect_uri={redirect_uri}"
f"&scope={scope_str}"
)
@staticmethod
def exchange_code_for_token(
client_id: str, client_secret: str, code: str, redirect_uri: str
) -> str | None:
"""Exchange authorization code for access token.
Args:
client_id: GitHub OAuth app client ID
client_secret: GitHub OAuth app client secret
code: Authorization code from OAuth callback
redirect_uri: OAuth redirect URI
Returns:
Access token or None on error
"""
payload = {
"client_id": client_id,
"client_secret": client_secret,
"code": code,
"redirect_uri": redirect_uri,
}
try:
req = Request(
"https://github.com/login/oauth/access_token",
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json", "Accept": "application/json"},
method="POST",
)
with urlopen(req, timeout=10) as response:
result = json.loads(response.read().decode())
return result.get("access_token")
except Exception:
return None