Skip to content

Commit 752e516

Browse files
committed
🐛 Update evaluate code
1 parent ff22962 commit 752e516

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ def evaluate(hp, validloader, model):
1717
l1 = torch.nn.L1Loss()
1818
model.eval()
1919
for valid in validloader:
20-
x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
20+
x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_, p_avg_, p_std_, p_cwt_cont_ = valid
2121

2222
with torch.no_grad():
2323
ilens = torch.tensor([x_[-1].shape[0]], dtype=torch.long, device=x_.device)
24-
_, after_outs, d_outs, e_outs, p_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel]
24+
_, after_outs, d_outs, e_outs, p_outs, p_avg_outs, p_std_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel]
2525

2626
# e_orig = model.energy_predictor.to_one_hot(e_).squeeze()
2727
# p_orig = model.pitch_predictor.to_one_hot(p_).squeeze()
@@ -30,7 +30,7 @@ def evaluate(hp, validloader, model):
3030

3131
dur_diff.append(l1(d_outs, dur_.cuda()).item()) #.numpy()
3232
energy_diff.append(l1(e_outs, e_.cuda()).item()) #.numpy()
33-
pitch_diff.append(l1(p_outs, p_.cuda()).item()) #.numpy()
33+
pitch_diff.append(l1(p_outs, p_cwt_cont_.cuda()).item()) #.numpy()
3434

3535

3636
'''_, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate)

train_fastspeech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def train(args, hp, hp_str, logger, vocoder):
166166
p_avg_.cuda(),
167167
p_std_.cuda()
168168
)
169-
169+
170170
mels_ = model.inference(x_[-1].cuda()) # [T, num_mel]
171171

172172
model.train()

0 commit comments

Comments
 (0)