diff --git a/chat_widget.py b/chat_widget.py index 955cb4a..0789c22 100644 --- a/chat_widget.py +++ b/chat_widget.py @@ -9,6 +9,14 @@ from .ollama_monitor import OllamaAvailabilityMonitor from .streaming_handler import StreamingHandler from .command_processor import CommandProcessor, CommandResult from .reasoning_controller import ReasoningController +from .provider_client import ( + AIProvider, + OllamaProvider, + GeminiProvider, + OpenRouterProvider, + CopilotProvider, +) +from .settings_widget import SettingsWidget class ChatWidget(widgets.Box): @@ -19,23 +27,30 @@ class ChatWidget(widgets.Box): self._load_css() self._conversation_manager = ConversationManager() self._conversation_archive = ConversationArchive() - self._ollama_client = OllamaClient() - self._current_model = self._ollama_client.default_model - - # Initialize availability monitor - self._ollama_monitor = OllamaAvailabilityMonitor(self._ollama_client) - self._ollama_monitor.add_callback(self._on_ollama_availability_changed) - - # Initialize command processor - self._command_processor = CommandProcessor() - self._register_commands() # Initialize reasoning controller self._reasoning_controller = ReasoningController() - # Set initial model based on reasoning preference - if self._ollama_client.is_available: - self._current_model = self._reasoning_controller.get_model_name() + # Initialize provider abstraction + self._current_provider: AIProvider | None = None + self._current_model: str | None = None + self._provider_instances: dict[str, AIProvider] = {} + self._initialize_provider() + + # Initialize availability monitor (only for Ollama) + ollama_provider = self._get_provider("ollama") + if isinstance(ollama_provider, OllamaProvider): + self._ollama_client = ollama_provider._client + self._ollama_monitor = OllamaAvailabilityMonitor(self._ollama_client) + self._ollama_monitor.add_callback(self._on_ollama_availability_changed) + self._ollama_monitor.start() + else: + self._ollama_client = None + self._ollama_monitor = None + + # Initialize command processor + self._command_processor = CommandProcessor() + self._register_commands() # Header with title and model header_title = widgets.Label( @@ -44,17 +59,13 @@ class ChatWidget(widgets.Box): css_classes=["title-2"], ) - # Display connection status if Ollama unavailable at startup - if not self._ollama_client.is_available: - model_name = "Ollama not running" - else: - model_name = self._current_model or "No local model detected" - + # Display provider and model status self._model_label = widgets.Label( - label=f"Model: {model_name}", + label="", halign="start", css_classes=["dim-label"], ) + self._update_model_label() # Reasoning mode toggle button (using regular button with state tracking) self._reasoning_enabled = self._reasoning_controller.is_enabled() @@ -66,11 +77,19 @@ class ChatWidget(widgets.Box): hexpand=False, ) - # Header top row with title and toggle + # Settings button (gear icon) + settings_button = widgets.Button( + label="⚙️", + on_click=lambda x: self._show_settings(), + halign="end", + hexpand=False, + ) + + # Header top row with title, settings, and toggle header_top = widgets.Box( spacing=8, hexpand=True, - child=[header_title, self._reasoning_toggle], + child=[header_title, settings_button, self._reasoning_toggle], ) header_box = widgets.Box( @@ -109,8 +128,7 @@ class ChatWidget(widgets.Box): self._text_view.set_size_request(300, 60) # Set explicit width and height # Set placeholder text - placeholder = "Ask a question…" if self._ollama_client.is_available else "Ollama not running - start with: ollama serve" - self._placeholder_text = placeholder + self._update_placeholder_text() self._is_placeholder_shown = False self._updating_placeholder = False @@ -180,12 +198,46 @@ class ChatWidget(widgets.Box): # Initialize placeholder display self._update_placeholder() - # Disable input if Ollama unavailable at startup - if not self._ollama_client.is_available: + # Disable input if provider unavailable at startup + if not self._current_provider or not self._current_provider.is_available: self._set_input_enabled(False) - # Start monitoring Ollama availability - self._ollama_monitor.start() + # Create settings widget + self._settings_widget = SettingsWidget( + self._reasoning_controller, + on_provider_changed=self._on_provider_changed_from_settings, + on_back=self._show_chat + ) + + # Create view stack for switching between chat and settings + self._view_stack = Gtk.Stack() + self._view_stack.set_transition_type(Gtk.StackTransitionType.SLIDE_LEFT_RIGHT) + self._view_stack.set_transition_duration(200) + + # Add chat view + chat_container = widgets.Box( + vertical=True, + spacing=12, + hexpand=True, + vexpand=True, + child=[header_box, self._scroller, input_box], + css_classes=["ai-sidebar-content"], + ) + chat_container.set_margin_top(16) + chat_container.set_margin_bottom(16) + chat_container.set_margin_start(16) + chat_container.set_margin_end(16) + + self._view_stack.add_named(chat_container, "chat") + self._view_stack.add_named(self._settings_widget, "settings") + self._view_stack.set_visible_child_name("chat") + + # Replace main container with stack + super().__init__( + hexpand=True, + vexpand=True, + child=[self._view_stack], + ) def _auto_archive_old_messages(self, keep_recent: int = 20): """Auto-archive old messages on startup, keeping only recent ones. @@ -462,11 +514,16 @@ class ChatWidget(widgets.Box): self._append_system_message(result.message) return - # Check Ollama availability before processing regular messages - if not self._ollama_client.is_available: + # Check provider availability before processing regular messages + if not self._current_provider or not self._current_provider.is_available: + provider_name = self._current_provider.name if self._current_provider else "Provider" + if provider_name == "ollama": + error_msg = "Ollama is not running. Please start Ollama with: ollama serve" + else: + error_msg = f"{provider_name.capitalize()} is not configured. Please check settings." self._append_message( "assistant", - "Ollama is not running. Please start Ollama with: ollama serve", + error_msg, persist=False, ) return @@ -477,19 +534,29 @@ class ChatWidget(widgets.Box): def _request_response(self): """Request AI response in background thread with streaming""" # Double-check availability before making request - if not self._ollama_client.is_available: + if not self._current_provider or not self._current_provider.is_available: + provider_name = self._current_provider.name if self._current_provider else "Provider" + if provider_name == "ollama": + error_msg = "Ollama is not running. Please start Ollama with: ollama serve" + else: + error_msg = f"{provider_name.capitalize()} is not configured. Please check settings." self._append_message( "assistant", - "Ollama is not running. Please start Ollama with: ollama serve", + error_msg, persist=False, ) return - model = self._current_model or self._ollama_client.default_model + model = self._current_model or self._current_provider.default_model if not model: + provider_name = self._current_provider.name + if provider_name == "ollama": + error_msg = "No Ollama models are available. Install a model with: ollama pull llama2" + else: + error_msg = f"No {provider_name.capitalize()} models are available. Check settings." self._append_message( "assistant", - "No Ollama models are available. Install a model with: ollama pull llama2", + error_msg, persist=True, ) return @@ -597,11 +664,11 @@ class ChatWidget(widgets.Box): try: handler.start_stream() - # Get model-specific options - options = self._reasoning_controller.get_model_options() + # Get model-specific options (only for Ollama) + options = self._reasoning_controller.get_model_options() if self._current_provider.name == "ollama" else None # Stream response tokens - for chunk in self._ollama_client.stream_chat( + for chunk in self._current_provider.stream_chat( model=model, messages=list(messages), options=options @@ -727,14 +794,117 @@ class ChatWidget(widgets.Box): """Focus the input text view""" self._text_view.grab_focus() + def _initialize_provider(self): + """Initialize the current provider based on preferences.""" + provider_id = self._reasoning_controller.get_provider() + self._current_provider = self._get_provider(provider_id) + + if self._current_provider: + if provider_id == "ollama": + # Use reasoning controller model for Ollama + self._current_model = self._reasoning_controller.get_model_name() + else: + self._current_model = self._current_provider.default_model + else: + self._current_model = None + + def _get_provider(self, provider_id: str) -> AIProvider | None: + """Get or create provider instance.""" + if provider_id in self._provider_instances: + return self._provider_instances[provider_id] + + if provider_id == "ollama": + provider = OllamaProvider() + elif provider_id == "gemini": + api_key = self._reasoning_controller.get_api_key("gemini") + provider = GeminiProvider(api_key=api_key) if api_key else None + elif provider_id == "openrouter": + api_key = self._reasoning_controller.get_api_key("openrouter") + provider = OpenRouterProvider(api_key=api_key) if api_key else None + elif provider_id == "copilot": + token = self._reasoning_controller.get_copilot_token() + provider = CopilotProvider(oauth_token=token) if token else None + else: + return None + + if provider: + self._provider_instances[provider_id] = provider + return provider + + def _update_model_label(self): + """Update the model label with current provider and model.""" + if not hasattr(self, '_model_label'): + return + + if not self._current_provider: + self._model_label.label = "Provider: Not configured" + return + + provider_name = self._current_provider.name.capitalize() + if self._current_provider.is_available: + model_name = self._current_model or "No model selected" + self._model_label.label = f"{provider_name}: {model_name}" + else: + if provider_name == "Ollama": + self._model_label.label = f"{provider_name}: Not running" + else: + self._model_label.label = f"{provider_name}: Not configured" + + def _update_placeholder_text(self): + """Update placeholder text based on provider availability.""" + if not hasattr(self, '_text_view'): + return + + if not self._current_provider or not self._current_provider.is_available: + provider_name = self._current_provider.name if self._current_provider else "Provider" + if provider_name == "ollama": + self._placeholder_text = "Ollama not running - start with: ollama serve" + else: + self._placeholder_text = f"{provider_name.capitalize()} not configured. Check settings." + else: + self._placeholder_text = "Ask a question…" + + def _show_settings(self): + """Show settings view.""" + if hasattr(self, '_view_stack'): + self._view_stack.set_visible_child_name("settings") + + def _show_chat(self): + """Show chat view.""" + if hasattr(self, '_view_stack'): + self._view_stack.set_visible_child_name("chat") + + def _on_provider_changed_from_settings(self, provider_id: str): + """Handle provider change from settings widget.""" + self._current_provider = self._get_provider(provider_id) + if self._current_provider: + if provider_id == "ollama": + self._current_model = self._reasoning_controller.get_model_name() + else: + self._current_model = self._current_provider.default_model + else: + self._current_model = None + + self._update_model_label() + self._update_placeholder_text() + self._update_placeholder() + + if self._current_provider and self._current_provider.is_available: + self._set_input_enabled(True) + else: + self._set_input_enabled(False) + def _on_ollama_availability_changed(self, is_available: bool): """Handle Ollama availability state changes""" + # Only handle if Ollama is the current provider + if self._reasoning_controller.get_provider() != "ollama": + return + if is_available: # Ollama became available - use model from reasoning controller self._current_model = self._reasoning_controller.get_model_name() - model_name = self._current_model or "No local model detected" - self._model_label.label = f"Model: {model_name}" - self._placeholder_text = "Ask a question…" + self._update_model_label() + self._update_placeholder_text() self._update_placeholder() self._set_input_enabled(True) @@ -746,13 +916,18 @@ class ChatWidget(widgets.Box): ) else: # Ollama became unavailable - self._model_label.label = "Model: Ollama not running" - self._placeholder_text = "Ollama not running - start with: ollama serve" + self._update_model_label() + self._update_placeholder_text() self._update_placeholder() self._set_input_enabled(False) def _on_reasoning_toggled(self): """Handle reasoning mode toggle button state changes""" + # Only work for Ollama provider + if self._reasoning_controller.get_provider() != "ollama": + self._append_system_message("Reasoning mode is only available for Ollama provider.") + return + # Toggle the reasoning mode new_state = self._reasoning_controller.toggle() self._reasoning_enabled = new_state @@ -761,14 +936,12 @@ class ChatWidget(widgets.Box): new_model = self._reasoning_controller.get_model_name() self._current_model = new_model - - # Update button label toggle_label = "🧠 Reasoning: ON" if new_state else "🧠 Reasoning: OFF" self._reasoning_toggle.label = toggle_label # Update model label in header - self._model_label.label = f"Model: {new_model}" + self._update_model_label() # Show feedback message status = "enabled" if new_state else "disabled" @@ -807,7 +980,7 @@ class ChatWidget(widgets.Box): Returns: Generated title or empty string if generation fails """ - if not messages or not self._ollama_client.is_available: + if not messages or not self._current_provider or not self._current_provider.is_available: return "" # Extract first few user messages for context @@ -832,14 +1005,14 @@ class ChatWidget(widgets.Box): ] try: - model = self._current_model or self._ollama_client.default_model + model = self._current_model or self._current_provider.default_model if not model: return "" # Use non-streaming chat for title generation - response = self._ollama_client.chat(model=model, messages=title_prompt) - if response and response.get("message"): - title = response["message"].get("content", "").strip() + response = self._current_provider.chat(model=model, messages=title_prompt) + if response and response.get("content"): + title = response["content"].strip() # Clean up the title (remove quotes, limit length) title = title.strip('"\'').strip() # Limit to 50 characters @@ -905,13 +1078,14 @@ class ChatWidget(widgets.Box): Returns: CommandResult with model list """ - if not self._ollama_client.is_available: + if not self._current_provider or not self._current_provider.is_available: + provider_name = self._current_provider.name if self._current_provider else "Provider" return CommandResult( success=False, - message="Ollama is not running. Start Ollama with: ollama serve" + message=f"{provider_name.capitalize()} is not available. Check settings." ) - models = self._ollama_client.list_models(force_refresh=True) + models = self._current_provider.list_models(force_refresh=True) if not models: return CommandResult( @@ -945,10 +1119,11 @@ class ChatWidget(widgets.Box): Returns: CommandResult with success status """ - if not self._ollama_client.is_available: + if not self._current_provider or not self._current_provider.is_available: + provider_name = self._current_provider.name if self._current_provider else "Provider" return CommandResult( success=False, - message="Ollama is not running. Start Ollama with: ollama serve" + message=f"{provider_name.capitalize()} is not available. Check settings." ) model_name = args.strip() @@ -960,7 +1135,7 @@ class ChatWidget(widgets.Box): ) # Validate model exists - available_models = self._ollama_client.list_models(force_refresh=True) + available_models = self._current_provider.list_models(force_refresh=True) if model_name not in available_models: return CommandResult( diff --git a/provider_client.py b/provider_client.py new file mode 100644 index 0000000..0600237 --- /dev/null +++ b/provider_client.py @@ -0,0 +1,756 @@ +"""Provider abstraction layer for multiple AI providers.""" + +from __future__ import annotations + +import json +import webbrowser +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterable, Iterator +from urllib.request import Request, urlopen +from urllib.error import URLError, HTTPError + +from .ollama_client import OllamaClient + + +class AIProvider(ABC): + """Abstract base class for AI providers.""" + + @property + @abstractmethod + def name(self) -> str: + """Return the provider name.""" + pass + + @property + @abstractmethod + def is_available(self) -> bool: + """Check if the provider is available.""" + pass + + @property + @abstractmethod + def default_model(self) -> str | None: + """Get the default model for this provider.""" + pass + + @abstractmethod + def list_models(self, force_refresh: bool = False) -> list[str]: + """List available models for this provider.""" + pass + + @abstractmethod + def chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> dict[str, str] | None: + """Execute a blocking chat call. + + Returns: + Dictionary with 'role' and 'content' keys, or None on error + """ + pass + + @abstractmethod + def stream_chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> Iterator[dict[str, Any]]: + """Execute a streaming chat call. + + Yields dictionaries containing token data from the streaming response. + Each yielded dict should follow the format: + - 'message' dict with 'content' (and optionally 'thinking') + - 'done' boolean flag + - 'error' boolean flag (optional) + """ + pass + + def test_connection(self) -> tuple[bool, str]: + """Test the connection to the provider. + + Returns: + Tuple of (success: bool, message: str) + """ + try: + models = self.list_models(force_refresh=True) + if models: + return True, f"Connected successfully. Found {len(models)} model(s)." + else: + return False, "Connected but no models available." + except Exception as e: + return False, f"Connection failed: {str(e)}" + + +class OllamaProvider(AIProvider): + """Ollama provider wrapper.""" + + def __init__(self, host: str | None = None): + """Initialize Ollama provider. + + Args: + host: Ollama server host (default: http://localhost:11434) + """ + self._client = OllamaClient(host) + + @property + def name(self) -> str: + return "ollama" + + @property + def is_available(self) -> bool: + return self._client.is_available + + @property + def default_model(self) -> str | None: + return self._client.default_model + + def list_models(self, force_refresh: bool = False) -> list[str]: + return self._client.list_models(force_refresh=force_refresh) + + def chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> dict[str, str] | None: + return self._client.chat(model=model, messages=messages, options=options) + + def stream_chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> Iterator[dict[str, Any]]: + return self._client.stream_chat(model=model, messages=messages, options=options) + + +class GeminiProvider(AIProvider): + """Google Gemini provider.""" + + def __init__(self, api_key: str | None = None): + """Initialize Gemini provider. + + Args: + api_key: Google Gemini API key + """ + self._api_key = api_key + self._is_available = False + self._cached_models: list[str] | None = None + self._base_url = "https://generativelanguage.googleapis.com/v1beta" + + if api_key: + self._check_connection() + + def _check_connection(self) -> None: + """Check if API key is valid.""" + if not self._api_key: + self._is_available = False + return + + try: + # Try to list models + req = Request( + f"{self._base_url}/models?key={self._api_key}", + method="GET" + ) + with urlopen(req, timeout=5) as response: + if response.status == 200: + self._is_available = True + else: + self._is_available = False + except Exception: + self._is_available = False + + @property + def name(self) -> str: + return "gemini" + + @property + def is_available(self) -> bool: + return self._is_available and self._api_key is not None + + @property + def default_model(self) -> str | None: + models = self.list_models() + # Prefer gemini-pro, fallback to first available + if "gemini-pro" in models: + return "gemini-pro" + return models[0] if models else None + + def list_models(self, force_refresh: bool = False) -> list[str]: + """List available Gemini models.""" + if self._cached_models is not None and not force_refresh: + return list(self._cached_models) + + if not self._api_key: + return [] + + try: + req = Request( + f"{self._base_url}/models?key={self._api_key}", + method="GET" + ) + with urlopen(req, timeout=10) as response: + data = json.loads(response.read().decode()) + models = [] + for model in data.get("models", []): + name = model.get("name", "") + # Extract model name (e.g., "models/gemini-pro" -> "gemini-pro") + if name.startswith("models/"): + models.append(name[7:]) + elif name: + models.append(name) + + self._cached_models = models + self._is_available = True + return models + except Exception: + self._is_available = False + return [] + + def chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> dict[str, str] | None: + """Execute a blocking chat call.""" + if not self._api_key: + return { + "role": "assistant", + "content": "Gemini API key not configured. Please set your API key in settings.", + } + + # Convert messages to Gemini format + gemini_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + # Gemini uses "user" and "model" instead of "assistant" + gemini_role = "user" if role == "user" else "model" + gemini_messages.append({"role": gemini_role, "parts": [{"text": content}]}) + + payload = { + "contents": gemini_messages, + } + + try: + req = Request( + f"{self._base_url}/models/{model}:generateContent?key={self._api_key}", + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urlopen(req, timeout=120) as response: + result = json.loads(response.read().decode()) + candidates = result.get("candidates", []) + if candidates: + content_parts = candidates[0].get("content", {}).get("parts", []) + if content_parts: + text = content_parts[0].get("text", "") + return {"role": "assistant", "content": text} + return {"role": "assistant", "content": ""} + except Exception as exc: + return { + "role": "assistant", + "content": f"Gemini API error: {exc}", + } + + def stream_chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> Iterator[dict[str, Any]]: + """Execute a streaming chat call.""" + if not self._api_key: + yield { + "role": "assistant", + "content": "Gemini API key not configured. Please set your API key in settings.", + "done": True, + "error": True, + } + return + + # Convert messages to Gemini format + gemini_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + gemini_role = "user" if role == "user" else "model" + gemini_messages.append({"role": gemini_role, "parts": [{"text": content}]}) + + payload = { + "contents": gemini_messages, + } + + try: + req = Request( + f"{self._base_url}/models/{model}:streamGenerateContent?key={self._api_key}", + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urlopen(req, timeout=120) as response: + for line in response: + if not line: + continue + + try: + chunk_data = json.loads(line.decode("utf-8")) + candidates = chunk_data.get("candidates", []) + if candidates: + content_parts = candidates[0].get("content", {}).get("parts", []) + if content_parts: + text = content_parts[0].get("text", "") + if text: + yield { + "message": {"content": text}, + "done": False, + } + + # Check if this is the final chunk + finish_reason = candidates[0].get("finishReason") if candidates else None + if finish_reason: + yield {"done": True} + break + except json.JSONDecodeError: + continue + except Exception as exc: + yield { + "role": "assistant", + "content": f"Gemini API error: {exc}", + "done": True, + "error": True, + } + + +class OpenRouterProvider(AIProvider): + """OpenRouter provider.""" + + def __init__(self, api_key: str | None = None): + """Initialize OpenRouter provider. + + Args: + api_key: OpenRouter API key + """ + self._api_key = api_key + self._is_available = False + self._cached_models: list[str] | None = None + self._base_url = "https://openrouter.ai/api/v1" + + if api_key: + self._check_connection() + + def _check_connection(self) -> None: + """Check if API key is valid.""" + if not self._api_key: + self._is_available = False + return + + try: + req = Request( + f"{self._base_url}/models", + headers={"Authorization": f"Bearer {self._api_key}"}, + method="GET" + ) + with urlopen(req, timeout=5) as response: + if response.status == 200: + self._is_available = True + else: + self._is_available = False + except Exception: + self._is_available = False + + @property + def name(self) -> str: + return "openrouter" + + @property + def is_available(self) -> bool: + return self._is_available and self._api_key is not None + + @property + def default_model(self) -> str | None: + models = self.list_models() + # Prefer common models + preferred = ["openai/gpt-4", "anthropic/claude-3-opus", "meta-llama/llama-3-70b-instruct"] + for pref in preferred: + if pref in models: + return pref + return models[0] if models else None + + def list_models(self, force_refresh: bool = False) -> list[str]: + """List available OpenRouter models.""" + if self._cached_models is not None and not force_refresh: + return list(self._cached_models) + + if not self._api_key: + return [] + + try: + req = Request( + f"{self._base_url}/models", + headers={"Authorization": f"Bearer {self._api_key}"}, + method="GET" + ) + with urlopen(req, timeout=10) as response: + data = json.loads(response.read().decode()) + models = [] + for model in data.get("data", []): + model_id = model.get("id") + if model_id: + models.append(model_id) + + self._cached_models = models + self._is_available = True + return models + except Exception: + self._is_available = False + return [] + + def chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> dict[str, str] | None: + """Execute a blocking chat call.""" + if not self._api_key: + return { + "role": "assistant", + "content": "OpenRouter API key not configured. Please set your API key in settings.", + } + + payload = { + "model": model, + "messages": list(messages), + } + + try: + req = Request( + f"{self._base_url}/chat/completions", + data=json.dumps(payload).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + }, + method="POST", + ) + with urlopen(req, timeout=120) as response: + result = json.loads(response.read().decode()) + choices = result.get("choices", []) + if choices: + message = choices[0].get("message", {}) + content = message.get("content", "") + return {"role": "assistant", "content": content} + return {"role": "assistant", "content": ""} + except Exception as exc: + return { + "role": "assistant", + "content": f"OpenRouter API error: {exc}", + } + + def stream_chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> Iterator[dict[str, Any]]: + """Execute a streaming chat call.""" + if not self._api_key: + yield { + "role": "assistant", + "content": "OpenRouter API key not configured. Please set your API key in settings.", + "done": True, + "error": True, + } + return + + payload = { + "model": model, + "messages": list(messages), + "stream": True, + } + + try: + req = Request( + f"{self._base_url}/chat/completions", + data=json.dumps(payload).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + }, + method="POST", + ) + with urlopen(req, timeout=120) as response: + for line in response: + if not line: + continue + + line_str = line.decode("utf-8").strip() + if not line_str or line_str == "data: [DONE]": + continue + + if line_str.startswith("data: "): + line_str = line_str[6:] # Remove "data: " prefix + + try: + chunk_data = json.loads(line_str) + choices = chunk_data.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + if content: + yield { + "message": {"content": content}, + "done": False, + } + + # Check if finished + finish_reason = choices[0].get("finish_reason") + if finish_reason: + yield {"done": True} + break + except json.JSONDecodeError: + continue + except Exception as exc: + yield { + "role": "assistant", + "content": f"OpenRouter API error: {exc}", + "done": True, + "error": True, + } + + +class CopilotProvider(AIProvider): + """GitHub Copilot provider with OAuth.""" + + def __init__(self, oauth_token: str | None = None): + """Initialize Copilot provider. + + Args: + oauth_token: GitHub OAuth access token + """ + self._oauth_token = oauth_token + self._is_available = False + self._cached_models: list[str] | None = None + self._base_url = "https://api.githubcopilot.com" + + if oauth_token: + self._check_connection() + + def _check_connection(self) -> None: + """Check if OAuth token is valid.""" + if not self._oauth_token: + self._is_available = False + return + + # Basic availability check - could be enhanced with actual API call + self._is_available = True + + @property + def name(self) -> str: + return "copilot" + + @property + def is_available(self) -> bool: + return self._is_available and self._oauth_token is not None + + @property + def default_model(self) -> str: + # GitHub Copilot uses a fixed model + return "github-copilot" + + def list_models(self, force_refresh: bool = False) -> list[str]: + """List available models (Copilot uses a fixed model).""" + if self._cached_models is None or force_refresh: + self._cached_models = ["github-copilot"] + return list(self._cached_models) + + def chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> dict[str, str] | None: + """Execute a blocking chat call.""" + if not self._oauth_token: + return { + "role": "assistant", + "content": "GitHub Copilot not authenticated. Please authenticate via OAuth in settings.", + } + + # Note: GitHub Copilot Chat API endpoint may vary + # This is a placeholder implementation + payload = { + "model": model, + "messages": list(messages), + } + + try: + req = Request( + f"{self._base_url}/chat/completions", + data=json.dumps(payload).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self._oauth_token}", + }, + method="POST", + ) + with urlopen(req, timeout=120) as response: + result = json.loads(response.read().decode()) + choices = result.get("choices", []) + if choices: + message = choices[0].get("message", {}) + content = message.get("content", "") + return {"role": "assistant", "content": content} + return {"role": "assistant", "content": ""} + except Exception as exc: + return { + "role": "assistant", + "content": f"GitHub Copilot API error: {exc}. Note: GitHub Copilot Chat API access requires a Copilot subscription.", + } + + def stream_chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + options: Dict[str, Any] | None = None, + ) -> Iterator[dict[str, Any]]: + """Execute a streaming chat call.""" + if not self._oauth_token: + yield { + "role": "assistant", + "content": "GitHub Copilot not authenticated. Please authenticate via OAuth in settings.", + "done": True, + "error": True, + } + return + + payload = { + "model": model, + "messages": list(messages), + "stream": True, + } + + try: + req = Request( + f"{self._base_url}/chat/completions", + data=json.dumps(payload).encode("utf-8"), + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self._oauth_token}", + }, + method="POST", + ) + with urlopen(req, timeout=120) as response: + for line in response: + if not line: + continue + + line_str = line.decode("utf-8").strip() + if not line_str or line_str == "data: [DONE]": + continue + + if line_str.startswith("data: "): + line_str = line_str[6:] + + try: + chunk_data = json.loads(line_str) + choices = chunk_data.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + if content: + yield { + "message": {"content": content}, + "done": False, + } + + finish_reason = choices[0].get("finish_reason") + if finish_reason: + yield {"done": True} + break + except json.JSONDecodeError: + continue + except Exception as exc: + yield { + "role": "assistant", + "content": f"GitHub Copilot API error: {exc}. Note: GitHub Copilot Chat API access requires a Copilot subscription.", + "done": True, + "error": True, + } + + @staticmethod + def get_oauth_url(client_id: str, redirect_uri: str, scopes: list[str] = None) -> str: + """Generate OAuth authorization URL. + + Args: + client_id: GitHub OAuth app client ID + redirect_uri: OAuth redirect URI + scopes: List of OAuth scopes (default: ['copilot']) + + Returns: + OAuth authorization URL + """ + if scopes is None: + scopes = ["copilot"] + + scope_str = " ".join(scopes) + return ( + f"https://github.com/login/oauth/authorize" + f"?client_id={client_id}" + f"&redirect_uri={redirect_uri}" + f"&scope={scope_str}" + ) + + @staticmethod + def exchange_code_for_token( + client_id: str, client_secret: str, code: str, redirect_uri: str + ) -> str | None: + """Exchange authorization code for access token. + + Args: + client_id: GitHub OAuth app client ID + client_secret: GitHub OAuth app client secret + code: Authorization code from OAuth callback + redirect_uri: OAuth redirect URI + + Returns: + Access token or None on error + """ + payload = { + "client_id": client_id, + "client_secret": client_secret, + "code": code, + "redirect_uri": redirect_uri, + } + + try: + req = Request( + "https://github.com/login/oauth/access_token", + data=json.dumps(payload).encode("utf-8"), + headers={"Content-Type": "application/json", "Accept": "application/json"}, + method="POST", + ) + with urlopen(req, timeout=10) as response: + result = json.loads(response.read().decode()) + return result.get("access_token") + except Exception: + return None + diff --git a/reasoning_controller.py b/reasoning_controller.py index f9eb55d..3ba0e96 100644 --- a/reasoning_controller.py +++ b/reasoning_controller.py @@ -15,6 +15,14 @@ class PreferencesState: reasoning_enabled: bool = False default_model: str | None = None theme: str = "default" + provider: str = "ollama" # AI provider: "ollama", "gemini", "openrouter", "copilot" + api_keys: Dict[str, str] = None # API keys for providers (gemini, openrouter) + copilot_oauth_token: str | None = None # GitHub Copilot OAuth token + + def __post_init__(self): + """Initialize api_keys if None.""" + if self.api_keys is None: + self.api_keys = {} class ReasoningController: @@ -66,6 +74,10 @@ class ReasoningController: Returns: Dictionary of model-specific parameters """ + # Only return options for Ollama (other providers don't use these) + if self._preferences.provider != "ollama": + return {} + if self._preferences.reasoning_enabled: # Thinking model settings return { @@ -85,6 +97,33 @@ class ReasoningController: "num_predict": 32768, } + def get_provider(self) -> str: + """Get the current AI provider.""" + return self._preferences.provider + + def set_provider(self, provider: str) -> None: + """Set the AI provider.""" + self._preferences.provider = provider + self._save_preferences() + + def get_api_key(self, provider: str) -> str | None: + """Get API key for a provider.""" + return self._preferences.api_keys.get(provider) + + def set_api_key(self, provider: str, api_key: str) -> None: + """Set API key for a provider.""" + self._preferences.api_keys[provider] = api_key + self._save_preferences() + + def get_copilot_token(self) -> str | None: + """Get GitHub Copilot OAuth token.""" + return self._preferences.copilot_oauth_token + + def set_copilot_token(self, token: str | None) -> None: + """Set GitHub Copilot OAuth token.""" + self._preferences.copilot_oauth_token = token + self._save_preferences() + def _load_preferences(self) -> PreferencesState: """Load preferences from disk or create defaults. @@ -101,6 +140,9 @@ class ReasoningController: reasoning_enabled=data.get("reasoning_enabled", False), default_model=data.get("default_model"), theme=data.get("theme", "default"), + provider=data.get("provider", "ollama"), + api_keys=data.get("api_keys", {}), + copilot_oauth_token=data.get("copilot_oauth_token"), ) except (json.JSONDecodeError, OSError): # If file is corrupted or unreadable, return defaults diff --git a/settings_widget.py b/settings_widget.py new file mode 100644 index 0000000..80851b3 --- /dev/null +++ b/settings_widget.py @@ -0,0 +1,366 @@ +"""Settings widget for AI provider configuration.""" + +from __future__ import annotations + +import webbrowser +from gi.repository import GLib, Gtk +from ignis import widgets + +from .provider_client import ( + OllamaProvider, + GeminiProvider, + OpenRouterProvider, + CopilotProvider, +) +from .reasoning_controller import ReasoningController + + +class SettingsWidget(widgets.Box): + """Settings view for configuring AI providers.""" + + def __init__(self, preferences: ReasoningController, on_provider_changed=None, on_back=None): + """Initialize settings widget. + + Args: + preferences: ReasoningController instance for managing preferences + on_provider_changed: Callback function called when provider changes + on_back: Callback function called when back button is clicked + """ + self._preferences = preferences + self._on_provider_changed = on_provider_changed + self._on_back_callback = on_back + + # Provider instances (will be created as needed) + self._providers = {} + + # Header + header_title = widgets.Label( + label="Settings", + halign="start", + css_classes=["title-2"], + ) + + back_button = widgets.Button( + label="← Back", + on_click=lambda x: self._on_back(), + halign="start", + hexpand=False, + ) + + header_box = widgets.Box( + spacing=8, + hexpand=True, + child=[back_button, header_title], + ) + + # Provider selection + provider_label = widgets.Label( + label="AI Provider", + halign="start", + css_classes=["title-3"], + ) + + # Radio buttons for provider selection + self._provider_buttons = {} + provider_names = { + "ollama": "Ollama (Local)", + "gemini": "Google Gemini", + "openrouter": "OpenRouter", + "copilot": "GitHub Copilot", + } + + current_provider = self._preferences.get_provider() + provider_box = widgets.Box(vertical=True, spacing=4) + + for provider_id, provider_label_text in provider_names.items(): + button = Gtk.CheckButton(label=provider_label_text) + button.set_active(provider_id == current_provider) + button.connect("toggled", lambda btn, pid=provider_id: self._on_provider_selected(pid) if btn.get_active() else None) + self._provider_buttons[provider_id] = button + provider_box.append(button) + + # API key inputs + self._api_key_entries = {} + + # Gemini API key + gemini_label = widgets.Label( + label="Gemini API Key", + halign="start", + ) + gemini_entry = Gtk.Entry() + gemini_entry.set_placeholder_text("Enter your Gemini API key") + gemini_entry.set_visibility(False) # Password mode + gemini_entry.set_input_purpose(Gtk.InputPurpose.PASSWORD) + api_key = self._preferences.get_api_key("gemini") + if api_key: + gemini_entry.set_text(api_key) + gemini_entry.connect("changed", lambda e: self._on_api_key_changed("gemini", e.get_text())) + self._api_key_entries["gemini"] = gemini_entry + + # Toggle visibility button for Gemini + gemini_visibility_button = Gtk.Button(icon_name="view-reveal-symbolic") + gemini_visibility_button.connect("clicked", lambda b: self._toggle_password_visibility(gemini_entry, b)) + + gemini_box = widgets.Box( + spacing=8, + hexpand=True, + child=[gemini_entry, gemini_visibility_button], + ) + + # OpenRouter API key + openrouter_label = widgets.Label( + label="OpenRouter API Key", + halign="start", + ) + openrouter_entry = Gtk.Entry() + openrouter_entry.set_placeholder_text("Enter your OpenRouter API key") + openrouter_entry.set_visibility(False) # Password mode + openrouter_entry.set_input_purpose(Gtk.InputPurpose.PASSWORD) + api_key = self._preferences.get_api_key("openrouter") + if api_key: + openrouter_entry.set_text(api_key) + openrouter_entry.connect("changed", lambda e: self._on_api_key_changed("openrouter", e.get_text())) + self._api_key_entries["openrouter"] = openrouter_entry + + # Toggle visibility button for OpenRouter + openrouter_visibility_button = Gtk.Button(icon_name="view-reveal-symbolic") + openrouter_visibility_button.connect("clicked", lambda b: self._toggle_password_visibility(openrouter_entry, b)) + + openrouter_box = widgets.Box( + spacing=8, + hexpand=True, + child=[openrouter_entry, openrouter_visibility_button], + ) + + # GitHub Copilot OAuth + copilot_label = widgets.Label( + label="GitHub Copilot", + halign="start", + ) + copilot_status_label = widgets.Label( + label="Not authenticated" if not self._preferences.get_copilot_token() else "Authenticated", + halign="start", + css_classes=["dim-label"], + ) + + copilot_auth_button = widgets.Button( + label="Authenticate with GitHub" if not self._preferences.get_copilot_token() else "Re-authenticate", + on_click=lambda x: self._on_copilot_auth(), + ) + + # Test connection buttons + test_ollama_button = widgets.Button( + label="Test Connection", + on_click=lambda x: self._test_connection("ollama"), + ) + + test_gemini_button = widgets.Button( + label="Test Connection", + on_click=lambda x: self._test_connection("gemini"), + ) + + test_openrouter_button = widgets.Button( + label="Test Connection", + on_click=lambda x: self._test_connection("openrouter"), + ) + + test_copilot_button = widgets.Button( + label="Test Connection", + on_click=lambda x: self._test_connection("copilot"), + ) + + # Status labels for test results + self._test_status_labels = { + "ollama": widgets.Label(label="", halign="start", css_classes=["dim-label"]), + "gemini": widgets.Label(label="", halign="start", css_classes=["dim-label"]), + "openrouter": widgets.Label(label="", halign="start", css_classes=["dim-label"]), + "copilot": widgets.Label(label="", halign="start", css_classes=["dim-label"]), + } + + # Layout + settings_content = widgets.Box( + vertical=True, + spacing=16, + hexpand=True, + vexpand=True, + ) + + # Provider section + settings_content.append(provider_label) + settings_content.append(provider_box) + spacer1 = widgets.Box(hexpand=True) + spacer1.set_size_request(-1, 12) + settings_content.append(spacer1) + + # Gemini section + settings_content.append(gemini_label) + settings_content.append(gemini_box) + test_gemini_box = widgets.Box(spacing=8, hexpand=True, child=[test_gemini_button, self._test_status_labels["gemini"]]) + settings_content.append(test_gemini_box) + spacer2 = widgets.Box(hexpand=True) + spacer2.set_size_request(-1, 8) + settings_content.append(spacer2) + + # OpenRouter section + settings_content.append(openrouter_label) + settings_content.append(openrouter_box) + test_openrouter_box = widgets.Box(spacing=8, hexpand=True, child=[test_openrouter_button, self._test_status_labels["openrouter"]]) + settings_content.append(test_openrouter_box) + spacer3 = widgets.Box(hexpand=True) + spacer3.set_size_request(-1, 8) + settings_content.append(spacer3) + + # Ollama section + ollama_label = widgets.Label( + label="Ollama", + halign="start", + ) + ollama_info = widgets.Label( + label="Local Ollama server. Start with: ollama serve", + halign="start", + css_classes=["dim-label"], + ) + settings_content.append(ollama_label) + settings_content.append(ollama_info) + test_ollama_box = widgets.Box(spacing=8, hexpand=True, child=[test_ollama_button, self._test_status_labels["ollama"]]) + settings_content.append(test_ollama_box) + spacer4 = widgets.Box(hexpand=True) + spacer4.set_size_request(-1, 8) + settings_content.append(spacer4) + + # Copilot section + settings_content.append(copilot_label) + settings_content.append(copilot_status_label) + settings_content.append(copilot_auth_button) + test_copilot_box = widgets.Box(spacing=8, hexpand=True, child=[test_copilot_button, self._test_status_labels["copilot"]]) + settings_content.append(test_copilot_box) + + # Scrolled window for settings + scroller = widgets.Scroll( + hexpand=True, + vexpand=True, + child=settings_content, + ) + + # Main container + super().__init__( + vertical=True, + spacing=12, + hexpand=True, + vexpand=True, + child=[header_box, scroller], + css_classes=["ai-sidebar-content"], + ) + + # Set margins + self.set_margin_top(16) + self.set_margin_bottom(16) + self.set_margin_start(16) + self.set_margin_end(16) + + def _toggle_password_visibility(self, entry: Gtk.Entry, button: Gtk.Button): + """Toggle password visibility in entry.""" + visible = entry.get_visibility() + entry.set_visibility(not visible) + button.set_icon_name("view-reveal-symbolic" if visible else "view-conceal-symbolic") + + def _on_provider_selected(self, provider_id: str): + """Handle provider selection.""" + self._preferences.set_provider(provider_id) + if self._on_provider_changed: + self._on_provider_changed(provider_id) + + def _on_api_key_changed(self, provider: str, api_key: str): + """Handle API key changes.""" + self._preferences.set_api_key(provider, api_key) + # Clear test status + self._test_status_labels[provider].label = "" + + def _on_copilot_auth(self): + """Handle GitHub Copilot OAuth authentication.""" + # Note: For a full OAuth implementation, you would need: + # 1. A GitHub OAuth app registered + # 2. A local HTTP server to receive the callback + # 3. Exchange authorization code for token + + # For now, we'll show a message that OAuth setup is required + status_label = self._test_status_labels["copilot"] + status_label.label = "OAuth setup required. See documentation for GitHub OAuth app configuration." + status_label.css_classes = ["dim-label"] + + # TODO: Implement full OAuth flow + # This would involve: + # 1. Generate OAuth URL + # 2. Open browser + # 3. Start local server to receive callback + # 4. Exchange code for token + # 5. Save token + + def _test_connection(self, provider_id: str): + """Test connection to a provider.""" + status_label = self._test_status_labels[provider_id] + status_label.label = "Testing..." + status_label.css_classes = ["dim-label"] + + def _test(): + try: + provider = self._get_provider(provider_id) + if provider: + success, message = provider.test_connection() + GLib.idle_add( + lambda: self._update_test_status(provider_id, success, message), + priority=GLib.PRIORITY_DEFAULT + ) + else: + GLib.idle_add( + lambda: self._update_test_status(provider_id, False, "Provider not configured"), + priority=GLib.PRIORITY_DEFAULT + ) + except Exception as e: + GLib.idle_add( + lambda: self._update_test_status(provider_id, False, f"Error: {str(e)}"), + priority=GLib.PRIORITY_DEFAULT + ) + + import threading + thread = threading.Thread(target=_test, daemon=True) + thread.start() + + def _update_test_status(self, provider_id: str, success: bool, message: str): + """Update test status label.""" + status_label = self._test_status_labels[provider_id] + status_label.label = message + if success: + status_label.css_classes = ["dim-label"] + else: + status_label.css_classes = ["dim-label"] + + def _get_provider(self, provider_id: str): + """Get or create provider instance.""" + if provider_id in self._providers: + return self._providers[provider_id] + + if provider_id == "ollama": + provider = OllamaProvider() + elif provider_id == "gemini": + api_key = self._preferences.get_api_key("gemini") + provider = GeminiProvider(api_key=api_key) if api_key else None + elif provider_id == "openrouter": + api_key = self._preferences.get_api_key("openrouter") + provider = OpenRouterProvider(api_key=api_key) if api_key else None + elif provider_id == "copilot": + token = self._preferences.get_copilot_token() + provider = CopilotProvider(oauth_token=token) if token else None + else: + return None + + if provider: + self._providers[provider_id] = provider + return provider + + def _on_back(self): + """Handle back button click.""" + # Signal to parent to switch back to chat view + if hasattr(self, '_on_back_callback'): + self._on_back_callback() +