Skip to content

Data

Dataset

yolo.data.dataset

logger = logging.getLogger('yolo') module-attribute

DataConfig dataclass

Source code in yolo/config/schemas/data.py
@dataclass
class DataConfig:
    shuffle: bool
    batch_size: int
    pin_memory: bool
    dataloader_workers: int
    image_size: List[int]
    data_augment: Dict[str, int]
    source: Optional[Union[str, int]]
    dynamic_shape: Optional[bool]
    equivalent_batch_size: Optional[int] = 64
    drop_last: bool = True

DatasetConfig dataclass

Source code in yolo/config/schemas/data.py
@dataclass
class DatasetConfig:
    path: str
    class_num: int
    class_list: List[str]
    auto_download: Optional[DownloadOptions]

RemoveOutliers

Removes outlier bounding boxes that are too small or have invalid dimensions.

Source code in yolo/data/augmentation.py
class RemoveOutliers:
    """Removes outlier bounding boxes that are too small or have invalid dimensions."""

    def __init__(self, min_box_area=1e-8):
        """
        Args:
            min_box_area (float): Minimum area for a box to be kept, as a fraction of the image area.
        """
        self.min_box_area = min_box_area

    def __call__(self, image, boxes):
        """
        Args:
            image (PIL.Image): The cropped image.
            boxes (torch.Tensor): Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max).
        Returns:
            PIL.Image: The input image (unchanged).
            torch.Tensor: Filtered bounding boxes.
        """
        box_areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 4] - boxes[:, 2])

        valid_boxes = (box_areas > self.min_box_area) & (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 4] > boxes[:, 2])

        return image, boxes[valid_boxes]

__init__(min_box_area=1e-08)

Parameters:

Name Type Description Default
min_box_area float

Minimum area for a box to be kept, as a fraction of the image area.

1e-08
Source code in yolo/data/augmentation.py
def __init__(self, min_box_area=1e-8):
    """
    Args:
        min_box_area (float): Minimum area for a box to be kept, as a fraction of the image area.
    """
    self.min_box_area = min_box_area

__call__(image, boxes)

Parameters:

Name Type Description Default
image Image

The cropped image.

required
boxes Tensor

Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max).

required

Returns: PIL.Image: The input image (unchanged). torch.Tensor: Filtered bounding boxes.

Source code in yolo/data/augmentation.py
def __call__(self, image, boxes):
    """
    Args:
        image (PIL.Image): The cropped image.
        boxes (torch.Tensor): Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max).
    Returns:
        PIL.Image: The input image (unchanged).
        torch.Tensor: Filtered bounding boxes.
    """
    box_areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 4] - boxes[:, 2])

    valid_boxes = (box_areas > self.min_box_area) & (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 4] > boxes[:, 2])

    return image, boxes[valid_boxes]

PadAndResize

Source code in yolo/data/augmentation.py
class PadAndResize:
    def __init__(self, image_size, background_color=(114, 114, 114)):
        """Initialize the object with the target image size."""
        self.target_width, self.target_height = image_size
        self.background_color = background_color

    def set_size(self, image_size: List[int]):
        self.target_width, self.target_height = image_size

    def __call__(self, image: Image, boxes):
        img_width, img_height = image.size
        scale = min(self.target_width / img_width, self.target_height / img_height)
        new_width, new_height = int(img_width * scale), int(img_height * scale)

        resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

        pad_left = (self.target_width - new_width) // 2
        pad_top = (self.target_height - new_height) // 2
        padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color)
        padded_image.paste(resized_image, (pad_left, pad_top))

        boxes[:, [1, 3]] = (boxes[:, [1, 3]] * new_width + pad_left) / self.target_width
        boxes[:, [2, 4]] = (boxes[:, [2, 4]] * new_height + pad_top) / self.target_height

        transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top])
        return padded_image, boxes, transform_info

__init__(image_size, background_color=(114, 114, 114))

Initialize the object with the target image size.

Source code in yolo/data/augmentation.py
def __init__(self, image_size, background_color=(114, 114, 114)):
    """Initialize the object with the target image size."""
    self.target_width, self.target_height = image_size
    self.background_color = background_color

HorizontalFlip

Randomly horizontally flips the image along with the bounding boxes.

Source code in yolo/data/augmentation.py
class HorizontalFlip:
    """Randomly horizontally flips the image along with the bounding boxes."""

    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, boxes):
        if torch.rand(1) < self.prob:
            image = TF.hflip(image)
            boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
        return image, boxes

VerticalFlip

Randomly vertically flips the image along with the bounding boxes.

Source code in yolo/data/augmentation.py
class VerticalFlip:
    """Randomly vertically flips the image along with the bounding boxes."""

    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, boxes):
        if torch.rand(1) < self.prob:
            image = TF.vflip(image)
            boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
        return image, boxes

Mosaic

Applies the Mosaic augmentation to a batch of images and their corresponding boxes.

Source code in yolo/data/augmentation.py
class Mosaic:
    """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""

    def __init__(self, prob=0.5):
        self.prob = prob
        self.parent = None

    def set_parent(self, parent):
        self.parent = parent

    def __call__(self, image, boxes):
        if torch.rand(1) >= self.prob:
            return image, boxes

        assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."

        img_sz = self.parent.base_size  # Assuming `image_size` is defined in parent
        more_data = self.parent.get_more_data(3)  # get 3 more images randomly

        data = [(image, boxes)] + more_data
        mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz), (114, 114, 114))
        vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
        center = np.array([img_sz, img_sz])
        all_labels = []

        for (image, boxes), vector in zip(data, vectors):
            this_w, this_h = image.size
            coord = tuple(center + vector * np.array([this_w, this_h]))

            mosaic_image.paste(image, coord)
            xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
            xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
            xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
            ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
            ymax = (ymax * this_h + coord[1]) / (2 * img_sz)

            adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
            all_labels.append(adjusted_boxes)

        all_labels = torch.cat(all_labels, dim=0)
        mosaic_image = mosaic_image.resize((img_sz, img_sz))
        return mosaic_image, all_labels

MixUp

Applies the MixUp augmentation to a pair of images and their corresponding boxes.

Source code in yolo/data/augmentation.py
class MixUp:
    """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""

    def __init__(self, prob=0.5, alpha=1.0):
        self.alpha = alpha
        self.prob = prob
        self.parent = None

    def set_parent(self, parent):
        """Set the parent dataset object for accessing dataset methods."""
        self.parent = parent

    def __call__(self, image, boxes):
        if torch.rand(1) >= self.prob:
            return image, boxes

        assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."

        # Retrieve another image and its boxes randomly from the dataset
        image2, boxes2 = self.parent.get_more_data()[0]

        # Calculate the mixup lambda parameter
        lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5

        # Mix images
        image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
        mixed_image = lam * image1 + (1 - lam) * image2

        # Merge bounding boxes
        merged_boxes = torch.cat((boxes, boxes2))

        return TF.to_pil_image(mixed_image), merged_boxes

set_parent(parent)

Set the parent dataset object for accessing dataset methods.

Source code in yolo/data/augmentation.py
def set_parent(self, parent):
    """Set the parent dataset object for accessing dataset methods."""
    self.parent = parent

RandomCrop

Randomly crops the image to half its size along with adjusting the bounding boxes.

Source code in yolo/data/augmentation.py
class RandomCrop:
    """Randomly crops the image to half its size along with adjusting the bounding boxes."""

    def __init__(self, prob=0.5):
        """
        Args:
            prob (float): Probability of applying the crop.
        """
        self.prob = prob

    def __call__(self, image, boxes):
        if torch.rand(1) < self.prob:
            original_width, original_height = image.size
            crop_height, crop_width = original_height // 2, original_width // 2
            top = torch.randint(0, original_height - crop_height + 1, (1,)).item()
            left = torch.randint(0, original_width - crop_width + 1, (1,)).item()

            image = TF.crop(image, top, left, crop_height, crop_width)

            boxes[:, [1, 3]] = boxes[:, [1, 3]] * original_width - left
            boxes[:, [2, 4]] = boxes[:, [2, 4]] * original_height - top

            boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, crop_width)
            boxes[:, [2, 4]] = boxes[:, [2, 4]].clamp(0, crop_height)

            boxes[:, [1, 3]] /= crop_width
            boxes[:, [2, 4]] /= crop_height

        return image, boxes

__init__(prob=0.5)

Parameters:

Name Type Description Default
prob float

Probability of applying the crop.

0.5
Source code in yolo/data/augmentation.py
def __init__(self, prob=0.5):
    """
    Args:
        prob (float): Probability of applying the crop.
    """
    self.prob = prob

AugmentationComposer

Composes several transforms together.

Source code in yolo/data/augmentation.py
class AugmentationComposer:
    """Composes several transforms together."""

    def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640):
        self.transforms = transforms
        # TODO: handle List of image_size [640, 640]
        self.pad_resize = PadAndResize(image_size)
        self.base_size = base_size

        for transform in self.transforms:
            if hasattr(transform, "set_parent"):
                transform.set_parent(self)

    def __call__(self, image, boxes=torch.zeros(0, 5)):
        for transform in self.transforms:
            image, boxes = transform(image, boxes)
        image, boxes, rev_tensor = self.pad_resize(image, boxes)
        image = TF.to_tensor(image)
        return image, boxes, rev_tensor

YoloDataset

Bases: Dataset

Source code in yolo/data/dataset.py
class YoloDataset(Dataset):
    def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
        augment_cfg = data_cfg.data_augment
        self.image_size = data_cfg.image_size
        phase_name = dataset_cfg.get(phase, phase)
        self.batch_size = data_cfg.batch_size
        self.dynamic_shape = getattr(data_cfg, "dynamic_shape", False)
        self.base_size = mean(self.image_size)

        transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
        self.transform = AugmentationComposer(transforms, self.image_size, self.base_size)
        self.transform.get_more_data = self.get_more_data
        self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))

    def load_data(self, dataset_path: Path, phase_name: str) -> list:
        """
        Loads data from a cache or generates a new cache for a specific dataset phase.

        Parameters:
            dataset_path (Path): The root path to the dataset directory.
            phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

        Returns:
            list: The loaded data from the cache for the specified phase.
        """
        cache_path = dataset_path / f"{phase_name}.pache"

        if not cache_path.exists():
            logger.info(f":factory: Generating {phase_name} cache")
            data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
            torch.save(data, cache_path)
        else:
            try:
                data = torch.load(cache_path, weights_only=False)
            except Exception as e:
                logger.error(
                    f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
                    ":rotating_light: This may be caused by using cache from different other YOLO.\n"
                    ":rotating_light: Please clean the cache and try running again."
                )
                raise e
            logger.info(f":package: Loaded {phase_name} cache, there are {len(data)} data in total.")
        return data

    def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list:
        """
        Filters and collects dataset information by pairing images with their corresponding labels.

        Parameters:
            dataset_path (Path): Root path of the dataset directory.
            phase_name (str): Dataset split to load (e.g. ``'train'``, ``'validation'``).
            sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order.

        Returns:
            list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
        """
        images_path = dataset_path / "images" / phase_name
        labels_path, data_type = locate_label_paths(dataset_path, phase_name)
        file_list, adjust_path = dataset_path / f"{phase_name}.txt", False
        if file_list.exists():
            data_type, adjust_path = "txt", True
            # TODO: should i sort by name?
            with open(file_list, "r") as file:
                images_list = [dataset_path / line.rstrip() for line in file]
            labels_list = [
                Path(str(image_path).replace("images", "labels")).with_suffix(".txt") for image_path in images_list
            ]
        else:
            images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()])

        if data_type == "json":
            annotations_index, image_info_dict = create_image_metadata(labels_path)

        data = []
        valid_inputs = 0
        for idx, image_name in enumerate(track(images_list, description="Filtering data")):
            if not adjust_path and not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
                continue
            image_id = Path(image_name).stem

            if data_type == "json":
                image_info = image_info_dict.get(image_id, None)
                if image_info is None:
                    continue
                annotations = annotations_index.get(image_info["id"], [])
                image_seg_annotations = scale_segmentation(annotations, image_info)
            elif data_type == "txt":
                label_path = labels_list[idx] if adjust_path else labels_path / f"{image_id}.txt"
                if not label_path.is_file():
                    image_seg_annotations = []
                else:
                    with open(label_path, "r") as file:
                        image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
            else:
                image_seg_annotations = []

            labels = self.load_valid_labels(image_id, image_seg_annotations)
            img_path = image_name if adjust_path else images_path / image_name
            if sort_image:
                with Image.open(img_path) as img:
                    width, height = img.size
            else:
                width, height = 0, 1
            data.append((img_path, labels, width / height))
            if len(image_seg_annotations) != 0:
                valid_inputs += 1

        data = sorted(data, key=lambda x: x[2], reverse=True)

        logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
        return data

    def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
        """
        Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates
        by finding the minimum and maximum x and y values.

        Parameters:
            label_path (str): The filepath to the label file containing annotation data.
            seg_data_one_img (list): The actual list of annotations (in segmentation format)

        Returns:
            Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
        """
        bboxes = []
        for seg_data in seg_data_one_img:
            cls = seg_data[0]
            points = np.array(seg_data[1:]).reshape(-1, 2).clip(0, 1)
            valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
            if valid_points.size > 1:
                bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
                bboxes.append(bbox)

        if bboxes:
            return torch.stack(bboxes)
        else:
            logger.warning(f"No valid BBox in {label_path}")
            return torch.zeros((0, 5))

    def get_data(self, idx):
        img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
        valid_mask = bboxes[:, 0] != -1
        with Image.open(img_path) as img:
            img = img.convert("RGB")
        return img, torch.from_numpy(bboxes[valid_mask]), img_path

    def get_more_data(self, num: int = 1):
        indices = torch.randint(0, len(self), (num,))
        return [self.get_data(idx)[:2] for idx in indices]

    def _update_image_size(self, idx: int) -> None:
        """Update image size based on dynamic shape and batch settings."""
        batch_start_idx = (idx // self.batch_size) * self.batch_size
        image_ratio = self.ratios[batch_start_idx].clip(1 / 3, 3)
        shift = ((self.base_size / 32 * (image_ratio - 1)) // (image_ratio + 1)) * 32

        self.image_size = [int(self.base_size + shift), int(self.base_size - shift)]
        self.transform.pad_resize.set_size(self.image_size)

    def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
        img, bboxes, img_path = self.get_data(idx)

        if self.dynamic_shape:
            self._update_image_size(idx)

        img, bboxes, rev_tensor = self.transform(img, bboxes)
        bboxes[:, [1, 3]] *= self.image_size[0]
        bboxes[:, [2, 4]] *= self.image_size[1]
        return img, bboxes, rev_tensor, img_path

    def __len__(self) -> int:
        return len(self.bboxes)

load_data(dataset_path, phase_name)

Loads data from a cache or generates a new cache for a specific dataset phase.

Parameters:

Name Type Description Default
dataset_path Path

The root path to the dataset directory.

required
phase_name str

The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

required

Returns:

Name Type Description
list list

The loaded data from the cache for the specified phase.

Source code in yolo/data/dataset.py
def load_data(self, dataset_path: Path, phase_name: str) -> list:
    """
    Loads data from a cache or generates a new cache for a specific dataset phase.

    Parameters:
        dataset_path (Path): The root path to the dataset directory.
        phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

    Returns:
        list: The loaded data from the cache for the specified phase.
    """
    cache_path = dataset_path / f"{phase_name}.pache"

    if not cache_path.exists():
        logger.info(f":factory: Generating {phase_name} cache")
        data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
        torch.save(data, cache_path)
    else:
        try:
            data = torch.load(cache_path, weights_only=False)
        except Exception as e:
            logger.error(
                f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
                ":rotating_light: This may be caused by using cache from different other YOLO.\n"
                ":rotating_light: Please clean the cache and try running again."
            )
            raise e
        logger.info(f":package: Loaded {phase_name} cache, there are {len(data)} data in total.")
    return data

filter_data(dataset_path, phase_name, sort_image=False)

Filters and collects dataset information by pairing images with their corresponding labels.

Parameters:

Name Type Description Default
dataset_path Path

Root path of the dataset directory.

required
phase_name str

Dataset split to load (e.g. 'train', 'validation').

required
sort_image bool

If True, sorts the dataset by the width-to-height ratio of images in descending order.

False

Returns:

Name Type Description
list list

A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.

Source code in yolo/data/dataset.py
def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list:
    """
    Filters and collects dataset information by pairing images with their corresponding labels.

    Parameters:
        dataset_path (Path): Root path of the dataset directory.
        phase_name (str): Dataset split to load (e.g. ``'train'``, ``'validation'``).
        sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order.

    Returns:
        list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
    """
    images_path = dataset_path / "images" / phase_name
    labels_path, data_type = locate_label_paths(dataset_path, phase_name)
    file_list, adjust_path = dataset_path / f"{phase_name}.txt", False
    if file_list.exists():
        data_type, adjust_path = "txt", True
        # TODO: should i sort by name?
        with open(file_list, "r") as file:
            images_list = [dataset_path / line.rstrip() for line in file]
        labels_list = [
            Path(str(image_path).replace("images", "labels")).with_suffix(".txt") for image_path in images_list
        ]
    else:
        images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()])

    if data_type == "json":
        annotations_index, image_info_dict = create_image_metadata(labels_path)

    data = []
    valid_inputs = 0
    for idx, image_name in enumerate(track(images_list, description="Filtering data")):
        if not adjust_path and not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
            continue
        image_id = Path(image_name).stem

        if data_type == "json":
            image_info = image_info_dict.get(image_id, None)
            if image_info is None:
                continue
            annotations = annotations_index.get(image_info["id"], [])
            image_seg_annotations = scale_segmentation(annotations, image_info)
        elif data_type == "txt":
            label_path = labels_list[idx] if adjust_path else labels_path / f"{image_id}.txt"
            if not label_path.is_file():
                image_seg_annotations = []
            else:
                with open(label_path, "r") as file:
                    image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
        else:
            image_seg_annotations = []

        labels = self.load_valid_labels(image_id, image_seg_annotations)
        img_path = image_name if adjust_path else images_path / image_name
        if sort_image:
            with Image.open(img_path) as img:
                width, height = img.size
        else:
            width, height = 0, 1
        data.append((img_path, labels, width / height))
        if len(image_seg_annotations) != 0:
            valid_inputs += 1

    data = sorted(data, key=lambda x: x[2], reverse=True)

    logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
    return data

load_valid_labels(label_path, seg_data_one_img)

Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates by finding the minimum and maximum x and y values.

Parameters:

Name Type Description Default
label_path str

The filepath to the label file containing annotation data.

required
seg_data_one_img list

The actual list of annotations (in segmentation format)

required

Returns:

Type Description
Union[Tensor, None]

Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.

Source code in yolo/data/dataset.py
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
    """
    Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates
    by finding the minimum and maximum x and y values.

    Parameters:
        label_path (str): The filepath to the label file containing annotation data.
        seg_data_one_img (list): The actual list of annotations (in segmentation format)

    Returns:
        Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
    """
    bboxes = []
    for seg_data in seg_data_one_img:
        cls = seg_data[0]
        points = np.array(seg_data[1:]).reshape(-1, 2).clip(0, 1)
        valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
        if valid_points.size > 1:
            bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
            bboxes.append(bbox)

    if bboxes:
        return torch.stack(bboxes)
    else:
        logger.warning(f"No valid BBox in {label_path}")
        return torch.zeros((0, 5))

prepare_dataset(dataset_cfg, task)

Prepares dataset by downloading and unzipping if necessary.

Source code in yolo/data/preparation.py
def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
    """
    Prepares dataset by downloading and unzipping if necessary.
    """
    # TODO: do EDA of dataset
    data_dir = Path(dataset_cfg.path)
    for data_type, settings in dataset_cfg.auto_download.items():
        base_url = settings["base_url"]
        for dataset_type, dataset_args in settings.items():
            if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type:
                continue
            file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
            url = f"{base_url}{file_name}"
            local_zip_path = data_dir / file_name
            extract_to = data_dir / data_type if data_type != "annotations" else data_dir
            final_place = extract_to / dataset_type

            final_place.mkdir(parents=True, exist_ok=True)
            if check_files(final_place, dataset_args.get("file_num")):
                logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.")
                continue

            if not local_zip_path.exists():
                download_file(url, local_zip_path)
            unzip_file(local_zip_path, extract_to)

            if not check_files(final_place, dataset_args.get("file_num")):
                logger.error(f"Error verifying the {dataset_type} dataset after extraction.")

create_image_metadata(labels_path)

Create a dictionary containing image information and annotations indexed by image ID.

Parameters:

Name Type Description Default
labels_path str

The path to the annotation json file.

required

Returns:

Type Description
Dict[str, List]
  • annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
Dict[str, Dict]
  • image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
Source code in yolo/utils/dataset_utils.py
def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
    """
    Create a dictionary containing image information and annotations indexed by image ID.

    Args:
        labels_path (str): The path to the annotation json file.

    Returns:
        - annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
        - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
    """
    with open(labels_path, "r") as file:
        labels_data = json.load(file)
        id_to_idx = discretize_categories(labels_data.get("categories", [])) if "categories" in labels_data else None
        annotations_index = organize_annotations_by_image(labels_data, id_to_idx)  # check lookup is a good name?
        image_info_dict = {Path(img["file_name"]).stem: img for img in labels_data["images"]}
        return annotations_index, image_info_dict

locate_label_paths(dataset_path, phase_name)

Find the path to label files for a specified dataset and phase(e.g. training).

Parameters:

Name Type Description Default
dataset_path Path

The path to the root directory of the dataset.

required
phase_name Path

The name of the phase for which labels are being searched (e.g., "train", "val", "test").

required

Returns:

Type Description
Tuple[Path, Path]

Tuple[Path, Path]: A tuple containing the path to the labels file and the file format ("json" or "txt").

Source code in yolo/utils/dataset_utils.py
def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path]:
    """
    Find the path to label files for a specified dataset and phase(e.g. training).

    Args:
        dataset_path (Path): The path to the root directory of the dataset.
        phase_name (Path): The name of the phase for which labels are being searched (e.g., "train", "val", "test").

    Returns:
        Tuple[Path, Path]: A tuple containing the path to the labels file and the file format ("json" or "txt").
    """
    json_labels_path = dataset_path / "annotations" / f"instances_{phase_name}.json"

    txt_labels_path = dataset_path / "labels" / phase_name

    if json_labels_path.is_file():
        return json_labels_path, "json"

    elif txt_labels_path.is_dir():
        txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
        if txt_files:
            return txt_labels_path, "txt"

    logger.warning("No labels found in the specified dataset path and phase name.")
    return [], None

scale_segmentation(annotations, image_dimensions)

Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.

Parameters:

Name Type Description Default
annotations List[Dict[str, Any]]

A list of annotation dictionaries.

required
image_dimensions Dict[str, int]

A dictionary containing image dimensions (height and width).

required

Returns:

Type Description
Optional[List[List[float]]]

Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.

Source code in yolo/utils/dataset_utils.py
def scale_segmentation(
    annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
) -> Optional[List[List[float]]]:
    """
    Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.

    Args:
        annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
        image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).

    Returns:
        Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
    """
    if annotations is None:
        return None

    seg_array_with_cat = []
    h, w = image_dimensions["height"], image_dimensions["width"]
    for anno in annotations:
        category_id = anno["category_id"]
        if "segmentation" in anno:
            seg_list = [item for sublist in anno["segmentation"] for item in sublist]
        elif "bbox" in anno:
            x, y, width, height = anno["bbox"]
            seg_list = [x, y, x + width, y, x + width, y + height, x, y + height]

        scaled_seg_data = (
            np.array(seg_list).reshape(-1, 2) / [w, h]
        ).tolist()  # make the list group in x, y pairs and scaled with image width, height
        scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data))  # flatten the scaled_seg_data list
        seg_array_with_cat.append(scaled_flat_seg_data)

    return seg_array_with_cat

tensorlize(data)

Source code in yolo/utils/dataset_utils.py
def tensorlize(data):
    try:
        img_paths, bboxes, img_ratios = zip(*data)
    except ValueError as e:
        logger.error(
            ":rotating_light: This may be caused by using old cache or another version of YOLO's cache.\n"
            ":rotating_light: Please clean the cache and try running again."
        )
        raise e
    max_box = max(bbox.size(0) for bbox in bboxes)
    padded_bbox_list = []
    for bbox in bboxes:
        padding = torch.full((max_box, 5), -1, dtype=torch.float32)
        padding[: bbox.size(0)] = bbox
        padded_bbox_list.append(padding)
    bboxes = np.stack(padded_bbox_list)
    img_paths = np.array(img_paths)
    img_ratios = np.array(img_ratios)
    return img_paths, bboxes, img_ratios

collate_fn(batch)

A collate function to handle batching of images and their corresponding targets.

Parameters:

Name Type Description Default
batch list of tuples

Each tuple contains: - image (Tensor): The image tensor. - labels (Tensor): The tensor of labels for the image.

required

Returns:

Type Description
Tuple[Tensor, List[Tensor]]

Tuple[Tensor, List[Tensor]]: A tuple containing: - A tensor of batched images. - A list of tensors, each corresponding to bboxes for each image in the batch.

Source code in yolo/data/dataset.py
def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
    """
    A collate function to handle batching of images and their corresponding targets.

    Args:
        batch (list of tuples): Each tuple contains:
            - image (Tensor): The image tensor.
            - labels (Tensor): The tensor of labels for the image.

    Returns:
        Tuple[Tensor, List[Tensor]]: A tuple containing:
            - A tensor of batched images.
            - A list of tensors, each corresponding to bboxes for each image in the batch.
    """
    batch_size = len(batch)
    target_sizes = [item[1].size(0) for item in batch]
    # TODO: Improve readability of these process
    # TODO: remove maxBbox or reduce loss function memory usage
    batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
    batch_targets[:, :, 0] = -1
    for idx, target_size in enumerate(target_sizes):
        batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]

    batch_images, _, batch_reverse, batch_path = zip(*batch)
    batch_images = torch.stack(batch_images)
    batch_reverse = torch.stack(batch_reverse)

    return batch_size, batch_images, batch_targets, batch_reverse, batch_path

Loader

yolo.data.loader

DataConfig dataclass

Source code in yolo/config/schemas/data.py
@dataclass
class DataConfig:
    shuffle: bool
    batch_size: int
    pin_memory: bool
    dataloader_workers: int
    image_size: List[int]
    data_augment: Dict[str, int]
    source: Optional[Union[str, int]]
    dynamic_shape: Optional[bool]
    equivalent_batch_size: Optional[int] = 64
    drop_last: bool = True

DatasetConfig dataclass

Source code in yolo/config/schemas/data.py
@dataclass
class DatasetConfig:
    path: str
    class_num: int
    class_list: List[str]
    auto_download: Optional[DownloadOptions]

AugmentationComposer

Composes several transforms together.

Source code in yolo/data/augmentation.py
class AugmentationComposer:
    """Composes several transforms together."""

    def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640):
        self.transforms = transforms
        # TODO: handle List of image_size [640, 640]
        self.pad_resize = PadAndResize(image_size)
        self.base_size = base_size

        for transform in self.transforms:
            if hasattr(transform, "set_parent"):
                transform.set_parent(self)

    def __call__(self, image, boxes=torch.zeros(0, 5)):
        for transform in self.transforms:
            image, boxes = transform(image, boxes)
        image, boxes, rev_tensor = self.pad_resize(image, boxes)
        image = TF.to_tensor(image)
        return image, boxes, rev_tensor

YoloDataset

Bases: Dataset

Source code in yolo/data/dataset.py
class YoloDataset(Dataset):
    def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str = "train2017"):
        augment_cfg = data_cfg.data_augment
        self.image_size = data_cfg.image_size
        phase_name = dataset_cfg.get(phase, phase)
        self.batch_size = data_cfg.batch_size
        self.dynamic_shape = getattr(data_cfg, "dynamic_shape", False)
        self.base_size = mean(self.image_size)

        transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
        self.transform = AugmentationComposer(transforms, self.image_size, self.base_size)
        self.transform.get_more_data = self.get_more_data
        self.img_paths, self.bboxes, self.ratios = tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))

    def load_data(self, dataset_path: Path, phase_name: str) -> list:
        """
        Loads data from a cache or generates a new cache for a specific dataset phase.

        Parameters:
            dataset_path (Path): The root path to the dataset directory.
            phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

        Returns:
            list: The loaded data from the cache for the specified phase.
        """
        cache_path = dataset_path / f"{phase_name}.pache"

        if not cache_path.exists():
            logger.info(f":factory: Generating {phase_name} cache")
            data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
            torch.save(data, cache_path)
        else:
            try:
                data = torch.load(cache_path, weights_only=False)
            except Exception as e:
                logger.error(
                    f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
                    ":rotating_light: This may be caused by using cache from different other YOLO.\n"
                    ":rotating_light: Please clean the cache and try running again."
                )
                raise e
            logger.info(f":package: Loaded {phase_name} cache, there are {len(data)} data in total.")
        return data

    def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list:
        """
        Filters and collects dataset information by pairing images with their corresponding labels.

        Parameters:
            dataset_path (Path): Root path of the dataset directory.
            phase_name (str): Dataset split to load (e.g. ``'train'``, ``'validation'``).
            sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order.

        Returns:
            list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
        """
        images_path = dataset_path / "images" / phase_name
        labels_path, data_type = locate_label_paths(dataset_path, phase_name)
        file_list, adjust_path = dataset_path / f"{phase_name}.txt", False
        if file_list.exists():
            data_type, adjust_path = "txt", True
            # TODO: should i sort by name?
            with open(file_list, "r") as file:
                images_list = [dataset_path / line.rstrip() for line in file]
            labels_list = [
                Path(str(image_path).replace("images", "labels")).with_suffix(".txt") for image_path in images_list
            ]
        else:
            images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()])

        if data_type == "json":
            annotations_index, image_info_dict = create_image_metadata(labels_path)

        data = []
        valid_inputs = 0
        for idx, image_name in enumerate(track(images_list, description="Filtering data")):
            if not adjust_path and not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
                continue
            image_id = Path(image_name).stem

            if data_type == "json":
                image_info = image_info_dict.get(image_id, None)
                if image_info is None:
                    continue
                annotations = annotations_index.get(image_info["id"], [])
                image_seg_annotations = scale_segmentation(annotations, image_info)
            elif data_type == "txt":
                label_path = labels_list[idx] if adjust_path else labels_path / f"{image_id}.txt"
                if not label_path.is_file():
                    image_seg_annotations = []
                else:
                    with open(label_path, "r") as file:
                        image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
            else:
                image_seg_annotations = []

            labels = self.load_valid_labels(image_id, image_seg_annotations)
            img_path = image_name if adjust_path else images_path / image_name
            if sort_image:
                with Image.open(img_path) as img:
                    width, height = img.size
            else:
                width, height = 0, 1
            data.append((img_path, labels, width / height))
            if len(image_seg_annotations) != 0:
                valid_inputs += 1

        data = sorted(data, key=lambda x: x[2], reverse=True)

        logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
        return data

    def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
        """
        Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates
        by finding the minimum and maximum x and y values.

        Parameters:
            label_path (str): The filepath to the label file containing annotation data.
            seg_data_one_img (list): The actual list of annotations (in segmentation format)

        Returns:
            Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
        """
        bboxes = []
        for seg_data in seg_data_one_img:
            cls = seg_data[0]
            points = np.array(seg_data[1:]).reshape(-1, 2).clip(0, 1)
            valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
            if valid_points.size > 1:
                bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
                bboxes.append(bbox)

        if bboxes:
            return torch.stack(bboxes)
        else:
            logger.warning(f"No valid BBox in {label_path}")
            return torch.zeros((0, 5))

    def get_data(self, idx):
        img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
        valid_mask = bboxes[:, 0] != -1
        with Image.open(img_path) as img:
            img = img.convert("RGB")
        return img, torch.from_numpy(bboxes[valid_mask]), img_path

    def get_more_data(self, num: int = 1):
        indices = torch.randint(0, len(self), (num,))
        return [self.get_data(idx)[:2] for idx in indices]

    def _update_image_size(self, idx: int) -> None:
        """Update image size based on dynamic shape and batch settings."""
        batch_start_idx = (idx // self.batch_size) * self.batch_size
        image_ratio = self.ratios[batch_start_idx].clip(1 / 3, 3)
        shift = ((self.base_size / 32 * (image_ratio - 1)) // (image_ratio + 1)) * 32

        self.image_size = [int(self.base_size + shift), int(self.base_size - shift)]
        self.transform.pad_resize.set_size(self.image_size)

    def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
        img, bboxes, img_path = self.get_data(idx)

        if self.dynamic_shape:
            self._update_image_size(idx)

        img, bboxes, rev_tensor = self.transform(img, bboxes)
        bboxes[:, [1, 3]] *= self.image_size[0]
        bboxes[:, [2, 4]] *= self.image_size[1]
        return img, bboxes, rev_tensor, img_path

    def __len__(self) -> int:
        return len(self.bboxes)

load_data(dataset_path, phase_name)

Loads data from a cache or generates a new cache for a specific dataset phase.

Parameters:

Name Type Description Default
dataset_path Path

The root path to the dataset directory.

required
phase_name str

The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

required

Returns:

Name Type Description
list list

The loaded data from the cache for the specified phase.

Source code in yolo/data/dataset.py
def load_data(self, dataset_path: Path, phase_name: str) -> list:
    """
    Loads data from a cache or generates a new cache for a specific dataset phase.

    Parameters:
        dataset_path (Path): The root path to the dataset directory.
        phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.

    Returns:
        list: The loaded data from the cache for the specified phase.
    """
    cache_path = dataset_path / f"{phase_name}.pache"

    if not cache_path.exists():
        logger.info(f":factory: Generating {phase_name} cache")
        data = self.filter_data(dataset_path, phase_name, self.dynamic_shape)
        torch.save(data, cache_path)
    else:
        try:
            data = torch.load(cache_path, weights_only=False)
        except Exception as e:
            logger.error(
                f":rotating_light: Failed to load the cache at '{cache_path}'.\n"
                ":rotating_light: This may be caused by using cache from different other YOLO.\n"
                ":rotating_light: Please clean the cache and try running again."
            )
            raise e
        logger.info(f":package: Loaded {phase_name} cache, there are {len(data)} data in total.")
    return data

filter_data(dataset_path, phase_name, sort_image=False)

Filters and collects dataset information by pairing images with their corresponding labels.

Parameters:

Name Type Description Default
dataset_path Path

Root path of the dataset directory.

required
phase_name str

Dataset split to load (e.g. 'train', 'validation').

required
sort_image bool

If True, sorts the dataset by the width-to-height ratio of images in descending order.

False

Returns:

Name Type Description
list list

A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.

Source code in yolo/data/dataset.py
def filter_data(self, dataset_path: Path, phase_name: str, sort_image: bool = False) -> list:
    """
    Filters and collects dataset information by pairing images with their corresponding labels.

    Parameters:
        dataset_path (Path): Root path of the dataset directory.
        phase_name (str): Dataset split to load (e.g. ``'train'``, ``'validation'``).
        sort_image (bool): If True, sorts the dataset by the width-to-height ratio of images in descending order.

    Returns:
        list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
    """
    images_path = dataset_path / "images" / phase_name
    labels_path, data_type = locate_label_paths(dataset_path, phase_name)
    file_list, adjust_path = dataset_path / f"{phase_name}.txt", False
    if file_list.exists():
        data_type, adjust_path = "txt", True
        # TODO: should i sort by name?
        with open(file_list, "r") as file:
            images_list = [dataset_path / line.rstrip() for line in file]
        labels_list = [
            Path(str(image_path).replace("images", "labels")).with_suffix(".txt") for image_path in images_list
        ]
    else:
        images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()])

    if data_type == "json":
        annotations_index, image_info_dict = create_image_metadata(labels_path)

    data = []
    valid_inputs = 0
    for idx, image_name in enumerate(track(images_list, description="Filtering data")):
        if not adjust_path and not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
            continue
        image_id = Path(image_name).stem

        if data_type == "json":
            image_info = image_info_dict.get(image_id, None)
            if image_info is None:
                continue
            annotations = annotations_index.get(image_info["id"], [])
            image_seg_annotations = scale_segmentation(annotations, image_info)
        elif data_type == "txt":
            label_path = labels_list[idx] if adjust_path else labels_path / f"{image_id}.txt"
            if not label_path.is_file():
                image_seg_annotations = []
            else:
                with open(label_path, "r") as file:
                    image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
        else:
            image_seg_annotations = []

        labels = self.load_valid_labels(image_id, image_seg_annotations)
        img_path = image_name if adjust_path else images_path / image_name
        if sort_image:
            with Image.open(img_path) as img:
                width, height = img.size
        else:
            width, height = 0, 1
        data.append((img_path, labels, width / height))
        if len(image_seg_annotations) != 0:
            valid_inputs += 1

    data = sorted(data, key=lambda x: x[2], reverse=True)

    logger.info(f"Recorded {valid_inputs}/{len(images_list)} valid inputs")
    return data

load_valid_labels(label_path, seg_data_one_img)

Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates by finding the minimum and maximum x and y values.

Parameters:

Name Type Description Default
label_path str

The filepath to the label file containing annotation data.

required
seg_data_one_img list

The actual list of annotations (in segmentation format)

required

Returns:

Type Description
Union[Tensor, None]

Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.

Source code in yolo/data/dataset.py
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
    """
    Loads valid COCO style segmentation data (values between [0, 1]) and converts it to bounding box coordinates
    by finding the minimum and maximum x and y values.

    Parameters:
        label_path (str): The filepath to the label file containing annotation data.
        seg_data_one_img (list): The actual list of annotations (in segmentation format)

    Returns:
        Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
    """
    bboxes = []
    for seg_data in seg_data_one_img:
        cls = seg_data[0]
        points = np.array(seg_data[1:]).reshape(-1, 2).clip(0, 1)
        valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
        if valid_points.size > 1:
            bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
            bboxes.append(bbox)

    if bboxes:
        return torch.stack(bboxes)
    else:
        logger.warning(f"No valid BBox in {label_path}")
        return torch.zeros((0, 5))

StreamDataLoader

Source code in yolo/data/loader.py
class StreamDataLoader:
    def __init__(self, data_cfg: DataConfig):
        self.source = data_cfg.source
        self.running = True
        self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")

        self.transform = AugmentationComposer([], data_cfg.image_size)
        self.stop_event = Event()

        if self.is_stream:
            import cv2

            self.cap = cv2.VideoCapture(self.source)
        else:
            self.source = Path(self.source)
            self.queue = Queue()
            self.thread = Thread(target=self.load_source)
            self.thread.start()

    def load_source(self):
        if self.source.is_dir():  # image folder
            self.load_image_folder(self.source)
        elif any(self.source.suffix.lower().endswith(ext) for ext in [".mp4", ".avi", ".mkv"]):  # Video file
            self.load_video_file(self.source)
        else:  # Single image
            self.process_image(self.source)

    def load_image_folder(self, folder):
        folder_path = Path(folder)
        for file_path in folder_path.rglob("*"):
            if self.stop_event.is_set():
                break
            if file_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
                self.process_image(file_path)

    def process_image(self, image_path):
        image = Image.open(image_path).convert("RGB")
        if image is None:
            raise ValueError(f"Error loading image: {image_path}")
        self.process_frame(image)

    def load_video_file(self, video_path):
        import cv2

        cap = cv2.VideoCapture(str(video_path))
        while self.running:
            ret, frame = cap.read()
            if not ret:
                break
            self.process_frame(frame)
        cap.release()

    def process_frame(self, frame):
        if isinstance(frame, np.ndarray):
            # TODO: we don't need cv2
            import cv2

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
        origin_frame = frame
        frame, _, rev_tensor = self.transform(frame, torch.zeros(0, 5))
        frame = frame[None]
        rev_tensor = rev_tensor[None]
        if not self.is_stream:
            self.queue.put((frame, rev_tensor, origin_frame))
        else:
            self.current_frame = (frame, rev_tensor, origin_frame)

    def __iter__(self) -> Generator[Tensor, None, None]:
        return self

    def __next__(self) -> Tensor:
        if self.is_stream:
            ret, frame = self.cap.read()
            if not ret:
                self.stop()
                raise StopIteration
            self.process_frame(frame)
            return self.current_frame
        else:
            try:
                frame = self.queue.get(timeout=1)
                return frame
            except Empty:
                raise StopIteration

    def stop(self):
        self.running = False
        if self.is_stream:
            self.cap.release()
        else:
            self.thread.join(timeout=1)

    def __len__(self):
        return self.queue.qsize() if not self.is_stream else 0

collate_fn(batch)

A collate function to handle batching of images and their corresponding targets.

Parameters:

Name Type Description Default
batch list of tuples

Each tuple contains: - image (Tensor): The image tensor. - labels (Tensor): The tensor of labels for the image.

required

Returns:

Type Description
Tuple[Tensor, List[Tensor]]

Tuple[Tensor, List[Tensor]]: A tuple containing: - A tensor of batched images. - A list of tensors, each corresponding to bboxes for each image in the batch.

Source code in yolo/data/dataset.py
def collate_fn(batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
    """
    A collate function to handle batching of images and their corresponding targets.

    Args:
        batch (list of tuples): Each tuple contains:
            - image (Tensor): The image tensor.
            - labels (Tensor): The tensor of labels for the image.

    Returns:
        Tuple[Tensor, List[Tensor]]: A tuple containing:
            - A tensor of batched images.
            - A list of tensors, each corresponding to bboxes for each image in the batch.
    """
    batch_size = len(batch)
    target_sizes = [item[1].size(0) for item in batch]
    # TODO: Improve readability of these process
    # TODO: remove maxBbox or reduce loss function memory usage
    batch_targets = torch.zeros(batch_size, min(max(target_sizes), 100), 5)
    batch_targets[:, :, 0] = -1
    for idx, target_size in enumerate(target_sizes):
        batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100]

    batch_images, _, batch_reverse, batch_path = zip(*batch)
    batch_images = torch.stack(batch_images)
    batch_reverse = torch.stack(batch_reverse)

    return batch_size, batch_images, batch_targets, batch_reverse, batch_path

prepare_dataset(dataset_cfg, task)

Prepares dataset by downloading and unzipping if necessary.

Source code in yolo/data/preparation.py
def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
    """
    Prepares dataset by downloading and unzipping if necessary.
    """
    # TODO: do EDA of dataset
    data_dir = Path(dataset_cfg.path)
    for data_type, settings in dataset_cfg.auto_download.items():
        base_url = settings["base_url"]
        for dataset_type, dataset_args in settings.items():
            if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type:
                continue
            file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
            url = f"{base_url}{file_name}"
            local_zip_path = data_dir / file_name
            extract_to = data_dir / data_type if data_type != "annotations" else data_dir
            final_place = extract_to / dataset_type

            final_place.mkdir(parents=True, exist_ok=True)
            if check_files(final_place, dataset_args.get("file_num")):
                logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.")
                continue

            if not local_zip_path.exists():
                download_file(url, local_zip_path)
            unzip_file(local_zip_path, extract_to)

            if not check_files(final_place, dataset_args.get("file_num")):
                logger.error(f"Error verifying the {dataset_type} dataset after extraction.")

create_dataloader(data_cfg, dataset_cfg, task='train')

Source code in yolo/data/loader.py
def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train"):
    if task == "inference":
        return StreamDataLoader(data_cfg)

    if getattr(dataset_cfg, "auto_download", False):
        prepare_dataset(dataset_cfg, task)
    dataset = YoloDataset(data_cfg, dataset_cfg, task)

    return DataLoader(
        dataset,
        batch_size=data_cfg.batch_size,
        num_workers=data_cfg.dataloader_workers,
        pin_memory=data_cfg.pin_memory,
        collate_fn=collate_fn,
        drop_last=data_cfg.drop_last,
    )

Augmentation

yolo.data.augmentation

AugmentationComposer

Composes several transforms together.

Source code in yolo/data/augmentation.py
class AugmentationComposer:
    """Composes several transforms together."""

    def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640):
        self.transforms = transforms
        # TODO: handle List of image_size [640, 640]
        self.pad_resize = PadAndResize(image_size)
        self.base_size = base_size

        for transform in self.transforms:
            if hasattr(transform, "set_parent"):
                transform.set_parent(self)

    def __call__(self, image, boxes=torch.zeros(0, 5)):
        for transform in self.transforms:
            image, boxes = transform(image, boxes)
        image, boxes, rev_tensor = self.pad_resize(image, boxes)
        image = TF.to_tensor(image)
        return image, boxes, rev_tensor

RemoveOutliers

Removes outlier bounding boxes that are too small or have invalid dimensions.

Source code in yolo/data/augmentation.py
class RemoveOutliers:
    """Removes outlier bounding boxes that are too small or have invalid dimensions."""

    def __init__(self, min_box_area=1e-8):
        """
        Args:
            min_box_area (float): Minimum area for a box to be kept, as a fraction of the image area.
        """
        self.min_box_area = min_box_area

    def __call__(self, image, boxes):
        """
        Args:
            image (PIL.Image): The cropped image.
            boxes (torch.Tensor): Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max).
        Returns:
            PIL.Image: The input image (unchanged).
            torch.Tensor: Filtered bounding boxes.
        """
        box_areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 4] - boxes[:, 2])

        valid_boxes = (box_areas > self.min_box_area) & (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 4] > boxes[:, 2])

        return image, boxes[valid_boxes]

__init__(min_box_area=1e-08)

Parameters:

Name Type Description Default
min_box_area float

Minimum area for a box to be kept, as a fraction of the image area.

1e-08
Source code in yolo/data/augmentation.py
def __init__(self, min_box_area=1e-8):
    """
    Args:
        min_box_area (float): Minimum area for a box to be kept, as a fraction of the image area.
    """
    self.min_box_area = min_box_area

__call__(image, boxes)

Parameters:

Name Type Description Default
image Image

The cropped image.

required
boxes Tensor

Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max).

required

Returns: PIL.Image: The input image (unchanged). torch.Tensor: Filtered bounding boxes.

Source code in yolo/data/augmentation.py
def __call__(self, image, boxes):
    """
    Args:
        image (PIL.Image): The cropped image.
        boxes (torch.Tensor): Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max).
    Returns:
        PIL.Image: The input image (unchanged).
        torch.Tensor: Filtered bounding boxes.
    """
    box_areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 4] - boxes[:, 2])

    valid_boxes = (box_areas > self.min_box_area) & (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 4] > boxes[:, 2])

    return image, boxes[valid_boxes]

PadAndResize

Source code in yolo/data/augmentation.py
class PadAndResize:
    def __init__(self, image_size, background_color=(114, 114, 114)):
        """Initialize the object with the target image size."""
        self.target_width, self.target_height = image_size
        self.background_color = background_color

    def set_size(self, image_size: List[int]):
        self.target_width, self.target_height = image_size

    def __call__(self, image: Image, boxes):
        img_width, img_height = image.size
        scale = min(self.target_width / img_width, self.target_height / img_height)
        new_width, new_height = int(img_width * scale), int(img_height * scale)

        resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

        pad_left = (self.target_width - new_width) // 2
        pad_top = (self.target_height - new_height) // 2
        padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color)
        padded_image.paste(resized_image, (pad_left, pad_top))

        boxes[:, [1, 3]] = (boxes[:, [1, 3]] * new_width + pad_left) / self.target_width
        boxes[:, [2, 4]] = (boxes[:, [2, 4]] * new_height + pad_top) / self.target_height

        transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top])
        return padded_image, boxes, transform_info

__init__(image_size, background_color=(114, 114, 114))

Initialize the object with the target image size.

Source code in yolo/data/augmentation.py
def __init__(self, image_size, background_color=(114, 114, 114)):
    """Initialize the object with the target image size."""
    self.target_width, self.target_height = image_size
    self.background_color = background_color

HorizontalFlip

Randomly horizontally flips the image along with the bounding boxes.

Source code in yolo/data/augmentation.py
class HorizontalFlip:
    """Randomly horizontally flips the image along with the bounding boxes."""

    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, boxes):
        if torch.rand(1) < self.prob:
            image = TF.hflip(image)
            boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
        return image, boxes

VerticalFlip

Randomly vertically flips the image along with the bounding boxes.

Source code in yolo/data/augmentation.py
class VerticalFlip:
    """Randomly vertically flips the image along with the bounding boxes."""

    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, boxes):
        if torch.rand(1) < self.prob:
            image = TF.vflip(image)
            boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
        return image, boxes

Mosaic

Applies the Mosaic augmentation to a batch of images and their corresponding boxes.

Source code in yolo/data/augmentation.py
class Mosaic:
    """Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""

    def __init__(self, prob=0.5):
        self.prob = prob
        self.parent = None

    def set_parent(self, parent):
        self.parent = parent

    def __call__(self, image, boxes):
        if torch.rand(1) >= self.prob:
            return image, boxes

        assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."

        img_sz = self.parent.base_size  # Assuming `image_size` is defined in parent
        more_data = self.parent.get_more_data(3)  # get 3 more images randomly

        data = [(image, boxes)] + more_data
        mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz), (114, 114, 114))
        vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
        center = np.array([img_sz, img_sz])
        all_labels = []

        for (image, boxes), vector in zip(data, vectors):
            this_w, this_h = image.size
            coord = tuple(center + vector * np.array([this_w, this_h]))

            mosaic_image.paste(image, coord)
            xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
            xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
            xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
            ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
            ymax = (ymax * this_h + coord[1]) / (2 * img_sz)

            adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
            all_labels.append(adjusted_boxes)

        all_labels = torch.cat(all_labels, dim=0)
        mosaic_image = mosaic_image.resize((img_sz, img_sz))
        return mosaic_image, all_labels

MixUp

Applies the MixUp augmentation to a pair of images and their corresponding boxes.

Source code in yolo/data/augmentation.py
class MixUp:
    """Applies the MixUp augmentation to a pair of images and their corresponding boxes."""

    def __init__(self, prob=0.5, alpha=1.0):
        self.alpha = alpha
        self.prob = prob
        self.parent = None

    def set_parent(self, parent):
        """Set the parent dataset object for accessing dataset methods."""
        self.parent = parent

    def __call__(self, image, boxes):
        if torch.rand(1) >= self.prob:
            return image, boxes

        assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."

        # Retrieve another image and its boxes randomly from the dataset
        image2, boxes2 = self.parent.get_more_data()[0]

        # Calculate the mixup lambda parameter
        lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5

        # Mix images
        image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
        mixed_image = lam * image1 + (1 - lam) * image2

        # Merge bounding boxes
        merged_boxes = torch.cat((boxes, boxes2))

        return TF.to_pil_image(mixed_image), merged_boxes

set_parent(parent)

Set the parent dataset object for accessing dataset methods.

Source code in yolo/data/augmentation.py
def set_parent(self, parent):
    """Set the parent dataset object for accessing dataset methods."""
    self.parent = parent

RandomCrop

Randomly crops the image to half its size along with adjusting the bounding boxes.

Source code in yolo/data/augmentation.py
class RandomCrop:
    """Randomly crops the image to half its size along with adjusting the bounding boxes."""

    def __init__(self, prob=0.5):
        """
        Args:
            prob (float): Probability of applying the crop.
        """
        self.prob = prob

    def __call__(self, image, boxes):
        if torch.rand(1) < self.prob:
            original_width, original_height = image.size
            crop_height, crop_width = original_height // 2, original_width // 2
            top = torch.randint(0, original_height - crop_height + 1, (1,)).item()
            left = torch.randint(0, original_width - crop_width + 1, (1,)).item()

            image = TF.crop(image, top, left, crop_height, crop_width)

            boxes[:, [1, 3]] = boxes[:, [1, 3]] * original_width - left
            boxes[:, [2, 4]] = boxes[:, [2, 4]] * original_height - top

            boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, crop_width)
            boxes[:, [2, 4]] = boxes[:, [2, 4]].clamp(0, crop_height)

            boxes[:, [1, 3]] /= crop_width
            boxes[:, [2, 4]] /= crop_height

        return image, boxes

__init__(prob=0.5)

Parameters:

Name Type Description Default
prob float

Probability of applying the crop.

0.5
Source code in yolo/data/augmentation.py
def __init__(self, prob=0.5):
    """
    Args:
        prob (float): Probability of applying the crop.
    """
    self.prob = prob

Preparation

yolo.data.preparation

logger = logging.getLogger('yolo') module-attribute

DatasetConfig dataclass

Source code in yolo/config/schemas/data.py
@dataclass
class DatasetConfig:
    path: str
    class_num: int
    class_list: List[str]
    auto_download: Optional[DownloadOptions]

download_file(url, destination)

Downloads a file from the specified URL to the destination path with progress logging.

Source code in yolo/data/preparation.py
def download_file(url, destination: Path):
    """
    Downloads a file from the specified URL to the destination path with progress logging.
    """
    with requests.get(url, stream=True) as response:
        response.raise_for_status()
        total_size = int(response.headers.get("content-length", 0))
        with Progress(
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            "{task.completed}/{task.total} bytes",
            "•",
            TimeRemainingColumn(),
        ) as progress:
            task = progress.add_task(f"📥 Downloading {destination.name }...", total=total_size)
            with open(destination, "wb") as file:
                for data in response.iter_content(chunk_size=1024 * 1024):  # 1 MB chunks
                    file.write(data)
                    progress.update(task, advance=len(data))
    logger.info(":white_check_mark: Download completed.")

unzip_file(source, destination)

Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.

Source code in yolo/data/preparation.py
def unzip_file(source: Path, destination: Path):
    """
    Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
    """
    logger.info(f"Unzipping {source.name}...")
    with zipfile.ZipFile(source, "r") as zip_ref:
        zip_ref.extractall(destination)
    source.unlink()
    logger.info(f"Removed {source}.")

check_files(directory, expected_count=None)

Returns True if the number of files in the directory matches expected_count, False otherwise.

Source code in yolo/data/preparation.py
def check_files(directory, expected_count=None):
    """
    Returns True if the number of files in the directory matches expected_count, False otherwise.
    """
    files = [f.name for f in Path(directory).iterdir() if f.is_file()]
    return len(files) == expected_count if expected_count is not None else bool(files)

prepare_dataset(dataset_cfg, task)

Prepares dataset by downloading and unzipping if necessary.

Source code in yolo/data/preparation.py
def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
    """
    Prepares dataset by downloading and unzipping if necessary.
    """
    # TODO: do EDA of dataset
    data_dir = Path(dataset_cfg.path)
    for data_type, settings in dataset_cfg.auto_download.items():
        base_url = settings["base_url"]
        for dataset_type, dataset_args in settings.items():
            if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type:
                continue
            file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
            url = f"{base_url}{file_name}"
            local_zip_path = data_dir / file_name
            extract_to = data_dir / data_type if data_type != "annotations" else data_dir
            final_place = extract_to / dataset_type

            final_place.mkdir(parents=True, exist_ok=True)
            if check_files(final_place, dataset_args.get("file_num")):
                logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.")
                continue

            if not local_zip_path.exists():
                download_file(url, local_zip_path)
            unzip_file(local_zip_path, extract_to)

            if not check_files(final_place, dataset_args.get("file_num")):
                logger.error(f"Error verifying the {dataset_type} dataset after extraction.")

prepare_weight(download_link=None, weight_path=Path('v9-c.pt'))

Source code in yolo/data/preparation.py
def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path("v9-c.pt")):
    weight_name = weight_path.name
    if download_link is None:
        download_link = "https://github.com/shreyaskamathkm/yolo/releases/download/v1-trained_models/"
    weight_link = f"{download_link}{weight_name}"

    if not weight_path.parent.is_dir():
        weight_path.parent.mkdir(parents=True, exist_ok=True)

    if weight_path.exists():
        logger.info(f"Weight file '{weight_path}' already exists.")
    try:
        download_file(weight_link, weight_path)
    except requests.exceptions.RequestException as e:
        logger.warning(f"Failed to download the weight file: {e}")