Skip to content

Models

Create Model

mindcv.models.model_factory.create_model(model_name, num_classes=1000, pretrained=False, in_channels=3, checkpoint_path='', ema=False, auto_mapping=False, **kwargs)

Creates model by name.

PARAMETER DESCRIPTION
model_name

The name of model.

TYPE: str

num_classes

The number of classes. Default: 1000.

TYPE: int DEFAULT: 1000

pretrained

Whether to load the pretrained model. Default: False.

TYPE: bool DEFAULT: False

in_channels

The input channels. Default: 3.

TYPE: int DEFAULT: 3

checkpoint_path

The path of checkpoint files. Default: "".

TYPE: str DEFAULT: ''

ema

Whether use ema method. Default: False.

TYPE: bool DEFAULT: False

auto_mapping

Whether to automatically map the names of checkpoint weights to the names of model weights when there are differences in names. Default: False.

TYPE: bool DEFAULT: False

**kwargs

additional args, e.g., "features_only", "out_indices".

DEFAULT: {}

Source code in mindcv/models/model_factory.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def create_model(
    model_name: str,
    num_classes: int = 1000,
    pretrained: bool = False,
    in_channels: int = 3,
    checkpoint_path: str = "",
    ema: bool = False,
    auto_mapping: bool = False,
    **kwargs,
):
    r"""Creates model by name.

    Args:
        model_name (str):  The name of model.
        num_classes (int): The number of classes. Default: 1000.
        pretrained (bool): Whether to load the pretrained model. Default: False.
        in_channels (int): The input channels. Default: 3.
        checkpoint_path (str): The path of checkpoint files. Default: "".
        ema (bool): Whether use ema method. Default: False.
        auto_mapping (bool): Whether to automatically map the names of checkpoint weights
            to the names of model weights when there are differences in names. Default: False.
        **kwargs: additional args, e.g., "features_only", "out_indices".
    """

    if checkpoint_path != "" and pretrained:
        raise ValueError("checkpoint_path is mutually exclusive with pretrained")

    model_args = dict(num_classes=num_classes, pretrained=pretrained, in_channels=in_channels)
    kwargs = {k: v for k, v in kwargs.items() if v is not None}

    if not is_model(model_name):
        raise RuntimeError(f"Unknown model {model_name}")

    create_fn = model_entrypoint(model_name)
    model = create_fn(**model_args, **kwargs)

    if checkpoint_path:
        load_model_checkpoint(model, checkpoint_path, ema, auto_mapping)

    return model

bit

cait

cmt

coat

convit

convnext

crossvit

densenet

dpn

edgenext

efficientnet

features

ghostnet

halonet

hrnet

inceptionv3

inceptionv4

mae

mixnet

mlpmixer

mnasnet

mobilenetv1

mobilenetv2

mobilenetv3

mobilevit

nasnet

pit

poolformer

pvt

pvtv2

regnet

repmlp

repvgg

res2net

resnest

resnet

resnetv2

rexnet

senet

shufflenetv1

shufflenetv2

sknet

squeezenet

swintransformer

swintransformerv2

vgg

visformer

vit

volo

xcit