Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions torchstain/torch/augmentors/macenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,28 @@
Original implementation: https://github.com/mitkovetta/staining-normalization
"""
class TorchMacenkoAugmentor(HEAugmentor):
def __init__(self, sigma1=0.2, sigma2=0.2):
def __init__(self, sigma1=0.2, sigma2=0.2, device="cpu"):
super().__init__()

self.sigma1 = sigma1
self.sigma2 = sigma2
self.device = device

self.I = None

self.HERef = torch.tensor([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = torch.tensor([1.9705, 1.0308])
self.HERef = torch.tensor(
[[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]], device=device
)
self.maxCRef = torch.tensor([1.9705, 1.0308], device=device)

# Avoid using deprecated torch.lstsq (since 1.9.0)
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')
self.updated_lstsq = hasattr(torch.linalg, "lstsq")

def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)

# calculate optical density
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1)/Io)
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1) / Io)

# remove transparent pixels
ODhat = OD[~torch.any(OD < beta, dim=1)]
Expand All @@ -43,12 +44,18 @@ def __find_HE(self, ODhat, eigvecs, alpha):
minPhi = percentile(phi, alpha)
maxPhi = percentile(phi, 100 - alpha)

vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1)
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1)
vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(
1
)
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(
1
)

# a heuristic to make the vector corresponding to hematoxylin first and the
# one corresponding to eosin second
HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1))
HE = torch.where(
vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1)
)

return HE

Expand All @@ -59,14 +66,14 @@ def __find_concentration(self, OD, HE):
# determine concentrations of the individual stains
if not self.updated_lstsq:
return torch.lstsq(Y, HE)[0][:2]

return torch.linalg.lstsq(HE, Y)[0]

def __compute_matrices(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

# compute eigenvectors
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
_, eigvecs = torch.linalg.eigh(cov(ODhat.T))
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
Expand All @@ -84,10 +91,10 @@ def fit(self, I, Io=240, alpha=1, beta=0.15):
self.HERef = HE
self.CRef = C
self.maxCRef = maxC

@staticmethod
def random_uniform(shape, low, high):
return (low - high) * torch.rand(*shape) + high
def random_uniform(shape, low, high, device):
return (low - high) * torch.rand(*shape, device=device) + high

def augment(self, Io=240, alpha=1, beta=0.15):
I = self.I
Expand All @@ -99,12 +106,13 @@ def augment(self, Io=240, alpha=1, beta=0.15):
C2 = C / torch.unsqueeze(maxC, axis=-1)

# introduce noise to the concentrations (applied along axis=0)
C2 = (C2 * self.random_uniform((2, 1), 1 - self.sigma1, 1 + self.sigma1)) + self.random_uniform((2, 1), -self.sigma2, self.sigma2)
C2 = (
C2 * self.random_uniform((2, 1), 1 - self.sigma1, 1 + self.sigma1, self.device)
) + self.random_uniform((2, 1), -self.sigma2, self.sigma2, self.device)

# recreate the image using reference mixing matrix
Inorm = Io * torch.exp(-torch.matmul(self.HERef, C2))
Inorm[Inorm > 255] = 255
Inorm = Inorm.T.reshape(h, w, c).int()

return Inorm