In my last few articles, I've looked at several ways one can apply machine learning, both supervised and unsupervised. This time, I want to bring your attention to a surprisingly simple—but powerful and widespread—use of machine learning, namely document classification.

You almost certainly have seen this technique used in day-to-day life. Actually, you might not have seen it in action, but you certainly have benefited from it, in the form of an email spam filter. You might remember that back in the earliest days of spam filters, you needed to "train" your email program, so that it would know what your real email looked like. Well, that was a machine-learning model in action, being told what "good" documents looked like, as opposed to "bad" documents. Of course, spam filters are far more sophisticated than that nowadays, but as you'll see over the course of this article, there are logical reasons why spammers include innocent-seeming (and irrelevant to their business) words in the text of their spam.

Text classification is a problem many businesses and organizations have to deal with. Whether it's classifying legal documents, medical records or tweets, machine learning can help you look through lots of text, separating it into different groups.

Now, text classification requires a bit more sophistication than working with purely numeric data. In particular, it requires that you spend some time collecting and organizing data into a format that a model can handle. Fortunately, Python's scikit-learn comes with a number of tools that can get you there fairly easily.

Organizing the Data

Many cases of text classification are supervised learning problems—that is, you'll train the model, give it inputs (for example, text documents) and the "right" output for each input (for example, categories). In scikit-learn, the general template for supervised learning is:

model = CLASS(), y)

CLASS is one of the 30 or so Python classes that come with scikit-learn, each of which implements a different type of "estimator"—a machine-learning algorithm. Some estimators work best with supervised classification problems, some work with supervised regression problems, and still others work with clustering (that is, unsupervised classification) problems. You often will be able to choose from among several different estimators, but the general format remains the same.

Once you have created an instance of your estimator, you then have to train it. That's done using the "fit" method, to which you give X (the inputs, as a two-dimensional NumPy array or a Pandas data frame) and y (a one-dimensional NumPy array or a Pandas series). Once the model is trained, you then can invoke its "predict" method, passing it new_data_X, another two-dimensional NumPy array or Pandas data frame. The result is a NumPy array, listing the (numeric) categories into which the inputs should be classified.

One of my favorite parts of using scikit-learn is the fact that so much of it uses the same API. You almost always will be using some combination of "fit" and "predict" on your model, no matter what kind of model you're using.

As a general rule, machine-learning models require that inputs be numeric. So, you turn category names into numbers, country names into numbers, color names into numbers—basically, everything has to be a number.

How, then, can you deal with textual data? It's true that bytes are numbers, but that won't really help here; you want to deal with words and sentences, not with individual characters.

The answer is to turn documents into a DTM—a "document term matrix" in which the columns are the words that were used across the documents, and the rows indicate whether (and how many times) that word existed in the document.

For example, take the following three sentences:

  • I'm hungry, and need to eat lunch.

  • Call me, and we'll go eat.

  • Do you need to eat?

Let's turn the above into a DTM:

i'm hungry and need to eat lunch call me we'll go do you
1     1     1    1  1    1   1   0    0     0  0   0  0
0     0     1    0  0    1   0   1    1     1  1   0  0
0     0     0    1  1    1   0   0    0     0  0   1  1

Now, this DTM certainly does a good job of summarizing which words appeared in which documents. But with just three short sentences that I constructed to have overlapping vocabulary, it's already starting to get fairly wide. Imagine what would happen if you were to categorize a large number of documents; the DTM would be massive! Moreover, the DTM would mostly consist of zeros.

For this reason, a DTM usually is implemented as a "sparse matrix", listing the coordinates of where the value is non-zero. That tends to crunch down its size and, thus, processing time, quite a lot.

It's this DTM that you'll feed into your model. Because it's numeric, the model can handle it—and, thus, can make predictions. Note that you'll actually need to make two different DTMs: one for training the model and another for handing it the text you want to categorize.

Creating a DTM

I decided to do a short experiment to see if I could create a machine-learning model that knows how to differentiate between Python and Ruby code. Not only do I have a fair amount of such code on my computer, but the languages have similar vocabularies, and I was wondering how accurately a model could actually do some categorization.

So, the first task was to create a Python list of text, with a parallel list of numeric categories. I did this using some list comprehensions, as follows:

from glob import glob

# read Ruby files
ruby_files = [open(filename).read()
              for filename in glob("Programs/*.rb")]

# read Python files
python_files = [open(filename).read()
                for filename in glob("Programs/*.py")]

# all input files
input_text = ruby_files + python_files

# set up categories
input_text_categories = [0] * len(ruby_files) + [1]
 ↪* len(python_files)

After this code is run, I have a list (input_text) of strings and another list (input_text_categories) of integers representing the two categories into which these strings should be classified.

Now I have to turn this list of strings into a DTM. Fortunately, scikit-learn comes with a number of "feature extraction" tools to make this easy:

from sklearn.feature_extraction.text import CountVectorizer

cv = CountVectorizer()
cv_dtm = cv.fit_transform(input_text)

CountVectorizer isn't the only way to create a DTM. Indeed, there are different strategies you can use. Among other things, the granularity of one word, rather than multiple words, might not be appropriate for your text.

Notice that I use cv.fit_transform. This both teaches the vectorizer the vocabulary ("fit") and produces a DTM. I can create new DTMs with this same vocabulary using just "transform"—and I will indeed do this in a little bit, when I want to make a prediction or two.

Creating a Model

Now I have my inputs in a format that can be used to create a model! You potentially can use a number of algorithms, but one of the most common (and surprisingly accurate) is Naive Bayes. Scikit-learn actually comes with several different versions of Naive Bayes. The one that I use here is called MultinomialNB; it works well with this sort of textual data. (But, of course, it's generally a good idea to test your models and even tweak the inputs and parameters to squeeze better results out of them.) Here's how I create and then train my model:

from sklearn.naive_bayes import MultinomialNB
nb = MultinomialNB(), input_text_categories)

Notice that I've used "fit" twice now: once (on CountVectorizer) to train and create a DTM from the input text and then (on MultinomialNB) to train the model based on that DTM.

The model is now all set! Now I can make some predictions. I'll create some new documents:

docs_new = ['class Foo(object):\nprint "Hello, {}".format(\n',
            'x = [10, 20, 30]\n',
           '10.times do {|i| puts i}']

The docs_new variable contains three strings: the first is in Python, the second could be either Ruby or Python, and the third is in Ruby.

To see how the model categorizes them, I'll first need to create a DTM from these documents. Note that I'm going to reuse cv, the CountVectorizer object. However, I'm not going to use the "fit" method to train it with a new vocabulary. Rather, I'm going to use "transform" to use the existing vocabulary with the new documents. This will allow the model to compare the documents with the previous ones:

docs_new_dtm = cv.transform(docs_new)

Now to make a prediction:


The output is:

array([1, 1, 0])

In other words, the first two documents are seen as Python, and the third is seen as Ruby—not bad, for such a small training set. As you can imagine, the more documents with which you train, the more accurate your categorization is likely to be.

I tried a slight variation on the above code with the "20 newsgroups" data set, using 20,000 postings from 20 different Usenet forum postings. After using CountVectorizer and MultinomialNB just as I did here, the model was able to predict, with a surprisingly high degree of accuracy, the most appropriate newsgroup for a variety of sentences and paragraphs.

Of course, as with everything statistical—including machine learning—the success rate never will be 100%. And indeed, you can (and probably will want to) update the model, tuning the inputs and the model's hyperparameters to try to improve it even more.


Document categorization is a practical application of machine learning that a large number of organizations use—not just in spam filters, but also for sorting through large volumes of text. As you can see, setting up such a model isn't especially difficult, and scikit-learn provides a large number of vectorizers, feature extraction tools and estimators that you can use to create them.


I used Python and the many parts of the SciPy stack (NumPy, SciPy, Pandas, Matplotlib and scikit-learn) in this article. All are available from PyPI or from

Reuven Lerner teaches Python, data science and Git to companies around the world. You can subscribe to his free, weekly "better developers" e-mail list, and learn from his books and courses at Reuven lives with his wife and children in Modi'in, Israel.

Load Disqus comments