Files
niri-ai-sidebar/conversation_manager.py
Melvin Ragusa 239242e2fc refactor(aisidebar): restructure project and implement reasoning mode toggle
- Reorganize project structure and file locations
- Add ReasoningController to manage model selection and reasoning mode
- Update design and requirements for reasoning mode toggle
- Implement model switching between Qwen3-4B-Instruct and Qwen3-4B-Thinking models
- Remove deprecated files and consolidate project layout
- Add new steering and specification documentation
- Clean up and remove unnecessary files and directories
- Prepare for enhanced AI sidebar functionality with more flexible model handling
2025-10-26 09:10:31 +01:00

207 lines
7.3 KiB
Python

"""Conversation state management and persistence helpers."""
from __future__ import annotations
import json
import os
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import ClassVar, Dict, Iterable, List, MutableMapping
DEFAULT_CONVERSATION_ID = "default"
@dataclass
class ConversationState:
"""In-memory representation of a conversation transcript."""
conversation_id: str
created_at: str
updated_at: str
messages: List[Dict[str, str]] = field(default_factory=list)
class ConversationManager:
"""Load and persist conversation transcripts as JSON files."""
VALID_ROLES: ClassVar[set[str]] = {"system", "user", "assistant"}
def __init__(
self,
storage_dir: str | Path | None = None,
conversation_id: str | None = None,
) -> None:
module_root = Path(__file__).resolve().parent
default_storage = module_root / "data" / "conversations"
self._storage_dir = Path(storage_dir) if storage_dir else default_storage
self._storage_dir.mkdir(parents=True, exist_ok=True)
self._conversation_id = conversation_id or DEFAULT_CONVERSATION_ID
self._path = self._storage_dir / f"{self._conversation_id}.json"
self._state = self._load_state()
# ------------------------------------------------------------------ properties
@property
def conversation_id(self) -> str:
return self._state.conversation_id
@property
def messages(self) -> List[Dict[str, str]]:
return list(self._state.messages)
@property
def chat_messages(self) -> List[Dict[str, str]]:
"""Return messages formatted for the Ollama chat API."""
return [
{"role": msg["role"], "content": msg["content"]}
for msg in self._state.messages
]
# ------------------------------------------------------------------ public API
def append_message(self, role: str, content: str) -> Dict[str, str]:
"""Append a new message and persist the updated transcript."""
normalized_role = role.lower()
if normalized_role not in self.VALID_ROLES:
raise ValueError(f"Invalid role '{role}'. Expected one of {self.VALID_ROLES}.")
timestamp = datetime.now(timezone.utc).isoformat()
message = {
"role": normalized_role,
"content": content,
"timestamp": timestamp,
}
self._state.messages.append(message)
self._state.updated_at = timestamp
self._write_state()
return message
def replace_messages(self, messages: Iterable[Dict[str, str]]) -> None:
"""Replace the transcript contents. Useful for loading fixtures."""
normalized: List[Dict[str, str]] = []
for item in messages:
role = item.get("role", "").lower()
content = item.get("content", "")
if role not in self.VALID_ROLES:
continue
normalized.append(
{
"role": role,
"content": content,
"timestamp": item.get("timestamp")
or datetime.now(timezone.utc).isoformat(),
}
)
now = datetime.now(timezone.utc).isoformat()
self._state.messages = normalized
self._state.created_at = self._state.created_at or now
self._state.updated_at = now
self._write_state()
def clear_messages(self) -> None:
"""Clear all messages and reset the conversation state."""
timestamp = datetime.now(timezone.utc).isoformat()
self._state = ConversationState(
conversation_id=self._conversation_id,
created_at=timestamp,
updated_at=timestamp,
messages=[],
)
self._write_state()
def trim_to_recent(self, keep_count: int = 20) -> List[Dict[str, str]]:
"""Trim conversation to keep only recent messages, return removed messages.
Args:
keep_count: Number of recent messages to keep
Returns:
List of messages that were removed (older messages)
"""
if len(self._state.messages) <= keep_count:
return []
# Split messages into old and recent
removed_messages = self._state.messages[:-keep_count]
self._state.messages = self._state.messages[-keep_count:]
# Update state and persist
self._state.updated_at = datetime.now(timezone.utc).isoformat()
self._write_state()
return removed_messages
# ------------------------------------------------------------------ persistence
def _load_state(self) -> ConversationState:
"""Load the transcript from disk or create a fresh default."""
if self._path.exists():
try:
with self._path.open("r", encoding="utf-8") as fh:
payload = json.load(fh)
return self._state_from_payload(payload)
except (json.JSONDecodeError, OSError):
pass
timestamp = datetime.now(timezone.utc).isoformat()
return ConversationState(
conversation_id=self._conversation_id,
created_at=timestamp,
updated_at=timestamp,
messages=[],
)
def _state_from_payload(self, payload: MutableMapping[str, object]) -> ConversationState:
"""Normalize persisted data into ConversationState instances."""
conversation_id = str(payload.get("id") or self._conversation_id)
created_at = str(payload.get("created_at") or datetime.now(timezone.utc).isoformat())
updated_at = str(payload.get("updated_at") or created_at)
messages_payload = payload.get("messages", [])
messages: List[Dict[str, str]] = []
if isinstance(messages_payload, list):
for item in messages_payload:
if not isinstance(item, dict):
continue
role = str(item.get("role", "")).lower()
content = str(item.get("content", ""))
if role not in self.VALID_ROLES:
continue
timestamp = str(
item.get("timestamp") or datetime.now(timezone.utc).isoformat()
)
messages.append({"role": role, "content": content, "timestamp": timestamp})
return ConversationState(
conversation_id=conversation_id,
created_at=created_at,
updated_at=updated_at,
messages=messages,
)
def _write_state(self) -> None:
"""Persist the conversation state atomically."""
payload = {
"id": self._state.conversation_id,
"created_at": self._state.created_at,
"updated_at": self._state.updated_at,
"messages": self._state.messages,
}
with tempfile.NamedTemporaryFile(
"w",
encoding="utf-8",
dir=self._storage_dir,
delete=False,
prefix=f"{self._conversation_id}.",
suffix=".tmp",
) as tmp_file:
json.dump(payload, tmp_file, indent=2, ensure_ascii=False)
tmp_file.flush()
os.fsync(tmp_file.fileno())
os.replace(tmp_file.name, self._path)