@@ -54,20 +54,21 @@ def _grad_mu(distr, z, eta):
5454 return grad_mu
5555
5656
57- def _logL (distr , y , y_hat ):
57+ def _logL (distr , y , y_hat , z = None ):
5858 """The log likelihood."""
5959 if distr in ['softplus' , 'poisson' ]:
6060 eps = np .spacing (1 )
6161 logL = np .sum (y * np .log (y_hat + eps ) - y_hat )
6262 elif distr == 'gaussian' :
6363 logL = - 0.5 * np .sum ((y - y_hat )** 2 )
6464 elif distr == 'binomial' :
65- # analytical formula
66- logL = np .sum (y * np .log (y_hat ) + (1 - y ) * np .log (1 - y_hat ))
6765
68- # but this prevents underflow
69- # z = beta0 + np.dot(X, beta)
70- # logL = np.sum(y * z - np.log(1 + np.exp(z)))
66+ # prevents underflow
67+ if z is not None :
68+ logL = np .sum (y * z - np .log (1 + np .exp (z )))
69+ # for scoring
70+ else :
71+ logL = np .sum (y * np .log (y_hat ) + (1 - y ) * np .log (1 - y_hat ))
7172 elif distr == 'probit' :
7273 logL = np .sum (y * np .log (y_hat ) + (1 - y ) * np .log (1 - y_hat ))
7374 elif distr == 'gamma' :
@@ -123,8 +124,9 @@ def _L1penalty(beta, group=None):
123124def _loss (distr , alpha , Tau , reg_lambda , X , y , eta , group , beta ):
124125 """Define the objective function for elastic net."""
125126 n_samples = X .shape [0 ]
126- y_hat = _mu (distr , beta [0 ] + np .dot (X , beta [1 :]), eta )
127- L = 1. / n_samples * _logL (distr , y , y_hat )
127+ z = beta [0 ] + np .dot (X , beta [1 :])
128+ y_hat = _mu (distr , z , eta )
129+ L = 1. / n_samples * _logL (distr , y , y_hat , z )
128130 P = _penalty (alpha , beta [1 :], Tau , group )
129131 J = - L + reg_lambda * P
130132 return J
@@ -133,8 +135,9 @@ def _loss(distr, alpha, Tau, reg_lambda, X, y, eta, group, beta):
133135def _L2loss (distr , alpha , Tau , reg_lambda , X , y , eta , group , beta ):
134136 """Define the objective function for elastic net."""
135137 n_samples = X .shape [0 ]
136- y_hat = _mu (distr , beta [0 ] + np .dot (X , beta [1 :]), eta )
137- L = 1. / n_samples * _logL (distr , y , y_hat )
138+ z = beta [0 ] + np .dot (X , beta [1 :])
139+ y_hat = _mu (distr , z , eta )
140+ L = 1. / n_samples * _logL (distr , y , y_hat , z )
138141 P = 0.5 * (1 - alpha ) * _L2penalty (beta [1 :], Tau )
139142 J = - L + reg_lambda * P
140143 return J
0 commit comments