88from jax import numpy as jnp
99from jax .interpreters import ad
1010from jax .experimental .sparse import csr
11+ from braintaichi import event_csrmm as bt_event_csrmm
1112
1213from brainpy ._src .dependency_check import import_taichi
1314from brainpy ._src .math .interoperability import as_jax
1415from brainpy ._src .math .ndarray import Array
1516from brainpy ._src .math .op_register import (XLACustomOp , register_general_batching )
16- from brainpy ._src .math .sparse .csr_mm import raw_csrmm_taichi as normal_csrmm
1717from brainpy ._src .math .sparse .utils import csr_to_coo
1818from brainpy ._src .math .defaults import float_
1919
@@ -49,262 +49,4 @@ def csrmm(
4949 C : array of shape ``(shape[1] if transpose else shape[0], cols)``
5050 representing the matrix-matrix product product.
5151 """
52- return raw_event_csrmm_taichi (data , indices , indptr , matrix , shape = shape , transpose = transpose )[0 ]
53-
54-
55- def raw_event_csrmm_taichi (
56- data : Union [float , jnp .ndarray , Array ],
57- indices : Union [jnp .ndarray , Array ],
58- indptr : Union [jnp .ndarray , Array ],
59- matrix : Union [jnp .ndarray , Array ],
60- * ,
61- shape : Tuple [int , int ],
62- transpose : bool = False ,
63- ):
64- assert len (shape ) == 2
65-
66- data = jnp .atleast_1d (data )
67- if np .ndim (data ) == 1 :
68- if data .shape [0 ] not in [1 , indices .shape [0 ]]:
69- raise ValueError ('The size of data should be 1 or be consistent with indices.'
70- f'But we got { data .shape } != { indices .shape } , { data .shape } != 1.' )
71-
72- indices = as_jax (indices )
73- indptr = as_jax (indptr )
74- matrix = as_jax (matrix )
75-
76- assert data .ndim == indices .ndim == indptr .ndim == 1
77- assert matrix .ndim == 2
78- assert indptr .shape [0 ] == shape [0 ] + 1
79- if not jnp .issubdtype (indices .dtype , jnp .integer ):
80- raise ValueError ('indices should be a 1D vector with integer type.' )
81- if not jnp .issubdtype (indptr .dtype , jnp .integer ):
82- raise ValueError ('indptr should be a 1D vector with integer type.' )
83-
84- out_shape = shape [1 ] if transpose else shape [0 ]
85- result_shape = (out_shape , matrix .shape [1 ])
86- # if the shape of indices is (0,), then we return a zero matrix
87- if indices .shape [0 ] == 0 :
88- return [jnp .zeros (result_shape , dtype = data .dtype ), ]
89-
90- assert matrix .shape [0 ] == (shape [0 ] if transpose else shape [1 ])
91-
92- # homo -> taichi
93- # heter -> cusparse
94- if data .shape [0 ] != 1 :
95- if matrix .dtype == jnp .bool_ :
96- # change dtype to float
97- matrix = matrix .astype (float_ )
98- return [_csr_matmat_cusparse_p .bind (data , indices , indptr , matrix , shape = shape , transpose = transpose ), ]
99- else :
100- if transpose :
101- if matrix .dtype == jnp .bool_ :
102- prim = _event_csr_matmat_transpose_homo_p
103- else :
104- return normal_csrmm (data , indices , indptr , matrix , shape = shape , transpose = transpose )
105- else :
106- if matrix .dtype == jnp .bool_ :
107- prim = _event_csr_matmat_bool_homo_p
108- else :
109- return normal_csrmm (data , indices , indptr , matrix , shape = shape , transpose = transpose )
110- return prim (data ,
111- indices ,
112- indptr ,
113- matrix ,
114- outs = [jax .ShapeDtypeStruct (result_shape , dtype = data .dtype )],
115- transpose = transpose ,
116- shape = shape )
117-
118-
119- # taichi kernels
120-
121- @ti .kernel
122- def _event_csr_matmat_transpose_heter (values : ti .types .ndarray (ndim = 1 ),
123- col_indices : ti .types .ndarray (ndim = 1 ),
124- row_ptr : ti .types .ndarray (ndim = 1 ),
125- matrix : ti .types .ndarray (ndim = 2 ),
126- out : ti .types .ndarray (ndim = 2 )):
127- for col_i , row_k in ti .ndrange (out .shape [1 ], out .shape [0 ]):
128- for row_j in range (matrix .shape [0 ]):
129- if matrix [row_j , col_i ] != 0. :
130- for j in range (row_ptr [row_j ], row_ptr [row_j + 1 ]):
131- if col_indices [j ] == row_k :
132- out [row_k , col_i ] += values [j ] * matrix [row_j , col_i ]
133-
134-
135- @ti .kernel
136- def _event_csr_matmat_transpose_bool_heter (values : ti .types .ndarray (ndim = 1 ),
137- col_indices : ti .types .ndarray (ndim = 1 ),
138- row_ptr : ti .types .ndarray (ndim = 1 ),
139- matrix : ti .types .ndarray (ndim = 2 ),
140- out : ti .types .ndarray (ndim = 2 )):
141- for col_i , row_k in ti .ndrange (out .shape [1 ], out .shape [0 ]):
142- for row_j in range (matrix .shape [0 ]):
143- if matrix [row_j , col_i ]:
144- for j in range (row_ptr [row_j ], row_ptr [row_j + 1 ]):
145- if col_indices [j ] == row_k :
146- out [row_k , col_i ] += values [j ] * matrix [row_j , col_i ]
147-
148-
149- @ti .kernel
150- def _event_csr_matmat_heter (values : ti .types .ndarray (ndim = 1 ),
151- col_indices : ti .types .ndarray (ndim = 1 ),
152- row_ptr : ti .types .ndarray (ndim = 1 ),
153- matrix : ti .types .ndarray (ndim = 2 ),
154- out : ti .types .ndarray (ndim = 2 )):
155- for row_i , col_k in ti .ndrange (out .shape [0 ], out .shape [1 ]):
156- r = 0.
157- for row_j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
158- if matrix [col_indices [row_j ], col_k ] != 0. :
159- r += values [row_j ] * matrix [col_indices [row_j ], col_k ]
160- out [row_i , col_k ] = r
161-
162-
163- @ti .kernel
164- def _event_csr_matmat_bool_heter (values : ti .types .ndarray (ndim = 1 ),
165- col_indices : ti .types .ndarray (ndim = 1 ),
166- row_ptr : ti .types .ndarray (ndim = 1 ),
167- matrix : ti .types .ndarray (ndim = 2 ),
168- out : ti .types .ndarray (ndim = 2 )):
169- for row_i , col_k in ti .ndrange (out .shape [0 ], out .shape [1 ]):
170- r = 0.
171- for row_j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
172- if matrix [col_indices [row_j ], col_k ]:
173- r += values [row_j ] * matrix [col_indices [row_j ], col_k ]
174- out [row_i , col_k ] = r
175-
176-
177- @ti .kernel
178- def _event_csr_matmat_transpose_homo (values : ti .types .ndarray (ndim = 1 ),
179- col_indices : ti .types .ndarray (ndim = 1 ),
180- row_ptr : ti .types .ndarray (ndim = 1 ),
181- matrix : ti .types .ndarray (ndim = 2 ),
182- out : ti .types .ndarray (ndim = 2 )):
183- value = values [0 ]
184- for col_i , row_k in ti .ndrange (out .shape [1 ], out .shape [0 ]):
185- for row_j in range (matrix .shape [0 ]):
186- if matrix [row_j , col_i ] != 0. :
187- for j in range (row_ptr [row_j ], row_ptr [row_j + 1 ]):
188- if col_indices [j ] == row_k :
189- out [row_k , col_i ] += value * matrix [row_j , col_i ]
190-
191-
192- @ti .kernel
193- def _event_csr_matmat_transpose_bool_homo (values : ti .types .ndarray (ndim = 1 ),
194- col_indices : ti .types .ndarray (ndim = 1 ),
195- row_ptr : ti .types .ndarray (ndim = 1 ),
196- matrix : ti .types .ndarray (ndim = 2 ),
197- out : ti .types .ndarray (ndim = 2 )):
198- value = values [0 ]
199- for col_i , row_k in ti .ndrange (out .shape [1 ], out .shape [0 ]):
200- for row_j in range (matrix .shape [0 ]):
201- if matrix [row_j , col_i ]:
202- for j in range (row_ptr [row_j ], row_ptr [row_j + 1 ]):
203- if col_indices [j ] == row_k :
204- out [row_k , col_i ] += value * matrix [row_j , col_i ]
205-
206-
207- @ti .kernel
208- def _event_csr_matmat_homo (values : ti .types .ndarray (ndim = 1 ),
209- col_indices : ti .types .ndarray (ndim = 1 ),
210- row_ptr : ti .types .ndarray (ndim = 1 ),
211- matrix : ti .types .ndarray (ndim = 2 ),
212- out : ti .types .ndarray (ndim = 2 )):
213- value = values [0 ]
214- for row_i , col_k in ti .ndrange (out .shape [0 ], out .shape [1 ]):
215- r = 0.
216- for row_j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
217- if matrix [col_indices [row_j ], col_k ] != 0. :
218- r += matrix [col_indices [row_j ], col_k ]
219- out [row_i , col_k ] = r * value
220-
221-
222- @ti .kernel
223- def _event_csr_matmat_bool_homo (values : ti .types .ndarray (ndim = 1 ),
224- col_indices : ti .types .ndarray (ndim = 1 ),
225- row_ptr : ti .types .ndarray (ndim = 1 ),
226- matrix : ti .types .ndarray (ndim = 2 ),
227- out : ti .types .ndarray (ndim = 2 )):
228- value = values [0 ]
229- for row_i , col_k in ti .ndrange (out .shape [0 ], out .shape [1 ]):
230- r = 0.
231- for row_j in range (row_ptr [row_i ], row_ptr [row_i + 1 ]):
232- if matrix [col_indices [row_j ], col_k ]:
233- r += matrix [col_indices [row_j ], col_k ]
234- out [row_i , col_k ] = r * value
235-
236-
237- def _event_csr_matmat_jvp_values (val_dot , values , col_indices , row_ptr , matrix , * , outs , transpose , shape ):
238- return normal_csrmm (val_dot , col_indices , row_ptr , matrix , shape = shape , transpose = transpose )
239-
240-
241- def _event_csr_matmat_jvp_matrix (mat_dot , values , col_indices , row_ptr , matrix , * , outs , transpose , shape ):
242- return normal_csrmm (values , col_indices , row_ptr , mat_dot , shape = shape , transpose = transpose )
243-
244-
245- def _event_csr_matmat_transpose (
246- ct , data , indices , indptr , matrix , * , outs , transpose , shape ,
247- ):
248- if ad .is_undefined_primal (indices ) or ad .is_undefined_primal (indptr ):
249- raise ValueError ("Cannot transpose with respect to sparse indices." )
250- if ad .is_undefined_primal (matrix ):
251- ct_matrix = raw_event_csrmm_taichi (data , indices , indptr , ct [0 ], shape = shape , transpose = not transpose )[0 ]
252- return data , indices , indptr , (ad .Zero (matrix ) if type (ct [0 ]) is ad .Zero else ct_matrix )
253-
254- else :
255- if type (ct [0 ]) is ad .Zero :
256- ct_data = ad .Zero (data )
257- else :
258- if data .aval .shape [0 ] == 1 : # scalar
259- ct_data = \
260- raw_event_csrmm_taichi (jnp .ones (1 ), indices , indptr , matrix , shape = shape , transpose = transpose )[0 ]
261- ct_data = jnp .sum (ct [0 ] * ct_data )
262- else : # heter
263- matrix = jnp .asarray (matrix )
264- row , col = csr_to_coo (indices , indptr )
265- ct_data = (ct [0 ][row ] * matrix [col ]).sum (1 )
266- return ct_data , indices , indptr , matrix
267-
268-
269- def _define_op (cpu_kernel , gpu_kernel ):
270- prim = XLACustomOp (cpu_kernel = cpu_kernel , gpu_kernel = gpu_kernel )
271- prim .defjvp (_event_csr_matmat_jvp_values , None , None , _event_csr_matmat_jvp_matrix )
272- prim .def_transpose_rule (_event_csr_matmat_transpose )
273- return prim
274-
275-
276- # transpose heter
277- _event_csr_matmat_transpose_heter_p = _define_op (cpu_kernel = _event_csr_matmat_transpose_heter ,
278- gpu_kernel = _event_csr_matmat_transpose_heter )
279-
280- # no transpose heter
281- _event_csr_matmat_heter_p = _define_op (cpu_kernel = _event_csr_matmat_heter ,
282- gpu_kernel = _event_csr_matmat_heter )
283-
284- # transpose homo
285- _event_csr_matmat_transpose_homo_p = _define_op (cpu_kernel = _event_csr_matmat_transpose_homo ,
286- gpu_kernel = _event_csr_matmat_transpose_homo )
287-
288- # no transpose homo
289- _event_csr_matmat_homo_p = _define_op (cpu_kernel = _event_csr_matmat_homo ,
290- gpu_kernel = _event_csr_matmat_homo )
291-
292- # bool transpose heter
293- _event_csr_matmat_transpose_bool_heter_p = _define_op (cpu_kernel = _event_csr_matmat_transpose_bool_heter ,
294- gpu_kernel = _event_csr_matmat_transpose_bool_heter )
295-
296- # bool no transpose heter
297- _event_csr_matmat_bool_heter_p = _define_op (cpu_kernel = _event_csr_matmat_bool_heter ,
298- gpu_kernel = _event_csr_matmat_bool_heter )
299-
300- # bool transpose homo
301- _event_csr_matmat_transpose_bool_homo_p = _define_op (cpu_kernel = _event_csr_matmat_transpose_bool_homo ,
302- gpu_kernel = _event_csr_matmat_transpose_bool_homo )
303-
304- # bool no transpose homo
305- _event_csr_matmat_bool_homo_p = _define_op (cpu_kernel = _event_csr_matmat_bool_homo ,
306- gpu_kernel = _event_csr_matmat_bool_homo )
307-
308- # heter CUSPARSE
309- _csr_matmat_cusparse_p = csr .csr_matmat_p
310- register_general_batching (_csr_matmat_cusparse_p )
52+ return bt_event_csrmm (data , indices , indptr , matrix , shape = shape , transpose = transpose )
0 commit comments