161 lines
5.8 KiB
Python
161 lines
5.8 KiB
Python
"""미래 준비 — LoRA 자동 파인튜닝 파이프라인 관리."""
|
|
from __future__ import annotations
|
|
import logging
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select, desc
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from core.auth import get_current_user, require_admin_role
|
|
from database import get_db
|
|
from models import User, AutoFinetuneJob
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/finetune", tags=["미래준비-파인튜닝"])
|
|
|
|
MIN_DATASET_SIZE = 50
|
|
|
|
|
|
class FinetuneRequest(BaseModel):
|
|
model_name: str = "llama3"
|
|
epochs: int = 3
|
|
note: str = ""
|
|
|
|
|
|
async def _run_finetune(job_id: int):
|
|
"""LoRA 파인튜닝 실행 (Ollama Modelfile 기반)."""
|
|
from database import AsyncSessionLocal
|
|
async with AsyncSessionLocal() as db:
|
|
job = await db.get(AutoFinetuneJob, job_id)
|
|
if not job:
|
|
return
|
|
job.status = "running"
|
|
job.started_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
# SR 이력 데이터 수집
|
|
try:
|
|
from database import AsyncSessionLocal
|
|
from models import ServiceRequest
|
|
async with AsyncSessionLocal() as db:
|
|
rows = await db.execute(
|
|
select(ServiceRequest).where(ServiceRequest.status == "완료").limit(500)
|
|
)
|
|
samples = rows.scalars().all()
|
|
except Exception:
|
|
samples = []
|
|
|
|
dataset_size = len(samples)
|
|
success = False
|
|
loss = None
|
|
error_msg = None
|
|
|
|
if dataset_size < MIN_DATASET_SIZE:
|
|
error_msg = f"데이터 부족: {dataset_size}개 (최소 {MIN_DATASET_SIZE}개 필요). 다음 달 재시도 예정."
|
|
logger.warning(error_msg)
|
|
else:
|
|
try:
|
|
import asyncio, json, tempfile, os
|
|
dataset_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False, encoding="utf-8")
|
|
for sr in samples[:500]:
|
|
dataset_file.write(json.dumps({
|
|
"prompt": getattr(sr, "description", ""),
|
|
"completion": getattr(sr, "resolution", ""),
|
|
}, ensure_ascii=False) + "\n")
|
|
dataset_file.close()
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"python3", "/opt/guardia/scripts/lora_finetune.py",
|
|
"--model", job.model_name,
|
|
"--dataset", dataset_file.name,
|
|
"--epochs", str(job.epochs),
|
|
"--output", f"/opt/guardia/models/{job.model_name}-finetuned",
|
|
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT,
|
|
)
|
|
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=3600)
|
|
output = stdout.decode("utf-8", "replace")
|
|
success = proc.returncode == 0
|
|
for line in output.splitlines():
|
|
if "loss" in line.lower():
|
|
try:
|
|
loss = float(line.split("loss")[-1].strip().split()[0].strip(":= "))
|
|
except Exception:
|
|
pass
|
|
os.unlink(dataset_file.name)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error("파인튜닝 실패: %s", e)
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
job = await db.get(AutoFinetuneJob, job_id)
|
|
if job:
|
|
job.status = "success" if success else ("failed" if not error_msg else "skipped")
|
|
job.dataset_size = dataset_size
|
|
job.loss = loss
|
|
job.error_msg = error_msg
|
|
job.finished_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
if success:
|
|
try:
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
await c.post("http://127.0.0.1:9001/api/messenger/webhook", json={
|
|
"event": "finetune_complete",
|
|
"room": "ops",
|
|
"message": f"🧠 LoRA 파인튜닝 완료: {job.model_name} (loss={loss:.4f})",
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@router.post("/start")
|
|
async def start_finetune(
|
|
req: FinetuneRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(require_admin_role),
|
|
):
|
|
"""LoRA 파인튜닝 작업 시작."""
|
|
job = AutoFinetuneJob(
|
|
model_name=req.model_name, epochs=req.epochs,
|
|
status="pending", created_at=datetime.utcnow(),
|
|
)
|
|
db.add(job)
|
|
await db.commit()
|
|
await db.refresh(job)
|
|
background_tasks.add_task(_run_finetune, job.id)
|
|
return {"ok": True, "job_id": job.id, "message": "파인튜닝 작업 시작됨 (백그라운드)"}
|
|
|
|
|
|
@router.get("/jobs")
|
|
async def list_jobs(
|
|
limit: int = 20,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""파인튜닝 작업 이력."""
|
|
rows = await db.execute(
|
|
select(AutoFinetuneJob).order_by(desc(AutoFinetuneJob.created_at)).limit(limit)
|
|
)
|
|
return [{
|
|
"id": j.id, "model_name": j.model_name, "status": j.status,
|
|
"dataset_size": j.dataset_size, "loss": j.loss, "epochs": j.epochs,
|
|
"started_at": j.started_at, "finished_at": j.finished_at,
|
|
"error_msg": j.error_msg,
|
|
} for j in rows.scalars().all()]
|
|
|
|
|
|
@router.get("/jobs/{job_id}")
|
|
async def get_job(
|
|
job_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
job = await db.get(AutoFinetuneJob, job_id)
|
|
if not job:
|
|
raise HTTPException(404, "작업 없음")
|
|
return {"id": job.id, "model_name": job.model_name, "status": job.status,
|
|
"dataset_size": job.dataset_size, "loss": job.loss, "epochs": job.epochs,
|
|
"started_at": job.started_at, "finished_at": job.finished_at, "error_msg": job.error_msg}
|