281 lines
10 KiB
Python
281 lines
10 KiB
Python
"""
|
|
SR 채팅방 API + WebSocket (모바일 기능 #98).
|
|
|
|
WS /ws/sr-chat/{sr_id}?token={jwt} — SR별 실시간 채팅
|
|
POST /api/sr-chat/{sr_id}/messages — 메시지 전송 (REST)
|
|
GET /api/sr-chat/{sr_id}/messages — 메시지 이력
|
|
POST /api/sr-chat/{sr_id}/read — 읽음 처리
|
|
|
|
메시지 타입: text | image | sr_update
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
from fastapi import (
|
|
APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect,
|
|
)
|
|
from pydantic import BaseModel, ConfigDict
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from core.auth import get_current_user
|
|
from database import get_db, SessionLocal
|
|
from models import SRChatMessage, SRRequest, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(tags=["SR Chat"])
|
|
|
|
_VALID_MSG_TYPE = {"text", "image", "sr_update"}
|
|
|
|
|
|
# ── WebSocket 연결 관리 (SR방별 그룹) ─────────────────────────────────────────
|
|
class _ChatRooms:
|
|
def __init__(self) -> None:
|
|
# { sr_id: set(WebSocket) }
|
|
self._rooms: Dict[str, Set[WebSocket]] = {}
|
|
|
|
def join(self, sr_id: str, ws: WebSocket) -> None:
|
|
self._rooms.setdefault(sr_id, set()).add(ws)
|
|
|
|
def leave(self, sr_id: str, ws: WebSocket) -> None:
|
|
room = self._rooms.get(sr_id)
|
|
if room:
|
|
room.discard(ws)
|
|
if not room:
|
|
self._rooms.pop(sr_id, None)
|
|
|
|
async def broadcast(self, sr_id: str, payload: dict) -> None:
|
|
room = self._rooms.get(sr_id)
|
|
if not room:
|
|
return
|
|
msg = json.dumps(payload, ensure_ascii=False)
|
|
dead = []
|
|
for ws in list(room):
|
|
try:
|
|
await ws.send_text(msg)
|
|
except Exception:
|
|
dead.append(ws)
|
|
for ws in dead:
|
|
room.discard(ws)
|
|
|
|
|
|
rooms = _ChatRooms()
|
|
|
|
|
|
# ── 스키마 ────────────────────────────────────────────────────────────────────
|
|
class ChatMessageCreate(BaseModel):
|
|
content: str
|
|
msg_type: str = "text"
|
|
|
|
|
|
class ChatMessageOut(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: int
|
|
task_id: str
|
|
sender_id: str
|
|
content: str
|
|
msg_type: str
|
|
created_at: Optional[datetime]
|
|
|
|
|
|
# ── 헬퍼 ──────────────────────────────────────────────────────────────────────
|
|
async def _ensure_sr(sr_id: str, db: AsyncSession) -> SRRequest:
|
|
sr = (await db.execute(
|
|
select(SRRequest).where(SRRequest.sr_id == sr_id)
|
|
)).scalars().first()
|
|
if not sr:
|
|
raise HTTPException(404, "SR을 찾을 수 없습니다.")
|
|
return sr
|
|
|
|
|
|
async def _save_message(db: AsyncSession, sr_id: str, sender: str,
|
|
content: str, msg_type: str) -> SRChatMessage:
|
|
if msg_type not in _VALID_MSG_TYPE:
|
|
raise HTTPException(422, f"msg_type은 {_VALID_MSG_TYPE} 중 하나여야 합니다.")
|
|
if not content or not content.strip():
|
|
raise HTTPException(422, "메시지 내용이 비어 있습니다.")
|
|
m = SRChatMessage(
|
|
task_id=sr_id,
|
|
sender_id=sender,
|
|
content=content,
|
|
msg_type=msg_type,
|
|
read_by=json.dumps([sender], ensure_ascii=False),
|
|
)
|
|
db.add(m)
|
|
await db.commit()
|
|
await db.refresh(m)
|
|
return m
|
|
|
|
|
|
# ── REST: 메시지 전송 ─────────────────────────────────────────────────────────
|
|
@router.post("/api/sr-chat/{sr_id}/messages", response_model=ChatMessageOut, status_code=201)
|
|
async def send_message(
|
|
sr_id: str,
|
|
payload: ChatMessageCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""SR 채팅 메시지 전송 (REST). 연결된 WebSocket 구독자에게도 브로드캐스트."""
|
|
await _ensure_sr(sr_id, db)
|
|
m = await _save_message(db, sr_id, current_user.username,
|
|
payload.content, payload.msg_type)
|
|
await rooms.broadcast(sr_id, {
|
|
"type": "message",
|
|
"id": m.id,
|
|
"task_id": sr_id,
|
|
"sender_id": m.sender_id,
|
|
"content": m.content,
|
|
"msg_type": m.msg_type,
|
|
"created_at": m.created_at.isoformat() if m.created_at else None,
|
|
})
|
|
return m
|
|
|
|
|
|
# ── REST: 메시지 이력 ─────────────────────────────────────────────────────────
|
|
@router.get("/api/sr-chat/{sr_id}/messages", response_model=List[ChatMessageOut])
|
|
async def list_messages(
|
|
sr_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""SR 채팅 메시지 이력 (오래된 순)."""
|
|
await _ensure_sr(sr_id, db)
|
|
rows = (await db.execute(
|
|
select(SRChatMessage)
|
|
.where(SRChatMessage.task_id == sr_id)
|
|
.order_by(SRChatMessage.created_at.asc())
|
|
.offset(skip).limit(min(limit, 500))
|
|
)).scalars().all()
|
|
return rows
|
|
|
|
|
|
# ── REST: 읽음 처리 ───────────────────────────────────────────────────────────
|
|
@router.post("/api/sr-chat/{sr_id}/read")
|
|
async def mark_read(
|
|
sr_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""현재 사용자가 SR 채팅의 모든 메시지를 읽음 처리."""
|
|
await _ensure_sr(sr_id, db)
|
|
rows = (await db.execute(
|
|
select(SRChatMessage).where(SRChatMessage.task_id == sr_id)
|
|
)).scalars().all()
|
|
updated = 0
|
|
for m in rows:
|
|
try:
|
|
readers = json.loads(m.read_by) if m.read_by else []
|
|
except Exception:
|
|
readers = []
|
|
if current_user.username not in readers:
|
|
readers.append(current_user.username)
|
|
m.read_by = json.dumps(readers, ensure_ascii=False)
|
|
updated += 1
|
|
await db.commit()
|
|
return {"sr_id": sr_id, "marked_read": updated, "reader": current_user.username}
|
|
|
|
|
|
# ── WebSocket: 실시간 채팅 ────────────────────────────────────────────────────
|
|
async def _authenticate_ws(token: str, db: AsyncSession) -> Optional[User]:
|
|
if not token:
|
|
return None
|
|
try:
|
|
from core.auth import SECRET_KEY, ALGORITHM
|
|
from jose import jwt
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
if payload.get("mfa_pending"):
|
|
return None
|
|
username = payload.get("sub")
|
|
if not username:
|
|
return None
|
|
user = (await db.execute(
|
|
select(User).where(User.username == username)
|
|
)).scalars().first()
|
|
return user if (user and user.is_active) else None
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
@router.websocket("/ws/sr-chat/{sr_id}")
|
|
async def sr_chat_ws(
|
|
websocket: WebSocket,
|
|
sr_id: str,
|
|
token: str = Query(..., description="JWT access_token"),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""SR별 실시간 채팅 WebSocket."""
|
|
user = await _authenticate_ws(token, db)
|
|
if not user:
|
|
await websocket.close(code=4001, reason="인증 실패: 유효한 토큰이 필요합니다.")
|
|
return
|
|
|
|
# SR 존재 확인
|
|
sr = (await db.execute(
|
|
select(SRRequest).where(SRRequest.sr_id == sr_id)
|
|
)).scalars().first()
|
|
if not sr:
|
|
await websocket.close(code=4004, reason="SR을 찾을 수 없습니다.")
|
|
return
|
|
|
|
await websocket.accept()
|
|
rooms.join(sr_id, websocket)
|
|
await websocket.send_text(json.dumps({
|
|
"type": "connected",
|
|
"sr_id": sr_id,
|
|
"username": user.username,
|
|
"server_time": datetime.now().isoformat(),
|
|
}, ensure_ascii=False))
|
|
|
|
try:
|
|
while True:
|
|
raw = await websocket.receive_text()
|
|
try:
|
|
data = json.loads(raw)
|
|
except Exception:
|
|
await websocket.send_text(json.dumps(
|
|
{"type": "error", "message": "JSON 형식이 아닙니다."},
|
|
ensure_ascii=False))
|
|
continue
|
|
|
|
if data.get("type") == "ping":
|
|
await websocket.send_text(json.dumps(
|
|
{"type": "pong", "server_time": datetime.now().isoformat()},
|
|
ensure_ascii=False))
|
|
continue
|
|
|
|
content = (data.get("content") or "").strip()
|
|
msg_type = data.get("msg_type", "text")
|
|
if not content or msg_type not in _VALID_MSG_TYPE:
|
|
await websocket.send_text(json.dumps(
|
|
{"type": "error", "message": "content 또는 msg_type이 올바르지 않습니다."},
|
|
ensure_ascii=False))
|
|
continue
|
|
|
|
# DB 저장 (독립 세션) + 구독자에게 브로드캐스트
|
|
async with SessionLocal() as _db:
|
|
m = await _save_message(_db, sr_id, user.username, content, msg_type)
|
|
payload = {
|
|
"type": "message",
|
|
"id": m.id,
|
|
"task_id": sr_id,
|
|
"sender_id": m.sender_id,
|
|
"content": m.content,
|
|
"msg_type": m.msg_type,
|
|
"created_at": m.created_at.isoformat() if m.created_at else None,
|
|
}
|
|
await rooms.broadcast(sr_id, payload)
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
logger.debug("SR 채팅 WS 오류: sr=%s err=%s", sr_id, exc)
|
|
finally:
|
|
rooms.leave(sr_id, websocket)
|