Classifying Text

In my last few articles, I've looked at several ways one can apply machine learning, both supervised and unsupervised. This time, I want to bring your attention to a surprisingly simple—but powerful and widespread—use of machine learning, namely document classification.

You almost certainly have seen this technique used in day-to-day life. Actually, you might not have seen it in action, but you certainly have benefited from it, in the form of an email spam filter. You might remember that back in the earliest days of spam filters, you needed to "train" your email program, so that it would know what your real email looked like. Well, that was a machine-learning model in action, being told what "good" documents looked like, as opposed to "bad" documents. Of course, spam filters are far more sophisticated than that nowadays, but as you'll see over the course of this article, there are logical reasons why spammers include innocent-seeming (and irrelevant to their business) words in the text of their spam.

Text classification is a problem many businesses and organizations have to deal with. Whether it's classifying legal documents, medical records or tweets, machine learning can help you look through lots of text, separating it into different groups.

Now, text classification requires a bit more sophistication than working with purely numeric data. In particular, it requires that you spend some time collecting and organizing data into a format that a model can handle. Fortunately, Python's scikit-learn comes with a number of tools that can get you there fairly easily.

Organizing the Data

Many cases of text classification are supervised learning problems—that is, you'll train the model, give it inputs (for example, text documents) and the "right" output for each input (for example, categories). In scikit-learn, the general template for supervised learning is:

model = CLASS(), y)

CLASS is one of the 30 or so Python classes that come with scikit-learn, each of which implements a different type of "estimator"—a machine-learning algorithm. Some estimators work best with supervised classification problems, some work with supervised regression problems, and still others work with clustering (that is, unsupervised classification) problems. You often will be able to choose from among several different estimators, but the general format remains the same.

Once you have created an instance of your estimator, you then have to train it. That's done using the "fit" method, to which you give X (the inputs, as a two-dimensional NumPy array or a Pandas data frame) and y (a one-dimensional NumPy array or a Pandas series). Once the model is trained, you then can invoke its "predict" method, passing it new_data_X, another two-dimensional NumPy array or Pandas data frame. The result is a NumPy array, listing the (numeric) categories into which the inputs should be classified.

One of my favorite parts of using scikit-learn is the fact that so much of it uses the same API. You almost always will be using some combination of "fit" and "predict" on your model, no matter what kind of model you're using.

As a general rule, machine-learning models require that inputs be numeric. So, you turn category names into numbers, country names into numbers, color names into numbers—basically, everything has to be a number.