1- import torch
2- from functools import partial
3-
41from typing import Optional , Tuple
52
3+ import torch
4+
65
76class IncrementalPCA :
87 """
98 An implementation of Incremental Principal Components Analysis (IPCA) that leverages PyTorch for GPU acceleration.
9+ Adapted from https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/decomposition/_incremental_pca.py
1010
11- This class provides methods to fit the model on data incrementally in batches, and to transform new data
12- based on the principal components learned during the fitting process.
11+ This class provides methods to fit the model on data incrementally in batches, and to transform new data based on
12+ the principal components learned during the fitting process.
1313
14- Attributes :
14+ Args :
1515 n_components (int, optional): Number of components to keep. If `None`, it's set to the minimum of the
16- number of samples and features. Defaults to None.
16+ number of samples and features. Defaults to None.
1717 copy (bool): If False, input data will be overwritten. Defaults to True.
1818 batch_size (int, optional): The number of samples to use for each batch. Only needed if self.fit is called.
19- If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None.
20- svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword
21- argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj,
22- and gesvda. Defaults to None.
19+ If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None.
20+ svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword
21+ argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj, and gesvda. Defaults to
22+ None.
2323 lowrank (bool, optional): Whether to use torch.svd_lowrank instead of torch.linalg.svd which can be faster.
24- Defaults to False.
25- lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to
26- n_components * 2.
24+ Defaults to False.
25+ lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to
26+ n_components * 2.
2727 lowrank_niter (int, optional): Number of subspace iterations to conduct for torch.svd_lowrank.
28- Defaults to 4.
28+ Defaults to 4.
29+ lowrank_seed (int, optional): Seed for making results of torch.svd_lowrank reproducible.
2930 """
3031
3132 def __init__ (
@@ -36,49 +37,81 @@ def __init__(
3637 svd_driver : Optional [str ] = None ,
3738 lowrank : bool = False ,
3839 lowrank_q : Optional [int ] = None ,
39- lowrank_niter : int = 4
40+ lowrank_niter : int = 4 ,
41+ lowrank_seed : Optional [int ] = None ,
4042 ):
41- self .n_components_ = n_components
43+ self .n_components = n_components
4244 self .copy = copy
4345 self .batch_size = batch_size
44-
45- if lowrank :
46- if lowrank_q is None :
47- assert n_components is not None , "n_components must be specified when using lowrank mode with lowrank_q=None."
48- lowrank_q = n_components * 2
49- assert lowrank_q >= n_components , "lowrank_q must be greater than or equal to n_components."
50- def svd_fn (X ):
51- U , S , V = torch .svd_lowrank (X , q = lowrank_q , niter = lowrank_niter )
52- return U , S , V .mH # V is returned as a conjugate transpose
53- self ._svd_fn = svd_fn
54-
55- else :
56- self ._svd_fn = partial (torch .linalg .svd , full_matrices = False , driver = svd_driver )
57-
58-
59- def _validate_data (self , X , dtype = torch .float32 ) -> torch .Tensor :
46+ self .svd_driver = svd_driver
47+ self .lowrank = lowrank
48+ self .lowrank_q = lowrank_q
49+ self .lowrank_niter = lowrank_niter
50+ self .lowrank_seed = lowrank_seed
51+
52+ self .n_features_ = None
53+
54+ if self .lowrank :
55+ self ._validate_lowrank_params ()
56+
57+ def _validate_lowrank_params (self ):
58+ if self .lowrank_q is None :
59+ if self .n_components is None :
60+ raise ValueError ("n_components must be specified when using lowrank mode with lowrank_q=None." )
61+ self .lowrank_q = self .n_components * 2
62+ elif self .lowrank_q < self .n_components :
63+ raise ValueError ("lowrank_q must be greater than or equal to n_components." )
64+
65+ def _svd_fn_full (self , X ):
66+ return torch .linalg .svd (X , full_matrices = False , driver = self .svd_driver )
67+
68+ def _svd_fn_lowrank (self , X ):
69+ seed_enabled = self .lowrank_seed is not None
70+ with torch .random .fork_rng (enabled = seed_enabled ):
71+ if seed_enabled :
72+ torch .manual_seed (self .lowrank_seed )
73+ U , S , V = torch .svd_lowrank (X , q = self .lowrank_q , niter = self .lowrank_niter )
74+ return U , S , V .mH
75+
76+ def _validate_data (self , X ) -> torch .Tensor :
6077 """
6178 Validates and converts the input data `X` to the appropriate tensor format.
6279
6380 Args:
6481 X (torch.Tensor): Input data.
65- dtype (torch.dtype, optional): Desired data type for the tensor. Defaults to torch.float32.
6682
6783 Returns:
6884 torch.Tensor: Converted to appropriate format.
6985 """
86+ valid_dtypes = [torch .float32 , torch .float64 ]
87+
7088 if not isinstance (X , torch .Tensor ):
71- X = torch .tensor (X , dtype = dtype )
89+ X = torch .tensor (X , dtype = torch . float32 )
7290 elif self .copy :
7391 X = X .clone ()
7492
75- if X .dtype != dtype :
76- X = X .to (dtype )
93+ n_samples , n_features = X .shape
94+ if self .n_components is None :
95+ pass
96+ elif self .n_components > n_features :
97+ raise ValueError (
98+ f"n_components={ self .n_components } invalid for n_features={ n_features } , "
99+ "need more rows than columns for IncrementalPCA processing."
100+ )
101+ elif self .n_components > n_samples :
102+ raise ValueError (
103+ f"n_components={ self .n_components } must be less or equal to the batch number of samples { n_samples } "
104+ )
105+
106+ if X .dtype not in valid_dtypes :
107+ X = X .to (torch .float32 )
77108
78109 return X
79110
80111 @staticmethod
81- def _incremental_mean_and_var (X , last_mean , last_variance , last_sample_count ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
112+ def _incremental_mean_and_var (
113+ X , last_mean , last_variance , last_sample_count
114+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
82115 """
83116 Computes the incremental mean and variance for the data `X`.
84117
@@ -95,12 +128,10 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
95128 return last_mean , last_variance , last_sample_count
96129
97130 if last_sample_count > 0 :
98- assert (
99- last_mean is not None
100- ), "last_mean should not be None if last_sample_count > 0."
101- assert (
102- last_variance is not None
103- ), "last_variance should not be None if last_sample_count > 0."
131+ if last_mean is None :
132+ raise ValueError ("last_mean should not be None if last_sample_count > 0." )
133+ if last_variance is None :
134+ raise ValueError ("last_variance should not be None if last_sample_count > 0." )
104135
105136 new_sample_count = torch .tensor ([X .shape [0 ]], device = X .device )
106137 updated_sample_count = last_sample_count + new_sample_count
@@ -128,9 +159,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
128159 updated_unnormalized_variance = (
129160 last_unnormalized_variance
130161 + new_unnormalized_variance
131- + last_over_new_count
132- / updated_sample_count
133- * (last_sum / last_over_new_count - new_sum ).square ()
162+ + last_over_new_count / updated_sample_count * (last_sum / last_over_new_count - new_sum ).square ()
134163 )
135164 updated_variance = updated_unnormalized_variance / updated_sample_count
136165
@@ -146,7 +175,8 @@ def _svd_flip(u, v, u_based_decision=True) -> Tuple[torch.Tensor, torch.Tensor]:
146175 Args:
147176 u (torch.Tensor): Left singular vectors tensor.
148177 v (torch.Tensor): Right singular vectors tensor.
149- u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping. Defaults to True.
178+ u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping.
179+ Defaults to True.
150180
151181 Returns:
152182 Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
@@ -157,7 +187,7 @@ def _svd_flip(u, v, u_based_decision=True) -> Tuple[torch.Tensor, torch.Tensor]:
157187 else :
158188 max_abs_rows = torch .argmax (torch .abs (v ), dim = 1 )
159189 signs = torch .sign (v [range (v .shape [0 ]), max_abs_rows ])
160- u *= signs [:u .shape [1 ]].view (1 , - 1 )
190+ u *= signs [: u .shape [1 ]].view (1 , - 1 )
161191 v *= signs .view (- 1 , 1 )
162192 return u , v
163193
@@ -176,14 +206,10 @@ def fit(self, X, check_input=True):
176206 X = self ._validate_data (X )
177207 n_samples , n_features = X .shape
178208 if self .batch_size is None :
179- self .batch_size_ = 5 * n_features
180- else :
181- self .batch_size_ = self .batch_size
209+ self .batch_size = 5 * n_features
182210
183- for start in range (0 , n_samples , self .batch_size_ ):
184- end = min (start + self .batch_size_ , n_samples )
185- X_batch = X [start :end ]
186- self .partial_fit (X_batch , check_input = False )
211+ for batch in self .gen_batches (n_samples , self .batch_size , min_batch_size = self .n_components or 0 ):
212+ self .partial_fit (X [batch ], check_input = False )
187213
188214 return self
189215
@@ -209,8 +235,14 @@ def partial_fit(self, X, check_input=True):
209235 self .mean_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
210236 self .var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
211237 self .n_samples_seen_ = torch .tensor ([0 ], device = X .device )
212- if not self .n_components_ :
213- self .n_components_ = min (n_samples , n_features )
238+ self .n_features_ = n_features
239+ if not self .n_components :
240+ self .n_components = min (n_samples , n_features )
241+
242+ if n_features != self .n_features_ :
243+ raise ValueError (
244+ "Number of features of the new batch does not match the number of features of the first batch."
245+ )
214246
215247 col_mean , col_var , n_total_samples = self ._incremental_mean_and_var (
216248 X , self .mean_ , self .var_ , self .n_samples_seen_
@@ -221,9 +253,7 @@ def partial_fit(self, X, check_input=True):
221253 else :
222254 col_batch_mean = torch .mean (X , dim = 0 )
223255 X -= col_batch_mean
224- mean_correction_factor = torch .sqrt (
225- (self .n_samples_seen_ .double () / n_total_samples ) * n_samples
226- )
256+ mean_correction_factor = torch .sqrt ((self .n_samples_seen_ .double () / n_total_samples ) * n_samples )
227257 mean_correction = mean_correction_factor * (self .mean_ - col_batch_mean )
228258 X = torch .vstack (
229259 (
@@ -233,20 +263,23 @@ def partial_fit(self, X, check_input=True):
233263 )
234264 )
235265
236- U , S , Vt = self ._svd_fn (X )
266+ if self .lowrank :
267+ U , S , Vt = self ._svd_fn_lowrank (X )
268+ else :
269+ U , S , Vt = self ._svd_fn_full (X )
237270 U , Vt = self ._svd_flip (U , Vt , u_based_decision = False )
238271 explained_variance = S ** 2 / (n_total_samples - 1 )
239272 explained_variance_ratio = S ** 2 / torch .sum (col_var * n_total_samples )
240273
241274 self .n_samples_seen_ = n_total_samples
242- self .components_ = Vt [:self .n_components_ ]
243- self .singular_values_ = S [:self .n_components_ ]
275+ self .components_ = Vt [: self .n_components ]
276+ self .singular_values_ = S [: self .n_components ]
244277 self .mean_ = col_mean
245278 self .var_ = col_var
246- self .explained_variance_ = explained_variance [:self .n_components_ ]
247- self .explained_variance_ratio_ = explained_variance_ratio [:self .n_components_ ]
248- if self .n_components_ not in (n_samples , n_features ):
249- self .noise_variance_ = explained_variance [self .n_components_ :].mean ()
279+ self .explained_variance_ = explained_variance [: self .n_components ]
280+ self .explained_variance_ratio_ = explained_variance_ratio [: self .n_components ]
281+ if self .n_components not in (n_samples , n_features ):
282+ self .noise_variance_ = explained_variance [self .n_components :].mean ()
250283 else :
251284 self .noise_variance_ = torch .tensor (0.0 , device = X .device )
252285 return self
@@ -263,5 +296,29 @@ def transform(self, X) -> torch.Tensor:
263296 Returns:
264297 torch.Tensor: Transformed data tensor with shape (n_samples, n_components).
265298 """
266- X -= self .mean_
267- return torch .mm (X , self .components_ .T )
299+ X = X - self .mean_
300+ return torch .mm (X .double (), self .components_ .T ).to (X .dtype )
301+
302+ @staticmethod
303+ def gen_batches (n : int , batch_size : int , min_batch_size : int = 0 ):
304+ """Generator to create slices containing `batch_size` elements from 0 to `n`.
305+
306+ The last slice may contain less than `batch_size` elements, when `batch_size` does not divide `n`.
307+
308+ Args:
309+ n (int): Size of the sequence.
310+ batch_size (int): Number of elements in each batch.
311+ min_batch_size (int, optional): Minimum number of elements in each batch. Defaults to 0.
312+
313+ Yields:
314+ slice: A slice of `batch_size` elements.
315+ """
316+ start = 0
317+ for _ in range (int (n // batch_size )):
318+ end = start + batch_size
319+ if end + min_batch_size > n :
320+ continue
321+ yield slice (start , end )
322+ start = end
323+ if start < n :
324+ yield slice (start , n )
0 commit comments