Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion posydon/interpolation/IF_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class relies on the BaseIFInterpolator class to perform the interpolation
from posydon.interpolation.constraints import (
find_constraints_to_apply, sanitize_interpolated_quantities)

import time

# INITIAL-FINAL INTERPOLATOR
class IFInterpolator:
Expand Down Expand Up @@ -269,13 +270,16 @@ def evaluate(self, binary, sanitization_verbose=False):
"""
ynums = {}
ycats = {}

# s = time.time()
for interpolator in self.interpolators:
ynum, ycat = interpolator.evaluate(binary, sanitization_verbose)

ynums = {**ynums, **ynum}
ycats = {**ycats, **ycat}

# e = time.time()
# print(f"Iterated over {len(self.interpolators)} interpolators in {e - s}")

return ynums, ycats


Expand Down Expand Up @@ -666,7 +670,11 @@ def test_interpolator(self, Xt):

if isinstance(self.interp_method, list):
Xtn = self.X_scaler.normalize(Xt, classes)
# s = time.time()
Ypredn = self.interpolator.predict(Xtn, classes, self.X_scaler)
# e = time.time()

# print(f"Predicted one interpolator value in {e - s} seconds")
else:
Xtn = self.X_scaler.normalize(Xt)
Ypredn = self.interpolator.predict(Xtn)
Expand Down
68 changes: 68 additions & 0 deletions posydon/interpolation/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
from posydon.utils.common_functions import (stefan_boltzmann_law,
orbital_separation_from_period)

CLASSIFICATION_KEYS = [
"S<*>_state",
"mt_hist",
"S<*>_MOD<n>_SN_type",
"S<*>_MOD<n>_CO_type"
]

N_MODELS = 11 # how many super nova models are there?

# toggle this flag to enable/disable constraints (used for debugging)
INTERPOLATION_CONSTRAINTS_ON = True
Expand Down Expand Up @@ -511,3 +519,63 @@ def sanitize_interpolated_quantities(fvalues, constraints, verbose=False):
constraint["constraint"])

return sanitized


def mt_constraint(classes):

interpolation_class = classes["interpolation_class"]

if interpolation_class == "initial_MT":
classes["mt_hist"] == "ini_RLO"
elif interpolation_class == "no_MT":
classes["mt_hist"] = "no_RLO"
elif interpolation_class == "stable_MT":
pass
elif interpolation_class == "unstable_MT":
pass
elif interpolation_class == "stable_reverse_MT":
pass


CLASS_CONSTRAINTS = {
"S<*>_state": None,
"mt_hist": mt_constraint,
"S<*>_MOD<n>_SN_type": None,
"S<*>_MOD<n>_CO_type": None
}

def apply_class_constraint(key_name, classes):

if key_name not in classes.keys():
return
else:
CLASS_CONSTRAINTS[key_name](classes)

def sanitize_classes(classes, ):

assert(type(classes) == dict)

if "interpolation_class" not in classes.keys():
raise ValueError(
"Interpolation class must be present as a classified quantity to enforce classification constraints!"
)

for key in CLASSIFICATION_KEYS:
if "<*>" in key:

for star in range(2):
key_name = key.replace("<*>", f"{star}")

if "MOD<n>" in key_name:

for model in range(N_MODELS):
key_name = key_name.replace("<n>", f"{model}")

apply_class_constraint(key_name, classes)

else:
apply_class_constraint(key_name, classes)
else:

apply_class_constraint(key, classes)

89 changes: 61 additions & 28 deletions posydon/interpolation/data_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,25 @@


import numpy as np

import warnings
import sys

# Convert UserWarning to an error
warnings.simplefilter("error", RuntimeWarning)

eps = 1.0e-16

SCALING_OPTIONS = [
"none",
"min_max",
"max_abs",
# "standardize",
"log_min_max", # has
# "neg_log_min_max", # has
"log_max_abs", # has
# "log_standardize", # has
# "neg_log_standardize" # has
]

class DataScaler:
"""Data Normalization class.
Expand Down Expand Up @@ -68,27 +86,28 @@ def fit(self, x, method='none', lower=-1.0, upper=1.0):
if method == 'min_max':
assert upper > lower, "upper must be greater than lower"
self.lower, self.upper = lower, upper
self.params = [x.min(axis=0), x.max(axis=0)]
self.params = [np.nanmin(x, axis=0), np.nanmax(x, axis=0)]
elif method == 'log_min_max':
assert upper > lower, "upper must be greater than lower"
self.lower, self.upper = lower, upper
self.params = [np.log10(x.min(axis=0)), np.log10(x.max(axis=0))]
self.params = [self.log(np.nanmin(x, axis=0)), self.log(np.nanmax(x, axis=0))]

elif method == 'neg_log_min_max':
assert upper > lower, "upper must be greater than lower"
self.lower, self.upper = lower, upper
self.params = [np.log10((-x).min(axis=0)),
np.log10((-x).max(axis=0))]
self.params = [self.log(np.nanmin(-x, axis=0)),
self.log(np.nanmax(-x, axis=0))]
elif method == 'max_abs':
self.params = [np.abs(x).max(axis=0)]
self.params = [np.nanmax(np.abs(x), axis=0)]
elif method == 'log_max_abs':
self.params = [np.abs(np.log10(x)).max(axis=0)]
elif method == 'standarize':
self.params = [x.mean(axis=0), x.std(axis=0)]
elif method == 'log_standarize':
self.params = [np.nanmax(np.abs(self.log(x)), axis=0)]
elif method == 'standardize':
self.params = [np.nanmean(x, axis=0), np.nanstd(x, axis=0)]
elif method == 'log_standardize':
# log will be computed in transform again
self.params = [np.log10(x).mean(axis=0), np.log10(x).std(axis=0)]
elif method == 'neg_log_standarize': # log(-x)
self.params = [np.log10(-x).mean(axis=0), np.log10(-x).std(axis=0)]
self.params = [np.nanmean(self.log(x), axis=0), np.nanstd(self.log(x), axis=0)]
elif method == 'neg_log_standardize': # log(-x)
self.params = [np.nanmean(self.log(-x), axis=0), np.nanstd(self.log(-x), axis=0)]
elif method == 'log':
self.params = []
elif method == 'none': # no transformation
Expand Down Expand Up @@ -124,26 +143,26 @@ def transform(self, x):
x_t = ((x - self.params[0]) / (self.params[1] - self.params[0])
* (self.upper - self.lower) + self.lower)
elif self.method == 'log_min_max':
x_t = ((np.log10(x) - self.params[0])
x_t = ((self.log(x) - self.params[0])
/ (self.params[1] - self.params[0])
* (self.upper - self.lower) + self.lower)
elif self.method == 'neg_log_min_max':
x_t = ((np.log10(-x) - self.params[0])
x_t = ((self.log(-x) - self.params[0])
/ (self.params[1] - self.params[0])
* (self.upper - self.lower) + self.lower)
elif self.method == 'max_abs':
x_t = x / self.params[0]
elif self.method == 'log_max_abs':
x_t = np.log10(x) / self.params[0]
elif self.method == 'standarize':
x_t = self.log(x) / self.params[0]
elif self.method == 'standardize':
x_t = (x - self.params[0]) / self.params[1]
elif self.method == 'log_standarize':
elif self.method == 'log_standardize':
# log will be computed in transform again
x_t = (np.log10(x) - self.params[0]) / self.params[1]
elif self.method == 'neg_log_standarize':
x_t = (np.log10(-x) - self.params[0]) / self.params[1]
x_t = (self.log(x) - self.params[0]) / self.params[1]
elif self.method == 'neg_log_standardize':
x_t = (self.log(-x) - self.params[0]) / self.params[1]
elif self.method == 'log':
x_t = np.log10(x)
x_t = self.log(x)
else: # no transformation
x_t = x

Expand Down Expand Up @@ -201,24 +220,38 @@ def inv_transform(self, x_t):
/ (self.upper - self.lower)
* (self.params[1] - self.params[0]) + self.params[0])
elif self.method == 'log_min_max':
x = 10 ** ((x_t - self.lower) / (self.upper - self.lower)
x = self.unlog((x_t - self.lower) / (self.upper - self.lower)
* (self.params[1] - self.params[0]) + self.params[0])
elif self.method == 'neg_log_min_max':
x = -10 ** ((x_t - self.lower) / (self.upper - self.lower)
x = -self.unlog((x_t - self.lower) / (self.upper - self.lower)
* (self.params[1] - self.params[0]) + self.params[0])
elif self.method == 'max_abs':
x = x_t * self.params[0]
elif self.method == 'log_max_abs':
x = 10 ** (x_t * self.params[0])
x = self.unlog(x_t * self.params[0])
elif self.method == 'standarize':
x = x_t * self.params[1] + self.params[0]
elif self.method == 'log_standarize':
x = 10 ** (x_t * self.params[1] + self.params[0])
x = self.unlog(x_t * self.params[1] + self.params[0])
elif self.method == 'neg_log_standarize':
x = -10 ** (x_t * self.params[1] + self.params[0])
x = -self.unlog(x_t * self.params[1] + self.params[0])
elif self.method == 'log':
x = 10 ** x_t
x = self.unlog(x_t)
else: # no transformation
x = x_t

return x

def log(self, x):
logged = None
try:
logged = np.log10(x + eps)
except RuntimeWarning:
print(self.method)
print(x, np.isinf(x).any(), np.isnan(x).any(), (x < 0).any(), np.nanmin(x))
# sys.exit()

return logged

def unlog(self, x):
return (10 ** x) - eps
Loading