444 lines
14 KiB
Python
444 lines
14 KiB
Python
"""
|
|
테넌트별 개인화 AI — 파인튜닝·질의·KB 관리
|
|
|
|
기능:
|
|
- 테넌트별 Ollama 모델 현황 조회
|
|
- 파인튜닝(LoRA) 시작 및 진행 상황 추적
|
|
- 개인화 AI 질의 (테넌트 KB 컨텍스트 주입)
|
|
- 테넌트 전용 지식베이스(KB) CRUD
|
|
- 사용 통계
|
|
|
|
보안:
|
|
- 테넌트 데이터 완전 격리 (tenant_id 필터 강제)
|
|
- 외부 API 완전 금지 — Ollama localhost:11434 only
|
|
|
|
엔드포인트:
|
|
GET /api/tenant-ai/models — 테넌트별 모델 현황
|
|
POST /api/tenant-ai/train — 파인튜닝 시작
|
|
GET /api/tenant-ai/train/{id} — 학습 진행 상황
|
|
POST /api/tenant-ai/query — 개인화 AI 질의
|
|
GET /api/tenant-ai/kb — 테넌트 KB 문서 목록
|
|
POST /api/tenant-ai/kb — KB 문서 추가
|
|
GET /api/tenant-ai/stats — 사용 통계
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import func, select, desc
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.auth import get_current_user
|
|
from database import get_db
|
|
from models import TenantAIModel, TenantKBDoc, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/tenant-ai", tags=["Tenant AI"])
|
|
|
|
OLLAMA_URL = "http://localhost:11434"
|
|
|
|
# ── 파인튜닝 진행 상태 인메모리 캐시 (운영 환경에서는 DB/Redis로 대체 가능)
|
|
_train_jobs: Dict[int, Dict[str, Any]] = {}
|
|
|
|
|
|
# ── Pydantic 스키마 ──────────────────────────────────────────────────────────
|
|
|
|
class TrainRequest(BaseModel):
|
|
model_name: str = Field(..., max_length=100, description="신규 모델 이름 (테넌트 전용)")
|
|
base_model: str = Field("llama3", description="베이스 Ollama 모델")
|
|
description: Optional[str] = None
|
|
|
|
|
|
class TrainStatusOut(BaseModel):
|
|
id: int
|
|
tenant_id: str
|
|
model_name: str
|
|
base_model: str
|
|
status: str
|
|
accuracy: Optional[float]
|
|
dataset_size: int
|
|
created_at: datetime
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
question: str = Field(..., min_length=1, max_length=2000)
|
|
model_name: Optional[str] = Field(None, description="사용할 테넌트 모델 이름 (미지정 시 기본 llama3)")
|
|
use_kb: bool = Field(True, description="테넌트 KB 컨텍스트 주입 여부")
|
|
top_k: int = Field(3, ge=1, le=10, description="KB 문서 최대 참조 수")
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
answer: str
|
|
sources: List[str]
|
|
model_used: str
|
|
|
|
|
|
class KBDocCreate(BaseModel):
|
|
title: str = Field(..., max_length=300)
|
|
content: str = Field(..., min_length=1)
|
|
|
|
|
|
class KBDocOut(BaseModel):
|
|
id: int
|
|
tenant_id: str
|
|
title: str
|
|
content: str
|
|
created_at: datetime
|
|
|
|
|
|
class ModelOut(BaseModel):
|
|
id: int
|
|
tenant_id: str
|
|
model_name: str
|
|
base_model: str
|
|
status: str
|
|
accuracy: Optional[float]
|
|
dataset_size: int
|
|
created_at: datetime
|
|
|
|
|
|
# ── 내부 헬퍼 ────────────────────────────────────────────────────────────────
|
|
|
|
def _get_tenant_id(user: User) -> str:
|
|
"""현재 사용자의 테넌트 ID 반환 (inst_code 우선, 없으면 username)."""
|
|
return user.inst_code or user.username
|
|
|
|
|
|
async def _simulate_training(model_id: int, tenant_id: str) -> None:
|
|
"""
|
|
실제 LoRA 파인튜닝 대신 상태 전이만 시뮬레이션한다.
|
|
운영 환경에서는 Unsloth/LoRA 학습 프로세스로 교체한다.
|
|
"""
|
|
import asyncio
|
|
from database import SessionLocal
|
|
|
|
_train_jobs[model_id] = {"progress": 0, "message": "데이터셋 준비 중"}
|
|
await asyncio.sleep(2)
|
|
|
|
_train_jobs[model_id] = {"progress": 30, "message": "학습 진행 중 (30%)"}
|
|
await asyncio.sleep(3)
|
|
|
|
_train_jobs[model_id] = {"progress": 70, "message": "학습 진행 중 (70%)"}
|
|
await asyncio.sleep(2)
|
|
|
|
async with SessionLocal() as db:
|
|
row = await db.execute(
|
|
select(TenantAIModel).where(TenantAIModel.id == model_id)
|
|
)
|
|
model = row.scalar_one_or_none()
|
|
if model:
|
|
model.status = "ready"
|
|
model.accuracy = 0.91
|
|
await db.commit()
|
|
|
|
_train_jobs[model_id] = {"progress": 100, "message": "학습 완료"}
|
|
logger.info(f"[TenantAI] 모델 {model_id} 학습 완료 (tenant={tenant_id})")
|
|
|
|
|
|
# ── 엔드포인트 ───────────────────────────────────────────────────────────────
|
|
|
|
@router.get("/models", response_model=List[ModelOut])
|
|
async def list_models(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""테넌트별 AI 모델 현황 조회."""
|
|
tenant_id = _get_tenant_id(user)
|
|
rows = await db.execute(
|
|
select(TenantAIModel)
|
|
.where(TenantAIModel.tenant_id == tenant_id)
|
|
.order_by(desc(TenantAIModel.created_at))
|
|
)
|
|
models = rows.scalars().all()
|
|
return [
|
|
ModelOut(
|
|
id=m.id,
|
|
tenant_id=m.tenant_id,
|
|
model_name=m.model_name,
|
|
base_model=m.base_model,
|
|
status=m.status,
|
|
accuracy=m.accuracy,
|
|
dataset_size=m.dataset_size,
|
|
created_at=m.created_at,
|
|
)
|
|
for m in models
|
|
]
|
|
|
|
|
|
@router.post("/train", response_model=TrainStatusOut)
|
|
async def start_training(
|
|
req: TrainRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""파인튜닝 작업 시작."""
|
|
tenant_id = _get_tenant_id(user)
|
|
|
|
# 동일 테넌트 내 학습 중인 모델 중복 방지
|
|
running_row = await db.execute(
|
|
select(TenantAIModel).where(
|
|
TenantAIModel.tenant_id == tenant_id,
|
|
TenantAIModel.status == "training",
|
|
)
|
|
)
|
|
if running_row.scalar_one_or_none():
|
|
raise HTTPException(409, "이미 학습 중인 모델이 있습니다. 완료 후 다시 시도하세요.")
|
|
|
|
# KB 문서 수 확인
|
|
kb_count_row = await db.execute(
|
|
select(func.count(TenantKBDoc.id)).where(TenantKBDoc.tenant_id == tenant_id)
|
|
)
|
|
kb_count = kb_count_row.scalar() or 0
|
|
|
|
model = TenantAIModel(
|
|
tenant_id=tenant_id,
|
|
model_name=req.model_name,
|
|
base_model=req.base_model,
|
|
dataset_size=kb_count,
|
|
status="training",
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
db.add(model)
|
|
await db.commit()
|
|
await db.refresh(model)
|
|
|
|
# 백그라운드 학습
|
|
background_tasks.add_task(_simulate_training, model.id, tenant_id)
|
|
logger.info(f"[TenantAI] 파인튜닝 시작 (tenant={tenant_id}, model={req.model_name})")
|
|
|
|
return TrainStatusOut(
|
|
id=model.id,
|
|
tenant_id=model.tenant_id,
|
|
model_name=model.model_name,
|
|
base_model=model.base_model,
|
|
status=model.status,
|
|
accuracy=model.accuracy,
|
|
dataset_size=model.dataset_size,
|
|
created_at=model.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/train/{model_id}", response_model=TrainStatusOut)
|
|
async def get_training_status(
|
|
model_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""학습 진행 상황 조회."""
|
|
tenant_id = _get_tenant_id(user)
|
|
row = await db.execute(
|
|
select(TenantAIModel).where(
|
|
TenantAIModel.id == model_id,
|
|
TenantAIModel.tenant_id == tenant_id, # 테넌트 격리
|
|
)
|
|
)
|
|
model = row.scalar_one_or_none()
|
|
if not model:
|
|
raise HTTPException(404, "모델을 찾을 수 없습니다")
|
|
|
|
# 인메모리 진행률 주입
|
|
job_info = _train_jobs.get(model_id, {})
|
|
progress = job_info.get("progress", 100 if model.status == "ready" else 0)
|
|
|
|
return TrainStatusOut(
|
|
id=model.id,
|
|
tenant_id=model.tenant_id,
|
|
model_name=model.model_name,
|
|
base_model=model.base_model,
|
|
status=model.status,
|
|
accuracy=model.accuracy,
|
|
dataset_size=model.dataset_size,
|
|
created_at=model.created_at,
|
|
)
|
|
|
|
|
|
@router.post("/query", response_model=QueryResponse)
|
|
async def query_ai(
|
|
req: QueryRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""개인화 AI 질의 — 테넌트 KB 컨텍스트 주입 후 Ollama 호출."""
|
|
tenant_id = _get_tenant_id(user)
|
|
|
|
# 1. 테넌트 KB에서 관련 문서 검색 (단순 키워드 매칭)
|
|
kb_context = ""
|
|
sources: List[str] = []
|
|
if req.use_kb:
|
|
kb_rows = await db.execute(
|
|
select(TenantKBDoc)
|
|
.where(TenantKBDoc.tenant_id == tenant_id)
|
|
.order_by(desc(TenantKBDoc.created_at))
|
|
.limit(50)
|
|
)
|
|
kb_docs = kb_rows.scalars().all()
|
|
keywords = set(req.question.lower().split())
|
|
scored: List[tuple[int, TenantKBDoc]] = []
|
|
for doc in kb_docs:
|
|
score = sum(1 for k in keywords if k in (doc.content or "").lower())
|
|
if score > 0:
|
|
scored.append((score, doc))
|
|
scored.sort(key=lambda x: -x[0])
|
|
top_docs = [d for _, d in scored[: req.top_k]]
|
|
if top_docs:
|
|
kb_context = "\n\n".join(
|
|
f"[문서: {d.title}]\n{d.content[:500]}" for d in top_docs
|
|
)
|
|
sources = [d.title for d in top_docs]
|
|
|
|
# 2. 사용할 모델 결정 (테넌트 ready 모델 → 기본 llama3)
|
|
model_name = req.model_name
|
|
if not model_name:
|
|
ready_row = await db.execute(
|
|
select(TenantAIModel).where(
|
|
TenantAIModel.tenant_id == tenant_id,
|
|
TenantAIModel.status == "ready",
|
|
).order_by(desc(TenantAIModel.created_at))
|
|
)
|
|
ready_model = ready_row.scalar_one_or_none()
|
|
model_name = ready_model.model_name if ready_model else "llama3"
|
|
|
|
# 3. Ollama 호출 (localhost only)
|
|
system_prompt = (
|
|
"당신은 GUARDiA ITSM 전문 AI 어시스턴트입니다. "
|
|
"한국어로 간결하고 정확하게 답변하세요."
|
|
)
|
|
if kb_context:
|
|
system_prompt += f"\n\n참고 문서:\n{kb_context}"
|
|
|
|
prompt = f"{system_prompt}\n\n질문: {req.question}"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.post(
|
|
f"{OLLAMA_URL}/api/generate",
|
|
json={
|
|
"model": model_name,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {"temperature": 0.3, "num_predict": 512},
|
|
},
|
|
)
|
|
if resp.status_code == 200:
|
|
answer = resp.json().get("response", "").strip()
|
|
else:
|
|
answer = "AI 응답을 가져오지 못했습니다. 잠시 후 다시 시도하세요."
|
|
except Exception as e:
|
|
logger.warning(f"[TenantAI] Ollama 호출 실패: {e}")
|
|
answer = "AI 서비스에 일시적 문제가 발생했습니다."
|
|
|
|
return QueryResponse(answer=answer, sources=sources, model_used=model_name)
|
|
|
|
|
|
@router.get("/kb", response_model=List[KBDocOut])
|
|
async def list_kb(
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""테넌트 KB 문서 목록."""
|
|
tenant_id = _get_tenant_id(user)
|
|
rows = await db.execute(
|
|
select(TenantKBDoc)
|
|
.where(TenantKBDoc.tenant_id == tenant_id)
|
|
.order_by(desc(TenantKBDoc.created_at))
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
docs = rows.scalars().all()
|
|
return [
|
|
KBDocOut(
|
|
id=d.id,
|
|
tenant_id=d.tenant_id,
|
|
title=d.title,
|
|
content=d.content,
|
|
created_at=d.created_at,
|
|
)
|
|
for d in docs
|
|
]
|
|
|
|
|
|
@router.post("/kb", response_model=KBDocOut, status_code=201)
|
|
async def add_kb_doc(
|
|
req: KBDocCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""KB 문서 추가."""
|
|
tenant_id = _get_tenant_id(user)
|
|
doc = TenantKBDoc(
|
|
tenant_id=tenant_id,
|
|
title=req.title,
|
|
content=req.content,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
db.add(doc)
|
|
await db.commit()
|
|
await db.refresh(doc)
|
|
logger.info(f"[TenantAI] KB 문서 추가 (tenant={tenant_id}, id={doc.id})")
|
|
return KBDocOut(
|
|
id=doc.id,
|
|
tenant_id=doc.tenant_id,
|
|
title=doc.title,
|
|
content=doc.content,
|
|
created_at=doc.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/stats")
|
|
async def get_stats(
|
|
db: AsyncSession = Depends(get_db),
|
|
user: User = Depends(get_current_user),
|
|
):
|
|
"""테넌트 AI 사용 통계."""
|
|
tenant_id = _get_tenant_id(user)
|
|
|
|
# 모델 통계
|
|
model_count_row = await db.execute(
|
|
select(func.count(TenantAIModel.id)).where(TenantAIModel.tenant_id == tenant_id)
|
|
)
|
|
model_count = model_count_row.scalar() or 0
|
|
|
|
ready_count_row = await db.execute(
|
|
select(func.count(TenantAIModel.id)).where(
|
|
TenantAIModel.tenant_id == tenant_id,
|
|
TenantAIModel.status == "ready",
|
|
)
|
|
)
|
|
ready_count = ready_count_row.scalar() or 0
|
|
|
|
# KB 통계
|
|
kb_count_row = await db.execute(
|
|
select(func.count(TenantKBDoc.id)).where(TenantKBDoc.tenant_id == tenant_id)
|
|
)
|
|
kb_count = kb_count_row.scalar() or 0
|
|
|
|
# 최신 모델 정보
|
|
latest_row = await db.execute(
|
|
select(TenantAIModel)
|
|
.where(TenantAIModel.tenant_id == tenant_id)
|
|
.order_by(desc(TenantAIModel.created_at))
|
|
)
|
|
latest = latest_row.scalar_one_or_none()
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"total_models": model_count,
|
|
"ready_models": ready_count,
|
|
"kb_documents": kb_count,
|
|
"latest_model": {
|
|
"id": latest.id,
|
|
"name": latest.model_name,
|
|
"status": latest.status,
|
|
"accuracy": latest.accuracy,
|
|
} if latest else None,
|
|
}
|