Onnx is
- Open Neural Network Exchange의 약자로 Pytorchs나 Tensorflow 등 다양한 Framework를 통해 생성된 모델을 결합하여 사용할 수 있도록 일관성과 상호 운용성을 보장하는 오픈소스 프레임워크
- ONNX는 표준화된 연산자 및 데이터 타입을 활용하여 다양한 플랫폼에서 호환성을 보장함
Motivation
- Interoperability: 다양한 ML 프레임워크 간 모델을 호환해서 사용할 수 있도록 지원
- Standardization: 표준화된 형식을 통해서 모델의 재사용성을 높일 수 있음
- Productivity: 여러 환경에서 모델을 쉽게 배포할 수 있도록 함으로써 생산성 향상
Pros & Cons
Pros
- 다양한 프레임워크 지원 → Tensorflow, Pytorch, Caffe2등 다양한 프레임워크 지원
- 하드웨어 가속 → ONNX Runtime 등 툴을 이용해서 모델 실행 속도를 높일 수 있음
- ONNX Runtime은 ONNX 형식의 모델을 실행하기 위한 고성능 엔진으로 ONNX가 모델을 정의하고 변화하는 기능에 초점을 맞추고 있다면 Runtime은 추론 속도 및 효율성 최적화에 초점을 맞추고 있음
- ONNX는 Meta 및 MS를 포함, 여러 회사가 함께 개발하고 지원하는데 반해 ONNX Runtime은 MS가 주도
- CPU, GPU, 그리고 TensorRT, OpenVINO, DirectML 등을 직접 지원
- 오픈소스 커뮤니티 → 활발한 오픈소스 커뮤니티가 있어 지속 업데이트 가능
Cons
- 제한된 연산자 지원 → 일부 연산자 호환이 되지 않을 수 있음
- TensorFlow의
tf.image.non_max_suppression
: 커스텀 구현이나 우회 방법이 필요할 수 있dㅡㅁ
- PyTorch의 고급 인덱싱 또는 커스텀 CUDA 커널
- Scikit-learn의 특정 전처리 단계
PolynomialFeatures
:
FunctionTransformer
: 사용자 정의 함수로 데이터를 변환할 수 있게 하는 이 도구는 ONNX에서 지원되지 않을 수 있음
ColumnTransformer
: 여러 유형의 전처리를 열마다 다르게 적용할 수 있게 하는 이 도구는 ONNX로 변환 시 복잡성을 증가시킬 수 있습니다.
- 복잡성 → 모델 변환과정에서 발생할 수 있는 호환성과 이를 해결하기 위한 디버깅이 어려을 수 있음
- 성능저하 가능성: 프레임워크간 변환시 성능 저하될 수 있음
Alternative
- Tensorflow SavedModel: Tensorflow에서 모델 저장 및 배포를 위한 표준 형식
- TorchScript: Pytorch 모델을 저장하고 배포하기 위한 형식
- PMML(Predictive Model Markup Language): 데이터 마이닝 및 예측 모델링을 위한 표준화된 XML기반 형식
Sample Code
import torch
import torch.onnx
import onnx
import onnxruntime
import matplotlib.pyplot as plt
# 간단한 PyTorch 모델 정의
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = torch.nn.Linear(2, 1)
def forward(self, x):
return self.fc(x)
# 모델 인스턴스 생성 및 더미 입력 정의
model = SimpleModel()
dummy_input = torch.randn(1, 2)
# 모델을 ONNX 형식으로 변환 및 저장
torch.onnx.export(model, dummy_input, "simple_model.onnx", verbose=True)
# ONNX 모델 로드 및 확인
onnx_model = onnx.load("simple_model.onnx")
onnx.checker.check_model(onnx_model)
# ONNX Runtime을 사용하여 모델 실행
ort_session = onnxruntime.InferenceSession("simple_model.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
print("ONNX Runtime output:", ort_outs)
# 추가 설명을 위한 코드 예제 확장
print(f"PyTorch Model Output: {model(dummy_input).detach().numpy()}")
print(f"ONNX Model Output: {ort_outs[0]}")
🥤Source