{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0d4ca8f3",
   "metadata": {},
   "source": [
    "# SentSim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "422561a2",
   "metadata": {},
   "source": [
    "## Imports and Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "822dfc8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer, util\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3797187e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Column index of feedback in CSV\n",
    "data_column = 2\n",
    "# Header row number to start parsing after\n",
    "data_row = 1\n",
    "# How many total test examples are provided\n",
    "num_tests = 5\n",
    "test_file = 'tests.csv'\n",
    "# Path to desired SentenceTransformer model (relative to MYPATH)\n",
    "model_path = 'models/sentence-transformers_all-mpnet-base-v2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50d4e596",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the path to the install folder set in run.sh\n",
    "MYPATH = os.getenv('MYPATH')\n",
    "directory = MYPATH+'/data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "561f350c",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_list = os.listdir(directory)\n",
    "file_path = os.path.join(directory, file_list[0])\n",
    "print(file_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28b7fdc0",
   "metadata": {},
   "source": [
    "## Data Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a7d8d2cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ID</td>\n",
       "      <td>Score</td>\n",
       "      <td>Feedback</td>\n",
       "      <td>Notes</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>I would change the dining hall to have more ty...</td>\n",
       "      <td>Apple</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>17</td>\n",
       "      <td>2</td>\n",
       "      <td>I dont like having to walk to class.</td>\n",
       "      <td>Banana</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>39</td>\n",
       "      <td>9</td>\n",
       "      <td>More buses</td>\n",
       "      <td>Cucumber</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>235423</td>\n",
       "      <td>6</td>\n",
       "      <td>There are too many mosquitoes on campus. They ...</td>\n",
       "      <td>Duck</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        0      1                                                  2         3\n",
       "0      ID  Score                                           Feedback     Notes\n",
       "1       1      2  I would change the dining hall to have more ty...     Apple\n",
       "2      17      2               I dont like having to walk to class.    Banana\n",
       "3      39      9                                         More buses  Cucumber\n",
       "4  235423      6  There are too many mosquitoes on campus. They ...      Duck"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv(file_path, keep_default_na=False, header=None, encoding = 'unicode_escape', engine ='python')\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "819e9d74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>size of classes</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>mass transit</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>public safety</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>quality of food</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>robot uprising</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 0\n",
       "0  size of classes\n",
       "1     mass transit\n",
       "2    public safety\n",
       "3  quality of food\n",
       "4   robot uprising"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tests = pd.read_csv(MYPATH+test_file, keep_default_na=False, header=None, encoding = 'unicode_escape', engine ='python')\n",
    "tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "04da4e6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    I would change the dining hall to have more ty...\n",
       "1                 I dont like having to walk to class.\n",
       "2                                           More buses\n",
       "3    There are too many mosquitoes on campus. They ...\n",
       "Name: 2, dtype: object"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tests_text = tests.loc[:][0]\n",
    "# SentenceTransformer wants indexing to start at 0, so we reset\n",
    "#     the index and drop the old indices\n",
    "feedback_text = data.loc[data_row:][data_column].reset_index(drop=True)\n",
    "feedback_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91207471",
   "metadata": {},
   "source": [
    "## Using the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2fe1c46f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to the folder that contains the config.json file\n",
    "# This uses the following HF model:\n",
    "#  https://huggingface.co/sentence-transformers/all-mpnet-base-v2\n",
    "model = SentenceTransformer(MYPATH+model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1535812d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tests = model.encode(tests_text)\n",
    "feedback = model.encode(feedback_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "88bd9579",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 768)\n",
      "(4, 768)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[ 0.00531309,  0.01246637, -0.00990538, ...,  0.02759883,\n",
       "        -0.01668295,  0.01243234],\n",
       "       [-0.02030246,  0.01770941, -0.00373942, ..., -0.04288071,\n",
       "        -0.03075184, -0.01293491],\n",
       "       [-0.00247474,  0.00174774,  0.01997339, ..., -0.02406575,\n",
       "         0.01295806, -0.00414417],\n",
       "       [ 0.00664289,  0.0805131 , -0.02354824, ...,  0.03988045,\n",
       "         0.04549433,  0.00696457],\n",
       "       [ 0.00536655,  0.06831884,  0.02261301, ...,  0.04422512,\n",
       "        -0.03214771, -0.01640472]], dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(tests.shape)\n",
    "print(feedback.shape)\n",
    "tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d52d3dbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1447,  0.0726,  0.0951,  0.1393,  0.0647],\n",
       "        [ 0.1878,  0.3848,  0.1217, -0.0127,  0.0539],\n",
       "        [ 0.2149,  0.5412,  0.3327,  0.1264,  0.1854],\n",
       "        [ 0.1083,  0.0889,  0.1487,  0.0124,  0.0532]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "similarity = util.cos_sim(feedback, tests)\n",
    "similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6c1afb39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0    size of classes\n",
      "1       mass transit\n",
      "2      public safety\n",
      "3    quality of food\n",
      "4     robot uprising\n",
      "Name: 0, dtype: object\n",
      "0    I would change the dining hall to have more ty...\n",
      "1                 I dont like having to walk to class.\n",
      "2                                           More buses\n",
      "3    There are too many mosquitoes on campus. They ...\n",
      "Name: 2, dtype: object\n"
     ]
    }
   ],
   "source": [
    "print(tests_text)\n",
    "print(feedback_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46557f27",
   "metadata": {},
   "source": [
    "## Examples of behavior\n",
    "### World Knowledge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bbb0c13d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_compare(a,b):\n",
    "    temp =  util.cos_sim(model.encode(a), model.encode(b)).tolist()[0]\n",
    "    for num in temp:\n",
    "        print(f'{num:.4f}')\n",
    "\n",
    "def my_compare2(a,b):\n",
    "    return util.cos_sim(model.encode(a), model.encode(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ef019440",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7027\n"
     ]
    }
   ],
   "source": [
    "a = \"The first president of the United States had dogs.\"\n",
    "b = \"George Washington owned four French hounds.\"\n",
    "my_compare(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8b6e415c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1135\n",
      "0.0883\n"
     ]
    }
   ],
   "source": [
    "c = \"Michael Phelps swam really well.\"\n",
    "my_compare(c,[a,b])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cff1f409",
   "metadata": {},
   "source": [
    "### Score scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "db751187",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.0642\n",
      "-0.0111\n"
     ]
    }
   ],
   "source": [
    "d = \"trichloroaetic acid isopropyl ester\"\n",
    "my_compare(d,[a,c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "252f58ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2757\n",
      "0.3410\n",
      "0.0324\n"
     ]
    }
   ],
   "source": [
    "e = \"John Adams wrote 1100 letters to his wife.\"\n",
    "my_compare(e,[a,b,c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "55b1fb13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4964\n",
      "0.6256\n",
      "0.0925\n"
     ]
    }
   ],
   "source": [
    "f = \"George Washington owned three houses.\"\n",
    "my_compare(f,[a,b,c])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c22731c",
   "metadata": {},
   "source": [
    "### Score dilution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "bc65563d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5964\n",
      "0.7646\n",
      "0.0587\n"
     ]
    }
   ],
   "source": [
    "g = \"George Washington owned four French hounds. A big old well-worn sofa is very comfy to sit in.\"\n",
    "my_compare(g,[a,b,c])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dea63911",
   "metadata": {},
   "source": [
    "### Score smearing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0a00284d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6073\n",
      "0.7359\n",
      "0.5617\n",
      "-0.0794\n",
      "0.2085\n"
     ]
    }
   ],
   "source": [
    "h = \"George Washington owned four French hounds. Michael Phelps swam really well.\"\n",
    "my_compare(h,[a,b,c,d,e])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb4d6e74",
   "metadata": {},
   "source": [
    "### Ordering matters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "da088090",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5532\n",
      "0.6585\n",
      "0.6619\n",
      "-0.0727\n",
      "0.2046\n"
     ]
    }
   ],
   "source": [
    "h = \"Michael Phelps swam really well. George Washington owned four French hounds.\"\n",
    "my_compare(h,[a,b,c,d,e])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5cd7937",
   "metadata": {},
   "source": [
    "### Further smearing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8509ba87",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.5532,  0.6585,  0.6619, -0.0727,  0.2046],\n",
       "        [ 0.5001,  0.5478,  0.5280,  0.3385,  0.1875]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i = \"Michael Phelps swam really well. George Washington owned four French hounds. trichloroaetic acid isopropyl ester\"\n",
    "my_compare2([h,i],[a,b,c,d,e])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "ba53e2fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1431\n"
     ]
    }
   ],
   "source": [
    "a = \"giant happy sweet lady dancing on summer day\"\n",
    "b = \"tiny miserable sour man comatose under winter night\"\n",
    "my_compare(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adef0296",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
