1111from tensorflow .experimental import numpy as tnp
1212
1313class Tweedie (distribution .AutoCompositeTensorDistribution ):
14- """Mean-Variance Regression (https://arxiv.org/pdf/1804.01631.pdf)
14+ """Tweedie
1515 """
1616
1717 def __init__ (self ,
1818 loc ,
19+ scale ,
1920 var_power = 1. ,
2021 quasi = False ,
22+ a = 1.01 ,
23+ b = 1.99 ,
2124 validate_args = False ,
2225 allow_nan_stats = True ,
2326 name = 'Tweedie' ):
@@ -26,6 +29,8 @@ def __init__(self,
2629 broadcasting.
2730 Args:
2831 loc: Floating point tensor; the means of the distribution(s).
32+ scale: Floating point tensor; the scale of the distribution for Quasi,
33+ phi for non-Quasi
2934 var_power: The variance power, also referred to as "p". The default is 1.
3035 quasi: Python `bool`, default `False`. When `True` quasi log-liklihood is used.
3136 validate_args: Python `bool`, default `False`. When `True` distribution
@@ -40,11 +45,15 @@ def __init__(self,
4045 """
4146 parameters = dict (locals ())
4247 with tf .name_scope (name ) as name :
43- dtype = dtype_util .common_dtype ([loc ], dtype_hint = tf .float32 )
48+ dtype = dtype_util .common_dtype ([loc , scale ], dtype_hint = tf .float32 )
4449 self ._loc = tensor_util .convert_nonref_to_tensor (
4550 loc , dtype = dtype , name = 'loc' )
51+ self ._scale = tensor_util .convert_nonref_to_tensor (
52+ scale , dtype = dtype , name = 'scale' )
4653 self ._p = var_power
4754 self .quasi = quasi
55+ self .a = a
56+ self .b = b
4857 super (Tweedie , self ).__init__ (
4958 dtype = dtype ,
5059 reparameterization_type = reparameterization .FULLY_REPARAMETERIZED ,
@@ -67,10 +76,15 @@ def _parameter_properties(cls, dtype, num_classes=None):
6776 def loc (self ):
6877 """Parameter for the mean."""
6978 return self ._loc
79+
80+ @property
81+ def scale (self ):
82+ """Parameter for standard deviation."""
83+ return self ._scale
7084
7185 @property
7286 def p (self ):
73- """Parameter for standard deviation ."""
87+ """Parameter for power ."""
7488 return self ._p
7589
7690 def _event_shape_tensor (self ):
@@ -81,7 +95,8 @@ def _event_shape(self):
8195
8296 def _sample_n (self , n , seed = None ):
8397 loc = tf .convert_to_tensor (self .loc )
84- shape = ps .concat ([[n ], self ._batch_shape_tensor (loc = loc , scale = 1 )],
98+ scale = tf .convert_to_tensor (self .scale )
99+ shape = ps .concat ([[n ], self ._batch_shape_tensor (loc = loc , scale = scale )],
85100 axis = 0 )
86101 sampled = samplers .normal (
87102 shape = shape , mean = 0. , stddev = 1. , dtype = self .dtype , seed = seed )
@@ -91,25 +106,36 @@ def _sample_n(self, n, seed=None):
91106 def _log_prob (self , x ):
92107 """Used for the loss of the model -- not an actual log prob"""
93108 if self .quasi : # from https://www.statsmodels.org/stable/_modules/statsmodels/genmod/families/family.html#Tweedie
94- llf = log (2 * tnp .pi ) + self .p * log (x )
109+ llf = log (2 * tnp .pi * self . scale ) + self .p * log (x )
95110 llf /= - 2
96111 u = (x ** (2 - self .p ) - (2 - self .p ) * x * self .loc ** (1 - self .p ) + (1 - self .p ) * self .loc ** (2 - self .p ))
97- u *= 1 / ((1 - self .p ) * (2 - self .p ))
112+ u *= 1 / (self . scale * (1 - self .p ) * (2 - self .p ))
98113 return llf - u
99114
100- else : # from https://github.com/cran/statmod/blob/master/R/tweedie.R negative deviance residuals
101- x1 = x + 0.1 * tf .cast (tf .equal (x , 0 ), tf .float32 )
102- theta = (tf .pow (x1 , 1 - self .p ) - tf .pow (self .loc , 1 - self .p )) / (1 - self .p )
103- kappa = (tf .pow (x , 2 - self .p ) - tf .pow (self .loc , 2 - self .p )) / (2 - self .p )
104-
105- return - 2 * (x * theta - kappa )
115+ else :
116+ # from https://github.com/cran/statmod/blob/master/R/tweedie.R negative deviance residuals
117+ # x1 = x + 0.1 * tf.cast(tf.equal(x, 0), tf.float32)
118+ # theta = (tf.pow(x1, 1 - self.p) - tf.pow(self.loc, 1 - self.p)) / (1 - self.p)
119+ # kappa = (tf.pow(x, 2 - self.p) - tf.pow(self.loc, 2 - self.p)) / (2 - self.p)
120+ # return - 2 * (x * theta - kappa)
121+ # from https://github.com/cran/mgcv/blob/aff4560d187dfd7d98c7bd367f5a0076faf129b7/R/gamlss.r#L2474
122+ ethi = tf .exp (- self .p ) # assuming p > 0
123+ p = (self .b + self .a * ethi )/ (1 + ethi )
124+ x1 = x + tf .cast (x == 0 , tf .float32 )
125+ theta = (tf .pow (x1 , 1 - p ) - tf .pow (self .loc , 1 - p )) / (1 - p )
126+ kappa = (tf .pow (x , 2 - p ) - tf .pow (self .loc , 2 - p )) / (2 - p )
127+ return tf .sign (x - self .loc ) * tf .sqrt (tf .nn .relu (2 * (x * theta - kappa ) * 1 / self .scale ))
128+
106129
107130
108131 def _mean (self ):
109132 return self .loc * tf .ones_like (self .scale )
110133
111134 def _stddev (self ):
112- return self .scale * tf .ones_like (self .loc )
135+ if self .quasi :
136+ return self ._scale
137+ else :
138+ return tf .sqrt (self ._scale * tf .pow (self .loc , self .p ))
113139
114140 def _default_event_space_bijector (self ):
115141 return identity_bijector .Identity (validate_args = self .validate_args )
0 commit comments