|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 2, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "import jax.numpy as jnp\n", |
| 10 | + "from jax import random, jit, vmap, grad\n", |
| 11 | + "from tensorflow.keras.datasets import mnist\n" |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "code", |
| 16 | + "execution_count": 3, |
| 17 | + "metadata": {}, |
| 18 | + "outputs": [], |
| 19 | + "source": [ |
| 20 | + "# Create a program to classify MNIST data, using JAX\n", |
| 21 | + "# First load data\n", |
| 22 | + "\n", |
| 23 | + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", |
| 24 | + "x_train, x_test = x_train / 255.0, x_test / 255.0\n", |
| 25 | + "x_train, x_test, y_train, y_test = jnp.array(x_train), jnp.array(x_test), jnp.array(y_train), jnp.array(y_test)\n" |
| 26 | + ] |
| 27 | + }, |
| 28 | + { |
| 29 | + "cell_type": "code", |
| 30 | + "execution_count": 17, |
| 31 | + "metadata": {}, |
| 32 | + "outputs": [], |
| 33 | + "source": [ |
| 34 | + "# Define functions to create network\n", |
| 35 | + "\n", |
| 36 | + "def initialise_layer_params(inp, out, key=None, scale=1e-2):\n", |
| 37 | + " if key == None:\n", |
| 38 | + " raise RuntimeError(\"Key must be provided for generating weights\")\n", |
| 39 | + " \n", |
| 40 | + " key1, key2 = random.split(key)\n", |
| 41 | + "\n", |
| 42 | + " return scale*random.normal(key1, (out, inp)), scale*random.normal(key2)\n", |
| 43 | + "\n", |
| 44 | + "def initialise_nn(layers, key):\n", |
| 45 | + " keys = random.split(key, num=len(layers))\n", |
| 46 | + " return [initialise_layer_params(i, o, k) for i, o, k in zip(layers[:-1], layers[1:], keys)]\n", |
| 47 | + "\n", |
| 48 | + "layer_sizes = [784,256,256,256,10]\n", |
| 49 | + "\n", |
| 50 | + "nn = initialise_nn(layer_sizes, random.key(1878))" |
| 51 | + ] |
| 52 | + }, |
| 53 | + { |
| 54 | + "cell_type": "code", |
| 55 | + "execution_count": 72, |
| 56 | + "metadata": {}, |
| 57 | + "outputs": [], |
| 58 | + "source": [ |
| 59 | + "# Define one-hot encoding for y data, prediction function and loss function\n", |
| 60 | + "\n", |
| 61 | + "def one_hot(y, num_classes, dtype=jnp.float32):\n", |
| 62 | + " return jnp.array(y[:,None] == jnp.arange(num_classes), dtype=dtype)\n", |
| 63 | + "\n", |
| 64 | + "def relu(x):\n", |
| 65 | + " return jnp.max(0,x)\n", |
| 66 | + "\n", |
| 67 | + "def predict(params, image):\n", |
| 68 | + " res = image\n", |
| 69 | + " for weights, bias in params[:-1]:\n", |
| 70 | + " res = jnp.dot(weights, res) + bias\n", |
| 71 | + " res = relu(res)\n", |
| 72 | + " res = jnp.dot(params[-1][0], res) + params[-1][1]\n", |
| 73 | + " return jnp.exp(-res) / jnp.sum(jnp.exp(-res))\n", |
| 74 | + "\n", |
| 75 | + "batch_pred = vmap(predict, in_axes=(None, 0))\n", |
| 76 | + "\n", |
| 77 | + "def loss(params, images, targets):\n", |
| 78 | + " preds = batch_pred(params, images)\n", |
| 79 | + " return -jnp.mean(jnp.log(preds)*targets)\n", |
| 80 | + "\n", |
| 81 | + "def accuracy(params, images, targets):\n", |
| 82 | + " preds = batch_pred(params, images)\n", |
| 83 | + " return jnp.mean(jnp.argmax(preds, axis=1)==jnp.argmax(targets, axis=1))\n", |
| 84 | + "\n", |
| 85 | + "@jit\n", |
| 86 | + "def update(params, x, y, lr):\n", |
| 87 | + " grads = grad(loss)(params, x, y)\n", |
| 88 | + " return [(w - lr * dw, b - lr * db)\n", |
| 89 | + " for (w, b), (dw, db) in zip(params, grads)]\n" |
| 90 | + ] |
| 91 | + }, |
| 92 | + { |
| 93 | + "cell_type": "code", |
| 94 | + "execution_count": 73, |
| 95 | + "metadata": {}, |
| 96 | + "outputs": [], |
| 97 | + "source": [ |
| 98 | + "# Define some parameters\n", |
| 99 | + "epochs = 100\n", |
| 100 | + "batch_size = 128\n", |
| 101 | + "learning_rate = 0.01\n", |
| 102 | + "num_digits = 10" |
| 103 | + ] |
| 104 | + }, |
| 105 | + { |
| 106 | + "cell_type": "code", |
| 107 | + "execution_count": 74, |
| 108 | + "metadata": {}, |
| 109 | + "outputs": [], |
| 110 | + "source": [ |
| 111 | + "# One-hot encode data\n", |
| 112 | + "y_train, y_test = one_hot(y_train, num_digits), one_hot(y_test, num_digits)\n", |
| 113 | + "x_train, x_test = jnp.reshape(x_train, (-1, 28*28)), jnp.reshape(x_test, (-1, 28*28))\n" |
| 114 | + ] |
| 115 | + }, |
| 116 | + { |
| 117 | + "cell_type": "code", |
| 118 | + "execution_count": 75, |
| 119 | + "metadata": {}, |
| 120 | + "outputs": [], |
| 121 | + "source": [ |
| 122 | + "def batch_data(images, key, bsize=128):\n", |
| 123 | + " order = random.permutation(key, len(images))\n", |
| 124 | + " print(order[0])\n", |
| 125 | + " for i in range(jnp.floor(len(images)/bsize)):\n", |
| 126 | + " yield images[order[bsize*i:bsize*(i+1)]]\n", |
| 127 | + " if images%bsize != 0:\n", |
| 128 | + " yield images[order[bsize*(i+1):]]" |
| 129 | + ] |
| 130 | + }, |
| 131 | + { |
| 132 | + "cell_type": "code", |
| 133 | + "execution_count": 77, |
| 134 | + "metadata": {}, |
| 135 | + "outputs": [ |
| 136 | + { |
| 137 | + "ename": "ConcretizationTypeError", |
| 138 | + "evalue": "Abstract tracer value encountered where concrete value is expected: traced array with shape float32[256]\nThe axis argument must be known statically.\nThis BatchTracer with object id 1426817064992 was created on line:\n C:\\Users\\thoma\\AppData\\Local\\Temp\\ipykernel_13036\\740884550.py:12:14 (predict)\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError", |
| 139 | + "output_type": "error", |
| 140 | + "traceback": [ |
| 141 | + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", |
| 142 | + "\u001b[1;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)", |
| 143 | + "Cell \u001b[1;32mIn[77], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m10\u001b[39m):\n\u001b[1;32m----> 2\u001b[0m nn \u001b[38;5;241m=\u001b[39m \u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m accuracy(nn, x_train, y_train)\n\u001b[0;32m 5\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m accuracy(nn, x_test, y_test)\n", |
| 144 | + " \u001b[1;31m[... skipping hidden 13 frame]\u001b[0m\n", |
| 145 | + "Cell \u001b[1;32mIn[72], line 29\u001b[0m, in \u001b[0;36mupdate\u001b[1;34m(params, x, y, lr)\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[38;5;129m@jit\u001b[39m\n\u001b[0;32m 28\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mupdate\u001b[39m(params, x, y, lr):\n\u001b[1;32m---> 29\u001b[0m grads \u001b[38;5;241m=\u001b[39m \u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [(w \u001b[38;5;241m-\u001b[39m lr \u001b[38;5;241m*\u001b[39m dw, b \u001b[38;5;241m-\u001b[39m lr \u001b[38;5;241m*\u001b[39m db)\n\u001b[0;32m 31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m (w, b), (dw, db) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(params, grads)]\n", |
| 146 | + " \u001b[1;31m[... skipping hidden 17 frame]\u001b[0m\n", |
| 147 | + "Cell \u001b[1;32mIn[72], line 20\u001b[0m, in \u001b[0;36mloss\u001b[1;34m(params, images, targets)\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mloss\u001b[39m(params, images, targets):\n\u001b[1;32m---> 20\u001b[0m preds \u001b[38;5;241m=\u001b[39m \u001b[43mbatch_pred\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimages\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m-\u001b[39mjnp\u001b[38;5;241m.\u001b[39mmean(jnp\u001b[38;5;241m.\u001b[39mlog(preds)\u001b[38;5;241m*\u001b[39mtargets)\n", |
| 148 | + " \u001b[1;31m[... skipping hidden 6 frame]\u001b[0m\n", |
| 149 | + "Cell \u001b[1;32mIn[72], line 13\u001b[0m, in \u001b[0;36mpredict\u001b[1;34m(params, image)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m weights, bias \u001b[38;5;129;01min\u001b[39;00m params[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]:\n\u001b[0;32m 12\u001b[0m res \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mdot(weights, res) \u001b[38;5;241m+\u001b[39m bias\n\u001b[1;32m---> 13\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mrelu\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 14\u001b[0m res \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mdot(params[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m0\u001b[39m], res) \u001b[38;5;241m+\u001b[39m params[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m 15\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jnp\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;241m-\u001b[39mres) \u001b[38;5;241m/\u001b[39m jnp\u001b[38;5;241m.\u001b[39msum(jnp\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;241m-\u001b[39mres))\n", |
| 150 | + "Cell \u001b[1;32mIn[72], line 7\u001b[0m, in \u001b[0;36mrelu\u001b[1;34m(x)\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrelu\u001b[39m(x):\n\u001b[1;32m----> 7\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", |
| 151 | + "File \u001b[1;32mc:\\Users\\thoma\\miniconda3\\envs\\neutang\\Lib\\site-packages\\jax\\_src\\numpy\\reductions.py:483\u001b[0m, in \u001b[0;36mmax\u001b[1;34m(a, axis, out, keepdims, initial, where)\u001b[0m\n\u001b[0;32m 412\u001b[0m \u001b[38;5;129m@export\u001b[39m\n\u001b[0;32m 413\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mmax\u001b[39m(a: ArrayLike, axis: Axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, out: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 414\u001b[0m keepdims: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, initial: ArrayLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 415\u001b[0m where: ArrayLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Array:\n\u001b[0;32m 416\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Return the maximum of the array elements along a given axis.\u001b[39;00m\n\u001b[0;32m 417\u001b[0m \n\u001b[0;32m 418\u001b[0m \u001b[38;5;124;03m JAX implementation of :func:`numpy.max`.\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 481\u001b[0m \u001b[38;5;124;03m Array([[0, 0, 0, 0]], dtype=int32)\u001b[39;00m\n\u001b[0;32m 482\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m--> 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reduce_max(a, axis\u001b[38;5;241m=\u001b[39m\u001b[43m_ensure_optional_axes\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m, out\u001b[38;5;241m=\u001b[39mout,\n\u001b[0;32m 484\u001b[0m keepdims\u001b[38;5;241m=\u001b[39mkeepdims, initial\u001b[38;5;241m=\u001b[39minitial, where\u001b[38;5;241m=\u001b[39mwhere)\n", |
| 152 | + "File \u001b[1;32mc:\\Users\\thoma\\miniconda3\\envs\\neutang\\Lib\\site-packages\\jax\\_src\\numpy\\reductions.py:224\u001b[0m, in \u001b[0;36m_ensure_optional_axes\u001b[1;34m(x)\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m 223\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(i \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(i, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m operator\u001b[38;5;241m.\u001b[39mindex(i) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m x)\n\u001b[1;32m--> 224\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcore\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconcrete_or_error\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mThe axis argument must be known statically.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", |
| 153 | + "File \u001b[1;32mc:\\Users\\thoma\\miniconda3\\envs\\neutang\\Lib\\site-packages\\jax\\_src\\core.py:1514\u001b[0m, in \u001b[0;36mconcrete_or_error\u001b[1;34m(force, val, context)\u001b[0m\n\u001b[0;32m 1512\u001b[0m maybe_concrete \u001b[38;5;241m=\u001b[39m val\u001b[38;5;241m.\u001b[39mto_concrete_value()\n\u001b[0;32m 1513\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m maybe_concrete \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m-> 1514\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConcretizationTypeError(val, context)\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m force(maybe_concrete)\n", |
| 154 | + "\u001b[1;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[256]\nThe axis argument must be known statically.\nThis BatchTracer with object id 1426817064992 was created on line:\n C:\\Users\\thoma\\AppData\\Local\\Temp\\ipykernel_13036\\740884550.py:12:14 (predict)\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError" |
| 155 | + ] |
| 156 | + } |
| 157 | + ], |
| 158 | + "source": [ |
| 159 | + "for epoch in range(10):\n", |
| 160 | + " nn = update(nn, x_train, y_train, learning_rate)\n", |
| 161 | + " \n", |
| 162 | + " train_acc = accuracy(nn, x_train, y_train)\n", |
| 163 | + " test_acc = accuracy(nn, x_test, y_test)\n", |
| 164 | + " print(\"Training set accuracy {}\".format(train_acc))" |
| 165 | + ] |
| 166 | + }, |
| 167 | + { |
| 168 | + "cell_type": "code", |
| 169 | + "execution_count": null, |
| 170 | + "metadata": {}, |
| 171 | + "outputs": [], |
| 172 | + "source": [] |
| 173 | + }, |
| 174 | + { |
| 175 | + "cell_type": "code", |
| 176 | + "execution_count": null, |
| 177 | + "metadata": {}, |
| 178 | + "outputs": [], |
| 179 | + "source": [] |
| 180 | + } |
| 181 | + ], |
| 182 | + "metadata": { |
| 183 | + "kernelspec": { |
| 184 | + "display_name": "neutang", |
| 185 | + "language": "python", |
| 186 | + "name": "python3" |
| 187 | + }, |
| 188 | + "language_info": { |
| 189 | + "codemirror_mode": { |
| 190 | + "name": "ipython", |
| 191 | + "version": 3 |
| 192 | + }, |
| 193 | + "file_extension": ".py", |
| 194 | + "mimetype": "text/x-python", |
| 195 | + "name": "python", |
| 196 | + "nbconvert_exporter": "python", |
| 197 | + "pygments_lexer": "ipython3", |
| 198 | + "version": "3.11.11" |
| 199 | + } |
| 200 | + }, |
| 201 | + "nbformat": 4, |
| 202 | + "nbformat_minor": 2 |
| 203 | +} |
0 commit comments