Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions openvoice/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from openvoice.commons import init_weights, get_padding

from openvoice.mps_support import convt1d_chunked, conv1d_chunked

class TextEncoder(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -276,7 +277,10 @@ def forward(self, x, g=None):

for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
if str(x.device) != "mps:0":
x = self.ups[i](x)
else:
x = convt1d_chunked(self.ups[i], x)
xs = None
for j in range(self.num_kernels):
if xs is None:
Expand All @@ -285,7 +289,10 @@ def forward(self, x, g=None):
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
if str(x.device) != "mps:0":
x = self.conv_post(x)
else:
x = conv1d_chunked(self.conv_post, x)
x = torch.tanh(x)

return x
Expand Down
12 changes: 10 additions & 2 deletions openvoice/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from openvoice.transforms import piecewise_rational_quadratic_transform
from openvoice.attentions import Encoder

from openvoice.mps_support import convt1d_chunked, conv1d_chunked

LRELU_SLOPE = 0.1


Expand Down Expand Up @@ -298,11 +300,17 @@ def forward(self, x, x_mask=None):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
if str(x.device) != "mps:0":
xt = c1(xt)
else:
xt = conv1d_chunked(c1, xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
if str(x.device) != "mps:0":
xt = c2(xt)
else:
xt = conv1d_chunked(c2, xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
Expand Down
45 changes: 45 additions & 0 deletions openvoice/mps_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch


@torch.no_grad()
def convt1d_chunked(deconv: torch.nn.ConvTranspose1d, x: torch.Tensor, chunk: int = 32768, overlap: int = 2):
N, C, L = x.shape
outs = []
start = 0
while start < L:
end = min(L, start + chunk)
left_ctx = max(0, start - overlap)
right_ctx = min(L, end + overlap)
x_piece = x[..., left_ctx:right_ctx]
y_piece = deconv(x_piece)
left_cut = start - left_ctx
right_cut = y_piece.shape[-1] - (right_ctx - end)
if right_cut == 0:
y_valid = y_piece[..., left_cut:]
else:
y_valid = y_piece[..., left_cut:right_cut]
outs.append(y_valid)
start = end
return torch.cat(outs, dim=-1)


@torch.no_grad()
def conv1d_chunked(conv: torch.nn.Conv1d, x: torch.Tensor, chunk: int = 32768, overlap: int = 2):
N, C, L = x.shape
ys = []
start = 0
while start < L:
end = min(L, start + chunk)
left_ctx = max(0, start - overlap)
right_ctx = min(L, end + overlap)
x_piece = x[..., left_ctx:right_ctx]
y_piece = conv(x_piece)
left_cut = start - left_ctx
right_cut = y_piece.shape[-1] - (right_ctx - end)
if right_cut == 0:
y_valid = y_piece[..., left_cut:]
else:
y_valid = y_piece[..., left_cut:right_cut]
ys.append(y_valid)
start = end
return torch.cat(ys, dim=-1)