@@ -111,6 +111,29 @@ def testTwoClassLogLikelihood(self):
111111 actual = session .run (avg_log_likelihood )
112112 self .assertAlmostEqual (actual , expected )
113113
114+ def testTwoClassLogLikelihoodVersusOldImplementation (self ):
115+ def alt_two_class_log_likelihood_impl (predictions , labels ):
116+ float_labels = tf .cast (labels , dtype = tf .float64 )
117+ float_predictions = tf .cast (tf .squeeze (predictions ), dtype = tf .float64 )
118+ # likelihood should be just p for class 1, and 1 - p for class 0.
119+ # signs is 1 for class 1, and -1 for class 0
120+ signs = 2 * float_labels - tf .ones_like (float_labels )
121+ # constant_term is 1 for class 0, and 0 for class 1.
122+ constant_term = tf .ones_like (float_labels ) - float_labels
123+ likelihoods = constant_term + signs * float_predictions
124+ log_likelihoods = tf .log (likelihoods )
125+ avg_log_likelihood = tf .reduce_mean (log_likelihoods )
126+ return avg_log_likelihood
127+ predictions = np .random .rand (1 , 10 , 1 )
128+ targets = np .random .randint (2 , size = 10 )
129+ with self .test_session () as session :
130+ new_log_likelihood , _ = metrics .two_class_log_likelihood (
131+ predictions , targets )
132+ alt_log_likelihood = alt_two_class_log_likelihood_impl (
133+ predictions , targets )
134+ new_impl , alt_impl = session .run ([new_log_likelihood , alt_log_likelihood ])
135+ self .assertAlmostEqual (new_impl , alt_impl )
136+
114137 def testRMSEMetric (self ):
115138 predictions = np .full ((10 , 1 ), 1 ) # All 1's
116139 targets = np .full ((10 , 1 ), 3 ) # All 3's
0 commit comments