Skip to content

Models

Create Model

mindyolo.models.model_factory.create_model(model_name, model_cfg=None, in_channels=3, num_classes=80, checkpoint_path='', **kwargs)

Source code in mindyolo/models/model_factory.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def create_model(
    model_name: str,
    model_cfg: dict = None,
    in_channels: int = 3,
    num_classes: int = 80,
    checkpoint_path: str = "",
    **kwargs,
):
    model_args = dict(cfg=model_cfg, num_classes=num_classes, in_channels=in_channels)
    kwargs = {k: v for k, v in kwargs.items() if v is not None}

    if not is_model(model_name):
        raise RuntimeError(f"Unknown model {model_name}")

    create_fn = model_entrypoint(model_name)
    model = create_fn(**model_args, **kwargs)

    if checkpoint_path:
        assert os.path.isfile(checkpoint_path) and checkpoint_path.endswith(
            ".ckpt"
        ), f"[{checkpoint_path}] not a ckpt file."
        checkpoint_param = load_checkpoint(checkpoint_path)
        load_param_into_net(model, checkpoint_param)
        logger.info(f"Load checkpoint from [{checkpoint_path}] success.")

    return model

yolov3_head

yolov4_head

yolov5_head

yolov7_head

yolov8_head

yolox_head

initializer

focal_loss

iou_loss

loss_factory

yolov3_loss

yolov4_loss

yolov5_loss

yolov7_loss

yolov8_loss

yolox_loss

yolov3

yolov4

yolov5

yolov7

yolov8

yolox