""" Self-Improving Learning Loop — Ollama 모델 파인튜닝 파이프라인 RAG 피드백 데이터 + SR 해결 이력으로 모델을 주기적으로 개선. 엔드포인트: GET /api/learn/status — 학습 현황 POST /api/learn/collect — 학습 데이터 수집 (수동 트리거) POST /api/learn/train — 파인튜닝 실행 (Ollama Modelfile) GET /api/learn/history — 학습 이력 GET /api/learn/quality — 모델 품질 지표 POST /api/learn/rollback — 이전 모델로 롤백 """ from __future__ import annotations import json import logging from datetime import datetime, timedelta from typing import List, Optional import httpx from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel from sqlalchemy import select, func, desc from sqlalchemy.ext.asyncio import AsyncSession from core.auth import get_current_user, require_admin_role from database import get_db from models import User, RAGFeedback, SRRequest, SRStatus, LearningRun # 신규 logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/learn", tags=["Learning Loop"]) OLLAMA_URL = "http://localhost:11434" BASE_MODEL = "llama3" async def _collect_training_data(db: AsyncSession) -> list[dict]: """학습 데이터 수집: 고품질 RAG 피드백 + 해결된 SR.""" samples = [] # 1. RAG 피드백 (평점 4 이상) fb_rows = await db.execute( select(RAGFeedback).where(RAGFeedback.rating >= 4).limit(200) ) for fb in fb_rows.scalars().all(): if fb.query and fb.comment: samples.append({ "type": "rag_positive", "input": fb.query, "output": fb.comment, "rating": fb.rating, }) # 2. 해결된 SR (해결방법이 있는 경우) month_ago = datetime.utcnow() - timedelta(days=30) sr_rows = await db.execute( select(SRRequest).where( SRRequest.status == SRStatus.DONE, SRRequest.updated_at >= month_ago, SRRequest.description.isnot(None), ).limit(100) ) for sr in sr_rows.scalars().all(): if sr.title and sr.description: samples.append({ "type": "sr_resolution", "input": f"SR: {sr.title}\n{sr.description[:200]}", "category": sr.category, }) return samples async def _build_modelfile(samples: list[dict], base_model: str) -> str: """Ollama Modelfile 생성.""" system_prompt = ( "당신은 GUARDiA ITSM 전문 어시스턴트입니다. " "IT 인프라 운영, 장애 대응, SR 처리에 특화된 한국어 응답을 제공합니다. " "외부 API 사용 없이 내부 지식베이스만 활용합니다." ) modelfile = f'FROM {base_model}\nSYSTEM """{system_prompt}"""\n' # 고품질 RAG 피드백을 파라미터로 modelfile += "PARAMETER temperature 0.3\n" modelfile += "PARAMETER top_p 0.9\n" modelfile += "PARAMETER num_ctx 4096\n" return modelfile async def _run_training(run_id: int, samples: list[dict], db: AsyncSession): """백그라운드 학습 실행.""" run_row = await db.execute(select(LearningRun).where(LearningRun.id == run_id)) run = run_row.scalar_one_or_none() if not run: return try: run.status = "RUNNING" await db.commit() modelfile = await _build_modelfile(samples, BASE_MODEL) new_model_name = f"guardia-itsm:{datetime.utcnow().strftime('%Y%m%d')}" # Ollama create (Modelfile로 커스텀 모델 생성) async with httpx.AsyncClient(timeout=300) as client: r = await client.post(f"{OLLAMA_URL}/api/create", json={ "name": new_model_name, "modelfile": modelfile, }) if r.status_code == 200: run.status = "SUCCESS" run.model_name = new_model_name run.samples_used = len(samples) logger.info(f"학습 완료: {new_model_name}") else: run.status = "FAILED" run.error_message = r.text[:200] except Exception as e: run.status = "FAILED" run.error_message = str(e)[:200] logger.error(f"학습 실패: {e}") finally: run.finished_at = datetime.utcnow() await db.commit() @router.get("/status") async def learning_status( db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """학습 현황 + 데이터 수집 가능량.""" samples = await _collect_training_data(db) high_quality = [s for s in samples if s.get("type") == "rag_positive"] latest = await db.execute( select(LearningRun).order_by(desc(LearningRun.started_at)).limit(1) ) last_run = latest.scalar_one_or_none() return { "available_samples": len(samples), "high_quality_rag": len(high_quality), "sr_samples": len(samples) - len(high_quality), "ready_to_train": len(samples) >= 20, "last_run": { "status": last_run.status if last_run else None, "model": last_run.model_name if last_run else None, "started_at": last_run.started_at if last_run else None, } if last_run else None, "base_model": BASE_MODEL, } @router.post("/collect") async def collect_data( db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """학습 데이터 수집 현황 미리보기.""" samples = await _collect_training_data(db) types = {} for s in samples: types[s["type"]] = types.get(s["type"], 0) + 1 return {"total": len(samples), "by_type": types, "preview": samples[:3]} @router.post("/train") async def start_training( background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db), user: User = Depends(require_admin_role), ): """파인튜닝 실행 (백그라운드).""" samples = await _collect_training_data(db) if len(samples) < 10: raise HTTPException(400, f"학습 데이터 부족: {len(samples)}개 (최소 10개)") run = LearningRun( triggered_by=user.id, sample_count=len(samples), status="PENDING", started_at=datetime.utcnow(), ) db.add(run) await db.commit() await db.refresh(run) background_tasks.add_task(_run_training, run.id, samples, db) return {"ok": True, "run_id": run.id, "samples": len(samples)} @router.get("/history") async def learning_history( limit: int = 20, db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): rows = await db.execute( select(LearningRun).order_by(desc(LearningRun.started_at)).limit(limit) ) runs = rows.scalars().all() return [ { "id": r.id, "status": r.status, "model_name": r.model_name, "samples_used": r.samples_used, "started_at": r.started_at, "finished_at": r.finished_at, "error": r.error_message, } for r in runs ] @router.get("/quality") async def model_quality( db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user), ): """모델 품질 지표 (RAG 피드백 기반).""" total_fb = await db.execute(select(func.count(RAGFeedback.id))) total = total_fb.scalar() or 0 positive_fb = await db.execute( select(func.count(RAGFeedback.id)).where(RAGFeedback.rating >= 4) ) positive = positive_fb.scalar() or 0 avg_rating = await db.execute(select(func.avg(RAGFeedback.rating))) avg = avg_rating.scalar() or 0.0 return { "total_feedback": total, "positive_rate": round(positive / total * 100, 1) if total else 0, "avg_rating": round(avg, 2), "quality_grade": "A" if avg >= 4.5 else "B" if avg >= 3.5 else "C" if avg >= 2.5 else "D", }