11"""
22Check the speed of the conjugate gradient solver.
33"""
4+ import inspect
5+
46import numpy as np
57from numpy .testing import assert_equal
68
79from .common import Benchmark , safe_import
810
911with safe_import ():
1012 from scipy import linalg , sparse
11- from scipy .sparse .linalg import cg , minres , gmres , tfqmr , spsolve
13+ from scipy .sparse .linalg import cg , minres , gmres , tfqmr , spsolve , LinearOperator
1214with safe_import ():
1315 from scipy .sparse .linalg import lgmres
1416with safe_import ():
1820def _create_sparse_poisson1d (n ):
1921 # Make Gilbert Strang's favorite matrix
2022 # http://www-math.mit.edu/~gs/PIX/cupcakematrix.jpg
21- P1d = sparse .diags ([[- 1 ]* (n - 1 ), [2 ]* n , [- 1 ]* (n - 1 )], [- 1 , 0 , 1 ])
23+ P1d = sparse .diags_array (
24+ [[- 1 ]* (n - 1 ), [2 ]* n , [- 1 ]* (n - 1 )],
25+ offsets = [- 1 , 0 , 1 ],
26+ dtype = np .float64
27+ )
2228 assert_equal (P1d .shape , (n , n ))
2329 return P1d
2430
@@ -27,15 +33,21 @@ def _create_sparse_poisson2d(n):
2733 P1d = _create_sparse_poisson1d (n )
2834 P2d = sparse .kronsum (P1d , P1d )
2935 assert_equal (P2d .shape , (n * n , n * n ))
30- return P2d .tocsr ()
36+ return sparse .csr_array (P2d )
37+
38+
39+ def _create_sparse_poisson2d_coo (n ):
40+ P1d = _create_sparse_poisson1d (n )
41+ P2d = sparse .kronsum (P1d , P1d )
42+ assert_equal (P2d .shape , (n * n , n * n ))
43+ return sparse .coo_array (P2d )
3144
3245
3346class Bench (Benchmark ):
3447 params = [
35- [4 , 6 , 10 , 16 , 25 , 40 , 64 , 100 ],
36- # ['dense', 'spsolve', 'cg', 'minres', 'gmres', 'lgmres', 'gcrotmk',
37- # 'tfqmr']
38- ['cg' ]
48+ [4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 ],
49+ ['dense' , 'spsolve' , 'cg' , 'minres' , 'gmres' , 'lgmres' , 'gcrotmk' ,
50+ 'tfqmr' ]
3951 ]
4052 mapping = {'spsolve' : spsolve , 'cg' : cg , 'minres' : minres , 'gmres' : gmres ,
4153 'lgmres' : lgmres , 'gcrotmk' : gcrotmk , 'tfqmr' : tfqmr }
@@ -58,6 +70,86 @@ def time_solve(self, n, solver):
5870 self .mapping [solver ](self .P_sparse , self .b )
5971
6072
73+ class BatchedCG (Benchmark ):
74+ params = [
75+ [2 , 4 , 6 , 8 , 16 , 32 , 64 ],
76+ [1 , 10 , 100 , 500 , 1000 , 5000 , 10000 ]
77+ ]
78+ param_names = ['(n,n)' , 'batch_size' ]
79+
80+ def setup (self , n , batch_size ):
81+ if n >= 32 and batch_size >= 500 :
82+ raise NotImplementedError ()
83+ if n >= 16 and batch_size > 5000 :
84+ raise NotImplementedError ()
85+ rng = np .random .default_rng (42 )
86+
87+ self .batched = "xp" in inspect .signature (LinearOperator .__init__ ).parameters
88+ if self .batched :
89+ P_sparse = _create_sparse_poisson2d_coo (n )
90+ if batch_size > 1 :
91+ self .P_sparse = sparse .vstack (
92+ [P_sparse ] * batch_size , format = "coo"
93+ ).reshape (batch_size , n * n , n * n )
94+ self .b = rng .standard_normal ((batch_size , n * n ))
95+ else :
96+ self .P_sparse = P_sparse
97+ self .b = rng .standard_normal (n * n )
98+ else :
99+ self .P_sparse = _create_sparse_poisson2d (n )
100+ self .b = [rng .standard_normal (n * n ) for _ in range (batch_size )]
101+
102+ def time_solve (self , n , batch_size ):
103+ if self .batched :
104+ cg (self .P_sparse , self .b )
105+ else :
106+ for i in range (batch_size ):
107+ cg (self .P_sparse , self .b [i ])
108+
109+
110+ def _create_dense_random (n , batch_shape = None ):
111+ rng = np .random .default_rng (42 )
112+ M = rng .standard_normal ((n * n , n * n ))
113+ reg = 1e-3
114+ if batch_shape :
115+ M = np .broadcast_to (M [np .newaxis , ...], (* batch_shape , n * n , n * n ))
116+
117+ def matvec (x ):
118+ return np .squeeze (M .mT @ (M @ x [..., np .newaxis ]), axis = - 1 ) + reg * x
119+
120+ return LinearOperator (shape = M .shape , matvec = matvec , dtype = np .float64 )
121+
122+
123+ class BatchedCGDense (Benchmark ):
124+ params = [
125+ [2 , 4 , 8 , 16 , 24 ],
126+ [1 , 10 , 100 , 500 , 1000 ]
127+ ]
128+ param_names = ['(n,n)' , 'batch_size' ]
129+
130+ def setup (self , n , batch_size ):
131+ rng = np .random .default_rng (42 )
132+
133+ self .batched = "xp" in inspect .signature (LinearOperator .__init__ ).parameters
134+ if self .batched :
135+ if batch_size > 1 :
136+ self .A = _create_dense_random (n , batch_shape = (batch_size ,))
137+ self .b = rng .standard_normal ((batch_size , n * n ))
138+ else :
139+ self .A = _create_dense_random (n )
140+ self .b = rng .standard_normal (n * n )
141+ else :
142+ self .A = _create_dense_random (n )
143+ self .b = [rng .standard_normal (n * n ) for _ in range (batch_size )]
144+
145+ def time_solve (self , n , batch_size ):
146+ if self .batched :
147+ cg (self .A , self .b )
148+ else :
149+ for i in range (batch_size ):
150+ cg (self .A , self .b [i ])
151+
152+
61153class Lgmres (Benchmark ):
62154 params = [
63155 [10 , 50 , 100 , 1000 , 10000 ],
0 commit comments