11import torch
22from functools import partial
33
4- from typing import Optional
4+ from typing import Optional , Tuple
55
66
77class IncrementalPCA :
@@ -78,7 +78,7 @@ def _validate_data(self, X, dtype=torch.float32) -> torch.Tensor:
7878 return X
7979
8080 @staticmethod
81- def _incremental_mean_and_var (X , last_mean , last_variance , last_sample_count ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
81+ def _incremental_mean_and_var (X , last_mean , last_variance , last_sample_count ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
8282 """
8383 Computes the incremental mean and variance for the data `X`.
8484
@@ -89,7 +89,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
8989 last_sample_count (torch.Tensor): The count tensor of samples processed before the current batch.
9090
9191 Returns:
92- tuple [torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count.
92+ Tuple [torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count.
9393 """
9494 if X .shape [0 ] == 0 :
9595 return last_mean , last_variance , last_sample_count
@@ -137,7 +137,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
137137 return updated_mean , updated_variance , updated_sample_count
138138
139139 @staticmethod
140- def _svd_flip (u , v , u_based_decision = True ) -> tuple [torch .Tensor , torch .Tensor ]:
140+ def _svd_flip (u , v , u_based_decision = True ) -> Tuple [torch .Tensor , torch .Tensor ]:
141141 """
142142 Adjusts the signs of the singular vectors from the SVD decomposition for deterministic output.
143143
@@ -149,7 +149,7 @@ def _svd_flip(u, v, u_based_decision=True) -> tuple[torch.Tensor, torch.Tensor]:
149149 u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping. Defaults to True.
150150
151151 Returns:
152- tuple [torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
152+ Tuple [torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
153153 """
154154 if u_based_decision :
155155 max_abs_cols = torch .argmax (torch .abs (u ), dim = 0 )
0 commit comments