Files
LangBot/src/langbot/pkg/agent/runner/transcript_store.py

342 lines
12 KiB
Python

"""Transcript store for writing and querying conversation history."""
from __future__ import annotations
import json
import datetime
import typing
import uuid
import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.orm import sessionmaker
from ...entity.persistence.transcript import Transcript
from langbot_plugin.api.entities.builtin.provider import message as provider_message
class TranscriptStore:
"""Store for Transcript records.
Handles writing transcript items and querying them for history API.
All methods are async and use the provided database engine.
"""
engine: AsyncEngine
# Hard limits
MAX_CONTENT_LENGTH = 4000
HARD_LIMIT = 100
def __init__(self, engine: AsyncEngine):
self.engine = engine
self._session_factory = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
async def append_transcript(
self,
transcript_id: str | None,
event_id: str,
conversation_id: str,
role: str,
content: str | None = None,
content_json: dict[str, typing.Any] | None = None,
artifact_refs: list[dict[str, typing.Any]] | None = None,
thread_id: str | None = None,
item_type: str = "message",
run_id: str | None = None,
runner_id: str | None = None,
metadata: dict[str, typing.Any] | None = None,
) -> str:
"""Append a transcript item.
Args:
transcript_id: Unique transcript ID (generated if None)
event_id: Source event ID
conversation_id: Conversation ID
role: Message role (user, assistant, system, tool)
content: Text content
content_json: Full structured content
artifact_refs: Artifact references
thread_id: Thread ID
item_type: Item type
run_id: Run ID that generated this
runner_id: Runner ID that generated this
metadata: Additional metadata
Returns:
The transcript_id
"""
if transcript_id is None:
transcript_id = str(uuid.uuid4())
# Truncate content if too long
if content and len(content) > self.MAX_CONTENT_LENGTH:
content = content[:self.MAX_CONTENT_LENGTH - 3] + "..."
async with self._session_factory() as session:
item = Transcript(
transcript_id=transcript_id,
event_id=event_id,
conversation_id=conversation_id,
thread_id=thread_id,
role=role,
item_type=item_type,
content=content,
content_json=json.dumps(content_json) if content_json else None,
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
seq=0,
run_id=run_id,
runner_id=runner_id,
created_at=datetime.datetime.utcnow(),
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(item)
await session.flush()
item.seq = item.id or await self._get_next_seq(conversation_id)
await session.commit()
return transcript_id
async def page_transcript(
self,
conversation_id: str,
before_seq: int | None = None,
after_seq: int | None = None,
limit: int = 50,
direction: str = "backward",
include_artifacts: bool = False,
) -> tuple[list[dict[str, typing.Any]], int | None, int | None, bool]:
"""Page through transcript items.
Args:
conversation_id: Conversation ID
before_seq: Get items before this sequence (backward)
after_seq: Get items after this sequence (forward)
limit: Maximum items to return (capped at 100)
direction: 'backward' (older) or 'forward' (newer)
include_artifacts: Include artifact refs
Returns:
Tuple of (items, next_seq, prev_seq, has_more)
"""
limit = min(limit, self.HARD_LIMIT)
async with self._session_factory() as session:
query = sqlalchemy.select(Transcript).where(
Transcript.conversation_id == conversation_id
)
if direction == "backward" and before_seq is not None:
query = query.where(Transcript.seq < before_seq)
query = query.order_by(Transcript.seq.desc())
elif direction == "forward" and after_seq is not None:
query = query.where(Transcript.seq > after_seq)
query = query.order_by(Transcript.seq.asc())
else:
# Default: most recent items first (backward from latest)
query = query.order_by(Transcript.seq.desc())
query = query.limit(limit + 1)
result = await session.execute(query)
rows = result.scalars().all()
items = [self._row_to_dict(row, include_artifacts) for row in rows[:limit]]
has_more = len(rows) > limit
# Calculate cursors
next_seq = None
prev_seq = None
if direction == "backward":
# Items are in descending order
if items:
next_seq = items[-1].get('seq') if has_more else None
prev_seq = items[0].get('seq')
else:
# Items are in ascending order
if items:
next_seq = items[-1].get('seq') if has_more else None
prev_seq = items[0].get('seq')
return items, next_seq, prev_seq, has_more
async def search_transcript(
self,
conversation_id: str,
query_text: str,
filters: dict[str, typing.Any] | None = None,
top_k: int = 10,
) -> list[dict[str, typing.Any]]:
"""Search transcript items.
Basic implementation using LIKE filtering.
Args:
conversation_id: Conversation ID
query_text: Search query
filters: Optional filters
top_k: Maximum results
Returns:
List of matching items
"""
async with self._session_factory() as session:
query = sqlalchemy.select(Transcript).where(
Transcript.conversation_id == conversation_id,
Transcript.content.ilike(f"%{query_text}%"),
)
# Apply additional filters
if filters:
if 'roles' in filters:
query = query.where(Transcript.role.in_(filters['roles']))
if 'item_types' in filters:
query = query.where(Transcript.item_type.in_(filters['item_types']))
query = query.order_by(Transcript.seq.desc()).limit(top_k)
result = await session.execute(query)
rows = result.scalars().all()
return [self._row_to_dict(row, include_artifacts=True) for row in rows]
async def get_latest_cursor(
self,
conversation_id: str,
) -> str | None:
"""Get the latest cursor for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Cursor string (seq number), or None if no items
"""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(Transcript.seq)
.where(Transcript.conversation_id == conversation_id)
.order_by(Transcript.seq.desc())
.limit(1)
)
row = result.scalars().first()
if row is None:
return None
return str(row)
async def get_legacy_provider_messages(
self,
conversation_id: str,
limit: int = HARD_LIMIT,
) -> list[provider_message.Message]:
"""Project Transcript rows into the legacy provider Message view.
AgentRunner history is canonical in Transcript. This view exists for
legacy Pipeline readers such as PromptPreProcessing that still expect
query.messages.
"""
items, _, _, _ = await self.page_transcript(
conversation_id=conversation_id,
limit=limit,
direction="backward",
)
messages: list[provider_message.Message] = []
for item in reversed(items):
message = self._transcript_item_to_provider_message(item)
if message is not None:
messages.append(message)
return messages
async def has_history_before(
self,
conversation_id: str,
seq: int,
) -> bool:
"""Check if there is history before a sequence number.
Args:
conversation_id: Conversation ID
seq: Sequence number
Returns:
True if there are items before
"""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(sqlalchemy.func.count())
.select_from(Transcript)
.where(
Transcript.conversation_id == conversation_id,
Transcript.seq < seq,
)
)
count = result.scalar()
return count > 0
async def _get_next_seq(self, conversation_id: str) -> int:
"""Fallback next sequence number for stores that cannot expose autoincrement IDs."""
async with self._session_factory() as session:
result = await session.execute(
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
.where(Transcript.conversation_id == conversation_id)
)
max_seq = result.scalar()
return (max_seq or 0) + 1
def _row_to_dict(
self,
row: Transcript,
include_artifacts: bool = False,
) -> dict[str, typing.Any]:
"""Convert a Transcript row to dict."""
result = {
'transcript_id': row.transcript_id,
'event_id': row.event_id,
'conversation_id': row.conversation_id,
'thread_id': row.thread_id,
'role': row.role,
'item_type': row.item_type,
'content': row.content,
'content_json': json.loads(row.content_json) if row.content_json else None,
'seq': row.seq,
'cursor': str(row.seq),
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
}
if include_artifacts and row.artifact_refs_json:
result['artifact_refs'] = json.loads(row.artifact_refs_json)
else:
result['artifact_refs'] = []
return result
def _transcript_item_to_provider_message(
self,
item: dict[str, typing.Any],
) -> provider_message.Message | None:
"""Convert one Transcript API item into a provider Message."""
if item.get('item_type') != 'message':
return None
role = item.get('role')
if role not in {'user', 'assistant'}:
return None
content_json = item.get('content_json')
if isinstance(content_json, dict):
message_data = dict(content_json)
message_data['role'] = role
try:
return provider_message.Message.model_validate(message_data)
except Exception:
pass
content = item.get('content')
if content is None:
return None
return provider_message.Message(role=role, content=content)