IBM's LAMBADA AI is a novel data augmentation method for text classification tasks. The acronym LAMBADA stands for "language-model-based data augmentation". The method's idea is to finetune pretrained language models to generate synthetic training data for text classification tasks such as intent classification in conversational systems.
Chatbots must be able to perform the following three key tasks in general:
For task 1. we face the challenge of identifying all possible statements and questions that a user can utter when talking to our chatbot. To overcome this issue, Natural Language Understanding (NLU) algorithms can be utilized. These machine learning algorithms are provided with sample utterances for predefined intents that a user might have. The aim is to extrapolate from these predefined samples and thus to identify all possible wordings for a user’s intent.
The challenge in this approach is to gather a sufficiently large and diverse set of sample utterances that are used for training. These utterances usually must originate from real people and therefore require either manual work to generate or label training data. Therefore, it would be beneficial to enable automatized generation of additional training data based on a small sample set in order to reduce time and cost. The LAMBADA AI method addresses this problem.
Task 2. is a complex problem that we will discuss in more detail in upcoming articles. Task 3. can be handled applying Natural Language Generation (NLG) techniques. However, it is often sufficient to use a finite set of answers that the chatbot can choose from to solve 3.
An interesting approach to generate training utterances called LAMBADA (language-model-based data augmentation) has been published by IBM Research AI.
The underlying idea is to take a language model, which has been pretrained on large corpora such as Wikipedia and books, that is able to generate textual output of good quality. This language model is then fine-tuned on the available domain specific data. After fine-tuning, the model can then be used to generate additional utterances. These utterances in turn improve the training of Intent Classification models.
To generate new utterances and leverage them to improve intent classification, we take the following steps.
We start with a small set of training utterances for each intent defined for our chatbot. The intent ‘greet’ might for example be expressed by the utterances „hello“, „hi“, „Good morning!“, the intent ‘inform_hungry’ by „I am starving“, „My tummy is growling“, „I want a snack“ etc. Next, we train an intent classification model – using BERT for example. This model acts as our baseline that we want to improve.
Next, we take a generative model such as GPT-2 and fine-tune it on the small set of intent and utterance data. To generate text with GPT-2 must provide some input text as seed. GPT-2 will then predict the next words to follow and output its prediction. To get an impression of NLG, you can let GPT-2 generate text from your input at here.
See our example: We provided GPT-2 with the phrase „I would like to order some coffee“ and repeatedly let it add more text. The result was
„I would like to order some coffee to go and we will discuss how to proceed from there. It will take around 15 minutes to get a table, so it's not going to be a huge rush but if you want to try a beer I 'd suggest waiting to place the order and then you can just pick it up when I'm ready to go. . .. I 've ordered a few times now and it's always been okay . ... The waitress is always nice.“
The presented text is not a result of cherry-picking but the outcome of our first try. We can see that even though there is not really a concise message in our generated story, the language is correct, respects context and is quite close to what a human might author.
How do we use this model to generate new utterances for our intents? To do this we fine-tune the model by providing it our intent/utterance pairs as training data as follows:
Our aim is that if we provide the model with the input string „greet,“ it is supposed to complete the string to something like „greet, Good evening!“, using its general model for language generation and respecting our , notation.
We provide GPT-2 with our intent names multiple times. Given that the evaluation of a neural network is deterministic, it would predict and therefore generate the same output repeatedly.
However, GPT-2 provides the possibility to randomly pick one of the top k predictions for each word, adding variability to the output.
We generate many more utterances than we want to keep in the end because we will still have to drop some of the lower quality generations.
Since we will be using the generated utterances to train our classifier again, we must ensure that these utterances are a) correct and b) of good quality. It cannot be ruled out that our GPT-2 model generated intent/utterance pairs like “inform_hungry,Good afternoon”, or something completely out of our scope like “inform_hungry,the most important thing for any food processor is the size and the speed of the machine.” (again, taken from here).
Therefore, we filter the generated utterances by predicting their intent with our old baseline Intent Classification model and applying the following rules:
a) If the predicted intent for the utterance does not match the intent for which it was generated, we drop the utterance.
b) If the prediction has a confidence score below a certain threshold, we drop the utterance as well.
To filter the data with the model that we want to improve seems like moving in circles, however the approach is borrowed from semi-supervised learning settings where a larger pool of unlabelled data already exists and only needs to be labelled.
The filtered utterances are now added to the initial training data and the Intent Classification model – in our case BERT – is trained on the new enriched data set. We can then compare the performance of the classifier on a test set that was neither used in the first or second training process, nor in the fine-tuning of GPT-2.
In their paper the IBM Research AI team runs the LAMBADA algorithm on three data sets (ATIS, TREC, WVA), using three different models for Intent Classification (SVM, LSTM, BERT) and eventually compares LAMBADAs performance with other data augmentation techniques (EDA, CVAE, CBERT). You should look up the publication if you are interested in more details but here are their main conclusions:
For our own experiment we used one of the Chitchat data sets provided for the Microsoft Azure QnA Maker. We took ten intents with a minimum of 65 utterances each and 1147 utterances overall. Per intent, we randomly chose ten utterances to train BERT for Intent Classification and fine-tune GPT-2. We kept the remaining 1047 intent/utterance pairs as test set to measure our performance.
The ten utterances per intent were split again using six utterances per intent for training and four utterances per intent as a validation set during training, leaving us with a total of 60 intent/utterance pairs for training and 40 for validation. We fine-tuned a distilBERT model on this data and evaluated it with our test set. The model reached an accuracy of 86.3%.
Next, we fine-tuned the medium-size GPT-2 using the same 10 utterances per intent. We then used this model to generate 100 new utterances per intent.
These 1000 utterances were fed into our distilBERT model to predict their intent. If the predicted intent did not match the actual intent, they were dropped from the data. Afterwards we took the 30 utterances with the highest prediction probability per intent and added them to our training data, resulting in a new data set with 40 utterances per intent. We then split the set into a training and a validation set containing 32 and 8 utterances per intent and trained distilBERT again.
Finally, we used the same test set as before to determine the accuracy of our new model. The model predicted 90.3% of the utterances correctly, giving a 4% improvement over the baseline model.
The LAMBADA method is promising when you need to train an intent classification model with only a small amount of training data available and additional data is expensive to obtain. It uses the power of pretrained generative models to generate more diverse utterances and improves your classifier's performance.
In order to use the LAMBADA method, you need to be familiar with pretrained models such as BERT and GPT-2. However, while knowing the theory behind a method is nice, it takes some effort to translate this into code. To help you get started with your own use cases watch out for our next article in our series about conversational systems, where we will describe in detail how we implemented our experiment on the Microsoft Chitchat data set.