Skip to content

Data

Auto Augmentation

mindcv.data.auto_augment.auto_augment_transform(configs, hparams)

Create a AutoAugment transform Args: configs: A string that defines the automatic augmentation configuration. It is composed of multiple parts separated by dashes ("-"). The first part defines the AutoAugment policy ('autoaug', 'autoaugr' or '3a': 'autoaug' for the original AutoAugment policy with PosterizeOriginal, 'autoaugr' for the AutoAugment policy with PosterizeIncreasing operation, '3a' for the AutoAugment only with 3 augmentations.) There is no order requirement for the remaining config parts.

    - mstd: Float standard deviation of applied magnitude noise.

    Example: 'autoaug-mstd0.5' will be automatically augment using the autoaug strategy
    and magnitude_std 0.5.
hparams: Other hparams of the automatic augmentation scheme.
Source code in mindcv/data/auto_augment.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def auto_augment_transform(configs, hparams):
    """
    Create a AutoAugment transform
    Args:
        configs: A string that defines the automatic augmentation configuration.
            It is composed of multiple parts separated by dashes ("-"). The first part defines
            the AutoAugment policy ('autoaug', 'autoaugr' or '3a':
            'autoaug' for the original AutoAugment policy with PosterizeOriginal,
            'autoaugr' for the AutoAugment policy with PosterizeIncreasing operation,
             '3a' for the AutoAugment only with 3 augmentations.)
            There is no order requirement for the remaining config parts.

            - mstd: Float standard deviation of applied magnitude noise.

            Example: 'autoaug-mstd0.5' will be automatically augment using the autoaug strategy
            and magnitude_std 0.5.
        hparams: Other hparams of the automatic augmentation scheme.
    """
    config = configs.split("-")
    policy_name = config[0]
    config = config[1:]
    hparams.setdefault("magnitude_std", 0.5)  # default magnitude_std is set to 0.5
    for c in config:
        cs = re.split(r"(\d.*)", c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == "mstd":
            # noise param injected via hparams for now
            hparams.setdefault("magnitude_std", float(val))
        else:
            assert False, "Unknown AutoAugment config section"
    aa_policy = auto_augment_policy(policy_name, hparams=hparams)
    return AutoAugment(aa_policy)

mindcv.data.auto_augment.rand_augment_transform(configs, hparams)

Create a RandAugment transform Args: configs: A string that defines the random augmentation configuration. It is composed of multiple parts separated by dashes ("-"). The first part defines the AutoAugment policy ('randaug' policy). There is no order requirement for the remaining config parts.

    - m: Integer magnitude of rand augment. Default: 10
    - n: Integer num layer (number of transform operations selected for each image). Default: 2
    - w: Integer probability weight index (the index that affects a group of weights selected by operations).
    - mstd: Floating standard deviation of applied magnitude noise,
        or uniform sampling at infinity (or greater than 100).
    - mmax: Set the upper range limit for magnitude to a value
        other than the default value of _LEVEL_DENOM (10).
    - inc: Integer (bool), using the severity increase with magnitude (default: 0).

    Example: 'randaug-w0-n3-mstd0.5' will be random augment
        using the weights 0, num_layers 3, magnitude_std 0.5.
hparams: Other hparams (kwargs) for the RandAugmentation scheme.
Source code in mindcv/data/auto_augment.py
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def rand_augment_transform(configs, hparams):
    """
    Create a RandAugment transform
    Args:
        configs: A string that defines the random augmentation configuration.
            It is composed of multiple parts separated by dashes ("-").
            The first part defines the AutoAugment policy ('randaug' policy).
            There is no order requirement for the remaining config parts.

            - m: Integer magnitude of rand augment. Default: 10
            - n: Integer num layer (number of transform operations selected for each image). Default: 2
            - w: Integer probability weight index (the index that affects a group of weights selected by operations).
            - mstd: Floating standard deviation of applied magnitude noise,
                or uniform sampling at infinity (or greater than 100).
            - mmax: Set the upper range limit for magnitude to a value
                other than the default value of _LEVEL_DENOM (10).
            - inc: Integer (bool), using the severity increase with magnitude (default: 0).

            Example: 'randaug-w0-n3-mstd0.5' will be random augment
                using the weights 0, num_layers 3, magnitude_std 0.5.
        hparams: Other hparams (kwargs) for the RandAugmentation scheme.
    """
    magnitude = _LEVEL_DENOM  # default to _LEVEL_DENOM for magnitude (currently 10)
    num_layers = 2  # default to 2 ops per image
    hparams.setdefault("magnitude_std", 0.5)  # default magnitude_std is set to 0.5
    weight_idx = None  # default to no probability weights for op choice
    transforms = _RAND_TRANSFORMS
    config = configs.split("-")
    assert config[0] == "randaug"
    config = config[1:]
    for c in config:
        cs = re.split(r"(\d.*)", c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == "mstd":
            # noise param / randomization of magnitude values
            mstd = float(val)
            if mstd > 100:
                # use uniform sampling in 0 to magnitude if mstd is > 100
                mstd = float("inf")
            hparams.setdefault("magnitude_std", mstd)
        elif key == "mmax":
            # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
            hparams.setdefault("magnitude_max", int(val))
        elif key == "inc":
            if bool(val):
                transforms = _RAND_INCREASING_TRANSFORMS
        elif key == "m":
            magnitude = int(val)
        elif key == "n":
            num_layers = int(val)
        elif key == "w":
            weight_idx = int(val)
        else:
            assert False, "Unknown RandAugment config section"
    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)

mindcv.data.auto_augment.trivial_augment_wide_transform(configs, hparams)

Create a TrivialAugmentWide transform Args: configs: A string that defines the TrivialAugmentWide configuration. It is composed of multiple parts separated by dashes ("-"). The first part defines the AutoAugment name, it should be 'trivialaugwide'. the second part(not necessary) the maximum value of magnitude.

    - m: final magnitude of a operation will uniform sampling from [0, m] . Default: 31

    Example: 'trivialaugwide-m20' will be TrivialAugmentWide
    with mgnitude uniform sampling from [0, 20],
hparams: Other hparams (kwargs) for the TrivialAugment scheme.

Returns: A Mindspore compatible Transform

Source code in mindcv/data/auto_augment.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
def trivial_augment_wide_transform(configs, hparams):
    """
    Create a TrivialAugmentWide transform
    Args:
        configs: A string that defines the TrivialAugmentWide configuration.
            It is composed of multiple parts separated by dashes ("-").
            The first part defines the AutoAugment name, it should be 'trivialaugwide'.
            the second part(not necessary) the maximum value of magnitude.

            - m: final magnitude of a operation will uniform sampling from [0, m] . Default: 31

            Example: 'trivialaugwide-m20' will be TrivialAugmentWide
            with mgnitude uniform sampling from [0, 20],
        hparams: Other hparams (kwargs) for the TrivialAugment scheme.
    Returns:
        A Mindspore compatible Transform
    """
    magnitude = 31
    transforms = _TRIVIALAUGMENT_WIDE_TRANSFORMS
    config = configs.split("-")
    assert config[0] == "trivialaugwide"
    config = config[1:]
    for c in config:
        cs = re.split(r"(\d.*)", c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == "m":
            magnitude = int(val)
        else:
            assert False, "Unknown TrivialAugmentWide config section"
    if not hparams:
        hparams = dict()
    hparams["magnitude_max"] = magnitude
    hparams["magnitude_std"] = float("inf")  # default to uniform sampling
    hparams["trivialaugwide"] = True
    ta_ops = trivial_augment_wide_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
    return TrivialAugmentWide(ta_ops)

mindcv.data.auto_augment.augment_and_mix_transform(configs, hparams=None)

Create AugMix PyTorch transform

PARAMETER DESCRIPTION
configs

String defining configuration of AugMix augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the specific name of augment, it should be 'augmix'. The remaining sections, not order sepecific determine 'm' - integer magnitude (severity) of augmentation mix (default: 3) 'w' - integer width of augmentation chain (default: 3) 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) 'a' - integer or float, the args of beta deviation of beta for generate the weight, default 1.. Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2

TYPE: str

hparams

Other hparams (kwargs) for the Augmentation transforms

DEFAULT: None

RETURNS DESCRIPTION

A Mindspore compatible Transform

Source code in mindcv/data/auto_augment.py
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
def augment_and_mix_transform(configs, hparams=None):
    """Create AugMix PyTorch transform

    Args:
        configs (str): String defining configuration of AugMix augmentation. Consists of multiple sections separated
            by dashes ('-'). The first section defines the specific name of augment, it should be 'augmix'.
            The remaining sections, not order sepecific determine
                'm' - integer magnitude (severity) of augmentation mix (default: 3)
                'w' - integer width of augmentation chain (default: 3)
                'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
                'a' - integer or float, the args of beta deviation of beta for generate the weight, default 1..
            Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2

        hparams: Other hparams (kwargs) for the Augmentation transforms

    Returns:
         A Mindspore compatible Transform
    """
    magnitude = 3
    width = 3
    depth = -1
    alpha = 1.0
    config = configs.split("-")
    assert config[0] == "augmix"
    config = config[1:]
    for c in config:
        cs = re.split(r"(\d.*)", c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == "m":
            magnitude = int(val)
        elif key == "w":
            width = int(val)
        elif key == "d":
            depth = int(val)
        elif key == "a":
            alpha = float(val)
        else:
            assert False, "Unknown AugMix config section"
    if not hparams:
        hparams = dict()
    hparams["magnitude_std"] = float("inf")  # default to uniform sampling (if not set via mstd arg)
    ops = augmix_ops(magnitude=magnitude, hparams=hparams)
    return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)

Dataset Factory

mindcv.data.dataset_factory.create_dataset(name='', root=None, split='train', shuffle=True, num_samples=None, num_shards=None, shard_id=None, num_parallel_workers=None, download=False, num_aug_repeats=0, **kwargs)

Creates dataset by name.

PARAMETER DESCRIPTION
name

dataset name like MNIST, CIFAR10, ImageNeT, ''. '' means a customized dataset. Default: ''.

TYPE: str DEFAULT: ''

root

dataset root dir. Default: None.

TYPE: Optional[str] DEFAULT: None

split

data split: '' or split name string (train/val/test), if it is '', no split is used. Otherwise, it is a subfolder of root dir, e.g., train, val, test. Default: 'train'.

TYPE: str DEFAULT: 'train'

shuffle

whether to shuffle the dataset. Default: True.

TYPE: bool DEFAULT: True

num_samples

Number of elements to sample (default=None, which means sample all elements).

TYPE: Optional[int] DEFAULT: None

num_shards

Number of shards that the dataset will be divided into (default=None). When this argument is specified, num_samples reflects the maximum sample number of per shard.

TYPE: Optional[int] DEFAULT: None

shard_id

The shard ID within num_shards (default=None). This argument can only be specified when num_shards is also specified.

TYPE: Optional[int] DEFAULT: None

num_parallel_workers

Number of workers to read the data (default=None, set in the config).

TYPE: Optional[int] DEFAULT: None

download

whether to download the dataset. Default: False

TYPE: bool DEFAULT: False

num_aug_repeats

Number of dataset repetition for repeated augmentation. If 0 or 1, repeated augmentation is disabled. Otherwise, repeated augmentation is enabled and the common choice is 3. (Default: 0)

TYPE: int DEFAULT: 0

Note

For custom datasets and imagenet, the dataset dir should follow the structure like: .dataset_name/ ├── split1/ │ ├── class1/ │ │ ├── 000001.jpg │ │ ├── 000002.jpg │ │ └── .... │ └── class2/ │ ├── 000001.jpg │ ├── 000002.jpg │ └── .... └── split2/ ├── class1/ │ ├── 000001.jpg │ ├── 000002.jpg │ └── .... └── class2/ ├── 000001.jpg ├── 000002.jpg └── ....

RETURNS DESCRIPTION

Dataset object

Source code in mindcv/data/dataset_factory.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def create_dataset(
    name: str = "",
    root: Optional[str] = None,
    split: str = "train",
    shuffle: bool = True,
    num_samples: Optional[int] = None,
    num_shards: Optional[int] = None,
    shard_id: Optional[int] = None,
    num_parallel_workers: Optional[int] = None,
    download: bool = False,
    num_aug_repeats: int = 0,
    **kwargs,
):
    r"""Creates dataset by name.

    Args:
        name: dataset name like MNIST, CIFAR10, ImageNeT, ''. '' means a customized dataset. Default: ''.
        root: dataset root dir. Default: None.
        split: data split: '' or split name string (train/val/test), if it is '', no split is used.
            Otherwise, it is a subfolder of root dir, e.g., train, val, test. Default: 'train'.
        shuffle: whether to shuffle the dataset. Default: True.
        num_samples: Number of elements to sample (default=None, which means sample all elements).
        num_shards: Number of shards that the dataset will be divided into (default=None).
            When this argument is specified, `num_samples` reflects the maximum sample number of per shard.
        shard_id: The shard ID within `num_shards` (default=None).
            This argument can only be specified when `num_shards` is also specified.
        num_parallel_workers: Number of workers to read the data (default=None, set in the config).
        download: whether to download the dataset. Default: False
        num_aug_repeats: Number of dataset repetition for repeated augmentation.
            If 0 or 1, repeated augmentation is disabled.
            Otherwise, repeated augmentation is enabled and the common choice is 3. (Default: 0)

    Note:
        For custom datasets and imagenet, the dataset dir should follow the structure like:
        .dataset_name/
        ├── split1/
        │  ├── class1/
        │  │   ├── 000001.jpg
        │  │   ├── 000002.jpg
        │  │   └── ....
        │  └── class2/
        │      ├── 000001.jpg
        │      ├── 000002.jpg
        │      └── ....
        └── split2/
           ├── class1/
           │   ├── 000001.jpg
           │   ├── 000002.jpg
           │   └── ....
           └── class2/
               ├── 000001.jpg
               ├── 000002.jpg
               └── ....

    Returns:
        Dataset object
    """
    name = name.lower()
    if root is None:
        root = os.path.join(get_dataset_download_root(), name)

    assert (num_samples is None) or (num_aug_repeats == 0), "num_samples and num_aug_repeats can NOT be set together."

    # subset sampling
    if num_samples is not None and num_samples > 0:
        # TODO: rewrite ordered distributed sampler (subset sampling in distributed mode is not tested)
        if num_shards is not None and num_shards > 1:  # distributed
            _logger.info(f"number of shards: {num_shards}, number of samples: {num_samples}")
            sampler = DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
        else:  # standalone
            if shuffle:
                sampler = ds.RandomSampler(replacement=False, num_samples=num_samples)
            else:
                sampler = ds.SequentialSampler(num_samples=num_samples)
        mindspore_kwargs = dict(
            shuffle=None,
            sampler=sampler,
            num_parallel_workers=num_parallel_workers,
            **kwargs,
        )
    else:
        sampler = None
        mindspore_kwargs = dict(
            shuffle=shuffle,
            sampler=sampler,
            num_shards=num_shards,
            shard_id=shard_id,
            num_parallel_workers=num_parallel_workers,
            **kwargs,
        )

    # sampler for repeated augmentation
    if num_aug_repeats > 0:
        dataset_size = get_dataset_size(name, root, split)
        _logger.info(
            f"Repeated augmentation is enabled, num_aug_repeats: {num_aug_repeats}, "
            f"original dataset size: {dataset_size}."
        )
        # since drop_remainder is usually True, we don't need to do rounding in sampling
        sampler = RepeatAugSampler(
            dataset_size,
            num_shards=num_shards,
            rank_id=shard_id,
            num_repeats=num_aug_repeats,
            selected_round=0,
            shuffle=shuffle,
        )
        mindspore_kwargs = dict(shuffle=None, sampler=sampler, num_shards=None, shard_id=None, **kwargs)

    # create dataset
    if name in _MINDSPORE_BASIC_DATASET:
        dataset_class = _MINDSPORE_BASIC_DATASET[name][0]
        dataset_download = _MINDSPORE_BASIC_DATASET[name][1]
        dataset_new_path = None
        if download:
            if shard_id is not None:
                root = os.path.join(root, f"dataset_{str(shard_id)}")
            dataset_download = dataset_download(root)
            dataset_download.download()
            dataset_new_path = dataset_download.path

        dataset = dataset_class(
            dataset_dir=dataset_new_path if dataset_new_path else root,
            usage=split,
            **mindspore_kwargs,
        )
        # address ms dataset num_classes empty issue
        if name == "mnist":
            dataset.num_classes = lambda: 10
        elif name == "cifar10":
            dataset.num_classes = lambda: 10
        elif name == "cifar100":
            dataset.num_classes = lambda: 100

    else:
        if name == "imagenet" and download:
            raise ValueError(
                "Imagenet dataset download is not supported. "
                "Please download imagenet from https://www.image-net.org/download.php, "
                "and parse the path of dateset directory via args.data_dir."
            )

        if os.path.isdir(root):
            root = os.path.join(root, split)
        dataset = ImageFolderDataset(dataset_dir=root, **mindspore_kwargs)
        """ Another implementation which a bit slower than ImageFolderDataset
            imagenet_dataset = ImageNetDataset(dataset_dir=root)
            sampler = RepeatAugSampler(len(imagenet_dataset), num_shards=num_shards, rank_id=shard_id,
                                       num_repeats=repeated_aug, selected_round=1, shuffle=shuffle)
            dataset = ds.GeneratorDataset(imagenet_dataset, column_names=imagenet_dataset.column_names, sampler=sampler)
        """
    return dataset

Sampler

mindcv.data.distributed_sampler.RepeatAugSampler

Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process.

This sampler was adapted from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py

PARAMETER DESCRIPTION
dataset_size

dataset size.

num_shards

num devices.

DEFAULT: None

rank_id

device id.

DEFAULT: None

shuffle(bool)

True for using shuffle, False for not using.

num_repeats(int)

num of repeated instances in repeated augmentation, Default:3.

selected_round(int)

round the total num of samples by this factor, Defailt:256.

Source code in mindcv/data/distributed_sampler.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class RepeatAugSampler:
    """Sampler that restricts data loading to a subset of the dataset for distributed,
    with repeated augmentation.
    It ensures that different each augmented version of a sample will be visible to a
    different process.

    This sampler was adapted from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py

    Args:
        dataset_size: dataset size.
        num_shards: num devices.
        rank_id: device id.
        shuffle(bool): True for using shuffle, False for not using.
        num_repeats(int): num of repeated instances in repeated augmentation, Default:3.
        selected_round(int): round the total num of samples by this factor, Defailt:256.
    """

    def __init__(
        self,
        dataset_size,
        num_shards=None,
        rank_id=None,
        shuffle=True,
        num_repeats=3,
        selected_round=256,
    ):
        if num_shards is None:
            _logger.warning("num_shards is set to 1 in RepeatAugSampler since it is not passed in")
            num_shards = 1
        if rank_id is None:
            rank_id = 0

        # assert isinstance(num_repeats, int), f'num_repeats should be Type integer, but got {type(num_repeats)}'

        self.dataset_size = dataset_size
        self.num_shards = num_shards
        self.rank_id = rank_id
        self.shuffle = shuffle
        self.num_repeats = int(num_repeats)
        self.epoch = 0
        self.num_samples = int(math.ceil(self.dataset_size * num_repeats / self.num_shards))
        self.total_size = self.num_samples * self.num_shards
        # Determine the number of samples to select per epoch for each rank.
        if selected_round:
            self.num_selected_samples = int(
                math.floor(self.dataset_size // selected_round * selected_round / num_shards)
            )
        else:
            self.num_selected_samples = int(math.ceil(self.dataset_size / num_shards))

    def __iter__(self):
        # deterministically shuffle based on epoch
        # print('__iter__  generating new shuffled indices: ', self.epoch)
        if self.shuffle:
            indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
            indices = indices.tolist()
            self.epoch += 1
            # print(indices[:30])
        else:
            indices = list(range(self.dataset_size))
        # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
        indices = [ele for ele in indices for i in range(self.num_repeats)]

        # add extra samples to make it evenly divisible
        padding_size = self.total_size - len(indices)
        if padding_size > 0:
            indices += indices[:padding_size]
        assert len(indices) == self.total_size

        # subsample per rank
        indices = indices[self.rank_id : self.total_size : self.num_shards]
        assert len(indices) == self.num_samples

        # return up to num selected samples
        return iter(indices[: self.num_selected_samples])

    def __len__(self):
        return self.num_selected_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

DataLoader

mindcv.data.loader.create_loader(dataset, batch_size, drop_remainder=False, is_training=False, mixup=0.0, cutmix=0.0, cutmix_prob=0.0, num_classes=1000, transform=None, target_transform=None, num_parallel_workers=None, python_multiprocessing=False, separate=False)

Creates dataloader.

Applies operations such as transform and batch to the ms.dataset.Dataset object created by the create_dataset function to get the dataloader.

PARAMETER DESCRIPTION
dataset

dataset object created by create_dataset.

TYPE: Dataset

batch_size

The number of rows each batch is created with. An int or callable object which takes exactly 1 parameter, BatchInfo.

TYPE: int or function

drop_remainder

Determines whether to drop the last block whose data row number is less than batch size (default=False). If True, and if there are less than batch_size rows available to make the last batch, then those rows will be dropped and not propagated to the child node.

TYPE: bool DEFAULT: False

is_training

whether it is in train mode. Default: False.

TYPE: bool DEFAULT: False

mixup

mixup alpha, mixup will be enabled if > 0. (default=0.0).

TYPE: float DEFAULT: 0.0

cutmix

cutmix alpha, cutmix will be enabled if > 0. (default=0.0). This operation is experimental.

TYPE: float DEFAULT: 0.0

cutmix_prob

prob of doing cutmix for an image (default=0.0)

TYPE: float DEFAULT: 0.0

num_classes

the number of classes. Default: 1000.

TYPE: int DEFAULT: 1000

transform

the list of transformations that wil be applied on the image, which is obtained by create_transform. If None, the default imagenet transformation for evaluation will be applied. Default: None.

TYPE: list or None DEFAULT: None

target_transform

the list of transformations that will be applied on the label. If None, the label will be converted to the type of ms.int32. Default: None.

TYPE: list or None DEFAULT: None

num_parallel_workers

Number of workers(threads) to process the dataset in parallel (default=None).

TYPE: int DEFAULT: None

python_multiprocessing

Parallelize Python operations with multiple worker processes. This option could be beneficial if the Python operation is computational heavy (default=False).

TYPE: bool DEFAULT: False

separate(bool,

separate the image clean and the image been transformed. If separate==True, that means the dataset returned has 3 parts: * the first part called image "clean", which means the image without auto_augment (e.g., auto-aug) * the second and third parts called image transformed, hence, with the auto_augment transform. Refer to ".transforms_factory.create_transforms" for more information.

TYPE: optional

Note
  1. cutmix is now experimental (which means performance gain is not guarantee) and can not be used together with mixup due to the label int type conflict.
  2. is_training, mixup, num_classes is used for MixUp, which is a kind of transform operation. However, we are not able to merge it into transform, due to the limitations of the mindspore.dataset API.
RETURNS DESCRIPTION

BatchDataset, dataset batched.

Source code in mindcv/data/loader.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def create_loader(
    dataset,
    batch_size,
    drop_remainder=False,
    is_training=False,
    mixup=0.0,
    cutmix=0.0,
    cutmix_prob=0.0,
    num_classes=1000,
    transform=None,
    target_transform=None,
    num_parallel_workers=None,
    python_multiprocessing=False,
    separate=False,
):
    r"""Creates dataloader.

    Applies operations such as transform and batch to the `ms.dataset.Dataset` object
    created by the `create_dataset` function to get the dataloader.

    Args:
        dataset (ms.dataset.Dataset): dataset object created by `create_dataset`.
        batch_size (int or function): The number of rows each batch is created with. An
            int or callable object which takes exactly 1 parameter, BatchInfo.
        drop_remainder (bool, optional): Determines whether to drop the last block
            whose data row number is less than batch size (default=False). If True, and if there are less
            than batch_size rows available to make the last batch, then those rows will
            be dropped and not propagated to the child node.
        is_training (bool): whether it is in train mode. Default: False.
        mixup (float): mixup alpha, mixup will be enabled if > 0. (default=0.0).
        cutmix (float): cutmix alpha, cutmix will be enabled if > 0. (default=0.0). This operation is experimental.
        cutmix_prob (float): prob of doing cutmix for an image (default=0.0)
        num_classes (int): the number of classes. Default: 1000.
        transform (list or None): the list of transformations that wil be applied on the image,
            which is obtained by `create_transform`. If None, the default imagenet transformation
            for evaluation will be applied. Default: None.
        target_transform (list or None): the list of transformations that will be applied on the label.
            If None, the label will be converted to the type of ms.int32. Default: None.
        num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel
            (default=None).
        python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This
            option could be beneficial if the Python operation is computational heavy (default=False).
        separate(bool, optional): separate the image clean and the image been transformed.
            If separate==True, that means the dataset returned has 3 parts:
            * the first part called image "clean", which means the image without auto_augment (e.g., auto-aug)
            * the second and third parts called image transformed, hence, with the auto_augment transform.
            Refer to ".transforms_factory.create_transforms" for more information.

    Note:
        1. cutmix is now experimental (which means performance gain is not guarantee)
            and can not be used together with mixup due to the label int type conflict.
        2. `is_training`, `mixup`, `num_classes` is used for MixUp, which is a kind of transform operation.
          However, we are not able to merge it into `transform`, due to the limitations of the `mindspore.dataset` API.


    Returns:
        BatchDataset, dataset batched.
    """

    if target_transform is None:
        target_transform = transforms.TypeCast(ms.int32)
    target_input_columns = "label" if "label" in dataset.get_col_names() else "fine_label"
    dataset = dataset.map(
        operations=target_transform,
        input_columns=target_input_columns,
        num_parallel_workers=num_parallel_workers,
        python_multiprocessing=python_multiprocessing,
    )

    if transform is None:
        warnings.warn(
            "Using None as the default value of transform will set it back to "
            "traditional image transform, which is not recommended. "
            "You should explicitly call `create_transforms` and pass it to `create_loader`."
        )
        transform = create_transforms("imagenet", is_training=False)

    # only apply augment splits to train dataset
    if separate and is_training:
        assert isinstance(transform, tuple) and len(transform) == 3

        # Note: mindspore-2.0 delete the parameter column_order
        sig = inspect.signature(dataset.map)
        pass_column_order = False if "kwargs" in sig.parameters else True

        # map all the transform
        dataset = map_transform_splits(
            dataset, transform, num_parallel_workers, python_multiprocessing, pass_column_order
        )
        # after batch, datasets has 4 columns
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
        # concat the 3 columns of image
        dataset = dataset.map(
            operations=concat_per_batch_map,
            input_columns=["image_clean", "image_aug1", "image_aug2", "label"],
            output_columns=["image", "label"],
            column_order=["image", "label"] if pass_column_order else None,
            num_parallel_workers=num_parallel_workers,
            python_multiprocessing=python_multiprocessing,
        )

    else:
        dataset = dataset.map(
            operations=transform,
            input_columns="image",
            num_parallel_workers=num_parallel_workers,
            python_multiprocessing=python_multiprocessing,
        )

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)

    if is_training:
        if (mixup + cutmix > 0.0) and batch_size > 1:
            # TODO: use mindspore vision cutmix and mixup after the confliction fixed in later release
            # set label_smoothing 0 here since label smoothing is computed in loss module
            mixup_fn = Mixup(
                mixup_alpha=mixup,
                cutmix_alpha=cutmix,
                cutmix_minmax=None,
                prob=cutmix_prob,
                switch_prob=0.5,
                label_smoothing=0.0,
                num_classes=num_classes,
            )
            # images in a batch are mixed. labels are converted soft onehot labels.
            dataset = dataset.map(
                operations=mixup_fn,
                input_columns=["image", target_input_columns],
                num_parallel_workers=num_parallel_workers,
            )

    return dataset

MixUp

mindcv.data.mixup.Mixup

Mixup/Cutmix that applies different params to each element or whole batch

PARAMETER DESCRIPTION
mixup_alpha

mixup alpha value, mixup is active if > 0.

TYPE: float DEFAULT: 1.0

cutmix_alpha

cutmix alpha value, cutmix is active if > 0.

TYPE: float DEFAULT: 0.0

cutmix_minmax

cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.

TYPE: List[float] DEFAULT: None

prob

probability of applying mixup or cutmix per batch or element

TYPE: float DEFAULT: 1.0

switch_prob

probability of switching to cutmix instead of mixup when both are active

TYPE: float DEFAULT: 0.5

mode

how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)

TYPE: str DEFAULT: 'batch'

correct_lam

apply lambda correction when cutmix bbox clipped by image borders

TYPE: bool DEFAULT: True

label_smoothing

apply label smoothing to the mixed target tensor

TYPE: float DEFAULT: 0.1

num_classes

number of classes for target

TYPE: int DEFAULT: 1000

Source code in mindcv/data/mixup.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class Mixup:
    """Mixup/Cutmix that applies different params to each element or whole batch

    Args:
        mixup_alpha (float): mixup alpha value, mixup is active if > 0.
        cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
        cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
        prob (float): probability of applying mixup or cutmix per batch or element
        switch_prob (float): probability of switching to cutmix instead of mixup when both are active
        mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
        correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
        label_smoothing (float): apply label smoothing to the mixed target tensor
        num_classes (int): number of classes for target
    """

    def __init__(
        self,
        mixup_alpha=1.0,
        cutmix_alpha=0.0,
        cutmix_minmax=None,
        prob=1.0,
        switch_prob=0.5,
        mode="batch",
        correct_lam=True,
        label_smoothing=0.1,
        num_classes=1000,
    ):
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.cutmix_minmax = cutmix_minmax
        if self.cutmix_minmax is not None:
            assert len(self.cutmix_minmax) == 2
            # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
            self.cutmix_alpha = 1.0
        self.mix_prob = prob
        self.switch_prob = switch_prob
        self.label_smoothing = label_smoothing
        self.num_classes = num_classes
        self.mode = mode
        self.correct_lam = correct_lam  # correct lambda based on clipped area for cutmix
        self.mixup_enabled = True  # set false to disable mixing (intended tp be set by train loop)

    def _params_per_elem(self, batch_size):
        """_params_per_elem"""
        lam = np.ones(batch_size, dtype=np.float32)
        use_cutmix = np.zeros(batch_size, dtype=np.bool)
        if self.mixup_enabled:
            if self.mixup_alpha > 0.0 and self.cutmix_alpha > 0.0:
                use_cutmix = np.random.rand(batch_size) < self.switch_prob
                lam_mix = np.where(
                    use_cutmix,
                    np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
                    np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size),
                )
            elif self.mixup_alpha > 0.0:
                lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
            elif self.cutmix_alpha > 0.0:
                use_cutmix = np.ones(batch_size, dtype=np.bool)
                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
            else:
                assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
            lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
        return lam, use_cutmix

    def _params_per_batch(self):
        """_params_per_batch"""
        lam = 1.0
        use_cutmix = False
        if self.mixup_enabled and np.random.rand() < self.mix_prob:
            if self.mixup_alpha > 0.0 and self.cutmix_alpha > 0.0:
                use_cutmix = np.random.rand() < self.switch_prob
                lam_mix = (
                    np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
                    if use_cutmix
                    else np.random.beta(self.mixup_alpha, self.mixup_alpha)
                )
            elif self.mixup_alpha > 0.0:
                lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
            elif self.cutmix_alpha > 0.0:
                use_cutmix = True
                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
            else:
                assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
            lam = float(lam_mix)
        return lam, use_cutmix

    def _mix_elem(self, x):
        """_mix_elem"""
        batch_size = len(x)
        lam_batch, use_cutmix = self._params_per_elem(batch_size)
        x_orig = x.clone()  # need to keep an unmodified original for mixing source
        for i in range(batch_size):
            j = batch_size - i - 1
            lam = lam_batch[i]
            if lam != 1.0:
                if use_cutmix[i]:
                    (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
                        x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam
                    )
                    x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
                    lam_batch[i] = lam
                else:
                    x[i] = x[i] * lam + x_orig[j] * (1 - lam)
        return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)

    def _mix_pair(self, x):
        """_mix_pair"""
        batch_size = len(x)
        lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
        x_orig = x.clone()  # need to keep an unmodified original for mixing source
        for i in range(batch_size // 2):
            j = batch_size - i - 1
            lam = lam_batch[i]
            if lam != 1.0:
                if use_cutmix[i]:
                    (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
                        x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam
                    )
                    x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
                    x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
                    lam_batch[i] = lam
                else:
                    x[i] = x[i] * lam + x_orig[j] * (1 - lam)
                    x[j] = x[j] * lam + x_orig[i] * (1 - lam)
        lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
        return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)

    def _mix_batch(self, x):
        """_mix_batch"""
        lam, use_cutmix = self._params_per_batch()
        if lam == 1.0:
            return 1.0
        if use_cutmix:
            (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
                x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam
            )
            x[:, :, yl:yh, xl:xh] = np.flip(x, axis=0)[:, :, yl:yh, xl:xh]
        else:
            x_flipped = np.flip(x, axis=0) * (1.0 - lam)
            x *= lam
            x += x_flipped
        return lam

    def __call__(self, x, target):
        """Mixup apply"""
        # the same to image, label
        assert len(x) % 2 == 0, "Batch size should be even when using this"
        if self.mode == "elem":
            lam = self._mix_elem(x)
        elif self.mode == "pair":
            lam = self._mix_pair(x)
        else:
            lam = self._mix_batch(x)
        target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
        return x.astype(np.float32), target.astype(np.float32)

Transform Factory

mindcv.data.transforms_factory.create_transforms(dataset_name='', image_resize=224, is_training=False, auto_augment=None, separate=False, **kwargs)

Creates a list of transform operation on image data.

PARAMETER DESCRIPTION
dataset_name

if '', customized dataset. Currently, apply the same transform pipeline as ImageNet. if standard dataset name is given including imagenet, cifar10, mnist, preset transforms will be returned. Default: ''.

TYPE: str DEFAULT: ''

image_resize

the image size after resize for adapting to network. Default: 224.

TYPE: int DEFAULT: 224

is_training

if True, augmentation will be applied if support. Default: False.

TYPE: bool DEFAULT: False

separate

separate the image clean and the image been transformed. If separate==True, the transformers are returned as a tuple of 3 separate transforms for use in a mixing dataset that passes: * all data through the primary transform, called "clean" data * a portion of the data through the secondary transform (e.g., auto-aug) * normalized and converts the branches above with the third, transform

DEFAULT: False

**kwargs

additional args parsed to transforms_imagenet_train and transforms_imagenet_eval

DEFAULT: {}

RETURNS DESCRIPTION

A list of transformation operations

Source code in mindcv/data/transforms_factory.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def create_transforms(
    dataset_name="",
    image_resize=224,
    is_training=False,
    auto_augment=None,
    separate=False,
    **kwargs,
):
    r"""Creates a list of transform operation on image data.

    Args:
        dataset_name (str): if '', customized dataset. Currently, apply the same transform pipeline as ImageNet.
            if standard dataset name is given including imagenet, cifar10, mnist, preset transforms will be returned.
            Default: ''.
        image_resize (int): the image size after resize for adapting to network. Default: 224.
        is_training (bool): if True, augmentation will be applied if support. Default: False.
        auto_augment(str):augmentation strategies, such as "augmix", "autoaug" etc.
        separate: separate the image clean and the image been transformed. If separate==True, the transformers are
            returned as a tuple of 3 separate transforms for use in a mixing dataset that  passes:
            * all data through the primary transform, called "clean" data
            * a portion of the data through the secondary transform (e.g., auto-aug)
            * normalized and converts the branches above with the third, transform
        **kwargs: additional args parsed to `transforms_imagenet_train` and `transforms_imagenet_eval`

    Returns:
        A list of transformation operations
    """

    dataset_name = dataset_name.lower()

    if dataset_name in ("imagenet", ""):
        trans_args = dict(image_resize=image_resize, **kwargs)
        if is_training:
            return transforms_imagenet_train(auto_augment=auto_augment, separate=separate, **trans_args)

        return transforms_imagenet_eval(**trans_args)
    elif dataset_name in ("cifar10", "cifar100"):
        trans_list = transforms_cifar(resize=image_resize, is_training=is_training)
        return trans_list
    elif dataset_name == "mnist":
        trans_list = transforms_mnist(resize=image_resize)
        return trans_list
    else:
        raise NotImplementedError(
            f"Only supports creating transforms for ['imagenet'] datasets, but got {dataset_name}."
        )