"""Conversation state management and persistence helpers.""" from __future__ import annotations import json import os import tempfile from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import ClassVar, Dict, Iterable, List, MutableMapping DEFAULT_CONVERSATION_ID = "default" @dataclass class ConversationState: """In-memory representation of a conversation transcript.""" conversation_id: str created_at: str updated_at: str messages: List[Dict[str, str]] = field(default_factory=list) class ConversationManager: """Load and persist conversation transcripts as JSON files.""" VALID_ROLES: ClassVar[set[str]] = {"system", "user", "assistant"} def __init__( self, storage_dir: str | Path | None = None, conversation_id: str | None = None, ) -> None: module_root = Path(__file__).resolve().parent default_storage = module_root / "data" / "conversations" self._storage_dir = Path(storage_dir) if storage_dir else default_storage self._storage_dir.mkdir(parents=True, exist_ok=True) self._conversation_id = conversation_id or DEFAULT_CONVERSATION_ID self._path = self._storage_dir / f"{self._conversation_id}.json" self._state = self._load_state() # ------------------------------------------------------------------ properties @property def conversation_id(self) -> str: return self._state.conversation_id @property def messages(self) -> List[Dict[str, str]]: return list(self._state.messages) @property def chat_messages(self) -> List[Dict[str, str]]: """Return messages formatted for the Ollama chat API.""" return [ {"role": msg["role"], "content": msg["content"]} for msg in self._state.messages ] # ------------------------------------------------------------------ public API def append_message(self, role: str, content: str) -> Dict[str, str]: """Append a new message and persist the updated transcript.""" normalized_role = role.lower() if normalized_role not in self.VALID_ROLES: raise ValueError(f"Invalid role '{role}'. Expected one of {self.VALID_ROLES}.") timestamp = datetime.now(timezone.utc).isoformat() message = { "role": normalized_role, "content": content, "timestamp": timestamp, } self._state.messages.append(message) self._state.updated_at = timestamp self._write_state() return message def replace_messages(self, messages: Iterable[Dict[str, str]]) -> None: """Replace the transcript contents. Useful for loading fixtures.""" normalized: List[Dict[str, str]] = [] for item in messages: role = item.get("role", "").lower() content = item.get("content", "") if role not in self.VALID_ROLES: continue normalized.append( { "role": role, "content": content, "timestamp": item.get("timestamp") or datetime.now(timezone.utc).isoformat(), } ) now = datetime.now(timezone.utc).isoformat() self._state.messages = normalized self._state.created_at = self._state.created_at or now self._state.updated_at = now self._write_state() def clear_messages(self) -> None: """Clear all messages and reset the conversation state.""" timestamp = datetime.now(timezone.utc).isoformat() self._state = ConversationState( conversation_id=self._conversation_id, created_at=timestamp, updated_at=timestamp, messages=[], ) self._write_state() def trim_to_recent(self, keep_count: int = 20) -> List[Dict[str, str]]: """Trim conversation to keep only recent messages, return removed messages. Args: keep_count: Number of recent messages to keep Returns: List of messages that were removed (older messages) """ if len(self._state.messages) <= keep_count: return [] # Split messages into old and recent removed_messages = self._state.messages[:-keep_count] self._state.messages = self._state.messages[-keep_count:] # Update state and persist self._state.updated_at = datetime.now(timezone.utc).isoformat() self._write_state() return removed_messages # ------------------------------------------------------------------ persistence def _load_state(self) -> ConversationState: """Load the transcript from disk or create a fresh default.""" if self._path.exists(): try: with self._path.open("r", encoding="utf-8") as fh: payload = json.load(fh) return self._state_from_payload(payload) except (json.JSONDecodeError, OSError): pass timestamp = datetime.now(timezone.utc).isoformat() return ConversationState( conversation_id=self._conversation_id, created_at=timestamp, updated_at=timestamp, messages=[], ) def _state_from_payload(self, payload: MutableMapping[str, object]) -> ConversationState: """Normalize persisted data into ConversationState instances.""" conversation_id = str(payload.get("id") or self._conversation_id) created_at = str(payload.get("created_at") or datetime.now(timezone.utc).isoformat()) updated_at = str(payload.get("updated_at") or created_at) messages_payload = payload.get("messages", []) messages: List[Dict[str, str]] = [] if isinstance(messages_payload, list): for item in messages_payload: if not isinstance(item, dict): continue role = str(item.get("role", "")).lower() content = str(item.get("content", "")) if role not in self.VALID_ROLES: continue timestamp = str( item.get("timestamp") or datetime.now(timezone.utc).isoformat() ) messages.append({"role": role, "content": content, "timestamp": timestamp}) return ConversationState( conversation_id=conversation_id, created_at=created_at, updated_at=updated_at, messages=messages, ) def _write_state(self) -> None: """Persist the conversation state atomically.""" payload = { "id": self._state.conversation_id, "created_at": self._state.created_at, "updated_at": self._state.updated_at, "messages": self._state.messages, } with tempfile.NamedTemporaryFile( "w", encoding="utf-8", dir=self._storage_dir, delete=False, prefix=f"{self._conversation_id}.", suffix=".tmp", ) as tmp_file: json.dump(payload, tmp_file, indent=2, ensure_ascii=False) tmp_file.flush() os.fsync(tmp_file.fileno()) os.replace(tmp_file.name, self._path)