diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b12be956..8e34101d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,23 +43,23 @@ jobs: pip install -e . conda list - - name: Install Jupyter and dependencies - run: | - pip install jupyter - pip install ipykernel - python -m ipykernel install --user --name aurora-test - # Install any other dependencies you need + # - name: Install Jupyter and dependencies + # run: | + # pip install jupyter + # pip install ipykernel + # python -m ipykernel install --user --name aurora-test + # # Install any other dependencies you need - - name: Execute Jupyter Notebooks - run: | - jupyter nbconvert --to notebook --execute docs/examples/dataset_definition.ipynb - jupyter nbconvert --to notebook --execute docs/examples/make_cas04_single_station_h5.ipynb - jupyter nbconvert --to notebook --execute docs/examples/operate_aurora.ipynb - jupyter nbconvert --to notebook --execute tests/test_run_on_commit.ipynb - jupyter nbconvert --to notebook --execute docs/tutorials/pole_zero_fitting/lemi_pole_zero_fitting_example.ipynb - jupyter nbconvert --to notebook --execute docs/tutorials/processing_configuration.ipynb - jupyter nbconvert --to notebook --execute docs/tutorials/synthetic_data_processing.ipynb - # Replace "notebook.ipynb" with your notebook's filename + # - name: Execute Jupyter Notebooks + # run: | + # jupyter nbconvert --to notebook --execute docs/examples/dataset_definition.ipynb + # jupyter nbconvert --to notebook --execute docs/examples/make_cas04_single_station_h5.ipynb + # jupyter nbconvert --to notebook --execute docs/examples/operate_aurora.ipynb + # jupyter nbconvert --to notebook --execute tests/test_run_on_commit.ipynb + # jupyter nbconvert --to notebook --execute docs/tutorials/pole_zero_fitting/lemi_pole_zero_fitting_example.ipynb + # jupyter nbconvert --to notebook --execute docs/tutorials/processing_configuration.ipynb + # jupyter nbconvert --to notebook --execute docs/tutorials/synthetic_data_processing.ipynb + # # Replace "notebook.ipynb" with your notebook's filename # - name: Commit changes (if any) # run: | diff --git a/aurora/pipelines/run_summary.py b/aurora/pipelines/run_summary.py index 8963dfae..7500a973 100644 --- a/aurora/pipelines/run_summary.py +++ b/aurora/pipelines/run_summary.py @@ -73,7 +73,13 @@ def __init__(self, **kwargs): self.column_dtypes = [str, str, pd.Timestamp, pd.Timestamp] self._input_dict = kwargs.get("input_dict", None) self.df = kwargs.get("df", None) - self._mini_summary_columns = ["survey", "station_id", "run_id", "start", "end"] + self._mini_summary_columns = [ + "survey", + "station_id", + "run_id", + "start", + "end", + ] def clone(self): """ @@ -120,7 +126,9 @@ def check_runs_are_valid(self, drop=False, **kwargs): run_obj = m.get_run(row.station_id, row.run_id, row.survey) runts = run_obj.to_runts() if runts.dataset.to_array().data.__abs__().sum() == 0: - logger.critical("CRITICAL: Detected a run with all zero values") + logger.critical( + "CRITICAL: Detected a run with all zero values" + ) self.df["valid"].at[i_row] = False # load each run, and take the median of the sum of the absolute values if drop: @@ -131,6 +139,60 @@ def drop_invalid_rows(self): self.df = self.df[self.df.valid] self.df.reset_index(drop=True, inplace=True) + def validate_channels(self, drop=False): + """ + Check to make sure each run has the same input and output channels + + optional: drop runs that do not have the same number of channels. + + + """ + + if len(self.df) <= 1: + return + if len(self.df) == 2: + if ( + self.df.iloc[0].input_channels + != self.df.iloc[1].input_channels + ): + logger.warning( + "Input channels are not the same: " + f"row[0]: {self.df.iloc[0].input_channels} != " + f"row[1]: {self.df.iloc[1].input_channels}" + ) + if ( + self.df.iloc[0].output_channels + != self.df.iloc[1].output_channels + ): + logger.warning( + "Output channels are not the same: " + f"row[0]: {self.df.iloc[0].output_channels} != " + f"row[1]: {self.df.iloc[1].output_channels}" + ) + return + else: + common_input_channels = self.df.input_channels.mode()[0] + common_output_channels = self.df.output_channels.mode()[0] + for row in self.df.itertuples(): + if row.input_channels != common_input_channels: + self.df["valid"].at[row.Index] = False + logger.warning( + "Input channels are not the same: " + f"run {row.run_id} row {row.Index}: {row.input_channels} != " + f"{common_input_channels}" + ) + if row.output_channels != common_output_channels: + self.df["valid"].at[row.Index] = False + logger.warning( + "output channels are not the same: " + f"run {row.run_id} row {row.Index}: {row.output_channels} != " + f"{common_output_channels}" + ) + + if drop: + self.drop_invalid_rows() + return + # BELOW FUNCTION CAN BE COPIED FROM METHOD IN KernelDataset() # def drop_runs_shorter_than(self, duration, units="s"): # if units != "s": @@ -221,7 +283,9 @@ def channel_summary_to_run_summary( channel_scale_factors = n_station_runs * [None] i = 0 for group_values, group in grouper: - group_info = dict(zip(group_by_columns, group_values)) # handy for debug + group_info = dict( + zip(group_by_columns, group_values) + ) # handy for debug # for k, v in group_info.items(): # print(f"{k} = {v}") survey_ids[i] = group_info["survey"] @@ -232,9 +296,15 @@ def channel_summary_to_run_summary( sample_rates[i] = group.sample_rate.iloc[0] channels_list = group.component.to_list() num_channels = len(channels_list) - input_channels[i] = [x for x in channels_list if x in allowed_input_channels] - output_channels[i] = [x for x in channels_list if x in allowed_output_channels] - channel_scale_factors[i] = dict(zip(channels_list, num_channels * [1.0])) + input_channels[i] = [ + x for x in channels_list if x in allowed_input_channels + ] + output_channels[i] = [ + x for x in channels_list if x in allowed_output_channels + ] + channel_scale_factors[i] = dict( + zip(channels_list, num_channels * [1.0]) + ) i += 1 data_dict = {} @@ -286,7 +356,9 @@ def extract_run_summary_from_mth5(mth5_obj, summary_type="run"): return out_df -def extract_run_summaries_from_mth5s(mth5_list, summary_type="run", deduplicate=True): +def extract_run_summaries_from_mth5s( + mth5_list, summary_type="run", deduplicate=True +): """ ToDo: Move this method into mth5? or mth5_helpers? ToDo: Make this a class so that the __repr__ is a nice visual representation of the diff --git a/aurora/pipelines/time_series_helpers.py b/aurora/pipelines/time_series_helpers.py index fbb7cd4d..f7979698 100644 --- a/aurora/pipelines/time_series_helpers.py +++ b/aurora/pipelines/time_series_helpers.py @@ -61,7 +61,9 @@ def apply_prewhitening(decimation_obj, run_xrds_input): run_xrds = run_xrds_input.differentiate("time") else: - msg = f"{decimation_obj.prewhitening_type} pre-whitening not implemented" + msg = ( + f"{decimation_obj.prewhitening_type} pre-whitening not implemented" + ) logger.exception(msg) raise NotImplementedError(msg) return run_xrds @@ -194,7 +196,9 @@ def truncate_to_clock_zero(decimation_obj, run_xrds): pass # time series start is already clock zero else: windowing_scheme = window_scheme_from_decimation(decimation_obj) - number_of_steps = delta_t_seconds / windowing_scheme.duration_advance + number_of_steps = ( + delta_t_seconds / windowing_scheme.duration_advance + ) n_partial_steps = number_of_steps - np.floor(number_of_steps) n_clip = n_partial_steps * windowing_scheme.num_samples_advance n_clip = int(np.round(n_clip)) @@ -222,8 +226,10 @@ def nan_to_mean(xrds): for ch in xrds.keys(): null_values_present = xrds[ch].isnull().any() if null_values_present: + nan_count = np.count_nonzero(np.isnan(xrds[ch])) logger.info( - "Null values detected in xrds -- this is not expected and should be examined" + f"{nan_count} Null values detected in xrds channel {ch}. " + "Check if this is unexpected." ) value = np.nan_to_num(np.nanmean(xrds[ch].data)) xrds[ch] = xrds[ch].fillna(value) @@ -259,7 +265,9 @@ def run_ts_to_stft(decimation_obj, run_xrds_orig): if not np.prod(windowed_obj.to_array().data.shape): raise ValueError - windowed_obj = WindowedTimeSeries.detrend(data=windowed_obj, detrend_type="linear") + windowed_obj = WindowedTimeSeries.detrend( + data=windowed_obj, detrend_type="linear" + ) tapered_obj = windowed_obj * windowing_scheme.taper stft_obj = windowing_scheme.apply_fft( tapered_obj, detrend_type=decimation_obj.extra_pre_fft_detrend_type @@ -269,7 +277,9 @@ def run_ts_to_stft(decimation_obj, run_xrds_orig): return stft_obj -def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None): +def calibrate_stft_obj( + stft_obj, run_obj, units="MT", channel_scale_factors=None +): """ Parameters @@ -291,7 +301,6 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None Time series of calibrated Fourier coefficients """ for channel_id in stft_obj.keys(): - channel = run_obj.get_channel(channel_id) channel_response = channel.channel_response if not channel_response.filters_list: @@ -299,7 +308,9 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None logger.warning(msg) if channel_id == "hy": msg = "Channel hy has no filters, try using filters from hx" - logger.warning("Channel HY has no filters, try using filters from HX") + logger.warning( + "Channel HY has no filters, try using filters from HX" + ) channel_response = run_obj.get_channel("hx").channel_response indices_to_flip = channel_response.get_indices_of_filters_to_remove( @@ -308,7 +319,9 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None indices_to_flip = [ i for i in indices_to_flip if channel.metadata.filter.applied[i] ] - filters_to_remove = [channel_response.filters_list[i] for i in indices_to_flip] + filters_to_remove = [ + channel_response.filters_list[i] for i in indices_to_flip + ] if not filters_to_remove: logger.warning("No filters to remove") calibration_response = channel_response.complex_response( @@ -321,7 +334,9 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None channel_scale_factor = 1.0 calibration_response /= channel_scale_factor if units == "SI": - logger.warning("Warning: SI Units are not robustly supported issue #36") + logger.warning( + "Warning: SI Units are not robustly supported issue #36" + ) stft_obj[channel_id].data /= calibration_response return stft_obj @@ -353,7 +368,9 @@ def prototype_decimate(config, run_xrds): num_channels = len(channel_labels) new_data = np.full((num_observations, num_channels), np.nan) for i_ch, ch_label in enumerate(channel_labels): - new_data[:, i_ch] = ssig.decimate(run_xrds[ch_label], int(config.factor)) + new_data[:, i_ch] = ssig.decimate( + run_xrds[ch_label], int(config.factor) + ) xr_da = xr.DataArray( new_data, @@ -387,7 +404,9 @@ def prototype_decimate_2(config, run_xrds): xr_ds: xr.Dataset Decimated version of the input run_xrds """ - new_xr_ds = run_xrds.coarsen(time=int(config.factor), boundary="trim").mean() + new_xr_ds = run_xrds.coarsen( + time=int(config.factor), boundary="trim" + ).mean() attr_dict = run_xrds.attrs attr_dict["sample_rate"] = config.sample_rate new_xr_ds.attrs = attr_dict @@ -422,3 +441,24 @@ def prototype_decimate_3(config, run_xrds): attr_dict["sample_rate"] = config.sample_rate new_xr_ds.attrs = attr_dict return new_xr_ds + + +def prototype_decimate_4(config, run_xrds): + """ + use scipy filters resample_poly + + :param config: DESCRIPTION + :type config: TYPE + :param run_xrds: DESCRIPTION + :type run_xrds: TYPE + :return: DESCRIPTION + :rtype: TYPE + + """ + new_ds = run_xrds.fillna(0) + new_ds = new_ds.sps_filters.resample_poly( + config.sample_rate, pad_type="mean" + ) + + new_ds.attrs["sample_rate"] = config.sample_rate + return new_ds diff --git a/aurora/pipelines/transfer_function_helpers.py b/aurora/pipelines/transfer_function_helpers.py index 40f7b584..490f5157 100644 --- a/aurora/pipelines/transfer_function_helpers.py +++ b/aurora/pipelines/transfer_function_helpers.py @@ -19,7 +19,11 @@ from loguru import logger -ESTIMATOR_LIBRARY = {"OLS": RegressionEstimator, "RME": TRME, "RME_RR": TRME_RR} +ESTIMATOR_LIBRARY = { + "OLS": RegressionEstimator, + "RME": TRME, + "RME_RR": TRME_RR, +} def get_estimator_class(estimation_engine): @@ -201,7 +205,6 @@ def process_transfer_functions( estimator_class = get_estimator_class(dec_level_config.estimator.engine) iter_control = set_up_iter_control(dec_level_config) for band in transfer_function_obj.frequency_bands.bands(): - X, Y, RR = get_band_for_tf_estimate( band, dec_level_config, local_stft_obj, remote_stft_obj ) @@ -213,7 +216,9 @@ def process_transfer_functions( coherence_weights_jj84, ) - Wjj84 = coherence_weights_jj84(band, local_stft_obj, remote_stft_obj) + Wjj84 = coherence_weights_jj84( + band, local_stft_obj, remote_stft_obj + ) apply_weights(X, Y, RR, Wjj84, segment=True, dropna=False) if "simple_coherence" in segment_weights: from aurora.transfer_function.weights.coherence_weights import ( @@ -228,7 +233,9 @@ def process_transfer_functions( multiple_coherence_weights, ) - W = multiple_coherence_weights(band, local_stft_obj, remote_stft_obj) + W = multiple_coherence_weights( + band, local_stft_obj, remote_stft_obj + ) apply_weights(X, Y, RR, W, segment=True, dropna=False) # if there are channel weights apply them here @@ -237,14 +244,16 @@ def process_transfer_functions( X, Y, RR = stack_fcs(X, Y, RR) # Should only be needed if weights were applied - X, Y, RR = drop_nans(X, Y, RR) + # X, Y, RR = drop_nans(X, Y, RR) W = effective_degrees_of_freedom_weights(X, RR, edf_obj=None) X, Y, RR = apply_weights(X, Y, RR, W, segment=False, dropna=True) if dec_level_config.estimator.estimate_per_channel: for ch in dec_level_config.output_channels: - Y_ch = Y[ch].to_dataset() # keep as a dataset, maybe not needed + Y_ch = Y[ + ch + ].to_dataset() # keep as a dataset, maybe not needed X_, Y_, RR_ = handle_nan(X, Y_ch, RR, drop_dim="observation") @@ -252,18 +261,21 @@ def process_transfer_functions( # if RR is not None: # W = effective_degrees_of_freedom_weights(X_, RR_, edf_obj=None) # X_, Y_, RR_ = apply_weights(X_, Y_, RR_, W, segment=False) - regression_estimator = estimator_class( X=X_, Y=Y_, Z=RR_, iter_control=iter_control ) regression_estimator.estimate() - transfer_function_obj.set_tf(regression_estimator, band.center_period) + transfer_function_obj.set_tf( + regression_estimator, band.center_period + ) else: X, Y, RR = handle_nan(X, Y, RR, drop_dim="observation") regression_estimator = estimator_class( X=X, Y=Y, Z=RR, iter_control=iter_control ) regression_estimator.estimate() - transfer_function_obj.set_tf(regression_estimator, band.center_period) + transfer_function_obj.set_tf( + regression_estimator, band.center_period + ) return transfer_function_obj diff --git a/aurora/pipelines/transfer_function_kernel.py b/aurora/pipelines/transfer_function_kernel.py index 1ba0ee72..c26207c4 100644 --- a/aurora/pipelines/transfer_function_kernel.py +++ b/aurora/pipelines/transfer_function_kernel.py @@ -3,7 +3,10 @@ import psutil from aurora.pipelines.helpers import initialize_config -from aurora.pipelines.time_series_helpers import prototype_decimate +from aurora.pipelines.time_series_helpers import ( + prototype_decimate, + prototype_decimate_4, +) from mth5.utils.exceptions import MTH5Error from mth5.utils.helpers import initialize_mth5 from mth5.utils.helpers import path_or_mth5_object @@ -132,21 +135,23 @@ def update_dataset_df(self, i_dec_level): continue run_xrds = row["run_dataarray"].to_dataset("channel") decimation = self.config.decimations[i_dec_level].decimation - decimated_xrds = prototype_decimate(decimation, run_xrds) - self.dataset_df["run_dataarray"].at[i] = decimated_xrds.to_array( + decimated_xrds = prototype_decimate_4(decimation, run_xrds) + self.dataset_df["run_dataarray"].at[ + i + ] = decimated_xrds.to_array( "channel" ) # See Note 1 above - msg = ( - f"Dataset Dataframe Updated for decimation level {i_dec_level} Successfully" - ) + msg = f"Dataset Dataframe Updated for decimation level {i_dec_level} Successfully" logger.info(msg) return def apply_clock_zero(self, dec_level_config): """get clock-zero from data if needed""" if dec_level_config.window.clock_zero_type == "data start": - dec_level_config.window.clock_zero = str(self.dataset_df.start.min()) + dec_level_config.window.clock_zero = str( + self.dataset_df.start.min() + ) return dec_level_config @property @@ -212,7 +217,12 @@ def check_if_fcs_already_exist(self): remote = run_sub_df.remote.iloc[0] mth5_path = run_sub_df.mth5_path.iloc[0] fcs_present = mth5_has_fcs( - mth5_path, survey_id, station_id, run_id, remote, self.processing_config + mth5_path, + survey_id, + station_id, + run_id, + remote, + self.processing_config, ) self.dataset_df.loc[dataset_df_indices, "fc"] = fcs_present @@ -245,7 +255,9 @@ def show_processing_summary( columns_to_show = self.processing_summary.columns columns_to_show = [x for x in columns_to_show if x not in omit_columns] logger.info("Processing Summary Dataframe:") - logger.info(f"\n{self.processing_summary[columns_to_show].to_string()}") + logger.info( + f"\n{self.processing_summary[columns_to_show].to_string()}" + ) def make_processing_summary(self): """ @@ -265,11 +277,15 @@ def make_processing_summary(self): decimation_info = self.config.decimation_info() for i_dec, dec_factor in decimation_info.items(): tmp[i_dec] = dec_factor - tmp = tmp.melt(id_vars=id_vars, value_name="dec_factor", var_name="dec_level") + tmp = tmp.melt( + id_vars=id_vars, value_name="dec_factor", var_name="dec_level" + ) sortby = ["survey", "station_id", "run_id", "start", "dec_level"] tmp.sort_values(by=sortby, inplace=True) tmp.reset_index(drop=True, inplace=True) - tmp.drop("sample_rate", axis=1, inplace=True) # not valid for decimated data + tmp.drop( + "sample_rate", axis=1, inplace=True + ) # not valid for decimated data # Add window info group_by = [ @@ -305,7 +321,9 @@ def make_processing_summary(self): num_samples_window=row.num_samples_window, num_samples_overlap=row.num_samples_overlap, ) - num_windows[i] = ws.available_number_of_windows(row.num_samples) + num_windows[i] = ws.available_number_of_windows( + row.num_samples + ) df["num_stft_windows"] = num_windows groups.append(df) @@ -347,7 +365,8 @@ def validate_decimation_scheme_and_dataset_compatability( for x in self.processing_config.decimations } min_stft_window_list = [ - min_stft_window_info[x] for x in self.processing_summary.dec_level + min_stft_window_info[x] + for x in self.processing_summary.dec_level ] min_num_stft_windows = pd.Series(min_stft_window_list) @@ -371,7 +390,9 @@ def validate_processing(self): self.config.drop_reference_channels() for decimation in self.config.decimations: if decimation.estimator.engine == "RME_RR": - logger.info("No RR station specified, switching RME_RR to RME") + logger.info( + "No RR station specified, switching RME_RR to RME" + ) decimation.estimator.engine = "RME" # Make sure that a local station is defined @@ -399,7 +420,9 @@ def valid_decimations(self): valid_levels = tmp.dec_level.unique() dec_levels = [x for x in self.config.decimations] - dec_levels = [x for x in dec_levels if x.decimation.level in valid_levels] + dec_levels = [ + x for x in dec_levels if x.decimation.level in valid_levels + ] msg = f"After validation there are {len(dec_levels)} valid decimation levels" logger.info(msg) return dec_levels @@ -412,7 +435,9 @@ def validate_save_fc_settings(self): # if dec_level_config.save_fcs: dec_level_config.save_fcs = False if self.config.stations.remote: - save_any_fcs = np.array([x.save_fcs for x in self.config.decimations]).any() + save_any_fcs = np.array( + [x.save_fcs for x in self.config.decimations] + ).any() if save_any_fcs: msg = "\n Saving FCs for remote reference processing is not supported" msg = f"{msg} \n - To save FCs, process as single station, then you can use the FCs for RR processing" @@ -521,17 +546,27 @@ def make_decimation_dict_for_tf(tf_collection, processing_config): ------- """ - from mt_metadata.transfer_functions.io.zfiles.zmm import PERIOD_FORMAT + from mt_metadata.transfer_functions.io.zfiles.zmm import ( + PERIOD_FORMAT, + ) decimation_dict = {} - for i_dec, dec_level_cfg in enumerate(processing_config.decimations): + for i_dec, dec_level_cfg in enumerate( + processing_config.decimations + ): for i_band, band in enumerate(dec_level_cfg.bands): period_key = f"{band.center_period:{PERIOD_FORMAT}}" period_value = {} - period_value["level"] = i_dec + 1 # +1 to match EMTF standard - period_value["bands"] = tuple(band.harmonic_indices[np.r_[0, -1]]) - period_value["sample_rate"] = dec_level_cfg.sample_rate_decimation + period_value["level"] = ( + i_dec + 1 + ) # +1 to match EMTF standard + period_value["bands"] = tuple( + band.harmonic_indices[np.r_[0, -1]] + ) + period_value[ + "sample_rate" + ] = dec_level_cfg.sample_rate_decimation try: period_value["npts"] = tf_collection.tf_dict[ i_dec @@ -561,21 +596,30 @@ def make_decimation_dict_for_tf(tf_collection, processing_config): tf_cls.transfer_function = tmp isp = merged_tf_dict["cov_ss_inv"] - renamer_dict = {"input_channel_1": "input", "input_channel_2": "output"} + renamer_dict = { + "input_channel_1": "input", + "input_channel_2": "output", + } isp = isp.rename(renamer_dict) tf_cls.inverse_signal_power = isp res_cov = merged_tf_dict["cov_nn"] - renamer_dict = {"output_channel_1": "input", "output_channel_2": "output"} + renamer_dict = { + "output_channel_1": "input", + "output_channel_2": "output", + } res_cov = res_cov.rename(renamer_dict) tf_cls.residual_covariance = res_cov # Set key as first el't of dict, nor currently supporting mixed surveys in TF tf_cls.survey_metadata = self.dataset.local_survey_metadata - tf_cls.station_metadata.transfer_function.processing_type = self.processing_type - # tf_cls.station_metadata.transfer_function.processing_config = ( - # self.processing_config - # ) + tf_cls.station_metadata.transfer_function.processing_type = ( + self.processing_type + ) + tf_cls.station_metadata.transfer_function.processing_config = ( + self.processing_config + ) + return tf_cls def memory_warning(self): @@ -600,7 +644,9 @@ def memory_warning(self): num_samples = self.dataset_df.duration * self.dataset_df.sample_rate total_samples = num_samples.sum() total_bytes = total_samples * bytes_per_sample - logger.info(f"Total Bytes of Raw Data: {total_bytes / (1024 ** 3):.3f} GB") + logger.info( + f"Total Bytes of Raw Data: {total_bytes / (1024 ** 3):.3f} GB" + ) ram_fraction = 1.0 * total_bytes / total_memory logger.info(f"Raw Data will use: {100 * ram_fraction:.3f} % of memory") @@ -656,7 +702,9 @@ def mth5_has_fcs(m, survey_id, station_id, run_id, remote, processing_config): return False if len(fc_group.groups_list) < processing_config.num_decimation_levels: - msg = f"Not enough FC Groups found for {row_ssr_str} -- will build them" + msg = ( + f"Not enough FC Groups found for {row_ssr_str} -- will build them" + ) return False # Can check time periods here if desired, but unique (survey, station, run) should make this unneeded diff --git a/aurora/test_utils/synthetic/rms_helpers.py b/aurora/test_utils/synthetic/rms_helpers.py index 75e35114..8546d1aa 100644 --- a/aurora/test_utils/synthetic/rms_helpers.py +++ b/aurora/test_utils/synthetic/rms_helpers.py @@ -40,21 +40,42 @@ def get_expected_rms_misfit(test_case_id, emtf_version=None): expected_rms_misfit["phi"] = {} if test_case_id == "test1": if emtf_version == "fortran": - expected_rms_misfit["rho"]["xy"] = 4.433905 - expected_rms_misfit["phi"]["xy"] = 0.910484 - expected_rms_misfit["rho"]["yx"] = 3.658614 - expected_rms_misfit["phi"]["yx"] = 0.844645 + # original decimation method + # expected_rms_misfit["rho"]["xy"] = 4.433905 + # expected_rms_misfit["phi"]["xy"] = 0.910484 + # expected_rms_misfit["rho"]["yx"] = 3.658614 + # expected_rms_misfit["phi"]["yx"] = 0.844645 + + # resample_poly method + expected_rms_misfit["rho"]["xy"] = 4.432282 + expected_rms_misfit["phi"]["xy"] = 0.915786 + expected_rms_misfit["rho"]["yx"] = 3.649244 + expected_rms_misfit["phi"]["yx"] = 0.843633 elif emtf_version == "matlab": - expected_rms_misfit["rho"]["xy"] = 2.706098 - expected_rms_misfit["phi"]["xy"] = 0.784229 - expected_rms_misfit["rho"]["yx"] = 3.745280 - expected_rms_misfit["phi"]["yx"] = 1.374938 + # original decimation method + # expected_rms_misfit["rho"]["xy"] = 2.706098 + # expected_rms_misfit["phi"]["xy"] = 0.784229 + # expected_rms_misfit["rho"]["yx"] = 3.745280 + # expected_rms_misfit["phi"]["yx"] = 1.374938 + + # resample_poly method + expected_rms_misfit["rho"]["xy"] = 2.711959 + expected_rms_misfit["phi"]["xy"] = 0.787291 + expected_rms_misfit["rho"]["yx"] = 3.632992 + expected_rms_misfit["phi"]["yx"] = 1.365387 elif test_case_id == "test2r1": - expected_rms_misfit["rho"]["xy"] = 3.971313 - expected_rms_misfit["phi"]["xy"] = 0.982613 - expected_rms_misfit["rho"]["yx"] = 3.967259 - expected_rms_misfit["phi"]["yx"] = 1.62881 + # original decimation method + # expected_rms_misfit["rho"]["xy"] = 3.971313 + # expected_rms_misfit["phi"]["xy"] = 0.982613 + # expected_rms_misfit["rho"]["yx"] = 3.967259 + # expected_rms_misfit["phi"]["yx"] = 1.62881 + + # resample_poly method + expected_rms_misfit["rho"]["xy"] = 3.96470 + expected_rms_misfit["phi"]["xy"] = 0.991345 + expected_rms_misfit["rho"]["yx"] = 4.01597 + expected_rms_misfit["phi"]["yx"] = 1.59927 return expected_rms_misfit @@ -81,22 +102,38 @@ def assert_rms_misfit_ok( expected_rms_phi = expected_rms_misfit["phi"][xy_or_yx] logger.info(f"expected_rms_rho_{xy_or_yx} {expected_rms_rho}") logger.info(f"expected_rms_phi_{xy_or_yx} {expected_rms_phi}") - if not np.isclose(rho_rms_aurora - expected_rms_rho, 0, atol=rho_tol): - logger.error("==== AURORA ====\n") + + rho = True + phi = True + if not np.isclose(abs(rho_rms_aurora - expected_rms_rho), 0, atol=rho_tol): + logger.error(f"==== AURORA (rho_{xy_or_yx}) ====") logger.error(rho_rms_aurora) - logger.error("==== EXPECTED ====\n") + logger.error(f"==== EXPECTED (rho_{xy_or_yx}) ====") logger.error(expected_rms_rho) - logger.error("==== DIFFERENCE ====\n") + logger.error(f"==== DIFFERENCE (rho_{xy_or_yx}) ====") logger.error(rho_rms_aurora - expected_rms_rho) - raise AssertionError("Expected misfit for resistivity is not correct") + rho = False + # raise AssertionError("Expected misfit for resistivity is not correct") - if not np.isclose(phi_rms_aurora - expected_rms_phi, 0, atol=rho_tol): - logger.error("==== AURORA ====\n") + if not np.isclose(abs(phi_rms_aurora - expected_rms_phi), 0, atol=phi_tol): + logger.error(f"==== AURORA (phi_{xy_or_yx}) ====\n") logger.error(phi_rms_aurora) - logger.error("==== EXPECTED ====\n") + logger.error(f"==== EXPECTED (phi_{xy_or_yx}) ====\n") logger.error(expected_rms_phi) - logger.error("==== DIFFERENCE ====\n") + logger.error(f"==== DIFFERENCE (phi_{xy_or_yx}) ====\n") logger.error(phi_rms_aurora - expected_rms_phi) - raise AssertionError("Expected misfit for phase is not correct") + phi = False + # raise AssertionError("Expected misfit for phase is not correct") + if not rho: + if not phi: + raise AssertionError( + "Expected misfit for resistivity and phase is not correct" + ) + else: + raise AssertionError( + "Expected misfit for resistivity is not correct" + ) + elif not phi: + raise AssertionError("Expected misfit for phase is not correct") return diff --git a/aurora/time_series/filters/filter_helpers.py b/aurora/time_series/filters/filter_helpers.py new file mode 100644 index 00000000..00d24c9a --- /dev/null +++ b/aurora/time_series/filters/filter_helpers.py @@ -0,0 +1,90 @@ +from mt_metadata.timeseries.filters.coefficient_filter import CoefficientFilter +from mt_metadata.timeseries.filters.frequency_response_table_filter import ( + FrequencyResponseTableFilter, +) +from loguru import logger + + +def make_coefficient_filter(gain=1.0, name="generic coefficient filter", **kwargs): + """ + + Parameters + ---------- + gain + name + units_in : string + one of "digital counts", "millivolts", etc. + TODO: Add a refernce here to the list of units supported in mt_metadata + + Returns + ------- + + """ + # in general, you need to add all required fields from the standards.json + default_units_in = "unknown" + default_units_out = "unknown" + + cf = CoefficientFilter() + cf.gain = gain + cf.name = name + + cf.units_in = kwargs.get("units_in", default_units_in) + cf.units_out = kwargs.get("units_out", default_units_out) + + return cf + + +def make_frequency_response_table_filter(file_path, case="bf4"): + """ + Parameters + ---------- + filepath: pathlib.Path or string + case : string, placeholder for handlig different fap table formats. + + Returns + ------- + fap_filter: FrequencyResponseTableFilter + """ + fap_filter = FrequencyResponseTableFilter() + + if case == "bf4": + import numpy as np + import pandas as pd + + df = pd.read_csv(file_path) # , skiprows=1) + # Hz, V/nT, degrees + fap_filter.frequencies = df["Frequency [Hz]"].values + fap_filter.amplitudes = df["Amplitude [V/nT]"].values + fap_filter.phases = np.deg2rad(df["Phase [degrees]"].values) + fap_filter.units_in = "volts" + fap_filter.units_out = "nanotesla" + fap_filter.gain = 1.0 + fap_filter.name = "bf4" + return fap_filter + else: + logger.error(f"case {case} not supported for FAP Table") + raise Exception + + +def make_volt_per_meter_to_millivolt_per_km_converter(): + """ + This represents a filter that converts from mV/km to V/m. + + Returns + ------- + + """ + coeff_filter = make_coefficient_filter( + gain=1e-6, + units_in="millivolts per kilometer", + units_out="volts per meter", + name="MT to SI electric field conversion", + ) + return coeff_filter + + +MT2SI_ELECTRIC_FIELD_FILTER = make_volt_per_meter_to_millivolt_per_km_converter() + + +def main(): + make_volt_per_meter_to_millivolt_per_km_converter() diff --git a/aurora/time_series/xarray_helpers.py b/aurora/time_series/xarray_helpers.py index f8cdb780..dae54ab1 100644 --- a/aurora/time_series/xarray_helpers.py +++ b/aurora/time_series/xarray_helpers.py @@ -54,15 +54,29 @@ def handle_nan(X, Y, RR, drop_dim=""): data_var_rm_label_mapper[f"remote_{ch}"] = ch RR = RR.rename(data_var_add_label_mapper) - merged_xr = X.merge(Y, join="exact") + try: + merged_xr = X.merge(Y, join="exact") + except ValueError as error: + logger.debug(error) + logger.debug("Merging with 'outer'") + merged_xr = X.merge(Y, join="outer") + # Workaround for issue #228 # merged_xr = merged_xr.merge(RR, join="exact") try: - merged_xr = merged_xr.merge(RR, join="exact") + try: + merged_xr = merged_xr.merge(RR, join="exact") + except ValueError as error: + logger.debug(error) + logger.debug("Merging with 'outer'") + merged_xr = merged_xr.merge(RR, join="outer") + except ValueError: logger.error("Coordinate alignment mismatch -- see aurora issue #228 ") matches = X.time.values == RR.time.values - logger.error(f"{matches.sum()}/{len(matches)} timestamps match exactly") + logger.error( + f"{matches.sum()}/{len(matches)} timestamps match exactly" + ) deltas = X.time.values - RR.time.values logger.error(f"Maximum offset is {deltas.__abs__().max()}ns") # print(f"X.time.[0]: {X.time[0].values}") diff --git a/aurora/transfer_function/kernel_dataset.py b/aurora/transfer_function/kernel_dataset.py index 4ea21fe0..982a265b 100644 --- a/aurora/transfer_function/kernel_dataset.py +++ b/aurora/transfer_function/kernel_dataset.py @@ -132,7 +132,9 @@ def clone(self): def clone_dataframe(self): return copy.deepcopy(self.df) - def from_run_summary(self, run_summary, local_station_id, remote_station_id=None): + def from_run_summary( + self, run_summary, local_station_id, remote_station_id=None + ): """ Parameters @@ -156,7 +158,9 @@ def from_run_summary(self, run_summary, local_station_id, remote_station_id=None ] if remote_station_id: station_ids.append(remote_station_id) - df = restrict_to_station_list(run_summary.df, station_ids, inplace=False) + df = restrict_to_station_list( + run_summary.df, station_ids, inplace=False + ) # Check df is non-empty if len(df) == 0: msg = f"Restricting run_summary df to {station_ids} yields an empty set" @@ -372,17 +376,23 @@ def update_survey_metadata(self, i, row, run_ts): """ survey_id = run_ts.survey_metadata.id - if i == 0: + # need to add another survey if it is not in the survey dictionary. + if i == 0 or survey_id not in self.survey_metadata.keys(): self.survey_metadata[survey_id] = run_ts.survey_metadata elif i > 0: - if row.station_id in self.survey_metadata[survey_id].stations.keys(): - self.survey_metadata[survey_id].stations[row.station_id].add_run( - run_ts.run_metadata - ) + if ( + row.station_id + in self.survey_metadata[survey_id].stations.keys() + ): + self.survey_metadata[survey_id].stations[ + row.station_id + ].add_run(run_ts.run_metadata) else: - self.survey_metadata[survey_id].add_station(run_ts.station_metadata) - if len(self.survey_metadata.keys()) > 1: - raise NotImplementedError + self.survey_metadata[survey_id].add_station( + run_ts.station_metadata + ) + # if len(self.survey_metadata.keys()) > 1: + # raise NotImplementedError def initialize_dataframe_for_processing(self, mth5_objs): """ diff --git a/aurora/transfer_function/regression/TRME.py b/aurora/transfer_function/regression/TRME.py index 1ea5f070..967f41bf 100644 --- a/aurora/transfer_function/regression/TRME.py +++ b/aurora/transfer_function/regression/TRME.py @@ -114,6 +114,7 @@ def update_residual_variance(self, correction_factor=1): def update_b(self): """matlab was: b = R\QTY;""" + self.b = solve_triangular(self.R, self.QHYc) def compute_inverse_signal_covariance(self): diff --git a/aurora/transfer_function/regression/m_estimator.py b/aurora/transfer_function/regression/m_estimator.py index 004a7923..407d8417 100644 --- a/aurora/transfer_function/regression/m_estimator.py +++ b/aurora/transfer_function/regression/m_estimator.py @@ -56,7 +56,9 @@ def Y_hat(self): return self._Y_hat def update_y_hat(self): - logger.error("Y_hat update method is not defined for abstract MEstimator class") + logger.error( + "Y_hat update method is not defined for abstract MEstimator class" + ) logger.error("Try using RME or RME_RR class instead") raise Exception @@ -103,7 +105,9 @@ def residual_variance_method1(self): than the one in TRME, but also has more computational overhead. """ res = self.Yc - self.Y_hat # intial estimate of error variance - residual_variance = np.sum(np.abs(res * np.conj(res)), axis=0) / self.n_data + residual_variance = ( + np.sum(np.abs(res * np.conj(res)), axis=0) / self.n_data + ) return residual_variance def residual_variance_method2(self): @@ -142,12 +146,12 @@ def residual_variance_method2(self): try: assert (residual_variance > 0).all() except AssertionError: - logger.warning("WARNING - Negative error variances observed") - logger.warning(residual_variance) - logger.warning( + # logger.warning("WARNING - Negative error variances observed") + # logger.warning(residual_variance) + logger.debug( "Setting residual_variance to zero - Negative values observed" ) - residual_variance *= 0 + residual_variance = np.zeros_like(residual_variance) return residual_variance @@ -196,19 +200,23 @@ def apply_huber_regression(self): converged = self.iter_control.max_number_of_iterations <= 0 self.iter_control.number_of_iterations = 0 while not converged: - b0 = self.b + b0 = self.b.copy() self.iter_control.number_of_iterations += 1 self.update_y_cleaned_via_huber_weights() self.update_b() self.update_y_hat() - self.update_residual_variance(correction_factor=self.correction_factor) + self.update_residual_variance( + correction_factor=self.correction_factor + ) converged = self.iter_control.converged(self.b, b0) return def apply_redecending_influence_function(self): """one or two iterations with redescending influence curve cleaned data""" if self.iter_control.max_number_of_redescending_iterations: - self.iter_control.number_of_redescending_iterations = 0 # reset per channel + self.iter_control.number_of_redescending_iterations = ( + 0 # reset per channel + ) while self.iter_control.continue_redescending: self.iter_control.number_of_redescending_iterations += 1 self.update_y_cleaned_via_redescend_weights() diff --git a/tests/pipelines/test_run_summary.py b/tests/pipelines/test_run_summary.py index dd90e7f8..18fa0790 100644 --- a/tests/pipelines/test_run_summary.py +++ b/tests/pipelines/test_run_summary.py @@ -1,6 +1,8 @@ # import logging import unittest +import numpy as np +import pandas as pd from aurora.pipelines.run_summary import RunSummary from aurora.test_utils.synthetic.make_mth5_from_asc import create_test12rr_h5 from aurora.test_utils.synthetic.paths import DATA_PATH @@ -27,12 +29,147 @@ def test_add_duration(self): assert "duration" in rs.df.columns -def main(): - # tmp = TestRunSummary() - # tmp.setUpClass() - # tmp.test_add_duration() - unittest.main() +class TestRunSummaryValidation(unittest.TestCase): + @classmethod + def setUpClass(self): + self.df_bad_outputs = pd.DataFrame( + { + "survey": {2: "LD2024", 4: "LD2024", 6: "LD2024"}, + "station_id": {2: "12", 4: "12", 6: "12"}, + "run_id": { + 2: "sr4096_0002", + 4: "sr4096_0004", + 6: "sr4096_0006", + }, + "start": { + 2: pd.Timestamp("2024-05-09 00:59:58+0000", tz="UTC"), + 4: pd.Timestamp("2024-05-09 06:59:58+0000", tz="UTC"), + 6: pd.Timestamp("2024-05-09 12:59:58+0000", tz="UTC"), + }, + "end": { + 2: pd.Timestamp( + "2024-05-09 01:09:41.997070312+0000", tz="UTC" + ), + 4: pd.Timestamp( + "2024-05-09 07:09:41.996582031+0000", tz="UTC" + ), + 6: pd.Timestamp( + "2024-05-09 13:09:41.996338+0000", tz="UTC" + ), + }, + "sample_rate": {2: 4096.0, 4: 4096.0, 6: 4096.0}, + "input_channels": { + 2: ["hx", "hy"], + 4: ["hx", "hy"], + 6: ["hx", "hy"], + }, + "output_channels": { + 2: ["ey", "hz"], + 4: ["ex", "ey", "hz"], + 6: ["ex", "ey", "hz"], + }, + "channel_scale_factors": { + 2: {"ey": 1.0, "hx": 1.0, "hy": 1.0, "hz": 1.0}, + 4: {"ex": 1.0, "ey": 1.0, "hx": 1.0, "hy": 1.0, "hz": 1.0}, + 6: {"ex": 1.0, "ey": 1.0, "hx": 1.0, "hy": 1.0, "hz": 1.0}, + }, + "valid": {2: True, 4: True, 6: True}, + "mth5_path": { + 2: "path1\test.h5", + 4: "path1\test.h5", + 6: "path1\test.h5", + }, + } + ) + self.df_bad_inputs = pd.DataFrame( + { + "survey": {2: "LD2024", 4: "LD2024", 6: "LD2024"}, + "station_id": {2: "12", 4: "12", 6: "12"}, + "run_id": { + 2: "sr4096_0002", + 4: "sr4096_0004", + 6: "sr4096_0006", + }, + "start": { + 2: pd.Timestamp("2024-05-09 00:59:58+0000", tz="UTC"), + 4: pd.Timestamp("2024-05-09 06:59:58+0000", tz="UTC"), + 6: pd.Timestamp("2024-05-09 12:59:58+0000", tz="UTC"), + }, + "end": { + 2: pd.Timestamp( + "2024-05-09 01:09:41.997070312+0000", tz="UTC" + ), + 4: pd.Timestamp( + "2024-05-09 07:09:41.996582031+0000", tz="UTC" + ), + 6: pd.Timestamp( + "2024-05-09 13:09:41.996338+0000", tz="UTC" + ), + }, + "sample_rate": {2: 4096.0, 4: 4096.0, 6: 4096.0}, + "input_channels": { + 2: ["hx", "hy"], + 4: ["hx"], + 6: ["hx", "hy"], + }, + "output_channels": { + 2: ["ex", "ey", "hz"], + 4: ["ex", "ey", "hz"], + 6: ["ex", "ey", "hz"], + }, + "channel_scale_factors": { + 2: {"ey": 1.0, "hx": 1.0, "hy": 1.0, "hz": 1.0}, + 4: {"ex": 1.0, "ey": 1.0, "hx": 1.0, "hy": 1.0, "hz": 1.0}, + 6: {"ex": 1.0, "ey": 1.0, "hx": 1.0, "hy": 1.0, "hz": 1.0}, + }, + "valid": {2: True, 4: True, 6: True}, + "mth5_path": { + 2: "path1\test.h5", + 4: "path1\test.h5", + 6: "path1\test.h5", + }, + } + ) + + def test_bad_outputs(self): + rs = RunSummary(df=self.df_bad_outputs) + rs.validate_channels() + + self.assertEqual( + True, np.all([False, True, True] == rs.df.valid.values) + ) + + def test_bad_inputs(self): + rs = RunSummary(df=self.df_bad_inputs) + rs.validate_channels() + + self.assertEqual( + True, np.all([True, False, True] == rs.df.valid.values) + ) + + def test_bad_outputs_drop(self): + rs = RunSummary(df=self.df_bad_outputs) + rs.validate_channels(drop=True) + + self.assertEqual(True, np.all([True, True] == rs.df.valid.values)) + + def test_bad_inputs_drop(self): + rs = RunSummary(df=self.df_bad_inputs) + rs.validate_channels(drop=True) + + self.assertEqual(True, np.all([True, True] == rs.df.valid.values)) + + def test_duration(self): + rs = RunSummary(df=self.df_bad_outputs) + rs.add_duration() + + self.assertEqual( + True, + np.isclose( + np.array([583.99707, 583.996582, 583.996338]), rs.df.duration + ).all(), + ) if __name__ == "__main__": - main() + unittest.main() diff --git a/tests/synthetic/test_compare_aurora_vs_archived_emtf.py b/tests/synthetic/test_compare_aurora_vs_archived_emtf.py index 04fa05e9..f235c908 100644 --- a/tests/synthetic/test_compare_aurora_vs_archived_emtf.py +++ b/tests/synthetic/test_compare_aurora_vs_archived_emtf.py @@ -1,16 +1,23 @@ +import unittest +import logging + from aurora.pipelines.process_mth5 import process_mth5 from aurora.pipelines.run_summary import RunSummary from aurora.sandbox.io_helpers.zfile_murphy import read_z_file -from aurora.test_utils.synthetic.make_mth5_from_asc import create_test1_h5 -from aurora.test_utils.synthetic.make_mth5_from_asc import create_test2_h5 -from aurora.test_utils.synthetic.make_mth5_from_asc import create_test12rr_h5 +from aurora.test_utils.synthetic.make_mth5_from_asc import ( + create_test1_h5, + create_test2_h5, + create_test12rr_h5, +) from aurora.test_utils.synthetic.make_processing_configs import ( create_test_run_config, ) from aurora.test_utils.synthetic.paths import SyntheticTestPaths -from aurora.test_utils.synthetic.rms_helpers import assert_rms_misfit_ok -from aurora.test_utils.synthetic.rms_helpers import compute_rms -from aurora.test_utils.synthetic.rms_helpers import get_expected_rms_misfit +from aurora.test_utils.synthetic.rms_helpers import ( + assert_rms_misfit_ok, + compute_rms, + get_expected_rms_misfit, +) from aurora.transfer_function.emtf_z_file_helpers import ( merge_tf_collection_to_match_z_file, ) @@ -85,11 +92,14 @@ def aurora_vs_emtf( ) aux_data = read_z_file(auxilliary_z_file) - aurora_rho_phi = merge_tf_collection_to_match_z_file(aux_data, tf_collection) + aurora_rho_phi = merge_tf_collection_to_match_z_file( + aux_data, tf_collection + ) data_dict = {} data_dict["period"] = aux_data.periods data_dict["emtf_rho_xy"] = aux_data.rxy data_dict["emtf_phi_xy"] = aux_data.pxy + tf_dict = {"xy": True, "yx": True, "xy_error": None, "yx_error": None} for xy_or_yx in ["xy", "yx"]: aurora_rho = aurora_rho_phi["rho"][xy_or_yx] aurora_phi = aurora_rho_phi["phi"][xy_or_yx] @@ -102,9 +112,16 @@ def aurora_vs_emtf( data_dict["aurora_rho_xy"] = aurora_rho data_dict["aurora_phi_xy"] = aurora_phi if expected_rms_misfit is not None: - assert_rms_misfit_ok( - expected_rms_misfit, xy_or_yx, rho_rms_aurora, phi_rms_aurora - ) + try: + assert_rms_misfit_ok( + expected_rms_misfit, + xy_or_yx, + rho_rms_aurora, + phi_rms_aurora, + ) + except AssertionError as error: + tf_dict[xy_or_yx] = False + tf_dict[f"{xy_or_yx}_error"] = error if make_rho_phi_plot: plot_rho_phi( @@ -121,7 +138,17 @@ def aurora_vs_emtf( output_path=AURORA_RESULTS_PATH, ) - return + if not tf_dict["xy"]: + if not tf_dict["yx"]: + raise AssertionError( + f"{tf_dict['xy_error']}; {tf_dict['yx_error']}" + ) + else: + raise AssertionError(tf_dict["xy_error"]) + elif not tf_dict["yx"]: + raise AssertionError(tf_dict["yx_error"]) + + return True def run_test1(emtf_version, ds_df): @@ -142,8 +169,9 @@ def run_test1(emtf_version, ds_df): test_case_id = "test1" auxilliary_z_file = EMTF_RESULTS_PATH.joinpath("test1.zss") z_file_base = f"{test_case_id}_aurora_{emtf_version}.zss" - aurora_vs_emtf(test_case_id, emtf_version, auxilliary_z_file, z_file_base, ds_df) - return + return aurora_vs_emtf( + test_case_id, emtf_version, auxilliary_z_file, z_file_base, ds_df + ) def run_test2r1(tfk_dataset): @@ -162,10 +190,9 @@ def run_test2r1(tfk_dataset): emtf_version = "fortran" auxilliary_z_file = EMTF_RESULTS_PATH.joinpath("test2r1.zrr") z_file_base = f"{test_case_id}_aurora_{emtf_version}.zrr" - aurora_vs_emtf( + return aurora_vs_emtf( test_case_id, emtf_version, auxilliary_z_file, z_file_base, tfk_dataset ) - return def make_mth5s(merged=True): @@ -186,58 +213,116 @@ def make_mth5s(merged=True): return mth5_paths -def test_pipeline(merged=True): - """ +class TestAuroraVsArchivedMergedTrue(unittest.TestCase): + @classmethod + def setUpClass(self): + logging.getLogger("matplotlib.font_manager").disabled = True + logging.getLogger("matplotlib.ticker").disabled = True - Parameters - ---------- - merged: bool - If true, summarise two separate mth5 files and merge their run summaries - If False, use an already-merged mth5 + close_open_files() - Returns - ------- + mth5_paths = make_mth5s(merged=True) + self.run_summary = RunSummary() + self.run_summary.from_mth5s(mth5_paths) + self.tfk_dataset = KernelDataset() + self.tfk_dataset.from_run_summary(self.run_summary, "test1") - """ - close_open_files() + def test_aurora_vs_fortran(self): + self.assertEqual(True, run_test1("fortran", self.tfk_dataset)) - mth5_paths = make_mth5s(merged=merged) - run_summary = RunSummary() - run_summary.from_mth5s(mth5_paths) - tfk_dataset = KernelDataset() - tfk_dataset.from_run_summary(run_summary, "test1") + def test_aurora_vs_matlab(self): + self.assertEqual(True, run_test1("matlab", self.tfk_dataset)) - run_test1("fortran", tfk_dataset) - run_test1("matlab", tfk_dataset) + def test_aurora_vs_rr(self): + tfk_dataset = KernelDataset() + tfk_dataset.from_run_summary(self.run_summary, "test2", "test1") + self.assertEqual(True, run_test2r1(tfk_dataset)) - tfk_dataset = KernelDataset() - tfk_dataset.from_run_summary(run_summary, "test2", "test1") - # Uncomment to sanity check the problem is linear - # scale_factors = { - # "ex": 20.0, - # "ey": 20.0, - # "hx": 20.0, - # "hy": 20.0, - # "hz": 20.0, - # } - # tfk_dataset.df["channel_scale_factors"].at[0] = scale_factors - # tfk_dataset.df["channel_scale_factors"].at[1] = scale_factors - run_test2r1(tfk_dataset) +class TestAuroraVsArchivedMergedFalse(unittest.TestCase): + @classmethod + def setUpClass(self): + logging.getLogger("matplotlib.font_manager").disabled = True + logging.getLogger("matplotlib.ticker").disabled = True -def test(): - import logging + close_open_files() - logging.getLogger("matplotlib.font_manager").disabled = True - logging.getLogger("matplotlib.ticker").disabled = True + mth5_paths = make_mth5s(merged=False) + self.run_summary = RunSummary() + self.run_summary.from_mth5s(mth5_paths) + self.tfk_dataset = KernelDataset() + self.tfk_dataset.from_run_summary(self.run_summary, "test1") - test_pipeline(merged=False) - test_pipeline(merged=True) + def test_aurora_vs_fortran(self): + self.assertEqual(True, run_test1("fortran", self.tfk_dataset)) + def test_aurora_vs_matlab(self): + self.assertEqual(True, run_test1("matlab", self.tfk_dataset)) -def main(): - test() + def test_aurora_vs_rr(self): + tfk_dataset = KernelDataset() + tfk_dataset.from_run_summary(self.run_summary, "test2", "test1") + self.assertEqual(True, run_test2r1(tfk_dataset)) +# ============================================================================= +# run +# ============================================================================= if __name__ == "__main__": - main() + unittest.main() + +# def test_pipeline(merged=True): +# """ + +# Parameters +# ---------- +# merged: bool +# If true, summarise two separate mth5 files and merge their run summaries +# If False, use an already-merged mth5 + +# Returns +# ------- + +# """ +# close_open_files() + +# mth5_paths = make_mth5s(merged=merged) +# run_summary = RunSummary() +# run_summary.from_mth5s(mth5_paths) +# tfk_dataset = KernelDataset() +# tfk_dataset.from_run_summary(run_summary, "test1") + +# run_test1("fortran", tfk_dataset) +# run_test1("matlab", tfk_dataset) + +# tfk_dataset = KernelDataset() +# tfk_dataset.from_run_summary(run_summary, "test2", "test1") +# # Uncomment to sanity check the problem is linear +# # scale_factors = { +# # "ex": 20.0, +# # "ey": 20.0, +# # "hx": 20.0, +# # "hy": 20.0, +# # "hz": 20.0, +# # } +# # tfk_dataset.df["channel_scale_factors"].at[0] = scale_factors +# # tfk_dataset.df["channel_scale_factors"].at[1] = scale_factors +# run_test2r1(tfk_dataset) + + +# def test(): +# import logging + +# logging.getLogger("matplotlib.font_manager").disabled = True +# logging.getLogger("matplotlib.ticker").disabled = True + +# test_pipeline(merged=False) +# test_pipeline(merged=True) + + +# def main(): +# test() + + +# if __name__ == "__main__": +# main() diff --git a/tests/synthetic/test_define_frequency_bands.py b/tests/synthetic/test_define_frequency_bands.py index 2087c671..af7e80b1 100644 --- a/tests/synthetic/test_define_frequency_bands.py +++ b/tests/synthetic/test_define_frequency_bands.py @@ -2,7 +2,9 @@ from aurora.config.config_creator import ConfigCreator from aurora.pipelines.process_mth5 import process_mth5 -from aurora.test_utils.synthetic.processing_helpers import get_example_kernel_dataset +from aurora.test_utils.synthetic.processing_helpers import ( + get_example_kernel_dataset, +) class TestDefineBandsFromDict(unittest.TestCase): @@ -18,7 +20,9 @@ def test_can_declare_frequencies_directly_in_config(self): cfg1 = cc.create_from_kernel_dataset( kernel_dataset, estimator={"engine": "RME"} ) - decimation_factors = list(cfg1.decimation_info().values()) # [1, 4, 4, 4] + decimation_factors = list( + cfg1.decimation_info().values() + ) # [1, 4, 4, 4] # Default Band edges, corresponds to DEFAULT_BANDS_FILE band_edges = cfg1.band_edges_dict cfg2 = cc.create_from_kernel_dataset( @@ -33,6 +37,10 @@ def test_can_declare_frequencies_directly_in_config(self): tf_cls1.write(fn="cfg1.xml", file_type="emtfxml") tf_cls2 = process_mth5(cfg2, kernel_dataset) tf_cls2.write(fn="cfg2.xml", file_type="emtfxml") + + # the processing parameters are not the same so need to null them + tf_cls1.station_metadata.transfer_function.processing_parameters = [] + tf_cls2.station_metadata.transfer_function.processing_parameters = [] assert tf_cls2 == tf_cls1 diff --git a/tests/synthetic/test_fourier_coefficients.py b/tests/synthetic/test_fourier_coefficients.py index 252f4989..44c6d6df 100644 --- a/tests/synthetic/test_fourier_coefficients.py +++ b/tests/synthetic/test_fourier_coefficients.py @@ -10,7 +10,9 @@ from aurora.test_utils.synthetic.make_mth5_from_asc import create_test2_h5 from aurora.test_utils.synthetic.make_mth5_from_asc import create_test3_h5 from aurora.test_utils.synthetic.make_mth5_from_asc import create_test12rr_h5 -from aurora.test_utils.synthetic.make_processing_configs import create_test_run_config +from aurora.test_utils.synthetic.make_processing_configs import ( + create_test_run_config, +) from aurora.test_utils.synthetic.paths import SyntheticTestPaths from aurora.transfer_function.kernel_dataset import KernelDataset from loguru import logger @@ -55,7 +57,12 @@ def setUpClass(self): mth5_path_2 = create_test2_h5(file_version=self.file_version) mth5_path_3 = create_test3_h5(file_version=self.file_version) mth5_path_12rr = create_test12rr_h5(file_version=self.file_version) - self.mth5_paths = [mth5_path_1, mth5_path_2, mth5_path_3, mth5_path_12rr] + self.mth5_paths = [ + mth5_path_1, + mth5_path_2, + mth5_path_3, + mth5_path_12rr, + ] def test_123(self): """ @@ -85,7 +92,9 @@ def test_123(self): ]: station_id = mth5_path.stem tfk_dataset.from_run_summary(run_summary, station_id) - processing_config = create_test_run_config(station_id, tfk_dataset) + processing_config = create_test_run_config( + station_id, tfk_dataset + ) elif mth5_path.stem in [ "test3", ]: @@ -131,11 +140,18 @@ def test_create_then_use_stored_fcs_for_processing(self): from test_processing import process_synthetic_2 z_file_path_1 = AURORA_RESULTS_PATH.joinpath("test2.zss") - z_file_path_2 = AURORA_RESULTS_PATH.joinpath("test2_from_stored_fc.zss") + z_file_path_2 = AURORA_RESULTS_PATH.joinpath( + "test2_from_stored_fc.zss" + ) tf1 = process_synthetic_2( force_make_mth5=True, z_file_path=z_file_path_1, save_fc=True ) - tf2 = process_synthetic_2(force_make_mth5=False, z_file_path=z_file_path_2) + tf2 = process_synthetic_2( + force_make_mth5=False, z_file_path=z_file_path_2 + ) + + tf1.station_metadata.transfer_function.processing_parameters = [] + tf2.station_metadata.transfer_function.processing_parameters = [] assert tf1 == tf2