Building a Naive Bayes spam classifier with NLTK

We'll follow the same logic as the program from chapter 3 of Machine Learning for Hackers, but we'll do so with a workflow more suited to NLTK's functions. So instead of creating a term-document matrix, and building our own Naive Bayes classifier, we'll build a features $\rightarrow$ label association for each training e-mail, and feed a list of these to NLTK's NaiveBayesClassifier function.

Some good references for this are:

Bird, Steven and et. al., Natural Language Processing with Python

Perkins, Jacob, Python Text Processing with NLTK 2.0 Cookbook

In [1]:
from pandas import *
import numpy as np
import os
import re
from nltk import NaiveBayesClassifier
import nltk.classify
from nltk.tokenize import wordpunct_tokenize
from nltk.corpus import stopwords
from collections import defaultdict

Loading the e-mail messages into lists

E-mails of each type --spam, "easy" ham, and "hard" ham-- are split across two directories per type. We'll use the first directories of spam and "easy" ham to train the classifier. Then we'll test the classifier on the e-mails in the second directories.

In [2]:
data_path = os.path.abspath(os.path.join('.', 'data'))
spam_path = os.path.join(data_path, 'spam')
spam2_path = os.path.join(data_path, 'spam_2') 
easyham_path = os.path.join(data_path, 'easy_ham')
easyham2_path = os.path.join(data_path, 'easy_ham_2')
hardham_path = os.path.join(data_path, 'hard_ham')
hardham2_path = os.path.join(data_path, 'hard_ham_2')

The following function loads all the e-mail files in a directory, extracts their message bodies and returns them in a list.

In [3]:
def get_msgdir(path):
    Read all messages from files in a directory into
    a list where each item is the text of a message. 
    Simply gets a list of e-mail files in a directory,
    and iterates get_msg() over them.

    Returns a list of strings.
    filelist = os.listdir(path)
    filelist = filter(lambda x: x != 'cmds', filelist)
    all_msgs =[get_msg(os.path.join(path, f)) for f in filelist]
    return all_msgs

def get_msg(path):
    Read in the 'message' portion of an e-mail, given
    its file path. The 'message' text begins after the first
    blank line; above is header information.

    Returns a string.
    with open(path, 'rU') as con:
        msg = con.readlines()
        first_blank_index = msg.index('\n')
        msg = msg[(first_blank_index + 1): ]
        return ''.join(msg) 

We'll use the functions to make training and testing message lists for each type of e-mail.

In [4]:
train_spam_messages    = get_msgdir(spam_path)
train_easyham_messages = get_msgdir(easyham_path)
# Only keep the first 500 to balance w/ number of spam messages.
train_easyham_messages = train_easyham_messages[:500]
train_hardham_messages = get_msgdir(hardham_path)

test_spam_messages    = get_msgdir(spam2_path)
test_easyham_messages = get_msgdir(easyham2_path)
test_hardham_messages = get_msgdir(hardham2_path)

Extracting word features from the e-mail messages

Each e-mail in our classifier's training data will have a label ("spam" or "ham") and a feature set. For this application, we're just going to use a feature set that is just a set of the unique words in the e-mail. Below, we'll turn this into a dictionary to feed into the NaiveBayesClassifier, but first, let's get the set.

Parsing and tokenizing the e-mails

We're going to use NLTK's wordpunct_tokenize function to break the message into tokens. This splits tokens at white space and (most) punctuation marks, and returns the punctuation along with the tokens on each side. So "I don't know. Do you?" becomes ["I", "don", "'", "t", "know", ".", "Do", "you", "?"].

If you look through some of the training e-mails in train_spam_messages and train_ham_messages, you'll notice a few features that make extracting words tricky.

First, there are a couple of odd text artefacts. The string '3D' shows up in strange places in HTML attributes and other places, and we'll remove these. Furthermore there seem to be some mid-word line wraps flagged with an '=' where the word is broken across lines. For example, the work 'apple' might be split across lines like 'app=\nle'. We want to strip these out so we can recover 'apple'. We'll want to deal with all these first, before we apply the tokenizer.

Second, there's a lot of HTML in the messages. We'll have to decide first whether we want to keep HTML info in our set of words. If we do, we'll apply wordpunct_tokenize to some HTML, for example:

"<HEAD></HEAD><BODY><!-- Comment -->"

and it will tokenize to:

["<", "HEAD", "></", "HEAD", "><", "BODY", "><!--", "Comment", "-->"]

So if we drop the punctuation tokens, and get the unique set of what remains, we'd have {"HEAD", "BODY", "Comment"}, which seems like what we'd want. For example, it's nice that this method doesn't make, <HEAD> and </HEAD> separate words in our set, but just captures the existence of this tag with the term "HEAD". It might be a problem that we won't distinguish between the HTML tag <HEAD> and "head" used as an English word in the message. But for the moment I'm willing to bet that sort of conflation won't have a big affect on the classifier.

If we don't want to count HTML information in our set of words, we can set the strip_html to True, and we'll take all the HTML tags out before tokenizing.

Lastly we'll strip out any "stopwords" from the set. Stopwords are highly common, therefore low information words, like "a", "the", "he", etc. Below I'll use stopwords, downloaded from NLTK's corpus library, with a minor modifications to deal with this. (In other programs I've used the stopwords exported from R's tm package.)

Note that because our tokenizer splits contractions ("she'll" $\rightarrow$ "she", "ll"), we'd like to drop the ends ("ll"). Some of these may be picked up in NLTK's stopwords list, others we'll manually add. It's an imperfect, but easy solution. There are more sophisticated ways of dealing with this which are overkill for our purposes.

Tokenizing, as perhaps you can tell, is a non-trivial operation. NLTK has a host of other tokenizing functions of varying sophistication, and even lets you define your own tokenizing rule using regex.

In [5]:
def get_msg_words(msg, stopwords = [], strip_html = False):
    Returns the set of unique words contained in an e-mail message. Excludes 
    any that are in an optionally-provided list. 

    NLTK's 'wordpunct' tokenizer is used, and this will break contractions.
    For example, don't -> (don, ', t). Therefore, it's advisable to supply
    a stopwords list that includes contraction parts, like 'don' and 't'.
    # Strip out weird '3D' artefacts.
    msg = re.sub('3D', '', msg)
    # Strip out html tags and attributes and html character codes,
    # like &nbsp; and &lt;.
    if strip_html:
        msg = re.sub('<(.|\n)*?>', ' ', msg)
        msg = re.sub('&\w+;', ' ', msg)
    # wordpunct_tokenize doesn't split on underscores. We don't
    # want to strip them, since the token first_name may be informative
    # moreso than 'first' and 'name' apart. But there are tokens with long
    # underscore strings (e.g. 'name_________'). We'll just replace the
    # multiple underscores with a single one, since 'name_____' is probably
    # not distinct from 'name___' or 'name_' in identifying spam.
    msg = re.sub('_+', '_', msg)

    # Note, remove '=' symbols before tokenizing, since these are
    # sometimes occur within words to indicate, e.g., line-wrapping.
    msg_words = set(wordpunct_tokenize(msg.replace('=\n', '').lower()))
    # Get rid of stopwords
    msg_words = msg_words.difference(stopwords)
    # Get rid of punctuation tokens, numbers, and single letters.
    msg_words = [w for w in msg_words if'[a-zA-Z]', w) and len(w) > 1]
    return msg_words

The stopwords list. While it contains some terms to account for contractions, we'll add a couple more.

In [6]:
sw = stopwords.words('english')
sw.extend(['ll', 've'])

Making a (features, label) list

The NaiveBayesClassifier function trains on data that's of the form [(features1, label1), features2, label2), ..., (featuresN, labelN)] where featuresi is a dictionary of features for e-mail i and labeli is the label for e-mail `i' ("spam" or "ham").

The function features_from_messages iterates through the messages creating this list, but calls an outside function to create the features for each e-mail. This makes the function modular in case we decide to try out some other method of extracting features from the e-mails besides the set of word. It then combines the features to the e-mail's label in a tuple and adds the tuple to the list.

The word_indicator function calls get_msg_words() to get an e-mail's words as a set, then creates a dictionary with entries {word: True} for each word in the set. This is a little counter-intuitive (since we don't have {word: False} entries for words not in the set) but NaiveBayesClassifier knows how to handle it.

In [7]:
def features_from_messages(messages, label, feature_extractor, **kwargs):
    Make a (features, label) tuple for each message in a list of a certain,
    label of e-mails ('spam', 'ham') and return a list of these tuples.

    Note every e-mail in 'messages' should have the same label.
    features_labels = []
    for msg in messages:
        features = feature_extractor(msg, **kwargs)
        features_labels.append((features, label))
    return features_labels

def word_indicator(msg, **kwargs):
    Create a dictionary of entries {word: True} for every unique
    word in a message.

    Note **kwargs are options to the word-set creator,
    features = defaultdict(list)
    msg_words = get_msg_words(msg, **kwargs)
    for  w in msg_words:
            features[w] = True
    return features

Training and evaluating the classifier

The following is just a helper function to make training and testing data from the messages. Notice we combine the training spam and training ham into a single set, since we need to train our classifier on data with both spam and ham in it.

In [8]:
def make_train_test_sets(feature_extractor, **kwargs):
    Make (feature, label) lists for each of the training 
    and testing lists.
    train_spam = features_from_messages(train_spam_messages, 'spam', 
                                        feature_extractor, **kwargs)
    train_ham = features_from_messages(train_easyham_messages, 'ham', 
                                       feature_extractor, **kwargs)
    train_set = train_spam + train_ham

    test_spam = features_from_messages(test_spam_messages, 'spam',
                                       feature_extractor, **kwargs)

    test_ham = features_from_messages(test_easyham_messages, 'ham',
                                      feature_extractor, **kwargs)

    test_hardham = features_from_messages(test_hardham_messages, 'ham',
                                          feature_extractor, **kwargs)
    return train_set, test_spam, test_ham, test_hardham

Finally we make a function to run the classifier and check its accuracy on test data. After training the classifier, we check how accurately it classifies data in new spam, "easy" ham, and "hard" ham datasets.

The function then prints out the results of NaiveBayesClassifiers's handy show_most_informative_features method. This shows which features are most unique to one label or another. For example, if "viagra" shows up in 500 of the spam e-mails, but only 2 of the "ham" e-mails in the training set, then the method will show that "viagra" is one of the most informative features with a spam:ham ratio of 250:1.

In [9]:
def check_classifier(feature_extractor, **kwargs):
    Train the classifier on the training spam and ham, then check its accuracy
    on the test data, and show the classifier's most informative features.
    # Make training and testing sets of (features, label) data
    train_set, test_spam, test_ham, test_hardham = \
        make_train_test_sets(feature_extractor, **kwargs)
    # Train the classifier on the training set
    classifier = NaiveBayesClassifier.train(train_set)
    # How accurate is the classifier on the test sets?
    print ('Test Spam accuracy: {0:.2f}%'
       .format(100 * nltk.classify.accuracy(classifier, test_spam)))
    print ('Test Ham accuracy: {0:.2f}%'
       .format(100 * nltk.classify.accuracy(classifier, test_ham)))
    print ('Test Hard Ham accuracy: {0:.2f}%'
       .format(100 * nltk.classify.accuracy(classifier, test_hardham)))

    # Show the top 20 informative features
    print classifier.show_most_informative_features(20)

First, we run the classifier keeping all the HTML information in the feature set. The accuracy at identifying spam and ham is very high. Unsurprisingly, we do a lousy job at identifying hard ham.

This may be because our training set is relying too much on HTML tags to identify spam. As we can see, HTML info comprises all the most_informative_features.

In [10]:
check_classifier(word_indicator, stopwords = sw)
Test Spam accuracy: 98.71%
Test Ham accuracy: 97.07%
Test Hard Ham accuracy: 13.71%
Most Informative Features
                   align = True             spam : ham    =    119.7 : 1.0
                      tr = True             spam : ham    =    115.7 : 1.0
                      td = True             spam : ham    =    111.7 : 1.0
                   arial = True             spam : ham    =    107.7 : 1.0
             cellpadding = True             spam : ham    =     97.0 : 1.0
             cellspacing = True             spam : ham    =     94.3 : 1.0
                     img = True             spam : ham    =     80.3 : 1.0
                 bgcolor = True             spam : ham    =     67.4 : 1.0
                    href = True             spam : ham    =     67.0 : 1.0
                    sans = True             spam : ham    =     62.3 : 1.0
                 colspan = True             spam : ham    =     61.0 : 1.0
                    font = True             spam : ham    =     61.0 : 1.0
                  valign = True             spam : ham    =     60.3 : 1.0
                      br = True             spam : ham    =     59.6 : 1.0
                 verdana = True             spam : ham    =     57.7 : 1.0
                    nbsp = True             spam : ham    =     57.4 : 1.0
                   color = True             spam : ham    =     54.4 : 1.0
                  ff0000 = True             spam : ham    =     53.0 : 1.0
                  ffffff = True             spam : ham    =     50.6 : 1.0
                  border = True             spam : ham    =     49.6 : 1.0

If we try just using the text of the messages, without the HTML tags and information, we lose a tiny bit of accuracy in identifying spam but do much better with the hard ham.

In [11]:
check_classifier(word_indicator, stopwords = sw, strip_html = True)
Test Spam accuracy: 96.64%
Test Ham accuracy: 98.64%
Test Hard Ham accuracy: 56.05%
Most Informative Features
                    dear = True             spam : ham    =     41.7 : 1.0
                     aug = True              ham : spam   =     38.3 : 1.0
              guaranteed = True             spam : ham    =     35.0 : 1.0
              assistance = True             spam : ham    =     29.7 : 1.0
                  groups = True              ham : spam   =     27.9 : 1.0
                mailings = True             spam : ham    =     25.0 : 1.0
               sincerely = True             spam : ham    =     23.0 : 1.0
                    fill = True             spam : ham    =     23.0 : 1.0
                mortgage = True             spam : ham    =     21.7 : 1.0
                     sir = True             spam : ham    =     21.0 : 1.0
                 sponsor = True              ham : spam   =     20.3 : 1.0
                 article = True              ham : spam   =     20.3 : 1.0
                  assist = True             spam : ham    =     19.0 : 1.0
                  income = True             spam : ham    =     18.6 : 1.0
                     tue = True              ham : spam   =     18.3 : 1.0
                   mails = True             spam : ham    =     18.3 : 1.0
                     iso = True             spam : ham    =     17.7 : 1.0
                   admin = True              ham : spam   =     17.7 : 1.0
                  monday = True              ham : spam   =     17.7 : 1.0
                    earn = True             spam : ham    =     17.0 : 1.0
In [ ]: