Customize transformer models to your domain

Photo by Kelly Sikkema on Unsplash

The article summarizes the key ideas from the paper Don't Stop Pretraining: Adapt Language Models to Domains and Tasks. If you are building a transformer model specific to your domain - healthcare, finance, or looking for ideas to improve your model performance, this article is for you.


Transformers like BERT are pre-trained on large heterogeneous texts, covering news articles, English literature and web content. However, do the embeddings learned by these model work universally?

The problem

Consider the confusion matrix below.

Credits: The paper

It shows the percentage of the vocabulary overlap between different corpora, and each corpus (News, Reviews, CS, BioMed) belongs to a particular domain. PT is RoBERTa’s (the transformer model selection for the experiments reported in the paper) pre-training corpus. 10k most frequent from each corpus formed the vocabulary sets. The PT corpus has the least overlap with CS (computer science publications) and the most overlap with News.

How is this useful? A minor vocabulary overlap with other domains suggests that RoBERTa’s embeddings, resulting from training on its original corpus (PT) might not work well on tasks involving a specific dissimilar domain like CS or BioMed. There could be some gap between a transformer’s original embeddings and the embeddings required for a task on a specific domain.

The question is, can further pre-training improve the model performance on specific domains? If so, how do we implement it?

  • Domain Adapted Pre-Training (DAPT)

  • Task Adapted Pre-Training (TAPT)

Domain Adaptive Pre-Training

It means continuing pre-training of a transformer the same way it was done originally, just on a large corpus of domain-related text. The authors compare the DAPT updated model with the base model on eight classification datasets, one high resource and one low resource task across the four domains. Across all eight tasks, DAPT updated model either outperformed the original model or did equally well.

To control for the possibility that the better results could be just attributed to more pre-training data and not any domain adaptation phenomena, the original model was also pre-trained on unrelated text. However, such a model did not outperform the model pre-trained on relevant unlabelled text. It underperformed even the original model in some tasks.

Task Adaptive Pre-Training

Task Adaptive Pre-Training or TAPT explores the idea that what if instead of further pre-training on large unlabelled domain-specific text (DAPT), we took a smaller, but more focused (to the downstream task) unlabelled corpus. Does this approach compete with DAPT? 

Yes.

In the experiments, the authors pre-train on the unlabelled training set. The results on the eight tasks show that in some tasks TAPT works better while in others, DAPT. However, TAPT outperforms the original model every time. 

What about doing both? Sure. Experiments show that DAPT+TAPT outperforms the individual approaches. However, it is important to note that TAPT should follow DAPT and not vice-versa, as the model could be exposed to catastrophic forgetting.

Caveats

TAPT requires lesser data and computational power than DAPT, which may indicate that it might be more favourable to do TAPT if given a choice between the two. Here’s the catch though - TAPT works well when you evaluate it on a supervised task on the same dataset. If you train a TAPTed model on some different dataset, even within the same domain, it will likely degrade the model performance. In the paper, experiments have been conducted between two models that were TAPTed on two datasets A and B and evaluated on A, where A and B belong to the same domain. Model TAPTed on a different dataset underperformed, indicating that TAPT optimizes for single-task performance.

What if you don’t have enough data for TAPT?

There might be situations where you neither have enough task specific data for TAPT nor enough compute to perform DAPT. The authors propose a bag-of-words model approach to embed the entire domain corpus. Next, pick an example from the task dataset and select k-nearest neighbours. According to the experiments, this approach provides a considerable boost in performance.

Previous
Previous

How does automatic differentiation really work?

Next
Next

How to ace your take-home data science assignment?