guardia-itsm/routers/learning_loop.py
2026-06-02 06:07:36 +09:00

237 lines
7.8 KiB
Python

"""
Self-Improving Learning Loop — Ollama 모델 파인튜닝 파이프라인
RAG 피드백 데이터 + SR 해결 이력으로 모델을 주기적으로 개선.
엔드포인트:
GET /api/learn/status — 학습 현황
POST /api/learn/collect — 학습 데이터 수집 (수동 트리거)
POST /api/learn/train — 파인튜닝 실행 (Ollama Modelfile)
GET /api/learn/history — 학습 이력
GET /api/learn/quality — 모델 품질 지표
POST /api/learn/rollback — 이전 모델로 롤백
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timedelta
from typing import List, Optional
import httpx
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth import get_current_user, require_admin_role
from database import get_db
from models import User, RAGFeedback, SRRequest, SRStatus, LearningRun # 신규
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/learn", tags=["Learning Loop"])
OLLAMA_URL = "http://localhost:11434"
BASE_MODEL = "llama3"
async def _collect_training_data(db: AsyncSession) -> list[dict]:
"""학습 데이터 수집: 고품질 RAG 피드백 + 해결된 SR."""
samples = []
# 1. RAG 피드백 (평점 4 이상)
fb_rows = await db.execute(
select(RAGFeedback).where(RAGFeedback.rating >= 4).limit(200)
)
for fb in fb_rows.scalars().all():
if fb.query and fb.comment:
samples.append({
"type": "rag_positive",
"input": fb.query,
"output": fb.comment,
"rating": fb.rating,
})
# 2. 해결된 SR (해결방법이 있는 경우)
month_ago = datetime.utcnow() - timedelta(days=30)
sr_rows = await db.execute(
select(SRRequest).where(
SRRequest.status == SRStatus.DONE,
SRRequest.updated_at >= month_ago,
SRRequest.description.isnot(None),
).limit(100)
)
for sr in sr_rows.scalars().all():
if sr.title and sr.description:
samples.append({
"type": "sr_resolution",
"input": f"SR: {sr.title}\n{sr.description[:200]}",
"category": sr.category,
})
return samples
async def _build_modelfile(samples: list[dict], base_model: str) -> str:
"""Ollama Modelfile 생성."""
system_prompt = (
"당신은 GUARDiA ITSM 전문 어시스턴트입니다. "
"IT 인프라 운영, 장애 대응, SR 처리에 특화된 한국어 응답을 제공합니다. "
"외부 API 사용 없이 내부 지식베이스만 활용합니다."
)
modelfile = f'FROM {base_model}\nSYSTEM """{system_prompt}"""\n'
# 고품질 RAG 피드백을 파라미터로
modelfile += "PARAMETER temperature 0.3\n"
modelfile += "PARAMETER top_p 0.9\n"
modelfile += "PARAMETER num_ctx 4096\n"
return modelfile
async def _run_training(run_id: int, samples: list[dict], db: AsyncSession):
"""백그라운드 학습 실행."""
run_row = await db.execute(select(LearningRun).where(LearningRun.id == run_id))
run = run_row.scalar_one_or_none()
if not run:
return
try:
run.status = "RUNNING"
await db.commit()
modelfile = await _build_modelfile(samples, BASE_MODEL)
new_model_name = f"guardia-itsm:{datetime.utcnow().strftime('%Y%m%d')}"
# Ollama create (Modelfile로 커스텀 모델 생성)
async with httpx.AsyncClient(timeout=300) as client:
r = await client.post(f"{OLLAMA_URL}/api/create", json={
"name": new_model_name,
"modelfile": modelfile,
})
if r.status_code == 200:
run.status = "SUCCESS"
run.model_name = new_model_name
run.samples_used = len(samples)
logger.info(f"학습 완료: {new_model_name}")
else:
run.status = "FAILED"
run.error_message = r.text[:200]
except Exception as e:
run.status = "FAILED"
run.error_message = str(e)[:200]
logger.error(f"학습 실패: {e}")
finally:
run.finished_at = datetime.utcnow()
await db.commit()
@router.get("/status")
async def learning_status(
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""학습 현황 + 데이터 수집 가능량."""
samples = await _collect_training_data(db)
high_quality = [s for s in samples if s.get("type") == "rag_positive"]
latest = await db.execute(
select(LearningRun).order_by(desc(LearningRun.started_at)).limit(1)
)
last_run = latest.scalar_one_or_none()
return {
"available_samples": len(samples),
"high_quality_rag": len(high_quality),
"sr_samples": len(samples) - len(high_quality),
"ready_to_train": len(samples) >= 20,
"last_run": {
"status": last_run.status if last_run else None,
"model": last_run.model_name if last_run else None,
"started_at": last_run.started_at if last_run else None,
} if last_run else None,
"base_model": BASE_MODEL,
}
@router.post("/collect")
async def collect_data(
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""학습 데이터 수집 현황 미리보기."""
samples = await _collect_training_data(db)
types = {}
for s in samples:
types[s["type"]] = types.get(s["type"], 0) + 1
return {"total": len(samples), "by_type": types, "preview": samples[:3]}
@router.post("/train")
async def start_training(
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
user: User = Depends(require_admin_role),
):
"""파인튜닝 실행 (백그라운드)."""
samples = await _collect_training_data(db)
if len(samples) < 10:
raise HTTPException(400, f"학습 데이터 부족: {len(samples)}개 (최소 10개)")
run = LearningRun(
triggered_by=user.id,
sample_count=len(samples),
status="PENDING",
started_at=datetime.utcnow(),
)
db.add(run)
await db.commit()
await db.refresh(run)
background_tasks.add_task(_run_training, run.id, samples, db)
return {"ok": True, "run_id": run.id, "samples": len(samples)}
@router.get("/history")
async def learning_history(
limit: int = 20,
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
rows = await db.execute(
select(LearningRun).order_by(desc(LearningRun.started_at)).limit(limit)
)
runs = rows.scalars().all()
return [
{
"id": r.id, "status": r.status, "model_name": r.model_name,
"samples_used": r.samples_used, "started_at": r.started_at,
"finished_at": r.finished_at, "error": r.error_message,
}
for r in runs
]
@router.get("/quality")
async def model_quality(
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""모델 품질 지표 (RAG 피드백 기반)."""
total_fb = await db.execute(select(func.count(RAGFeedback.id)))
total = total_fb.scalar() or 0
positive_fb = await db.execute(
select(func.count(RAGFeedback.id)).where(RAGFeedback.rating >= 4)
)
positive = positive_fb.scalar() or 0
avg_rating = await db.execute(select(func.avg(RAGFeedback.rating)))
avg = avg_rating.scalar() or 0.0
return {
"total_feedback": total,
"positive_rate": round(positive / total * 100, 1) if total else 0,
"avg_rating": round(avg, 2),
"quality_grade": "A" if avg >= 4.5 else "B" if avg >= 3.5 else "C" if avg >= 2.5 else "D",
}