|
9 | 9 | from .base import BaseEstimator, is_classifier, check_version |
10 | 10 |
|
11 | 11 |
|
| 12 | +ALLOWED_DISTRS = ['gaussian', 'binomial', 'softplus', 'poisson', |
| 13 | + 'probit', 'gamma'] |
| 14 | + |
| 15 | + |
12 | 16 | def _lmb(distr, beta0, beta, X, eta): |
13 | 17 | """Conditional intensity function.""" |
14 | 18 | z = beta0 + np.dot(X, beta) |
@@ -333,7 +337,7 @@ class GLM(BaseEstimator): |
333 | 337 | ---------- |
334 | 338 | distr : str |
335 | 339 | distribution family can be one of the following |
336 | | - 'gaussian' | 'binomial' | 'poisson' | 'softplus' |
| 340 | + 'gaussian' | 'binomial' | 'poisson' | 'softplus' | 'probit' | 'gamma' |
337 | 341 | default: 'poisson'. |
338 | 342 | alpha : float |
339 | 343 | the weighting between L1 penalty and L2 penalty term |
@@ -401,6 +405,10 @@ def __init__(self, distr='poisson', alpha=0.5, |
401 | 405 | if not isinstance(max_iter, int): |
402 | 406 | raise ValueError('max_iter must be of type int') |
403 | 407 |
|
| 408 | + if distr not in ALLOWED_DISTRS: |
| 409 | + raise ValueError('distr must be one of %s, Got ' |
| 410 | + '%s' % (', '.join(ALLOWED_DISTRS), distr)) |
| 411 | + |
404 | 412 | self.distr = distr |
405 | 413 | self.alpha = alpha |
406 | 414 | self.reg_lambda = reg_lambda |
@@ -845,7 +853,7 @@ class GLMCV(object): |
845 | 853 | ---------- |
846 | 854 | distr : str |
847 | 855 | distribution family can be one of the following |
848 | | - 'gaussian' | 'binomial' | 'poisson' | 'softplus' |
| 856 | + 'gaussian' | 'binomial' | 'poisson' | 'softplus' | 'probit' | 'gamma' |
849 | 857 | default: 'poisson'. |
850 | 858 | alpha : float |
851 | 859 | the weighting between L1 penalty and L2 penalty term |
@@ -926,6 +934,10 @@ def __init__(self, distr='poisson', alpha=0.5, |
926 | 934 | if not isinstance(reg_lambda, (list, np.ndarray)): |
927 | 935 | reg_lambda = [reg_lambda] |
928 | 936 |
|
| 937 | + if distr not in ALLOWED_DISTRS: |
| 938 | + raise ValueError('distr must be one of %s, Got ' |
| 939 | + '%s' % (', '.join(ALLOWED_DISTRS), distr)) |
| 940 | + |
929 | 941 | if not isinstance(max_iter, int): |
930 | 942 | raise ValueError('max_iter must be of type int') |
931 | 943 |
|
|
0 commit comments