Skip to content

Commit 6d75c6b

Browse files
committed
reverse mode AD general compute graph in Python
1 parent 05f476b commit 6d75c6b

File tree

1 file changed

+395
-0
lines changed

1 file changed

+395
-0
lines changed
Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"Function is:\n",
17+
"\n",
18+
"$$ f(x_0, x_1) = \\sin(x_0) (x_0 + x_1) $$\n",
19+
"\n",
20+
"or broken down\n",
21+
"\n",
22+
"$$ \\begin{align}\n",
23+
"z_0 &= x_0 \\\\\n",
24+
"z_1 &= x_1 \\\\\n",
25+
"z_2 &= \\sin(z_0) \\\\\n",
26+
"z_3 &= z_0 + z_1 \\\\\n",
27+
"z_4 &= z_2 z_3 \\\\\n",
28+
"\\end{align} $$\n",
29+
"\n",
30+
"Its symbolic derivative is:\n",
31+
"\n",
32+
"$$ \\nabla f(x_0, x_1) = \\begin{bmatrix}\n",
33+
"\\cos(x_0) (x_0 + x_1) + \\sin(x_0) \\\\\n",
34+
"\\sin(x_0)\n",
35+
"\\end{bmatrix} $$"
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": 2,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"def f(x_0, x_1):\n",
45+
" return np.sin(x_0) * (x_0 + x_1)\n",
46+
"\n",
47+
"def f_grad(x_0, x_1):\n",
48+
" return np.array([\n",
49+
" np.cos(x_0) * (x_0 + x_1) + np.sin(x_0),\n",
50+
" np.sin(x_0),\n",
51+
" ])"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": 3,
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"compute_graph = [\n",
61+
" (\"inp\", (0,)), # 0\n",
62+
" (\"inp\", (1,)), # 1\n",
63+
" (\"sin\", (0,)), # 2\n",
64+
" (\"add\", (0, 1)), # 3\n",
65+
" (\"mul\", (2, 3)), # 4\n",
66+
"]"
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 4,
72+
"metadata": {},
73+
"outputs": [],
74+
"source": [
75+
"fn_library = {\n",
76+
" \"inp\": lambda x: x,\n",
77+
" \"sin\": lambda x: np.sin(x),\n",
78+
" \"add\": lambda x, y: x + y,\n",
79+
" \"mul\": lambda x, y: x * y,\n",
80+
"}"
81+
]
82+
},
83+
{
84+
"cell_type": "code",
85+
"execution_count": 5,
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"def compute(graph, inputs):\n",
90+
" values = list(inputs)\n",
91+
" for operation, indices in graph:\n",
92+
" if operation == \"inp\":\n",
93+
" continue\n",
94+
" args = [values[index] for index in indices]\n",
95+
" result = fn_library[operation](*args)\n",
96+
" values.append(result)\n",
97+
" \n",
98+
" return values[-1]"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": 6,
104+
"metadata": {},
105+
"outputs": [],
106+
"source": [
107+
"SAMPLE_INPUT = (0.6, 1.4)"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": 7,
113+
"metadata": {},
114+
"outputs": [
115+
{
116+
"data": {
117+
"text/plain": [
118+
"1.1292849467900707"
119+
]
120+
},
121+
"execution_count": 7,
122+
"metadata": {},
123+
"output_type": "execute_result"
124+
}
125+
],
126+
"source": [
127+
"f(*SAMPLE_INPUT)"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": 8,
133+
"metadata": {},
134+
"outputs": [
135+
{
136+
"data": {
137+
"text/plain": [
138+
"1.1292849467900707"
139+
]
140+
},
141+
"execution_count": 8,
142+
"metadata": {},
143+
"output_type": "execute_result"
144+
}
145+
],
146+
"source": [
147+
"compute(compute_graph, SAMPLE_INPUT)"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": 9,
153+
"metadata": {},
154+
"outputs": [],
155+
"source": [
156+
"def inp_backprop_rule(x):\n",
157+
" z = x\n",
158+
"\n",
159+
" def inp_pullback(z_cotangent):\n",
160+
" x_cotangent = z_cotangent\n",
161+
" return (x_cotangent,)\n",
162+
" \n",
163+
" return z, inp_pullback\n",
164+
"\n",
165+
"def sin_backprop_rule(x):\n",
166+
" z = np.sin(x)\n",
167+
"\n",
168+
" def sin_pullback(z_cotangent):\n",
169+
" x_cotangent = np.cos(x) * z_cotangent\n",
170+
" return (x_cotangent,)\n",
171+
" \n",
172+
" return z, sin_pullback\n",
173+
"\n",
174+
"def add_backprop_rule(x, y):\n",
175+
" z = x + y\n",
176+
"\n",
177+
" def add_pullback(z_cotangent):\n",
178+
" x_cotangent = z_cotangent\n",
179+
" y_cotangent = z_cotangent\n",
180+
"\n",
181+
" return (x_cotangent, y_cotangent)\n",
182+
" \n",
183+
" return z, add_pullback\n",
184+
"\n",
185+
"def mul_backprop_rule(x, y):\n",
186+
" z = x * y\n",
187+
"\n",
188+
" def mul_pullback(z_cotangent):\n",
189+
" x_cotangent = y * z_cotangent\n",
190+
" y_cotangent = x * z_cotangent\n",
191+
" return (x_cotangent, y_cotangent)\n",
192+
" \n",
193+
" return z, mul_pullback"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": 10,
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"backprop_library = {\n",
203+
" \"inp\": inp_backprop_rule,\n",
204+
" \"sin\": sin_backprop_rule,\n",
205+
" \"add\": add_backprop_rule,\n",
206+
" \"mul\": mul_backprop_rule,\n",
207+
"}"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": 11,
213+
"metadata": {},
214+
"outputs": [],
215+
"source": [
216+
"def vjp(graph, inputs):\n",
217+
" values = list(inputs)\n",
218+
" pullback_stack = []\n",
219+
"\n",
220+
" # Forward pass\n",
221+
" for operation, indices in graph:\n",
222+
" if operation == \"inp\":\n",
223+
" continue\n",
224+
" args = [values[index] for index in indices]\n",
225+
" result, pullback_fn = backprop_library[operation](*args)\n",
226+
" values.append(result)\n",
227+
" pullback_stack.append((pullback_fn, indices))\n",
228+
"\n",
229+
" def pullback(output_cotangent):\n",
230+
" cotangent_values = np.zeros(len(values))\n",
231+
" cotangent_values[-1] = output_cotangent\n",
232+
"\n",
233+
" for i, (pullback_fn, indices) in enumerate(reversed(pullback_stack)):\n",
234+
" current_cotangent_value = cotangent_values[-1 - i]\n",
235+
" cotangent_args = pullback_fn(current_cotangent_value)\n",
236+
" for index, cotangent in zip(indices, cotangent_args):\n",
237+
" cotangent_values[index] += cotangent\n",
238+
" \n",
239+
" return cotangent_values[:len(inputs)]\n",
240+
" \n",
241+
" return values[-1], pullback\n",
242+
" "
243+
]
244+
},
245+
{
246+
"cell_type": "code",
247+
"execution_count": 12,
248+
"metadata": {},
249+
"outputs": [],
250+
"source": [
251+
"out, back_fn = vjp(compute_graph, SAMPLE_INPUT)"
252+
]
253+
},
254+
{
255+
"cell_type": "code",
256+
"execution_count": 13,
257+
"metadata": {},
258+
"outputs": [
259+
{
260+
"data": {
261+
"text/plain": [
262+
"1.1292849467900707"
263+
]
264+
},
265+
"execution_count": 13,
266+
"metadata": {},
267+
"output_type": "execute_result"
268+
}
269+
],
270+
"source": [
271+
"out"
272+
]
273+
},
274+
{
275+
"cell_type": "code",
276+
"execution_count": 15,
277+
"metadata": {},
278+
"outputs": [
279+
{
280+
"data": {
281+
"text/plain": [
282+
"array([2.2153137 , 0.56464247])"
283+
]
284+
},
285+
"execution_count": 15,
286+
"metadata": {},
287+
"output_type": "execute_result"
288+
}
289+
],
290+
"source": [
291+
"back_fn(1.0)"
292+
]
293+
},
294+
{
295+
"cell_type": "code",
296+
"execution_count": 16,
297+
"metadata": {},
298+
"outputs": [
299+
{
300+
"data": {
301+
"text/plain": [
302+
"array([2.2153137 , 0.56464247])"
303+
]
304+
},
305+
"execution_count": 16,
306+
"metadata": {},
307+
"output_type": "execute_result"
308+
}
309+
],
310+
"source": [
311+
"f_grad(*SAMPLE_INPUT)"
312+
]
313+
},
314+
{
315+
"cell_type": "code",
316+
"execution_count": 17,
317+
"metadata": {},
318+
"outputs": [],
319+
"source": [
320+
"def value_and_grad(graph, inputs):\n",
321+
" out, back_fn = vjp(graph, inputs)\n",
322+
" grad = back_fn(1.0)\n",
323+
" return out, grad"
324+
]
325+
},
326+
{
327+
"cell_type": "code",
328+
"execution_count": 18,
329+
"metadata": {},
330+
"outputs": [
331+
{
332+
"data": {
333+
"text/plain": [
334+
"(1.1292849467900707, array([2.2153137 , 0.56464247]))"
335+
]
336+
},
337+
"execution_count": 18,
338+
"metadata": {},
339+
"output_type": "execute_result"
340+
}
341+
],
342+
"source": [
343+
"value_and_grad(compute_graph, SAMPLE_INPUT)"
344+
]
345+
},
346+
{
347+
"cell_type": "code",
348+
"execution_count": 19,
349+
"metadata": {},
350+
"outputs": [
351+
{
352+
"data": {
353+
"text/plain": [
354+
"(1.1292849467900707, array([2.2153137 , 0.56464247]))"
355+
]
356+
},
357+
"execution_count": 19,
358+
"metadata": {},
359+
"output_type": "execute_result"
360+
}
361+
],
362+
"source": [
363+
"f(*SAMPLE_INPUT), f_grad(*SAMPLE_INPUT)"
364+
]
365+
},
366+
{
367+
"cell_type": "code",
368+
"execution_count": null,
369+
"metadata": {},
370+
"outputs": [],
371+
"source": []
372+
}
373+
],
374+
"metadata": {
375+
"kernelspec": {
376+
"display_name": "base",
377+
"language": "python",
378+
"name": "python3"
379+
},
380+
"language_info": {
381+
"codemirror_mode": {
382+
"name": "ipython",
383+
"version": 3
384+
},
385+
"file_extension": ".py",
386+
"mimetype": "text/x-python",
387+
"name": "python",
388+
"nbconvert_exporter": "python",
389+
"pygments_lexer": "ipython3",
390+
"version": "3.10.9"
391+
}
392+
},
393+
"nbformat": 4,
394+
"nbformat_minor": 2
395+
}

0 commit comments

Comments
 (0)