From d44d259ed27fbf6e7355559a246f356894e7ab5a Mon Sep 17 00:00:00 2001 From: Wouter Zwerink <35296208+wouterzwerink@users.noreply.github.com> Date: Thu, 13 Mar 2025 14:05:09 +0100 Subject: [PATCH 1/3] Support other devices for TorchMacenkoAugmentor --- torchstain/torch/augmentors/macenko.py | 46 ++++++++++++++++---------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/torchstain/torch/augmentors/macenko.py b/torchstain/torch/augmentors/macenko.py index 0def8e4..76b6849 100644 --- a/torchstain/torch/augmentors/macenko.py +++ b/torchstain/torch/augmentors/macenko.py @@ -6,28 +6,31 @@ Source code ported from: https://github.com/schaugf/HEnorm_python Original implementation: https://github.com/mitkovetta/staining-normalization """ + + class TorchMacenkoAugmentor(HEAugmentor): - def __init__(self, sigma1=0.2, sigma2=0.2): + def __init__(self, sigma1=0.2, sigma2=0.2, device="cpu"): super().__init__() self.sigma1 = sigma1 self.sigma2 = sigma2 + self.device = device self.I = None - self.HERef = torch.tensor([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) - self.maxCRef = torch.tensor([1.9705, 1.0308]) + self.HERef = torch.tensor( + [[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]], device=device + ) + self.maxCRef = torch.tensor([1.9705, 1.0308], device=device) # Avoid using deprecated torch.lstsq (since 1.9.0) - self.updated_lstsq = hasattr(torch.linalg, 'lstsq') - + self.updated_lstsq = hasattr(torch.linalg, "lstsq") + def __convert_rgb2od(self, I, Io, beta): I = I.permute(1, 2, 0) # calculate optical density - OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1)/Io) + OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io) # remove transparent pixels ODhat = OD[~torch.any(OD < beta, dim=1)] @@ -43,12 +46,18 @@ def __find_HE(self, ODhat, eigvecs, alpha): minPhi = percentile(phi, alpha) maxPhi = percentile(phi, 100 - alpha) - vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1) - vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1) + vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze( + 1 + ) + vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze( + 1 + ) # a heuristic to make the vector corresponding to hematoxylin first and the # one corresponding to eosin second - HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1)) + HE = torch.where( + vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1) + ) return HE @@ -59,14 +68,14 @@ def __find_concentration(self, OD, HE): # determine concentrations of the individual stains if not self.updated_lstsq: return torch.lstsq(Y, HE)[0][:2] - + return torch.linalg.lstsq(HE, Y)[0] def __compute_matrices(self, I, Io, alpha, beta): OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) # compute eigenvectors - _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) + _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) eigvecs = eigvecs[:, [1, 2]] HE = self.__find_HE(ODhat, eigvecs, alpha) @@ -84,10 +93,10 @@ def fit(self, I, Io=240, alpha=1, beta=0.15): self.HERef = HE self.CRef = C self.maxCRef = maxC - + @staticmethod - def random_uniform(shape, low, high): - return (low - high) * torch.rand(*shape) + high + def random_uniform(shape, low, high, device): + return (low - high) * torch.rand(*shape, device=device) + high def augment(self, Io=240, alpha=1, beta=0.15): I = self.I @@ -99,7 +108,9 @@ def augment(self, Io=240, alpha=1, beta=0.15): C2 = C / torch.unsqueeze(maxC, axis=-1) # introduce noise to the concentrations (applied along axis=0) - C2 = (C2 * self.random_uniform((2, 1), 1 - self.sigma1, 1 + self.sigma1)) + self.random_uniform((2, 1), -self.sigma2, self.sigma2) + C2 = ( + C2 * self.random_uniform((2, 1), 1 - self.sigma1, 1 + self.sigma1, self.device) + ) + self.random_uniform((2, 1), -self.sigma2, self.sigma2, self.device) # recreate the image using reference mixing matrix Inorm = Io * torch.exp(-torch.matmul(self.HERef, C2)) @@ -107,4 +118,3 @@ def augment(self, Io=240, alpha=1, beta=0.15): Inorm = Inorm.T.reshape(h, w, c).int() return Inorm - \ No newline at end of file From a544e5463b9c7b21e1785b70bce499d2d4cf0404 Mon Sep 17 00:00:00 2001 From: Wouter Zwerink <35296208+wouterzwerink@users.noreply.github.com> Date: Thu, 13 Mar 2025 14:06:22 +0100 Subject: [PATCH 2/3] Update macenko.py --- torchstain/torch/augmentors/macenko.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchstain/torch/augmentors/macenko.py b/torchstain/torch/augmentors/macenko.py index 76b6849..954a866 100644 --- a/torchstain/torch/augmentors/macenko.py +++ b/torchstain/torch/augmentors/macenko.py @@ -6,8 +6,6 @@ Source code ported from: https://github.com/schaugf/HEnorm_python Original implementation: https://github.com/mitkovetta/staining-normalization """ - - class TorchMacenkoAugmentor(HEAugmentor): def __init__(self, sigma1=0.2, sigma2=0.2, device="cpu"): super().__init__() @@ -117,4 +115,5 @@ def augment(self, Io=240, alpha=1, beta=0.15): Inorm[Inorm > 255] = 255 Inorm = Inorm.T.reshape(h, w, c).int() + return Inorm From c87526019d7a66964bf26b11af6057a89bc858df Mon Sep 17 00:00:00 2001 From: Wouter Zwerink <35296208+wouterzwerink@users.noreply.github.com> Date: Thu, 13 Mar 2025 14:09:14 +0100 Subject: [PATCH 3/3] Revert formatting change --- torchstain/torch/augmentors/macenko.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchstain/torch/augmentors/macenko.py b/torchstain/torch/augmentors/macenko.py index 954a866..6fc1735 100644 --- a/torchstain/torch/augmentors/macenko.py +++ b/torchstain/torch/augmentors/macenko.py @@ -115,5 +115,4 @@ def augment(self, Io=240, alpha=1, beta=0.15): Inorm[Inorm > 255] = 255 Inorm = Inorm.T.reshape(h, w, c).int() - return Inorm