Email writing Assistant
People send more than 300 billion emails everyday with more than 4 billion email users worldwide(source : Statista). Hence, simplifying the whole email writing process and improving user-experience is a top priority for major e-mail service providers like Gmail. It does it by using features like smart-reply and smart-compose.
Based on the content in previous email, Smart reply uses Neural networks(LSTM’s) to capture context, sentiment etc. and generates a reply. Smart-replies are certainly useful but people usually write mails themselves that are longer than ‘replies’ so we need something that helps people ‘compose’ these e-mails.
Smart compose helps in sentence completion where given current prefix and context(subject + previous email) it predicts next most likely words. Released by google in ~ 2018, has helped them save 1 billion+ character per week! according to this paper.
Want to know how to build such a system? Read along!
Given the current email prefix, we will have to generate the next few words taking context into account. The Business will need us to give out predictions of good quality (grammatically correct) while also being the most likely words that the user was going to type.
We know that we have to improve user experience but how do we quantify this? Our aim is to reduce the number of repetitive words typed. We can track:
- Average email response-time taken per user (from the time he starts typing the mail TO the time he sits ‘sent’) as a metric.
- Total number of suggested tokens selected across all users per week is another useful metric(used by Google).
We can propose these 2 business metrics to SME’s/stakeholders.
Mapping to ML Problem and metric
This is simply a Natural Language Generation task. More specifically this will be a “Conditional Text Generation” task where given some tokens we will predict the next most likely tokens.
We will use Bleu score(Bleu-4 gram) as metric here:
Bleu-score = Brevity Penalty X Geometric Average of the Precision Scores. The image above shows the ‘log’ version of the metric for numerical stability. (Min part is Brevity penalty which penalizes shorter output sentences. And the other part is the precision score for upto 4-gram overlaps between reference & predicted sentences).
There are a few business constraints to be kept in mind, especially if we’re to deploy such system:
- Predictions should be grammatically correct and aligned with context.
- Very Low latency : If we notice at how google smart compose works, suggestions appear as the user is typing(even at character level!). Since inference(request to server) needs to be performed on almost every keystroke, faster models will be preferred.
- Fairness and Privacy: We wouldn’t want our model to output names, phone numbers, addresses etc. in predictions.
The dataset we’ll use is the Enron email corpus, the biggest email database on the internet with 500,000+ emails. Data is highly unstructured and a lot of heavy-lifting will be needed in cleaning and preprocessing. The dataset is available on Kaggle in a csv format here. It consists of:
- Email index : This looks something like → “emp_name/directory_name/mail_index_number”
- Message : Consists all information about the email like timestamp, from_address, sent_address, subject etc. all wrapped in an email object.
We’re only concerned with the second column (i.e. Message) of the dataset.
Data cleaning and structuring
We will use the inbuilt python email library to extract the fields of email while skipping attachments like images, docs etc. Here’s the code:
Extracting text from body — In many emails, the body just doesn’t contain a single email but full email threads. So if there are a lot of replies to a single mail we need to extract them separately someway. On some manual validation I found out that they were separated by lines which contained text like : “ — — — Forwarded By — —” or “ — — — -Original Message — — —” so we will take all the content between one occurrence of such pattern to another, since these act as borders between emails.
We do this using regex’s finditer() function which finds indices of all such patterns then we extract each email independently. After this we clean out any names, phone numbers, tags, html code, special chars, decontraction etc. from email-bodies and any other rubbish text is removed. (This is the code if you wanna have a look, it’s quite long so I’ve provided comments at every step for readability)
Q. What didn’t work in Data cleaning?
A — While modelling, when I tried to concatenate subject + previous-emails (if replying) + current email prefix and use this to predict the next words, but it didn’t work. Not only was it becoming computationally expensive(longer sequences) for models but also prediction quality was low. So I decided to keep things simple at first and extracted each email separately only.
Exploratory data analysis
We’ll ask questions and try to answer them visually.
Q. What is the prime time of email-services?
As expected the email conversations mostly happens on working days and between 8 a.m. to 10 a.m. This can help us decide on when to retrain our models.
Q. Who are the ‘active’ users?
This analysis might not be directly helpful to us, but it is helpful in real world deployment of email writing assistant systems where we might want a personalized model as well for our ‘active users’.(similar to what Google has done in Gmail)
Q. Most frequent words in Email-body.
We have words like time, process, information, new, issue etc. as common words which hints us that our data cleaning is done properly and the text makes sense.
Since ML/DL models understand only numbers we’ll need to convert our data into different formats based on the models we’ll feed it into. We’ll try 2 types of preprocessing strategies for our models :-
- Encoder-Decoder stack based : For this we’ll need to prepare our data in the form of Input & Output sequences (Seq2Seq).
- Decoder-only based : This will be an autoregressive model which will predict next tokens given previous tokens. For this we can directly give our email-sentences and have it predict next likely tokens.
Before we augment our data, we apply a few restrictions to our sentences so that we can train our models properly.
- We will first remove rare words which occur below a certain threshold. As we can see from histogram below, most words are very rare and useless. Any word which occur more than 125 times will be useful.
2. Restricting the length of sentences to 30. We can see that the elbow point might be somewhere around ~70 but this will be too computationally expensive, especially for the seq2seq data. (you’ll see in a while why)
Duplicate mails, Junk mails like logs-data and Outlook, No-reply autogenerated mails etc. all of these were dropped. Now to prepare data for models —
- Sequence data for Encoder-decoder models. For this we make encoder and decoder sequences by splitting at each word in the sentence after 5 initial words (≤5 needed for getting context). Now you know why we restricted sentence length to 30. Even after such restrictions the train data and test data had ~750k sequences & ~200k sequences respectively.
- Sentence data for decoder-based models. No preprocessing is required here as we can just input our email-bodies as individual sentences.
Q. What didn’t work in preprocessing?
A — Random splitting of sentences for Sequence data didn’t work. In random splitting instead of splitting at each word in a sentence(what we did) we split randomly ‘k’ times at any position ‘pos’ where pos ∈ (0,sentence_length).
This significantly reduced the size of data but enc-dec based models performance was quite bad.
One possible reason could be that this type of splitting is good when we are Pre-training on an unsupervised objective as an upstream task, like in this paper, where they used this technique for their ‘prefixed language modelling’ objective but for specific downstream task like ours we need to provide a good defined pattern for our model to learn.
1- RNN Encoder-Decoder model
We will code encoder-decoder model from scratch(don’t worry it’s quite simple) so that we have full control over how it predicts. First we need to decide on how to feed sequences to our model.
Word featurization with Glove Embeddings:
We featurize our data on word level and use pre-trained glove embeddings. We will use largest available i.e. vocab — 2.2M & 300-dimension embeddings. The fundamental idea behind the W2V type embeddings is that similar words are closer (geometrically) while also exploiting relationship between entities (King & Queen will have same relationship as Prince & Princess).
Model Architecture :
Even though we have seen encoder-decoder models being used for machine translation mostly, we can also use it for conditional text generation. As we have already prepared the sequences for both encoder and decoder and decided on the embeddings, we just have to code the model. The code..
- We are providing an Embedding layer(pretrained glove) followed by an LSTM layer(256 units) for both Encoder and decoder sublayers.
- We are setting “trainable=False” for embedding layer in both encoder and decoder as we only have free resources at our disposal(colab) to train. Moreover glove embeddings are pre-trained on huge data already(840B tokens) so they should possess good global context.
- Final Encoder_Decoder model consists of encoder layer, decoder layer and an additional dense softmax layer(units=decoder’s vocab size) which will give final output predictions.
2- T5 transformer
T5 is a text-to-text transformer(encoder+decoder stack) released by Google recently in 2020. The most novel thing about this model is it’s shared text-to-text framework, that is, every language problem is solved in a sequence to sequence format where both input and output are text.
I know you’re thinking what about regression? Yes! even regression is converted to multi-class(as string representations) by bucketizing the ranges. (For example — range of output is 1 to 5 then all numbers are bucketized with differences of 0.2)
Unsupervised Pretraining objective —
It also involved a slightly different flavor of Masked Language Modelling where along with masking single tokens, contiguous spans of tokens were masked as well. It was trained on very huge C4 dataset (~750GB) and tokenized using SentencePiece(unigram vocab) tokenizer.
(For more information I highly recommend you read the paper or this blog)
We will use T5-Base as our model which has a total of 220M parameters with 12 encoder-decoder layers, 12 attention heads and token embedding size of 768. We use SimpleT5 library to finetune the model on our task.
Note :- We feed the same sequence data to this model that we gave to our RNN based seq2seq. Also we can’t use task-specific prefixing in T5 since sentence-completion task is not in the list of options available.
3- GPT2 decoder-only transformer
GPT2 is a decoder only, autoregressive transformer released by OpenAI in ~ 2019. To know about the inner workings of this model, have a look at this terrific illustration, which will explain it way better than I could.
Unsupervised Pretraining objective —
GPT2 was pre-trained on next word prediction task but generalized well in generating long sequences as well (which infact is very similar to our task). It was trained on 40GB of WebText data and tokenized using Byte-pair embeddings.
We’ll be using GPT2 small for our task which has a total of 124M parameters with 12 decoder-layers, 12 attention heads and token embedding size of 768. We use SimpleGPT2 library to finetune the model on our data.
Note :- We feed each individual email-bodies that we extracted as data to this model as the finetuning objective is the same as pre-training i.e. Next word prediction.
After training Vanilla Encoder-decoder for 40 epochs, fine-tuning T5-base for 2 epochs and fine-tuning GPT2 for 2500 steps (1 step=random sample of 2048 sequences from train-data), here is a results summary :
As we can easily see GPT2 is the MVP for this problem! with a bleu-score of 18%. However, it’s latency is the worst with mean latency > 10s/query and 99P latency > 20s/query which isn’t acceptable for system like Smart-compose. (The latency can be decreased dramatically if we use efficient implementation from popular libraries like Huggingface).
The model with best latency is our vanilla Encoder-decoder with 99P latency <150ms (and bleu-score of 10%) so logically given the scale of this problem and strict time constraints, this model should be preferred for productionization. However, we will go ahead with the one that’s statistically better i.e. GPT2.
Now I know you must be thinking why is the bleu-score of just 18% enough?
Short answer- it’s not and I’ll discuss how to improve it in a while. But it is good enough given our task. As you see this is not exactly a translation task but a sentence completion problem and we want to predict the words which are most likely to be typed by the user but not exactly the same. If the Bleu score was high, that would mean that our model just “memorized” the Enron data when we wanted it to make general predictions which are grammatically correct with some context-information. For example look at this prediction :
In my opinion, even though bleu score will be 0 for this prediction, the output generated by the model is appropriate here (and so are predictions like ‘alright’, ‘fine’ etc.). Also, Brevity penalty as discussed above, penalizes shorter sentences more. Hence, all this leads to an overall lower bleu-score.
Here are a few more Test sample predictions from GPT2 :-
Not satisfied yet? Have a look at the notebooks(SOTA_modelling.ipynb) in this repository for sample predictions of GPT2 and other models as well.
Q. What didn’t work in modelling?
- Subword based embeddings with LSTM’s — I tried to use Byte pair embeddings to tokenize input for LSTM based Encoder-decoder. I experimented with various vocab-sizes, embedding dimension, different preprocessing strategies and tuned the model itself but it just didn’t work and predictions were of lower quality. Even while researching online I didn’t see many people using subword-based tokenization for LSTM’s hence I decided to drop this altogether.
- T5 Base (220M parameters) showed the worst results — I had high hopes for T5, since it’s the largest model we trained and we got the worst bleu score of just 7% ! (which is even lesser than our encoder-decoder (10%))
Large LMs trained with a lot of data generally show magical results, but it all took a nose-dive for this one. The possible reason could be that the problem we have is sentence-completion i.e. a Casual language modelling one, hence it’s better if we choose a transformer-model which has this as a pretraining objective than Masked Language modelling. GPT on the other hand, is pre-trained on Casual LM objective i.e. Next word prediction (and finetuned on this as well) hence we see it perform so well.
- None of the models worked for very long sentences OR very rare words. Since the people generally don’t write very long mails and mostly use common words, the model wasn’t able to generalize for those. (Another reason on why we have strict constraints in preprocessing)
DEPLOYMENT & PRODUCTIONIZATION
We choose GPT2 since it had the best language generation capabilities according to our task. If you’re more concerned about latency, you should either use hugging face’s implementation of GPT2 (or use an even smaller version DistillGPT2) OR use the RNN based encoder-decoder we tried earlier. I used Streamlit to create the web app but wasn’t able to deploy it on cloud (like heroku) since the deployment file size > 2.5GB and the ones (AWS/GCP) that did allow needed a card. So I decided to run it locally and recorded a video to show how our model performs in real time…
We take raw data from the input box, do all cleaning, transformation, preprocessing etc. and feed it to the model. Then show the prediction below. (In case you want to look at the code, refer App.py)
Note :- The predictions are slow because model is running locally on CPU and HDD. and (obviously) our model is slow too :(
Yes! we are done. But
Possible next steps…
Our case study is done here. But if you’re building something like this as a product/service in the real world, the actual work will start from here.
No stakeholder/SME’s will ever care about the model metrics but what they want is an actual value generation from our data science projects. So we will need to do a lot of A/B tests (again this is a broad topic in itself, here’s a starter) to show improvement in the real world business metrics.
Other important steps would be retraining, monitoring, further optimization etc.
- As I had mentioned earlier, the task we have is a Casual language modelling one whereas the largest model we chose had a Masked Language modelling objective as it’s pre-training objective. Using something like XLNet transformer which is multitasked on both Causal LM and Masked LM will be a good next step as we get the best of both worlds.
- We don’t have access to GPT-3 but we know that these autoregressive decoder-only language models trained on NWP are very good for our problem. So you can do two things from here if you have access to more resources, that is, use a larger version of GPT2 (Medium — 355M or Large — 774M) OR if you have TPU’s at your disposal then you can go ahead with GPT-Neo whose performance is very close to GPT3, offered by EleutherAI
- Train using language models that are trained on the Pile dataset. Why the Pile? Because this dataset already consists of all Enron’s email data. The models pretrained on this dataset might perform better even in zero-shot setting! (GPT-Neo mentioned above is pre-trained on this)
- All the ideas stated above use the transfer learning approach. But if you want to approach it from your own custom architecture (although not recommended), you can add more layers to our RNN encoder-decoder or make it bidirectional and see if performance increases. You are also encouraged to make changes to preprocessing and cleaning steps to produce higher quality text from Enron corpus.
This is the second machine learning (or deep learning to be specific) problem that I took in hands and solved it end to end (the first one if you’re interested). I really enjoyed doing it as it challenged me a lot at every step thus allowing me to experiment as much as I can, though at times I felt limited due to lack of resources. I hope you learnt something useful from this case-study which will help you in your work / research.
Note :- I tried to summarize as much as I can in this blog from my project but more inferences on models, errors and all the code is available here in this repository. If you have any concept/code related issue in this project, feel free to comment it below or DM me on linkedin.
If you’re someone who likes solving real world problems end-to-end, or enjoys understanding concepts from ground-up, hit that green FOLLOW button as hard as you can ! so that my content can reach you🙂
- Google’s Paper
- Google AI Blog
- Prithvi Da’s inspiring first cut approach
- T5 paper
- Hands down best explanation on GPT2