Fine-tuning a BERT model with transformers

Setup a custom Dataset, fine-tune BERT with Transformers Trainer, and export the model via ONNX

This post describes a simple way to get started with fine-tuning transformer models. It will cover the basics and introduce you to the amazing Trainer class from the transformers library. You can run the code from Google Colab but do not forget to enable GPU support.

Photo by Samule Sun on Unsplash

We use a dataset built from COVID-19 Open Research Dataset Challenge. This work is one small piece of a larger project that is to build the cord19 search app.

Install required libraries

!pip install pandas transformers

Load the dataset

To fine-tune the BERT models for the cord19 application, we need to generate a set of query-document features and labels that indicate which documents are relevant for the specific queries. For this exercise, we will use the query string to represent the query and the title string to represent the documents.

training_data = read_csv("https://thigm85.github.io/data/cord19/cord19-query-title-label.csv")
training_data.head()
png

There are 50 unique queries.

len(training_data["query"].unique())
50

For each query, we have a list of documents, divided between relevant (label=1) and irrelevant (label=0).

training_data[["title", "label"]].groupby("label").count()
png

Data split

We are going to use a simple data split into train and validation sets for illustration purposes. Even though we have more than 50 thousand data points when considering unique query and document pairs, I believe this specific case would benefit from cross-validation since it has only 50 queries containing relevance judgment.

from sklearn.model_selection import train_test_split
train_queries, val_queries, train_docs, val_docs, train_labels, val_labels = train_test_split(
    training_data["query"].tolist(), 
    training_data["title"].tolist(), 
    training_data["label"].tolist(), 
    test_size=.2
)

Create BERT encodings

Create a train and validation encodings. To do that, we need to chose which BERT model to use. We will use padding and truncation because the training routine expects all tensors within a batch to have the same dimensions.

from transformers import BertTokenizerFast

model_name = "google/bert_uncased_L-4_H-512_A-8"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = tokenizer(train_queries, train_docs, truncation=True, padding='max_length', max_length=128)
val_encodings = tokenizer(val_queries, val_docs, truncation=True, padding='max_length', max_length=128)

Create a custom dataset

Now that we have the encodings and the labels, we can create a Dataset object as described in the transformers webpage about custom datasets.

import torch

class Cord19Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Cord19Dataset(train_encodings, train_labels)
val_dataset = Cord19Dataset(val_encodings, val_labels)

Fine-tune the BERT model

We are going to use BertForSequenceClassification, since we are trying to classify query and document pairs into two distinct classes (non-relevant, relevant).

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(model_name)

We can set requires_grad to False for all the base model parameters to fine-tune only the task-specific parameters.

for param in model.base_model.parameters():
    param.requires_grad = False

We can then fine-tune the model with Trainer. Below is a basic routine with an out-of-the-box set of parameters. Care should be taken when choosing the parameters below, but this is out of this piece’s scope.

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    evaluation_strategy="epoch",     # Evaluation is done at the end of each epoch.
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    save_total_limit=1,              # limit the total amount of checkpoints. Deletes the older checkpoints.    
)


trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset
)

trainer.train()

Export the model to ONNX

Once training is complete, we can export the model using the ONNX format to be deployed elsewhere. I assume below that you have access to a GPU, which you can get from Google Colab, for example.

from torch.onnx import export

device = torch.device("cuda") 

model_onnx_path = "model.onnx"
dummy_input = (
    train_dataset[0]["input_ids"].unsqueeze(0).to(device), 
    train_dataset[0]["token_type_ids"].unsqueeze(0).to(device), 
    train_dataset[0]["attention_mask"].unsqueeze(0).to(device)
)
input_names = ["input_ids", "token_type_ids", "attention_mask"]
output_names = ["logits"]
export(
    model, dummy_input, model_onnx_path, input_names = input_names, 
    output_names = output_names, verbose=False, opset_version=11
)

Concluding remarks

As mentioned before, this post covered basic training setup. This is a good starting point to be improved upon. It is better to start simple and complement than the opposite, especially when learning something new. I left important topics such as hyperparameter tuning, cross-validation, and more detailed model validation to followup posts. But having a basic training setup is a good first step.

Character strings in R

This post deals with the basics of character strings in R. My main reference has been Gaston Sanchez‘s ebook [1], which is excellent and you should read it if interested in manipulating text in R. I got the encoding’s section from [2], which is also a nice reference to have nearby. Text analysis will be one topic of interest to this Blog, so expect more posts about it in the near future.

Creating character strings

The class of an object that holds character strings in R is “character”. A string in R can be created using single quotes or double quotes.

chr = 'this is a string'
chr = "this is a string"

chr = "this 'is' valid"
chr = 'this "is" valid'

We can create an empty string with empty_str = "" or an empty character vector with empty_chr = character(0). Both have class “character” but the empty string has length equal to 1 while the empty character vector has length equal to zero.

empty_str = ""
empty_chr = character(0)

class(empty_str)
[1] "character"
class(empty_chr)
[1] "character"

length(empty_str)
[1] 1
length(empty_chr)
[1] 0

The function character() will create a character vector with as many empty strings as we want. We can add new components to the character vector just by assigning it to an index outside the current valid range. The index does not need to be consecutive, in which case R will auto-complete it with NA elements.

chr_vector = character(2) # create char vector
chr_vector
[1] "" ""

chr_vector[3] = "three" # add new element
chr_vector
[1] ""      ""      "three"

chr_vector[5] = "five" # do not need to 
                       # be consecutive
chr_vector
[1] ""      ""      "three" NA      "five" 

Auxiliary functions

The functions as.character() and is.character() can be used to convert non-character objects into character strings and to test if a object is of type “character”, respectively.

Strings and data objects

R has five main types of objects to store data: vector, factor, multi-dimensional array, data.frame and list. It is interesting to know how these objects behave when exposed to different types of data (e.g. character, numeric, logical).

  • vector: Vectors must have their values all of the same mode. If we combine mixed types of data in vectors, strings will dominate.
  • arrays: A matrix, which is a 2-dimensional array, have the same behavior found in vectors.
  • data.frame: By default, a column that contains a character string in it is converted to factors. If we want to turn this default behavior off we can use the argument stringsAsFactors = FALSE when constructing the data.frame object.
  • list: Each element on the list will maintain its corresponding mode.
# character dominates vector
c(1, 2, "text") 
[1] "1"    "2"    "text"

# character dominates arrays
rbind(1:3, letters[1:3]) 
    [,1] [,2] [,3]
[1,] "1"  "2"  "3" 
[2,] "a"  "b"  "c" 

# data.frame with stringsAsFactors = TRUE (default)
df1 = data.frame(numbers = 1:3, letters = letters[1:3])
df1
  numbers letters
1       1       a
2       2       b
3       3       c

str(df1, vec.len=1)
'data.frame':  3 obs. of  2 variables:
  $ numbers: int  1 2 ...
  $ letters: Factor w/ 3 levels "a","b","c": 1 2 ...

# data.frame with stringsAsFactors = FALSE
df2 = data.frame(numbers = 1:3, letters = letters[1:3], 
                 stringsAsFactors = FALSE)
df2
  numbers letters
1       1       a
2       2       b
3       3       c

str(df2, vec.len=1)
'data.frame':  3 obs. of  2 variables:
  $ numbers: int  1 2 ...
  $ letters: chr  "a" ...

# Each element in a list has its own type
list(1:3, letters[1:3])
[[1]]
[1] 1 2 3

[[2]]
[1] "a" "b" "c"

Character encoding

R provides functions to deal with various set of encoding schemes. The Encoding() function returns the encoding of a string. iconv() converts the encoding.

chr = "lá lá"
Encoding(chr)
[1] "UTF-8"

chr = iconv(chr, from = "UTF-8", 
            to = "latin1")
Encoding(chr)
[1] "latin1"

References:

[1] Gaston Sanchez’s ebook on Handling and Processing Strings in R.
[2] R Programming/Text Processing webpage.