Skip to content

Commit 09e1cba

Browse files
author
Pavan Ramkumar
authored
Merge pull request #247 from jasmainak/check_distr
ENH add check for distr
2 parents 29f2c71 + c27ce74 commit 09e1cba

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

pyglmnet/pyglmnet.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from .base import BaseEstimator, is_classifier, check_version
1010

1111

12+
ALLOWED_DISTRS = ['gaussian', 'binomial', 'softplus', 'poisson',
13+
'probit', 'gamma']
14+
15+
1216
def _lmb(distr, beta0, beta, X, eta):
1317
"""Conditional intensity function."""
1418
z = beta0 + np.dot(X, beta)
@@ -333,7 +337,7 @@ class GLM(BaseEstimator):
333337
----------
334338
distr : str
335339
distribution family can be one of the following
336-
'gaussian' | 'binomial' | 'poisson' | 'softplus'
340+
'gaussian' | 'binomial' | 'poisson' | 'softplus' | 'probit' | 'gamma'
337341
default: 'poisson'.
338342
alpha : float
339343
the weighting between L1 penalty and L2 penalty term
@@ -401,6 +405,10 @@ def __init__(self, distr='poisson', alpha=0.5,
401405
if not isinstance(max_iter, int):
402406
raise ValueError('max_iter must be of type int')
403407

408+
if distr not in ALLOWED_DISTRS:
409+
raise ValueError('distr must be one of %s, Got '
410+
'%s' % (', '.join(ALLOWED_DISTRS), distr))
411+
404412
self.distr = distr
405413
self.alpha = alpha
406414
self.reg_lambda = reg_lambda
@@ -845,7 +853,7 @@ class GLMCV(object):
845853
----------
846854
distr : str
847855
distribution family can be one of the following
848-
'gaussian' | 'binomial' | 'poisson' | 'softplus'
856+
'gaussian' | 'binomial' | 'poisson' | 'softplus' | 'probit' | 'gamma'
849857
default: 'poisson'.
850858
alpha : float
851859
the weighting between L1 penalty and L2 penalty term
@@ -926,6 +934,10 @@ def __init__(self, distr='poisson', alpha=0.5,
926934
if not isinstance(reg_lambda, (list, np.ndarray)):
927935
reg_lambda = [reg_lambda]
928936

937+
if distr not in ALLOWED_DISTRS:
938+
raise ValueError('distr must be one of %s, Got '
939+
'%s' % (', '.join(ALLOWED_DISTRS), distr))
940+
929941
if not isinstance(max_iter, int):
930942
raise ValueError('max_iter must be of type int')
931943

tests/test_pyglmnet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def test_group_lasso():
136136

137137
def test_glmnet():
138138
"""Test glmnet."""
139+
assert_raises(ValueError, GLM, distr='blah')
140+
assert_raises(ValueError, GLM, distr='gaussian', max_iter=1.8)
141+
139142
n_samples, n_features = 100, 10
140143

141144
# coefficients
@@ -192,6 +195,9 @@ def test_glmnet():
192195

193196
def test_glmcv():
194197
"""Test GLMCV class."""
198+
assert_raises(ValueError, GLM, distr='blah')
199+
assert_raises(ValueError, GLM, distr='gaussian', max_iter=1.8)
200+
195201
scaler = StandardScaler()
196202
n_samples, n_features = 100, 10
197203

0 commit comments

Comments
 (0)