Files
niri-ai-sidebar/aisidebar/conversation_manager.py
Melvin Ragusa 58bd935af0 feat(aisidebar): implement Ollama availability handling and graceful startup
- Add comprehensive Ollama connection error handling strategy
- Implement OllamaClient with non-blocking initialization and connection checks
- Create OllamaAvailabilityMonitor for periodic Ollama connection tracking
- Update design and requirements to support graceful Ollama unavailability
- Add new project structure for AI sidebar module with initial implementation
- Enhance error handling to prevent application crashes when Ollama is not running
- Prepare for future improvements in AI sidebar interaction and resilience
2025-10-25 22:28:54 +02:00

174 lines
6.1 KiB
Python

"""Conversation state management and persistence helpers."""
from __future__ import annotations
import json
import os
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import ClassVar, Dict, Iterable, List, MutableMapping
DEFAULT_CONVERSATION_ID = "default"
@dataclass
class ConversationState:
"""In-memory representation of a conversation transcript."""
conversation_id: str
created_at: str
updated_at: str
messages: List[Dict[str, str]] = field(default_factory=list)
class ConversationManager:
"""Load and persist conversation transcripts as JSON files."""
VALID_ROLES: ClassVar[set[str]] = {"system", "user", "assistant"}
def __init__(
self,
storage_dir: str | Path | None = None,
conversation_id: str | None = None,
) -> None:
module_root = Path(__file__).resolve().parent
default_storage = module_root / "data" / "conversations"
self._storage_dir = Path(storage_dir) if storage_dir else default_storage
self._storage_dir.mkdir(parents=True, exist_ok=True)
self._conversation_id = conversation_id or DEFAULT_CONVERSATION_ID
self._path = self._storage_dir / f"{self._conversation_id}.json"
self._state = self._load_state()
# ------------------------------------------------------------------ properties
@property
def conversation_id(self) -> str:
return self._state.conversation_id
@property
def messages(self) -> List[Dict[str, str]]:
return list(self._state.messages)
@property
def chat_messages(self) -> List[Dict[str, str]]:
"""Return messages formatted for the Ollama chat API."""
return [
{"role": msg["role"], "content": msg["content"]}
for msg in self._state.messages
]
# ------------------------------------------------------------------ public API
def append_message(self, role: str, content: str) -> Dict[str, str]:
"""Append a new message and persist the updated transcript."""
normalized_role = role.lower()
if normalized_role not in self.VALID_ROLES:
raise ValueError(f"Invalid role '{role}'. Expected one of {self.VALID_ROLES}.")
timestamp = datetime.now(timezone.utc).isoformat()
message = {
"role": normalized_role,
"content": content,
"timestamp": timestamp,
}
self._state.messages.append(message)
self._state.updated_at = timestamp
self._write_state()
return message
def replace_messages(self, messages: Iterable[Dict[str, str]]) -> None:
"""Replace the transcript contents. Useful for loading fixtures."""
normalized: List[Dict[str, str]] = []
for item in messages:
role = item.get("role", "").lower()
content = item.get("content", "")
if role not in self.VALID_ROLES:
continue
normalized.append(
{
"role": role,
"content": content,
"timestamp": item.get("timestamp")
or datetime.now(timezone.utc).isoformat(),
}
)
now = datetime.now(timezone.utc).isoformat()
self._state.messages = normalized
self._state.created_at = self._state.created_at or now
self._state.updated_at = now
self._write_state()
# ------------------------------------------------------------------ persistence
def _load_state(self) -> ConversationState:
"""Load the transcript from disk or create a fresh default."""
if self._path.exists():
try:
with self._path.open("r", encoding="utf-8") as fh:
payload = json.load(fh)
return self._state_from_payload(payload)
except (json.JSONDecodeError, OSError):
pass
timestamp = datetime.now(timezone.utc).isoformat()
return ConversationState(
conversation_id=self._conversation_id,
created_at=timestamp,
updated_at=timestamp,
messages=[],
)
def _state_from_payload(self, payload: MutableMapping[str, object]) -> ConversationState:
"""Normalize persisted data into ConversationState instances."""
conversation_id = str(payload.get("id") or self._conversation_id)
created_at = str(payload.get("created_at") or datetime.now(timezone.utc).isoformat())
updated_at = str(payload.get("updated_at") or created_at)
messages_payload = payload.get("messages", [])
messages: List[Dict[str, str]] = []
if isinstance(messages_payload, list):
for item in messages_payload:
if not isinstance(item, dict):
continue
role = str(item.get("role", "")).lower()
content = str(item.get("content", ""))
if role not in self.VALID_ROLES:
continue
timestamp = str(
item.get("timestamp") or datetime.now(timezone.utc).isoformat()
)
messages.append({"role": role, "content": content, "timestamp": timestamp})
return ConversationState(
conversation_id=conversation_id,
created_at=created_at,
updated_at=updated_at,
messages=messages,
)
def _write_state(self) -> None:
"""Persist the conversation state atomically."""
payload = {
"id": self._state.conversation_id,
"created_at": self._state.created_at,
"updated_at": self._state.updated_at,
"messages": self._state.messages,
}
with tempfile.NamedTemporaryFile(
"w",
encoding="utf-8",
dir=self._storage_dir,
delete=False,
prefix=f"{self._conversation_id}.",
suffix=".tmp",
) as tmp_file:
json.dump(payload, tmp_file, indent=2, ensure_ascii=False)
tmp_file.flush()
os.fsync(tmp_file.fileno())
os.replace(tmp_file.name, self._path)