174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
"""Conversation state management and persistence helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import ClassVar, Dict, Iterable, List, MutableMapping
|
|
|
|
DEFAULT_CONVERSATION_ID = "default"
|
|
|
|
|
|
@dataclass
|
|
class ConversationState:
|
|
"""In-memory representation of a conversation transcript."""
|
|
|
|
conversation_id: str
|
|
created_at: str
|
|
updated_at: str
|
|
messages: List[Dict[str, str]] = field(default_factory=list)
|
|
|
|
|
|
class ConversationManager:
|
|
"""Load and persist conversation transcripts as JSON files."""
|
|
|
|
VALID_ROLES: ClassVar[set[str]] = {"system", "user", "assistant"}
|
|
|
|
def __init__(
|
|
self,
|
|
storage_dir: str | Path | None = None,
|
|
conversation_id: str | None = None,
|
|
) -> None:
|
|
module_root = Path(__file__).resolve().parent
|
|
default_storage = module_root / "data" / "conversations"
|
|
self._storage_dir = Path(storage_dir) if storage_dir else default_storage
|
|
self._storage_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._conversation_id = conversation_id or DEFAULT_CONVERSATION_ID
|
|
self._path = self._storage_dir / f"{self._conversation_id}.json"
|
|
|
|
self._state = self._load_state()
|
|
|
|
# ------------------------------------------------------------------ properties
|
|
@property
|
|
def conversation_id(self) -> str:
|
|
return self._state.conversation_id
|
|
|
|
@property
|
|
def messages(self) -> List[Dict[str, str]]:
|
|
return list(self._state.messages)
|
|
|
|
@property
|
|
def chat_messages(self) -> List[Dict[str, str]]:
|
|
"""Return messages formatted for the Ollama chat API."""
|
|
return [
|
|
{"role": msg["role"], "content": msg["content"]}
|
|
for msg in self._state.messages
|
|
]
|
|
|
|
# ------------------------------------------------------------------ public API
|
|
def append_message(self, role: str, content: str) -> Dict[str, str]:
|
|
"""Append a new message and persist the updated transcript."""
|
|
normalized_role = role.lower()
|
|
if normalized_role not in self.VALID_ROLES:
|
|
raise ValueError(f"Invalid role '{role}'. Expected one of {self.VALID_ROLES}.")
|
|
|
|
timestamp = datetime.now(timezone.utc).isoformat()
|
|
message = {
|
|
"role": normalized_role,
|
|
"content": content,
|
|
"timestamp": timestamp,
|
|
}
|
|
|
|
self._state.messages.append(message)
|
|
self._state.updated_at = timestamp
|
|
self._write_state()
|
|
return message
|
|
|
|
def replace_messages(self, messages: Iterable[Dict[str, str]]) -> None:
|
|
"""Replace the transcript contents. Useful for loading fixtures."""
|
|
normalized: List[Dict[str, str]] = []
|
|
for item in messages:
|
|
role = item.get("role", "").lower()
|
|
content = item.get("content", "")
|
|
if role not in self.VALID_ROLES:
|
|
continue
|
|
normalized.append(
|
|
{
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": item.get("timestamp")
|
|
or datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
self._state.messages = normalized
|
|
self._state.created_at = self._state.created_at or now
|
|
self._state.updated_at = now
|
|
self._write_state()
|
|
|
|
# ------------------------------------------------------------------ persistence
|
|
def _load_state(self) -> ConversationState:
|
|
"""Load the transcript from disk or create a fresh default."""
|
|
if self._path.exists():
|
|
try:
|
|
with self._path.open("r", encoding="utf-8") as fh:
|
|
payload = json.load(fh)
|
|
return self._state_from_payload(payload)
|
|
except (json.JSONDecodeError, OSError):
|
|
pass
|
|
|
|
timestamp = datetime.now(timezone.utc).isoformat()
|
|
return ConversationState(
|
|
conversation_id=self._conversation_id,
|
|
created_at=timestamp,
|
|
updated_at=timestamp,
|
|
messages=[],
|
|
)
|
|
|
|
def _state_from_payload(self, payload: MutableMapping[str, object]) -> ConversationState:
|
|
"""Normalize persisted data into ConversationState instances."""
|
|
conversation_id = str(payload.get("id") or self._conversation_id)
|
|
created_at = str(payload.get("created_at") or datetime.now(timezone.utc).isoformat())
|
|
updated_at = str(payload.get("updated_at") or created_at)
|
|
|
|
messages_payload = payload.get("messages", [])
|
|
messages: List[Dict[str, str]] = []
|
|
if isinstance(messages_payload, list):
|
|
for item in messages_payload:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
role = str(item.get("role", "")).lower()
|
|
content = str(item.get("content", ""))
|
|
if role not in self.VALID_ROLES:
|
|
continue
|
|
timestamp = str(
|
|
item.get("timestamp") or datetime.now(timezone.utc).isoformat()
|
|
)
|
|
messages.append({"role": role, "content": content, "timestamp": timestamp})
|
|
|
|
return ConversationState(
|
|
conversation_id=conversation_id,
|
|
created_at=created_at,
|
|
updated_at=updated_at,
|
|
messages=messages,
|
|
)
|
|
|
|
def _write_state(self) -> None:
|
|
"""Persist the conversation state atomically."""
|
|
payload = {
|
|
"id": self._state.conversation_id,
|
|
"created_at": self._state.created_at,
|
|
"updated_at": self._state.updated_at,
|
|
"messages": self._state.messages,
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
"w",
|
|
encoding="utf-8",
|
|
dir=self._storage_dir,
|
|
delete=False,
|
|
prefix=f"{self._conversation_id}.",
|
|
suffix=".tmp",
|
|
) as tmp_file:
|
|
json.dump(payload, tmp_file, indent=2, ensure_ascii=False)
|
|
tmp_file.flush()
|
|
os.fsync(tmp_file.fileno())
|
|
|
|
os.replace(tmp_file.name, self._path)
|