{ "cells": [ { "cell_type": "markdown", "id": "35a989b1", "metadata": {}, "source": [ "# Model - SentenceTransformers" ] }, { "cell_type": "markdown", "id": "f1a813d7", "metadata": {}, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "8eb77ef5", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from iqual import iqualnlp, evaluation, crossval" ] }, { "cell_type": "markdown", "id": "17a42bc8", "metadata": {}, "source": [ "### Load `annotated (human-coded)` and `unannotated` datasets" ] }, { "cell_type": "code", "execution_count": 2, "id": "a7d035ab", "metadata": {}, "outputs": [], "source": [ "data_dir = \"../../data\"\n", "human_coded_df = pd.read_csv(os.path.join(data_dir,\"annotated.csv\"))\n", "uncoded_df = pd.read_csv(os.path.join(data_dir,\"unannotated.csv\"))" ] }, { "cell_type": "markdown", "id": "9b308f43", "metadata": {}, "source": [ "### Split the data into training and test sets" ] }, { "cell_type": "code", "execution_count": 3, "id": "e0f95233", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Size: 7470\n", "Test Size: 2490\n" ] } ], "source": [ "from sklearn.model_selection import train_test_split\n", "train_df, test_df = train_test_split(human_coded_df,test_size=0.25)\n", "print(f\"Train Size: {len(train_df)}\\nTest Size: {len(test_df)}\")" ] }, { "cell_type": "markdown", "id": "f0ea3ad2", "metadata": {}, "source": [ "### Configure training data" ] }, { "cell_type": "code", "execution_count": 4, "id": "ebffbf2c", "metadata": {}, "outputs": [], "source": [ "### Select Question and Answer Columns\n", "question_col = 'Q_en'\n", "answer_col = 'A_en'\n", "\n", "### Select a code\n", "code_variable = 'marriage'\n", "\n", "### Create X and y\n", "X = train_df[[question_col,answer_col]]\n", "y = train_df[code_variable]" ] }, { "cell_type": "markdown", "id": "23910543", "metadata": {}, "source": [ "### Initiate model" ] }, { "cell_type": "code", "execution_count": 5, "id": "d4b10060", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 12, "id": "ea7cd53b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('Input',\n",
       "                 FeatureUnion(transformer_list=[('question',\n",
       "                                                 Pipeline(steps=[('selector',\n",
       "                                                                  FunctionTransformer(func=<function column_selector at 0x00000254018E8820>,\n",
       "                                                                                      kw_args={'column_name': 'Q_en'})),\n",
       "                                                                 ('vectorizer',\n",
       "                                                                  Vectorizer(env='sentence-transformers',\n",
       "                                                                             model='all-MiniLM-L6-v2'))])),\n",
       "                                                ('answer',\n",
       "                                                 Pipeline(steps=[('selector',\n",
       "                                                                  FunctionTransformer(func=<fun...\n",
       "                                                                  Vectorizer(env='sentence-transformers',\n",
       "                                                                             model='all-MiniLM-L6-v2'))]))])),\n",
       "                ('Classifier',\n",
       "                 Classifier(C=1.0, class_weight=None, dual=False,\n",
       "                            fit_intercept=True, intercept_scaling=1,\n",
       "                            l1_ratio=None, max_iter=100,\n",
       "                            model='LogisticRegression', multi_class='auto',\n",
       "                            n_jobs=None, penalty='l2', random_state=None,\n",
       "                            solver='lbfgs', tol=0.0001, verbose=0,\n",
       "                            warm_start=False)),\n",
       "                ('Threshold', BinaryThresholder())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('Input',\n", " FeatureUnion(transformer_list=[('question',\n", " Pipeline(steps=[('selector',\n", " FunctionTransformer(func=,\n", " kw_args={'column_name': 'Q_en'})),\n", " ('vectorizer',\n", " Vectorizer(env='sentence-transformers',\n", " model='all-MiniLM-L6-v2'))])),\n", " ('answer',\n", " Pipeline(steps=[('selector',\n", " FunctionTransformer(func=" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAN8AAADCCAYAAADJsRdpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAM6klEQVR4nO3dX4xc9XnG8e9TUwh1m2LH7coyCBvVUmSCCngVrLZqTaiwcaSaqFVklNaGuCEppn9UX9QpF0TQqHBBIzlNU9HGAksUQ2kj3MbUdR2vol4sYFoXY1LHi3EUW47dYAI1SKRGby/Ob8mxmfHOzM7MOzt+PtJoz/zO75x552ifnTNnxn4VEZhZ//1EdgFmFyqHzyyJw2eWxOEzS+LwmSVx+MySXJRdQKfmzZsXCxcubLjurbfeYvbs2f0tqAWuqz0zta4XXnjhBxHxc1PuKCJm5G3p0qXRzJ49e5quy+S62jNT6wL2Rgu/wz7tNEvi8JklcfjMkjh8ZkkcPrMkM/ajhvPZf+wNbt/0jewy3mfjNWfaruvIAx/vUTWWza98ZkkcPrMkDp9ZEofPLInDZ5bE4TNLMmX4JF0haY+klyUdkPSHZXyupF2SDpWfc8q4JG2WNCHpRUnX1/a1rsw/JGldbXyppP1lm82S1IsnazZIWnnlOwNsjIglwDJgg6QlwCZgd0QsBnaX+wC3AIvL7U7gq1CFFbgXuAH4KHDvZGDLnM/Utls5/admNtimDF9EHI+I/yjL/wt8G1gArAYeLdMeBW4ty6uBreVfV4wDl0maD6wAdkXEqYh4HdgFrCzrPhgR4+WfY2yt7ctsaLX1nk/SQuA64FlgJCKOl1XfB0bK8gLge7XNjpax840fbTBuNtRa/nqZpJ8G/gH4o4h4s/62LCJCUs//911Jd1KdyjIyMsLY2FjDeSOXVl/lGjSd1NXsOXbT6dOn+/I47Rr2uloKn6SfpAreYxHxj2X4hKT5EXG8nDqeLOPHgCtqm19exo4By88ZHyvjlzeY/z4R8TDwMMDo6GgsX7680TS+/NjTPLR/8L62uvGaM23XdeRTy3tTTM3Y2BjNjmWmYa+rlaudAr4GfDsi/qK2ajswecVyHfB0bXxtueq5DHijnJ7uBG6WNKdcaLkZ2FnWvSlpWXmstbV9mQ2tVv4M/zLwO8B+SfvK2J8CDwBPSloPfBf4ZFm3A1gFTABvA3cARMQpSfcDz5d590XEqbJ8F/AIcCnwTLmZDbUpwxcR/w40+9ztpgbzA9jQZF9bgC0NxvcCH5mqFrNh4m+4mCVx+MySOHxmSRw+syQOn1kSh88sicNnlsThM0vi8JklcfjMkjh8ZkkcPrMkDp9ZEofPLInDZ5bE4TNL4vCZJXH4zJI4fGZJHD6zJA6fWRKHzyyJw2eWxOEzS+LwmSVx+MySOHxmSRw+sySttAjbIumkpJdqY1+QdEzSvnJbVVv3eUkTkg5KWlEbX1nGJiRtqo0vkvRsGX9C0sXdfIJmg6qVV75HgJUNxr8UEdeW2w4ASUuANcDVZZu/kjRL0izgK8AtwBLgtjIX4MGyr18AXgfWT+cJmc0UU4YvIr4FnJpqXrEa2BYR70TEq1Q9+j5abhMRcTgifgRsA1aXZpgfA54q2z8K3NreUzCbmabTO/luSWuBvcDGiHgdWACM1+YcLWMA3ztn/AbgQ8API+JMg/nv457svTHsvc+7ra892Rv4KnA/EOXnQ8Cnp13NFNyTvTeGvfd5t3Wrro5+QyPixOSypL8B/rncPQZcUZt6eRmjyfhrwGWSLiqvfvX5ZkOto48aJM2v3f0EMHkldDuwRtIlkhYBi4HnqPqwLy5XNi+muiizvbSQ3gP8Vtl+HfB0JzWZzTRTvvJJehxYDsyTdBS4F1gu6Vqq084jwGcBIuKApCeBl4EzwIaIeLfs525gJzAL2BIRB8pD/AmwTdKfAf8JfK1bT85skE0Zvoi4rcFw04BExBeBLzYY3wHsaDB+mOpqqNkFxd9wMUvi8JklcfjMkjh8ZkkcPrMkDp9ZEofPLInDZ5bE4TNL4vCZJXH4zJI4fGZJHD6zJA6fWRKHzyyJw2eWxOEzS+LwmSVx+MySOHxmSRw+syQOn1kSh88sicNnlsThM0vi8JklcfjMknTak32upF2SDpWfc8q4JG0u/dVflHR9bZt1Zf4hSetq40sl7S/bbC7das2GXqc92TcBuyNiMbC73Ieq5/ricruTqokmkuZSdTe6gaopyr2TgS1zPlPbrlH/d7Oh02lP9tVU/dPh7D7qq4GtURmnanw5H1gB7IqIU6V99C5gZVn3wYgYL736tuKe7HaB6LR38khEHC/L3wdGyvIC3t97fcEU40cbjDfknuy9Mey9z7stuyf7eyIiJMW0K2ntsdyTvQeGvfd5t3Wrrk6vdp6YbA1dfp4s4816sp9v/PIG42ZDr9Pwbafqnw5n91HfDqwtVz2XAW+U09OdwM2S5pQLLTcDO8u6NyUtK1c51+Ke7HaB6LQn+wPAk5LWA98FPlmm7wBWARPA28AdABFxStL9wPNl3n0RMXkR5y6qK6qXAs+Um9nQ67QnO8BNDeYGsKHJfrYAWxqM7wU+MlUdZsPG33AxS+LwmSVx+MySOHxmSRw+syQOn1kSh88sicNnlsThM0vi8JklcfjMkjh8ZkkcPrMkDp9ZEofPLInDZ5bE4TNL4vCZJXH4zJI4fGZJHD6zJA6fWRKHzyyJw2eWxOEzS+LwmSVx+MySTCt8ko6Ufur7JO0tY13r1242zLrxyndjRFwbEaPlfjf7tZsNrV6cdnalX3sP6jIbKNMNXwD/KumF0i8dutev3WyoTbdx+a9ExDFJPw/skvTf9ZXd7tdeAn4nwMjISNOm9COXVv3PB00ndTV7jt10+vTpvjxOu4a9rmmFLyKOlZ8nJX2d6j3bCUnzI+J4G/3al58zPtbk8R4GHgYYHR2NZk3pv/zY0zy0f7p/V7pv4zVn2q7ryKeW96aYmrGxMZody0zDXlfHp52SZkv6mcllqj7rL9Glfu2d1mU2U0zn5WEE+Lqkyf38XUT8i6Tn6V6/drOh1XH4IuIw8IsNxl+jS/3azYaZv+FilsThM0vi8JklcfjMkjh8ZkkcPrMkDp9ZEofPLInDZ5bE4TNL4vCZJXH4zJI4fGZJHD6zJA6fWRKHzyyJw2eWxOEzS+LwmSVx+MySOHxmSRw+syQOn1mSwfs/1W1oLdz0jbbmb7zmDLe3uU0/PLJydlf241c+syQOn1kSh88sycCET9JKSQdLz/ZNU29hNrMNRPgkzQK+QtW3fQlwm6QluVWZ9dZAhI+qqeZERByOiB8B26h6uJsNrUEJn/uy2wVnRn3OV+/JDpyWdLDJ1HnAD/pTVev+oIO69GCPijnb0ByvfrjxwSnrurKV/QxK+Jr1az9LvSf7+UjaGxGj3SuvO1xXe4a9rkE57XweWCxpkaSLgTVUPdzNhtZAvPJFxBlJdwM7gVnAlog4kFyWWU8NRPgAImIHsKNLu5vy1DSJ62rPUNeliOjGfsysTYPyns/sgjPjwjfV19AkXSLpibL+WUkLa+s+X8YPSlrR57r+WNLLkl6UtFvSlbV170raV25dvdDUQl23S/qf2uP/bm3dOkmHym1dn+v6Uq2m70j6YW1dL4/XFkknJb3UZL0kbS51vyjp+tq69o5XRMyYG9XFmFeAq4CLgf8Clpwz5y7gr8vyGuCJsrykzL8EWFT2M6uPdd0I/FRZ/r3Jusr904nH63bgLxtsOxc4XH7OKctz+lXXOfN/n+oiXE+PV9n3rwLXAy81Wb8KeAYQsAx4ttPjNdNe+Vr5Gtpq4NGy/BRwkySV8W0R8U5EvApMlP31pa6I2BMRb5e741SfZfbadL62twLYFRGnIuJ1YBewMqmu24DHu/TY5xUR3wJOnWfKamBrVMaByyTNp4PjNdPC18rX0N6bExFngDeAD7W4bS/rqltP9ddz0gck7ZU0LunWLtXUTl2/WU6hnpI0+WWHgThe5fR8EfDN2nCvjlcrmtXe9vEamI8aLhSSfhsYBX6tNnxlRByTdBXwTUn7I+KVPpX0T8DjEfGOpM9SnTV8rE+P3Yo1wFMR8W5tLPN4dc1Me+Vr5Wto782RdBHws8BrLW7by7qQ9OvAPcBvRMQ7k+MRcaz8PAyMAdf1q66IeK1Wy98CS1vdtpd11azhnFPOHh6vVjSrvf3j1as3rj16M3wR1RvZRfz4jfrV58zZwNkXXJ4sy1dz9gWXw3TvgksrdV1HdZFh8Tnjc4BLyvI84BDnufjQg7rm15Y/AYzHjy8gvFrqm1OW5/arrjLvw8ARyufRvT5etcdYSPMLLh/n7Asuz3V6vNID1cGBWQV8p/wi31PG7qN6NQH4APD3VBdUngOuqm17T9nuIHBLn+v6N+AEsK/ctpfxXwL2l1/A/cD6Ptf158CB8vh7gA/Xtv10OY4TwB39rKvc/wLwwDnb9fp4PQ4cB/6P6n3beuBzwOfKelH9w+9XyuOPdnq8/A0XsyQz7T2f2dBw+MySOHxmSRw+syQOn1kSh88sicNnlsThM0vy/40CaXdUSIPZAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "uncoded_df[code_variable+'_pred'] = iqual_model.predict(uncoded_df[['Q_en','A_en']])\n", "\n", "uncoded_df[code_variable+\"_pred\"].hist(figsize=(3,3),bins=3)" ] }, { "cell_type": "markdown", "id": "31bcd498", "metadata": {}, "source": [ "### Examples for positive predictions" ] }, { "cell_type": "code", "execution_count": 11, "id": "b1e02a70", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Q: If you want to get married?\n", "A: Yes, I want to get married. But the proposal actually says that people have to pay 1 lakh rupees. I don't mean without money.\n", "\n", "Q: What other dreams about him?\n", "A: My dream is to study while I'm alive and get married soon.\n", "\n", "Q: How do you plan to achieve these?\n", "A: I have saved some money for them. I will borrow some money from people and marry the girls.\n", "\n" ] } ], "source": [ "for idx, row in uncoded_df.loc[(uncoded_df[code_variable+\"_pred\"]==1),['Q_en','A_en']].sample(3).iterrows():\n", " print(\"Q: \",row['Q_en'],\"\\n\",\"A: \", row['A_en'],sep='')\n", " print()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }