-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Please see below the python code
The python code attached is providing the statistical analysis and visualization part of the code is designed to provide detailed insights into the data through both descriptive statistics and graphical representations. By calculating and printing essential statistics such as the minimum, maximum, mean, and standard deviation, the code offers a quick summary of the dataset's core characteristics. The inclusion of percentile calculations further enriches the statistical analysis, giving a more granular view of the data distribution. Visualization is handled using Matplotlib, with a dual approach of a histogram and a box plot for each field. The histogram provides a comprehensive view of the frequency distribution, while the box plot highlights the data's central tendency and variability, including outliers. Using a log scale for the histogram's y-axis ensures that a wide range of frequencies is effectively visualized. Overall, the code's approach to analysis and visualization is both thorough and user-friendly, making it easier to interpret complex datasets.
import glob
import os
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
def analyze_and_plot_distribution(data_array, field_name):
"""
Calculates statistics and plots the distribution for a given 1D data array.
Args:
data_array (np.ndarray): A NumPy array of scalar values.
field_name (str): The name of the field being analyzed (e.g., "Temperature").
"""
if data_array.size == 0:
print(f"\n--- No data found for '{field_name}'. Skipping analysis. ---\n")
return
data_array = data_array.flatten()
print(f"\n{'---'*5} Analysis for: {field_name.upper()} {'---'*5}")
print(f"Shape of aggregated data: {data_array.shape}")
print(f"Total values calculated: {len(data_array)}")
# Basic statistics
stats = {
'Min': np.min(data_array),
'Max': np.max(data_array),
'Mean': np.mean(data_array),
'Std Dev': np.std(data_array)
}
for stat, value in stats.items():
print(f" {stat}: {value:.4f}")
# Percentile distribution
print("\n--- Percentile Distribution ---")
percentiles = [1, 5, 25, 50, 75, 95, 99]
percentile_values = np.percentile(data_array, percentiles)
for p, v in zip(percentiles, percentile_values):
print(f" {p:2d}th percentile: {v:.4f}")
print("***************************************\n")
# Plotting
plt.style.use("seaborn-v0_8-whitegrid")
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 1]}
)
fig.suptitle(f"Distribution for {field_name}", fontsize=20, y=0.98)
# Plot 1: Histogram
ax1.hist(data_array, bins=100, color="skyblue", edgecolor="black", alpha=0.8)
ax1.set_title("Histogram")
ax1.set_xlabel(field_name)
ax1.set_ylabel("Frequency")
ax1.set_yscale("log")
ax1.grid(True, which="both", linestyle='--', linewidth=0.5)
# Plot 2: Box plot
ax2.boxplot(
data_array,
vert=False,
whis=[5, 95],
patch_artist=True,
boxprops=dict(facecolor="lightgreen"),
flierprops=dict(marker="o", markerfacecolor="red", markersize=5, alpha=0.3),
)
ax2.set_title("Box Plot")
ax2.set_xlabel(field_name)
ax2.set_yticks([])
ax2.grid(True, linestyle='--', linewidth=0.5)
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
def process_and_plot_directory(data_dir):
"""
Loads all .npy files from a directory, processes only 'volume_fields' and
'surface_fields', aggregates the data, and plots a histogram for each field.
"""
file_paths = glob.glob(os.path.join(data_dir, "*.npy"))
if not file_paths:
print(f"Error: No .npy files found in the directory '{data_dir}'.")
return
aggregated_data = defaultdict(list)
target_keys = ['volume_fields', 'surface_fields']
print(f"Found {len(file_paths)} files. Processing for keys: {target_keys}...")
for file_path in file_paths:
try:
data_dict = np.load(file_path, allow_pickle=True).item()
if not isinstance(data_dict, dict):
print(f"Warning: File '{file_path}' does not contain a dictionary. Skipping.")
continue
for key, array in data_dict.items():
if key in target_keys:
if not isinstance(array, np.ndarray):
print(f"Warning: Item '{key}' in {os.path.basename(file_path)} is not a NumPy array. Skipping.")
continue
if array.ndim == 1:
aggregated_data[key].extend(array)
elif array.ndim == 2:
num_cols = array.shape[1]
for i in range(num_cols):
col_name = f"{key}_col_{i}"
aggregated_data[col_name].extend(array[:, i])
if key == 'volume_fields' and num_cols >= 3:
velocities = array[:, :3]
magnitudes = np.linalg.norm(velocities, axis=1)
aggregated_data[f"{key}_Magnitude"].extend(magnitudes)
else:
print(f"Warning: Array '{key}' has an unsupported dimension {array.ndim}. Skipping.")
except Exception as e:
print(f"Error loading or processing file {os.path.basename(file_path)}: {e}")
print("\n--- All files processed. Generating plots for aggregated data. ---\n")
if not aggregated_data:
print(f"No data for keys {target_keys} was found in any files. Exiting.")
return
for field_name, data_list in aggregated_data.items():
data_array = np.array(data_list)
analyze_and_plot_distribution(data_array, field_name)
if __name__ == "__main__":
DATA_DIRECTORY = "/workspace/cummins/data/domino_data/train/"
process_and_plot_directory(DATA_DIRECTORY)