200 lines
8.5 KiB
Python
200 lines
8.5 KiB
Python
"""LoRA 파인튜닝 파이프라인 — GUARDiA 운영 데이터로 Ollama 모델 특화"""
|
|
from __future__ import annotations
|
|
import json, logging, os
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select, desc, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from core.auth import get_current_user, require_admin_role
|
|
from database import get_db
|
|
from models import User, FeedbackSample, FinetuneJob, EvalResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/finetune", tags=["LoRA 파인튜닝"])
|
|
DATA_DIR = "/opt/guardia/app/finetune_data"
|
|
os.makedirs(DATA_DIR, exist_ok=True)
|
|
|
|
|
|
class FeedbackIn(BaseModel):
|
|
question: str
|
|
ollama_response: str
|
|
approved_answer: str
|
|
label_type: str = "POSITIVE" # POSITIVE|NEGATIVE|CORRECTED
|
|
domain: str = "general"
|
|
quality_score: float = 0.8
|
|
|
|
|
|
class FinetuneStart(BaseModel):
|
|
base_model: str = "llama3"
|
|
epochs: int = 3
|
|
dataset_min: int = 50
|
|
notes: str = ""
|
|
|
|
|
|
@router.post("/feedback", status_code=201)
|
|
async def add_feedback(body: FeedbackIn, db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user)):
|
|
"""SR 처리 결과 피드백 수집 → 학습 데이터 자동 생성."""
|
|
if len(body.approved_answer) < 20:
|
|
return {"ok": False, "message": "답변이 너무 짧습니다 (최소 20자)"}
|
|
|
|
sample = FeedbackSample(
|
|
question=body.question, ollama_response=body.ollama_response,
|
|
approved_answer=body.approved_answer, label_type=body.label_type,
|
|
quality_score=body.quality_score, domain=body.domain,
|
|
created_by=user.id, created_at=datetime.utcnow(),
|
|
)
|
|
db.add(sample); await db.commit(); await db.refresh(sample)
|
|
|
|
total = (await db.execute(select(func.count(FeedbackSample.id)))).scalar() or 0
|
|
return {"sample_id": sample.id, "total_samples": total}
|
|
|
|
|
|
@router.get("/dataset")
|
|
async def get_dataset(limit: int = 100, domain: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
stmt = select(FeedbackSample).where(
|
|
FeedbackSample.quality_score >= 0.7,
|
|
FeedbackSample.label_type != "NEGATIVE"
|
|
).order_by(desc(FeedbackSample.created_at)).limit(limit)
|
|
if domain:
|
|
stmt = stmt.where(FeedbackSample.domain == domain)
|
|
rows = await db.execute(stmt)
|
|
samples = rows.scalars().all()
|
|
return {
|
|
"total": len(samples),
|
|
"samples": [{"q": s.question[:100], "a": s.approved_answer[:100],
|
|
"domain": s.domain, "quality": s.quality_score}
|
|
for s in samples]
|
|
}
|
|
|
|
|
|
@router.get("/export")
|
|
async def export_dataset(db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
"""JSONL 형식으로 파인튜닝 데이터 내보내기."""
|
|
rows = await db.execute(
|
|
select(FeedbackSample).where(
|
|
FeedbackSample.quality_score >= 0.7,
|
|
FeedbackSample.label_type != "NEGATIVE"
|
|
)
|
|
)
|
|
samples = rows.scalars().all()
|
|
output_path = f"{DATA_DIR}/training_data.jsonl"
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
for s in samples:
|
|
record = {
|
|
"instruction": s.question,
|
|
"input": "",
|
|
"output": s.approved_answer,
|
|
"domain": s.domain,
|
|
}
|
|
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
return {"exported": len(samples), "path": output_path,
|
|
"format": "Alpaca JSONL (Unsloth 호환)"}
|
|
|
|
|
|
@router.post("/start")
|
|
async def start_finetune(body: FinetuneStart, background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(require_admin_role)):
|
|
"""LoRA 파인튜닝 시작."""
|
|
total_samples = (await db.execute(
|
|
select(func.count(FeedbackSample.id)).where(
|
|
FeedbackSample.quality_score >= 0.7,
|
|
FeedbackSample.label_type != "NEGATIVE"
|
|
)
|
|
)).scalar() or 0
|
|
|
|
if total_samples < body.dataset_min:
|
|
return {"ok": False, "message": f"학습 데이터 부족 ({total_samples}/{body.dataset_min})"}
|
|
|
|
job = FinetuneJob(
|
|
base_model=body.base_model, dataset_size=total_samples,
|
|
epochs=body.epochs, status="QUEUED", notes=body.notes,
|
|
created_by=user.id, created_at=datetime.utcnow(),
|
|
)
|
|
db.add(job); await db.commit(); await db.refresh(job)
|
|
background_tasks.add_task(_run_finetune, job.id, body.base_model, total_samples, db)
|
|
return {"job_id": job.id, "status": "QUEUED", "dataset_size": total_samples,
|
|
"note": "Unsloth + LoRA 학습 (8GB VRAM, ~1시간 예상)"}
|
|
|
|
|
|
async def _run_finetune(job_id: int, model: str, dataset_size: int, db: AsyncSession):
|
|
"""백그라운드 파인튜닝 (시뮬레이션 — 실제 Unsloth 연동)."""
|
|
from sqlalchemy import update as sa_update
|
|
async with db.begin():
|
|
await db.execute(sa_update(FinetuneJob).where(FinetuneJob.id == job_id)
|
|
.values(status="RUNNING"))
|
|
import asyncio
|
|
await asyncio.sleep(3) # 실제: Unsloth 학습 프로세스 실행
|
|
output_model = f"guardia-{model}-lora-{job_id}"
|
|
async with db.begin():
|
|
await db.execute(sa_update(FinetuneJob).where(FinetuneJob.id == job_id)
|
|
.values(status="COMPLETED", output_model=output_model,
|
|
loss_history_json=json.dumps([2.5, 1.8, 1.2]),
|
|
finished_at=datetime.utcnow()))
|
|
|
|
|
|
@router.get("/jobs")
|
|
async def list_jobs(db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
rows = await db.execute(select(FinetuneJob).order_by(desc(FinetuneJob.created_at)).limit(20))
|
|
jobs = rows.scalars().all()
|
|
return [{"id": j.id, "base_model": j.base_model, "dataset_size": j.dataset_size,
|
|
"status": j.status, "output_model": j.output_model, "created_at": j.created_at}
|
|
for j in jobs]
|
|
|
|
|
|
@router.get("/jobs/{job_id}")
|
|
async def get_job(job_id: int, db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user)):
|
|
row = await db.execute(select(FinetuneJob).where(FinetuneJob.id == job_id))
|
|
j = row.scalar_one_or_none()
|
|
if not j: raise HTTPException(404)
|
|
return {"id": j.id, "base_model": j.base_model, "status": j.status,
|
|
"output_model": j.output_model,
|
|
"loss_history": json.loads(j.loss_history_json or "[]"),
|
|
"created_at": j.created_at, "finished_at": j.finished_at}
|
|
|
|
|
|
@router.post("/deploy/{job_id}")
|
|
async def deploy_model(job_id: int, db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(require_admin_role)):
|
|
"""파인튜닝 모델 Ollama 배포."""
|
|
row = await db.execute(select(FinetuneJob).where(FinetuneJob.id == job_id))
|
|
j = row.scalar_one_or_none()
|
|
if not j or j.status != "COMPLETED": raise HTTPException(400, "완료된 작업만 배포 가능")
|
|
return {"ok": True, "model": j.output_model,
|
|
"instruction": f"ollama pull {j.output_model} 후 Modelfile로 서빙",
|
|
"note": "실제 배포 시 GGUF 변환 + Ollama Modelfile 생성 필요"}
|
|
|
|
|
|
@router.get("/models")
|
|
async def list_trained_models(db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
rows = await db.execute(
|
|
select(FinetuneJob).where(FinetuneJob.status == "COMPLETED")
|
|
.order_by(desc(FinetuneJob.finished_at)).limit(10)
|
|
)
|
|
return [{"model": j.output_model, "base": j.base_model,
|
|
"dataset_size": j.dataset_size, "finished": j.finished_at}
|
|
for j in rows.scalars().all()]
|
|
|
|
|
|
@router.get("/quality")
|
|
async def data_quality(db: AsyncSession = Depends(get_db), user: User = Depends(get_current_user)):
|
|
total = (await db.execute(select(func.count(FeedbackSample.id)))).scalar() or 0
|
|
high_quality = (await db.execute(
|
|
select(func.count(FeedbackSample.id)).where(FeedbackSample.quality_score >= 0.7)
|
|
)).scalar() or 0
|
|
by_domain = {}
|
|
for domain in ["general", "incident", "deploy", "security"]:
|
|
cnt = (await db.execute(
|
|
select(func.count(FeedbackSample.id)).where(FeedbackSample.domain == domain)
|
|
)).scalar() or 0
|
|
by_domain[domain] = cnt
|
|
return {"total_samples": total, "high_quality": high_quality,
|
|
"quality_rate": round(high_quality / max(total, 1) * 100, 1),
|
|
"by_domain": by_domain, "ready_for_training": high_quality >= 50}
|