Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 4085662

Browse files
author
tnpw2
committed
JAX NN test 2
1 parent 17627c5 commit 4085662

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed

notebooks/jax_nn_test2.ipynb

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 2,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import jax.numpy as jnp\n",
10+
"from jax import random, jit, vmap, grad\n",
11+
"from tensorflow.keras.datasets import mnist\n"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 3,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"# Create a program to classify MNIST data, using JAX\n",
21+
"# First load data\n",
22+
"\n",
23+
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
24+
"x_train, x_test = x_train / 255.0, x_test / 255.0\n",
25+
"x_train, x_test, y_train, y_test = jnp.array(x_train), jnp.array(x_test), jnp.array(y_train), jnp.array(y_test)\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": 17,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"# Define functions to create network\n",
35+
"\n",
36+
"def initialise_layer_params(inp, out, key=None, scale=1e-2):\n",
37+
" if key == None:\n",
38+
" raise RuntimeError(\"Key must be provided for generating weights\")\n",
39+
" \n",
40+
" key1, key2 = random.split(key)\n",
41+
"\n",
42+
" return scale*random.normal(key1, (out, inp)), scale*random.normal(key2)\n",
43+
"\n",
44+
"def initialise_nn(layers, key):\n",
45+
" keys = random.split(key, num=len(layers))\n",
46+
" return [initialise_layer_params(i, o, k) for i, o, k in zip(layers[:-1], layers[1:], keys)]\n",
47+
"\n",
48+
"layer_sizes = [784,256,256,256,10]\n",
49+
"\n",
50+
"nn = initialise_nn(layer_sizes, random.key(1878))"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 72,
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"# Define one-hot encoding for y data, prediction function and loss function\n",
60+
"\n",
61+
"def one_hot(y, num_classes, dtype=jnp.float32):\n",
62+
" return jnp.array(y[:,None] == jnp.arange(num_classes), dtype=dtype)\n",
63+
"\n",
64+
"def relu(x):\n",
65+
" return jnp.max(0,x)\n",
66+
"\n",
67+
"def predict(params, image):\n",
68+
" res = image\n",
69+
" for weights, bias in params[:-1]:\n",
70+
" res = jnp.dot(weights, res) + bias\n",
71+
" res = relu(res)\n",
72+
" res = jnp.dot(params[-1][0], res) + params[-1][1]\n",
73+
" return jnp.exp(-res) / jnp.sum(jnp.exp(-res))\n",
74+
"\n",
75+
"batch_pred = vmap(predict, in_axes=(None, 0))\n",
76+
"\n",
77+
"def loss(params, images, targets):\n",
78+
" preds = batch_pred(params, images)\n",
79+
" return -jnp.mean(jnp.log(preds)*targets)\n",
80+
"\n",
81+
"def accuracy(params, images, targets):\n",
82+
" preds = batch_pred(params, images)\n",
83+
" return jnp.mean(jnp.argmax(preds, axis=1)==jnp.argmax(targets, axis=1))\n",
84+
"\n",
85+
"@jit\n",
86+
"def update(params, x, y, lr):\n",
87+
" grads = grad(loss)(params, x, y)\n",
88+
" return [(w - lr * dw, b - lr * db)\n",
89+
" for (w, b), (dw, db) in zip(params, grads)]\n"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 73,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"# Define some parameters\n",
99+
"epochs = 100\n",
100+
"batch_size = 128\n",
101+
"learning_rate = 0.01\n",
102+
"num_digits = 10"
103+
]
104+
},
105+
{
106+
"cell_type": "code",
107+
"execution_count": 74,
108+
"metadata": {},
109+
"outputs": [],
110+
"source": [
111+
"# One-hot encode data\n",
112+
"y_train, y_test = one_hot(y_train, num_digits), one_hot(y_test, num_digits)\n",
113+
"x_train, x_test = jnp.reshape(x_train, (-1, 28*28)), jnp.reshape(x_test, (-1, 28*28))\n"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": 75,
119+
"metadata": {},
120+
"outputs": [],
121+
"source": [
122+
"def batch_data(images, key, bsize=128):\n",
123+
" order = random.permutation(key, len(images))\n",
124+
" print(order[0])\n",
125+
" for i in range(jnp.floor(len(images)/bsize)):\n",
126+
" yield images[order[bsize*i:bsize*(i+1)]]\n",
127+
" if images%bsize != 0:\n",
128+
" yield images[order[bsize*(i+1):]]"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 77,
134+
"metadata": {},
135+
"outputs": [
136+
{
137+
"ename": "ConcretizationTypeError",
138+
"evalue": "Abstract tracer value encountered where concrete value is expected: traced array with shape float32[256]\nThe axis argument must be known statically.\nThis BatchTracer with object id 1426817064992 was created on line:\n C:\\Users\\thoma\\AppData\\Local\\Temp\\ipykernel_13036\\740884550.py:12:14 (predict)\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError",
139+
"output_type": "error",
140+
"traceback": [
141+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
142+
"\u001b[1;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)",
143+
"Cell \u001b[1;32mIn[77], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m10\u001b[39m):\n\u001b[1;32m----> 2\u001b[0m nn \u001b[38;5;241m=\u001b[39m \u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m accuracy(nn, x_train, y_train)\n\u001b[0;32m 5\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m accuracy(nn, x_test, y_test)\n",
144+
" \u001b[1;31m[... skipping hidden 13 frame]\u001b[0m\n",
145+
"Cell \u001b[1;32mIn[72], line 29\u001b[0m, in \u001b[0;36mupdate\u001b[1;34m(params, x, y, lr)\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[38;5;129m@jit\u001b[39m\n\u001b[0;32m 28\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mupdate\u001b[39m(params, x, y, lr):\n\u001b[1;32m---> 29\u001b[0m grads \u001b[38;5;241m=\u001b[39m \u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [(w \u001b[38;5;241m-\u001b[39m lr \u001b[38;5;241m*\u001b[39m dw, b \u001b[38;5;241m-\u001b[39m lr \u001b[38;5;241m*\u001b[39m db)\n\u001b[0;32m 31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m (w, b), (dw, db) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(params, grads)]\n",
146+
" \u001b[1;31m[... skipping hidden 17 frame]\u001b[0m\n",
147+
"Cell \u001b[1;32mIn[72], line 20\u001b[0m, in \u001b[0;36mloss\u001b[1;34m(params, images, targets)\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mloss\u001b[39m(params, images, targets):\n\u001b[1;32m---> 20\u001b[0m preds \u001b[38;5;241m=\u001b[39m \u001b[43mbatch_pred\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimages\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m-\u001b[39mjnp\u001b[38;5;241m.\u001b[39mmean(jnp\u001b[38;5;241m.\u001b[39mlog(preds)\u001b[38;5;241m*\u001b[39mtargets)\n",
148+
" \u001b[1;31m[... skipping hidden 6 frame]\u001b[0m\n",
149+
"Cell \u001b[1;32mIn[72], line 13\u001b[0m, in \u001b[0;36mpredict\u001b[1;34m(params, image)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m weights, bias \u001b[38;5;129;01min\u001b[39;00m params[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]:\n\u001b[0;32m 12\u001b[0m res \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mdot(weights, res) \u001b[38;5;241m+\u001b[39m bias\n\u001b[1;32m---> 13\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mrelu\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 14\u001b[0m res \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mdot(params[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m0\u001b[39m], res) \u001b[38;5;241m+\u001b[39m params[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m 15\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jnp\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;241m-\u001b[39mres) \u001b[38;5;241m/\u001b[39m jnp\u001b[38;5;241m.\u001b[39msum(jnp\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;241m-\u001b[39mres))\n",
150+
"Cell \u001b[1;32mIn[72], line 7\u001b[0m, in \u001b[0;36mrelu\u001b[1;34m(x)\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrelu\u001b[39m(x):\n\u001b[1;32m----> 7\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
151+
"File \u001b[1;32mc:\\Users\\thoma\\miniconda3\\envs\\neutang\\Lib\\site-packages\\jax\\_src\\numpy\\reductions.py:483\u001b[0m, in \u001b[0;36mmax\u001b[1;34m(a, axis, out, keepdims, initial, where)\u001b[0m\n\u001b[0;32m 412\u001b[0m \u001b[38;5;129m@export\u001b[39m\n\u001b[0;32m 413\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mmax\u001b[39m(a: ArrayLike, axis: Axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, out: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 414\u001b[0m keepdims: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, initial: ArrayLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 415\u001b[0m where: ArrayLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Array:\n\u001b[0;32m 416\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Return the maximum of the array elements along a given axis.\u001b[39;00m\n\u001b[0;32m 417\u001b[0m \n\u001b[0;32m 418\u001b[0m \u001b[38;5;124;03m JAX implementation of :func:`numpy.max`.\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 481\u001b[0m \u001b[38;5;124;03m Array([[0, 0, 0, 0]], dtype=int32)\u001b[39;00m\n\u001b[0;32m 482\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m--> 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reduce_max(a, axis\u001b[38;5;241m=\u001b[39m\u001b[43m_ensure_optional_axes\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m, out\u001b[38;5;241m=\u001b[39mout,\n\u001b[0;32m 484\u001b[0m keepdims\u001b[38;5;241m=\u001b[39mkeepdims, initial\u001b[38;5;241m=\u001b[39minitial, where\u001b[38;5;241m=\u001b[39mwhere)\n",
152+
"File \u001b[1;32mc:\\Users\\thoma\\miniconda3\\envs\\neutang\\Lib\\site-packages\\jax\\_src\\numpy\\reductions.py:224\u001b[0m, in \u001b[0;36m_ensure_optional_axes\u001b[1;34m(x)\u001b[0m\n\u001b[0;32m 222\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[0;32m 223\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(i \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(i, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m operator\u001b[38;5;241m.\u001b[39mindex(i) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m x)\n\u001b[1;32m--> 224\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcore\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconcrete_or_error\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mThe axis argument must be known statically.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
153+
"File \u001b[1;32mc:\\Users\\thoma\\miniconda3\\envs\\neutang\\Lib\\site-packages\\jax\\_src\\core.py:1514\u001b[0m, in \u001b[0;36mconcrete_or_error\u001b[1;34m(force, val, context)\u001b[0m\n\u001b[0;32m 1512\u001b[0m maybe_concrete \u001b[38;5;241m=\u001b[39m val\u001b[38;5;241m.\u001b[39mto_concrete_value()\n\u001b[0;32m 1513\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m maybe_concrete \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m-> 1514\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConcretizationTypeError(val, context)\n\u001b[0;32m 1515\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m force(maybe_concrete)\n",
154+
"\u001b[1;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[256]\nThe axis argument must be known statically.\nThis BatchTracer with object id 1426817064992 was created on line:\n C:\\Users\\thoma\\AppData\\Local\\Temp\\ipykernel_13036\\740884550.py:12:14 (predict)\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError"
155+
]
156+
}
157+
],
158+
"source": [
159+
"for epoch in range(10):\n",
160+
" nn = update(nn, x_train, y_train, learning_rate)\n",
161+
" \n",
162+
" train_acc = accuracy(nn, x_train, y_train)\n",
163+
" test_acc = accuracy(nn, x_test, y_test)\n",
164+
" print(\"Training set accuracy {}\".format(train_acc))"
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"metadata": {},
171+
"outputs": [],
172+
"source": []
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": null,
177+
"metadata": {},
178+
"outputs": [],
179+
"source": []
180+
}
181+
],
182+
"metadata": {
183+
"kernelspec": {
184+
"display_name": "neutang",
185+
"language": "python",
186+
"name": "python3"
187+
},
188+
"language_info": {
189+
"codemirror_mode": {
190+
"name": "ipython",
191+
"version": 3
192+
},
193+
"file_extension": ".py",
194+
"mimetype": "text/x-python",
195+
"name": "python",
196+
"nbconvert_exporter": "python",
197+
"pygments_lexer": "ipython3",
198+
"version": "3.11.11"
199+
}
200+
},
201+
"nbformat": 4,
202+
"nbformat_minor": 2
203+
}

0 commit comments

Comments
 (0)