Skip to content

Utilities

Logging

yolo.utils.logging_utils

Module for initializing logging tools used in machine learning and data processing. Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other logging frameworks as needed.

This setup ensures consistent logging across various platforms, facilitating effective monitoring and debugging.

Example

from tools.logger import custom_logger custom_logger()

logger = logging.getLogger('yolo') module-attribute

Config dataclass

Source code in yolo/config/config.py
@dataclass
class Config:
    task: Union[TrainConfig, InferenceConfig, ValidationConfig]
    dataset: DatasetConfig
    model: ModelConfig
    name: str

    trainer: TrainerConfig

    image_size: List[int]

    out_path: str
    exist_ok: bool

    lucky_number: int
    use_wandb: bool
    use_tensorboard: bool

    task_type: str
    weight: Optional[str]

YOLOLayer dataclass

Bases: Module

Source code in yolo/config/schemas/model.py
@dataclass
class YOLOLayer(nn.Module):
    source: Union[int, str, List[int]]
    output: bool
    tags: str
    layer_type: str
    usable: bool
    external: Optional[dict]

YOLO

Bases: Module

A preliminary YOLO (You Only Look Once) model class still under development.

Parameters:

Name Type Description Default
model_cfg ModelConfig

Configuration for the YOLO model. Expected to define the layers, parameters, and any other relevant configuration details.

required
Source code in yolo/model/builder.py
class YOLO(nn.Module):
    """
    A preliminary YOLO (You Only Look Once) model class still under development.

    Parameters:
        model_cfg: Configuration for the YOLO model. Expected to define the layers,
                   parameters, and any other relevant configuration details.
    """

    def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
        super(YOLO, self).__init__()
        self.num_classes = class_num
        self.layer_map = get_layer_map()  # Get the map Dict[str: Module]
        self.model: List[YOLOLayer] = nn.ModuleList()
        self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
        self.build_model(model_cfg.model)

    def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
        self.layer_index = {}
        output_dim, layer_idx = [3], 1
        logger.info(f":tractor: Building YOLO")
        for arch_name in model_arch:
            if model_arch[arch_name]:
                logger.info(f"  :building_construction:  Building {arch_name}")
            for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
                layer_type, layer_info = next(iter(layer_spec.items()))
                layer_args = layer_info.get("args", {})

                # Get input source
                source = self.get_source_idx(layer_info.get("source", -1), layer_idx)

                # Find in channels
                if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
                    layer_args["in_channels"] = output_dim[source]
                if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]):
                    if isinstance(source, list):
                        layer_args["in_channels"] = [output_dim[idx] for idx in source]
                    else:
                        layer_args["in_channel"] = output_dim[source]
                    layer_args["num_classes"] = self.num_classes
                    layer_args["reg_max"] = self.reg_max

                # create layers
                layer = self.create_layer(layer_type, source, layer_info, **layer_args)
                self.model.append(layer)

                if layer.tags:
                    if layer.tags in self.layer_index:
                        raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
                    self.layer_index[layer.tags] = layer_idx

                out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
                output_dim.append(out_channels)
                setattr(layer, "out_c", out_channels)
            layer_idx += 1

    def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None):
        y = {0: x, **(external or {})}
        output = dict()
        for index, layer in enumerate(self.model, start=1):
            if isinstance(layer.source, list):
                model_input = [y[idx] for idx in layer.source]
            else:
                model_input = y[layer.source]

            external_input = {source_name: y[source_name] for source_name in layer.external}

            x = layer(model_input, **external_input)
            y[-1] = x
            if layer.usable:
                y[index] = x
            if layer.output:
                output[layer.tags] = x
                if layer.tags == shortcut:
                    return output
        return output

    def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
        if hasattr(layer_args, "out_channels"):
            return layer_args["out_channels"]
        if layer_type == "CBFuse":
            return output_dim[source[-1]]
        if isinstance(source, int):
            return output_dim[source]
        if isinstance(source, list):
            return sum(output_dim[idx] for idx in source)

    def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
        if isinstance(source, ListConfig):
            return [self.get_source_idx(index, layer_idx) for index in source]
        if isinstance(source, str):
            source = self.layer_index[source]
        if source < -1:
            source += layer_idx
        if source > 0:  # Using Previous Layer's Output
            self.model[source - 1].usable = True
        return source

    def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
        if layer_type in self.layer_map:
            layer = self.layer_map[layer_type](**kwargs)
            setattr(layer, "layer_type", layer_type)
            setattr(layer, "source", source)
            setattr(layer, "in_c", kwargs.get("in_channels", None))
            setattr(layer, "output", layer_info.get("output", False))
            setattr(layer, "tags", layer_info.get("tags", None))
            setattr(layer, "external", layer_info.get("external", []))
            setattr(layer, "usable", 0)
            return layer
        else:
            raise ValueError(f"Unsupported layer type: {layer_type}")

    def save_load_weights(self, weights: Union[Path, OrderedDict]):
        """
        Update the model's weights with the provided weights.

        args:
            weights: A OrderedDict containing the new weights.
        """
        if isinstance(weights, Path):
            weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
        if "state_dict" in weights:
            weights = {name.removeprefix("model.model."): key for name, key in weights["state_dict"].items()}
        model_state_dict = self.model.state_dict()

        # TODO1: autoload old version weight
        # TODO2: weight transform if num_class difference

        error_dict = {"Mismatch": set(), "Not Found": set()}
        for model_key, model_weight in model_state_dict.items():
            if model_key not in weights:
                error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
                continue
            if model_weight.shape != weights[model_key].shape:
                error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
                continue
            model_state_dict[model_key] = weights[model_key]

        for error_name, error_set in error_dict.items():
            error_dict = dict()
            for layer_idx, *layer_name in error_set:
                if layer_idx not in error_dict:
                    error_dict[layer_idx] = [".".join(layer_name)]
                else:
                    error_dict[layer_idx].append(".".join(layer_name))
            for layer_idx, layer_name in error_dict.items():
                layer_name.sort()
                logger.warning(f":warning: Weight {error_name} for Layer {layer_idx}: {', '.join(layer_name)}")

        self.model.load_state_dict(model_state_dict)

save_load_weights(weights)

Update the model's weights with the provided weights.

Parameters:

Name Type Description Default
weights Union[Path, OrderedDict]

A OrderedDict containing the new weights.

required
Source code in yolo/model/builder.py
def save_load_weights(self, weights: Union[Path, OrderedDict]):
    """
    Update the model's weights with the provided weights.

    args:
        weights: A OrderedDict containing the new weights.
    """
    if isinstance(weights, Path):
        weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
    if "state_dict" in weights:
        weights = {name.removeprefix("model.model."): key for name, key in weights["state_dict"].items()}
    model_state_dict = self.model.state_dict()

    # TODO1: autoload old version weight
    # TODO2: weight transform if num_class difference

    error_dict = {"Mismatch": set(), "Not Found": set()}
    for model_key, model_weight in model_state_dict.items():
        if model_key not in weights:
            error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
            continue
        if model_weight.shape != weights[model_key].shape:
            error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
            continue
        model_state_dict[model_key] = weights[model_key]

    for error_name, error_set in error_dict.items():
        error_dict = dict()
        for layer_idx, *layer_name in error_set:
            if layer_idx not in error_dict:
                error_dict[layer_idx] = [".".join(layer_name)]
            else:
                error_dict[layer_idx].append(".".join(layer_name))
        for layer_idx, layer_name in error_dict.items():
            layer_name.sort()
            logger.warning(f":warning: Weight {error_name} for Layer {layer_idx}: {', '.join(layer_name)}")

    self.model.load_state_dict(model_state_dict)

EMA

Bases: Callback

Exponential Moving Average of model weights as a Lightning Callback.

Keeps a shadow copy of model parameters smoothed over training steps

beta = decay * (1 - exp(-step / tau)) shadow = beta * shadow + (1 - beta) * model

The tau warmup ramps beta up from ~0 at step 0, so early noisy updates don't dominate the shadow. Validation always runs on shadow weights; training weights are swapped back immediately after.

Source code in yolo/training/callbacks.py
class EMA(Callback):
    """Exponential Moving Average of model weights as a Lightning Callback.

    Keeps a shadow copy of model parameters smoothed over training steps:
        beta = decay * (1 - exp(-step / tau))
        shadow = beta * shadow + (1 - beta) * model

    The tau warmup ramps beta up from ~0 at step 0, so early noisy updates
    don't dominate the shadow. Validation always runs on shadow weights;
    training weights are swapped back immediately after.
    """

    _CHECKPOINT_KEY = "ema_shadow"

    def __init__(self, decay: float = 0.9999, tau: float = 2000.0) -> None:
        super().__init__()
        logger.info(":chart_with_upwards_trend: Enable Model EMA")
        self.decay = decay
        self.tau = tau
        self.step: int = 0
        self.batch_count: int = 0
        self.shadow: Optional[dict] = None
        self._training_weights: Optional[dict] = None

    def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
        """Initialise shadow from the model before training begins."""
        if self.shadow is None:
            self.shadow = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}

    def _beta(self) -> float:
        """Effective smoothing coefficient at the current step."""
        return self.decay * (1 - exp(-self.step / self.tau))

    @torch.no_grad()
    def update(self, pl_module: "LightningModule") -> None:
        """Blend model parameters into shadow; copy buffers directly."""
        current = pl_module.model.state_dict()
        if self.shadow is None:
            self.shadow = {k: v.detach().clone() for k, v in current.items()}
            return

        beta = self._beta()

        param_keys = [k for k, _ in pl_module.model.named_parameters()]
        shadow_params = [self.shadow[k] for k in param_keys]
        model_params = [current[k].detach().to(self.shadow[k].device) for k in param_keys]
        if hasattr(torch, "_foreach_lerp_"):
            torch._foreach_lerp_(shadow_params, model_params, 1.0 - beta)
        elif hasattr(torch, "_foreach_mul_"):
            torch._foreach_mul_(shadow_params, beta)
            torch._foreach_add_(shadow_params, model_params, alpha=1.0 - beta)
        else:
            for s, m in zip(shadow_params, model_params):
                s.mul_(beta).add_(m, alpha=1.0 - beta)

        for key, buf in pl_module.model.named_buffers():
            self.shadow[key].copy_(buf.detach().to(self.shadow[key].device))

    @torch.no_grad()
    def apply_shadow(self, pl_module: "LightningModule") -> None:
        """Snapshot training weights then load shadow weights into the model."""
        if self.shadow is None:
            return
        self._training_weights = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}
        pl_module.model.load_state_dict(self.shadow, strict=True)

    @torch.no_grad()
    def restore(self, pl_module: "LightningModule") -> None:
        """Reload the training-weight snapshot, discarding the shadow swap."""
        if self._training_weights is None:
            return
        pl_module.model.load_state_dict(self._training_weights, strict=True)
        self._training_weights = None

    @torch.no_grad()
    def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
        self.batch_count += 1
        if self.batch_count % trainer.accumulate_grad_batches != 0:
            return
        self.step += 1
        self.update(pl_module)

    def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        self.batch_count = 0
        self.apply_shadow(pl_module)

    def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        self.restore(pl_module)

    def on_save_checkpoint(self, trainer: "Trainer", pl_module: "LightningModule", checkpoint: dict) -> None:
        if self.shadow is None:
            return
        checkpoint[self._CHECKPOINT_KEY] = {k: v.detach().cpu() for k, v in self.shadow.items()}
        checkpoint["ema_step"] = self.step
        checkpoint["ema_batch_count"] = self.batch_count

    def on_load_checkpoint(self, trainer: "Trainer", pl_module: "LightningModule", checkpoint: dict) -> None:
        self.step = checkpoint.get("ema_step", 0)
        self.batch_count = checkpoint.get("ema_batch_count", 0)
        if self._CHECKPOINT_KEY not in checkpoint:
            return
        target_device = next(pl_module.model.parameters()).device
        self.shadow = {k: v.detach().clone().to(target_device) for k, v in checkpoint[self._CHECKPOINT_KEY].items()}

setup(trainer, pl_module, stage)

Initialise shadow from the model before training begins.

Source code in yolo/training/callbacks.py
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
    """Initialise shadow from the model before training begins."""
    if self.shadow is None:
        self.shadow = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}

update(pl_module)

Blend model parameters into shadow; copy buffers directly.

Source code in yolo/training/callbacks.py
@torch.no_grad()
def update(self, pl_module: "LightningModule") -> None:
    """Blend model parameters into shadow; copy buffers directly."""
    current = pl_module.model.state_dict()
    if self.shadow is None:
        self.shadow = {k: v.detach().clone() for k, v in current.items()}
        return

    beta = self._beta()

    param_keys = [k for k, _ in pl_module.model.named_parameters()]
    shadow_params = [self.shadow[k] for k in param_keys]
    model_params = [current[k].detach().to(self.shadow[k].device) for k in param_keys]
    if hasattr(torch, "_foreach_lerp_"):
        torch._foreach_lerp_(shadow_params, model_params, 1.0 - beta)
    elif hasattr(torch, "_foreach_mul_"):
        torch._foreach_mul_(shadow_params, beta)
        torch._foreach_add_(shadow_params, model_params, alpha=1.0 - beta)
    else:
        for s, m in zip(shadow_params, model_params):
            s.mul_(beta).add_(m, alpha=1.0 - beta)

    for key, buf in pl_module.model.named_buffers():
        self.shadow[key].copy_(buf.detach().to(self.shadow[key].device))

apply_shadow(pl_module)

Snapshot training weights then load shadow weights into the model.

Source code in yolo/training/callbacks.py
@torch.no_grad()
def apply_shadow(self, pl_module: "LightningModule") -> None:
    """Snapshot training weights then load shadow weights into the model."""
    if self.shadow is None:
        return
    self._training_weights = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}
    pl_module.model.load_state_dict(self.shadow, strict=True)

restore(pl_module)

Reload the training-weight snapshot, discarding the shadow swap.

Source code in yolo/training/callbacks.py
@torch.no_grad()
def restore(self, pl_module: "LightningModule") -> None:
    """Reload the training-weight snapshot, discarding the shadow swap."""
    if self._training_weights is None:
        return
    pl_module.model.load_state_dict(self._training_weights, strict=True)
    self._training_weights = None

GradientAccumulation

Bases: Callback

Source code in yolo/training/callbacks.py
class GradientAccumulation(Callback):
    def __init__(self, data_cfg: DataConfig, scheduler_cfg: SchedulerConfig):
        super().__init__()
        self.equivalent_batch_size = data_cfg.equivalent_batch_size
        self.actual_batch_size = data_cfg.batch_size
        self.warmup_epochs = getattr(scheduler_cfg.warmup, "epochs", 0)
        self.max_accumulation = 1
        self.warmup_batches = 0
        self.steps_per_epoch = 1
        logger.info(":arrows_counterclockwise: Enable Gradient Accumulation")

    def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
        scaled_batch = self.actual_batch_size * trainer.world_size
        self.max_accumulation = max(1, round(self.equivalent_batch_size / scaled_batch))

    def on_train_start(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        for sched_cfg in trainer.lr_scheduler_configs:
            if hasattr(sched_cfg.scheduler, "steps_per_epoch"):
                self.steps_per_epoch = sched_cfg.scheduler.steps_per_epoch
                break
        self.warmup_batches = int(self.warmup_epochs * self.steps_per_epoch)

    def on_train_batch_start(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
        step = trainer.global_step
        if step < self.warmup_batches:
            current_accumulation = round(lerp(1, self.max_accumulation, step, self.warmup_batches))
        else:
            current_accumulation = self.max_accumulation
        trainer.accumulate_grad_batches = current_accumulation

YOLOCustomProgress

Bases: CustomProgress

Source code in yolo/utils/logging_utils.py
class YOLOCustomProgress(CustomProgress):
    def get_renderable(self):
        renderable = Group(*self.get_renderables())
        if hasattr(self, "table"):
            renderable = Group(*self.get_renderables(), self.table)
        return renderable

YOLORichProgressBar

Bases: RichProgressBar

Source code in yolo/utils/logging_utils.py
class YOLORichProgressBar(RichProgressBar):
    @override
    @rank_zero_only
    def _init_progress(self, trainer: "Trainer") -> None:
        if self.is_enabled and (self.progress is None or self._progress_stopped):
            self._reset_progress_bar_ids()
            reconfigure(**self._console_kwargs)
            self._console = Console()
            self.progress = YOLOCustomProgress(
                *self.configure_columns(trainer),
                auto_refresh=False,
                disable=self.is_disabled,
                console=self._console,
            )
            self.progress.start()

            self._progress_stopped = False

            self.max_result = 0
            self.past_results = deque(maxlen=5)
            self.progress.table = Table()

    @override
    def _get_train_description(self, current_epoch: int) -> str:
        return Text("[cyan]Train [white]|")

    @override
    @rank_zero_only
    def on_train_start(self, trainer, pl_module):
        self._init_progress(trainer)
        num_epochs = trainer.max_epochs - 1
        self.task_epoch = self._add_task(
            total_batches=num_epochs,
            description=f"[cyan]Start Training {num_epochs} epochs",
        )
        self.max_result = 0
        self.past_results.clear()

    @override
    @rank_zero_only
    def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
        self._update(self.train_progress_bar_id, batch_idx + 1)
        self._update_metrics(trainer, pl_module)
        epoch_descript = "[cyan]Train [white]|"
        batch_descript = "[green]Batch [white]|"
        metrics = self.get_metrics(trainer, pl_module)
        metrics.pop("v_num")
        for metrics_name, metrics_val in metrics.items():
            if "Loss_step" in metrics_name:
                epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
                batch_descript += f"   {metrics_val:2.2f}  |"

        self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
        self.progress.update(self.train_progress_bar_id, description=batch_descript)
        self.refresh()

    @override
    @rank_zero_only
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
        if self.is_disabled:
            return
        if trainer.sanity_checking:
            self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
        elif self.val_progress_bar_id is not None:
            self._update(self.val_progress_bar_id, batch_idx + 1)
            _, mAP = outputs
            mAP_desc = f" mAP :{mAP['map']*100:6.2f} | mAP50 :{mAP['map_50']*100:6.2f} |"
            self.progress.update(self.val_progress_bar_id, description=f"[green]Valid [white]|{mAP_desc}")
        self.refresh()

    @override
    @rank_zero_only
    def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        self._update_metrics(trainer, pl_module)
        self.progress.remove_task(self.train_progress_bar_id)
        self.train_progress_bar_id = None

    @override
    @rank_zero_only
    def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        if trainer.state.fn == "fit":
            self._update_metrics(trainer, pl_module)
        self.reset_dataloader_idx_tracker()
        all_metrics = self.get_metrics(trainer, pl_module)

        ap_ar_list = [
            key
            for key in all_metrics.keys()
            if key.startswith(("map", "mar")) and not key.endswith(("_step", "_epoch"))
        ]
        score = np.array([all_metrics[key] for key in ap_ar_list]) * 100

        self.progress.table, ap_main = make_ap_table(score, self.past_results, self.max_result, trainer.current_epoch)
        self.max_result = np.maximum(score, self.max_result)
        self.past_results.append((trainer.current_epoch, ap_main))

    @override
    def refresh(self, **kwargs) -> None:
        if self.progress:
            self.progress.refresh()

    @property
    def validation_description(self) -> str:
        return "[green]Validation"

YOLORichModelSummary

Bases: RichModelSummary

Source code in yolo/utils/logging_utils.py
class YOLORichModelSummary(RichModelSummary):
    @staticmethod
    @override
    def summarize(
        summary_data: List[Tuple[str, List[str]]],
        total_parameters: int,
        trainable_parameters: int,
        model_size: float,
        total_training_modes: Dict[str, int],
        **summarize_kwargs: Any,
    ) -> None:
        from lightning.pytorch.utilities.model_summary import get_human_readable_count

        console = get_console()

        header_style: str = summarize_kwargs.get("header_style", "bold magenta")
        table = Table(header_style=header_style)
        table.add_column(" ", style="dim")
        table.add_column("Name", justify="left", no_wrap=True)
        table.add_column("Type")
        table.add_column("Params", justify="right")
        table.add_column("Mode")

        column_names = list(zip(*summary_data))[0]

        for column_name in ["In sizes", "Out sizes"]:
            if column_name in column_names:
                table.add_column(column_name, justify="right", style="white")

        rows = list(zip(*(arr[1] for arr in summary_data)))
        for row in rows:
            table.add_row(*row)

        console.print(table)

        parameters = []
        for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
            parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))

        grid = Table(header_style=header_style)
        table.add_column(" ", style="dim")
        grid.add_column("[bold]Attributes[/]")
        grid.add_column("Value")

        grid.add_row("[bold]Trainable params[/]", f"{parameters[0]}")
        grid.add_row("[bold]Non-trainable params[/]", f"{parameters[1]}")
        grid.add_row("[bold]Total params[/]", f"{parameters[2]}")
        grid.add_row("[bold]Total estimated model params size (MB)[/]", f"{parameters[3]}")
        grid.add_row("[bold]Modules in train mode[/]", f"{total_training_modes['train']}")
        grid.add_row("[bold]Modules in eval mode[/]", f"{total_training_modes['eval']}")

        console.print(grid)

ImageLogger

Bases: Callback

Source code in yolo/utils/logging_utils.py
class ImageLogger(Callback):
    def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx) -> None:
        if batch_idx != 0:
            return
        batch_size, images, targets, rev_tensor, img_paths = batch
        predicts, _ = outputs
        gt_boxes = targets[0] if targets.ndim == 3 else targets
        pred_boxes = predicts[0] if isinstance(predicts, list) else predicts
        images = [images[0]]
        step = trainer.current_epoch
        for logger in trainer.loggers:
            if isinstance(logger, WandbLogger):
                logger.log_image("Input Image", images, step=step)
                logger.log_image("Ground Truth", images, step=step, boxes=[log_bbox(gt_boxes)])
                logger.log_image("Prediction", images, step=step, boxes=[log_bbox(pred_boxes)])

make_ap_table(score, past_result=[], max_result=None, epoch=-1)

Source code in yolo/utils/solver_utils.py
def make_ap_table(score: Dict[str, float], past_result=[], max_result=None, epoch=-1):
    ap_table = Table()
    ap_table.add_column("Epoch", justify="center", style="white", width=5)
    ap_table.add_column("Avg. Precision", justify="left", style="cyan")
    ap_table.add_column("%", justify="right", style="green", width=5)
    ap_table.add_column("Avg. Recall", justify="left", style="cyan")
    ap_table.add_column("%", justify="right", style="green", width=5)

    for eps, (ap_name1, ap_color1, ap_value1, ap_name2, ap_color2, ap_value2) in past_result:
        ap_table.add_row(f"{eps: 3d}", ap_name1, f"{ap_color1}{ap_value1:.2f}", ap_name2, f"{ap_color2}{ap_value2:.2f}")
    if past_result:
        ap_table.add_row()

    color = np.where(max_result <= score, "[green]", "[red]")

    this_ap = ("AP @ .5:.95", color[0], score[0], "AP @        .5", color[1], score[1])
    metrics = [
        ("AP @ .5:.95", color[0], score[0], "AR maxDets   1", color[6], score[6]),
        ("AP @     .5", color[1], score[1], "AR maxDets  10", color[7], score[7]),
        ("AP @    .75", color[2], score[2], "AR maxDets 100", color[8], score[8]),
        ("AP  (small)", color[3], score[3], "AR     (small)", color[9], score[9]),
        ("AP (medium)", color[4], score[4], "AR    (medium)", color[10], score[10]),
        ("AP  (large)", color[5], score[5], "AR     (large)", color[11], score[11]),
    ]

    for ap_name, ap_color, ap_value, ar_name, ar_color, ar_value in metrics:
        ap_table.add_row(f"{epoch: 3d}", ap_name, f"{ap_color}{ap_value:.2f}", ar_name, f"{ar_color}{ar_value:.2f}")

    return ap_table, this_ap

set_seed(seed)

Source code in yolo/utils/logging_utils.py
def set_seed(seed):
    seed_everything(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

setup_logger(logger_name, quiet=False)

Source code in yolo/utils/logging_utils.py
def setup_logger(logger_name, quiet=False):
    class EmojiFormatter(logging.Formatter):
        def format(self, record, emoji=":high_voltage:"):
            return f"{emoji} {super().format(record)}"

    rich_handler = RichHandler(markup=True)
    rich_handler.setFormatter(EmojiFormatter("%(message)s"))
    rich_logger = logging.getLogger(logger_name)
    if rich_logger:
        rich_logger.handlers.clear()
        rich_logger.addHandler(rich_handler)
        if quiet:
            rich_logger.setLevel(logging.ERROR)

    coco_logger = logging.getLogger("faster_coco_eval.core.cocoeval")
    coco_logger.setLevel(logging.ERROR)

setup(cfg)

Source code in yolo/utils/logging_utils.py
def setup(cfg: Config):
    quiet = hasattr(cfg, "quiet")
    setup_logger("lightning.fabric", quiet=quiet)
    setup_logger("lightning.pytorch", quiet=quiet)

    def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
        if silent:
            return
        for line in string.split("\n"):
            logger.info(Text.from_ansi(":globe_with_meridians: " + line))

    wandb.errors.term._log = custom_wandb_log

    save_path = validate_log_directory(cfg, cfg.name)

    progress, loggers = [], []

    if cfg.task.task == "train":
        progress.append(LearningRateMonitor(logging_interval="step"))
    if cfg.task.task == "train" and hasattr(cfg.task.data, "equivalent_batch_size"):
        progress.append(GradientAccumulation(data_cfg=cfg.task.data, scheduler_cfg=cfg.task.scheduler))

    if hasattr(cfg.task, "ema") and cfg.task.ema.enable:
        progress.append(EMA(cfg.task.ema.decay))
    if quiet:
        logger.setLevel(logging.ERROR)
        return progress, loggers, save_path

    progress.append(YOLORichProgressBar())
    progress.append(YOLORichModelSummary())
    progress.append(ImageLogger())
    if cfg.use_tensorboard:
        loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path))
    if cfg.use_wandb:
        wandb_cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
        loggers.append(WandbLogger(project="YOLO", name=cfg.name, save_dir=save_path, id=None, config=wandb_cfg))

    return progress, loggers, save_path

log_model_structure(model)

Source code in yolo/utils/logging_utils.py
def log_model_structure(model: Union[ModuleList, YOLOLayer, YOLO]):
    if isinstance(model, YOLO):
        model = model.model
    console = Console()
    table = Table(title="Model Layers")

    table.add_column("Index", justify="center")
    table.add_column("Layer Type", justify="center")
    table.add_column("Tags", justify="center")
    table.add_column("Params", justify="right")
    table.add_column("Channels (IN->OUT)", justify="center")

    for idx, layer in enumerate(model, start=1):
        layer_param = sum(x.numel() for x in layer.parameters())  # number parameters
        in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
        if in_channels and out_channels:
            if isinstance(in_channels, (list, ListConfig)):
                in_channels = "M"
            if isinstance(out_channels, (list, ListConfig)):
                out_channels = "M"
            channels = f"{str(in_channels): >4} -> {str(out_channels): >4}"
        else:
            channels = "-"
        table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
    console.print(table)

validate_log_directory(cfg, exp_name)

Source code in yolo/utils/logging_utils.py
@rank_zero_only
def validate_log_directory(cfg: Config, exp_name: str) -> Path:
    base_path = Path(cfg.out_path, cfg.task.task)
    save_path = base_path / exp_name

    if not cfg.exist_ok:
        index = 1
        old_exp_name = exp_name
        while save_path.is_dir():
            exp_name = f"{old_exp_name}{index}"
            save_path = base_path / exp_name
            index += 1
        if index > 1:
            logger.opt(colors=True).warning(
                f"🔀 Experiment directory exists! Changed <red>{old_exp_name}</> to <green>{exp_name}</>"
            )

    save_path.mkdir(parents=True, exist_ok=True)
    if not getattr(cfg, "quiet", False):
        logger.info(f"📄 Created log folder: [blue b u]{save_path}[/]")
    logger.addHandler(FileHandler(save_path / "output.log"))
    return save_path

log_bbox(bboxes, class_list=None, image_size=(640, 640))

Convert bounding boxes tensor to a list of dictionaries for logging, normalized by the image size.

Parameters:

Name Type Description Default
bboxes Tensor

Bounding boxes with shape (N, 5) or (N, 6), where each box is [class_id, x_min, y_min, x_max, y_max, (confidence)].

required
class_list Optional[List[str]]

List of class names. Defaults to None.

None
image_size Tuple[int, int]

The size of the image, used for normalization. Defaults to (640, 640).

(640, 640)

Returns:

Type Description
List[dict]

List[dict]: List of dictionaries containing normalized bounding box information.

Source code in yolo/utils/logging_utils.py
def log_bbox(
    bboxes: Tensor, class_list: Optional[List[str]] = None, image_size: Tuple[int, int] = (640, 640)
) -> List[dict]:
    """
    Convert bounding boxes tensor to a list of dictionaries for logging, normalized by the image size.

    Args:
        bboxes (Tensor): Bounding boxes with shape (N, 5) or (N, 6), where each box is [class_id, x_min, y_min, x_max, y_max, (confidence)].
        class_list (Optional[List[str]]): List of class names. Defaults to None.
        image_size (Tuple[int, int]): The size of the image, used for normalization. Defaults to (640, 640).

    Returns:
        List[dict]: List of dictionaries containing normalized bounding box information.
    """
    bbox_list = []
    scale_tensor = torch.Tensor([1, *image_size, *image_size]).to(bboxes.device)
    normalized_bboxes = bboxes[:, :5] / scale_tensor
    for bbox in normalized_bboxes:
        class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
        if class_id == -1:
            break
        bbox_entry = {
            "position": {"minX": x_min, "maxX": x_max, "minY": y_min, "maxY": y_max},
            "class_id": int(class_id),
        }
        if class_list:
            bbox_entry["box_caption"] = class_list[int(class_id)]
        if conf:
            bbox_entry["scores"] = {"confidence": conf[0]}
        bbox_list.append(bbox_entry)

    return {"predictions": {"box_data": bbox_list}}

Deployment

yolo.utils.deploy_utils

logger = logging.getLogger('yolo') module-attribute

Config dataclass

Source code in yolo/config/config.py
@dataclass
class Config:
    task: Union[TrainConfig, InferenceConfig, ValidationConfig]
    dataset: DatasetConfig
    model: ModelConfig
    name: str

    trainer: TrainerConfig

    image_size: List[int]

    out_path: str
    exist_ok: bool

    lucky_number: int
    use_wandb: bool
    use_tensorboard: bool

    task_type: str
    weight: Optional[str]

FastModelLoader

Source code in yolo/utils/deploy_utils.py
class FastModelLoader:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.compiler = cfg.task.fast_inference
        self.class_num = cfg.dataset.class_num

        self._validate_compiler()
        if cfg.weight == True:
            cfg.weight = Path("weights") / f"{cfg.model.name}.pt"
        self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}"

    def _validate_compiler(self):
        if self.compiler not in ["onnx", "trt", "deploy"]:
            logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.")
            self.compiler = None
        if self.cfg.trainer.device == "mps" and self.compiler == "trt":
            logger.warning(":red_apple: TensorRT does not support MPS devices. Using original model.")
            self.compiler = None

    def load_model(self, device):
        if self.compiler == "onnx":
            return self._load_onnx_model(device)
        elif self.compiler == "trt":
            return self._load_trt_model().to(device)
        elif self.compiler == "deploy":
            self.cfg.model.model.auxiliary = {}
        return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device)

    def _load_onnx_model(self, device):
        from onnxruntime import InferenceSession

        def onnx_forward(self: InferenceSession, x: Tensor):
            x = {self.get_inputs()[0].name: x.cpu().numpy()}
            model_outputs, layer_output = [], []
            for idx, predict in enumerate(self.run(None, x)):
                layer_output.append(torch.from_numpy(predict).to(device))
                if idx % 3 == 2:
                    model_outputs.append(layer_output)
                    layer_output = []
            if len(model_outputs) == 6:
                model_outputs = model_outputs[:3]
            return {"Main": model_outputs}

        InferenceSession.__call__ = onnx_forward

        if device == "cpu":
            providers = ["CPUExecutionProvider"]
        else:
            providers = ["CUDAExecutionProvider"]
        try:
            ort_session = InferenceSession(self.model_path, providers=providers)
            logger.info(":rocket: Using ONNX as MODEL frameworks!")
        except Exception as e:
            logger.warning(f"🈳 Error loading ONNX model: {e}")
            ort_session = self._create_onnx_model(providers)
        return ort_session

    def _create_onnx_model(self, providers):
        from onnxruntime import InferenceSession
        from torch.onnx import export

        model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size))
        export(
            model,
            dummy_input,
            self.model_path,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        )
        logger.info(f":inbox_tray: ONNX model saved to {self.model_path}")
        return InferenceSession(self.model_path, providers=providers)

    def _load_trt_model(self):
        from torch2trt import TRTModule

        try:
            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(self.model_path))
            logger.info(":rocket: Using TensorRT as MODEL frameworks!")
        except FileNotFoundError:
            logger.warning(f"🈳 No found model weight at {self.model_path}")
            model_trt = self._create_trt_model()
        return model_trt

    def _create_trt_model(self):
        from torch2trt import torch2trt

        model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval()
        dummy_input = torch.ones((1, 3, *self.cfg.image_size)).cuda()
        logger.info(f"♻️ Creating TensorRT model")
        model_trt = torch2trt(model.cuda(), [dummy_input])
        torch.save(model_trt.state_dict(), self.model_path)
        logger.info(f":inbox_tray: TensorRT model saved to {self.model_path}")
        return model_trt

create_model(model_cfg, weight_path=True, class_num=80)

Constructs and returns a YOLO model from a model config.

Parameters:

Name Type Description Default
model_cfg ModelConfig

The model configuration (architecture definition).

required
weight_path Union[bool, Path]

Path to pretrained weights. True loads the default weights for the model name; False trains from scratch.

True
class_num int

Number of output classes.

80

Returns:

Name Type Description
YOLO YOLO

An instance of the model defined by the given configuration.

Source code in yolo/model/builder.py
def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
    """Constructs and returns a YOLO model from a model config.

    Args:
        model_cfg (ModelConfig): The model configuration (architecture definition).
        weight_path (Union[bool, Path]): Path to pretrained weights. ``True`` loads
            the default weights for the model name; ``False`` trains from scratch.
        class_num (int): Number of output classes.

    Returns:
        YOLO: An instance of the model defined by the given configuration.
    """
    OmegaConf.set_struct(model_cfg, False)
    model = YOLO(model_cfg, class_num)
    if weight_path:
        if weight_path == True:
            weight_path = Path("weights") / f"{model_cfg.name}.pt"
        elif isinstance(weight_path, str):
            weight_path = Path(weight_path)

        if not weight_path.exists():
            logger.info(f"🌐 Weight {weight_path} not found, try downloading")
            prepare_weight(weight_path=weight_path)
        if weight_path.exists():
            model.save_load_weights(weight_path)
            logger.info(":white_check_mark: Success load model & weight")
    else:
        logger.info(":white_check_mark: Success load model")
    return model

Model Utilities

yolo.utils.model_utils

IDX_TO_ID = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] module-attribute

logger = logging.getLogger('yolo') module-attribute

NMSConfig dataclass

Source code in yolo/config/schemas/task.py
@dataclass
class NMSConfig:
    min_confidence: float
    min_iou: float
    max_bbox: int

Anc2Box

Source code in yolo/tasks/detection/postprocess.py
class Anc2Box:
    def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
        self.device = device

        if hasattr(anchor_cfg, "strides"):
            logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
            self.strides = anchor_cfg.strides
        else:
            logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
            self.strides = self.create_auto_anchor(model, image_size)

        self.head_num = len(anchor_cfg.anchor)
        self.anchor_grids = self.generate_anchors(image_size)
        self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
        self.anchor_num = self.anchor_scale.size(2)
        self.class_num = model.num_classes

    def create_auto_anchor(self, model: YOLO, image_size):
        W, H = image_size
        dummy_input = torch.zeros(1, 3, H, W, device=self.device)
        dummy_output = model(dummy_input)
        strides = []
        for predict_head in dummy_output["Main"]:
            _, _, *anchor_num = predict_head.shape
            strides.append(W // anchor_num[1])
        return strides

    def generate_anchors(self, image_size: List[int]):
        anchor_grids = []
        for stride in self.strides:
            W, H = image_size[0] // stride, image_size[1] // stride
            anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
            anchor_grid = torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device)
            anchor_grids.append(anchor_grid)
        return anchor_grids

    def update(self, image_size):
        self.anchor_grids = self.generate_anchors(image_size)

    def __call__(self, predicts: List[Tensor]):
        preds_box, preds_cls, preds_cnf = [], [], []
        for layer_idx, predict in enumerate(predicts):
            predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
            pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
            pred_box = pred_box.sigmoid()
            pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]) * self.strides[
                layer_idx
            ]
            pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
            preds_box.append(rearrange(pred_box, "B L h w A -> B (L h w) A"))
            preds_cls.append(rearrange(pred_cls, "B L h w C -> B (L h w) C"))
            preds_cnf.append(rearrange(pred_cnf, "B L h w C -> B (L h w) C"))

        preds_box = torch.concat(preds_box, dim=1)
        preds_cls = torch.concat(preds_cls, dim=1)
        preds_cnf = torch.concat(preds_cnf, dim=1)

        preds_box = transform_bbox(preds_box, "xycwh -> xyxy")
        return preds_cls, None, preds_box, preds_cnf.sigmoid()

Vec2Box

Source code in yolo/tasks/detection/postprocess.py
class Vec2Box:
    def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
        self.device = device

        if hasattr(anchor_cfg, "strides"):
            logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
            self.strides = anchor_cfg.strides
        else:
            logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
            self.strides = self.create_auto_anchor(model, image_size)

        anchor_grid, scaler = generate_anchors(image_size, self.strides)
        self.image_size = image_size
        self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)

    def create_auto_anchor(self, model: YOLO, image_size):
        W, H = image_size
        # TODO: need accelerate dummy test
        device = next(model.parameters()).device
        dummy_input = torch.zeros(1, 3, H, W, device=device)
        dummy_output = model(dummy_input)
        strides = []
        for predict_head in dummy_output["Main"]:
            _, _, *anchor_num = predict_head[2].shape
            strides.append(W // anchor_num[1])
        return strides

    def update(self, image_size):
        """
        image_size: W, H
        """
        if self.image_size == image_size:
            return
        anchor_grid, scaler = generate_anchors(image_size, self.strides)
        self.image_size = image_size
        self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)

    def __call__(self, predicts):
        preds_cls, preds_anc, preds_box = [], [], []
        for layer_output in predicts:
            pred_cls, pred_anc, pred_box = layer_output
            preds_cls.append(rearrange(pred_cls, "B C h w -> B (h w) C"))
            preds_anc.append(rearrange(pred_anc, "B A R h w -> B (h w) R A"))
            preds_box.append(rearrange(pred_box, "B X h w -> B (h w) X"))
        preds_cls = torch.concat(preds_cls, dim=1)
        preds_anc = torch.concat(preds_anc, dim=1)
        preds_box = torch.concat(preds_box, dim=1)

        pred_LTRB = preds_box * self.scaler.view(1, -1, 1)
        lt, rb = pred_LTRB.chunk(2, dim=-1)
        preds_box = torch.cat([self.anchor_grid - lt, self.anchor_grid + rb], dim=-1)
        return preds_cls, preds_anc, preds_box

update(image_size)

image_size: W, H

Source code in yolo/tasks/detection/postprocess.py
def update(self, image_size):
    """
    image_size: W, H
    """
    if self.image_size == image_size:
        return
    anchor_grid, scaler = generate_anchors(image_size, self.strides)
    self.image_size = image_size
    self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)

PostProcess

TODO: function document scale back the prediction and do nms for pred_bbox

Source code in yolo/utils/model_utils.py
class PostProcess:
    """
    TODO: function document
    scale back the prediction and do nms for pred_bbox
    """

    def __init__(self, converter: Union[Vec2Box, Anc2Box], nms_cfg: NMSConfig) -> None:
        self.converter = converter
        self.nms = nms_cfg

    def __call__(
        self, predict, rev_tensor: Optional[Tensor] = None, image_size: Optional[List[int]] = None
    ) -> List[Tensor]:
        if image_size is not None:
            self.converter.update(image_size)
        prediction = self.converter(predict["Main"])
        pred_class, _, pred_bbox = prediction[:3]
        pred_conf = prediction[3] if len(prediction) == 4 else None
        if rev_tensor is not None:
            pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
        pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)
        return pred_bbox

bbox_nms(cls_dist, bbox, nms_cfg, confidence=None)

Source code in yolo/tasks/detection/postprocess.py
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
    cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)

    batch_idx, valid_grid, valid_cls = torch.where(cls_dist > nms_cfg.min_confidence)
    valid_con = cls_dist[batch_idx, valid_grid, valid_cls]
    valid_box = bbox[batch_idx, valid_grid]

    nms_idx = batched_nms(valid_box, valid_con, batch_idx + valid_cls * bbox.size(0), nms_cfg.min_iou)
    predicts_nms = []
    for idx in range(cls_dist.size(0)):
        instance_idx = nms_idx[idx == batch_idx[nms_idx]]

        predict_nms = torch.cat(
            [valid_cls[instance_idx][:, None], valid_box[instance_idx], valid_con[instance_idx][:, None]], dim=-1
        )

        predicts_nms.append(predict_nms[: nms_cfg.max_bbox])
    return predicts_nms

transform_bbox(bbox, indicator='xywh -> xyxy')

Source code in yolo/tasks/detection/postprocess.py
def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
    data_type = bbox.dtype
    in_type, out_type = indicator.replace(" ", "").split("->")

    if in_type not in ["xyxy", "xywh", "xycwh"] or out_type not in ["xyxy", "xywh", "xycwh"]:
        raise ValueError("Invalid input or output format")

    if in_type == "xywh":
        x_min = bbox[..., 0]
        y_min = bbox[..., 1]
        x_max = bbox[..., 0] + bbox[..., 2]
        y_max = bbox[..., 1] + bbox[..., 3]
    elif in_type == "xyxy":
        x_min = bbox[..., 0]
        y_min = bbox[..., 1]
        x_max = bbox[..., 2]
        y_max = bbox[..., 3]
    elif in_type == "xycwh":
        x_min = bbox[..., 0] - bbox[..., 2] / 2
        y_min = bbox[..., 1] - bbox[..., 3] / 2
        x_max = bbox[..., 0] + bbox[..., 2] / 2
        y_max = bbox[..., 1] + bbox[..., 3] / 2

    if out_type == "xywh":
        bbox = torch.stack([x_min, y_min, x_max - x_min, y_max - y_min], dim=-1)
    elif out_type == "xyxy":
        bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
    elif out_type == "xycwh":
        bbox = torch.stack([(x_min + x_max) / 2, (y_min + y_max) / 2, x_max - x_min, y_max - y_min], dim=-1)

    return bbox.to(dtype=data_type)

initialize_distributed()

Source code in yolo/utils/model_utils.py
def initialize_distributed() -> None:
    rank = int(os.getenv("RANK", "0"))
    local_rank = int(os.getenv("LOCAL_RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    logger.info(f"🔢 Initialized process group; rank: {rank}, size: {world_size}")
    return local_rank

get_device(device_spec)

Source code in yolo/utils/model_utils.py
def get_device(device_spec: Union[str, int, List[int]]) -> torch.device:
    ddp_flag = False
    if isinstance(device_spec, (list, ListConfig)):
        ddp_flag = True
        device_spec = initialize_distributed()
    if torch.cuda.is_available() and "cuda" in str(device_spec):
        return torch.device(device_spec), ddp_flag
    if not torch.cuda.is_available():
        if device_spec != "cpu":
            logger.warning(f"❎ Device spec: {device_spec} not support, Choosing CPU instead")
        return torch.device("cpu"), False

    device = torch.device(device_spec)
    return device, ddp_flag

collect_prediction(predict_json, local_rank)

Collects predictions from all distributed processes and gathers them on the main process (rank 0).

Parameters:

Name Type Description Default
predict_json List

The prediction data (can be of any type) generated by the current process.

required
local_rank int

The rank of the current process. Typically, rank 0 is the main process.

required

Returns:

Name Type Description
List List

The combined list of predictions from all processes if on rank 0, otherwise predict_json.

Source code in yolo/utils/model_utils.py
def collect_prediction(predict_json: List, local_rank: int) -> List:
    """
    Collects predictions from all distributed processes and gathers them on the main process (rank 0).

    Args:
        predict_json (List): The prediction data (can be of any type) generated by the current process.
        local_rank (int): The rank of the current process. Typically, rank 0 is the main process.

    Returns:
        List: The combined list of predictions from all processes if on rank 0, otherwise predict_json.
    """
    if dist.is_initialized() and local_rank == 0:
        all_predictions = [None for _ in range(dist.get_world_size())]
        dist.gather_object(predict_json, all_predictions, dst=0)
        predict_json = [item for sublist in all_predictions for item in sublist]
    elif dist.is_initialized():
        dist.gather_object(predict_json, None, dst=0)
    return predict_json

predicts_to_json(img_paths, predicts, rev_tensor)

TODO: function document turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)

Source code in yolo/utils/model_utils.py
def predicts_to_json(img_paths, predicts, rev_tensor):
    """
    TODO: function document
    turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output)
    """
    batch_json = []
    for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor):
        scale, shift = box_reverse.split([1, 4])
        bboxes = bboxes.clone()
        bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None]
        bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh")
        for cls, *pos, conf in bboxes:
            bbox = {
                "image_id": int(Path(img_path).stem),
                "category_id": IDX_TO_ID[int(cls)],
                "bbox": [float(p) for p in pos],
                "score": float(conf),
            }
            batch_json.append(bbox)
    return batch_json

Format Converters

yolo.utils.format_converters

convert_dict = {'19.cv1': '19.conv', '16.cv1': '16.conv', '.7.cv1': '.7.conv', '.5.cv1': '.5.conv', '.3.cv1': '.3.conv', '.28.': '.29.', '.25.': '.26.', '.22.': '.23.', 'cv': 'conv', '.m.': '.bottleneck.'} module-attribute

HEAD_NUM = '29' module-attribute

head_converter = {'head_conv': 'm', 'implicit_a': 'ia', 'implicit_m': 'im'} module-attribute

SPP_converter = {'pre_conv.0': 'cv1', 'pre_conv.1': 'cv3', 'pre_conv.2': 'cv4', 'post_conv.0': 'cv5', 'post_conv.1': 'cv6', 'short_conv': 'cv2', 'merge_conv': 'cv7'} module-attribute

REP_converter = {'conv1': 'rbr_dense', 'conv2': 'rbr_1x1', 'conv': '0', 'bn': '1'} module-attribute

replace_dict = {'cv': 'conv', '.m.': '.bottleneck.'} module-attribute

convert_weight(old_state_dict, new_state_dict, model_size=38)

Source code in yolo/utils/format_converters.py
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
    new_weight_set = set(new_state_dict.keys())
    for weight_name, weight_value in old_state_dict.items():
        if HEAD_NUM in weight_name:
            _, _, conv_name, conv_id, *post_fix = weight_name.split(".")
            head_id = 30 if conv_name in ["cv2", "cv3"] else 22
            head_type = "anchor_conv" if conv_name in ["cv2", "cv4"] else "class_conv"
            weight_name = ".".join(["model", str(head_id), "heads", conv_id, head_type, *post_fix])
        else:
            for old_name, new_name in convert_dict.items():
                if old_name in weight_name:
                    weight_name = weight_name.replace(old_name, new_name)
        if weight_name in new_weight_set:
            assert new_state_dict[weight_name].shape == weight_value.shape, f"shape miss match {weight_name}"
            new_state_dict[weight_name] = weight_value
            new_weight_set.remove(weight_name)
    return new_state_dict

convert_weight_v7(old_state_dict, new_state_dict)

Source code in yolo/utils/format_converters.py
def convert_weight_v7(old_state_dict, new_state_dict):
    for key_name in new_state_dict.keys():
        new_shape = new_state_dict[key_name].shape
        old_key_name = "model." + key_name
        if old_key_name not in old_state_dict.keys():
            if "heads" in key_name:
                layer_idx, _, conv_idx, conv_name, *details = key_name.split(".")
                old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details])
            elif any(k in key_name for k in SPP_converter):
                for key, value in SPP_converter.items():
                    if key in key_name:
                        key_name = key_name.replace(key, value)
                old_key_name = "model." + key_name
            elif "conv1" in key_name or "conv2" in key_name:
                for key, value in REP_converter.items():
                    if key in key_name:
                        key_name = key_name.replace(key, value)
                old_key_name = "model." + key_name
        assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}"
        old_shape = old_state_dict[old_key_name].shape
        assert new_shape == old_shape, f"Weight Shape Mismatch!! {old_key_name}"
        new_state_dict[key_name] = old_state_dict[old_key_name]
    return new_state_dict

convert_weight_seg(old_state_dict, new_state_dict)

Source code in yolo/utils/format_converters.py
def convert_weight_seg(old_state_dict, new_state_dict):
    diff = -1
    for old_weight_name in old_state_dict.keys():
        old_idx = int(old_weight_name.split(".")[1])
        if old_idx == 23:
            diff = 3
        elif old_idx == 41:
            diff = -19
        new_idx = old_idx + diff
        new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.")
        for key, val in replace_dict.items():
            new_weight_name = new_weight_name.replace(key, val)

        if new_weight_name not in new_state_dict.keys():
            heads = "heads"
            _, _, conv_name, conv_idx, *details = old_weight_name.split(".")
            if "proto" in conv_name or "dfl" in old_weight_name:
                continue
            if conv_name in ("cv2", "cv3", "cv6"):
                layer_idx = 44
                heads = "detect.heads"
            if conv_name in ("cv4", "cv5", "cv7"):
                layer_idx = 25
                heads = "detect.heads"

            if conv_name in ("cv2", "cv4"):
                conv_task = "anchor_conv"
            elif conv_name in ("cv3", "cv5"):
                conv_task = "class_conv"
            elif conv_name in ("cv6", "cv7"):
                conv_task = "mask_conv"
                heads = "heads"
            else:
                continue

            new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details])

        if (
            new_weight_name not in new_state_dict.keys()
            or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape
        ):
            print(f"new: {new_weight_name}, old: {old_weight_name}")
        new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
    return new_state_dict

discretize_categories(categories)

Maps each category id to a sequential integer index.

Source code in yolo/utils/format_converters.py
def discretize_categories(categories: List[Dict]) -> Dict[int, int]:
    """Maps each category id to a sequential integer index."""
    sorted_categories = sorted(categories, key=lambda c: c["id"])
    return {c["id"]: idx for idx, c in enumerate(sorted_categories)}

normalize_segmentation(segmentation, img_width, img_height)

Source code in yolo/utils/format_converters.py
def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
    return [
        f"{coord / img_width:.6f}" if i % 2 == 0 else f"{coord / img_height:.6f}"
        for i, coord in enumerate(segmentation)
    ]

process_annotation(annotation, image_dims, id_to_idx, file)

Source code in yolo/utils/format_converters.py
def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
    category_id = annotation["category_id"]
    segmentation = (
        annotation["segmentation"][0]
        if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
        else None
    )
    if segmentation is None:
        return
    img_width, img_height = image_dims
    normalized = normalize_segmentation(segmentation, img_width, img_height)
    if id_to_idx:
        category_id = id_to_idx.get(category_id, category_id)
    file.write(f"{category_id} {' '.join(normalized)}\n")

process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx=None)

Source code in yolo/utils/format_converters.py
def process_annotations(
    image_annotations: Dict[int, List[Dict]],
    image_info_dict: Dict[int, tuple],
    output_dir: Path,
    id_to_idx: Optional[Dict[int, int]] = None,
) -> None:
    for image_id, annotations in track(image_annotations.items(), description="Processing annotations"):
        file_path = output_dir / f"{image_id:0>12}.txt"
        if not annotations:
            continue
        with open(file_path, "w") as file:
            for annotation in annotations:
                process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)

convert_annotations(json_file, output_dir)

Source code in yolo/utils/format_converters.py
def convert_annotations(json_file: str, output_dir: str) -> None:
    with open(json_file) as file:
        data = json.load(file)
    Path(output_dir).mkdir(exist_ok=True)
    image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
    id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
    image_annotations = {img_id: [] for img_id in image_info_dict}
    for annotation in data.get("annotations", []):
        if not annotation.get("iscrowd", False):
            image_annotations[annotation["image_id"]].append(annotation)
    process_annotations(image_annotations, image_info_dict, Path(output_dir), id_to_idx)