Inference & Deployment
YOLO supports three fast inference backends selectable at runtime via task.fast_inference.
Modes
| Mode |
Flag |
Requires |
| PyTorch (default) |
(omit flag) |
nothing extra |
| Deploy |
task.fast_inference=deploy |
nothing extra |
| ONNX |
task.fast_inference=onnx |
onnxruntime |
| TensorRT |
task.fast_inference=trt |
torch2trt, CUDA |
Deploy mode
Strips the auxiliary head from the model before running inference. No extra dependencies — just a lighter forward pass.
python -m yolo task=inference task.fast_inference=deploy
ONNX
Exports the model to ONNX on the first run, then reuses the .onnx file on subsequent runs.
pip install onnxruntime # CPU
pip install onnxruntime-gpu # GPU
python -m yolo task=inference task.fast_inference=onnx
python -m yolo task=inference task.fast_inference=onnx device=cpu
The exported file is saved as <weight_stem>.onnx next to the weight file.
TensorRT
Builds a TensorRT engine on the first run (requires a CUDA GPU), then reuses the .trt file.
pip install torch2trt
python -m yolo task=inference task.fast_inference=trt
Note
TensorRT is not supported on MPS (Apple Silicon). The loader falls back to the standard PyTorch model automatically.
API
yolo.utils.deploy_utils.FastModelLoader
Source code in yolo/utils/deploy_utils.py
| class FastModelLoader:
def __init__(self, cfg: Config):
self.cfg = cfg
self.compiler = cfg.task.fast_inference
self.class_num = cfg.dataset.class_num
self._validate_compiler()
if cfg.weight == True:
cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}"
def _validate_compiler(self):
if self.compiler not in ["onnx", "trt", "deploy"]:
logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
self.compiler = None
if self.cfg.trainer.device == "mps" and self.compiler == "trt":
logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
self.compiler = None
def load_model(self, device):
if self.compiler == "onnx":
return self._load_onnx_model(device)
elif self.compiler == "trt":
return self._load_trt_model().to(device)
elif self.compiler == "deploy":
self.cfg.model.model.auxiliary = {}
return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)
def _load_onnx_model(self, device):
from onnxruntime import InferenceSession
def onnx_forward(self: InferenceSession, x: Tensor):
x = {self.get_inputs()[0].name: x.cpu().numpy()}
model_outputs, layer_output = [], []
for idx, predict in enumerate(self.run(None, x)):
layer_output.append(torch.from_numpy(predict).to(device))
if idx % 3 == 2:
model_outputs.append(layer_output)
layer_output = []
if len(model_outputs) == 6:
model_outputs = model_outputs[:3]
return {"Main": model_outputs}
InferenceSession.__call__ = onnx_forward
if device == "cpu":
providers = ["CPUExecutionProvider"]
else:
providers = ["CUDAExecutionProvider"]
try:
ort_session = InferenceSession(self.model_path, providers=providers)
logger.info(":rocket: Using ONNX as MODEL frameworks!")
except Exception as e:
logger.warning(f"🈳 Error loading ONNX model: {e}")
ort_session = self._create_onnx_model(providers)
return ort_session
def _create_onnx_model(self, providers):
from onnxruntime import InferenceSession
from torch.onnx import export
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
export(
model,
dummy_input,
self.model_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
return InferenceSession(self.model_path, providers=providers)
def _load_trt_model(self):
from torch2trt import TRTModule
try:
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(self.model_path))
logger.info(":rocket: Using TensorRT as MODEL frameworks!")
except FileNotFoundError:
logger.warning(f"🈳 No found model weight at {self.model_path}")
model_trt = self._create_trt_model()
return model_trt
def _create_trt_model(self):
from torch2trt import torch2trt
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
logger.info(f"♻️ Creating TensorRT model")
model_trt = torch2trt(model.cuda(), [dummy_input])
torch.save(model_trt.state_dict(), self.model_path)
logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
return model_trt
|
cfg = cfg
instance-attribute
compiler = cfg.task.fast_inference
instance-attribute
class_num = cfg.dataset.class_num
instance-attribute
model_path = f'{Path(cfg.weight).stem}.{self.compiler}'
instance-attribute
__init__(cfg)
Source code in yolo/utils/deploy_utils.py
| def __init__(self, cfg: Config):
self.cfg = cfg
self.compiler = cfg.task.fast_inference
self.class_num = cfg.dataset.class_num
self._validate_compiler()
if cfg.weight == True:
cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}"
|
_validate_compiler()
Source code in yolo/utils/deploy_utils.py
| def _validate_compiler(self):
if self.compiler not in ["onnx", "trt", "deploy"]:
logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
self.compiler = None
if self.cfg.trainer.device == "mps" and self.compiler == "trt":
logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
self.compiler = None
|
load_model(device)
Source code in yolo/utils/deploy_utils.py
| def load_model(self, device):
if self.compiler == "onnx":
return self._load_onnx_model(device)
elif self.compiler == "trt":
return self._load_trt_model().to(device)
elif self.compiler == "deploy":
self.cfg.model.model.auxiliary = {}
return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)
|
_load_onnx_model(device)
Source code in yolo/utils/deploy_utils.py
| def _load_onnx_model(self, device):
from onnxruntime import InferenceSession
def onnx_forward(self: InferenceSession, x: Tensor):
x = {self.get_inputs()[0].name: x.cpu().numpy()}
model_outputs, layer_output = [], []
for idx, predict in enumerate(self.run(None, x)):
layer_output.append(torch.from_numpy(predict).to(device))
if idx % 3 == 2:
model_outputs.append(layer_output)
layer_output = []
if len(model_outputs) == 6:
model_outputs = model_outputs[:3]
return {"Main": model_outputs}
InferenceSession.__call__ = onnx_forward
if device == "cpu":
providers = ["CPUExecutionProvider"]
else:
providers = ["CUDAExecutionProvider"]
try:
ort_session = InferenceSession(self.model_path, providers=providers)
logger.info(":rocket: Using ONNX as MODEL frameworks!")
except Exception as e:
logger.warning(f"🈳 Error loading ONNX model: {e}")
ort_session = self._create_onnx_model(providers)
return ort_session
|
_create_onnx_model(providers)
Source code in yolo/utils/deploy_utils.py
| def _create_onnx_model(self, providers):
from onnxruntime import InferenceSession
from torch.onnx import export
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
export(
model,
dummy_input,
self.model_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
return InferenceSession(self.model_path, providers=providers)
|
_load_trt_model()
Source code in yolo/utils/deploy_utils.py
| def _load_trt_model(self):
from torch2trt import TRTModule
try:
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(self.model_path))
logger.info(":rocket: Using TensorRT as MODEL frameworks!")
except FileNotFoundError:
logger.warning(f"🈳 No found model weight at {self.model_path}")
model_trt = self._create_trt_model()
return model_trt
|
_create_trt_model()
Source code in yolo/utils/deploy_utils.py
| def _create_trt_model(self):
from torch2trt import torch2trt
model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
logger.info(f"♻️ Creating TensorRT model")
model_trt = torch2trt(model.cuda(), [dummy_input])
torch.save(model_trt.state_dict(), self.model_path)
logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
return model_trt
|