Building a Document Classification System#

The NumPy (Numerical Python) library used for working iwith arrays, and the Scikit-learn library is a python library built on NumPy, SciPy and matplotlib for data analytics and machine learning. The NLTK (Natural Language Toolkit) provides access to over 50 corpora and lexical resources such as WordNet, along with a suite of text processing libraries for classification, tokenization, stemming, tagging, parsing, and semantic reasoning, wrappers for industrial-strength NLP libraries.

# Ensuring that you have the necessary libraries
# !pip install nltk
# !pip install numpy
# !pip install scikit-learn
import numpy as np
import nltk
from nltk.corpus import reuters
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, classification_report

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB

1. Load your data#

The Reuters-21578 dataset is one of the most widely used data collections for text categorization research. It is a collection of documents with news articles and the original corpus has 10,369 documents and a vocabulary of 29,930 word and has labeled categories such as “earnings”, “acquisitions”.. etc. You can read metadata about the dataset on Hugging Face

# download the dataset
nltk.download('reuters')
[nltk_data] Downloading package reuters to
[nltk_data]     /Users/dunstanmatekenya/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
True
# Load the Reuters-21578 dataset
documents = reuters.fileids()
train_docs = list(filter(lambda doc: doc.startswith("train"), documents))
test_docs = list(filter(lambda doc: doc.startswith("test"), documents))

2. Prepare your data#

Prepare the data by extracting the raw text and category labels for both the training and testing documents. Assumption is that each document has only one category label, so we take only the first category label for each document.

# Prepare the data
train_data = [reuters.raw(doc_id) for doc_id in train_docs]
train_labels = [reuters.categories(doc_id)[0] for doc_id in train_docs]
test_data = [reuters.raw(doc_id) for doc_id in test_docs]
test_labels = [reuters.categories(doc_id)[0] for doc_id in test_docs]

Question-How many different classes are in the training data?#

Explore some of the training examples#

print("Article content: {} n\, Label: {}".format(train_data[1], train_labels[1]))
Article content: COMPUTER TERMINAL SYSTEMS <CPML> COMPLETES SALE
  Computer Terminal Systems Inc said
  it has completed the sale of 200,000 shares of its common
  stock, and warrants to acquire an additional one mln shares, to
  <Sedio N.V.> of Lugano, Switzerland for 50,000 dlrs.
      The company said the warrants are exercisable for five
  years at a purchase price of .125 dlrs per share.
      Computer Terminal said Sedio also has the right to buy
  additional shares and increase its total holdings up to 40 pct
  of the Computer Terminal's outstanding common stock under
  certain circumstances involving change of control at the
  company.
      The company said if the conditions occur the warrants would
  be exercisable at a price equal to 75 pct of its common stock's
  market price at the time, not to exceed 1.50 dlrs per share.
      Computer Terminal also said it sold the technolgy rights to
  its Dot Matrix impact technology, including any future
  improvements, to <Woodco Inc> of Houston, Tex. for 200,000
  dlrs. But, it said it would continue to be the exclusive
  worldwide licensee of the technology for Woodco.
      The company said the moves were part of its reorganization
  plan and would help pay current operation costs and ensure
  product delivery.
      Computer Terminal makes computer generated labels, forms,
  tags and ticket printers and terminals.
  

 n\, Label: acq

3. Vectorizing the text data#

  • Vectorize the text data using the TfidVectorizer from scikit-learn. TF-IDF is an abbreviation for Term Frequency Inverse Document Frequency. This is very common algorithm to transform text into a meaningful representation of numbers which is used to fit machine algorithm for prediction.

  • Its worth noting that nowadays, this vectorization approach is not commonly used. We will cover word embeddings tomorrow which is a better approach to represent words as numbers because vector embeddings can capture semantic meanings better.

For the sklearn TF-IDF vectorizer, you can learn more about it here

# Vectorize the text data
vectorizer = TfidfVectorizer(stop_words="english", max_features=1000)
X_train = vectorizer.fit_transform(train_data)
X_test = vectorizer.transform(test_data)

Question: What role are the stop words playing in the code above? You might have learned this from Prof. Mohamad Ali already.#

4. Training a Linear Support Vector Machine (LinearSVC) classifier using the vectorized training data and corresponding label#

# Train the classifier
classifier = LinearSVC()
classifier.fit(X_train, train_labels)
LinearSVC()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
classifier.

5. Evaluate the classifier used and calculate the accuracy score as well as some other metrics (Precision, Recall and F-1 score)#

# Evaluate the classifier
y_pred = classifier.predict(X_test)
accuracy = accuracy_score(test_labels, y_pred)
print("Accuracy:", accuracy)
print(classification_report(test_labels, y_pred))
Accuracy: 0.876117919841007
                 precision    recall  f1-score   support

            acq       0.95      0.96      0.96       719
           alum       0.33      0.18      0.24        22
         barley       1.00      0.71      0.83        14
            bop       0.77      0.80      0.79        30
        carcass       0.79      0.65      0.71        17
     castor-oil       0.00      0.00      0.00         1
          cocoa       0.94      1.00      0.97        17
        coconut       0.00      0.00      0.00         2
    coconut-oil       0.00      0.00      0.00         2
         coffee       0.89      0.96      0.92        25
         copper       0.93      0.93      0.93        15
           corn       0.85      0.81      0.83        48
         cotton       1.00      0.86      0.92        14
            cpi       0.62      0.62      0.62        24
            cpu       0.00      0.00      0.00         1
          crude       0.79      0.93      0.86       182
            dfl       0.00      0.00      0.00         1
            dlr       0.70      0.72      0.71        43
            dmk       0.00      0.00      0.00         1
           earn       0.98      0.99      0.98      1083
           fuel       1.00      0.22      0.36         9
            gas       0.75      0.33      0.46         9
            gnp       0.59      0.89      0.71        19
           gold       0.96      0.96      0.96        26
          grain       0.71      0.77      0.74        77
      groundnut       0.00      0.00      0.00         3
           heat       1.00      0.75      0.86         4
            hog       1.00      0.50      0.67         4
        housing       1.00      0.67      0.80         3
         income       1.00      0.80      0.89         5
    instal-debt       1.00      1.00      1.00         1
       interest       0.78      0.76      0.77       124
            ipi       1.00      1.00      1.00        11
     iron-steel       0.69      0.64      0.67        14
            jet       0.00      0.00      0.00         1
           jobs       0.73      0.85      0.79        13
       l-cattle       0.00      0.00      0.00         2
           lead       0.83      0.42      0.56        12
            lei       1.00      1.00      1.00         3
      livestock       0.50      0.50      0.50         6
         lumber       0.00      0.00      0.00         5
      meal-feed       0.20      0.17      0.18         6
       money-fx       0.65      0.65      0.65        96
   money-supply       0.80      0.83      0.81        29
        naphtha       0.00      0.00      0.00         1
        nat-gas       0.64      0.54      0.58        13
         nickel       0.00      0.00      0.00         1
        oilseed       0.54      0.54      0.54        13
         orange       0.75      0.33      0.46         9
      palladium       0.00      0.00      0.00         1
       palm-oil       0.67      1.00      0.80         4
       pet-chem       1.00      0.50      0.67         6
       platinum       0.00      0.00      0.00         3
         potato       1.00      0.67      0.80         3
        propane       0.00      0.00      0.00         2
       rape-oil       0.00      0.00      0.00         1
       reserves       1.00      0.64      0.78        14
         retail       1.00      1.00      1.00         1
           rice       0.00      0.00      0.00         1
         rubber       0.69      1.00      0.82         9
           ship       0.39      0.41      0.40        39
         silver       0.00      0.00      0.00         0
        soy-oil       0.00      0.00      0.00         2
        soybean       0.00      0.00      0.00         2
strategic-metal       0.00      0.00      0.00         6
          sugar       0.71      0.96      0.81        25
            tea       0.00      0.00      0.00         3
            tin       0.71      0.50      0.59        10
          trade       0.70      0.93      0.80        76
        veg-oil       0.54      0.64      0.58        11
            wpi       0.62      0.56      0.59         9
            yen       0.00      0.00      0.00         6
           zinc       0.00      0.00      0.00         5

       accuracy                           0.88      3019
      macro avg       0.53      0.48      0.49      3019
   weighted avg       0.86      0.88      0.87      3019
/Users/dunstanmatekenya/anaconda3/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/dunstanmatekenya/anaconda3/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/dunstanmatekenya/anaconda3/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/dunstanmatekenya/anaconda3/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/dunstanmatekenya/anaconda3/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/dunstanmatekenya/anaconda3/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))

6. Classify new documents (new BBC headlines) by vectorizing them using the same TfidfVectorizer and predicting their labels using the trained classifier#

# Classify new documents (recent headlines obtained from BBC news regarding Tunisia)
new_docs = [
    "Tunisia says 23 people missing in Mediterranean sea.",
    "Tunisia officials arrested in dispute over flag display.",
    "Tunisia lawyer arrested during live news broadcast."
]
new_docs_vectors = vectorizer.transform(new_docs)
predicted_labels = classifier.predict(new_docs_vectors)
print("Predicted labels:", predicted_labels)
Predicted labels: ['ship' 'ship' 'acq']

Discussion#

How did this classifier fare? What can you do to improve the model?
Ans: Experimenting with different preprocessing techniques, feature extraction models and classification algorithms.

Trying with a different classifier#

Steps 1 - 3 will be the same.

# Load the Reuters-21578 dataset
documents = reuters.fileids()
train_docs = list(filter(lambda doc: doc.startswith("train"), documents))
test_docs = list(filter(lambda doc: doc.startswith("test"), documents))

# Prepare the data
train_data = [reuters.raw(doc_id) for doc_id in train_docs]
train_labels = [reuters.categories(doc_id)[0] for doc_id in train_docs]
test_data = [reuters.raw(doc_id) for doc_id in test_docs]
test_labels = [reuters.categories(doc_id)[0] for doc_id in test_docs]

# Vectorize the text data
vectorizer = CountVectorizer(stop_words="english", max_features=1000)
X_train = vectorizer.fit_transform(train_data)
X_test = vectorizer.transform(test_data)

Different Classifier (Multinomial Naive Bayes)#

classifier = MultinomialNB()
classifier.fit(X_train, train_labels)
# Evaluate the classifier
y_pred = classifier.predict(X_test)
accuracy = accuracy_score(test_labels, y_pred)
print("Accuracy:", accuracy)
print(classification_report(test_labels, y_pred))
# Classify new documents (recent headlines obtained from BBC news regarding Tunisia)
new_docs = [
    "Tunisia says 23 people missing in Mediterranean sea.",
    "Tunisia officials arrested in dispute over flag display.",
    "Tunisia lawyer arrested during live news broadcast."
]
new_docs_vectors = vectorizer.transform(new_docs)
predicted_labels = classifier.predict(new_docs_vectors)
print("Predicted labels:", predicted_labels)

Discussion: Compare the results#

The choice of classifier depends on the specific characteristics of your dataset and the problem at hand. Multinomial Naive Bayes is known to work well with text data and can handle high-dimensional feature spaces efficiently. However, it assumes that the features are independent of each other, which may not always be the case in real-world scenarios.

You can also experiment with different classifiers, such as Logistic Regression, Random Forest, or Gradient Boosting, and compare their performance to find the best fit for your dataset. You can also refine the model by trying different feature extraction techniques and hyperparameters.

There are also other ways you can approach this, for example, Document Classification using BERT. Here is a notebook example on Kaggle that you can explore: https://www.kaggle.com/code/merishnasuwal/document-classification-using-bert#

BERT (Bidirectional Encoder Representations from Transformers) and other Transformer encoder architectures can also be used on a variety of tasks in NLP (natural language processing). They compute vector-space representations of natural language that are suitable for use in deep learning models. The BERT family of models uses the Transformer encoder architecture to process each token of input text in the full context of all tokens before and after. BERT models are usually pre-trained on a large corpus of text, then fine-tuned for specific tasks.