304 lines
12 KiB
Python
304 lines
12 KiB
Python
"""
|
|
RAG 엔진 (Retrieval-Augmented Generation) — 기존 KB 키워드 검색 고도화
|
|
|
|
기존 kb.py의 단순 키워드 매칭을 하이브리드 검색으로 업그레이드:
|
|
1. 키워드 기반 BM25 근사 (PostgreSQL FTS)
|
|
2. 시맨틱 유사도 (pgvector 코사인 거리)
|
|
3. RRF(Reciprocal Rank Fusion)로 두 결과 결합
|
|
4. Ollama 최종 생성 응답
|
|
|
|
엔드포인트:
|
|
POST /api/rag/search — 하이브리드 RAG 검색
|
|
POST /api/rag/ask — 자연어 질문 → Ollama 답변 생성
|
|
POST /api/rag/feedback — 검색 결과 피드백 (품질 개선용)
|
|
GET /api/rag/stats — RAG 사용 통계
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import select, func, text, desc
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.auth import get_current_user
|
|
from database import get_db
|
|
from models import (
|
|
KBDocument, SRRequest, User,
|
|
RAGFeedback, # 신규 모델
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/rag", tags=["RAG Engine"])
|
|
|
|
OLLAMA_URL = "http://localhost:11434"
|
|
EMBED_MODEL = "nomic-embed-text"
|
|
CHAT_MODEL = "llama3"
|
|
|
|
|
|
# ── Pydantic 스키마 ──────────────────────────────────────────────────────────
|
|
|
|
class RAGSearchRequest(BaseModel):
|
|
query: str = Field(..., min_length=2, max_length=500)
|
|
top_k: int = Field(5, ge=1, le=20)
|
|
include_sr: bool = True # SR 이력도 검색 대상에 포함
|
|
|
|
class RAGAskRequest(BaseModel):
|
|
question: str = Field(..., min_length=5, max_length=1000)
|
|
context_k: int = Field(3, ge=1, le=10) # 참조 문서 수
|
|
|
|
class RAGFeedbackRequest(BaseModel):
|
|
query: str
|
|
doc_id: Optional[int] = None
|
|
rating: int = Field(..., ge=1, le=5) # 1=나쁨 5=좋음
|
|
comment: Optional[str] = None
|
|
|
|
class RAGResult(BaseModel):
|
|
doc_id: int
|
|
title: str
|
|
excerpt: str
|
|
score: float
|
|
source: str # "kb" | "sr"
|
|
tags: List[str] = []
|
|
|
|
|
|
# ── 유틸: 임베딩 생성 ────────────────────────────────────────────────────────
|
|
|
|
async def _embed(text: str) -> Optional[list[float]]:
|
|
"""Ollama nomic-embed-text로 텍스트 임베딩 생성."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
resp = await client.post(
|
|
f"{OLLAMA_URL}/api/embeddings",
|
|
json={"model": EMBED_MODEL, "prompt": text}
|
|
)
|
|
if resp.status_code == 200:
|
|
return resp.json().get("embedding")
|
|
except Exception as e:
|
|
logger.warning(f"임베딩 생성 실패: {e}")
|
|
return None
|
|
|
|
|
|
def _tokenize(text: str) -> list[str]:
|
|
"""BM25용 토크나이징 (기존 kb.py 패턴 재사용)."""
|
|
STOPWORDS = {"이", "가", "을", "를", "의", "에", "the", "a", "an", "is"}
|
|
tokens = re.split(r'[\s,;:.(){}\[\]<>/\\|&!@#$%^*+=~`\-\'\"]+', text.lower())
|
|
return [t for t in tokens if len(t) >= 2 and t not in STOPWORDS]
|
|
|
|
|
|
def _rrf_merge(keyword_results: list, semantic_results: list, k: int = 60) -> list[dict]:
|
|
"""
|
|
Reciprocal Rank Fusion으로 두 결과 목록 결합.
|
|
score = 1/(k + rank_keyword) + 1/(k + rank_semantic)
|
|
"""
|
|
scores: dict[int, dict] = {}
|
|
|
|
for rank, item in enumerate(keyword_results):
|
|
doc_id = item["doc_id"]
|
|
if doc_id not in scores:
|
|
scores[doc_id] = {**item, "rrf_score": 0.0}
|
|
scores[doc_id]["rrf_score"] += 1.0 / (k + rank + 1)
|
|
|
|
for rank, item in enumerate(semantic_results):
|
|
doc_id = item["doc_id"]
|
|
if doc_id not in scores:
|
|
scores[doc_id] = {**item, "rrf_score": 0.0}
|
|
scores[doc_id]["rrf_score"] += 1.0 / (k + rank + 1)
|
|
|
|
return sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True)
|
|
|
|
|
|
# ── 엔드포인트 ───────────────────────────────────────────────────────────────
|
|
|
|
@router.post("/search", response_model=List[RAGResult])
|
|
async def hybrid_search(
|
|
req: RAGSearchRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""하이브리드 검색: BM25(키워드) + 벡터(시맨틱) → RRF 결합."""
|
|
query_tokens = _tokenize(req.query)
|
|
|
|
# ── 1. 키워드 기반 검색 (BM25 근사) ─────────────────────────────────────
|
|
keyword_hits: list[dict] = []
|
|
if query_tokens:
|
|
kbs = await db.execute(select(KBDocument).limit(200))
|
|
kbs_all = kbs.scalars().all()
|
|
scored = []
|
|
for doc in kbs_all:
|
|
# 간단한 TF 기반 스코어
|
|
text_blob = f"{doc.title or ''} {doc.symptom or ''} {doc.solution or ''}"
|
|
doc_tokens = _tokenize(text_blob)
|
|
if not doc_tokens:
|
|
continue
|
|
hit = sum(doc_tokens.count(t) for t in query_tokens)
|
|
if hit > 0:
|
|
scored.append({
|
|
"doc_id": doc.id,
|
|
"title": doc.title,
|
|
"excerpt": (doc.symptom or doc.solution or "")[:150],
|
|
"score": hit / len(doc_tokens),
|
|
"source": "kb",
|
|
"tags": json.loads(doc.tags) if doc.tags else [],
|
|
})
|
|
keyword_hits = sorted(scored, key=lambda x: x["score"], reverse=True)[:req.top_k * 2]
|
|
|
|
# ── 2. 시맨틱 검색 (pgvector) ────────────────────────────────────────────
|
|
semantic_hits: list[dict] = []
|
|
embedding = await _embed(req.query)
|
|
if embedding:
|
|
try:
|
|
vec_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
|
# pgvector cosine distance (낮을수록 유사)
|
|
raw = await db.execute(
|
|
text("""
|
|
SELECT id, title, symptom, solution, tags,
|
|
(embedding <=> :vec) AS distance
|
|
FROM tb_kb_document
|
|
WHERE embedding IS NOT NULL
|
|
ORDER BY embedding <=> :vec
|
|
LIMIT :lim
|
|
"""),
|
|
{"vec": vec_str, "lim": req.top_k * 2}
|
|
)
|
|
for row in raw.fetchall():
|
|
semantic_hits.append({
|
|
"doc_id": row.id,
|
|
"title": row.title or "",
|
|
"excerpt": (row.symptom or row.solution or "")[:150],
|
|
"score": max(0.0, 1.0 - row.distance),
|
|
"source": "kb",
|
|
"tags": json.loads(row.tags) if row.tags else [],
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"pgvector 검색 실패 (키워드만 사용): {e}")
|
|
|
|
# ── 3. SR 이력 검색 ──────────────────────────────────────────────────────
|
|
if req.include_sr and query_tokens:
|
|
sr_rows = await db.execute(
|
|
select(SRRequest).where(SRRequest.status == "DONE").order_by(
|
|
desc(SRRequest.updated_at)
|
|
).limit(100)
|
|
)
|
|
for sr in sr_rows.scalars().all():
|
|
text_blob = f"{sr.title or ''} {sr.description or ''}"
|
|
doc_tokens = _tokenize(text_blob)
|
|
hit = sum(doc_tokens.count(t) for t in query_tokens) if doc_tokens else 0
|
|
if hit > 0:
|
|
keyword_hits.append({
|
|
"doc_id": -(sr.id), # 음수 ID로 SR 구분
|
|
"title": f"[SR-{sr.id}] {sr.title}",
|
|
"excerpt": (sr.description or "")[:150],
|
|
"score": hit / max(len(doc_tokens), 1),
|
|
"source": "sr",
|
|
"tags": [sr.category] if sr.category else [],
|
|
})
|
|
|
|
# ── 4. RRF 결합 ──────────────────────────────────────────────────────────
|
|
merged = _rrf_merge(keyword_hits, semantic_hits)
|
|
final = merged[:req.top_k]
|
|
|
|
return [
|
|
RAGResult(
|
|
doc_id=r["doc_id"],
|
|
title=r["title"],
|
|
excerpt=r["excerpt"],
|
|
score=round(r.get("rrf_score", r.get("score", 0.0)), 4),
|
|
source=r["source"],
|
|
tags=r.get("tags", []),
|
|
)
|
|
for r in final
|
|
]
|
|
|
|
|
|
@router.post("/ask")
|
|
async def rag_ask(
|
|
req: RAGAskRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""자연어 질문 → 컨텍스트 검색 → Ollama 답변 생성."""
|
|
# 1. 하이브리드 검색으로 컨텍스트 수집
|
|
search_req = RAGSearchRequest(query=req.question, top_k=req.context_k)
|
|
results = await hybrid_search(search_req, db, user)
|
|
|
|
context_parts = []
|
|
for r in results:
|
|
context_parts.append(f"[{r.source.upper()} {r.doc_id}] {r.title}\n{r.excerpt}")
|
|
context = "\n\n".join(context_parts) if context_parts else "관련 문서를 찾지 못했습니다."
|
|
|
|
# 2. Ollama 프롬프트 구성
|
|
system_prompt = (
|
|
"당신은 GUARDiA ITSM 운영 어시스턴트입니다. "
|
|
"아래 문서만 참조하여 간결하고 정확한 한국어 답변을 제공하세요. "
|
|
"문서에 없는 내용은 추측하지 마세요."
|
|
)
|
|
user_prompt = f"질문: {req.question}\n\n참조 문서:\n{context}"
|
|
|
|
# 3. Ollama 호출
|
|
answer = "Ollama 응답 실패"
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.post(
|
|
f"{OLLAMA_URL}/api/generate",
|
|
json={
|
|
"model": CHAT_MODEL,
|
|
"system": system_prompt,
|
|
"prompt": user_prompt,
|
|
"stream": False,
|
|
}
|
|
)
|
|
if resp.status_code == 200:
|
|
answer = resp.json().get("response", "응답 없음")
|
|
except Exception as e:
|
|
logger.error(f"Ollama 호출 실패: {e}")
|
|
|
|
return {
|
|
"question": req.question,
|
|
"answer": answer,
|
|
"sources": [{"id": r.doc_id, "title": r.title, "source": r.source} for r in results],
|
|
"model": CHAT_MODEL,
|
|
}
|
|
|
|
|
|
@router.post("/feedback")
|
|
async def rag_feedback(
|
|
req: RAGFeedbackRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""검색 결과 품질 피드백 저장 (Learning Loop 기반 데이터)."""
|
|
fb = RAGFeedback(
|
|
user_id=user.id,
|
|
query=req.query,
|
|
doc_id=req.doc_id,
|
|
rating=req.rating,
|
|
comment=req.comment,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
db.add(fb)
|
|
await db.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
@router.get("/stats")
|
|
async def rag_stats(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""RAG 사용 통계."""
|
|
total_fb = await db.execute(select(func.count(RAGFeedback.id)))
|
|
avg_rating = await db.execute(select(func.avg(RAGFeedback.rating)))
|
|
return {
|
|
"total_feedback": total_fb.scalar() or 0,
|
|
"avg_rating": round(avg_rating.scalar() or 0.0, 2),
|
|
"embed_model": EMBED_MODEL,
|
|
"chat_model": CHAT_MODEL,
|
|
}
|