196 lines
6.8 KiB
Python
196 lines
6.8 KiB
Python
"""
|
|
GUARDiA 영구 메모리 엔진 — Mem0-style pgvector 기반
|
|
|
|
엔드포인트:
|
|
POST /api/memory/remember — 기억 저장
|
|
GET /api/memory/recall — 의미론적 검색
|
|
GET /api/memory/context/{sid} — 세션 컨텍스트
|
|
POST /api/memory/forget — 기억 삭제
|
|
GET /api/memory/stats — 메모리 현황
|
|
POST /api/memory/consolidate — 단기→장기 통합
|
|
"""
|
|
from __future__ import annotations
|
|
import json, logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select, desc, func, text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from core.auth import get_current_user
|
|
from database import get_db
|
|
from models import User, AgentMemory, AgentSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/memory", tags=["영구 메모리"])
|
|
OLLAMA_URL = "http://localhost:11434"
|
|
EMBED_MODEL = "nomic-embed-text"
|
|
EMBED_DIM = 768
|
|
|
|
|
|
async def _embed(text_: str) -> list:
|
|
"""nomic-embed-text로 텍스트 임베딩 생성."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15) as c:
|
|
r = await c.post(f"{OLLAMA_URL}/api/embeddings",
|
|
json={"model": EMBED_MODEL, "prompt": text_})
|
|
return r.json().get("embedding", [0.0] * EMBED_DIM)
|
|
except Exception as e:
|
|
logger.warning(f"임베딩 실패: {e}")
|
|
return [0.0] * EMBED_DIM
|
|
|
|
|
|
class MemoryIn(BaseModel):
|
|
content: str
|
|
memory_type: str = "EPISODIC" # EPISODIC|SEMANTIC|PROCEDURAL
|
|
session_id: Optional[str] = None
|
|
metadata: dict = {}
|
|
ttl_days: Optional[int] = None # None = 영구
|
|
|
|
|
|
class RecallQuery(BaseModel):
|
|
query: str
|
|
limit: int = 5
|
|
memory_type: Optional[str] = None
|
|
min_confidence: float = 0.0
|
|
|
|
|
|
@router.post("/remember", status_code=201)
|
|
async def remember(body: MemoryIn, db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user)):
|
|
embedding = await _embed(body.content)
|
|
expires_at = datetime.utcnow() + timedelta(days=body.ttl_days) if body.ttl_days else None
|
|
|
|
mem = AgentMemory(
|
|
session_id=body.session_id,
|
|
memory_type=body.memory_type,
|
|
content=body.content,
|
|
embedding=json.dumps(embedding), # JSON으로 저장 (pgvector 대안)
|
|
metadata_json=json.dumps(body.metadata),
|
|
confidence=0.5,
|
|
access_count=0,
|
|
created_by=user.id,
|
|
created_at=datetime.utcnow(),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(mem); await db.commit(); await db.refresh(mem)
|
|
return {"memory_id": mem.id, "type": body.memory_type}
|
|
|
|
|
|
@router.get("/recall")
|
|
async def recall(
|
|
q: str = Query(..., description="검색 쿼리"),
|
|
limit: int = Query(5, le=20),
|
|
memory_type: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""의미론적 유사 기억 검색 (코사인 유사도 근사)."""
|
|
query_emb = await _embed(q)
|
|
|
|
stmt = select(AgentMemory).where(
|
|
AgentMemory.expires_at.is_(None) | (AgentMemory.expires_at > datetime.utcnow())
|
|
)
|
|
if memory_type:
|
|
stmt = stmt.where(AgentMemory.memory_type == memory_type)
|
|
stmt = stmt.order_by(desc(AgentMemory.access_count)).limit(limit * 3)
|
|
|
|
rows = await db.execute(stmt)
|
|
memories = rows.scalars().all()
|
|
|
|
# 코사인 유사도 계산 (Python 레벨)
|
|
def cosine_sim(a, b):
|
|
if not a or not b: return 0.0
|
|
dot = sum(x*y for x, y in zip(a, b))
|
|
na = sum(x*x for x in a) ** 0.5
|
|
nb = sum(x*x for x in b) ** 0.5
|
|
return dot / (na * nb + 1e-8)
|
|
|
|
scored = []
|
|
for m in memories:
|
|
try:
|
|
emb = json.loads(m.embedding or "[]")
|
|
sim = cosine_sim(query_emb, emb)
|
|
scored.append((sim, m))
|
|
except Exception:
|
|
pass
|
|
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
# 접근 횟수 업데이트 (비동기 백그라운드)
|
|
result = []
|
|
for sim, m in scored[:limit]:
|
|
m.access_count = (m.access_count or 0) + 1
|
|
result.append({
|
|
"id": m.id, "content": m.content, "type": m.memory_type,
|
|
"similarity": round(sim, 3), "confidence": m.confidence,
|
|
"access_count": m.access_count, "created_at": m.created_at,
|
|
})
|
|
await db.commit()
|
|
return result
|
|
|
|
|
|
@router.get("/context/{session_id}")
|
|
async def get_context(session_id: str, db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user)):
|
|
"""세션 컨텍스트 복원."""
|
|
row = await db.execute(select(AgentSession).where(AgentSession.session_id == session_id))
|
|
session = row.scalar_one_or_none()
|
|
if not session:
|
|
return {"session_id": session_id, "context": {}, "memories": []}
|
|
|
|
# 세션 관련 메모리 조회
|
|
mems = await db.execute(
|
|
select(AgentMemory).where(AgentMemory.session_id == session_id)
|
|
.order_by(desc(AgentMemory.created_at)).limit(10)
|
|
)
|
|
return {
|
|
"session_id": session_id,
|
|
"context": json.loads(session.context_json or "{}"),
|
|
"memories": [{"content": m.content, "type": m.memory_type}
|
|
for m in mems.scalars().all()],
|
|
"last_active": session.last_active,
|
|
}
|
|
|
|
|
|
@router.post("/forget")
|
|
async def forget(memory_id: int, db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user)):
|
|
"""기억 삭제."""
|
|
row = await db.execute(select(AgentMemory).where(AgentMemory.id == memory_id))
|
|
mem = row.scalar_one_or_none()
|
|
if not mem: raise HTTPException(404)
|
|
await db.delete(mem); await db.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@router.get("/stats")
|
|
async def memory_stats(db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
total = (await db.execute(select(func.count(AgentMemory.id)))).scalar() or 0
|
|
by_type = {}
|
|
for mt in ["EPISODIC", "SEMANTIC", "PROCEDURAL"]:
|
|
cnt = (await db.execute(
|
|
select(func.count(AgentMemory.id)).where(AgentMemory.memory_type == mt)
|
|
)).scalar() or 0
|
|
by_type[mt] = cnt
|
|
return {"total_memories": total, "by_type": by_type, "embed_model": EMBED_MODEL}
|
|
|
|
|
|
@router.post("/consolidate")
|
|
async def consolidate(db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
"""접근 횟수 낮은 오래된 기억 압축·정리."""
|
|
cutoff = datetime.utcnow() - timedelta(days=30)
|
|
rows = await db.execute(
|
|
select(AgentMemory).where(
|
|
AgentMemory.created_at < cutoff,
|
|
AgentMemory.access_count < 2
|
|
).limit(50)
|
|
)
|
|
old = rows.scalars().all()
|
|
count = 0
|
|
for m in old:
|
|
if m.expires_at and m.expires_at < datetime.utcnow():
|
|
await db.delete(m); count += 1
|
|
await db.commit()
|
|
return {"consolidated": count, "message": f"{count}개 만료 기억 정리됨"}
|