mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-12 08:46:02 +00:00
feat(agent-runner): add artifact store pull APIs
This commit is contained in:
300
src/langbot/pkg/agent/runner/artifact_store.py
Normal file
300
src/langbot/pkg/agent/runner/artifact_store.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Artifact store for managing Host-owned artifacts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import datetime
|
||||
import typing
|
||||
import uuid
|
||||
import base64
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.artifact import AgentArtifact
|
||||
from ...entity.persistence.bstorage import BinaryStorage
|
||||
|
||||
|
||||
class ArtifactStore:
|
||||
"""Store for AgentArtifact records.
|
||||
|
||||
Handles artifact metadata registration and content retrieval.
|
||||
Actual blob storage is delegated to BinaryStorage or external storage.
|
||||
|
||||
All methods are async and use the provided database engine.
|
||||
"""
|
||||
|
||||
engine: AsyncEngine
|
||||
|
||||
# Hard limits
|
||||
MAX_INLINE_READ_BYTES = 1024 * 1024 # 1MB max for inline base64
|
||||
MAX_RANGE_READ_BYTES = 10 * 1024 * 1024 # 10MB max for range reads
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def register_artifact(
|
||||
self,
|
||||
artifact_id: str | None,
|
||||
artifact_type: str,
|
||||
source: str,
|
||||
storage_key: str | None = None,
|
||||
storage_type: str = 'binary_storage',
|
||||
mime_type: str | None = None,
|
||||
name: str | None = None,
|
||||
size_bytes: int | None = None,
|
||||
sha256: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
run_id: str | None = None,
|
||||
runner_id: str | None = None,
|
||||
bot_id: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
expires_at: datetime.datetime | None = None,
|
||||
metadata: dict[str, typing.Any] | None = None,
|
||||
content: bytes | None = None,
|
||||
) -> str:
|
||||
"""Register a new artifact.
|
||||
|
||||
If content is provided and storage_key is None, stores content
|
||||
in BinaryStorage automatically.
|
||||
|
||||
Args:
|
||||
artifact_id: Unique artifact ID (generated if None)
|
||||
artifact_type: Type of artifact (image, file, voice, tool_result, etc.)
|
||||
source: Source of artifact (platform, runner, tool, system)
|
||||
storage_key: Key in BinaryStorage or external reference
|
||||
storage_type: Storage type (binary_storage, file, url)
|
||||
mime_type: MIME type
|
||||
name: Original file name
|
||||
size_bytes: Size in bytes
|
||||
sha256: SHA256 hash
|
||||
conversation_id: Conversation ID
|
||||
run_id: Run ID that created this
|
||||
runner_id: Runner ID that created this
|
||||
bot_id: Bot UUID
|
||||
workspace_id: Workspace ID
|
||||
expires_at: Expiration time
|
||||
metadata: Additional metadata
|
||||
content: Optional content to store in BinaryStorage
|
||||
|
||||
Returns:
|
||||
The artifact_id
|
||||
"""
|
||||
if artifact_id is None:
|
||||
artifact_id = str(uuid.uuid4())
|
||||
|
||||
# If content provided, store in BinaryStorage
|
||||
if content is not None and storage_key is None:
|
||||
storage_key = f"artifact:{artifact_id}"
|
||||
storage_type = 'binary_storage'
|
||||
if size_bytes is None:
|
||||
size_bytes = len(content)
|
||||
|
||||
async with self._session_factory() as session:
|
||||
# Store content in BinaryStorage if provided
|
||||
if content is not None:
|
||||
binary_storage = BinaryStorage(
|
||||
unique_key=f'artifact:{artifact_id}',
|
||||
key=storage_key,
|
||||
owner_type='artifact',
|
||||
owner='host',
|
||||
value=content,
|
||||
)
|
||||
session.add(binary_storage)
|
||||
|
||||
# Store artifact metadata
|
||||
artifact = AgentArtifact(
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
mime_type=mime_type,
|
||||
name=name,
|
||||
size_bytes=size_bytes,
|
||||
sha256=sha256,
|
||||
source=source,
|
||||
storage_key=storage_key,
|
||||
storage_type=storage_type,
|
||||
conversation_id=conversation_id,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
expires_at=expires_at,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
session.add(artifact)
|
||||
await session.commit()
|
||||
|
||||
return artifact_id
|
||||
|
||||
async def get_metadata(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Get artifact metadata (public fields only, no internal storage info).
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact ID
|
||||
|
||||
Returns:
|
||||
Artifact metadata dict compatible with SDK ArtifactMetadata, or None if not found
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(AgentArtifact).where(
|
||||
AgentArtifact.artifact_id == artifact_id
|
||||
)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_public_dict(row)
|
||||
|
||||
async def _get_internal_record(
|
||||
self,
|
||||
artifact_id: str,
|
||||
) -> AgentArtifact | None:
|
||||
"""Get full artifact record including internal fields.
|
||||
|
||||
Used internally by read_artifact to access storage_key/storage_type.
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact ID
|
||||
|
||||
Returns:
|
||||
AgentArtifact ORM instance, or None if not found
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(AgentArtifact).where(
|
||||
AgentArtifact.artifact_id == artifact_id
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
artifact_id: str,
|
||||
offset: int = 0,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, typing.Any] | None:
|
||||
"""Read artifact content.
|
||||
|
||||
For small artifacts, returns content_base64 directly.
|
||||
For large artifacts, returns file_key for chunked transfer.
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact ID
|
||||
offset: Byte offset to start reading from (must be >= 0)
|
||||
limit: Maximum bytes to read (must be > 0 if provided)
|
||||
|
||||
Returns:
|
||||
ArtifactReadResult dict, or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If offset < 0 or limit <= 0
|
||||
"""
|
||||
# Validate offset and limit
|
||||
if offset < 0:
|
||||
raise ValueError("offset must be >= 0")
|
||||
|
||||
if limit is not None and limit <= 0:
|
||||
raise ValueError("limit must be > 0")
|
||||
|
||||
# Get internal record (includes storage_key/storage_type)
|
||||
record = await self._get_internal_record(artifact_id)
|
||||
if record is None:
|
||||
return None
|
||||
|
||||
storage_type = record.storage_type or 'binary_storage'
|
||||
storage_key = record.storage_key
|
||||
size_bytes = record.size_bytes or 0
|
||||
|
||||
# Cap limit at hard limit
|
||||
if limit is None:
|
||||
limit = self.MAX_INLINE_READ_BYTES
|
||||
limit = min(limit, self.MAX_RANGE_READ_BYTES)
|
||||
|
||||
# For binary_storage, read content
|
||||
if storage_type == 'binary_storage' and storage_key:
|
||||
content = await self._read_binary_storage(storage_key)
|
||||
if content is None:
|
||||
return None
|
||||
|
||||
# Apply offset and limit
|
||||
if offset > 0:
|
||||
content = content[offset:]
|
||||
if limit and len(content) > limit:
|
||||
content = content[:limit]
|
||||
has_more = True
|
||||
else:
|
||||
has_more = False
|
||||
|
||||
return {
|
||||
'artifact_id': artifact_id,
|
||||
'mime_type': record.mime_type,
|
||||
'size_bytes': size_bytes,
|
||||
'offset': offset,
|
||||
'length': len(content),
|
||||
'content_base64': base64.b64encode(content).decode('utf-8'),
|
||||
'file_key': None,
|
||||
'has_more': has_more,
|
||||
}
|
||||
|
||||
# For other storage types, return storage reference
|
||||
# (caller can use file_key for chunked transfer)
|
||||
return {
|
||||
'artifact_id': artifact_id,
|
||||
'mime_type': record.mime_type,
|
||||
'size_bytes': size_bytes,
|
||||
'offset': offset,
|
||||
'length': None,
|
||||
'content_base64': None,
|
||||
'file_key': storage_key,
|
||||
'has_more': False,
|
||||
}
|
||||
|
||||
async def _read_binary_storage(self, key: str) -> bytes | None:
|
||||
"""Read content from BinaryStorage.
|
||||
|
||||
Uses unique_key for isolation to prevent cross-artifact access.
|
||||
|
||||
Args:
|
||||
key: The unique_key used when storing the artifact
|
||||
|
||||
Returns:
|
||||
Content bytes, or None if not found
|
||||
"""
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(BinaryStorage).where(BinaryStorage.unique_key == key)
|
||||
)
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return row.value
|
||||
|
||||
def _row_to_public_dict(self, row: AgentArtifact) -> dict[str, typing.Any]:
|
||||
"""Convert an AgentArtifact row to public dict.
|
||||
|
||||
Returns only fields that match SDK ArtifactMetadata entity.
|
||||
Host-only fields (bot_id, workspace_id, storage_key, storage_type) are excluded.
|
||||
"""
|
||||
return {
|
||||
'artifact_id': row.artifact_id,
|
||||
'artifact_type': row.artifact_type,
|
||||
'mime_type': row.mime_type,
|
||||
'name': row.name,
|
||||
'size_bytes': row.size_bytes,
|
||||
'sha256': row.sha256,
|
||||
'source': row.source,
|
||||
'conversation_id': row.conversation_id,
|
||||
'run_id': row.run_id,
|
||||
'runner_id': row.runner_id,
|
||||
'created_at': int(row.created_at.timestamp()) if row.created_at else None,
|
||||
'expires_at': int(row.expires_at.timestamp()) if row.expires_at else None,
|
||||
'metadata': json.loads(row.metadata_json) if row.metadata_json else {},
|
||||
}
|
||||
@@ -891,11 +891,14 @@ class AgentRunContextBuilder:
|
||||
permissions = descriptor.permissions or {}
|
||||
history_permissions = permissions.get('history', [])
|
||||
event_permissions = permissions.get('events', [])
|
||||
artifact_permissions = permissions.get('artifacts', [])
|
||||
|
||||
history_page_enabled = 'page' in history_permissions and conversation_id is not None
|
||||
history_search_enabled = 'search' in history_permissions and conversation_id is not None
|
||||
event_get_enabled = 'get' in event_permissions
|
||||
event_page_enabled = 'page' in event_permissions and conversation_id is not None
|
||||
artifact_metadata_enabled = 'metadata' in artifact_permissions
|
||||
artifact_read_enabled = 'read' in artifact_permissions
|
||||
|
||||
# Get latest cursor and has_history_before if conversation exists
|
||||
latest_cursor = None
|
||||
@@ -931,8 +934,8 @@ class AgentRunContextBuilder:
|
||||
'history_search': history_search_enabled,
|
||||
'event_get': event_get_enabled,
|
||||
'event_page': event_page_enabled,
|
||||
'artifact_metadata': False, # TODO: Implement artifact store
|
||||
'artifact_read': False,
|
||||
'artifact_metadata': artifact_metadata_enabled,
|
||||
'artifact_read': artifact_read_enabled,
|
||||
'state': True,
|
||||
'storage': True,
|
||||
},
|
||||
|
||||
@@ -7,7 +7,8 @@ import typing
|
||||
import uuid
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.event_log import EventLog
|
||||
from ...entity.persistence.transcript import Transcript
|
||||
@@ -27,6 +28,9 @@ class EventLogStore:
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def append_event(
|
||||
self,
|
||||
@@ -83,32 +87,31 @@ class EventLogStore:
|
||||
if input_summary and len(input_summary) > self.MAX_INPUT_SUMMARY_LENGTH:
|
||||
input_summary = input_summary[:self.MAX_INPUT_SUMMARY_LENGTH - 3] + "..."
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
await conn.execute(
|
||||
sqlalchemy.insert(EventLog).values(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
event_time=event_time,
|
||||
source=source,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=thread_id,
|
||||
actor_type=actor_type,
|
||||
actor_id=actor_id,
|
||||
actor_name=actor_name,
|
||||
subject_type=subject_type,
|
||||
subject_id=subject_id,
|
||||
input_summary=input_summary,
|
||||
input_json=json.dumps(input_json) if input_json else None,
|
||||
raw_ref=raw_ref,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
)
|
||||
async with self._session_factory() as session:
|
||||
event = EventLog(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
event_time=event_time,
|
||||
source=source,
|
||||
bot_id=bot_id,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=thread_id,
|
||||
actor_type=actor_type,
|
||||
actor_id=actor_id,
|
||||
actor_name=actor_name,
|
||||
subject_type=subject_type,
|
||||
subject_id=subject_id,
|
||||
input_summary=input_summary,
|
||||
input_json=json.dumps(input_json) if input_json else None,
|
||||
raw_ref=raw_ref,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
)
|
||||
await conn.commit()
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
|
||||
return event_id
|
||||
|
||||
@@ -124,14 +127,14 @@ class EventLogStore:
|
||||
Returns:
|
||||
Event record as dict, or None if not found
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(EventLog).where(EventLog.event_id == event_id)
|
||||
)
|
||||
row = result.fetchone()
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_dict(row[0])
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def page_events(
|
||||
self,
|
||||
@@ -153,7 +156,7 @@ class EventLogStore:
|
||||
"""
|
||||
limit = min(limit, 100) # Hard cap
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(EventLog)
|
||||
|
||||
if conversation_id is not None:
|
||||
@@ -167,10 +170,10 @@ class EventLogStore:
|
||||
|
||||
query = query.order_by(EventLog.id.desc()).limit(limit + 1)
|
||||
|
||||
result = await conn.execute(query)
|
||||
rows = result.fetchall()
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
items = [self._row_to_dict(row[0]) for row in rows[:limit]]
|
||||
items = [self._row_to_dict(row) for row in rows[:limit]]
|
||||
has_more = len(rows) > limit
|
||||
next_seq = items[-1]['id'] if items and has_more else None
|
||||
|
||||
@@ -188,17 +191,17 @@ class EventLogStore:
|
||||
Returns:
|
||||
Cursor string (seq number), or None if no events
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(EventLog.id)
|
||||
.where(EventLog.conversation_id == conversation_id)
|
||||
.order_by(EventLog.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.fetchone()
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return str(row[0])
|
||||
return str(row)
|
||||
|
||||
async def has_events_before(
|
||||
self,
|
||||
@@ -214,8 +217,8 @@ class EventLogStore:
|
||||
Returns:
|
||||
True if there are events before
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(EventLog)
|
||||
.where(
|
||||
|
||||
@@ -125,6 +125,7 @@ class AgentRunOrchestrator:
|
||||
query_id=None, # No query_id in event-first mode
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
permissions=descriptor.permissions or {},
|
||||
conversation_id=event.conversation_id,
|
||||
)
|
||||
|
||||
@@ -222,6 +223,7 @@ class AgentRunOrchestrator:
|
||||
query_id=query.query_id,
|
||||
plugin_identity=descriptor.get_plugin_id(),
|
||||
resources=resources,
|
||||
permissions=descriptor.permissions or {},
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class AgentRunSession(typing.TypedDict):
|
||||
plugin_identity: Plugin identifier (author/name) of the runner
|
||||
conversation_id: Conversation ID for history/event access
|
||||
resources: Authorized resources for this run (from AgentResources)
|
||||
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||
status: Session status tracking
|
||||
_authorized_ids: Pre-computed authorized resource IDs for O(1) lookup
|
||||
"""
|
||||
@@ -36,6 +37,7 @@ class AgentRunSession(typing.TypedDict):
|
||||
plugin_identity: str # author/name
|
||||
conversation_id: str | None
|
||||
resources: AgentResources
|
||||
permissions: dict[str, list[str]]
|
||||
status: AgentRunSessionStatus
|
||||
_authorized_ids: dict[str, set[str]] # Pre-computed sets for O(1) lookup
|
||||
|
||||
@@ -67,6 +69,7 @@ class AgentRunSessionRegistry:
|
||||
plugin_identity: str,
|
||||
resources: AgentResources,
|
||||
conversation_id: str | None = None,
|
||||
permissions: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
"""Register a new agent run session.
|
||||
|
||||
@@ -77,9 +80,13 @@ class AgentRunSessionRegistry:
|
||||
plugin_identity: Plugin identifier (author/name)
|
||||
resources: Authorized resources for this run
|
||||
conversation_id: Conversation ID for history/event access
|
||||
permissions: Runner permissions from descriptor (artifacts, history, events, etc.)
|
||||
"""
|
||||
now = int(time.time())
|
||||
|
||||
# Normalize permissions to empty dict if None
|
||||
permissions = permissions or {}
|
||||
|
||||
# Pre-compute authorized resource IDs for O(1) lookup
|
||||
authorized_ids: dict[str, set[str]] = {
|
||||
'model': {m.get('model_id') for m in resources.get('models', [])},
|
||||
@@ -95,6 +102,7 @@ class AgentRunSessionRegistry:
|
||||
'plugin_identity': plugin_identity,
|
||||
'conversation_id': conversation_id,
|
||||
'resources': resources,
|
||||
'permissions': permissions,
|
||||
'status': {
|
||||
'started_at': now,
|
||||
'last_activity_at': now,
|
||||
|
||||
@@ -7,7 +7,8 @@ import typing
|
||||
import uuid
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from ...entity.persistence.transcript import Transcript
|
||||
|
||||
@@ -27,6 +28,9 @@ class TranscriptStore:
|
||||
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self.engine = engine
|
||||
self._session_factory = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def append_transcript(
|
||||
self,
|
||||
@@ -72,26 +76,25 @@ class TranscriptStore:
|
||||
# Get next sequence number for this conversation
|
||||
seq = await self._get_next_seq(conversation_id)
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
await conn.execute(
|
||||
sqlalchemy.insert(Transcript).values(
|
||||
transcript_id=transcript_id,
|
||||
event_id=event_id,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=thread_id,
|
||||
role=role,
|
||||
item_type=item_type,
|
||||
content=content,
|
||||
content_json=json.dumps(content_json) if content_json else None,
|
||||
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
|
||||
seq=seq,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
async with self._session_factory() as session:
|
||||
item = Transcript(
|
||||
transcript_id=transcript_id,
|
||||
event_id=event_id,
|
||||
conversation_id=conversation_id,
|
||||
thread_id=thread_id,
|
||||
role=role,
|
||||
item_type=item_type,
|
||||
content=content,
|
||||
content_json=json.dumps(content_json) if content_json else None,
|
||||
artifact_refs_json=json.dumps(artifact_refs) if artifact_refs else None,
|
||||
seq=seq,
|
||||
run_id=run_id,
|
||||
runner_id=runner_id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
await conn.commit()
|
||||
session.add(item)
|
||||
await session.commit()
|
||||
|
||||
return transcript_id
|
||||
|
||||
@@ -119,7 +122,7 @@ class TranscriptStore:
|
||||
"""
|
||||
limit = min(limit, self.HARD_LIMIT)
|
||||
|
||||
async with self.engine.connect() as conn:
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(Transcript).where(
|
||||
Transcript.conversation_id == conversation_id
|
||||
)
|
||||
@@ -136,10 +139,10 @@ class TranscriptStore:
|
||||
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await conn.execute(query)
|
||||
rows = result.fetchall()
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
items = [self._row_to_dict(row[0], include_artifacts) for row in rows[:limit]]
|
||||
items = [self._row_to_dict(row, include_artifacts) for row in rows[:limit]]
|
||||
has_more = len(rows) > limit
|
||||
|
||||
# Calculate cursors
|
||||
@@ -179,7 +182,7 @@ class TranscriptStore:
|
||||
Returns:
|
||||
List of matching items
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
async with self._session_factory() as session:
|
||||
query = sqlalchemy.select(Transcript).where(
|
||||
Transcript.conversation_id == conversation_id,
|
||||
Transcript.content.ilike(f"%{query_text}%"),
|
||||
@@ -194,10 +197,10 @@ class TranscriptStore:
|
||||
|
||||
query = query.order_by(Transcript.seq.desc()).limit(top_k)
|
||||
|
||||
result = await conn.execute(query)
|
||||
rows = result.fetchall()
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
return [self._row_to_dict(row[0], include_artifacts=True) for row in rows]
|
||||
return [self._row_to_dict(row, include_artifacts=True) for row in rows]
|
||||
|
||||
async def get_latest_cursor(
|
||||
self,
|
||||
@@ -211,17 +214,17 @@ class TranscriptStore:
|
||||
Returns:
|
||||
Cursor string (seq number), or None if no items
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(Transcript.seq)
|
||||
.where(Transcript.conversation_id == conversation_id)
|
||||
.order_by(Transcript.seq.desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.fetchone()
|
||||
row = result.scalars().first()
|
||||
if row is None:
|
||||
return None
|
||||
return str(row[0])
|
||||
return str(row)
|
||||
|
||||
async def has_history_before(
|
||||
self,
|
||||
@@ -237,8 +240,8 @@ class TranscriptStore:
|
||||
Returns:
|
||||
True if there are items before
|
||||
"""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(sqlalchemy.func.count())
|
||||
.select_from(Transcript)
|
||||
.where(
|
||||
@@ -251,8 +254,8 @@ class TranscriptStore:
|
||||
|
||||
async def _get_next_seq(self, conversation_id: str) -> int:
|
||||
"""Get the next sequence number for a conversation."""
|
||||
async with self.engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
sqlalchemy.select(sqlalchemy.func.max(Transcript.seq))
|
||||
.where(Transcript.conversation_id == conversation_id)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user