diff --git a/openvoice/models.py b/openvoice/models.py index b7c659a0..9ed0f27b 100644 --- a/openvoice/models.py +++ b/openvoice/models.py @@ -12,6 +12,7 @@ from openvoice.commons import init_weights, get_padding +from openvoice.mps_support import convt1d_chunked, conv1d_chunked class TextEncoder(nn.Module): def __init__(self, @@ -276,7 +277,10 @@ def forward(self, x, g=None): for i in range(self.num_upsamples): x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) + if str(x.device) != "mps:0": + x = self.ups[i](x) + else: + x = convt1d_chunked(self.ups[i], x) xs = None for j in range(self.num_kernels): if xs is None: @@ -285,7 +289,10 @@ def forward(self, x, g=None): xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = F.leaky_relu(x) - x = self.conv_post(x) + if str(x.device) != "mps:0": + x = self.conv_post(x) + else: + x = conv1d_chunked(self.conv_post, x) x = torch.tanh(x) return x diff --git a/openvoice/modules.py b/openvoice/modules.py index d659a326..f3ca6d51 100644 --- a/openvoice/modules.py +++ b/openvoice/modules.py @@ -11,6 +11,8 @@ from openvoice.transforms import piecewise_rational_quadratic_transform from openvoice.attentions import Encoder +from openvoice.mps_support import convt1d_chunked, conv1d_chunked + LRELU_SLOPE = 0.1 @@ -298,11 +300,17 @@ def forward(self, x, x_mask=None): xt = F.leaky_relu(x, LRELU_SLOPE) if x_mask is not None: xt = xt * x_mask - xt = c1(xt) + if str(x.device) != "mps:0": + xt = c1(xt) + else: + xt = conv1d_chunked(c1, xt) xt = F.leaky_relu(xt, LRELU_SLOPE) if x_mask is not None: xt = xt * x_mask - xt = c2(xt) + if str(x.device) != "mps:0": + xt = c2(xt) + else: + xt = conv1d_chunked(c2, xt) x = xt + x if x_mask is not None: x = x * x_mask diff --git a/openvoice/mps_support.py b/openvoice/mps_support.py new file mode 100644 index 00000000..1a6d972b --- /dev/null +++ b/openvoice/mps_support.py @@ -0,0 +1,45 @@ +import torch + + +@torch.no_grad() +def convt1d_chunked(deconv: torch.nn.ConvTranspose1d, x: torch.Tensor, chunk: int = 32768, overlap: int = 2): + N, C, L = x.shape + outs = [] + start = 0 + while start < L: + end = min(L, start + chunk) + left_ctx = max(0, start - overlap) + right_ctx = min(L, end + overlap) + x_piece = x[..., left_ctx:right_ctx] + y_piece = deconv(x_piece) + left_cut = start - left_ctx + right_cut = y_piece.shape[-1] - (right_ctx - end) + if right_cut == 0: + y_valid = y_piece[..., left_cut:] + else: + y_valid = y_piece[..., left_cut:right_cut] + outs.append(y_valid) + start = end + return torch.cat(outs, dim=-1) + + +@torch.no_grad() +def conv1d_chunked(conv: torch.nn.Conv1d, x: torch.Tensor, chunk: int = 32768, overlap: int = 2): + N, C, L = x.shape + ys = [] + start = 0 + while start < L: + end = min(L, start + chunk) + left_ctx = max(0, start - overlap) + right_ctx = min(L, end + overlap) + x_piece = x[..., left_ctx:right_ctx] + y_piece = conv(x_piece) + left_cut = start - left_ctx + right_cut = y_piece.shape[-1] - (right_ctx - end) + if right_cut == 0: + y_valid = y_piece[..., left_cut:] + else: + y_valid = y_piece[..., left_cut:right_cut] + ys.append(y_valid) + start = end + return torch.cat(ys, dim=-1)