Skip to content

Commit dc63758

Browse files
committed
Replace math operators with braintaichi
1 parent 2a5adea commit dc63758

File tree

11 files changed

+45
-3100
lines changed

11 files changed

+45
-3100
lines changed

brainpy/_src/dependency_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
]
1717

1818
_minimal_brainpylib_version = '0.2.6'
19-
_minimal_taichi_version = (1, 7, 0)
19+
_minimal_taichi_version = (1, 7, 2)
2020

2121
numba = None
2222
taichi = None

brainpy/_src/math/event/csr_matmat.py

Lines changed: 2 additions & 260 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from jax import numpy as jnp
99
from jax.interpreters import ad
1010
from jax.experimental.sparse import csr
11+
from braintaichi import event_csrmm as bt_event_csrmm
1112

1213
from brainpy._src.dependency_check import import_taichi
1314
from brainpy._src.math.interoperability import as_jax
1415
from brainpy._src.math.ndarray import Array
1516
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
16-
from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm
1717
from brainpy._src.math.sparse.utils import csr_to_coo
1818
from brainpy._src.math.defaults import float_
1919

@@ -49,262 +49,4 @@ def csrmm(
4949
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
5050
representing the matrix-matrix product product.
5151
"""
52-
return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0]
53-
54-
55-
def raw_event_csrmm_taichi(
56-
data: Union[float, jnp.ndarray, Array],
57-
indices: Union[jnp.ndarray, Array],
58-
indptr: Union[jnp.ndarray, Array],
59-
matrix: Union[jnp.ndarray, Array],
60-
*,
61-
shape: Tuple[int, int],
62-
transpose: bool = False,
63-
):
64-
assert len(shape) == 2
65-
66-
data = jnp.atleast_1d(data)
67-
if np.ndim(data) == 1:
68-
if data.shape[0] not in [1, indices.shape[0]]:
69-
raise ValueError('The size of data should be 1 or be consistent with indices.'
70-
f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.')
71-
72-
indices = as_jax(indices)
73-
indptr = as_jax(indptr)
74-
matrix = as_jax(matrix)
75-
76-
assert data.ndim == indices.ndim == indptr.ndim == 1
77-
assert matrix.ndim == 2
78-
assert indptr.shape[0] == shape[0] + 1
79-
if not jnp.issubdtype(indices.dtype, jnp.integer):
80-
raise ValueError('indices should be a 1D vector with integer type.')
81-
if not jnp.issubdtype(indptr.dtype, jnp.integer):
82-
raise ValueError('indptr should be a 1D vector with integer type.')
83-
84-
out_shape = shape[1] if transpose else shape[0]
85-
result_shape = (out_shape, matrix.shape[1])
86-
# if the shape of indices is (0,), then we return a zero matrix
87-
if indices.shape[0] == 0:
88-
return [jnp.zeros(result_shape, dtype=data.dtype), ]
89-
90-
assert matrix.shape[0] == (shape[0] if transpose else shape[1])
91-
92-
# homo -> taichi
93-
# heter -> cusparse
94-
if data.shape[0] != 1:
95-
if matrix.dtype == jnp.bool_:
96-
# change dtype to float
97-
matrix = matrix.astype(float_)
98-
return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ]
99-
else:
100-
if transpose:
101-
if matrix.dtype == jnp.bool_:
102-
prim = _event_csr_matmat_transpose_homo_p
103-
else:
104-
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
105-
else:
106-
if matrix.dtype == jnp.bool_:
107-
prim = _event_csr_matmat_bool_homo_p
108-
else:
109-
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
110-
return prim(data,
111-
indices,
112-
indptr,
113-
matrix,
114-
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
115-
transpose=transpose,
116-
shape=shape)
117-
118-
119-
# taichi kernels
120-
121-
@ti.kernel
122-
def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1),
123-
col_indices: ti.types.ndarray(ndim=1),
124-
row_ptr: ti.types.ndarray(ndim=1),
125-
matrix: ti.types.ndarray(ndim=2),
126-
out: ti.types.ndarray(ndim=2)):
127-
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
128-
for row_j in range(matrix.shape[0]):
129-
if matrix[row_j, col_i] != 0.:
130-
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
131-
if col_indices[j] == row_k:
132-
out[row_k, col_i] += values[j] * matrix[row_j, col_i]
133-
134-
135-
@ti.kernel
136-
def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1),
137-
col_indices: ti.types.ndarray(ndim=1),
138-
row_ptr: ti.types.ndarray(ndim=1),
139-
matrix: ti.types.ndarray(ndim=2),
140-
out: ti.types.ndarray(ndim=2)):
141-
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
142-
for row_j in range(matrix.shape[0]):
143-
if matrix[row_j, col_i]:
144-
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
145-
if col_indices[j] == row_k:
146-
out[row_k, col_i] += values[j] * matrix[row_j, col_i]
147-
148-
149-
@ti.kernel
150-
def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1),
151-
col_indices: ti.types.ndarray(ndim=1),
152-
row_ptr: ti.types.ndarray(ndim=1),
153-
matrix: ti.types.ndarray(ndim=2),
154-
out: ti.types.ndarray(ndim=2)):
155-
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
156-
r = 0.
157-
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
158-
if matrix[col_indices[row_j], col_k] != 0.:
159-
r += values[row_j] * matrix[col_indices[row_j], col_k]
160-
out[row_i, col_k] = r
161-
162-
163-
@ti.kernel
164-
def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1),
165-
col_indices: ti.types.ndarray(ndim=1),
166-
row_ptr: ti.types.ndarray(ndim=1),
167-
matrix: ti.types.ndarray(ndim=2),
168-
out: ti.types.ndarray(ndim=2)):
169-
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
170-
r = 0.
171-
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
172-
if matrix[col_indices[row_j], col_k]:
173-
r += values[row_j] * matrix[col_indices[row_j], col_k]
174-
out[row_i, col_k] = r
175-
176-
177-
@ti.kernel
178-
def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1),
179-
col_indices: ti.types.ndarray(ndim=1),
180-
row_ptr: ti.types.ndarray(ndim=1),
181-
matrix: ti.types.ndarray(ndim=2),
182-
out: ti.types.ndarray(ndim=2)):
183-
value = values[0]
184-
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
185-
for row_j in range(matrix.shape[0]):
186-
if matrix[row_j, col_i] != 0.:
187-
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
188-
if col_indices[j] == row_k:
189-
out[row_k, col_i] += value * matrix[row_j, col_i]
190-
191-
192-
@ti.kernel
193-
def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1),
194-
col_indices: ti.types.ndarray(ndim=1),
195-
row_ptr: ti.types.ndarray(ndim=1),
196-
matrix: ti.types.ndarray(ndim=2),
197-
out: ti.types.ndarray(ndim=2)):
198-
value = values[0]
199-
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
200-
for row_j in range(matrix.shape[0]):
201-
if matrix[row_j, col_i]:
202-
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
203-
if col_indices[j] == row_k:
204-
out[row_k, col_i] += value * matrix[row_j, col_i]
205-
206-
207-
@ti.kernel
208-
def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1),
209-
col_indices: ti.types.ndarray(ndim=1),
210-
row_ptr: ti.types.ndarray(ndim=1),
211-
matrix: ti.types.ndarray(ndim=2),
212-
out: ti.types.ndarray(ndim=2)):
213-
value = values[0]
214-
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
215-
r = 0.
216-
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
217-
if matrix[col_indices[row_j], col_k] != 0.:
218-
r += matrix[col_indices[row_j], col_k]
219-
out[row_i, col_k] = r * value
220-
221-
222-
@ti.kernel
223-
def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1),
224-
col_indices: ti.types.ndarray(ndim=1),
225-
row_ptr: ti.types.ndarray(ndim=1),
226-
matrix: ti.types.ndarray(ndim=2),
227-
out: ti.types.ndarray(ndim=2)):
228-
value = values[0]
229-
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
230-
r = 0.
231-
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
232-
if matrix[col_indices[row_j], col_k]:
233-
r += matrix[col_indices[row_j], col_k]
234-
out[row_i, col_k] = r * value
235-
236-
237-
def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
238-
return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose)
239-
240-
241-
def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
242-
return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose)
243-
244-
245-
def _event_csr_matmat_transpose(
246-
ct, data, indices, indptr, matrix, *, outs, transpose, shape,
247-
):
248-
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
249-
raise ValueError("Cannot transpose with respect to sparse indices.")
250-
if ad.is_undefined_primal(matrix):
251-
ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
252-
return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix)
253-
254-
else:
255-
if type(ct[0]) is ad.Zero:
256-
ct_data = ad.Zero(data)
257-
else:
258-
if data.aval.shape[0] == 1: # scalar
259-
ct_data = \
260-
raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0]
261-
ct_data = jnp.sum(ct[0] * ct_data)
262-
else: # heter
263-
matrix = jnp.asarray(matrix)
264-
row, col = csr_to_coo(indices, indptr)
265-
ct_data = (ct[0][row] * matrix[col]).sum(1)
266-
return ct_data, indices, indptr, matrix
267-
268-
269-
def _define_op(cpu_kernel, gpu_kernel):
270-
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
271-
prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix)
272-
prim.def_transpose_rule(_event_csr_matmat_transpose)
273-
return prim
274-
275-
276-
# transpose heter
277-
_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter,
278-
gpu_kernel=_event_csr_matmat_transpose_heter)
279-
280-
# no transpose heter
281-
_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter,
282-
gpu_kernel=_event_csr_matmat_heter)
283-
284-
# transpose homo
285-
_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo,
286-
gpu_kernel=_event_csr_matmat_transpose_homo)
287-
288-
# no transpose homo
289-
_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo,
290-
gpu_kernel=_event_csr_matmat_homo)
291-
292-
# bool transpose heter
293-
_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter,
294-
gpu_kernel=_event_csr_matmat_transpose_bool_heter)
295-
296-
# bool no transpose heter
297-
_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter,
298-
gpu_kernel=_event_csr_matmat_bool_heter)
299-
300-
# bool transpose homo
301-
_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo,
302-
gpu_kernel=_event_csr_matmat_transpose_bool_homo)
303-
304-
# bool no transpose homo
305-
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
306-
gpu_kernel=_event_csr_matmat_bool_homo)
307-
308-
# heter CUSPARSE
309-
_csr_matmat_cusparse_p = csr.csr_matmat_p
310-
register_general_batching(_csr_matmat_cusparse_p)
52+
return bt_event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)

0 commit comments

Comments
 (0)