guardia-itsm/routers/auto_finetune.py

161 lines
5.8 KiB
Python

"""미래 준비 — LoRA 자동 파인튜닝 파이프라인 관리."""
from __future__ import annotations
import logging
from datetime import datetime
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select, desc
from sqlalchemy.ext.asyncio import AsyncSession
from core.auth import get_current_user, require_admin_role
from database import get_db
from models import User, AutoFinetuneJob
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/finetune", tags=["미래준비-파인튜닝"])
MIN_DATASET_SIZE = 50
class FinetuneRequest(BaseModel):
model_name: str = "llama3"
epochs: int = 3
note: str = ""
async def _run_finetune(job_id: int):
"""LoRA 파인튜닝 실행 (Ollama Modelfile 기반)."""
from database import AsyncSessionLocal
async with AsyncSessionLocal() as db:
job = await db.get(AutoFinetuneJob, job_id)
if not job:
return
job.status = "running"
job.started_at = datetime.utcnow()
await db.commit()
# SR 이력 데이터 수집
try:
from database import AsyncSessionLocal
from models import ServiceRequest
async with AsyncSessionLocal() as db:
rows = await db.execute(
select(ServiceRequest).where(ServiceRequest.status == "완료").limit(500)
)
samples = rows.scalars().all()
except Exception:
samples = []
dataset_size = len(samples)
success = False
loss = None
error_msg = None
if dataset_size < MIN_DATASET_SIZE:
error_msg = f"데이터 부족: {dataset_size}개 (최소 {MIN_DATASET_SIZE}개 필요). 다음 달 재시도 예정."
logger.warning(error_msg)
else:
try:
import asyncio, json, tempfile, os
dataset_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False, encoding="utf-8")
for sr in samples[:500]:
dataset_file.write(json.dumps({
"prompt": getattr(sr, "description", ""),
"completion": getattr(sr, "resolution", ""),
}, ensure_ascii=False) + "\n")
dataset_file.close()
proc = await asyncio.create_subprocess_exec(
"python3", "/opt/guardia/scripts/lora_finetune.py",
"--model", job.model_name,
"--dataset", dataset_file.name,
"--epochs", str(job.epochs),
"--output", f"/opt/guardia/models/{job.model_name}-finetuned",
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=3600)
output = stdout.decode("utf-8", "replace")
success = proc.returncode == 0
for line in output.splitlines():
if "loss" in line.lower():
try:
loss = float(line.split("loss")[-1].strip().split()[0].strip(":= "))
except Exception:
pass
os.unlink(dataset_file.name)
except Exception as e:
error_msg = str(e)
logger.error("파인튜닝 실패: %s", e)
async with AsyncSessionLocal() as db:
job = await db.get(AutoFinetuneJob, job_id)
if job:
job.status = "success" if success else ("failed" if not error_msg else "skipped")
job.dataset_size = dataset_size
job.loss = loss
job.error_msg = error_msg
job.finished_at = datetime.utcnow()
await db.commit()
if success:
try:
import httpx
async with httpx.AsyncClient(timeout=5) as c:
await c.post("http://127.0.0.1:9001/api/messenger/webhook", json={
"event": "finetune_complete",
"room": "ops",
"message": f"🧠 LoRA 파인튜닝 완료: {job.model_name} (loss={loss:.4f})",
})
except Exception:
pass
@router.post("/start")
async def start_finetune(
req: FinetuneRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
user: User = Depends(require_admin_role),
):
"""LoRA 파인튜닝 작업 시작."""
job = AutoFinetuneJob(
model_name=req.model_name, epochs=req.epochs,
status="pending", created_at=datetime.utcnow(),
)
db.add(job)
await db.commit()
await db.refresh(job)
background_tasks.add_task(_run_finetune, job.id)
return {"ok": True, "job_id": job.id, "message": "파인튜닝 작업 시작됨 (백그라운드)"}
@router.get("/jobs")
async def list_jobs(
limit: int = 20,
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
"""파인튜닝 작업 이력."""
rows = await db.execute(
select(AutoFinetuneJob).order_by(desc(AutoFinetuneJob.created_at)).limit(limit)
)
return [{
"id": j.id, "model_name": j.model_name, "status": j.status,
"dataset_size": j.dataset_size, "loss": j.loss, "epochs": j.epochs,
"started_at": j.started_at, "finished_at": j.finished_at,
"error_msg": j.error_msg,
} for j in rows.scalars().all()]
@router.get("/jobs/{job_id}")
async def get_job(
job_id: int,
db: AsyncSession = Depends(get_db),
user: User = Depends(get_current_user),
):
job = await db.get(AutoFinetuneJob, job_id)
if not job:
raise HTTPException(404, "작업 없음")
return {"id": job.id, "model_name": job.model_name, "status": job.status,
"dataset_size": job.dataset_size, "loss": job.loss, "epochs": job.epochs,
"started_at": job.started_at, "finished_at": job.finished_at, "error_msg": job.error_msg}