@@ -28,6 +28,26 @@ inline double squared_exponential_covariance(double distance,
2828 return sigma * sigma * exp (-pow (distance / length_scale, 2 ));
2929}
3030
31+ inline double squared_exponential_covariance_derivative (double distance,
32+ double length_scale,
33+ double sigma = 1 .) {
34+ if (length_scale <= 0 .) {
35+ return 0 .;
36+ }
37+ return -2 * distance / (length_scale * length_scale) *
38+ squared_exponential_covariance (distance, length_scale, sigma);
39+ }
40+
41+ inline double squared_exponential_covariance_second_derivative (
42+ double distance, double length_scale, double sigma = 1 .) {
43+ if (length_scale <= 0 .) {
44+ return 0 .;
45+ }
46+ const auto ratio = distance / length_scale;
47+ return (4 . * ratio * ratio - 2 .) / (length_scale * length_scale) *
48+ squared_exponential_covariance (distance, length_scale, sigma);
49+ }
50+
3151/*
3252 * SquaredExponential distance
3353 * covariance(d) = sigma^2 exp(-(d/length_scale)^2)
@@ -83,6 +103,72 @@ class SquaredExponential
83103 sigma_squared_exponential.value );
84104 }
85105
106+ // This operator is only defined when the distance metric is also defined.
107+ template <typename X,
108+ typename std::enable_if<
109+ has_call_operator<DistanceMetricType, X &, X &>::value,
110+ int >::type = 0 >
111+ double _call_impl (const Derivative<X> &x, const X &y) const {
112+ double distance = this ->distance_metric_ (x.value , y);
113+ double distance_derivative = this ->distance_metric_ .derivative (x.value , y);
114+ return distance_derivative * squared_exponential_covariance_derivative (
115+ distance,
116+ squared_exponential_length_scale.value ,
117+ sigma_squared_exponential.value );
118+ }
119+
120+ template <typename X,
121+ typename std::enable_if<
122+ has_call_operator<DistanceMetricType, X &, X &>::value,
123+ int >::type = 0 >
124+ double _call_impl (const Derivative<X> &x, const Derivative<X> &y) const {
125+ const double distance = this ->distance_metric_ (x.value , y.value );
126+ const double d_x = this ->distance_metric_ .derivative (x.value , y.value );
127+ const double d_y = this ->distance_metric_ .derivative (y.value , x.value );
128+ const double d_xy =
129+ this ->distance_metric_ .second_derivative (x.value , y.value );
130+
131+ const double f_d = squared_exponential_covariance_derivative (
132+ distance, squared_exponential_length_scale.value ,
133+ sigma_squared_exponential.value );
134+
135+ const double f_dd = squared_exponential_covariance_second_derivative (
136+ distance, squared_exponential_length_scale.value ,
137+ sigma_squared_exponential.value );
138+
139+ std::cout << x.value << " " << y.value << " " << d_xy << " , " << f_d
140+ << " , " << d_x << " , " << d_y << " , " << f_dd << std::endl;
141+ return d_xy * f_d + d_x * d_y * f_dd;
142+ }
143+
144+ // This operator is only defined when the distance metric is also defined.
145+ template <typename X,
146+ typename std::enable_if<
147+ has_call_operator<DistanceMetricType, X &, X &>::value,
148+ int >::type = 0 >
149+ double _call_impl (const SecondDerivative<X> &x, const X &y) const {
150+ double d = this ->distance_metric_ (x.value , y);
151+ double d_1 = this ->distance_metric_ .derivative (x.value , y);
152+ double d_2 = this ->distance_metric_ .second_derivative (x.value , y);
153+ double f_1 = squared_exponential_covariance_derivative (
154+ d, squared_exponential_length_scale.value ,
155+ sigma_squared_exponential.value );
156+ double f_2 = squared_exponential_covariance_second_derivative (
157+ d, squared_exponential_length_scale.value ,
158+ sigma_squared_exponential.value );
159+ return d_2 * f_1 + d_1 * d_1 * f_2;
160+ }
161+
162+ // This operator is only defined when the distance metric is also defined.
163+ template <typename X,
164+ typename std::enable_if<
165+ has_call_operator<DistanceMetricType, X &, X &>::value,
166+ int >::type = 0 >
167+ double _call_impl (const SecondDerivative<X> &x,
168+ const SecondDerivative<X> &y) const {
169+ return NAN;
170+ }
171+
86172 DistanceMetricType distance_metric_;
87173};
88174
0 commit comments