Skip to content

Common Layers in Model

Activation

mindcv.models.layers.activation.Swish

Bases: Cell

Swish activation function: x * sigmoid(x).

Return

Tensor

Example

x = Tensor(((20, 16), (50, 50)), mindspore.float32) Swish()(x)

Source code in mindcv/models/layers/activation.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Swish(nn.Cell):
    """
    Swish activation function: x * sigmoid(x).

    Args:
        None

    Return:
        Tensor

    Example:
        >>> x = Tensor(((20, 16), (50, 50)), mindspore.float32)
        >>> Swish()(x)
    """

    def __init__(self):
        super().__init__()
        self.result = None
        self.sigmoid = nn.Sigmoid()

    def construct(self, x):
        result = x * self.sigmoid(x)
        return result

DropPath

mindcv.models.layers.drop_path.DropPath

Bases: Cell

DropPath (Stochastic Depth) regularization layers

Source code in mindcv/models/layers/drop_path.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class DropPath(nn.Cell):
    """DropPath (Stochastic Depth) regularization layers"""

    def __init__(
        self,
        drop_prob: float = 0.0,
        scale_by_keep: bool = True,
    ) -> None:
        super().__init__()
        self.keep_prob = 1.0 - drop_prob
        self.scale_by_keep = scale_by_keep
        self.dropout = Dropout(p=drop_prob)

    def construct(self, x: Tensor) -> Tensor:
        if self.keep_prob == 1.0 or not self.training:
            return x
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = self.dropout(ones(shape))
        if not self.scale_by_keep:
            random_tensor = ops.mul(random_tensor, self.keep_prob)
        return x * random_tensor

Identity

mindcv.models.layers.identity.Identity

Bases: Cell

Identity

Source code in mindcv/models/layers/identity.py
5
6
7
8
9
class Identity(nn.Cell):
    """Identity"""

    def construct(self, x):
        return x

MLP

mindcv.models.layers.mlp.Mlp

Bases: Cell

Source code in mindcv/models/layers/mlp.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Mlp(nn.Cell):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Optional[nn.Cell] = nn.GELU,
        drop: float = 0.0,
    ) -> None:
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=True)
        self.act = act_layer()
        self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=True)
        self.drop = Dropout(p=drop)

    def construct(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Patch Embedding

mindcv.models.layers.patch_embed.PatchEmbed

Bases: Cell

Image to Patch Embedding

PARAMETER DESCRIPTION
image_size

Image size. Default: 224.

TYPE: int DEFAULT: 224

patch_size

Patch token size. Default: 4.

TYPE: int DEFAULT: 4

in_chans

Number of input image channels. Default: 3.

TYPE: int DEFAULT: 3

embed_dim

Number of linear projection output channels. Default: 96.

TYPE: int DEFAULT: 96

norm_layer

Normalization layer. Default: None

TYPE: Cell DEFAULT: None

Source code in mindcv/models/layers/patch_embed.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class PatchEmbed(nn.Cell):
    """Image to Patch Embedding

    Args:
        image_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Cell, optional): Normalization layer. Default: None
    """
    output_fmt: Format

    def __init__(
        self,
        image_size: Optional[int] = 224,
        patch_size: int = 4,
        in_chans: int = 3,
        embed_dim: int = 96,
        norm_layer: Optional[nn.Cell] = None,
        flatten: bool = True,
        output_fmt: Optional[str] = None,
        bias: bool = True,
        strict_img_size: bool = True,
        dynamic_img_pad: bool = False,
    ) -> None:
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        if image_size is not None:
            self.image_size = to_2tuple(image_size)
            self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
            self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
        else:
            self.image_size = None
            self.patches_resolution = None
            self.num_patches = None

        if output_fmt is not None:
            self.flatten = False
            self.output_fmt = Format(output_fmt)
        else:
            self.flatten = flatten
            self.output_fmt = Format.NCHW

        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = dynamic_img_pad
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
                              pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal")

        if norm_layer is not None:
            if isinstance(embed_dim, int):
                embed_dim = (embed_dim,)
            self.norm = norm_layer(embed_dim, epsilon=1e-5)
        else:
            self.norm = None

    def construct(self, x: Tensor) -> Tensor:
        """docstring"""
        B, C, H, W = x.shape
        if self.image_size is not None:
            if self.strict_img_size:
                if (H, W) != (self.image_size[0], self.image_size[1]):
                    raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
                                     f"{self.image_size[1]}).")
            elif not self.dynamic_img_pad:
                if H % self.patch_size[0] != 0:
                    raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
                if W % self.patch_size[1] != 0:
                    raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
        if self.dynamic_img_pad:
            pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
            pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
            x = ops.pad(x, (0, pad_w, 0, pad_h))

        # FIXME look at relaxing size constraints
        x = self.proj(x)
        if self.flatten:
            x = ops.Reshape()(x, (B, self.embed_dim, -1))  # B Ph*Pw C
            x = ops.Transpose()(x, (0, 2, 1))
        elif self.output_fmt != "NCHW":
            x = nchw_to(x, self.output_fmt)
        if self.norm is not None:
            x = self.norm(x)
        return x

mindcv.models.layers.patch_embed.PatchEmbed.construct(x)

docstring

Source code in mindcv/models/layers/patch_embed.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def construct(self, x: Tensor) -> Tensor:
    """docstring"""
    B, C, H, W = x.shape
    if self.image_size is not None:
        if self.strict_img_size:
            if (H, W) != (self.image_size[0], self.image_size[1]):
                raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
                                 f"{self.image_size[1]}).")
        elif not self.dynamic_img_pad:
            if H % self.patch_size[0] != 0:
                raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
            if W % self.patch_size[1] != 0:
                raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
    if self.dynamic_img_pad:
        pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
        pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
        x = ops.pad(x, (0, pad_w, 0, pad_h))

    # FIXME look at relaxing size constraints
    x = self.proj(x)
    if self.flatten:
        x = ops.Reshape()(x, (B, self.embed_dim, -1))  # B Ph*Pw C
        x = ops.Transpose()(x, (0, 2, 1))
    elif self.output_fmt != "NCHW":
        x = nchw_to(x, self.output_fmt)
    if self.norm is not None:
        x = self.norm(x)
    return x

Pooling

mindcv.models.layers.pooling.GlobalAvgPooling

Bases: Cell

GlobalAvgPooling, same as torch.nn.AdaptiveAvgPool2d when output shape is 1

Source code in mindcv/models/layers/pooling.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class GlobalAvgPooling(nn.Cell):
    """
    GlobalAvgPooling, same as torch.nn.AdaptiveAvgPool2d when output shape is 1
    """

    def __init__(self, keep_dims: bool = False) -> None:
        super().__init__()
        self.keep_dims = keep_dims

    def construct(self, x):
        x = ops.mean(x, axis=(2, 3), keep_dims=self.keep_dims)
        return x

Selective Kernel

mindcv.models.layers.selective_kernel.SelectiveKernelAttn

Bases: Cell

Selective Kernel Attention Module Selective Kernel attention mechanism factored out into its own module.

Source code in mindcv/models/layers/selective_kernel.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class SelectiveKernelAttn(nn.Cell):
    """Selective Kernel Attention Module
    Selective Kernel attention mechanism factored out into its own module.
    """

    def __init__(
        self,
        channels: int,
        num_paths: int = 2,
        attn_channels: int = 32,
        activation: Optional[nn.Cell] = nn.ReLU,
        norm: Optional[nn.Cell] = nn.BatchNorm2d,
    ):
        super().__init__()
        self.num_paths = num_paths
        self.mean = GlobalAvgPooling(keep_dims=True)
        self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, has_bias=False)
        self.bn = norm(attn_channels)
        self.act = activation()
        self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1)
        self.softmax = nn.Softmax(axis=1)

    def construct(self, x: Tensor) -> Tensor:
        x = self.mean((x.sum(1)))
        x = self.fc_reduce(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.fc_select(x)
        b, c, h, w = x.shape
        x = x.reshape((b, self.num_paths, c // self.num_paths, h, w))
        x = self.softmax(x)
        return x

mindcv.models.layers.selective_kernel.SelectiveKernel

Bases: Cell

Selective Kernel Convolution Module As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. Largest change is the input split, which divides the input channels across each convolution path, this can be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps the parameter count from ballooning when the convolutions themselves don't have groups, but still provides a noteworthy increase in performance over similar param count models without this attention layer. -Ross W Args: in_channels (int): module input (feature) channel count out_channels (int): module output (feature) channel count kernel_size (int, list): kernel size for each convolution branch stride (int): stride for convolutions dilation (int): dilation for module as a whole, impacts dilation of each branch groups (int): number of groups for each branch rd_ratio (int, float): reduction factor for attention features rd_channels(int): reduction channels can be specified directly by arg (if rd_channels is set) rd_divisor(int): divisor can be specified to keep channels keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count activation (nn.Module): activation layer to use norm (nn.Module): batchnorm/norm layer to use

Source code in mindcv/models/layers/selective_kernel.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class SelectiveKernel(nn.Cell):
    """Selective Kernel Convolution Module
    As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
    Largest change is the input split, which divides the input channels across each convolution path, this can
    be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
    the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
    a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
    Args:
        in_channels (int):  module input (feature) channel count
        out_channels (int):  module output (feature) channel count
        kernel_size (int, list): kernel size for each convolution branch
        stride (int): stride for convolutions
        dilation (int): dilation for module as a whole, impacts dilation of each branch
        groups (int): number of groups for each branch
        rd_ratio (int, float): reduction factor for attention features
        rd_channels(int): reduction channels can be specified directly by arg (if rd_channels is set)
        rd_divisor(int): divisor can be specified to keep channels
        keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
        split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
            can be viewed as grouping by path, output expands to module out_channels count
        activation (nn.Module): activation layer to use
        norm (nn.Module): batchnorm/norm layer to use
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        kernel_size: Optional[Union[int, List]] = None,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        rd_ratio: float = 1.0 / 16,
        rd_channels: Optional[int] = None,
        rd_divisor: int = 8,
        keep_3x3: bool = True,
        split_input: bool = True,
        activation: Optional[nn.Cell] = nn.ReLU,
        norm: Optional[nn.Cell] = nn.BatchNorm2d,
    ):
        super().__init__()
        out_channels = out_channels or in_channels
        kernel_size = kernel_size or [3, 5]  # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
        _kernel_valid(kernel_size)
        if not isinstance(kernel_size, list):
            kernel_size = [kernel_size] * 2
        if keep_3x3:
            dilation = [dilation * (k - 1) // 2 for k in kernel_size]
            kernel_size = [3] * len(kernel_size)
        else:
            dilation = [dilation] * len(kernel_size)
        self.num_paths = len(kernel_size)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.split_input = split_input
        if self.split_input:
            assert in_channels % self.num_paths == 0
            in_channels = in_channels // self.num_paths
        groups = min(out_channels, groups)
        self.split = Split(split_size_or_sections=self.in_channels // self.num_paths, output_num=self.num_paths, axis=1)

        self.paths = nn.CellList([
            Conv2dNormActivation(in_channels, out_channels, kernel_size=k, stride=stride, groups=groups,
                                 dilation=d, activation=activation, norm=norm)
            for k, d in zip(kernel_size, dilation)
        ])

        attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
        self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)

    def construct(self, x: Tensor) -> Tensor:
        x_paths = []
        if self.split_input:
            x_split = self.split(x)
            for i, op in enumerate(self.paths):
                x_paths.append(op(x_split[i]))
        else:
            for op in self.paths:
                x_paths.append(op(x))

        x = ops.stack(x_paths, axis=1)
        x_attn = self.attn(x)
        x = x * x_attn
        x = x.sum(1)
        return x

Squeeze and Excite

mindcv.models.layers.squeeze_excite.SqueezeExcite

Bases: Cell

SqueezeExcite Module as defined in original SE-Nets with a few additions. Additions include: * divisor can be specified to keep channels % div == 0 (default: 8) * reduction channels can be specified directly by arg (if rd_channels is set) * reduction channels can be specified by float rd_ratio (default: 1/16) * customizable activation, normalization, and gate layer

Source code in mindcv/models/layers/squeeze_excite.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class SqueezeExcite(nn.Cell):
    """SqueezeExcite Module as defined in original SE-Nets with a few additions.
    Additions include:
        * divisor can be specified to keep channels % div == 0 (default: 8)
        * reduction channels can be specified directly by arg (if rd_channels is set)
        * reduction channels can be specified by float rd_ratio (default: 1/16)
        * customizable activation, normalization, and gate layer
    """

    def __init__(
        self,
        in_channels: int,
        rd_ratio: float = 1.0 / 16,
        rd_channels: Optional[int] = None,
        rd_divisor: int = 8,
        norm: Optional[nn.Cell] = None,
        act_layer: nn.Cell = nn.ReLU,
        gate_layer: nn.Cell = nn.Sigmoid,
    ) -> None:
        super().__init__()
        self.norm = norm
        self.act = act_layer()
        self.gate = gate_layer()
        if not rd_channels:
            rd_channels = make_divisible(in_channels * rd_ratio, rd_divisor)

        self.conv_reduce = nn.Conv2d(
            in_channels=in_channels,
            out_channels=rd_channels,
            kernel_size=1,
            has_bias=True,
        )
        if self.norm:
            self.bn = nn.BatchNorm2d(rd_channels)
        self.conv_expand = nn.Conv2d(
            in_channels=rd_channels,
            out_channels=in_channels,
            kernel_size=1,
            has_bias=True,
        )
        self.pool = GlobalAvgPooling(keep_dims=True)

    def construct(self, x: Tensor) -> Tensor:
        x_se = self.pool(x)
        x_se = self.conv_reduce(x_se)
        if self.norm:
            x_se = self.bn(x_se)
        x_se = self.act(x_se)
        x_se = self.conv_expand(x_se)
        x_se = self.gate(x_se)
        x = x * x_se
        return x

mindcv.models.layers.squeeze_excite.SqueezeExciteV2

Bases: Cell

SqueezeExcite Module as defined in original SE-Nets with a few additions. V1 uses 1x1conv to replace fc layers, and V2 uses nn.Dense to implement directly.

Source code in mindcv/models/layers/squeeze_excite.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class SqueezeExciteV2(nn.Cell):
    """SqueezeExcite Module as defined in original SE-Nets with a few additions.
    V1 uses 1x1conv to replace fc layers, and V2 uses nn.Dense to implement directly.
    """

    def __init__(
        self,
        in_channels: int,
        rd_ratio: float = 1.0 / 16,
        rd_channels: Optional[int] = None,
        rd_divisor: int = 8,
        norm: Optional[nn.Cell] = None,
        act_layer: nn.Cell = nn.ReLU,
        gate_layer: nn.Cell = nn.Sigmoid,
    ) -> None:
        super().__init__()
        self.norm = norm
        self.act = act_layer()
        self.gate = gate_layer()
        if not rd_channels:
            rd_channels = make_divisible(in_channels * rd_ratio, rd_divisor)

        self.conv_reduce = nn.Dense(
            in_channels=in_channels,
            out_channels=rd_channels,
            has_bias=True,
        )
        if self.norm:
            self.bn = nn.BatchNorm2d(rd_channels)
        self.conv_expand = nn.Dense(
            in_channels=rd_channels,
            out_channels=in_channels,
            has_bias=True,
        )
        self.pool = GlobalAvgPooling(keep_dims=False)

    def construct(self, x: Tensor) -> Tensor:
        x_se = self.pool(x)
        x_se = self.conv_reduce(x_se)
        if self.norm:
            x_se = self.bn(x_se)
        x_se = self.act(x_se)
        x_se = self.conv_expand(x_se)
        x_se = self.gate(x_se)
        x_se = ops.expand_dims(x_se, -1)
        x_se = ops.expand_dims(x_se, -1)
        x = x * x_se
        return x