Dataset Consolidation using networkx

%%capture
!pip install networkx nltk
import nltk

nltk.download("punkt_tab")
[nltk_data] Downloading package punkt_tab to /home/runner/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.
True

Dataset Consolidation using networkx#

This notebook demonstrates how to consolidate dataset mentions within a given text using networkx and other NLP techniques. The process involves identifying and grouping sentences that mention datasets, extracting relevant spans, and structuring the data to the appropriate format.

Steps include:

  1. Tokenizing the text into sentences.

  2. Using TF-IDF and cosine similarity to find the best matching spans for dataset mentions.

  3. Building a graph of sentence indices to identify connected components representing grouped mentions.

  4. Consolidating the dataset mentions into a structured format.

The following sections will walk through the implementation and usage of the functions defined for this purpose.

import json
import os
from nltk.tokenize import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity


def find_best_matching_span(text, snippet, window: int = 1):
    sents = sent_tokenize(text)
    tfidf = TfidfVectorizer(ngram_range=(1, 3))
    mi_vec = tfidf.fit_transform([snippet])
    sents_vec = tfidf.transform(sents)

    mx_idx = cosine_similarity(mi_vec, sents_vec).flatten().argmax()
    span_sents = sents[max(mx_idx - window, 0) : min(mx_idx + window + 1, len(sents))]

    return {
        "match_idx": mx_idx,
        "match_sent": sents[mx_idx],
        "match_span_sents": span_sents,
        "match_span": " ".join(span_sents),
    }


def find_empirical_span(
    text: str, sentences: list, best_match_idx: int, window: int = 1
):
    # Define the start and end indices to include adjacent sentences for context
    start_idx = text.index(sentences[max(best_match_idx - window, 0)])
    last_sent = sentences[min(best_match_idx + window, len(sentences) - 1)]
    # NOTE: This will fail if the last_sent also occurred in an earlier part of the text.
    # SOLUTION: Start the search for last_sent from the start_idx
    end_idx = start_idx + text[start_idx:].index(last_sent) + len(last_sent)

    # Extract the final span
    context_span = text[start_idx:end_idx]

    return {
        "empirical_span": context_span,  # Extracted span
        "start_idx": start_idx,
        "end_idx": end_idx,
    }


def get_empirical_mentioned_in(
    text, mentioned_in, window: int = 1, with_match_output: bool = False
):
    """
    Extract the most relevant span of text from the original document (`text`)
    that matches the `mentioned_in` field. Returns the span, label, start, and end indices.
    """
    # Tokenize the text into sentences
    sentences = sent_tokenize(text)
    match_output = find_best_matching_span(text, mentioned_in, window=window)
    best_match_idx = match_output["match_idx"]

    output = find_empirical_span(text, sentences, best_match_idx, window=window)
    output["empirical_mentioned_in"] = output.pop("empirical_span")

    output = {
        "label": "mentioned_in",  # Label as "mentioned_in"
        **output,
    }

    if with_match_output:
        output.update(match_output)

    return output
# load helper functions

from copy import deepcopy
import networkx as nx


def consolidate_dataset(raw_text: str, data: dict):
    text = raw_text
    page_data = {"dataset_used": data.get("dataset_used", False), "data_mentions": []}

    G = nx.Graph()
    sents = sent_tokenize(text)
    _datasets = []

    for ds in data.get("dataset", []):
        mentioned_in = ds.pop("mentioned_in") or ""

        try:
            mi = find_best_matching_span(mentioned_in, ds["raw_name"], window=0)
            mi = mi["match_span"]
            match_output = find_best_matching_span(text, mi, window=1)
        except ValueError:
            # Likely that the `mentioned_in` is not found in the text or not correct.
            # We try expanding the search to the entire text.
            match_output = find_best_matching_span(text, ds["raw_name"], window=1)

        ds["sent_spans"] = match_output["match_span_sents"]
        sents_idx = sorted([sents.index(s) for s in ds["sent_spans"]])
        ds["sent"] = match_output["match_sent"]
        ds["sent_idx"] = sents_idx

        G.add_edges_from(zip(sents_idx[:-1], sents_idx[1:]))
        _datasets.append(ds)

    _datasets = sorted(_datasets, key=lambda x: x["sent_idx"][0])

    # The connected components in the graphs form the `mentioned_in`s.
    mentioned_ins = sorted(
        [sorted(x) for x in nx.connected_components(G)], key=lambda x: x[0]
    )
    updated_mentions = []

    for midx in mentioned_ins:
        _mi = {"mentioned_in": " ".join([sents[i] for i in midx]), "datasets": []}

        for ds in _datasets:
            ds = deepcopy(ds)
            if ds["sent_idx"][0] in midx:
                ds.pop("sent_idx")
                ds.pop("sent_spans")
                _mi["datasets"].append(ds)

        updated_mentions.append(_mi)

    page_data["data_mentions"] = updated_mentions

    return page_data


def save_output_per_document(raw_text, data, output_path, page_idx):
    """
    Save output data to a JSON file per document, appending new page data.

    Parameters:
        data (LabelledResponseFormat): The data to save, in the validated format.
        output_path (str): The output path for the document-wide JSON file.
        page_idx (int): The current page index being processed.

    Returns:
        None
    """

    # Restructure and consolidate dataset if possible
    page_data = consolidate_dataset(raw_text, data)

    # Initialize the new page's data structure
    page_data = {"page": page_idx + 1, **page_data}

    # Check if the file already exists
    if os.path.exists(output_path):
        with open(output_path, "r", encoding="utf-8") as existing_file:
            document_data = json.load(existing_file)
    else:
        # Create a new JSON structure
        document_data = {
            "source": os.path.splitext(os.path.basename(output_path))[0],
            "pages": [],
        }

    # Append the new page data
    document_data["pages"].append(page_data)

    # Save the updated document data back to the file
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as output_file:
        json.dump(document_data, output_file, indent=4)
# Example raw text and data
raw_text = """
This is an introductory paragraph that has nothing to do with datasets. It talks about general topics 
and provides no useful information about machine learning or datasets.

The dataset used in this study is the MNIST dataset, which is widely used for image recognition tasks. 
Another dataset mentioned is the CIFAR-10 dataset, which is commonly used for object recognition. 
Both datasets are crucial for benchmarking machine learning models.

Here is another unrelated paragraph discussing the weather and how sunny days are great for outdoor activities. 
It has no connection to the datasets or the study being conducted.
"""
data = {
    "dataset_used": True,
    "dataset": [
        {
            "raw_name": "MNIST",
            "mentioned_in": "The dataset used in this study is the MNIST dataset.",
        },
        {
            "raw_name": "CIFAR-10",
            "mentioned_in": "Another dataset mentioned is the CIFAR-10 dataset.",
        },
    ],
}

# Use the consolidate_dataset function
consolidated_data = consolidate_dataset(raw_text, data)
# Print the consolidated data
print(json.dumps(consolidated_data, indent=4))
{
    "dataset_used": true,
    "data_mentions": [
        {
            "mentioned_in": "It talks about general topics \nand provides no useful information about machine learning or datasets. The dataset used in this study is the MNIST dataset, which is widely used for image recognition tasks. Another dataset mentioned is the CIFAR-10 dataset, which is commonly used for object recognition. Both datasets are crucial for benchmarking machine learning models.",
            "datasets": [
                {
                    "raw_name": "MNIST",
                    "sent": "The dataset used in this study is the MNIST dataset, which is widely used for image recognition tasks."
                },
                {
                    "raw_name": "CIFAR-10",
                    "sent": "Another dataset mentioned is the CIFAR-10 dataset, which is commonly used for object recognition."
                }
            ]
        }
    ]
}