237 lines
7.8 KiB
Python
237 lines
7.8 KiB
Python
"""
|
|
Self-Improving Learning Loop — Ollama 모델 파인튜닝 파이프라인
|
|
|
|
RAG 피드백 데이터 + SR 해결 이력으로 모델을 주기적으로 개선.
|
|
|
|
엔드포인트:
|
|
GET /api/learn/status — 학습 현황
|
|
POST /api/learn/collect — 학습 데이터 수집 (수동 트리거)
|
|
POST /api/learn/train — 파인튜닝 실행 (Ollama Modelfile)
|
|
GET /api/learn/history — 학습 이력
|
|
GET /api/learn/quality — 모델 품질 지표
|
|
POST /api/learn/rollback — 이전 모델로 롤백
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select, func, desc
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.auth import get_current_user, require_admin_role
|
|
from database import get_db
|
|
from models import User, RAGFeedback, SRRequest, SRStatus, LearningRun # 신규
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/learn", tags=["Learning Loop"])
|
|
|
|
OLLAMA_URL = "http://localhost:11434"
|
|
BASE_MODEL = "llama3"
|
|
|
|
|
|
async def _collect_training_data(db: AsyncSession) -> list[dict]:
|
|
"""학습 데이터 수집: 고품질 RAG 피드백 + 해결된 SR."""
|
|
samples = []
|
|
|
|
# 1. RAG 피드백 (평점 4 이상)
|
|
fb_rows = await db.execute(
|
|
select(RAGFeedback).where(RAGFeedback.rating >= 4).limit(200)
|
|
)
|
|
for fb in fb_rows.scalars().all():
|
|
if fb.query and fb.comment:
|
|
samples.append({
|
|
"type": "rag_positive",
|
|
"input": fb.query,
|
|
"output": fb.comment,
|
|
"rating": fb.rating,
|
|
})
|
|
|
|
# 2. 해결된 SR (해결방법이 있는 경우)
|
|
month_ago = datetime.utcnow() - timedelta(days=30)
|
|
sr_rows = await db.execute(
|
|
select(SRRequest).where(
|
|
SRRequest.status == SRStatus.DONE,
|
|
SRRequest.updated_at >= month_ago,
|
|
SRRequest.description.isnot(None),
|
|
).limit(100)
|
|
)
|
|
for sr in sr_rows.scalars().all():
|
|
if sr.title and sr.description:
|
|
samples.append({
|
|
"type": "sr_resolution",
|
|
"input": f"SR: {sr.title}\n{sr.description[:200]}",
|
|
"category": sr.category,
|
|
})
|
|
|
|
return samples
|
|
|
|
|
|
async def _build_modelfile(samples: list[dict], base_model: str) -> str:
|
|
"""Ollama Modelfile 생성."""
|
|
system_prompt = (
|
|
"당신은 GUARDiA ITSM 전문 어시스턴트입니다. "
|
|
"IT 인프라 운영, 장애 대응, SR 처리에 특화된 한국어 응답을 제공합니다. "
|
|
"외부 API 사용 없이 내부 지식베이스만 활용합니다."
|
|
)
|
|
modelfile = f'FROM {base_model}\nSYSTEM """{system_prompt}"""\n'
|
|
|
|
# 고품질 RAG 피드백을 파라미터로
|
|
modelfile += "PARAMETER temperature 0.3\n"
|
|
modelfile += "PARAMETER top_p 0.9\n"
|
|
modelfile += "PARAMETER num_ctx 4096\n"
|
|
|
|
return modelfile
|
|
|
|
|
|
async def _run_training(run_id: int, samples: list[dict], db: AsyncSession):
|
|
"""백그라운드 학습 실행."""
|
|
run_row = await db.execute(select(LearningRun).where(LearningRun.id == run_id))
|
|
run = run_row.scalar_one_or_none()
|
|
if not run:
|
|
return
|
|
|
|
try:
|
|
run.status = "RUNNING"
|
|
await db.commit()
|
|
|
|
modelfile = await _build_modelfile(samples, BASE_MODEL)
|
|
new_model_name = f"guardia-itsm:{datetime.utcnow().strftime('%Y%m%d')}"
|
|
|
|
# Ollama create (Modelfile로 커스텀 모델 생성)
|
|
async with httpx.AsyncClient(timeout=300) as client:
|
|
r = await client.post(f"{OLLAMA_URL}/api/create", json={
|
|
"name": new_model_name,
|
|
"modelfile": modelfile,
|
|
})
|
|
if r.status_code == 200:
|
|
run.status = "SUCCESS"
|
|
run.model_name = new_model_name
|
|
run.samples_used = len(samples)
|
|
logger.info(f"학습 완료: {new_model_name}")
|
|
else:
|
|
run.status = "FAILED"
|
|
run.error_message = r.text[:200]
|
|
|
|
except Exception as e:
|
|
run.status = "FAILED"
|
|
run.error_message = str(e)[:200]
|
|
logger.error(f"학습 실패: {e}")
|
|
finally:
|
|
run.finished_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
|
|
@router.get("/status")
|
|
async def learning_status(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""학습 현황 + 데이터 수집 가능량."""
|
|
samples = await _collect_training_data(db)
|
|
high_quality = [s for s in samples if s.get("type") == "rag_positive"]
|
|
|
|
latest = await db.execute(
|
|
select(LearningRun).order_by(desc(LearningRun.started_at)).limit(1)
|
|
)
|
|
last_run = latest.scalar_one_or_none()
|
|
|
|
return {
|
|
"available_samples": len(samples),
|
|
"high_quality_rag": len(high_quality),
|
|
"sr_samples": len(samples) - len(high_quality),
|
|
"ready_to_train": len(samples) >= 20,
|
|
"last_run": {
|
|
"status": last_run.status if last_run else None,
|
|
"model": last_run.model_name if last_run else None,
|
|
"started_at": last_run.started_at if last_run else None,
|
|
} if last_run else None,
|
|
"base_model": BASE_MODEL,
|
|
}
|
|
|
|
|
|
@router.post("/collect")
|
|
async def collect_data(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""학습 데이터 수집 현황 미리보기."""
|
|
samples = await _collect_training_data(db)
|
|
types = {}
|
|
for s in samples:
|
|
types[s["type"]] = types.get(s["type"], 0) + 1
|
|
return {"total": len(samples), "by_type": types, "preview": samples[:3]}
|
|
|
|
|
|
@router.post("/train")
|
|
async def start_training(
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(require_admin_role),
|
|
):
|
|
"""파인튜닝 실행 (백그라운드)."""
|
|
samples = await _collect_training_data(db)
|
|
if len(samples) < 10:
|
|
raise HTTPException(400, f"학습 데이터 부족: {len(samples)}개 (최소 10개)")
|
|
|
|
run = LearningRun(
|
|
triggered_by=user.id,
|
|
sample_count=len(samples),
|
|
status="PENDING",
|
|
started_at=datetime.utcnow(),
|
|
)
|
|
db.add(run)
|
|
await db.commit()
|
|
await db.refresh(run)
|
|
|
|
background_tasks.add_task(_run_training, run.id, samples, db)
|
|
return {"ok": True, "run_id": run.id, "samples": len(samples)}
|
|
|
|
|
|
@router.get("/history")
|
|
async def learning_history(
|
|
limit: int = 20,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
rows = await db.execute(
|
|
select(LearningRun).order_by(desc(LearningRun.started_at)).limit(limit)
|
|
)
|
|
runs = rows.scalars().all()
|
|
return [
|
|
{
|
|
"id": r.id, "status": r.status, "model_name": r.model_name,
|
|
"samples_used": r.samples_used, "started_at": r.started_at,
|
|
"finished_at": r.finished_at, "error": r.error_message,
|
|
}
|
|
for r in runs
|
|
]
|
|
|
|
|
|
@router.get("/quality")
|
|
async def model_quality(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""모델 품질 지표 (RAG 피드백 기반)."""
|
|
total_fb = await db.execute(select(func.count(RAGFeedback.id)))
|
|
total = total_fb.scalar() or 0
|
|
positive_fb = await db.execute(
|
|
select(func.count(RAGFeedback.id)).where(RAGFeedback.rating >= 4)
|
|
)
|
|
positive = positive_fb.scalar() or 0
|
|
avg_rating = await db.execute(select(func.avg(RAGFeedback.rating)))
|
|
avg = avg_rating.scalar() or 0.0
|
|
|
|
return {
|
|
"total_feedback": total,
|
|
"positive_rate": round(positive / total * 100, 1) if total else 0,
|
|
"avg_rating": round(avg, 2),
|
|
"quality_grade": "A" if avg >= 4.5 else "B" if avg >= 3.5 else "C" if avg >= 2.5 else "D",
|
|
}
|