diff --git a/notebook/HITS.ipynb b/notebook/HITS.ipynb new file mode 100644 index 0000000..6a44794 --- /dev/null +++ b/notebook/HITS.ipynb @@ -0,0 +1,264 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "HITS.ipynb", + "private_outputs": true, + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyN4Xg++97xwPc2Nr8C+c6hx", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import torch" + ], + "metadata": { + "id": "Yl-axB2vbPSR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install git+https://github.com/nickprock/influencer.git" + ], + "metadata": { + "id": "wSSHSYpDbUYB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import influencer\n", + "influencer.__version__" + ], + "metadata": { + "id": "QzAd2MnIbnqB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade jax jaxlib" + ], + "metadata": { + "id": "4017MOi1o0UU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from influencer.centrality import hits as npHITS\n", + "from influencer.torch_centrality import hits as torchHITS" + ], + "metadata": { + "id": "JcdonChabqGi" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "torch.cuda.is_available()" + ], + "metadata": { + "id": "TQe7eU-pUXcU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# lazy_cerntrality version\n", + "\n", + "import jax.numpy as jnp\n", + "from jax import jit\n", + "\n", + "def jhits(adjMatrix, p: int = 100):\n", + " n = adjMatrix.shape[0]\n", + " \n", + " a = jnp.ones([1,n])\n", + " h = jnp.ones([1,n])\n", + " \n", + " pa=a\n", + " \n", + " authority = {}\n", + " hub = {}\n", + " \n", + " for k in range(1,p):\n", + " h1 = jnp.dot(adjMatrix, pa.T)/jnp.linalg.norm(jnp.dot(adjMatrix, pa.T))\n", + " a1 = jnp.dot(adjMatrix.T, h1)/jnp.linalg.norm(jnp.dot(adjMatrix.T , h1))\n", + " \n", + " h = jnp.vstack((h,jnp.dot(adjMatrix, a[k-1,:].T)/jnp.linalg.norm(jnp.dot(adjMatrix, a[k-1,:].T))))\n", + " a = jnp.vstack((a,jnp.dot(adjMatrix.T, h[k,:].T)/jnp.linalg.norm(jnp.dot(adjMatrix.T, h[k,:].T))))\n", + " \n", + " pa = a1.T\n", + " \n", + " for i in range(n):\n", + " authority[str(i)] = a[-1,i]\n", + " hub[str(i)] = h[-1,i]\n", + " \n", + " return hub, authority, h, a" + ], + "metadata": { + "id": "MfiQGaGVo1lL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "jit_jhits = jit(jhits)" + ], + "metadata": { + "id": "RI6BjpI6o1ZL" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import time" + ], + "metadata": { + "id": "DLRt_OOEUGAP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "np.random.seed(42)\n", + "\n", + "num_nodes = [x for x in range(500,15000, 500)]\n", + "time_np = []\n", + "time_torch = []\n", + "time_torch_cpu = []\n", + "time_jnp = []" + ], + "metadata": { + "id": "eaDzCz8cUGk9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for N in num_nodes:\n", + " adjM = np.random.rand(N, N)\n", + " adjM[adjM>0.5]=1\n", + " adjM[adjM<=0.5]=0\n", + " start_time1 = time.time()\n", + " _, _,_,_ = npHITS(adjM, p=10)\n", + " exe_time1 = time.time() - start_time1\n", + " MT = torch.from_numpy(adjM).float().to(0)\n", + " start_time2 = time.time()\n", + " _,_,_,_ = torchHITS(MT, p=10)\n", + " exe_time2 = time.time() - start_time2\n", + " MT_cpu = torch.from_numpy(adjM).float()\n", + " start_time3 = time.time()\n", + " _,_,_,_ = torchHITS(MT_cpu, p=10, device='cpu')\n", + " exe_time3 = time.time() - start_time3\n", + " start_time4 = time.time()\n", + " _, _,_,_ = jhits(adjM, p=10)\n", + " exe_time4 = time.time() - start_time4\n", + " time_np.append(exe_time1)\n", + " time_torch.append(exe_time2)\n", + " time_torch_cpu.append(exe_time3)\n", + " time_jnp.append(exe_time4)" + ], + "metadata": { + "id": "98GaRZ5QmyYP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt" + ], + "metadata": { + "id": "M2fXTQ2UcK0-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "plt.figure(figsize=(18,10))\n", + "plt.plot(num_nodes,time_np, 'bo')\n", + "plt.plot(num_nodes,time_torch, 'ro')\n", + "plt.plot(num_nodes,time_torch_cpu, 'go')\n", + "plt.plot(num_nodes,time_jnp, 'ko')\n", + "plt.xlabel(\"nodes\")\n", + "plt.ylabel(\"seconds\")\n", + "plt.title(\"HITS algorithm execution time\")\n", + "plt.legend([\"numpy\", \"torch\", \"torch_CPU\", \"JAX\"])\n", + "plt.show()" + ], + "metadata": { + "id": "fPsHmIivcMMF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "plt.figure(figsize=(18,10))\n", + "plt.plot(num_nodes,time_np, 'bo')\n", + "plt.plot(num_nodes,time_torch, 'ro')\n", + "plt.plot(num_nodes,time_torch_cpu, 'go')\n", + "plt.xlabel(\"nodes\")\n", + "plt.ylabel(\"seconds\")\n", + "plt.title(\"HITS algorithm execution time\")\n", + "plt.legend([\"numpy\", \"torch\", \"torch_CPU\"])\n", + "plt.show()" + ], + "metadata": { + "id": "WsiYMTzS_NgL" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file