"""Provider abstraction layer for multiple AI providers.""" from __future__ import annotations import json import webbrowser from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Iterator from urllib.request import Request, urlopen from urllib.error import URLError, HTTPError from .ollama_client import OllamaClient class AIProvider(ABC): """Abstract base class for AI providers.""" @property @abstractmethod def name(self) -> str: """Return the provider name.""" pass @property @abstractmethod def is_available(self) -> bool: """Check if the provider is available.""" pass @property @abstractmethod def default_model(self) -> str | None: """Get the default model for this provider.""" pass @abstractmethod def list_models(self, force_refresh: bool = False) -> list[str]: """List available models for this provider.""" pass @abstractmethod def chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> dict[str, str] | None: """Execute a blocking chat call. Returns: Dictionary with 'role' and 'content' keys, or None on error """ pass @abstractmethod def stream_chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> Iterator[dict[str, Any]]: """Execute a streaming chat call. Yields dictionaries containing token data from the streaming response. Each yielded dict should follow the format: - 'message' dict with 'content' (and optionally 'thinking') - 'done' boolean flag - 'error' boolean flag (optional) """ pass def test_connection(self) -> tuple[bool, str]: """Test the connection to the provider. Returns: Tuple of (success: bool, message: str) """ try: models = self.list_models(force_refresh=True) if models: return True, f"Connected successfully. Found {len(models)} model(s)." else: return False, "Connected but no models available." except Exception as e: return False, f"Connection failed: {str(e)}" class OllamaProvider(AIProvider): """Ollama provider wrapper.""" def __init__(self, host: str | None = None): """Initialize Ollama provider. Args: host: Ollama server host (default: http://localhost:11434) """ self._client = OllamaClient(host) @property def name(self) -> str: return "ollama" @property def is_available(self) -> bool: return self._client.is_available @property def default_model(self) -> str | None: return self._client.default_model def list_models(self, force_refresh: bool = False) -> list[str]: return self._client.list_models(force_refresh=force_refresh) def chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> dict[str, str] | None: return self._client.chat(model=model, messages=messages, options=options) def stream_chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> Iterator[dict[str, Any]]: return self._client.stream_chat(model=model, messages=messages, options=options) class GeminiProvider(AIProvider): """Google Gemini provider.""" def __init__(self, api_key: str | None = None): """Initialize Gemini provider. Args: api_key: Google Gemini API key """ self._api_key = api_key self._is_available = False self._cached_models: list[str] | None = None self._base_url = "https://generativelanguage.googleapis.com/v1beta" if api_key: self._check_connection() def _check_connection(self) -> None: """Check if API key is valid.""" if not self._api_key: self._is_available = False return try: # Try to list models req = Request( f"{self._base_url}/models?key={self._api_key}", method="GET" ) with urlopen(req, timeout=5) as response: if response.status == 200: self._is_available = True else: self._is_available = False except Exception: self._is_available = False @property def name(self) -> str: return "gemini" @property def is_available(self) -> bool: return self._is_available and self._api_key is not None @property def default_model(self) -> str | None: models = self.list_models() # Prefer gemini-pro, fallback to first available if "gemini-pro" in models: return "gemini-pro" return models[0] if models else None def list_models(self, force_refresh: bool = False) -> list[str]: """List available Gemini models.""" if self._cached_models is not None and not force_refresh: return list(self._cached_models) if not self._api_key: return [] try: req = Request( f"{self._base_url}/models?key={self._api_key}", method="GET" ) with urlopen(req, timeout=10) as response: data = json.loads(response.read().decode()) models = [] for model in data.get("models", []): name = model.get("name", "") # Extract model name (e.g., "models/gemini-pro" -> "gemini-pro") if name.startswith("models/"): models.append(name[7:]) elif name: models.append(name) self._cached_models = models self._is_available = True return models except Exception: self._is_available = False return [] def chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> dict[str, str] | None: """Execute a blocking chat call.""" if not self._api_key: return { "role": "assistant", "content": "Gemini API key not configured. Please set your API key in settings.", } # Convert messages to Gemini format gemini_messages = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") # Gemini uses "user" and "model" instead of "assistant" gemini_role = "user" if role == "user" else "model" gemini_messages.append({"role": gemini_role, "parts": [{"text": content}]}) payload = { "contents": gemini_messages, } try: req = Request( f"{self._base_url}/models/{model}:generateContent?key={self._api_key}", data=json.dumps(payload).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) with urlopen(req, timeout=120) as response: result = json.loads(response.read().decode()) candidates = result.get("candidates", []) if candidates: content_parts = candidates[0].get("content", {}).get("parts", []) if content_parts: text = content_parts[0].get("text", "") return {"role": "assistant", "content": text} return {"role": "assistant", "content": ""} except Exception as exc: return { "role": "assistant", "content": f"Gemini API error: {exc}", } def stream_chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> Iterator[dict[str, Any]]: """Execute a streaming chat call.""" if not self._api_key: yield { "role": "assistant", "content": "Gemini API key not configured. Please set your API key in settings.", "done": True, "error": True, } return # Convert messages to Gemini format gemini_messages = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") gemini_role = "user" if role == "user" else "model" gemini_messages.append({"role": gemini_role, "parts": [{"text": content}]}) payload = { "contents": gemini_messages, } try: req = Request( f"{self._base_url}/models/{model}:streamGenerateContent?key={self._api_key}", data=json.dumps(payload).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST", ) with urlopen(req, timeout=120) as response: for line in response: if not line: continue try: chunk_data = json.loads(line.decode("utf-8")) candidates = chunk_data.get("candidates", []) if candidates: content_parts = candidates[0].get("content", {}).get("parts", []) if content_parts: text = content_parts[0].get("text", "") if text: yield { "message": {"content": text}, "done": False, } # Check if this is the final chunk finish_reason = candidates[0].get("finishReason") if candidates else None if finish_reason: yield {"done": True} break except json.JSONDecodeError: continue except Exception as exc: yield { "role": "assistant", "content": f"Gemini API error: {exc}", "done": True, "error": True, } class OpenRouterProvider(AIProvider): """OpenRouter provider.""" def __init__(self, api_key: str | None = None): """Initialize OpenRouter provider. Args: api_key: OpenRouter API key """ self._api_key = api_key self._is_available = False self._cached_models: list[str] | None = None self._base_url = "https://openrouter.ai/api/v1" if api_key: self._check_connection() def _check_connection(self) -> None: """Check if API key is valid.""" if not self._api_key: self._is_available = False return try: req = Request( f"{self._base_url}/models", headers={"Authorization": f"Bearer {self._api_key}"}, method="GET" ) with urlopen(req, timeout=5) as response: if response.status == 200: self._is_available = True else: self._is_available = False except Exception: self._is_available = False @property def name(self) -> str: return "openrouter" @property def is_available(self) -> bool: return self._is_available and self._api_key is not None @property def default_model(self) -> str | None: models = self.list_models() # Prefer common models preferred = ["openai/gpt-4", "anthropic/claude-3-opus", "meta-llama/llama-3-70b-instruct"] for pref in preferred: if pref in models: return pref return models[0] if models else None def list_models(self, force_refresh: bool = False) -> list[str]: """List available OpenRouter models.""" if self._cached_models is not None and not force_refresh: return list(self._cached_models) if not self._api_key: return [] try: req = Request( f"{self._base_url}/models", headers={"Authorization": f"Bearer {self._api_key}"}, method="GET" ) with urlopen(req, timeout=10) as response: data = json.loads(response.read().decode()) models = [] for model in data.get("data", []): model_id = model.get("id") if model_id: models.append(model_id) self._cached_models = models self._is_available = True return models except Exception: self._is_available = False return [] def chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> dict[str, str] | None: """Execute a blocking chat call.""" if not self._api_key: return { "role": "assistant", "content": "OpenRouter API key not configured. Please set your API key in settings.", } payload = { "model": model, "messages": list(messages), } try: req = Request( f"{self._base_url}/chat/completions", data=json.dumps(payload).encode("utf-8"), headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self._api_key}", }, method="POST", ) with urlopen(req, timeout=120) as response: result = json.loads(response.read().decode()) choices = result.get("choices", []) if choices: message = choices[0].get("message", {}) content = message.get("content", "") return {"role": "assistant", "content": content} return {"role": "assistant", "content": ""} except Exception as exc: return { "role": "assistant", "content": f"OpenRouter API error: {exc}", } def stream_chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> Iterator[dict[str, Any]]: """Execute a streaming chat call.""" if not self._api_key: yield { "role": "assistant", "content": "OpenRouter API key not configured. Please set your API key in settings.", "done": True, "error": True, } return payload = { "model": model, "messages": list(messages), "stream": True, } try: req = Request( f"{self._base_url}/chat/completions", data=json.dumps(payload).encode("utf-8"), headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self._api_key}", }, method="POST", ) with urlopen(req, timeout=120) as response: for line in response: if not line: continue line_str = line.decode("utf-8").strip() if not line_str or line_str == "data: [DONE]": continue if line_str.startswith("data: "): line_str = line_str[6:] # Remove "data: " prefix try: chunk_data = json.loads(line_str) choices = chunk_data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: yield { "message": {"content": content}, "done": False, } # Check if finished finish_reason = choices[0].get("finish_reason") if finish_reason: yield {"done": True} break except json.JSONDecodeError: continue except Exception as exc: yield { "role": "assistant", "content": f"OpenRouter API error: {exc}", "done": True, "error": True, } class CopilotProvider(AIProvider): """GitHub Copilot provider with OAuth.""" def __init__(self, oauth_token: str | None = None): """Initialize Copilot provider. Args: oauth_token: GitHub OAuth access token """ self._oauth_token = oauth_token self._is_available = False self._cached_models: list[str] | None = None self._base_url = "https://api.githubcopilot.com" if oauth_token: self._check_connection() def _check_connection(self) -> None: """Check if OAuth token is valid.""" if not self._oauth_token: self._is_available = False return # Basic availability check - could be enhanced with actual API call self._is_available = True @property def name(self) -> str: return "copilot" @property def is_available(self) -> bool: return self._is_available and self._oauth_token is not None @property def default_model(self) -> str: # GitHub Copilot uses a fixed model return "github-copilot" def list_models(self, force_refresh: bool = False) -> list[str]: """List available models (Copilot uses a fixed model).""" if self._cached_models is None or force_refresh: self._cached_models = ["github-copilot"] return list(self._cached_models) def chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> dict[str, str] | None: """Execute a blocking chat call.""" if not self._oauth_token: return { "role": "assistant", "content": "GitHub Copilot not authenticated. Please authenticate via OAuth in settings.", } # Note: GitHub Copilot Chat API endpoint may vary # This is a placeholder implementation payload = { "model": model, "messages": list(messages), } try: req = Request( f"{self._base_url}/chat/completions", data=json.dumps(payload).encode("utf-8"), headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self._oauth_token}", }, method="POST", ) with urlopen(req, timeout=120) as response: result = json.loads(response.read().decode()) choices = result.get("choices", []) if choices: message = choices[0].get("message", {}) content = message.get("content", "") return {"role": "assistant", "content": content} return {"role": "assistant", "content": ""} except Exception as exc: return { "role": "assistant", "content": f"GitHub Copilot API error: {exc}. Note: GitHub Copilot Chat API access requires a Copilot subscription.", } def stream_chat( self, *, model: str, messages: Iterable[Dict[str, str]], options: Dict[str, Any] | None = None, ) -> Iterator[dict[str, Any]]: """Execute a streaming chat call.""" if not self._oauth_token: yield { "role": "assistant", "content": "GitHub Copilot not authenticated. Please authenticate via OAuth in settings.", "done": True, "error": True, } return payload = { "model": model, "messages": list(messages), "stream": True, } try: req = Request( f"{self._base_url}/chat/completions", data=json.dumps(payload).encode("utf-8"), headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self._oauth_token}", }, method="POST", ) with urlopen(req, timeout=120) as response: for line in response: if not line: continue line_str = line.decode("utf-8").strip() if not line_str or line_str == "data: [DONE]": continue if line_str.startswith("data: "): line_str = line_str[6:] try: chunk_data = json.loads(line_str) choices = chunk_data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: yield { "message": {"content": content}, "done": False, } finish_reason = choices[0].get("finish_reason") if finish_reason: yield {"done": True} break except json.JSONDecodeError: continue except Exception as exc: yield { "role": "assistant", "content": f"GitHub Copilot API error: {exc}. Note: GitHub Copilot Chat API access requires a Copilot subscription.", "done": True, "error": True, } @staticmethod def get_oauth_url(client_id: str, redirect_uri: str, scopes: list[str] = None) -> str: """Generate OAuth authorization URL. Args: client_id: GitHub OAuth app client ID redirect_uri: OAuth redirect URI scopes: List of OAuth scopes (default: ['copilot']) Returns: OAuth authorization URL """ if scopes is None: scopes = ["copilot"] scope_str = " ".join(scopes) return ( f"https://github.com/login/oauth/authorize" f"?client_id={client_id}" f"&redirect_uri={redirect_uri}" f"&scope={scope_str}" ) @staticmethod def exchange_code_for_token( client_id: str, client_secret: str, code: str, redirect_uri: str ) -> str | None: """Exchange authorization code for access token. Args: client_id: GitHub OAuth app client ID client_secret: GitHub OAuth app client secret code: Authorization code from OAuth callback redirect_uri: OAuth redirect URI Returns: Access token or None on error """ payload = { "client_id": client_id, "client_secret": client_secret, "code": code, "redirect_uri": redirect_uri, } try: req = Request( "https://github.com/login/oauth/access_token", data=json.dumps(payload).encode("utf-8"), headers={"Content-Type": "application/json", "Accept": "application/json"}, method="POST", ) with urlopen(req, timeout=10) as response: result = json.loads(response.read().decode()) return result.get("access_token") except Exception: return None