feat: implement conversation state management and persistence, enhance sidebar UI
This commit is contained in:
10
AGENTS.md
10
AGENTS.md
@@ -1,17 +1,18 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
- `main.py` wires the GTK `Application`; UI widgets live in `sidebar_window.py` and `message_widget.py`.
|
||||
- `main.py` wires the GTK `Application` and guards headless runs; UI widgets live in `sidebar_window.py` and `message_widget.py`.
|
||||
- `ollama_client.py` wraps streaming calls and threading helpers so GTK stays responsive.
|
||||
- Conversation state persists through `conversation_manager.py` and JSON files under `data/conversations/`; keep writes atomic.
|
||||
- Shared settings belong in `config.py` and styles in `styles.css`; prefer adding focused modules over bloating these.
|
||||
- Shared settings belong in `config.py`, styles in `styles.css`, and tooling defaults in `pyproject.toml`; prefer adding focused modules over bloating these.
|
||||
- Tests mirror the source tree under `tests/`, with fixtures in `tests/fixtures/` for reusable transcripts and metadata.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
- `python -m venv .venv && source .venv/bin/activate` — creates and activates the project’s virtual environment.
|
||||
- `pip install -r requirements.txt` — installs GTK, Ollama, and tooling dependencies.
|
||||
- `python main.py` — launches the sidebar in development mode; pass `--mock-ollama` when iterating without a local model.
|
||||
- `pytest` — runs the full test suite; combine with `pytest -k "conversation"` for targeted checks.
|
||||
- `python main.py` — launches the sidebar; requires a Wayland/X11 session.
|
||||
- `AI_SIDEBAR_HEADLESS=1 python main.py` — skips GTK startup for CI smoke checks.
|
||||
- `AI_SIDEBAR_HEADLESS=1 pytest` — runs the full test suite; combine with `-k "conversation"` for targeted checks.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
- Use 4-space indentation and format with `black .`; avoid tab characters.
|
||||
@@ -23,6 +24,7 @@
|
||||
- Prefer `pytest` parameterized cases for conversation flows; store golden transcripts in `tests/fixtures/responses/`.
|
||||
- Name tests `test_<module>_<behavior>` (e.g., `test_conversation_manager_persists_history`).
|
||||
- Cover threading boundaries by mocking Ollama responses and asserting GTK updates via `GLib.idle_add`.
|
||||
- Use `AI_SIDEBAR_HEADLESS=1` when exercising tests or scripts in non-GUI environments.
|
||||
- Run `pytest --maxfail=1` before commits to catch regressions early.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
|
||||
@@ -1,8 +1,173 @@
|
||||
"""Conversation state management and persistence helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, Iterable, List, MutableMapping
|
||||
|
||||
DEFAULT_CONVERSATION_ID = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationState:
|
||||
"""In-memory representation of a conversation transcript."""
|
||||
|
||||
conversation_id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
messages: List[Dict[str, str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""Placeholder conversation manager awaiting implementation."""
|
||||
"""Load and persist conversation transcripts as JSON files."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError("Conversation manager not implemented yet.")
|
||||
VALID_ROLES: ClassVar[set[str]] = {"system", "user", "assistant"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: str | Path | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> None:
|
||||
module_root = Path(__file__).resolve().parent
|
||||
default_storage = module_root / "data" / "conversations"
|
||||
self._storage_dir = Path(storage_dir) if storage_dir else default_storage
|
||||
self._storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._conversation_id = conversation_id or DEFAULT_CONVERSATION_ID
|
||||
self._path = self._storage_dir / f"{self._conversation_id}.json"
|
||||
|
||||
self._state = self._load_state()
|
||||
|
||||
# ------------------------------------------------------------------ properties
|
||||
@property
|
||||
def conversation_id(self) -> str:
|
||||
return self._state.conversation_id
|
||||
|
||||
@property
|
||||
def messages(self) -> List[Dict[str, str]]:
|
||||
return list(self._state.messages)
|
||||
|
||||
@property
|
||||
def chat_messages(self) -> List[Dict[str, str]]:
|
||||
"""Return messages formatted for the Ollama chat API."""
|
||||
return [
|
||||
{"role": msg["role"], "content": msg["content"]}
|
||||
for msg in self._state.messages
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------ public API
|
||||
def append_message(self, role: str, content: str) -> Dict[str, str]:
|
||||
"""Append a new message and persist the updated transcript."""
|
||||
normalized_role = role.lower()
|
||||
if normalized_role not in self.VALID_ROLES:
|
||||
raise ValueError(f"Invalid role '{role}'. Expected one of {self.VALID_ROLES}.")
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
message = {
|
||||
"role": normalized_role,
|
||||
"content": content,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
self._state.messages.append(message)
|
||||
self._state.updated_at = timestamp
|
||||
self._write_state()
|
||||
return message
|
||||
|
||||
def replace_messages(self, messages: Iterable[Dict[str, str]]) -> None:
|
||||
"""Replace the transcript contents. Useful for loading fixtures."""
|
||||
normalized: List[Dict[str, str]] = []
|
||||
for item in messages:
|
||||
role = item.get("role", "").lower()
|
||||
content = item.get("content", "")
|
||||
if role not in self.VALID_ROLES:
|
||||
continue
|
||||
normalized.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": item.get("timestamp")
|
||||
or datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
self._state.messages = normalized
|
||||
self._state.created_at = self._state.created_at or now
|
||||
self._state.updated_at = now
|
||||
self._write_state()
|
||||
|
||||
# ------------------------------------------------------------------ persistence
|
||||
def _load_state(self) -> ConversationState:
|
||||
"""Load the transcript from disk or create a fresh default."""
|
||||
if self._path.exists():
|
||||
try:
|
||||
with self._path.open("r", encoding="utf-8") as fh:
|
||||
payload = json.load(fh)
|
||||
return self._state_from_payload(payload)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
pass
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
return ConversationState(
|
||||
conversation_id=self._conversation_id,
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
def _state_from_payload(self, payload: MutableMapping[str, object]) -> ConversationState:
|
||||
"""Normalize persisted data into ConversationState instances."""
|
||||
conversation_id = str(payload.get("id") or self._conversation_id)
|
||||
created_at = str(payload.get("created_at") or datetime.now(timezone.utc).isoformat())
|
||||
updated_at = str(payload.get("updated_at") or created_at)
|
||||
|
||||
messages_payload = payload.get("messages", [])
|
||||
messages: List[Dict[str, str]] = []
|
||||
if isinstance(messages_payload, list):
|
||||
for item in messages_payload:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
role = str(item.get("role", "")).lower()
|
||||
content = str(item.get("content", ""))
|
||||
if role not in self.VALID_ROLES:
|
||||
continue
|
||||
timestamp = str(
|
||||
item.get("timestamp") or datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
messages.append({"role": role, "content": content, "timestamp": timestamp})
|
||||
|
||||
return ConversationState(
|
||||
conversation_id=conversation_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
def _write_state(self) -> None:
|
||||
"""Persist the conversation state atomically."""
|
||||
payload = {
|
||||
"id": self._state.conversation_id,
|
||||
"created_at": self._state.created_at,
|
||||
"updated_at": self._state.updated_at,
|
||||
"messages": self._state.messages,
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
dir=self._storage_dir,
|
||||
delete=False,
|
||||
prefix=f"{self._conversation_id}.",
|
||||
suffix=".tmp",
|
||||
) as tmp_file:
|
||||
json.dump(payload, tmp_file, indent=2, ensure_ascii=False)
|
||||
tmp_file.flush()
|
||||
os.fsync(tmp_file.fileno())
|
||||
|
||||
os.replace(tmp_file.name, self._path)
|
||||
|
||||
141
ollama_client.py
141
ollama_client.py
@@ -1,8 +1,143 @@
|
||||
"""Client utilities for interacting with the Ollama API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, Iterator
|
||||
|
||||
try: # pragma: no cover - optional dependency may not be installed in CI
|
||||
import ollama
|
||||
except ImportError: # pragma: no cover - fallback path for environments without Ollama
|
||||
ollama = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class OllamaClientError(RuntimeError):
|
||||
"""Base exception raised when Ollama operations fail."""
|
||||
|
||||
|
||||
class OllamaUnavailableError(OllamaClientError):
|
||||
"""Raised when the Ollama Python SDK is not available."""
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Placeholder client used until streaming integration is built."""
|
||||
"""Thin wrapper around the Ollama Python SDK with graceful degradation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError("Ollama client not implemented yet.")
|
||||
def __init__(self, host: str | None = None) -> None:
|
||||
self._host = host
|
||||
self._client = None
|
||||
self._cached_models: list[str] | None = None
|
||||
|
||||
if ollama is None:
|
||||
return
|
||||
|
||||
if host and hasattr(ollama, "Client"):
|
||||
self._client = ollama.Client(host=host) # type: ignore[call-arg]
|
||||
|
||||
# ------------------------------------------------------------------ helpers
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return ollama is not None
|
||||
|
||||
@property
|
||||
def default_model(self) -> str | None:
|
||||
models = self.list_models()
|
||||
return models[0] if models else None
|
||||
|
||||
def list_models(self, force_refresh: bool = False) -> list[str]:
|
||||
"""Return the available model names, caching the result for quick reuse."""
|
||||
if not self.is_available:
|
||||
return []
|
||||
|
||||
if self._cached_models is not None and not force_refresh:
|
||||
return list(self._cached_models)
|
||||
|
||||
try:
|
||||
response = self._call_sdk("list") # type: ignore[arg-type]
|
||||
except OllamaClientError:
|
||||
return []
|
||||
|
||||
models: list[str] = []
|
||||
for item in response.get("models", []): # type: ignore[assignment]
|
||||
name = item.get("name") or item.get("model")
|
||||
if name:
|
||||
models.append(name)
|
||||
|
||||
self._cached_models = models
|
||||
return list(models)
|
||||
|
||||
# ------------------------------------------------------------------ chat APIs
|
||||
def chat(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
messages: Iterable[Dict[str, str]],
|
||||
) -> dict[str, str] | None:
|
||||
"""Execute a blocking chat call against Ollama."""
|
||||
if not self.is_available:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "Ollama SDK is not installed; install `ollama` to enable responses.",
|
||||
}
|
||||
|
||||
try:
|
||||
result = self._call_sdk(
|
||||
"chat",
|
||||
model=model,
|
||||
messages=list(messages),
|
||||
stream=False,
|
||||
)
|
||||
except OllamaClientError as exc:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": f"Unable to reach Ollama: {exc}",
|
||||
}
|
||||
|
||||
message = result.get("message") if isinstance(result, dict) else None
|
||||
if not message:
|
||||
return {"role": "assistant", "content": ""}
|
||||
|
||||
role = message.get("role") or "assistant"
|
||||
content = message.get("content") or ""
|
||||
return {"role": role, "content": content}
|
||||
|
||||
def stream_chat(
|
||||
self, *, model: str, messages: Iterable[Dict[str, str]]
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""Placeholder that exposes the streaming API for future UI hooks."""
|
||||
if not self.is_available:
|
||||
raise OllamaUnavailableError(
|
||||
"Streaming requires the Ollama Python SDK to be installed."
|
||||
)
|
||||
|
||||
try:
|
||||
stream = self._call_sdk(
|
||||
"chat",
|
||||
model=model,
|
||||
messages=list(messages),
|
||||
stream=True,
|
||||
)
|
||||
except OllamaClientError as exc:
|
||||
raise OllamaClientError(f"Failed to start streaming chat: {exc}") from exc
|
||||
|
||||
if not hasattr(stream, "__iter__"):
|
||||
raise OllamaClientError("Ollama returned a non-iterable stream response.")
|
||||
return iter(stream)
|
||||
|
||||
# ------------------------------------------------------------------ internals
|
||||
def _call_sdk(self, method: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if not self.is_available:
|
||||
raise OllamaUnavailableError(
|
||||
"Ollama Python SDK is not available in the environment."
|
||||
)
|
||||
|
||||
target = self._client if self._client is not None else ollama
|
||||
if target is None or not hasattr(target, method):
|
||||
raise OllamaClientError(
|
||||
f"Ollama SDK does not expose method '{method}'. Install or update the SDK."
|
||||
)
|
||||
|
||||
func = getattr(target, method)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exc: # pragma: no cover - network errors depend on runtime
|
||||
raise OllamaClientError(str(exc)) from exc
|
||||
|
||||
@@ -2,36 +2,208 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Iterable
|
||||
|
||||
import gi
|
||||
|
||||
gi.require_version("Gtk", "4.0")
|
||||
from gi.repository import Gtk # noqa: E402
|
||||
from gi.repository import GLib, Gtk # noqa: E402
|
||||
|
||||
try: # pragma: no cover - optional dependency may not be available in CI
|
||||
gi.require_version("Gtk4LayerShell", "1.0")
|
||||
from gi.repository import Gtk4LayerShell # type: ignore[attr-defined]
|
||||
except (ImportError, ValueError): # pragma: no cover - fallback path
|
||||
Gtk4LayerShell = None # type: ignore[misc]
|
||||
|
||||
from conversation_manager import ConversationManager
|
||||
from ollama_client import OllamaClient
|
||||
|
||||
|
||||
class SidebarWindow(Gtk.ApplicationWindow):
|
||||
"""Minimal window placeholder to confirm the GTK application starts."""
|
||||
"""Layer-shell anchored window hosting the chat interface."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.set_default_size(360, 640)
|
||||
self.set_default_size(360, 720)
|
||||
self.set_title("Niri AI Sidebar")
|
||||
self.set_hide_on_close(False)
|
||||
|
||||
layout = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=12)
|
||||
layout.set_margin_top(24)
|
||||
layout.set_margin_bottom(24)
|
||||
layout.set_margin_start(24)
|
||||
layout.set_margin_end(24)
|
||||
self._conversation_manager = ConversationManager()
|
||||
self._ollama_client = OllamaClient()
|
||||
self._current_model = self._ollama_client.default_model
|
||||
|
||||
title = Gtk.Label(label="AI Sidebar")
|
||||
title.set_halign(Gtk.Align.START)
|
||||
title.get_style_context().add_class("title-1")
|
||||
self._setup_layer_shell()
|
||||
self._build_ui()
|
||||
self._populate_initial_messages()
|
||||
|
||||
message = Gtk.Label(
|
||||
label="GTK app is running. Replace this view with the chat interface."
|
||||
# ------------------------------------------------------------------ UI setup
|
||||
def _setup_layer_shell(self) -> None:
|
||||
"""Attach the window to the left edge via gtk4-layer-shell when available."""
|
||||
if Gtk4LayerShell is None:
|
||||
return
|
||||
|
||||
Gtk4LayerShell.init_for_window(self)
|
||||
Gtk4LayerShell.set_namespace(self, "niri-ai-sidebar")
|
||||
Gtk4LayerShell.set_layer(self, Gtk4LayerShell.Layer.TOP)
|
||||
Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.LEFT, True)
|
||||
Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.TOP, True)
|
||||
Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.BOTTOM, True)
|
||||
Gtk4LayerShell.set_anchor(self, Gtk4LayerShell.Edge.RIGHT, False)
|
||||
Gtk4LayerShell.set_margin(self, Gtk4LayerShell.Edge.LEFT, 0)
|
||||
Gtk4LayerShell.set_keyboard_mode(
|
||||
self, Gtk4LayerShell.KeyboardMode.ON_DEMAND
|
||||
)
|
||||
message.set_wrap(True)
|
||||
message.set_halign(Gtk.Align.START)
|
||||
Gtk4LayerShell.set_exclusive_zone(self, -1)
|
||||
|
||||
layout.append(title)
|
||||
layout.append(message)
|
||||
self.set_child(layout)
|
||||
def _build_ui(self) -> None:
|
||||
"""Create the core layout: message history and input entry."""
|
||||
main_box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=12)
|
||||
main_box.set_margin_top(16)
|
||||
main_box.set_margin_bottom(16)
|
||||
main_box.set_margin_start(16)
|
||||
main_box.set_margin_end(16)
|
||||
main_box.set_hexpand(True)
|
||||
main_box.set_vexpand(True)
|
||||
|
||||
header_box = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=4)
|
||||
header_title = Gtk.Label(label="Niri AI Sidebar")
|
||||
header_title.set_halign(Gtk.Align.START)
|
||||
header_title.get_style_context().add_class("title-2")
|
||||
|
||||
model_name = self._current_model or "No local model detected"
|
||||
self._model_label = Gtk.Label(label=f"Model: {model_name}")
|
||||
self._model_label.set_halign(Gtk.Align.START)
|
||||
self._model_label.get_style_context().add_class("dim-label")
|
||||
|
||||
header_box.append(header_title)
|
||||
header_box.append(self._model_label)
|
||||
|
||||
self._message_list = Gtk.Box(orientation=Gtk.Orientation.VERTICAL, spacing=8)
|
||||
self._message_list.set_hexpand(True)
|
||||
self._message_list.set_vexpand(True)
|
||||
self._message_list.set_valign(Gtk.Align.START)
|
||||
|
||||
scroller = Gtk.ScrolledWindow()
|
||||
scroller.set_hexpand(True)
|
||||
scroller.set_vexpand(True)
|
||||
scroller.set_child(self._message_list)
|
||||
scroller.set_min_content_height(300)
|
||||
self._scroller = scroller
|
||||
|
||||
input_box = Gtk.Box(orientation=Gtk.Orientation.HORIZONTAL, spacing=8)
|
||||
input_box.set_hexpand(True)
|
||||
|
||||
self._entry = Gtk.Entry()
|
||||
self._entry.set_hexpand(True)
|
||||
self._entry.set_placeholder_text("Ask a question…")
|
||||
self._entry.connect("activate", self._on_submit)
|
||||
|
||||
self._send_button = Gtk.Button(label="Send")
|
||||
self._send_button.connect("clicked", self._on_submit)
|
||||
|
||||
input_box.append(self._entry)
|
||||
input_box.append(self._send_button)
|
||||
|
||||
main_box.append(header_box)
|
||||
main_box.append(scroller)
|
||||
main_box.append(input_box)
|
||||
|
||||
self.set_child(main_box)
|
||||
|
||||
def _populate_initial_messages(self) -> None:
|
||||
"""Render conversation history stored on disk."""
|
||||
for message in self._conversation_manager.messages:
|
||||
self._append_message(message["role"], message["content"], persist=False)
|
||||
|
||||
if not self._conversation_manager.messages:
|
||||
self._append_message(
|
||||
"assistant",
|
||||
"Welcome! Ask a question to start a conversation.",
|
||||
persist=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ helpers
|
||||
def _append_message(
|
||||
self, role: str, content: str, *, persist: bool = True
|
||||
) -> None:
|
||||
"""Add a message bubble to the history and optionally persist it."""
|
||||
label_prefix = "You" if role == "user" else "Assistant"
|
||||
label = Gtk.Label(label=f"{label_prefix}: {content}")
|
||||
label.set_halign(Gtk.Align.START)
|
||||
label.set_xalign(0.0)
|
||||
label.set_wrap(True)
|
||||
label.set_wrap_mode(Gtk.WrapMode.WORD_CHAR)
|
||||
label.set_justify(Gtk.Justification.LEFT)
|
||||
|
||||
self._message_list.append(label)
|
||||
self._scroll_to_bottom()
|
||||
|
||||
if persist:
|
||||
self._conversation_manager.append_message(role, content)
|
||||
|
||||
def _scroll_to_bottom(self) -> None:
|
||||
"""Ensure the most recent message is visible."""
|
||||
def _scroll() -> bool:
|
||||
adjustment = self._scroller.get_vadjustment()
|
||||
if adjustment is not None:
|
||||
adjustment.set_value(adjustment.get_upper() - adjustment.get_page_size())
|
||||
return False
|
||||
|
||||
GLib.idle_add(_scroll)
|
||||
|
||||
def _set_input_enabled(self, enabled: bool) -> None:
|
||||
self._entry.set_sensitive(enabled)
|
||||
self._send_button.set_sensitive(enabled)
|
||||
|
||||
# ------------------------------------------------------------------ callbacks
|
||||
def _on_submit(self, _widget: Gtk.Widget) -> None:
|
||||
"""Handle send button clicks or entry activation."""
|
||||
text = self._entry.get_text().strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
self._entry.set_text("")
|
||||
self._append_message("user", text, persist=True)
|
||||
self._request_response()
|
||||
|
||||
def _request_response(self) -> None:
|
||||
"""Trigger a synchronous Ollama chat call on a worker thread."""
|
||||
model = self._current_model or self._ollama_client.default_model
|
||||
if not model:
|
||||
self._append_message(
|
||||
"assistant",
|
||||
"No Ollama models are available. Install a model to continue.",
|
||||
persist=True,
|
||||
)
|
||||
return
|
||||
|
||||
history = self._conversation_manager.chat_messages
|
||||
self._set_input_enabled(False)
|
||||
|
||||
def _worker(messages: Iterable[dict[str, str]]) -> None:
|
||||
response = self._ollama_client.chat(model=model, messages=list(messages))
|
||||
GLib.idle_add(self._handle_response, response, priority=GLib.PRIORITY_DEFAULT)
|
||||
|
||||
thread = threading.Thread(target=_worker, args=(history,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def _handle_response(self, response: dict[str, str] | None) -> bool:
|
||||
"""Render the assistant reply and re-enable the entry."""
|
||||
self._set_input_enabled(True)
|
||||
|
||||
if not response:
|
||||
self._append_message(
|
||||
"assistant",
|
||||
"The model returned an empty response.",
|
||||
persist=True,
|
||||
)
|
||||
return False
|
||||
|
||||
role = response.get("role", "assistant")
|
||||
content = response.get("content") or ""
|
||||
if not content:
|
||||
content = "[No content received from Ollama]"
|
||||
|
||||
self._append_message(role, content, persist=True)
|
||||
return False
|
||||
|
||||
34
tests/test_conversation_manager.py
Normal file
34
tests/test_conversation_manager.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from conversation_manager import ConversationManager
|
||||
|
||||
|
||||
def test_conversation_manager_persists_history(tmp_path: Path) -> None:
|
||||
manager = ConversationManager(storage_dir=tmp_path, conversation_id="test")
|
||||
manager.append_message("user", "Hello there!")
|
||||
manager.append_message("assistant", "General Kenobi.")
|
||||
|
||||
conversation_file = tmp_path / "test.json"
|
||||
assert conversation_file.exists()
|
||||
|
||||
data = json.loads(conversation_file.read_text(encoding="utf-8"))
|
||||
assert len(data["messages"]) == 2
|
||||
assert data["messages"][0]["content"] == "Hello there!"
|
||||
|
||||
reloaded = ConversationManager(storage_dir=tmp_path, conversation_id="test")
|
||||
assert [msg["content"] for msg in reloaded.messages] == [
|
||||
"Hello there!",
|
||||
"General Kenobi.",
|
||||
]
|
||||
|
||||
|
||||
def test_conversation_manager_rejects_invalid_role(tmp_path: Path) -> None:
|
||||
manager = ConversationManager(storage_dir=tmp_path, conversation_id="invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
manager.append_message("narrator", "This should fail.")
|
||||
Reference in New Issue
Block a user