Skip to content

Commit 3297054

Browse files
authored
Merge pull request #48 from simpeg/consistent_refactor
Refactor base class
2 parents cf3f2e0 + 5047329 commit 3297054

File tree

13 files changed

+1356
-342
lines changed

13 files changed

+1356
-342
lines changed

.github/workflows/python-package-conda.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Testing
33
on:
44
push:
55
branches:
6-
- '*'
6+
- 'main'
77
tags:
88
- 'v*'
99
pull_request:
@@ -75,6 +75,8 @@ jobs:
7575
uses: codecov/codecov-action@v4
7676
with:
7777
verbose: true # optional (default = false)
78+
env:
79+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
7880

7981
distribute:
8082
name: Distributing from 3.8

pymatsolver/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
.. autosummary::
2525
:toctree: generated/
2626
27+
Triangle
2728
Forward
2829
Backward
2930
@@ -60,9 +61,9 @@
6061
}
6162

6263
# Simple solvers
63-
from .solvers import Diagonal, Forward, Backward
64-
from .wrappers import WrapDirect
65-
from .wrappers import WrapIterative
64+
from .solvers import Diagonal, Triangle, Forward, Backward
65+
from .wrappers import wrap_direct, WrapDirect
66+
from .wrappers import wrap_iterative, WrapIterative
6667

6768
# Scipy Iterative solvers
6869
from .iterative import SolverCG

pymatsolver/direct/mumps.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,106 @@
22
from mumps import Context
33

44
class Mumps(Base):
5-
"""
6-
Mumps solver
5+
"""The MUMPS direct solver.
6+
7+
This solver uses the python-mumps wrappers to factorize a sparse matrix, and use that factorization for solving.
8+
9+
Parameters
10+
----------
11+
A
12+
Matrix to solve with.
13+
ordering : str, default 'metis'
14+
Which ordering algorithm to use. See the `python-mumps` documentation for more details.
15+
is_symmetric : bool, optional
16+
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
17+
default to ``False`` if those fail.
18+
is_positive_definite : bool, optional
19+
Whether the matrix is positive definite.
20+
check_accuracy : bool, optional
21+
Whether to check the accuracy of the solution.
22+
check_rtol : float, optional
23+
The relative tolerance to check against for accuracy.
24+
check_atol : float, optional
25+
The absolute tolerance to check against for accuracy.
26+
accuracy_tol : float, optional
27+
Relative accuracy tolerance.
28+
.. deprecated:: 0.3.0
29+
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
30+
**kwargs
31+
Extra keyword arguments. If there are any left here a warning will be raised.
732
"""
833
_transposed = False
9-
ordering = ''
1034

11-
def __init__(self, A, **kwargs):
12-
self.set_kwargs(**kwargs)
35+
def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
36+
is_hermitian = kwargs.pop('is_hermitian', False)
37+
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
38+
if ordering is None:
39+
ordering = "metis"
40+
self.ordering = ordering
1341
self.solver = Context()
14-
self._set_A(A)
15-
self.A = A
42+
self._set_A(self.A)
1643

1744
def _set_A(self, A):
1845
self.solver.set_matrix(
1946
A,
2047
symmetric=self.is_symmetric,
21-
# positive_definite=self.is_positive_definite # doesn't (yet) support setting positive definiteness
2248
)
2349

2450
@property
2551
def ordering(self):
26-
return getattr(self, '_ordering', "metis")
52+
"""The ordering algorithm to use.
53+
54+
Returns
55+
-------
56+
str
57+
"""
58+
return self._ordering
2759

2860
@ordering.setter
2961
def ordering(self, value):
30-
self._ordering = value
62+
self._ordering = str(value)
3163

3264
@property
3365
def _factored(self):
3466
return self.solver.factored
3567

36-
@property
68+
def get_attributes(self):
69+
attrs = super().get_attributes()
70+
attrs['ordering'] = self.ordering
71+
return attrs
72+
3773
def transpose(self):
3874
trans_obj = Mumps.__new__(Mumps)
39-
trans_obj.A = self.A
75+
trans_obj._A = self.A
76+
for attr, value in self.get_attributes().items():
77+
setattr(trans_obj, attr, value)
4078
trans_obj.solver = self.solver
41-
trans_obj.is_symmetric = self.is_symmetric
42-
trans_obj.is_positive_definite = self.is_positive_definite
43-
trans_obj.ordering = self.ordering
4479
trans_obj._transposed = not self._transposed
4580
return trans_obj
4681

47-
T = transpose
48-
4982
def factor(self, A=None):
50-
reuse_analysis = False
51-
if A is not None:
52-
self._set_A(A)
53-
self.A = A
83+
"""(Re)factor the A matrix.
84+
85+
Parameters
86+
----------
87+
A : scipy.sparse.spmatrix
88+
The matrix to be factorized. If a previous factorization has been performed, this will
89+
reuse the previous factorization's analysis.
90+
"""
91+
reuse_analysis = self._factored
92+
do_factor = not self._factored
93+
if A is not None and A is not self.A:
5494
# if it was previously factored then re-use the analysis.
55-
reuse_analysis = self._factored
56-
if not self._factored:
95+
self._set_A(A)
96+
self._A = A
97+
do_factor = True
98+
if do_factor:
5799
pivot_tol = 0.0 if self.is_positive_definite else 0.01
58100
self.solver.factor(
59101
ordering=self.ordering, reuse_analysis=reuse_analysis, pivot_tol=pivot_tol
60102
)
61103

62-
def _solveM(self, rhs):
104+
def _solve_multiple(self, rhs):
63105
self.factor()
64106
if self._transposed:
65107
self.solver.mumps_instance.icntl[9] = 0
@@ -68,4 +110,4 @@ def _solveM(self, rhs):
68110
sol = self.solver.solve(rhs)
69111
return sol
70112

71-
_solve1 = _solveM
113+
_solve_single = _solve_multiple

pymatsolver/direct/pardiso.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,50 @@
33
from pydiso.mkl_solver import set_mkl_pardiso_threads, get_mkl_pardiso_max_threads
44

55
class Pardiso(Base):
6+
"""The Pardiso direct solver.
7+
8+
This solver uses the `pydiso` Intel MKL wrapper to factorize a sparse matrix, and use that
9+
factorization for solving.
10+
11+
Parameters
12+
----------
13+
A : scipy.sparse.spmatrix
14+
Matrix to solve with.
15+
n_threads : int, optional
16+
Number of threads to use for the `Pardiso` routine in Intel's MKL.
17+
is_symmetric : bool, optional
18+
Whether the matrix is symmetric. By default, it will perform some simple tests to check for symmetry, and
19+
default to ``False`` if those fail.
20+
is_positive_definite : bool, optional
21+
Whether the matrix is positive definite.
22+
is_hermitian : bool, optional
23+
Whether the matrix is hermitian. By default, it will perform some simple tests to check, and default to
24+
``False`` if those fail.
25+
check_accuracy : bool, optional
26+
Whether to check the accuracy of the solution.
27+
check_rtol : float, optional
28+
The relative tolerance to check against for accuracy.
29+
check_atol : float, optional
30+
The absolute tolerance to check against for accuracy.
31+
accuracy_tol : float, optional
32+
Relative accuracy tolerance.
33+
.. deprecated:: 0.3.0
34+
`accuracy_tol` will be removed in pymatsolver 0.4.0. Use `check_rtol` and `check_atol` instead.
35+
**kwargs
36+
Extra keyword arguments. If there are any left here a warning will be raised.
637
"""
7-
Pardiso Solver
838

9-
https://github.com/simpeg/pydiso
39+
_transposed = False
1040

11-
12-
documentation::
13-
14-
http://www.pardiso-project.org/
15-
"""
16-
17-
_factored = False
18-
19-
def __init__(self, A, **kwargs):
20-
self.A = A
21-
self.set_kwargs(**kwargs)
41+
def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
42+
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
2243
self.solver = MKLPardisoSolver(
2344
self.A,
2445
matrix_type=self._matrixType(),
2546
factor=False
2647
)
48+
if n_threads is not None:
49+
self.n_threads = n_threads
2750

2851
def _matrixType(self):
2952
"""
@@ -65,28 +88,45 @@ def _matrixType(self):
6588
return 13
6689

6790
def factor(self, A=None):
68-
if A is not None:
69-
self._factored = False
70-
self.A = A
71-
if not self._factored:
91+
"""(Re)factor the A matrix.
92+
93+
Parameters
94+
----------
95+
A : scipy.sparse.spmatrix
96+
The matrix to be factorized. If a previous factorization has been performed, this will
97+
reuse the previous factorization's analysis.
98+
"""
99+
if A is not None and self.A is not A:
100+
self._A = A
72101
self.solver.refactor(self.A)
73-
self._factored = True
74102

75-
def _solveM(self, rhs):
76-
self.factor()
77-
sol = self.solver.solve(rhs)
103+
def _solve_multiple(self, rhs):
104+
sol = self.solver.solve(rhs, transpose=self._transposed)
78105
return sol
79106

107+
def transpose(self):
108+
trans_obj = Pardiso.__new__(Pardiso)
109+
trans_obj._A = self.A
110+
for attr, value in self.get_attributes().items():
111+
setattr(trans_obj, attr, value)
112+
trans_obj.solver = self.solver
113+
trans_obj._transposed = not self._transposed
114+
return trans_obj
115+
80116
@property
81117
def n_threads(self):
82-
"""
83-
Number of threads to use for the Pardiso solver routine. This property
84-
is global to all Pardiso solver objects for a single python process.
118+
"""Number of threads to use for the Pardiso solver routine.
119+
120+
This property is global to all Pardiso solver objects for a single python process.
121+
122+
Returns
123+
-------
124+
int
85125
"""
86126
return get_mkl_pardiso_max_threads()
87127

88128
@n_threads.setter
89129
def n_threads(self, n_threads):
90130
set_mkl_pardiso_threads(n_threads)
91131

92-
_solve1 = _solveM
132+
_solve_single = _solve_multiple

0 commit comments

Comments
 (0)