ONNX (Open Neural Network Exchange)

By | 2024년 2월 22일
Table of Contents

ONNX (Open Neural Network Exchange)

ONNX 는 서로 다른 DNN 프레임워크 환경(ex Tensorflow, PyTorch, etc..)에서 만들어진 모델들을 서로 호환되게 사용할 수 있도록 만들어진 공유 플랫폼이다.

ONNX 또한 DNN 프레임워크라고 부른다.

호환

PyTorch 로 만들어진 모델을 TensorRT 에서 실행하고자 할때도 중간단계에 ONNX model 로 변환하여 다른 엔진으로 넘길 때도 좋다.

PyTorch model → ONNX model → TensorRT engine

PyTorch model → ONNX model 변환

기학습된 resnet50 모델을 불어와서 onnx 모델로 저장합니다.

convert.py

import torch
import torchvision.models as models

model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "resnet18.onnx")

pytorch 모델과 onnx 모델의 결과값을 비교합니다.
소수점단위 오차가 발생하므로 오차범위를 보정해 줍니다.
아무것도 출력되지 않는다면 테스트가 성공한 것입니다.

compare.py

import torch
import torchvision.models as models

import numpy as np
import onnxruntime as ort

dummy_input = torch.randn(1, 3, 224, 224)

model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
torch_output = model(dummy_input)

ort_session = ort.InferenceSession("resnet18.onnx")
ort_outputs = ort_session.run(None, {"input.1": dummy_input.numpy()})

np.testing.assert_allclose(torch_output.detach().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)

답글 남기기