Pydantic, 데이터 서빙시, Interface Data Validation 관련 라이브러리

Pydantic, 데이터 서빙시, Interface Data Validation 관련 라이브러리
Photo by Daria Nepriakhina 🇺🇦 / Unsplash

Motivation

  • Pydantic은 Validation Check를 위한 라이브러리로, 잘못된 데이터가 시스템에서 유입되고 운용되는 것을 막기 위한 라이브러리
  • 비동기 웹 프레임워크인 FastAPI와 함께 많이 쓰임
@app.post("/items/")
async def create_item(item: Item):
    # 비동기 처리를 포함한 작업 수행
    return item

Pros & Cons

Pros

  • 데이터 모델을 정의 후에 자동으로 데이터 검증 후 변환
  • 타입힌트를 활용함으로써 코드 가독성이 높은 편
class Item(BaseModel):
    name: str
    price: float
    description: str = None
    tax: float = None

@app.post("/items/")
async def create_item(item: Item):
    return item
  • Cython을 이용해서 성능을 최적화하였음
  • FastAPI 통합이 원활하여 웹 어플리케이션 개발시 유용
import torch
import torch.nn as nn
import numpy as np

class PredictionRequest(BaseModel):
    data: list

class PredictionResponse(BaseModel):
    prediction: list

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(10, 10)

    def forward(self, x):
        return self.layer1(x)

model = SimpleModel()
model.load_state_dict(torch.load("simple_model.pth"))
model.eval()

@app.post("/predict/", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    data = torch.tensor(np.array(request.data), dtype=torch.float32)
    with torch.no_grad():
        prediction = model(data).numpy().tolist()
    return PredictionResponse(prediction=prediction)

Cons

  • 검증 및 Parsing 등 이슈로 다소 속도 저하에 영향이 있을 수 있음(참고, v2로 업데이트하고 테스트한 내용)
  • 복잡한 모델 구현시 Pydantic으로 하기 쉽지 않음
    • 중첩된 데이터 구조가 깊어질수록 모델의 정의가 복잡해지고, 이를 관리하기 어려워짐
    • 중첩된 필드 간의 상호 의존성을 고려한 유효성 검사를 수행하는 것이 복잡해질 수 있음
    • 필드 값에 따라 다른 필드의 유효성을 동적으로 검사해야 할 때, Pydantic 모델에서 이를 구현하기가 까다로울 수 있습니다.
  • 문서화 관련 기능이 좀 아쉬운 편임

Alternatives

  • Marshmallow: 데이터 검증 외에도 Serialization 지원, Flask와 함께 많이 쓰임
  • Cerberus: Dictionary 기반 검증하는 경량화된 라이브러리
  • attrs: 간단하고 가벼운 데이터 클래스 라이브러리로, 데이터 검증 기능도 포함

Sample

import torch
import torch.nn as nn
import torch.optim as optim

# PyTorch 모델 정의
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(10, 10)

    def forward(self, x):
        return self.layer1(x)

# 데이터 생성 (임의의 데이터 사용)
X = torch.randn(100, 10)
y = torch.randn(100, 10)

# 모델 초기화
model = SimpleModel()

# 손실 함수 및 옵티마이저 정의
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 모델 학습
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

# 학습된 모델 저장
torch.save(model.state_dict(), "simple_model.pth")

from pydantic import BaseModel
from fastapi import FastAPI
import torch
import torch.nn as nn
import numpy as np

# Pydantic 데이터 모델 정의
class Item(BaseModel):
    name: str
    price: float
    description: str = None
    tax: float = None

class PredictionRequest(BaseModel):
    data: list

class PredictionResponse(BaseModel):
    prediction: list

# FastAPI 인스턴스 생성
app = FastAPI()

# PyTorch 모델 정의
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(10, 10)

    def forward(self, x):
        return self.layer1(x)

# 학습된 모델 로드
model = SimpleModel()
model.load_state_dict(torch.load("simple_model.pth"))
model.eval()

@app.post("/items/")
async def create_item(item: Item):
    return item

@app.post("/predict/", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    data = torch.tensor(np.array(request.data), dtype=torch.float32)
    with torch.no_grad():
        prediction = model(data).numpy().tolist()
    return PredictionResponse(prediction=prediction)

# 서버 실행: uvicorn main:app --reload