Skip to content

Commit ab069bd

Browse files
committed
change typehints for backward compatibility
1 parent 56455b7 commit ab069bd

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torch_incremental_pca/incremental_pca.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from functools import partial
33

4-
from typing import Optional
4+
from typing import Optional, Tuple
55

66

77
class IncrementalPCA:
@@ -78,7 +78,7 @@ def _validate_data(self, X, dtype=torch.float32) -> torch.Tensor:
7878
return X
7979

8080
@staticmethod
81-
def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
81+
def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
8282
"""
8383
Computes the incremental mean and variance for the data `X`.
8484
@@ -89,7 +89,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
8989
last_sample_count (torch.Tensor): The count tensor of samples processed before the current batch.
9090
9191
Returns:
92-
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count.
92+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count.
9393
"""
9494
if X.shape[0] == 0:
9595
return last_mean, last_variance, last_sample_count
@@ -137,7 +137,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count) ->
137137
return updated_mean, updated_variance, updated_sample_count
138138

139139
@staticmethod
140-
def _svd_flip(u, v, u_based_decision=True) -> tuple[torch.Tensor, torch.Tensor]:
140+
def _svd_flip(u, v, u_based_decision=True) -> Tuple[torch.Tensor, torch.Tensor]:
141141
"""
142142
Adjusts the signs of the singular vectors from the SVD decomposition for deterministic output.
143143
@@ -149,7 +149,7 @@ def _svd_flip(u, v, u_based_decision=True) -> tuple[torch.Tensor, torch.Tensor]:
149149
u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping. Defaults to True.
150150
151151
Returns:
152-
tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
152+
Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
153153
"""
154154
if u_based_decision:
155155
max_abs_cols = torch.argmax(torch.abs(u), dim=0)

0 commit comments

Comments
 (0)