From da57c43e6972d30fe3598c3beef5af29d11e68cb Mon Sep 17 00:00:00 2001 From: Melvin Ragusa Date: Sat, 25 Oct 2025 18:22:07 +0200 Subject: [PATCH] feat: implement conversation state management and persistence, enhance sidebar UI --- AGENTS.md | 10 +- conversation_manager.py | 171 +++++++++++++++++++++++- ollama_client.py | 141 ++++++++++++++++++- sidebar_window.py | 208 ++++++++++++++++++++++++++--- tests/test_conversation_manager.py | 34 +++++ 5 files changed, 536 insertions(+), 28 deletions(-) create mode 100644 tests/test_conversation_manager.py diff --git a/AGENTS.md b/AGENTS.md index e6bf76a..0d9b2e7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,17 +1,18 @@ # Repository Guidelines ## Project Structure & Module Organization -- `main.py` wires the GTK `Application`; UI widgets live in `sidebar_window.py` and `message_widget.py`. +- `main.py` wires the GTK `Application` and guards headless runs; UI widgets live in `sidebar_window.py` and `message_widget.py`. - `ollama_client.py` wraps streaming calls and threading helpers so GTK stays responsive. - Conversation state persists through `conversation_manager.py` and JSON files under `data/conversations/`; keep writes atomic. -- Shared settings belong in `config.py` and styles in `styles.css`; prefer adding focused modules over bloating these. +- Shared settings belong in `config.py`, styles in `styles.css`, and tooling defaults in `pyproject.toml`; prefer adding focused modules over bloating these. - Tests mirror the source tree under `tests/`, with fixtures in `tests/fixtures/` for reusable transcripts and metadata. ## Build, Test, and Development Commands - `python -m venv .venv && source .venv/bin/activate` — creates and activates the project’s virtual environment. - `pip install -r requirements.txt` — installs GTK, Ollama, and tooling dependencies. -- `python main.py` — launches the sidebar in development mode; pass `--mock-ollama` when iterating without a local model. -- `pytest` — runs the full test suite; combine with `pytest -k "conversation"` for targeted checks. +- `python main.py` — launches the sidebar; requires a Wayland/X11 session. +- `AI_SIDEBAR_HEADLESS=1 python main.py` — skips GTK startup for CI smoke checks. +- `AI_SIDEBAR_HEADLESS=1 pytest` — runs the full test suite; combine with `-k "conversation"` for targeted checks. ## Coding Style & Naming Conventions - Use 4-space indentation and format with `black .`; avoid tab characters. @@ -23,6 +24,7 @@ - Prefer `pytest` parameterized cases for conversation flows; store golden transcripts in `tests/fixtures/responses/`. - Name tests `test__` (e.g., `test_conversation_manager_persists_history`). - Cover threading boundaries by mocking Ollama responses and asserting GTK updates via `GLib.idle_add`. +- Use `AI_SIDEBAR_HEADLESS=1` when exercising tests or scripts in non-GUI environments. - Run `pytest --maxfail=1` before commits to catch regressions early. ## Commit & Pull Request Guidelines diff --git a/conversation_manager.py b/conversation_manager.py index 3e1aa16..7e0b365 100644 --- a/conversation_manager.py +++ b/conversation_manager.py @@ -1,8 +1,173 @@ """Conversation state management and persistence helpers.""" +from __future__ import annotations + +import json +import os +import tempfile +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import ClassVar, Dict, Iterable, List, MutableMapping + +DEFAULT_CONVERSATION_ID = "default" + + +@dataclass +class ConversationState: + """In-memory representation of a conversation transcript.""" + + conversation_id: str + created_at: str + updated_at: str + messages: List[Dict[str, str]] = field(default_factory=list) + class ConversationManager: - """Placeholder conversation manager awaiting implementation.""" + """Load and persist conversation transcripts as JSON files.""" - def __init__(self) -> None: - raise NotImplementedError("Conversation manager not implemented yet.") + VALID_ROLES: ClassVar[set[str]] = {"system", "user", "assistant"} + + def __init__( + self, + storage_dir: str | Path | None = None, + conversation_id: str | None = None, + ) -> None: + module_root = Path(__file__).resolve().parent + default_storage = module_root / "data" / "conversations" + self._storage_dir = Path(storage_dir) if storage_dir else default_storage + self._storage_dir.mkdir(parents=True, exist_ok=True) + + self._conversation_id = conversation_id or DEFAULT_CONVERSATION_ID + self._path = self._storage_dir / f"{self._conversation_id}.json" + + self._state = self._load_state() + + # ------------------------------------------------------------------ properties + @property + def conversation_id(self) -> str: + return self._state.conversation_id + + @property + def messages(self) -> List[Dict[str, str]]: + return list(self._state.messages) + + @property + def chat_messages(self) -> List[Dict[str, str]]: + """Return messages formatted for the Ollama chat API.""" + return [ + {"role": msg["role"], "content": msg["content"]} + for msg in self._state.messages + ] + + # ------------------------------------------------------------------ public API + def append_message(self, role: str, content: str) -> Dict[str, str]: + """Append a new message and persist the updated transcript.""" + normalized_role = role.lower() + if normalized_role not in self.VALID_ROLES: + raise ValueError(f"Invalid role '{role}'. Expected one of {self.VALID_ROLES}.") + + timestamp = datetime.now(timezone.utc).isoformat() + message = { + "role": normalized_role, + "content": content, + "timestamp": timestamp, + } + + self._state.messages.append(message) + self._state.updated_at = timestamp + self._write_state() + return message + + def replace_messages(self, messages: Iterable[Dict[str, str]]) -> None: + """Replace the transcript contents. Useful for loading fixtures.""" + normalized: List[Dict[str, str]] = [] + for item in messages: + role = item.get("role", "").lower() + content = item.get("content", "") + if role not in self.VALID_ROLES: + continue + normalized.append( + { + "role": role, + "content": content, + "timestamp": item.get("timestamp") + or datetime.now(timezone.utc).isoformat(), + } + ) + + now = datetime.now(timezone.utc).isoformat() + self._state.messages = normalized + self._state.created_at = self._state.created_at or now + self._state.updated_at = now + self._write_state() + + # ------------------------------------------------------------------ persistence + def _load_state(self) -> ConversationState: + """Load the transcript from disk or create a fresh default.""" + if self._path.exists(): + try: + with self._path.open("r", encoding="utf-8") as fh: + payload = json.load(fh) + return self._state_from_payload(payload) + except (json.JSONDecodeError, OSError): + pass + + timestamp = datetime.now(timezone.utc).isoformat() + return ConversationState( + conversation_id=self._conversation_id, + created_at=timestamp, + updated_at=timestamp, + messages=[], + ) + + def _state_from_payload(self, payload: MutableMapping[str, object]) -> ConversationState: + """Normalize persisted data into ConversationState instances.""" + conversation_id = str(payload.get("id") or self._conversation_id) + created_at = str(payload.get("created_at") or datetime.now(timezone.utc).isoformat()) + updated_at = str(payload.get("updated_at") or created_at) + + messages_payload = payload.get("messages", []) + messages: List[Dict[str, str]] = [] + if isinstance(messages_payload, list): + for item in messages_payload: + if not isinstance(item, dict): + continue + role = str(item.get("role", "")).lower() + content = str(item.get("content", "")) + if role not in self.VALID_ROLES: + continue + timestamp = str( + item.get("timestamp") or datetime.now(timezone.utc).isoformat() + ) + messages.append({"role": role, "content": content, "timestamp": timestamp}) + + return ConversationState( + conversation_id=conversation_id, + created_at=created_at, + updated_at=updated_at, + messages=messages, + ) + + def _write_state(self) -> None: + """Persist the conversation state atomically.""" + payload = { + "id": self._state.conversation_id, + "created_at": self._state.created_at, + "updated_at": self._state.updated_at, + "messages": self._state.messages, + } + + with tempfile.NamedTemporaryFile( + "w", + encoding="utf-8", + dir=self._storage_dir, + delete=False, + prefix=f"{self._conversation_id}.", + suffix=".tmp", + ) as tmp_file: + json.dump(payload, tmp_file, indent=2, ensure_ascii=False) + tmp_file.flush() + os.fsync(tmp_file.fileno()) + + os.replace(tmp_file.name, self._path) diff --git a/ollama_client.py b/ollama_client.py index a94e947..d6410aa 100644 --- a/ollama_client.py +++ b/ollama_client.py @@ -1,8 +1,143 @@ """Client utilities for interacting with the Ollama API.""" +from __future__ import annotations + +from typing import Any, Dict, Iterable, Iterator + +try: # pragma: no cover - optional dependency may not be installed in CI + import ollama +except ImportError: # pragma: no cover - fallback path for environments without Ollama + ollama = None # type: ignore[assignment] + + +class OllamaClientError(RuntimeError): + """Base exception raised when Ollama operations fail.""" + + +class OllamaUnavailableError(OllamaClientError): + """Raised when the Ollama Python SDK is not available.""" + class OllamaClient: - """Placeholder client used until streaming integration is built.""" + """Thin wrapper around the Ollama Python SDK with graceful degradation.""" - def __init__(self) -> None: - raise NotImplementedError("Ollama client not implemented yet.") + def __init__(self, host: str | None = None) -> None: + self._host = host + self._client = None + self._cached_models: list[str] | None = None + + if ollama is None: + return + + if host and hasattr(ollama, "Client"): + self._client = ollama.Client(host=host) # type: ignore[call-arg] + + # ------------------------------------------------------------------ helpers + @property + def is_available(self) -> bool: + return ollama is not None + + @property + def default_model(self) -> str | None: + models = self.list_models() + return models[0] if models else None + + def list_models(self, force_refresh: bool = False) -> list[str]: + """Return the available model names, caching the result for quick reuse.""" + if not self.is_available: + return [] + + if self._cached_models is not None and not force_refresh: + return list(self._cached_models) + + try: + response = self._call_sdk("list") # type: ignore[arg-type] + except OllamaClientError: + return [] + + models: list[str] = [] + for item in response.get("models", []): # type: ignore[assignment] + name = item.get("name") or item.get("model") + if name: + models.append(name) + + self._cached_models = models + return list(models) + + # ------------------------------------------------------------------ chat APIs + def chat( + self, + *, + model: str, + messages: Iterable[Dict[str, str]], + ) -> dict[str, str] | None: + """Execute a blocking chat call against Ollama.""" + if not self.is_available: + return { + "role": "assistant", + "content": "Ollama SDK is not installed; install `ollama` to enable responses.", + } + + try: + result = self._call_sdk( + "chat", + model=model, + messages=list(messages), + stream=False, + ) + except OllamaClientError as exc: + return { + "role": "assistant", + "content": f"Unable to reach Ollama: {exc}", + } + + message = result.get("message") if isinstance(result, dict) else None + if not message: + return {"role": "assistant", "content": ""} + + role = message.get("role") or "assistant" + content = message.get("content") or "" + return {"role": role, "content": content} + + def stream_chat( + self, *, model: str, messages: Iterable[Dict[str, str]] + ) -> Iterator[dict[str, Any]]: + """Placeholder that exposes the streaming API for future UI hooks.""" + if not self.is_available: + raise OllamaUnavailableError( + "Streaming requires the Ollama Python SDK to be installed." + ) + + try: + stream = self._call_sdk( + "chat", + model=model, + messages=list(messages), + stream=True, + ) + except OllamaClientError as exc: + raise OllamaClientError(f"Failed to start streaming chat: {exc}") from exc + + if not hasattr(stream, "__iter__"): + raise OllamaClientError("Ollama returned a non-iterable stream response.") + return iter(stream) + + # ------------------------------------------------------------------ internals + def _call_sdk(self, method: str, *args: Any, **kwargs: Any) -> Any: + if not self.is_available: + raise OllamaUnavailableError( + "Ollama Python SDK is not available in the environment." + ) + + target = self._client if self._client is not None else ollama + if target is None or not hasattr(target, method): + raise OllamaClientError( + f"Ollama SDK does not expose method '{method}'. Install or update the SDK." + ) + + func = getattr(target, method) + + try: + return func(*args, **kwargs) + except Exception as exc: # pragma: no cover - network errors depend on runtime + raise OllamaClientError(str(exc)) from exc diff --git a/sidebar_window.py b/sidebar_window.py index f6d9b1c..9e8b6ce 100644 --- a/sidebar_window.py +++ b/sidebar_window.py @@ -2,36 +2,208 @@ from __future__ import annotations +import threading +from typing import Iterable + import gi gi.require_version("Gtk", "4.0") -from gi.repository import Gtk # noqa: E402 +from gi.repository import GLib, Gtk # noqa: E402 + +try: # pragma: no cover - optional dependency may not be available in CI + gi.require_version("Gtk4LayerShell", "1.0") + from gi.repository import Gtk4LayerShell # type: ignore[attr-defined] +except (ImportError, ValueError): # pragma: no cover - fallback path + Gtk4LayerShell = None # type: ignore[misc] + +from conversation_manager import ConversationManager +from ollama_client import OllamaClient class SidebarWindow(Gtk.ApplicationWindow): - """Minimal window placeholder to confirm the GTK application starts.""" + """Layer-shell anchored window hosting the chat interface.""" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.set_default_size(360, 640) + self.set_default_size(360, 720) self.set_title("Niri AI Sidebar") + self.set_hide_on_close(False) - layout = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=12) - layout.set_margin_top(24) - layout.set_margin_bottom(24) - layout.set_margin_start(24) - layout.set_margin_end(24) + self._conversation_manager = ConversationManager() + self._ollama_client = OllamaClient() + self._current_model = self._ollama_client.default_model - title = Gtk.Label(label="AI Sidebar") - title.set_halign(Gtk.Align.START) - title.get_style_context().add_class("title-1") + self._setup_layer_shell() + self._build_ui() + self._populate_initial_messages() - message = Gtk.Label( - label="GTK app is running. Replace this view with the chat interface." + # ------------------------------------------------------------------ UI setup + def _setup_layer_shell(self) -> None: + """Attach the window to the left edge via gtk4-layer-shell when available.""" + if Gtk4LayerShell is None: + return + + Gtk4LayerShell.init_for_window(self) + Gtk4LayerShell.set_namespace(self, "niri-ai-sidebar") + Gtk4LayerShell.set_layer(self, Gtk4LayerShell.Layer.TOP) + Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.LEFT, True) + Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.TOP, True) + Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.BOTTOM, True) + Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.RIGHT, False) + Gtk4LayerShell.set_margin(self, Gtk4LayerShell.Edge.LEFT, 0) + Gtk4LayerShell.set_keyboard_mode( + self, Gtk4LayerShell.KeyboardMode.ON_DEMAND ) - message.set_wrap(True) - message.set_halign(Gtk.Align.START) + Gtk4LayerShell.set_exclusive_zone(self, -1) - layout.append(title) - layout.append(message) - self.set_child(layout) + def _build_ui(self) -> None: + """Create the core layout: message history and input entry.""" + main_box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=12) + main_box.set_margin_top(16) + main_box.set_margin_bottom(16) + main_box.set_margin_start(16) + main_box.set_margin_end(16) + main_box.set_hexpand(True) + main_box.set_vexpand(True) + + header_box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=4) + header_title = Gtk.Label(label="Niri AI Sidebar") + header_title.set_halign(Gtk.Align.START) + header_title.get_style_context().add_class("title-2") + + model_name = self._current_model or "No local model detected" + self._model_label = Gtk.Label(label=f"Model: {model_name}") + self._model_label.set_halign(Gtk.Align.START) + self._model_label.get_style_context().add_class("dim-label") + + header_box.append(header_title) + header_box.append(self._model_label) + + self._message_list = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=8) + self._message_list.set_hexpand(True) + self._message_list.set_vexpand(True) + self._message_list.set_valign(Gtk.Align.START) + + scroller = Gtk.ScrolledWindow() + scroller.set_hexpand(True) + scroller.set_vexpand(True) + scroller.set_child(self._message_list) + scroller.set_min_content_height(300) + self._scroller = scroller + + input_box = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=8) + input_box.set_hexpand(True) + + self._entry = Gtk.Entry() + self._entry.set_hexpand(True) + self._entry.set_placeholder_text("Ask a question…") + self._entry.connect("activate", self._on_submit) + + self._send_button = Gtk.Button(label="Send") + self._send_button.connect("clicked", self._on_submit) + + input_box.append(self._entry) + input_box.append(self._send_button) + + main_box.append(header_box) + main_box.append(scroller) + main_box.append(input_box) + + self.set_child(main_box) + + def _populate_initial_messages(self) -> None: + """Render conversation history stored on disk.""" + for message in self._conversation_manager.messages: + self._append_message(message["role"], message["content"], persist=False) + + if not self._conversation_manager.messages: + self._append_message( + "assistant", + "Welcome! Ask a question to start a conversation.", + persist=True, + ) + + # ------------------------------------------------------------------ helpers + def _append_message( + self, role: str, content: str, *, persist: bool = True + ) -> None: + """Add a message bubble to the history and optionally persist it.""" + label_prefix = "You" if role == "user" else "Assistant" + label = Gtk.Label(label=f"{label_prefix}: {content}") + label.set_halign(Gtk.Align.START) + label.set_xalign(0.0) + label.set_wrap(True) + label.set_wrap_mode(Gtk.WrapMode.WORD_CHAR) + label.set_justify(Gtk.Justification.LEFT) + + self._message_list.append(label) + self._scroll_to_bottom() + + if persist: + self._conversation_manager.append_message(role, content) + + def _scroll_to_bottom(self) -> None: + """Ensure the most recent message is visible.""" + def _scroll() -> bool: + adjustment = self._scroller.get_vadjustment() + if adjustment is not None: + adjustment.set_value(adjustment.get_upper() - adjustment.get_page_size()) + return False + + GLib.idle_add(_scroll) + + def _set_input_enabled(self, enabled: bool) -> None: + self._entry.set_sensitive(enabled) + self._send_button.set_sensitive(enabled) + + # ------------------------------------------------------------------ callbacks + def _on_submit(self, _widget: Gtk.Widget) -> None: + """Handle send button clicks or entry activation.""" + text = self._entry.get_text().strip() + if not text: + return + + self._entry.set_text("") + self._append_message("user", text, persist=True) + self._request_response() + + def _request_response(self) -> None: + """Trigger a synchronous Ollama chat call on a worker thread.""" + model = self._current_model or self._ollama_client.default_model + if not model: + self._append_message( + "assistant", + "No Ollama models are available. Install a model to continue.", + persist=True, + ) + return + + history = self._conversation_manager.chat_messages + self._set_input_enabled(False) + + def _worker(messages: Iterable[dict[str, str]]) -> None: + response = self._ollama_client.chat(model=model, messages=list(messages)) + GLib.idle_add(self._handle_response, response, priority=GLib.PRIORITY_DEFAULT) + + thread = threading.Thread(target=_worker, args=(history,), daemon=True) + thread.start() + + def _handle_response(self, response: dict[str, str] | None) -> bool: + """Render the assistant reply and re-enable the entry.""" + self._set_input_enabled(True) + + if not response: + self._append_message( + "assistant", + "The model returned an empty response.", + persist=True, + ) + return False + + role = response.get("role", "assistant") + content = response.get("content") or "" + if not content: + content = "[No content received from Ollama]" + + self._append_message(role, content, persist=True) + return False diff --git a/tests/test_conversation_manager.py b/tests/test_conversation_manager.py new file mode 100644 index 0000000..4559a19 --- /dev/null +++ b/tests/test_conversation_manager.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from conversation_manager import ConversationManager + + +def test_conversation_manager_persists_history(tmp_path: Path) -> None: + manager = ConversationManager(storage_dir=tmp_path, conversation_id="test") + manager.append_message("user", "Hello there!") + manager.append_message("assistant", "General Kenobi.") + + conversation_file = tmp_path / "test.json" + assert conversation_file.exists() + + data = json.loads(conversation_file.read_text(encoding="utf-8")) + assert len(data["messages"]) == 2 + assert data["messages"][0]["content"] == "Hello there!" + + reloaded = ConversationManager(storage_dir=tmp_path, conversation_id="test") + assert [msg["content"] for msg in reloaded.messages] == [ + "Hello there!", + "General Kenobi.", + ] + + +def test_conversation_manager_rejects_invalid_role(tmp_path: Path) -> None: + manager = ConversationManager(storage_dir=tmp_path, conversation_id="invalid") + + with pytest.raises(ValueError): + manager.append_message("narrator", "This should fail.")