Classifying Text

Now I have to turn this list of strings into a DTM. Fortunately, scikit-learn comes with a number of "feature extraction" tools to make this easy:

from sklearn.feature_extraction.text import CountVectorizer

cv = CountVectorizer()
cv_dtm = cv.fit_transform(input_text)

CountVectorizer isn't the only way to create a DTM. Indeed, there are different strategies you can use. Among other things, the granularity of one word, rather than multiple words, might not be appropriate for your text.

Notice that I use cv.fit_transform. This both teaches the vectorizer the vocabulary ("fit") and produces a DTM. I can create new DTMs with this same vocabulary using just "transform"—and I will indeed do this in a little bit, when I want to make a prediction or two.

Creating a Model

Now I have my inputs in a format that can be used to create a model! You potentially can use a number of algorithms, but one of the most common (and surprisingly accurate) is Naive Bayes. Scikit-learn actually comes with several different versions of Naive Bayes. The one that I use here is called MultinomialNB; it works well with this sort of textual data. (But, of course, it's generally a good idea to test your models and even tweak the inputs and parameters to squeeze better results out of them.) Here's how I create and then train my model:

from sklearn.naive_bayes import MultinomialNB
nb = MultinomialNB(), input_text_categories)

Notice that I've used "fit" twice now: once (on CountVectorizer) to train and create a DTM from the input text and then (on MultinomialNB) to train the model based on that DTM.

The model is now all set! Now I can make some predictions. I'll create some new documents:

docs_new = ['class Foo(object):\nprint "Hello, {}".format(\n',
            'x = [10, 20, 30]\n',
           '10.times do {|i| puts i}']

The docs_new variable contains three strings: the first is in Python, the second could be either Ruby or Python, and the third is in Ruby.

To see how the model categorizes them, I'll first need to create a DTM from these documents. Note that I'm going to reuse cv, the CountVectorizer object. However, I'm not going to use the "fit" method to train it with a new vocabulary. Rather, I'm going to use "transform" to use the existing vocabulary with the new documents. This will allow the model to compare the documents with the previous ones:

docs_new_dtm = cv.transform(docs_new)

Now to make a prediction:


The output is:

array([1, 1, 0])

In other words, the first two documents are seen as Python, and the third is seen as Ruby—not bad, for such a small training set. As you can imagine, the more documents with which you train, the more accurate your categorization is likely to be.

I tried a slight variation on the above code with the "20 newsgroups" data set, using 20,000 postings from 20 different Usenet forum postings. After using CountVectorizer and MultinomialNB just as I did here, the model was able to predict, with a surprisingly high degree of accuracy, the most appropriate newsgroup for a variety of sentences and paragraphs.

Of course, as with everything statistical—including machine learning—the success rate never will be 100%. And indeed, you can (and probably will want to) update the model, tuning the inputs and the model's hyperparameters to try to improve it even more.


Document categorization is a practical application of machine learning that a large number of organizations use—not just in spam filters, but also for sorting through large volumes of text. As you can see, setting up such a model isn't especially difficult, and scikit-learn provides a large number of vectorizers, feature extraction tools and estimators that you can use to create them.


I used Python and the many parts of the SciPy stack (NumPy, SciPy, Pandas, Matplotlib and scikit-learn) in this article. All are available from PyPI or from