Skip to content

Inference & Deployment

YOLO supports three fast inference backends selectable at runtime via task.fast_inference.

Modes

Mode Flag Requires
PyTorch (default) (omit flag) nothing extra
Deploy task.fast_inference=deploy nothing extra
ONNX task.fast_inference=onnx onnxruntime
TensorRT task.fast_inference=trt torch2trt, CUDA

Deploy mode

Strips the auxiliary head from the model before running inference. No extra dependencies — just a lighter forward pass.

python -m yolo task=inference task.fast_inference=deploy

ONNX

Exports the model to ONNX on the first run, then reuses the .onnx file on subsequent runs.

pip install onnxruntime        # CPU
pip install onnxruntime-gpu    # GPU

python -m yolo task=inference task.fast_inference=onnx
python -m yolo task=inference task.fast_inference=onnx device=cpu

The exported file is saved as <weight_stem>.onnx next to the weight file.

TensorRT

Builds a TensorRT engine on the first run (requires a CUDA GPU), then reuses the .trt file.

pip install torch2trt

python -m yolo task=inference task.fast_inference=trt

Note

TensorRT is not supported on MPS (Apple Silicon). The loader falls back to the standard PyTorch model automatically.

API

yolo.utils.deploy_utils.FastModelLoader

Source code in yolo/utils/deploy_utils.py
class FastModelLoader:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.compiler = cfg.task.fast_inference
        self.class_num = cfg.dataset.class_num

        self._validate_compiler()
        if cfg.weight == True:
            cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
        self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}"

    def _validate_compiler(self):
        if self.compiler not in ["onnx", "trt", "deploy"]:
            logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
            self.compiler = None
        if self.cfg.trainer.device == "mps" and self.compiler == "trt":
            logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
            self.compiler = None

    def load_model(self, device):
        if self.compiler == "onnx":
            return self._load_onnx_model(device)
        elif self.compiler == "trt":
            return self._load_trt_model().to(device)
        elif self.compiler == "deploy":
            self.cfg.model.model.auxiliary = {}
        return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)

    def _load_onnx_model(self, device):
        from onnxruntime import InferenceSession

        def onnx_forward(self: InferenceSession, x: Tensor):
            x = {self.get_inputs()[0].name: x.cpu().numpy()}
            model_outputs, layer_output = [], []
            for idx, predict in enumerate(self.run(None, x)):
                layer_output.append(torch.from_numpy(predict).to(device))
                if idx % 3 == 2:
                    model_outputs.append(layer_output)
                    layer_output = []
            if len(model_outputs) == 6:
                model_outputs = model_outputs[:3]
            return {"Main": model_outputs}

        InferenceSession.__call__ = onnx_forward

        if device == "cpu":
            providers = ["CPUExecutionProvider"]
        else:
            providers = ["CUDAExecutionProvider"]
        try:
            ort_session = InferenceSession(self.model_path, providers=providers)
            logger.info(":rocket: Using ONNX as MODEL frameworks!")
        except Exception as e:
            logger.warning(f"🈳 Error loading ONNX model: {e}")
            ort_session = self._create_onnx_model(providers)
        return ort_session

    def _create_onnx_model(self, providers):
        from onnxruntime import InferenceSession
        from torch.onnx import export

        model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size))
        export(
            model,
            dummy_input,
            self.model_path,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        )
        logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
        return InferenceSession(self.model_path, providers=providers)

    def _load_trt_model(self):
        from torch2trt import TRTModule

        try:
            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(self.model_path))
            logger.info(":rocket: Using TensorRT as MODEL frameworks!")
        except FileNotFoundError:
            logger.warning(f"🈳 No found model weight at {self.model_path}")
            model_trt = self._create_trt_model()
        return model_trt

    def _create_trt_model(self):
        from torch2trt import torch2trt

        model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
        logger.info(f"♻️ Creating TensorRT model")
        model_trt = torch2trt(model.cuda(), [dummy_input])
        torch.save(model_trt.state_dict(), self.model_path)
        logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
        return model_trt

cfg = cfg instance-attribute

compiler = cfg.task.fast_inference instance-attribute

class_num = cfg.dataset.class_num instance-attribute

model_path = f'{Path(cfg.weight).stem}.{self.compiler}' instance-attribute

__init__(cfg)

Source code in yolo/utils/deploy_utils.py
def __init__(self, cfg: Config):
    self.cfg = cfg
    self.compiler = cfg.task.fast_inference
    self.class_num = cfg.dataset.class_num

    self._validate_compiler()
    if cfg.weight == True:
        cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
    self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}"

_validate_compiler()

Source code in yolo/utils/deploy_utils.py
def _validate_compiler(self):
    if self.compiler not in ["onnx", "trt", "deploy"]:
        logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
        self.compiler = None
    if self.cfg.trainer.device == "mps" and self.compiler == "trt":
        logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
        self.compiler = None

load_model(device)

Source code in yolo/utils/deploy_utils.py
def load_model(self, device):
    if self.compiler == "onnx":
        return self._load_onnx_model(device)
    elif self.compiler == "trt":
        return self._load_trt_model().to(device)
    elif self.compiler == "deploy":
        self.cfg.model.model.auxiliary = {}
    return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)

_load_onnx_model(device)

Source code in yolo/utils/deploy_utils.py
def _load_onnx_model(self, device):
    from onnxruntime import InferenceSession

    def onnx_forward(self: InferenceSession, x: Tensor):
        x = {self.get_inputs()[0].name: x.cpu().numpy()}
        model_outputs, layer_output = [], []
        for idx, predict in enumerate(self.run(None, x)):
            layer_output.append(torch.from_numpy(predict).to(device))
            if idx % 3 == 2:
                model_outputs.append(layer_output)
                layer_output = []
        if len(model_outputs) == 6:
            model_outputs = model_outputs[:3]
        return {"Main": model_outputs}

    InferenceSession.__call__ = onnx_forward

    if device == "cpu":
        providers = ["CPUExecutionProvider"]
    else:
        providers = ["CUDAExecutionProvider"]
    try:
        ort_session = InferenceSession(self.model_path, providers=providers)
        logger.info(":rocket: Using ONNX as MODEL frameworks!")
    except Exception as e:
        logger.warning(f"🈳 Error loading ONNX model: {e}")
        ort_session = self._create_onnx_model(providers)
    return ort_session

_create_onnx_model(providers)

Source code in yolo/utils/deploy_utils.py
def _create_onnx_model(self, providers):
    from onnxruntime import InferenceSession
    from torch.onnx import export

    model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
    dummy_input = torch.ones((1, 3, *self.cfg.image_size))
    export(
        model,
        dummy_input,
        self.model_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )
    logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
    return InferenceSession(self.model_path, providers=providers)

_load_trt_model()

Source code in yolo/utils/deploy_utils.py
def _load_trt_model(self):
    from torch2trt import TRTModule

    try:
        model_trt = TRTModule()
        model_trt.load_state_dict(torch.load(self.model_path))
        logger.info(":rocket: Using TensorRT as MODEL frameworks!")
    except FileNotFoundError:
        logger.warning(f"🈳 No found model weight at {self.model_path}")
        model_trt = self._create_trt_model()
    return model_trt

_create_trt_model()

Source code in yolo/utils/deploy_utils.py
def _create_trt_model(self):
    from torch2trt import torch2trt

    model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
    dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
    logger.info(f"♻️ Creating TensorRT model")
    model_trt = torch2trt(model.cuda(), [dummy_input])
    torch.save(model_trt.state_dict(), self.model_path)
    logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
    return model_trt