Skip to content

Training

Solver

yolo.tasks.detection.solver

Config dataclass

Source code in yolo/config/config.py
@dataclass
class Config:
    task: Union[TrainConfig, InferenceConfig, ValidationConfig]
    dataset: DatasetConfig
    model: ModelConfig
    name: str

    trainer: TrainerConfig

    image_size: List[int]

    out_path: str
    exist_ok: bool

    lucky_number: int
    use_wandb: bool
    use_tensorboard: bool

    task_type: str
    weight: Optional[str]

PostProcess

TODO: function document scale back the prediction and do nms for pred_bbox

Source code in yolo/utils/model_utils.py
class PostProcess:
    """
    TODO: function document
    scale back the prediction and do nms for pred_bbox
    """

    def __init__(self, converter: Union[Vec2Box, Anc2Box], nms_cfg: NMSConfig) -> None:
        self.converter = converter
        self.nms = nms_cfg

    def __call__(
        self, predict, rev_tensor: Optional[Tensor] = None, image_size: Optional[List[int]] = None
    ) -> List[Tensor]:
        if image_size is not None:
            self.converter.update(image_size)
        prediction = self.converter(predict["Main"])
        pred_class, _, pred_bbox = prediction[:3]
        pred_conf = prediction[3] if len(prediction) == 4 else None
        if rev_tensor is not None:
            pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
        pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)
        return pred_bbox

BaseModel

Bases: LightningModule

Source code in yolo/tasks/detection/solver.py
class BaseModel(LightningModule):
    def __init__(self, cfg: Config):
        super().__init__()
        self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)

    def forward(self, x):
        return self.model(x)

DetectionValidateModel

Bases: BaseModel

Source code in yolo/tasks/detection/solver.py
@register("detection", "validation")
class DetectionValidateModel(BaseModel):
    def __init__(self, cfg: Config):
        super().__init__(cfg)
        self.cfg = cfg
        if self.cfg.task.task == "validation":
            self.validation_cfg = self.cfg.task
        else:
            self.validation_cfg = self.cfg.task.validation
        self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", backend="faster_coco_eval")
        self.metric.warn_on_many_detections = False
        self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
        self.ema = self.model

    def setup(self, stage):
        self.vec2box = create_converter(
            self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
        )
        self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms)

    def val_dataloader(self):
        return self.val_loader

    def validation_step(self, batch, batch_idx):
        batch_size, images, targets, rev_tensor, img_paths = batch
        H, W = images.shape[2:]
        predicts = self.post_process(self.ema(images), image_size=[W, H])
        mAP = self.metric(
            [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
        )
        return predicts, mAP

    def on_validation_epoch_end(self):
        epoch_metrics = self.metric.compute()
        del epoch_metrics["classes"]
        self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True, logger=True)
        self.log_dict(
            {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
            sync_dist=True,
            rank_zero_only=True,
            logger=True,
        )
        self.metric.reset()

DetectionTrainModel

Bases: DetectionValidateModel

Source code in yolo/tasks/detection/solver.py
@register("detection", "train")
class DetectionTrainModel(DetectionValidateModel):
    def __init__(self, cfg: Config):
        super().__init__(cfg)
        self.cfg = cfg
        self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)

    def setup(self, stage):
        super().setup(stage)
        self.loss_fn = create_loss_function(self.cfg, self.vec2box)

    def train_dataloader(self):
        return self.train_loader

    def on_train_epoch_start(self):
        self.vec2box.update(self.cfg.image_size)

    def training_step(self, batch, batch_idx):
        batch_size, images, targets, *_ = batch
        predicts = self(images)
        aux_predicts = self.vec2box(predicts["AUX"])
        main_predicts = self.vec2box(predicts["Main"])
        loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
        self.log_dict(
            loss_item,
            logger=True,
            prog_bar=True,
            on_epoch=True,
            batch_size=batch_size,
            rank_zero_only=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = create_optimizer(self.model, self.cfg.task.optimizer)

        batch_size = self.cfg.task.data.batch_size
        world_size = getattr(self.trainer, "world_size", 1) if self.trainer else 1
        equivalent_batch_size = getattr(self.cfg.task.data, "equivalent_batch_size", None)
        if equivalent_batch_size is not None:
            max_accum = max(1, round(equivalent_batch_size / (batch_size * world_size)))
        else:
            max_accum = 1

        # Use dataset length — invariant to loader sharding (e.g. Ray Train or Distributed Sampler
        # wraps the loader per rank, so len(train_loader) would be the per-rank count).
        if hasattr(self.train_loader, "dataset"):
            n_samples = len(self.train_loader.dataset)
            global_batch = batch_size * world_size * max_accum
            drop_last = getattr(self.cfg.task.data, "drop_last", False)
            if drop_last:
                steps_per_epoch = max(1, n_samples // global_batch)
            else:
                steps_per_epoch = max(1, ceil(n_samples / global_batch))
        else:
            steps_per_epoch = max(1, ceil(len(self.train_loader) / max_accum))

        scheduler = create_scheduler(optimizer, self.cfg.task.scheduler, steps_per_epoch, self.cfg.task.epoch)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}

DetectionInferenceModel

Bases: BaseModel

Source code in yolo/tasks/detection/solver.py
@register("detection", "inference")
class DetectionInferenceModel(BaseModel):
    def __init__(self, cfg: Config):
        super().__init__(cfg)
        self.cfg = cfg
        # TODO: Add FastModel
        self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)

    def setup(self, stage):
        self.vec2box = create_converter(
            self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
        )
        self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)

    def predict_dataloader(self):
        return self.predict_loader

    def predict_step(self, batch, batch_idx):
        images, rev_tensor, origin_frame = batch
        predicts = self.post_process(self(images), rev_tensor=rev_tensor)
        img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
        if getattr(self.predict_loader, "is_stream", None):
            fps = self._display_stream(img)
        else:
            fps = None
        if getattr(self.cfg.task, "save_predict", None):
            self._save_image(img, batch_idx)
        return img, fps

    def _save_image(self, img, batch_idx):
        save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png"
        img.save(save_image_path)
        print(f"💾 Saved visualize image at {save_image_path}")

create_dataloader(data_cfg, dataset_cfg, task='train')

Source code in yolo/data/loader.py
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
    if task == "inference":
        return StreamDataLoader(data_cfg)

    if getattr(dataset_cfg, "auto_download", False):
        prepare_dataset(dataset_cfg, task)
    dataset = YoloDataset(data_cfg, dataset_cfg, task)

    return DataLoader(
        dataset,
        batch_size=data_cfg.batch_size,
        num_workers=data_cfg.dataloader_workers,
        pin_memory=data_cfg.pin_memory,
        collate_fn=collate_fn,
        drop_last=data_cfg.drop_last,
    )

create_model(model_cfg, weight_path=True, class_num=80)

Constructs and returns a YOLO model from a model config.

Parameters:

Name Type Description Default
model_cfg ModelConfig

The model configuration (architecture definition).

required
weight_path Union[bool, Path]

Path to pretrained weights. True loads the default weights for the model name; False trains from scratch.

True
class_num int

Number of output classes.

80

Returns:

Name Type Description
YOLO YOLO

An instance of the model defined by the given configuration.

Source code in yolo/model/builder.py
def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
    """Constructs and returns a YOLO model from a model config.

    Args:
        model_cfg (ModelConfig): The model configuration (architecture definition).
        weight_path (Union[bool, Path]): Path to pretrained weights. ``True`` loads
            the default weights for the model name; ``False`` trains from scratch.
        class_num (int): Number of output classes.

    Returns:
        YOLO: An instance of the model defined by the given configuration.
    """
    OmegaConf.set_struct(model_cfg, False)
    model = YOLO(model_cfg, class_num)
    if weight_path:
        if weight_path == True:
            weight_path = Path("weights") / f"{model_cfg.name}.pt"
        elif isinstance(weight_path, str):
            weight_path = Path(weight_path)

        if not weight_path.exists():
            logger.info(f"🌐 Weight {weight_path} not found, try downloading")
            prepare_weight(weight_path=weight_path)
        if weight_path.exists():
            model.save_load_weights(weight_path)
            logger.info(":white_check_mark: Success load model & weight")
    else:
        logger.info(":white_check_mark: Success load model")
    return model

create_loss_function(cfg, vec2box)

Source code in yolo/tasks/detection/loss.py
def create_loss_function(cfg: Config, vec2box) -> DualLoss:
    # TODO: make it flexible, if cfg doesn't contain aux, only use SingleLoss
    loss_function = DualLoss(cfg, vec2box)
    logger.info(":white_check_mark: Success load loss function")
    return loss_function

create_converter(model_version='v9-c', *args, **kwargs)

Source code in yolo/tasks/detection/postprocess.py
def create_converter(model_version: str = "v9-c", *args, **kwargs) -> Union[Anc2Box, Vec2Box]:
    if "v7" in model_version:  # check model if v7
        converter = Anc2Box(*args, **kwargs)
    else:
        converter = Vec2Box(*args, **kwargs)
    return converter

to_metrics_format(prediction)

Source code in yolo/tasks/detection/postprocess.py
def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
    prediction = prediction[prediction[:, 0] != -1]
    bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
    if prediction.size(1) == 6:
        bbox["scores"] = prediction[:, 5]
    return bbox

register(task_type, mode)

Decorator that registers a solver class for a (task_type, mode) pair.

Parameters:

Name Type Description Default
task_type str

Task name, e.g. "detection", "segmentation".

required
mode str

Run mode — "train", "validation", or "inference".

required
Example
@register("detection", "train")
class DetectionTrainModel(BaseModel): ...
Source code in yolo/tasks/registry.py
def register(task_type: str, mode: str):
    """Decorator that registers a solver class for a (task_type, mode) pair.

    Args:
        task_type (str): Task name, e.g. ``"detection"``, ``"segmentation"``.
        mode (str): Run mode — ``"train"``, ``"validation"``, or ``"inference"``.

    Example:
        ```python
        @register("detection", "train")
        class DetectionTrainModel(BaseModel): ...
        ```
    """

    def decorator(cls: Type[LightningModule]) -> Type[LightningModule]:
        SOLVERS[(task_type, mode)] = cls
        return cls

    return decorator

create_optimizer(model, optim_cfg)

Create an optimizer for the given model parameters based on the configuration.

Returns:

Type Description
Optimizer

An instance of the optimizer configured according to the provided settings.

Source code in yolo/training/optim.py
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
    """Create an optimizer for the given model parameters based on the configuration.

    Returns:
        An instance of the optimizer configured according to the provided settings.
    """
    optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)

    bias_params = [p for name, p in model.named_parameters() if "bias" in name]
    norm_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" in name]
    conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]

    model_parameters = [
        {"params": bias_params, "momentum": optim_cfg.args.momentum, "weight_decay": 0},
        {"params": conv_params, "momentum": optim_cfg.args.momentum},
        {"params": norm_params, "momentum": optim_cfg.args.momentum, "weight_decay": 0},
    ]

    optimizer = optimizer_class(model_parameters, **optim_cfg.args)
    return optimizer

create_scheduler(optimizer, schedule_cfg, steps_per_epoch=None, epochs=None)

Create a learning rate scheduler for the given optimizer based on the configuration.

Returns:

Type Description
_LRScheduler

An instance of the scheduler configured according to the provided settings.

Source code in yolo/training/optim.py
def create_scheduler(
    optimizer: Optimizer,
    schedule_cfg: SchedulerConfig,
    steps_per_epoch: Optional[int] = None,
    epochs: Optional[int] = None,
) -> _LRScheduler:
    """Create a learning rate scheduler for the given optimizer based on the configuration.

    Returns:
        An instance of the scheduler configured according to the provided settings.
    """
    scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedule_cfg.type)
    epoch_sched = scheduler_class(optimizer, **schedule_cfg.args)

    warmup_policy = None
    warmup_epochs = 0
    if hasattr(schedule_cfg, "warmup"):
        warmup_epochs = int(schedule_cfg.warmup.epochs)
        warmup_policy = YOLOWarmupPolicy(warmup_epochs=warmup_epochs)

    start_momentum = getattr(schedule_cfg.warmup, "start_momentum", 0.8) if hasattr(schedule_cfg, "warmup") else 0.8
    end_momentum = getattr(schedule_cfg.warmup, "end_momentum", 0.937) if hasattr(schedule_cfg, "warmup") else 0.937

    return WarmupBatchScheduler(
        optimizer=optimizer,
        scheduler=epoch_sched,
        steps_per_epoch=steps_per_epoch or 1,
        warmup_epochs=warmup_epochs,
        warmup_policy=warmup_policy,
        start_momentum=start_momentum,
        end_momentum=end_momentum,
    )

draw_bboxes(img, bboxes, *, idx2label=None)

Draw bounding boxes on an image.

Args: - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes. - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max], where coordinates are normalized [0, 1].

Source code in yolo/utils/drawer.py
def draw_bboxes(
    img: Union[Image.Image, torch.Tensor],
    bboxes: List[List[Union[int, float]]],
    *,
    idx2label: Optional[list] = None,
):
    """
    Draw bounding boxes on an image.

    Args:
    - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
    - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
      where coordinates are normalized [0, 1].
    """
    # Convert tensor image to PIL Image if necessary
    if isinstance(img, torch.Tensor):
        if img.dim() > 3:
            logger.warning("🔍 >3 dimension tensor detected, using the 0-idx image.")
            img = img[0]
        img = to_pil_image(img)

    if isinstance(bboxes, list) or bboxes.ndim == 3:
        bboxes = bboxes[0]

    img = img.copy()
    label_size = img.size[1] / 30
    draw = ImageDraw.Draw(img, "RGBA")

    try:
        font = ImageFont.truetype("arial.ttf", int(label_size))
    except IOError:
        font = ImageFont.load_default(int(label_size))

    for bbox in bboxes:
        class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
        x_min, x_max = min(x_min, x_max), max(x_min, x_max)
        y_min, y_max = min(y_min, y_max), max(y_min, y_max)
        bbox = [(x_min, y_min), (x_max, y_max)]

        random.seed(int(class_id))
        color_map = (random.randint(0, 200), random.randint(0, 200), random.randint(0, 200))

        draw.rounded_rectangle(bbox, outline=(*color_map, 200), radius=5, width=2)
        draw.rounded_rectangle(bbox, fill=(*color_map, 100), radius=5)

        class_text = str(idx2label[int(class_id)] if idx2label else int(class_id))
        label_text = f"{class_text}" + (f" {conf[0]: .0%}" if conf else "")

        text_bbox = font.getbbox(label_text)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = (text_bbox[3] - text_bbox[1]) * 1.5

        text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)]
        draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
        draw.text((x_min, y_min), label_text, fill="white", font=font)

    return img

Optimizer & Scheduler

yolo.training.optim

OptimizerConfig dataclass

Source code in yolo/config/schemas/training.py
@dataclass
class OptimizerConfig:
    type: str
    args: OptimizerArgs

SchedulerConfig dataclass

Source code in yolo/config/schemas/training.py
@dataclass
class SchedulerConfig:
    type: str
    warmup: Dict[str, Union[int, float]]
    args: Dict[str, Any]

YOLO

Bases: Module

A preliminary YOLO (You Only Look Once) model class still under development.

Parameters:

Name Type Description Default
model_cfg ModelConfig

Configuration for the YOLO model. Expected to define the layers, parameters, and any other relevant configuration details.

required
Source code in yolo/model/builder.py
class YOLO(nn.Module):
    """
    A preliminary YOLO (You Only Look Once) model class still under development.

    Parameters:
        model_cfg: Configuration for the YOLO model. Expected to define the layers,
                   parameters, and any other relevant configuration details.
    """

    def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
        super(YOLO, self).__init__()
        self.num_classes = class_num
        self.layer_map = get_layer_map()  # Get the map Dict[str: Module]
        self.model: List[YOLOLayer] = nn.ModuleList()
        self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
        self.build_model(model_cfg.model)

    def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
        self.layer_index = {}
        output_dim, layer_idx = [3], 1
        logger.info(f":tractor: Building YOLO")
        for arch_name in model_arch:
            if model_arch[arch_name]:
                logger.info(f"  :building_construction:  Building {arch_name}")
            for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
                layer_type, layer_info = next(iter(layer_spec.items()))
                layer_args = layer_info.get("args", {})

                # Get input source
                source = self.get_source_idx(layer_info.get("source", -1), layer_idx)

                # Find in channels
                if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
                    layer_args["in_channels"] = output_dim[source]
                if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]):
                    if isinstance(source, list):
                        layer_args["in_channels"] = [output_dim[idx] for idx in source]
                    else:
                        layer_args["in_channel"] = output_dim[source]
                    layer_args["num_classes"] = self.num_classes
                    layer_args["reg_max"] = self.reg_max

                # create layers
                layer = self.create_layer(layer_type, source, layer_info, **layer_args)
                self.model.append(layer)

                if layer.tags:
                    if layer.tags in self.layer_index:
                        raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
                    self.layer_index[layer.tags] = layer_idx

                out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
                output_dim.append(out_channels)
                setattr(layer, "out_c", out_channels)
            layer_idx += 1

    def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None):
        y = {0: x, **(external or {})}
        output = dict()
        for index, layer in enumerate(self.model, start=1):
            if isinstance(layer.source, list):
                model_input = [y[idx] for idx in layer.source]
            else:
                model_input = y[layer.source]

            external_input = {source_name: y[source_name] for source_name in layer.external}

            x = layer(model_input, **external_input)
            y[-1] = x
            if layer.usable:
                y[index] = x
            if layer.output:
                output[layer.tags] = x
                if layer.tags == shortcut:
                    return output
        return output

    def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
        if hasattr(layer_args, "out_channels"):
            return layer_args["out_channels"]
        if layer_type == "CBFuse":
            return output_dim[source[-1]]
        if isinstance(source, int):
            return output_dim[source]
        if isinstance(source, list):
            return sum(output_dim[idx] for idx in source)

    def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
        if isinstance(source, ListConfig):
            return [self.get_source_idx(index, layer_idx) for index in source]
        if isinstance(source, str):
            source = self.layer_index[source]
        if source < -1:
            source += layer_idx
        if source > 0:  # Using Previous Layer's Output
            self.model[source - 1].usable = True
        return source

    def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
        if layer_type in self.layer_map:
            layer = self.layer_map[layer_type](**kwargs)
            setattr(layer, "layer_type", layer_type)
            setattr(layer, "source", source)
            setattr(layer, "in_c", kwargs.get("in_channels", None))
            setattr(layer, "output", layer_info.get("output", False))
            setattr(layer, "tags", layer_info.get("tags", None))
            setattr(layer, "external", layer_info.get("external", []))
            setattr(layer, "usable", 0)
            return layer
        else:
            raise ValueError(f"Unsupported layer type: {layer_type}")

    def save_load_weights(self, weights: Union[Path, OrderedDict]):
        """
        Update the model's weights with the provided weights.

        args:
            weights: A OrderedDict containing the new weights.
        """
        if isinstance(weights, Path):
            weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
        if "state_dict" in weights:
            weights = {name.removeprefix("model.model."): key for name, key in weights["state_dict"].items()}
        model_state_dict = self.model.state_dict()

        # TODO1: autoload old version weight
        # TODO2: weight transform if num_class difference

        error_dict = {"Mismatch": set(), "Not Found": set()}
        for model_key, model_weight in model_state_dict.items():
            if model_key not in weights:
                error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
                continue
            if model_weight.shape != weights[model_key].shape:
                error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
                continue
            model_state_dict[model_key] = weights[model_key]

        for error_name, error_set in error_dict.items():
            error_dict = dict()
            for layer_idx, *layer_name in error_set:
                if layer_idx not in error_dict:
                    error_dict[layer_idx] = [".".join(layer_name)]
                else:
                    error_dict[layer_idx].append(".".join(layer_name))
            for layer_idx, layer_name in error_dict.items():
                layer_name.sort()
                logger.warning(f":warning: Weight {error_name} for Layer {layer_idx}: {', '.join(layer_name)}")

        self.model.load_state_dict(model_state_dict)

save_load_weights(weights)

Update the model's weights with the provided weights.

Parameters:

Name Type Description Default
weights Union[Path, OrderedDict]

A OrderedDict containing the new weights.

required
Source code in yolo/model/builder.py
def save_load_weights(self, weights: Union[Path, OrderedDict]):
    """
    Update the model's weights with the provided weights.

    args:
        weights: A OrderedDict containing the new weights.
    """
    if isinstance(weights, Path):
        weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
    if "state_dict" in weights:
        weights = {name.removeprefix("model.model."): key for name, key in weights["state_dict"].items()}
    model_state_dict = self.model.state_dict()

    # TODO1: autoload old version weight
    # TODO2: weight transform if num_class difference

    error_dict = {"Mismatch": set(), "Not Found": set()}
    for model_key, model_weight in model_state_dict.items():
        if model_key not in weights:
            error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
            continue
        if model_weight.shape != weights[model_key].shape:
            error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
            continue
        model_state_dict[model_key] = weights[model_key]

    for error_name, error_set in error_dict.items():
        error_dict = dict()
        for layer_idx, *layer_name in error_set:
            if layer_idx not in error_dict:
                error_dict[layer_idx] = [".".join(layer_name)]
            else:
                error_dict[layer_idx].append(".".join(layer_name))
        for layer_idx, layer_name in error_dict.items():
            layer_name.sort()
            logger.warning(f":warning: Weight {error_name} for Layer {layer_idx}: {', '.join(layer_name)}")

    self.model.load_state_dict(model_state_dict)

WarmupLRPolicy

Base strategy for per-group LR shape during warmup.

Subclass this to define custom warmup curves without touching WarmupBatchScheduler (Open/Closed Principle).

Source code in yolo/training/optim.py
class WarmupLRPolicy:
    """Base strategy for per-group LR shape during warmup.

    Subclass this to define custom warmup curves without touching
    ``WarmupBatchScheduler`` (Open/Closed Principle).
    """

    def start_lr(self, group_idx: int, initial_lr: float) -> float:
        """LR at virtual epoch -1 — where epoch 0 batch interpolation begins."""
        raise NotImplementedError

    def target_lr(self, epoch: int, group_idx: int, initial_lr: float) -> float:
        """LR target at the end of warmup ``epoch``."""
        raise NotImplementedError

start_lr(group_idx, initial_lr)

LR at virtual epoch -1 — where epoch 0 batch interpolation begins.

Source code in yolo/training/optim.py
def start_lr(self, group_idx: int, initial_lr: float) -> float:
    """LR at virtual epoch -1 — where epoch 0 batch interpolation begins."""
    raise NotImplementedError

target_lr(epoch, group_idx, initial_lr)

LR target at the end of warmup epoch.

Source code in yolo/training/optim.py
def target_lr(self, epoch: int, group_idx: int, initial_lr: float) -> float:
    """LR target at the end of warmup ``epoch``."""
    raise NotImplementedError

LinearWarmupPolicy

Bases: WarmupLRPolicy

Uniform ramp: all param groups rise from 0 → initial_lr over warmup.

Source code in yolo/training/optim.py
class LinearWarmupPolicy(WarmupLRPolicy):
    """Uniform ramp: all param groups rise from 0 → initial_lr over warmup."""

    def __init__(self, warmup_epochs: int):
        self.warmup_epochs = int(warmup_epochs)

    def start_lr(self, group_idx: int, initial_lr: float) -> float:
        return 0.0

    def target_lr(self, epoch: int, group_idx: int, initial_lr: float) -> float:
        return lerp(0.0, initial_lr, epoch + 1, self.warmup_epochs)

YOLOWarmupPolicy

Bases: WarmupLRPolicy

YOLO-style warmup: bias group drops, all other groups rise.

Group 0 (bias): starts at 10× initial_lr and ramps down to 1×. Groups 1+ (conv, bn): start at 0 and ramp up to 1×.

This mirrors the original lambda2/lambda1 scheme from YOLO.

Source code in yolo/training/optim.py
class YOLOWarmupPolicy(WarmupLRPolicy):
    """YOLO-style warmup: bias group drops, all other groups rise.

    Group 0 (bias): starts at 10× initial_lr and ramps down to 1×.
    Groups 1+ (conv, bn): start at 0 and ramp up to 1×.

    This mirrors the original lambda2/lambda1 scheme from YOLO.
    """

    def __init__(self, warmup_epochs: int):
        self.warmup_epochs = int(warmup_epochs)

    def _lambda2(self, epoch: int) -> float:
        """Scale factor for bias group: 10 → 1 over warmup."""
        return 10 - 9 * ((epoch + 1) / self.warmup_epochs) if epoch < self.warmup_epochs else 1.0

    def _lambda1(self, epoch: int) -> float:
        """Scale factor for other groups: 0 → 1 over warmup."""
        return (epoch + 1) / self.warmup_epochs if epoch < self.warmup_epochs else 1.0

    def start_lr(self, group_idx: int, initial_lr: float) -> float:
        # Virtual epoch -1: lambda2(-1)=10, lambda1(-1)=0
        return (10.0 if group_idx == 0 else 0.0) * initial_lr

    def target_lr(self, epoch: int, group_idx: int, initial_lr: float) -> float:
        factor = self._lambda2(epoch) if group_idx == 0 else self._lambda1(epoch)
        return factor * initial_lr

WarmupBatchScheduler

Bases: _LRScheduler

Batch-level LR and momentum scheduler with epoch-aware warmup.

Wraps an epoch-level scheduler and linearly interpolates LR across batches within each epoch. Momentum is also interpolated during warmup epochs.

Must be called with interval="step" in Lightning so that scheduler.step() fires once per optimizer step.

Source code in yolo/training/optim.py
class WarmupBatchScheduler(_LRScheduler):
    """Batch-level LR and momentum scheduler with epoch-aware warmup.

    Wraps an epoch-level ``scheduler`` and linearly interpolates LR
    across batches within each epoch.  Momentum is also interpolated during
    warmup epochs.

    Must be called with ``interval="step"`` in Lightning so that
    ``scheduler.step()`` fires once per optimizer step.
    """

    def __init__(
        self,
        optimizer: Optimizer,
        scheduler: _LRScheduler,
        steps_per_epoch: int,
        warmup_epochs: int = 3,
        warmup_policy: Optional[WarmupLRPolicy] = None,
        start_momentum: float = 0.8,
        end_momentum: float = 0.937,
        last_epoch: int = -1,
    ):
        self.scheduler = scheduler
        self.steps_per_epoch = max(1, int(steps_per_epoch))
        self.warmup_epochs = int(warmup_epochs)
        self.warmup_policy = warmup_policy
        self.start_momentum = float(start_momentum)
        self.end_momentum = float(end_momentum)

        # Capture base LRs before any scheduler modifies them
        self._initial_lr: List[float] = [group["lr"] for group in optimizer.param_groups]

        # Epoch 0 interpolation endpoints — derived from policy or defaults
        if warmup_policy is not None:
            self._start_lr: List[float] = [warmup_policy.start_lr(i, lr) for i, lr in enumerate(self._initial_lr)]
            self._end_lr: List[float] = [warmup_policy.target_lr(0, i, lr) for i, lr in enumerate(self._initial_lr)]
        else:
            self._start_lr = [0.0] * len(self._initial_lr) if warmup_epochs > 0 else list(self._initial_lr)
            self._end_lr = list(self._initial_lr)

        self._epoch: int = 0
        super().__init__(optimizer, last_epoch)

    def _position(self) -> tuple:
        """Return (epoch, batch) for the current last_epoch value."""
        step = max(self.last_epoch, 0)
        return step // self.steps_per_epoch, step % self.steps_per_epoch

    def get_lr(self) -> List[float]:
        _, batch = self._position()
        return [lerp(s, e, batch + 1, self.steps_per_epoch) for s, e in zip(self._start_lr, self._end_lr)]

    def _set_lr_momentum(self, epoch: int, batch: int) -> None:
        # LR: interpolate within epoch
        for group, s, e in zip(self.optimizer.param_groups, self._start_lr, self._end_lr):
            group["lr"] = lerp(s, e, batch + 1, self.steps_per_epoch)

        # Momentum: linear interpolation over total warmup steps (global-step progress)
        if self.warmup_epochs > 0:
            warmup_steps = self.warmup_epochs * self.steps_per_epoch
            global_step = epoch * self.steps_per_epoch + batch
            progress = min(global_step + 1, warmup_steps) / warmup_steps
            momentum = self.start_momentum + (self.end_momentum - self.start_momentum) * progress
        else:
            momentum = self.end_momentum
        for group in self.optimizer.param_groups:
            if "momentum" in group:
                group["momentum"] = float(momentum)

        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

    def _advance_epoch(self) -> None:
        """Roll the interpolation window forward by one epoch."""
        self._start_lr = list(self._end_lr)
        self._epoch += 1
        if self.warmup_policy is not None and self._epoch < self.warmup_epochs:
            self._end_lr = [self.warmup_policy.target_lr(self._epoch, i, lr) for i, lr in enumerate(self._initial_lr)]
        else:
            self.scheduler.step()
            self._end_lr = [group["lr"] for group in self.optimizer.param_groups]

    def step(self, epoch: Optional[int] = None) -> None:
        self.last_epoch = self.last_epoch + 1 if epoch is None else int(epoch)
        current_epoch, batch = self._position()

        while self._epoch < current_epoch:
            self._advance_epoch()

        self._set_lr_momentum(epoch=current_epoch, batch=batch)

    def state_dict(self) -> dict:
        state = super().state_dict()
        state["initial_lr"] = list(self._initial_lr)
        state["start_lr"] = list(self._start_lr)
        state["end_lr"] = list(self._end_lr)
        state["epoch"] = int(self._epoch)
        state["scheduler"] = self.scheduler.state_dict()
        return state

    def load_state_dict(self, state_dict: dict) -> None:
        self._initial_lr = state_dict.pop("initial_lr", self._initial_lr)
        self._start_lr = state_dict.pop("start_lr", self._start_lr)
        self._end_lr = state_dict.pop("end_lr", self._end_lr)
        self._epoch = state_dict.pop("epoch", self._epoch)
        inner = state_dict.pop("scheduler", None)
        if inner is not None:
            self.scheduler.load_state_dict(inner)
        super().load_state_dict(state_dict)
        epoch, batch = self._position()
        self._set_lr_momentum(epoch=epoch, batch=batch)

lerp(start, end, step, total=1)

Linearly interpolates between start and end values.

start * (1 - step) + end * step

Parameters:

Name Type Description Default
start float

The starting value.

required
end float

The ending value.

required
step int

The current step in the interpolation process.

required
total int

The total number of steps.

1

Returns:

Name Type Description
float float

The interpolated value.

Source code in yolo/training/optim.py
def lerp(start: float, end: float, step: Union[int, float], total: int = 1) -> float:
    """
    Linearly interpolates between start and end values.

    start * (1 - step) + end * step

    Parameters:
        start (float): The starting value.
        end (float): The ending value.
        step (int): The current step in the interpolation process.
        total (int): The total number of steps.

    Returns:
        float: The interpolated value.
    """
    if total <= 0:
        return end
    return start + (end - start) * step / total

create_optimizer(model, optim_cfg)

Create an optimizer for the given model parameters based on the configuration.

Returns:

Type Description
Optimizer

An instance of the optimizer configured according to the provided settings.

Source code in yolo/training/optim.py
def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
    """Create an optimizer for the given model parameters based on the configuration.

    Returns:
        An instance of the optimizer configured according to the provided settings.
    """
    optimizer_class: Type[Optimizer] = getattr(torch.optim, optim_cfg.type)

    bias_params = [p for name, p in model.named_parameters() if "bias" in name]
    norm_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" in name]
    conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]

    model_parameters = [
        {"params": bias_params, "momentum": optim_cfg.args.momentum, "weight_decay": 0},
        {"params": conv_params, "momentum": optim_cfg.args.momentum},
        {"params": norm_params, "momentum": optim_cfg.args.momentum, "weight_decay": 0},
    ]

    optimizer = optimizer_class(model_parameters, **optim_cfg.args)
    return optimizer

create_scheduler(optimizer, schedule_cfg, steps_per_epoch=None, epochs=None)

Create a learning rate scheduler for the given optimizer based on the configuration.

Returns:

Type Description
_LRScheduler

An instance of the scheduler configured according to the provided settings.

Source code in yolo/training/optim.py
def create_scheduler(
    optimizer: Optimizer,
    schedule_cfg: SchedulerConfig,
    steps_per_epoch: Optional[int] = None,
    epochs: Optional[int] = None,
) -> _LRScheduler:
    """Create a learning rate scheduler for the given optimizer based on the configuration.

    Returns:
        An instance of the scheduler configured according to the provided settings.
    """
    scheduler_class: Type[_LRScheduler] = getattr(torch.optim.lr_scheduler, schedule_cfg.type)
    epoch_sched = scheduler_class(optimizer, **schedule_cfg.args)

    warmup_policy = None
    warmup_epochs = 0
    if hasattr(schedule_cfg, "warmup"):
        warmup_epochs = int(schedule_cfg.warmup.epochs)
        warmup_policy = YOLOWarmupPolicy(warmup_epochs=warmup_epochs)

    start_momentum = getattr(schedule_cfg.warmup, "start_momentum", 0.8) if hasattr(schedule_cfg, "warmup") else 0.8
    end_momentum = getattr(schedule_cfg.warmup, "end_momentum", 0.937) if hasattr(schedule_cfg, "warmup") else 0.937

    return WarmupBatchScheduler(
        optimizer=optimizer,
        scheduler=epoch_sched,
        steps_per_epoch=steps_per_epoch or 1,
        warmup_epochs=warmup_epochs,
        warmup_policy=warmup_policy,
        start_momentum=start_momentum,
        end_momentum=end_momentum,
    )

Callbacks

yolo.training.callbacks

logger = logging.getLogger('yolo') module-attribute

DataConfig dataclass

Source code in yolo/config/schemas/data.py
@dataclass
class DataConfig:
    shuffle: bool
    batch_size: int
    pin_memory: bool
    dataloader_workers: int
    image_size: List[int]
    data_augment: Dict[str, int]
    source: Optional[Union[str, int]]
    dynamic_shape: Optional[bool]
    equivalent_batch_size: Optional[int] = 64
    drop_last: bool = True

SchedulerConfig dataclass

Source code in yolo/config/schemas/training.py
@dataclass
class SchedulerConfig:
    type: str
    warmup: Dict[str, Union[int, float]]
    args: Dict[str, Any]

EMA

Bases: Callback

Exponential Moving Average of model weights as a Lightning Callback.

Keeps a shadow copy of model parameters smoothed over training steps

beta = decay * (1 - exp(-step / tau)) shadow = beta * shadow + (1 - beta) * model

The tau warmup ramps beta up from ~0 at step 0, so early noisy updates don't dominate the shadow. Validation always runs on shadow weights; training weights are swapped back immediately after.

Source code in yolo/training/callbacks.py
class EMA(Callback):
    """Exponential Moving Average of model weights as a Lightning Callback.

    Keeps a shadow copy of model parameters smoothed over training steps:
        beta = decay * (1 - exp(-step / tau))
        shadow = beta * shadow + (1 - beta) * model

    The tau warmup ramps beta up from ~0 at step 0, so early noisy updates
    don't dominate the shadow. Validation always runs on shadow weights;
    training weights are swapped back immediately after.
    """

    _CHECKPOINT_KEY = "ema_shadow"

    def __init__(self, decay: float = 0.9999, tau: float = 2000.0) -> None:
        super().__init__()
        logger.info(":chart_with_upwards_trend: Enable Model EMA")
        self.decay = decay
        self.tau = tau
        self.step: int = 0
        self.batch_count: int = 0
        self.shadow: Optional[dict] = None
        self._training_weights: Optional[dict] = None

    def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
        """Initialise shadow from the model before training begins."""
        if self.shadow is None:
            self.shadow = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}

    def _beta(self) -> float:
        """Effective smoothing coefficient at the current step."""
        return self.decay * (1 - exp(-self.step / self.tau))

    @torch.no_grad()
    def update(self, pl_module: "LightningModule") -> None:
        """Blend model parameters into shadow; copy buffers directly."""
        current = pl_module.model.state_dict()
        if self.shadow is None:
            self.shadow = {k: v.detach().clone() for k, v in current.items()}
            return

        beta = self._beta()

        param_keys = [k for k, _ in pl_module.model.named_parameters()]
        shadow_params = [self.shadow[k] for k in param_keys]
        model_params = [current[k].detach().to(self.shadow[k].device) for k in param_keys]
        if hasattr(torch, "_foreach_lerp_"):
            torch._foreach_lerp_(shadow_params, model_params, 1.0 - beta)
        elif hasattr(torch, "_foreach_mul_"):
            torch._foreach_mul_(shadow_params, beta)
            torch._foreach_add_(shadow_params, model_params, alpha=1.0 - beta)
        else:
            for s, m in zip(shadow_params, model_params):
                s.mul_(beta).add_(m, alpha=1.0 - beta)

        for key, buf in pl_module.model.named_buffers():
            self.shadow[key].copy_(buf.detach().to(self.shadow[key].device))

    @torch.no_grad()
    def apply_shadow(self, pl_module: "LightningModule") -> None:
        """Snapshot training weights then load shadow weights into the model."""
        if self.shadow is None:
            return
        self._training_weights = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}
        pl_module.model.load_state_dict(self.shadow, strict=True)

    @torch.no_grad()
    def restore(self, pl_module: "LightningModule") -> None:
        """Reload the training-weight snapshot, discarding the shadow swap."""
        if self._training_weights is None:
            return
        pl_module.model.load_state_dict(self._training_weights, strict=True)
        self._training_weights = None

    @torch.no_grad()
    def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
        self.batch_count += 1
        if self.batch_count % trainer.accumulate_grad_batches != 0:
            return
        self.step += 1
        self.update(pl_module)

    def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        self.batch_count = 0
        self.apply_shadow(pl_module)

    def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        self.restore(pl_module)

    def on_save_checkpoint(self, trainer: "Trainer", pl_module: "LightningModule", checkpoint: dict) -> None:
        if self.shadow is None:
            return
        checkpoint[self._CHECKPOINT_KEY] = {k: v.detach().cpu() for k, v in self.shadow.items()}
        checkpoint["ema_step"] = self.step
        checkpoint["ema_batch_count"] = self.batch_count

    def on_load_checkpoint(self, trainer: "Trainer", pl_module: "LightningModule", checkpoint: dict) -> None:
        self.step = checkpoint.get("ema_step", 0)
        self.batch_count = checkpoint.get("ema_batch_count", 0)
        if self._CHECKPOINT_KEY not in checkpoint:
            return
        target_device = next(pl_module.model.parameters()).device
        self.shadow = {k: v.detach().clone().to(target_device) for k, v in checkpoint[self._CHECKPOINT_KEY].items()}

setup(trainer, pl_module, stage)

Initialise shadow from the model before training begins.

Source code in yolo/training/callbacks.py
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
    """Initialise shadow from the model before training begins."""
    if self.shadow is None:
        self.shadow = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}

update(pl_module)

Blend model parameters into shadow; copy buffers directly.

Source code in yolo/training/callbacks.py
@torch.no_grad()
def update(self, pl_module: "LightningModule") -> None:
    """Blend model parameters into shadow; copy buffers directly."""
    current = pl_module.model.state_dict()
    if self.shadow is None:
        self.shadow = {k: v.detach().clone() for k, v in current.items()}
        return

    beta = self._beta()

    param_keys = [k for k, _ in pl_module.model.named_parameters()]
    shadow_params = [self.shadow[k] for k in param_keys]
    model_params = [current[k].detach().to(self.shadow[k].device) for k in param_keys]
    if hasattr(torch, "_foreach_lerp_"):
        torch._foreach_lerp_(shadow_params, model_params, 1.0 - beta)
    elif hasattr(torch, "_foreach_mul_"):
        torch._foreach_mul_(shadow_params, beta)
        torch._foreach_add_(shadow_params, model_params, alpha=1.0 - beta)
    else:
        for s, m in zip(shadow_params, model_params):
            s.mul_(beta).add_(m, alpha=1.0 - beta)

    for key, buf in pl_module.model.named_buffers():
        self.shadow[key].copy_(buf.detach().to(self.shadow[key].device))

apply_shadow(pl_module)

Snapshot training weights then load shadow weights into the model.

Source code in yolo/training/callbacks.py
@torch.no_grad()
def apply_shadow(self, pl_module: "LightningModule") -> None:
    """Snapshot training weights then load shadow weights into the model."""
    if self.shadow is None:
        return
    self._training_weights = {k: v.detach().clone() for k, v in pl_module.model.state_dict().items()}
    pl_module.model.load_state_dict(self.shadow, strict=True)

restore(pl_module)

Reload the training-weight snapshot, discarding the shadow swap.

Source code in yolo/training/callbacks.py
@torch.no_grad()
def restore(self, pl_module: "LightningModule") -> None:
    """Reload the training-weight snapshot, discarding the shadow swap."""
    if self._training_weights is None:
        return
    pl_module.model.load_state_dict(self._training_weights, strict=True)
    self._training_weights = None

GradientAccumulation

Bases: Callback

Source code in yolo/training/callbacks.py
class GradientAccumulation(Callback):
    def __init__(self, data_cfg: DataConfig, scheduler_cfg: SchedulerConfig):
        super().__init__()
        self.equivalent_batch_size = data_cfg.equivalent_batch_size
        self.actual_batch_size = data_cfg.batch_size
        self.warmup_epochs = getattr(scheduler_cfg.warmup, "epochs", 0)
        self.max_accumulation = 1
        self.warmup_batches = 0
        self.steps_per_epoch = 1
        logger.info(":arrows_counterclockwise: Enable Gradient Accumulation")

    def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
        scaled_batch = self.actual_batch_size * trainer.world_size
        self.max_accumulation = max(1, round(self.equivalent_batch_size / scaled_batch))

    def on_train_start(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
        for sched_cfg in trainer.lr_scheduler_configs:
            if hasattr(sched_cfg.scheduler, "steps_per_epoch"):
                self.steps_per_epoch = sched_cfg.scheduler.steps_per_epoch
                break
        self.warmup_batches = int(self.warmup_epochs * self.steps_per_epoch)

    def on_train_batch_start(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None:
        step = trainer.global_step
        if step < self.warmup_batches:
            current_accumulation = round(lerp(1, self.max_accumulation, step, self.warmup_batches))
        else:
            current_accumulation = self.max_accumulation
        trainer.accumulate_grad_batches = current_accumulation

lerp(start, end, step, total=1)

Linearly interpolates between start and end values.

start * (1 - step) + end * step

Parameters:

Name Type Description Default
start float

The starting value.

required
end float

The ending value.

required
step int

The current step in the interpolation process.

required
total int

The total number of steps.

1

Returns:

Name Type Description
float float

The interpolated value.

Source code in yolo/training/optim.py
def lerp(start: float, end: float, step: Union[int, float], total: int = 1) -> float:
    """
    Linearly interpolates between start and end values.

    start * (1 - step) + end * step

    Parameters:
        start (float): The starting value.
        end (float): The ending value.
        step (int): The current step in the interpolation process.
        total (int): The total number of steps.

    Returns:
        float: The interpolated value.
    """
    if total <= 0:
        return end
    return start + (end - start) * step / total