Skip to content

Commit a0b4223

Browse files
committed
improve readability and code logic
1 parent ab069bd commit a0b4223

File tree

1 file changed

+127
-70
lines changed

1 file changed

+127
-70
lines changed

torch_incremental_pca/incremental_pca.py

Lines changed: 127 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
1-
import torch
2-
from functools import partial
3-
41
from typing import Optional, Tuple
52

3+
import torch
4+
65

76
class IncrementalPCA:
87
"""
98
An implementation of Incremental Principal Components Analysis (IPCA) that leverages PyTorch for GPU acceleration.
9+
Adapted from https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/decomposition/_incremental_pca.py
1010
11-
This class provides methods to fit the model on data incrementally in batches, and to transform new data
12-
based on the principal components learned during the fitting process.
11+
This class provides methods to fit the model on data incrementally in batches, and to transform new data based on
12+
the principal components learned during the fitting process.
1313
14-
Attributes:
14+
Args:
1515
n_components (int, optional): Number of components to keep. If `None`, it's set to the minimum of the
16-
number of samples and features. Defaults to None.
16+
number of samples and features. Defaults to None.
1717
copy (bool): If False, input data will be overwritten. Defaults to True.
1818
batch_size (int, optional): The number of samples to use for each batch. Only needed if self.fit is called.
19-
If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None.
20-
svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword
21-
argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj,
22-
and gesvda. Defaults to None.
19+
If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None.
20+
svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword
21+
argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj, and gesvda. Defaults to
22+
None.
2323
lowrank (bool, optional): Whether to use torch.svd_lowrank instead of torch.linalg.svd which can be faster.
24-
Defaults to False.
25-
lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to
26-
n_components * 2.
24+
Defaults to False.
25+
lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to
26+
n_components * 2.
2727
lowrank_niter (int, optional): Number of subspace iterations to conduct for torch.svd_lowrank.
28-
Defaults to 4.
28+
Defaults to 4.
29+
lowrank_seed (int, optional): Seed for making results of torch.svd_lowrank reproducible.
2930
"""
3031

3132
def __init__(
@@ -36,49 +37,81 @@ def __init__(
3637
svd_driver: Optional[str] = None,
3738
lowrank: bool = False,
3839
lowrank_q: Optional[int] = None,
39-
lowrank_niter: int = 4
40+
lowrank_niter: int = 4,
41+
lowrank_seed: Optional[int] = None,
4042
):
41-
self.n_components_ = n_components
43+
self.n_components = n_components
4244
self.copy = copy
4345
self.batch_size = batch_size
44-
45-
if lowrank:
46-
if lowrank_q is None:
47-
assert n_components is not None, "n_components must be specified when using lowrank mode with lowrank_q=None."
48-
lowrank_q = n_components * 2
49-
assert lowrank_q >= n_components, "lowrank_q must be greater than or equal to n_components."
50-
def svd_fn(X):
51-
U, S, V = torch.svd_lowrank(X, q=lowrank_q, niter=lowrank_niter)
52-
return U, S, V.mH # V is returned as a conjugate transpose
53-
self._svd_fn = svd_fn
54-
55-
else:
56-
self._svd_fn = partial(torch.linalg.svd, full_matrices=False, driver=svd_driver)
57-
58-
59-
def _validate_data(self, X, dtype=torch.float32) -> torch.Tensor:
46+
self.svd_driver = svd_driver
47+
self.lowrank = lowrank
48+
self.lowrank_q = lowrank_q
49+
self.lowrank_niter = lowrank_niter
50+
self.lowrank_seed = lowrank_seed
51+
52+
self.n_features_ = None
53+
54+
if self.lowrank:
55+
self._validate_lowrank_params()
56+
57+
def _validate_lowrank_params(self):
58+
if self.lowrank_q is None:
59+
if self.n_components is None:
60+
raise ValueError("n_components must be specified when using lowrank mode with lowrank_q=None.")
61+
self.lowrank_q = self.n_components * 2
62+
elif self.lowrank_q < self.n_components:
63+
raise ValueError("lowrank_q must be greater than or equal to n_components.")
64+
65+
def _svd_fn_full(self, X):
66+
return torch.linalg.svd(X, full_matrices=False, driver=self.svd_driver)
67+
68+
def _svd_fn_lowrank(self, X):
69+
seed_enabled = self.lowrank_seed is not None
70+
with torch.random.fork_rng(enabled=seed_enabled):
71+
if seed_enabled:
72+
torch.manual_seed(self.lowrank_seed)
73+
U, S, V = torch.svd_lowrank(X, q=self.lowrank_q, niter=self.lowrank_niter)
74+
return U, S, V.mH
75+
76+
def _validate_data(self, X) -> torch.Tensor:
6077
"""
6178
Validates and converts the input data `X` to the appropriate tensor format.
6279
6380
Args:
6481
X (torch.Tensor): Input data.
65-
dtype (torch.dtype, optional): Desired data type for the tensor. Defaults to torch.float32.
6682
6783
Returns:
6884
torch.Tensor: Converted to appropriate format.
6985
"""
86+
valid_dtypes = [torch.float32, torch.float64]
87+
7088
if not isinstance(X, torch.Tensor):
71-
X = torch.tensor(X, dtype=dtype)
89+
X = torch.tensor(X, dtype=torch.float32)
7290
elif self.copy:
7391
X = X.clone()
7492

75-
if X.dtype != dtype:
76-
X = X.to(dtype)
93+
n_samples, n_features = X.shape
94+
if self.n_components is None:
95+
pass
96+
elif self.n_components > n_features:
97+
raise ValueError(
98+
f"n_components={self.n_components} invalid for n_features={n_features}, "
99+
"need more rows than columns for IncrementalPCA processing."
100+
)
101+
elif self.n_components > n_samples:
102+
raise ValueError(
103+
f"n_components={self.n_components} must be less or equal to the batch number of samples {n_samples}"
104+
)
105+
106+
if X.dtype not in valid_dtypes:
107+
X = X.to(torch.float32)
77108

78109
return X
79110

80111
@staticmethod
81-
def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
112+
def _incremental_mean_and_var(
113+
X, last_mean, last_variance, last_sample_count
114+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
82115
"""
83116
Computes the incremental mean and variance for the data `X`.
84117
@@ -95,12 +128,10 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
95128
return last_mean, last_variance, last_sample_count
96129

97130
if last_sample_count > 0:
98-
assert (
99-
last_mean is not None
100-
), "last_mean should not be None if last_sample_count > 0."
101-
assert (
102-
last_variance is not None
103-
), "last_variance should not be None if last_sample_count > 0."
131+
if last_mean is None:
132+
raise ValueError("last_mean should not be None if last_sample_count > 0.")
133+
if last_variance is None:
134+
raise ValueError("last_variance should not be None if last_sample_count > 0.")
104135

105136
new_sample_count = torch.tensor([X.shape[0]], device=X.device)
106137
updated_sample_count = last_sample_count + new_sample_count
@@ -128,9 +159,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
128159
updated_unnormalized_variance = (
129160
last_unnormalized_variance
130161
+ new_unnormalized_variance
131-
+ last_over_new_count
132-
/ updated_sample_count
133-
* (last_sum / last_over_new_count - new_sum).square()
162+
+ last_over_new_count / updated_sample_count * (last_sum / last_over_new_count - new_sum).square()
134163
)
135164
updated_variance = updated_unnormalized_variance / updated_sample_count
136165

@@ -146,7 +175,8 @@ def _svd_flip(u, v, u_based_decision=True) -> Tuple[torch.Tensor, torch.Tensor]:
146175
Args:
147176
u (torch.Tensor): Left singular vectors tensor.
148177
v (torch.Tensor): Right singular vectors tensor.
149-
u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping. Defaults to True.
178+
u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping.
179+
Defaults to True.
150180
151181
Returns:
152182
Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
@@ -157,7 +187,7 @@ def _svd_flip(u, v, u_based_decision=True) -> Tuple[torch.Tensor, torch.Tensor]:
157187
else:
158188
max_abs_rows = torch.argmax(torch.abs(v), dim=1)
159189
signs = torch.sign(v[range(v.shape[0]), max_abs_rows])
160-
u *= signs[:u.shape[1]].view(1, -1)
190+
u *= signs[: u.shape[1]].view(1, -1)
161191
v *= signs.view(-1, 1)
162192
return u, v
163193

@@ -176,14 +206,10 @@ def fit(self, X, check_input=True):
176206
X = self._validate_data(X)
177207
n_samples, n_features = X.shape
178208
if self.batch_size is None:
179-
self.batch_size_ = 5 * n_features
180-
else:
181-
self.batch_size_ = self.batch_size
209+
self.batch_size = 5 * n_features
182210

183-
for start in range(0, n_samples, self.batch_size_):
184-
end = min(start + self.batch_size_, n_samples)
185-
X_batch = X[start:end]
186-
self.partial_fit(X_batch, check_input=False)
211+
for batch in self.gen_batches(n_samples, self.batch_size, min_batch_size=self.n_components or 0):
212+
self.partial_fit(X[batch], check_input=False)
187213

188214
return self
189215

@@ -209,8 +235,14 @@ def partial_fit(self, X, check_input=True):
209235
self.mean_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
210236
self.var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
211237
self.n_samples_seen_ = torch.tensor([0], device=X.device)
212-
if not self.n_components_:
213-
self.n_components_ = min(n_samples, n_features)
238+
self.n_features_ = n_features
239+
if not self.n_components:
240+
self.n_components = min(n_samples, n_features)
241+
242+
if n_features != self.n_features_:
243+
raise ValueError(
244+
"Number of features of the new batch does not match the number of features of the first batch."
245+
)
214246

215247
col_mean, col_var, n_total_samples = self._incremental_mean_and_var(
216248
X, self.mean_, self.var_, self.n_samples_seen_
@@ -221,9 +253,7 @@ def partial_fit(self, X, check_input=True):
221253
else:
222254
col_batch_mean = torch.mean(X, dim=0)
223255
X -= col_batch_mean
224-
mean_correction_factor = torch.sqrt(
225-
(self.n_samples_seen_.double() / n_total_samples) * n_samples
226-
)
256+
mean_correction_factor = torch.sqrt((self.n_samples_seen_.double() / n_total_samples) * n_samples)
227257
mean_correction = mean_correction_factor * (self.mean_ - col_batch_mean)
228258
X = torch.vstack(
229259
(
@@ -233,20 +263,23 @@ def partial_fit(self, X, check_input=True):
233263
)
234264
)
235265

236-
U, S, Vt = self._svd_fn(X)
266+
if self.lowrank:
267+
U, S, Vt = self._svd_fn_lowrank(X)
268+
else:
269+
U, S, Vt = self._svd_fn_full(X)
237270
U, Vt = self._svd_flip(U, Vt, u_based_decision=False)
238271
explained_variance = S**2 / (n_total_samples - 1)
239272
explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples)
240273

241274
self.n_samples_seen_ = n_total_samples
242-
self.components_ = Vt[:self.n_components_]
243-
self.singular_values_ = S[:self.n_components_]
275+
self.components_ = Vt[: self.n_components]
276+
self.singular_values_ = S[: self.n_components]
244277
self.mean_ = col_mean
245278
self.var_ = col_var
246-
self.explained_variance_ = explained_variance[:self.n_components_]
247-
self.explained_variance_ratio_ = explained_variance_ratio[:self.n_components_]
248-
if self.n_components_ not in (n_samples, n_features):
249-
self.noise_variance_ = explained_variance[self.n_components_:].mean()
279+
self.explained_variance_ = explained_variance[: self.n_components]
280+
self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components]
281+
if self.n_components not in (n_samples, n_features):
282+
self.noise_variance_ = explained_variance[self.n_components :].mean()
250283
else:
251284
self.noise_variance_ = torch.tensor(0.0, device=X.device)
252285
return self
@@ -263,5 +296,29 @@ def transform(self, X) -> torch.Tensor:
263296
Returns:
264297
torch.Tensor: Transformed data tensor with shape (n_samples, n_components).
265298
"""
266-
X -= self.mean_
267-
return torch.mm(X, self.components_.T)
299+
X = X - self.mean_
300+
return torch.mm(X.double(), self.components_.T).to(X.dtype)
301+
302+
@staticmethod
303+
def gen_batches(n: int, batch_size: int, min_batch_size: int = 0):
304+
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
305+
306+
The last slice may contain less than `batch_size` elements, when `batch_size` does not divide `n`.
307+
308+
Args:
309+
n (int): Size of the sequence.
310+
batch_size (int): Number of elements in each batch.
311+
min_batch_size (int, optional): Minimum number of elements in each batch. Defaults to 0.
312+
313+
Yields:
314+
slice: A slice of `batch_size` elements.
315+
"""
316+
start = 0
317+
for _ in range(int(n // batch_size)):
318+
end = start + batch_size
319+
if end + min_batch_size > n:
320+
continue
321+
yield slice(start, end)
322+
start = end
323+
if start < n:
324+
yield slice(start, n)

0 commit comments

Comments
 (0)