diff --git a/main.py b/main.py index 05bfa0a..dec99c0 100755 --- a/main.py +++ b/main.py @@ -29,8 +29,8 @@ def evaluate(data, X, Y, model, evaluateL2, evaluateL1, batch_size): test = torch.cat((test, Y)); scale = data.scale.expand(output.size(0), data.m) - total_loss += evaluateL2(output * scale, Y * scale).data[0] - total_loss_l1 += evaluateL1(output * scale, Y * scale).data[0] + total_loss += evaluateL2(output * scale, Y * scale).item() + total_loss_l1 += evaluateL1(output * scale, Y * scale).item() n_samples += (output.size(0) * data.m); rse = math.sqrt(total_loss / n_samples)/data.rse rae = (total_loss_l1/n_samples)/data.rae @@ -42,7 +42,7 @@ def evaluate(data, X, Y, model, evaluateL2, evaluateL1, batch_size): mean_p = predict.mean(axis = 0) mean_g = Ytest.mean(axis = 0) index = (sigma_g!=0); - correlation = ((predict - mean_p) * (Ytest - mean_g)).mean(axis = 0)/(sigma_p * sigma_g); + correlation = ((predict - mean_p) * (Ytest - mean_g)).mean(axis = 0)/(sigma_p * sigma_g + 0.000000000000001); correlation = (correlation[index]).mean(); return rse, rae, correlation;