Fine-tune a pretrained model¶
There are significant benefits to using a pretrained model. It reduces computation costs, your carbon footprint, and allows you to use state-of-the-art models without having to train one from scratch. 🤗 Transformers provides access to thousands of pretrained models for a wide range of tasks. When you use a pretrained model, you train it on a dataset specific to your task. This is known as fine-tuning, an incredibly powerful training technique. In this tutorial, you will fine-tune a pretrained model with a deep learning framework of your choice:
- Fine-tune a pretrained model with 🤗 Transformers Trainer.
- Fine-tune a pretrained model in native MindSpore.
Prepare a dataset¶
Before you can fine-tune a pretrained model, download a dataset and prepare it for training. The previous tutorial showed you how to process data for training, and now you get an opportunity to put those skills to the test!
Begin by loading the Yelp Reviews dataset:
>>> from datasets import load_dataset
>>> dataset = load_dataset("yelp_review_full")
>>> dataset["train"][100]
{'label': 0,
'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. I\'ve worked at more than one location. I expect bad days, bad moods, and the occasional mistake. But I have yet to have a decent experience at this store. It will remain a place I avoid unless someone in my party needs to avoid illness from low blood sugar. Perhaps I should go back to the racially biased service of Steak n Shake instead!'}
As you now know, you need a tokenizer to process the text and include a padding and truncation strategy to handle any variable sequence lengths. To process your dataset in one step, use 🤗 Datasets map method to apply a preprocessing function over the entire dataset:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
>>> def tokenize_function(examples):
... return tokenizer(examples["text"], padding="max_length", truncation=True)
>>> tokenized_datasets = dataset.map(tokenize_function, batched=True)
If you like, you can create a smaller subset of the full dataset to fine-tune on to reduce the time it takes:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
Train¶
At this point, you should follow the section corresponding to the framework you want to use. You can use the links in the right sidebar to jump to the one you want - and if you want to hide all of the content for a given framework, just use the button at the top-right of that framework’s block!
Train with MindSpore Trainer¶
>>> from mindone.transformers.models.bert import BertForSequenceClassification
>>> model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
>>> from mindone.transformers.training_args import TrainingArguments
>>> training_args = TrainingArguments(output_dir="test_trainer")
>>> import mindspore as ms
>>> from mindone.transformers.mindspore_adapter import MindSporeArguments, init_environment
>>> env_args = MindSporeArguments(mode=ms.GRAPH_MODE, device_target="Ascend")
>>> init_environment(env_args)
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=small_train_dataset,
... eval_dataset=small_eval_dataset,
... compute_metrics=compute_metrics,
... )
>>> trainer.train()
Train in native MindSpore¶
>>> tokenized_datasets = tokenized_datasets.remove_columns(["text"])
>>> tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
>>> import mindspore as ms
>>> from mindone.transformers.mindspore_adapter import HF2MSDataset
>>> def ms_data_collator(features, batch_info):
... batch = {}
... for k, v in features[0]:
... batch[k] = np.stack([f[k] for f in features]) if isinstance(v, np.ndarray) else np.array([f[k] for f in features])
... return batch
>>> batch_size, num_epochs = 1, 3
>>> train_dataloader = ms.dataset.GeneratorDataset(HF2MSDataset(small_train_dataset), column_names="item")
>>> train_dataloader = train_dataloader.batch(batch_size=batch_size, per_batch_map=ms_data_collator)
>>> train_dataloader = train_dataloader.repeat(1)
>>> train_dataloader = train_dataloader.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
>>> from mindone.transformers.models.bert import BertForSequenceClassification
>>> model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
>>> from mindspore import nn
>>> optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=5e-6)
>>> from mindone.transformers.mindspore_adapter import TrainOneStepWrapper
>>> class ReturnLoss(nn.Cell):
... def __init__(self, model):
... super(ReturnLoss, self).__init__(auto_prefix=False)
... self.model = model
...
... def construct(self, *args, **kwargs):
... outputs = self.model(*args, **kwargs)
... loss = outputs[0]
... return loss
>>> train_model = TrainOneStepWrapper(ReturnLoss(model), optimizer)
>>> from tqdm.auto import tqdm
>>> num_training_steps = len(small_train_dataset) * num_epochs // batch_size
>>> progress_bar = tqdm(range(num_training_steps))
>>> train_model.train()
>>> for step, batch in enumerate(train_dataloader):
... batch = batch["item"]
...
... tuple_inputs = (
... ms.Tensor(batch["input_ids"], ms.int32),
... ms.Tensor(batch["attention_mask"], ms.bool_),
... None,
... None,
... None,
... None,
... ms.tensor(batch["labels"], ms.int32)
... )
...
... loss, _, overflow = train_model(*tuple_inputs)
...
... progress_bar.update(1)