diff --git a/HISTORY.rst b/HISTORY.rst index b578021b2d..6f9207b1b7 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -61,10 +61,28 @@ 7. InVEST model Z (model names should be sorted A-Z) +Unreleased Changes +------------------ + +Seasonal Water Yield +==================== +* The model now generates a report, a visual summary of results, available in + the output workspace and also viewable from the Workbench after the model run + completes. (`#2321 `_) +* The model now generates an additional output, a CSV containing average monthly + quickflow, baseflow, and precipitation values, in cubic meters per month, for + each feature in the AOI. This output is used by the report to generate some + plots. Note that this CSV is only created when the model is run without + inputting a Local Recharge raster. + (`#2321 `_) +* Various updates to model output data metadata, including correcting the + units of some outputs. + (`#2450 `_) +* Updated the naming convention of several monthly intermediate outputs to be + 1-indexed rather than 0-indexed. This makes filenames consistent throughout + the model, where 1=January and 12=December. + (`#2451 `_) -.. - Unreleased Changes - ------------------ 3.18.0 (2026-02-25) ------------------- diff --git a/Makefile b/Makefile index bc895612fb..8b3e5c108e 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ GIT_SAMPLE_DATA_REPO_REV := cfd1f07673e66823fd22989a2b87afb017aac447 GIT_TEST_DATA_REPO := https://bitbucket.org/natcap/invest-test-data.git GIT_TEST_DATA_REPO_PATH := $(DATA_DIR)/invest-test-data -GIT_TEST_DATA_REPO_REV := c791f2b50e67680832054536899efacbc72e9e0b +GIT_TEST_DATA_REPO_REV := 94c4bc9f0f22082d2251b6b14063eb0ff4094451 GIT_UG_REPO := https://github.com/natcap/invest.users-guide GIT_UG_REPO_PATH := doc/users-guide diff --git a/src/natcap/invest/carbon/reporter.py b/src/natcap/invest/carbon/reporter.py index 52e01df3bb..115c773e09 100644 --- a/src/natcap/invest/carbon/reporter.py +++ b/src/natcap/invest/carbon/reporter.py @@ -201,7 +201,7 @@ def report(file_registry: dict, args_dict: dict, model_spec: ModelSpec, ] input_raster_stats_table = raster_utils.raster_inputs_summary( - args_dict).to_html(na_rep='') + args_dict, model_spec).to_html(na_rep='') output_raster_stats_table = raster_utils.raster_workspace_summary( file_registry).to_html(na_rep='') diff --git a/src/natcap/invest/coastal_vulnerability/reporter.py b/src/natcap/invest/coastal_vulnerability/reporter.py index 2c36f8622a..bf9349dbbd 100644 --- a/src/natcap/invest/coastal_vulnerability/reporter.py +++ b/src/natcap/invest/coastal_vulnerability/reporter.py @@ -9,7 +9,7 @@ from natcap.invest import __version__ from natcap.invest import gettext import natcap.invest.spec -from natcap.invest.reports import jinja_env +from natcap.invest.reports import jinja_env, vector_utils LOGGER = logging.getLogger(__name__) @@ -28,49 +28,6 @@ POINT_SIZE = 20 MAP_WIDTH = 450 # pixels -LEGEND_CONFIG = { - 'labelFontSize': 14, - 'titleFontSize': 14, - 'orient': 'left', - 'gradientLength': 120 -} -AXIS_CONFIG = { - 'labelFontSize': 12, - 'titleFontSize': 12, -} - - -def _get_geojson_bbox(geodataframe): - """Get the bounding box of a GeoDataFrame as a GeoJSON feature. - - Also calculate its aspect ratio. These are useful for cropping - other layers in altair plots. - - Args: - geodataframe (geopandas.GeoDataFrame): - Returns: - tuple: A 2-tuple containing: - - extent_feature (dict): A GeoJSON feature representing the bounding - box of the input GeoDataFrame. - - xy_ratio (float): The aspect ratio of the bounding box - (width/height). - - """ - xmin, ymin, xmax, ymax = geodataframe.total_bounds - xy_ratio = (xmax - xmin) / (ymax - ymin) - extent_feature = { - "type": "Feature", - "geometry": {"type": "Polygon", - "coordinates": [[ - [xmax, ymax], - [xmax, ymin], - [xmin, ymin], - [xmin, ymax], - [xmax, ymax]]]}, - "properties": {} - } - return extent_feature, xy_ratio - def _chart_landmass(geodataframe, clip=False, extent_feature=None): landmass = altair.Chart(geodataframe).mark_geoshape( @@ -134,13 +91,15 @@ def concat_habitats(row): ] ) - _, xy_ratio = _get_geojson_bbox(exposure_geodf) + _, xy_ratio = vector_utils.get_geojson_bbox(exposure_geodf) habitat_map = landmass_chart + habitat_points habitat_map = habitat_map.properties( width=MAP_WIDTH, height=MAP_WIDTH / xy_ratio, title=gettext('The role of habitat in reducing coastal exposure') - ).configure_legend(**LEGEND_CONFIG).configure_axis(**AXIS_CONFIG) + ).configure_legend( + **vector_utils.LEGEND_CONFIG + ).configure_axis(**vector_utils.AXIS_CONFIG) return habitat_map @@ -172,7 +131,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath): landmass_geo = geopandas.read_file( file_registry['clipped_projected_landmass']) - extent_feature, xy_ratio = _get_geojson_bbox(exposure_geo) + extent_feature, xy_ratio = vector_utils.get_geojson_bbox(exposure_geo) landmass_chart = _chart_landmass( landmass_geo, clip=True, extent_feature=extent_feature) base_points = _chart_base_points(exposure_geo) @@ -247,7 +206,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath): width=MAP_WIDTH, height=MAP_WIDTH / xy_ratio, title='coastal exposure' - ).configure_legend(**LEGEND_CONFIG) + ).configure_legend(**vector_utils.LEGEND_CONFIG) exposure_map_json = exposure_map.to_json() exposure_map_caption = [model_spec.get_output( 'coastal_exposure').get_field('exposure').about] @@ -284,7 +243,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath): ).properties( width=MAP_WIDTH, height=200 - ).configure_axis(**AXIS_CONFIG) + ).configure_axis(**vector_utils.AXIS_CONFIG) exposure_histogram_json = exposure_histogram.to_json() base_rank_vars_chart = base_points.mark_circle( @@ -308,7 +267,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath): rank_vars_figure = altair.vconcat( altair.hconcat(*rank_vars_chart_list[:n_cols]), altair.hconcat(*rank_vars_chart_list[n_cols:]) - ).configure_axis(**AXIS_CONFIG) + ).configure_axis(**vector_utils.AXIS_CONFIG) rank_vars_figure_json = rank_vars_figure.to_json() rank_vars_figure_caption = gettext( """ @@ -345,7 +304,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath): ) histograms.append(hist) facetted_histograms = altair.hconcat( - *histograms).configure_axis(**AXIS_CONFIG) + *histograms).configure_axis(**vector_utils.AXIS_CONFIG) facetted_histograms_json = facetted_histograms.to_json() facetted_histograms_caption = model_spec.get_output( 'intermediate_exposure').about @@ -380,7 +339,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath): width=MAP_WIDTH + 30, # extra space for legend height=MAP_WIDTH / xy_ratio, title=gettext('local wind-driven waves vs. open ocean waves') - ).configure_legend(**LEGEND_CONFIG) + ).configure_legend(**vector_utils.LEGEND_CONFIG) wave_energy_map_json = wave_energy_map.to_json() wave_energy_map_caption = [model_spec.get_output( diff --git a/src/natcap/invest/reports/raster_utils.py b/src/natcap/invest/reports/raster_utils.py index b96c973eff..3bb8152d92 100644 --- a/src/natcap/invest/reports/raster_utils.py +++ b/src/natcap/invest/reports/raster_utils.py @@ -13,16 +13,18 @@ import pygeoprocessing import matplotlib import matplotlib.colors -from matplotlib.colors import ListedColormap +from matplotlib.colors import Colormap, ListedColormap import matplotlib.patches import matplotlib.pyplot as plt import pandas import yaml from osgeo import gdal +from pydantic import ConfigDict from pydantic.dataclasses import dataclass from natcap.invest import gettext -from natcap.invest.spec import ModelSpec, Input, Output +from natcap.invest.spec import ModelSpec, Input, Output, \ + CSVInput, SingleBandRasterInput LOGGER = logging.getLogger(__name__) @@ -143,7 +145,7 @@ class RasterTransform(str, Enum): log = 'log' -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class RasterPlotConfig: """A definition for how to plot a raster.""" @@ -157,12 +159,17 @@ class RasterPlotConfig: """For highly skewed data, a transformation can help reveal variation.""" title: str | None = None """An optional plot title. If ``None``, the filename is used.""" + colormap: str | Colormap | None = None + """The string name of a registered matplotlib colormap or a colormap object.""" def __post_init__(self): if self.title is None: self.title = os.path.basename(self.raster_path) self.caption = f'{self.title}:{self.spec.about}' + self.colormap = plt.get_cmap(self.colormap if self.colormap + else COLORMAPS[self.datatype]) + def build_raster_plot_configs(id_lookup_table, raster_plot_tuples): """Build RasterPlotConfigs for use in plotting input or output rasters. @@ -256,7 +263,7 @@ def _extra_wide_aoi(xy_ratio): return xy_ratio > EX_WIDE_AOI_THRESHOLD -def _choose_n_rows_n_cols(xy_ratio, n_plots): +def _choose_n_rows_n_cols(xy_ratio, n_plots, small_plots): if _extra_wide_aoi(xy_ratio): n_cols = 1 elif _wide_aoi(xy_ratio): @@ -264,14 +271,17 @@ def _choose_n_rows_n_cols(xy_ratio, n_plots): else: n_cols = 3 + if small_plots: + n_cols += 1 + if n_cols > n_plots: n_cols = n_plots n_rows = int(math.ceil(n_plots / n_cols)) return n_rows, n_cols -def _figure_subplots(xy_ratio, n_plots): - n_rows, n_cols = _choose_n_rows_n_cols(xy_ratio, n_plots) +def _figure_subplots(xy_ratio, n_plots, small_plots=False): + n_rows, n_cols = _choose_n_rows_n_cols(xy_ratio, n_plots, small_plots) figure_width = MAX_FIGURE_WIDTH_DEFAULT if (n_cols == 2) and (_wide_aoi(xy_ratio)): @@ -306,16 +316,21 @@ def _get_title_line_width(n_plots: int, xy_ratio: float) -> int: return 31 # 3-column layout -def _get_title_kwargs(title: str, resampled: bool, line_width: int): +def _get_title_kwargs(title: str, resampled: bool, line_width: int, facets=False): label = f"{title}{' (resampled)' if resampled else ''}" label = textwrap.fill(label, width=line_width) + padding = 1.5 + if not facets: + # Faceted plots don't need extra padding for title because their units + # label appears with the legend instead of under the title + padding *= SUBTITLE_FONT_SIZE return { 'fontfamily': 'monospace', 'fontsize': TITLE_FONT_SIZE, 'fontweight': 700, 'label': label, 'loc': 'left', - 'pad': 1.5 * SUBTITLE_FONT_SIZE, + 'pad': padding, 'verticalalignment': 'bottom', } @@ -390,7 +405,7 @@ def plot_raster_list(raster_list: list[RasterPlotConfig]): colorbar_kwargs = {} imshow_kwargs['norm'] = transform imshow_kwargs['interpolation'] = 'none' - cmap = COLORMAPS[dtype] + cmap = config.colormap if dtype == 'divergent': if transform == 'log': transform = matplotlib.colors.SymLogNorm(linthresh=0.03) @@ -501,7 +516,8 @@ def plot_and_base64_encode_rasters(raster_list: list[RasterPlotConfig]) -> str: return base64_encode(figure) -def plot_raster_facets(tif_list, datatype, transform=None, title_list=None): +def plot_raster_facets(tif_list, datatype, transform=None, title_list=None, + small_plots=False, colormap=None, supertitle=None): """Plot a list of rasters that will all share a fixed colorscale. When all the rasters have the same shape and represent the same variable, @@ -517,15 +533,22 @@ def plot_raster_facets(tif_list, datatype, transform=None, title_list=None): to the colormap. Either 'linear' or 'log'. title_list (list): Optional list of strings to use as subplot titles. If ``None``, the raster filename is used as the title. + small_plots (bool): Defaults to False. If True, the typical number of + columns calculated for plotting facets will be increased by 1, + making the plots smaller so more can be viewed side-by-side. + colormap (str): Optional string name of a registered matplotlib + colormap or a colormap object to use in place of the default + derived from the raster datatype. + supertitle (str): Optional title to use for the entire group of + raster facets. """ raster_info = pygeoprocessing.get_raster_info(tif_list[0]) bbox = raster_info['bounding_box'] n_plots = len(tif_list) xy_ratio = _get_aspect_ratio(bbox) - fig, axes = _figure_subplots(xy_ratio, n_plots) + fig, axes = _figure_subplots(xy_ratio, n_plots, small_plots=small_plots) - cmap_str = COLORMAPS[datatype] if transform is None: transform = 'linear' if title_list is None: @@ -549,7 +572,7 @@ def plot_raster_facets(tif_list, datatype, transform=None, title_list=None): # instead of storing all arrays in memory vmin = numpy.nanmin(ndarray) vmax = numpy.nanmax(ndarray) - cmap = plt.get_cmap(cmap_str) + cmap = plt.get_cmap(colormap if colormap else COLORMAPS[datatype]) if datatype == 'divergent': if transform == 'log': normalizer = matplotlib.colors.SymLogNorm(linthresh=0.03, vmin=vmin, vmax=vmax) @@ -559,7 +582,6 @@ def plot_raster_facets(tif_list, datatype, transform=None, title_list=None): if numpy.isclose(vmin, 0.0): vmin = 1e-6 normalizer = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax) - cmap.set_under(cmap.colors[0]) # values below vmin (0s) get this color else: normalizer = plt.Normalize(vmin=vmin, vmax=vmax) for arr, ax, raster_path, title in zip( @@ -567,14 +589,15 @@ def plot_raster_facets(tif_list, datatype, transform=None, title_list=None): mappable = ax.imshow(arr, cmap=cmap, norm=normalizer) # all rasters are identical size; `resampled` will be the same for all title_line_width = _get_title_line_width(n_plots, xy_ratio) - ax.set_title(**_get_title_kwargs(title, resampled, title_line_width)) - units = _get_raster_units(raster_path) - if units: - (ylim_kwargs, - text_kwargs) = _get_units_text_kwargs(units, len(arr)) - ax.set_ylim(**ylim_kwargs) - ax.text(**text_kwargs) - fig.colorbar(mappable, ax=ax) + ax.set_title(**_get_title_kwargs(title, resampled, title_line_width, facets=True)) + + units = _get_raster_units(tif_list[0]) + legend_label = f"{UNITS_TEXT}: {units}" if units else None + fig.colorbar(mappable, ax=axes.ravel().tolist(), label=legend_label) + + if supertitle: + fig.suptitle(supertitle, fontsize=TITLE_FONT_SIZE) + [ax.set_axis_off() for ax in axes.flatten()] return fig @@ -615,17 +638,13 @@ def _build_stats_table_row(resource, band): def _get_raster_metadata(filepath): - if isinstance(filepath, collections.abc.Mapping): - for path in filepath.values(): - return _get_raster_metadata(path) - else: - try: - resource = geometamaker_load(f'{filepath}.yml') - except (FileNotFoundError, ValueError) as err: - LOGGER.debug(err) - return None - if isinstance(resource, geometamaker.models.RasterResource): - return resource + try: + resource = geometamaker_load(f'{filepath}.yml') + except (FileNotFoundError, ValueError) as err: + LOGGER.debug(err) + return None + if isinstance(resource, geometamaker.models.RasterResource): + return resource def _get_raster_units(filepath): @@ -636,30 +655,64 @@ def _get_raster_units(filepath): def raster_workspace_summary(file_registry): """Create a table of stats for all rasters in a file_registry.""" raster_summary = {} - for path in file_registry.values(): - resource = _get_raster_metadata(path) - band = resource.get_band_description(1) if resource else None - if band: - filename = os.path.basename(resource.path) - raster_summary[filename] = _build_stats_table_row( - resource, band) + + def _summarize_output(item): + if isinstance(item, collections.abc.Mapping): + for path in item.values(): + _summarize_output(path) + else: + resource = _get_raster_metadata(item) + band = resource.get_band_description(1) if resource else None + if band: + filename = os.path.basename(resource.path) + raster_summary[filename] = _build_stats_table_row( + resource, band) + + for item in file_registry.values(): + _summarize_output(item) return pandas.DataFrame(raster_summary).T -def raster_inputs_summary(args_dict): +def raster_inputs_summary(args_dict, model_spec): """Create a table of stats for all rasters in an args_dict.""" raster_summary = {} - for v in args_dict.values(): - if isinstance(v, str) and os.path.isfile(v): - resource = geometamaker.describe(v, compute_stats=True) - if isinstance(resource, geometamaker.models.RasterResource): - filename = os.path.basename(resource.path) - band = resource.get_band_description(1) - raster_summary[filename] = _build_stats_table_row( - resource, band) - # Remove 'Units' column if all units are blank - if not any(raster_summary[filename][UNITS_COL_NAME]): - del raster_summary[filename][UNITS_COL_NAME] + + paths_to_check = [v for v in args_dict.values() + if isinstance(v, str) and os.path.isfile(v)] + + paths_to_check.extend(_parse_csv_paths_from_spec(args_dict, model_spec)) + + for v in paths_to_check: + resource = geometamaker.describe(v, compute_stats=True) + if isinstance(resource, geometamaker.models.RasterResource): + filename = os.path.basename(resource.path) + band = resource.get_band_description(1) + raster_summary[filename] = _build_stats_table_row( + resource, band) + # Remove 'Units' column if all units are blank + if not any(raster_summary[filename][UNITS_COL_NAME]): + del raster_summary[filename][UNITS_COL_NAME] return pandas.DataFrame(raster_summary).T + + +def _parse_csv_paths_from_spec(args_dict, spec): + table_map_inputs = [] + for input_ in spec.inputs: + if isinstance(input_, CSVInput): + table_map_inputs.extend([ + (input_.id, col.id) for col in input_.columns + if isinstance(col, SingleBandRasterInput)]) + + paths_to_check = [] + for input_id, col_name in table_map_inputs: + if args_dict.get(input_id): + df = CSVInput.get_validated_dataframe( + spec.get_input(input_id), + csv_path=args_dict.get(input_id)) + paths_to_check.extend([ + v for v in df[col_name] + if isinstance(v, str) and os.path.isfile(v)]) + + return paths_to_check diff --git a/src/natcap/invest/reports/sdr_ndr_report_generator.py b/src/natcap/invest/reports/sdr_ndr_report_generator.py index 4057db96c3..5a709171a4 100644 --- a/src/natcap/invest/reports/sdr_ndr_report_generator.py +++ b/src/natcap/invest/reports/sdr_ndr_report_generator.py @@ -69,7 +69,7 @@ def report(file_registry, args_dict, model_spec, target_html_filepath, file_registry).to_html(na_rep='') input_raster_stats_table = raster_utils.raster_inputs_summary( - args_dict).to_html(na_rep='') + args_dict, model_spec).to_html(na_rep='') model_description = model_spec.about model_description += gettext( diff --git a/src/natcap/invest/reports/templates/models/seasonal_water_yield.html b/src/natcap/invest/reports/templates/models/seasonal_water_yield.html new file mode 100644 index 0000000000..ceb96ba281 --- /dev/null +++ b/src/natcap/invest/reports/templates/models/seasonal_water_yield.html @@ -0,0 +1,130 @@ +{% extends 'base.html' %} + +{% block content %} + {{ super() }} + {% from 'args-table.html' import args_table %} + {% from 'caption.html' import caption %} + {% from 'content-grid.html' import content_grid %} + {% from 'metadata.html' import list_metadata %} + {% from 'raster-plot-img.html' import raster_plot_img %} + {% from 'wide-table.html' import wide_table %} + +

Results

+ {% if qf_rasters %} + {{ accordion_section( + 'Annual and Monthly Quickflow (QF)', + content_grid([ + (caption(raster_group_caption, pre_caption=True), 100), + (raster_plot_img(qf_rasters['annual_qf_img_src'], 'Annual Quickflow values'), 100), + (raster_plot_img(qf_rasters['monthly_qf_img_src'], 'Monthly Quickflow values'), 100), + (caption(qf_rasters['qf_caption'], definition_list=True), 100) + ]) + )}} + {% endif %} + + {% if qf_b_charts %} + {{ accordion_section( + 'Average Quickflow and Baseflow by Month', + content_grid([ + ('
', 100), + (caption(qf_b_charts['caption'], qf_b_charts['sources']), 100) + ]) + ) }} + {% endif %} + + {{ accordion_section( + outputs_heading, + content_grid([ + (caption(raster_group_caption, pre_caption=True), 100), + (raster_plot_img(outputs_img_src, 'Primary Outputs'), 100), + (caption(outputs_caption, definition_list=True), 100) + ]) + ) }} + + {{ accordion_section( + 'Results Aggregated by AOI Feature', + content_grid([ + (content_grid([ + ('
', 100), + (caption(qb_map_caption, aggregate_map_source_list), 100) + ]), 50), + (content_grid([ + ('
', 100), + (caption(vri_sum_map_caption, aggregate_map_source_list), 100) + ]), 50), + ]) + ) }} + + {{ accordion_section( + stream_outputs_heading, + content_grid([ + (caption(raster_group_caption, pre_caption=True), 100), + (raster_plot_img(stream_img_src, stream_outputs_heading), 100), + (caption(stream_caption, definition_list=True), 100) + ]) + )}} + + {{ accordion_section( + 'Output Raster Stats', + content_grid([ + (stats_table_note, 100), + (wide_table( + output_raster_stats_table | safe, + font_size_px=16 + ), 100) + ]) + )}} + +

Inputs

+ {{ accordion_section( + 'Arguments', + args_table(args_dict), + )}} + + {{ accordion_section( + 'Raster Inputs', + content_grid([ + (caption(raster_group_caption, pre_caption=True), 100), + (raster_plot_img(inputs_img_src, 'Raster Inputs'), 100), + (caption(inputs_caption, definition_list=True), 100) + ]) + ) }} + + {{ accordion_section( + 'Input Raster Stats', + content_grid([ + (stats_table_note, 100), + (wide_table( + input_raster_stats_table | safe, + font_size_px=16 + ), 100) + ]) + )}} + +

Metadata

+ {{ + accordion_section( + 'Output Filenames and Descriptions', + list_metadata(model_spec_outputs), + expanded=False + ) + }} + +{% endblock content %} + +{% from 'vegalite-plot.html' import embed_vega %} +{% block scripts %} + {{ super() }} + + + + {% include 'vega-embed-js.html' %} + {% set chart_spec_id_list = [ + (qb_map_json, 'qb_map'), + (vri_sum_map_json, 'vri_sum_map'), + ] %} + {{ embed_vega(chart_spec_id_list) }} + {% if qf_b_charts %} + {{ embed_vega([(qf_b_charts['json'], 'qf_b_charts')]) }} + {% endif %} +{% endblock scripts %} diff --git a/src/natcap/invest/reports/vector_utils.py b/src/natcap/invest/reports/vector_utils.py new file mode 100644 index 0000000000..f53abf0776 --- /dev/null +++ b/src/natcap/invest/reports/vector_utils.py @@ -0,0 +1,46 @@ +import altair + + +LEGEND_CONFIG = { + 'labelFontSize': 14, + 'titleFontSize': 14, + 'orient': 'left', + 'gradientLength': 120 +} +AXIS_CONFIG = { + 'labelFontSize': 12, + 'titleFontSize': 12, +} + + +def get_geojson_bbox(geodataframe): + """Get the bounding box of a GeoDataFrame as a GeoJSON feature. + + Also calculate its aspect ratio. These are useful for cropping + other layers in altair plots. + + Args: + geodataframe (geopandas.GeoDataFrame): + + Returns: + tuple: A 2-tuple containing: + - extent_feature (dict): A GeoJSON feature representing the bounding + box of the input GeoDataFrame. + - xy_ratio (float): The aspect ratio of the bounding box + (width/height). + + """ + xmin, ymin, xmax, ymax = geodataframe.total_bounds + xy_ratio = (xmax - xmin) / (ymax - ymin) + extent_feature = { + "type": "Feature", + "geometry": {"type": "Polygon", + "coordinates": [[ + [xmax, ymax], + [xmax, ymin], + [xmin, ymin], + [xmin, ymax], + [xmax, ymax]]]}, + "properties": {} + } + return extent_feature, xy_ratio diff --git a/src/natcap/invest/seasonal_water_yield/reporter.py b/src/natcap/invest/seasonal_water_yield/reporter.py new file mode 100644 index 0000000000..44ad7cb9af --- /dev/null +++ b/src/natcap/invest/seasonal_water_yield/reporter.py @@ -0,0 +1,388 @@ +import calendar +import csv +import logging +import os +import time + +import altair +import geopandas +import numpy +import pandas +import pygeoprocessing +from matplotlib.colors import LinearSegmentedColormap +from osgeo import gdal +from osgeo import ogr + +from natcap.invest import __version__ +from natcap.invest import gettext +import natcap.invest.spec +from natcap.invest.reports import jinja_env +from natcap.invest.reports import raster_utils +from natcap.invest.reports import report_constants +from natcap.invest.reports import vector_utils +from natcap.invest.reports.raster_utils import RasterDatatype +from natcap.invest.reports.raster_utils import RasterPlotConfig +from natcap.invest.reports.raster_utils import RasterTransform +from natcap.invest.unit_registry import u + + +LOGGER = logging.getLogger(__name__) + +TEMPLATE = jinja_env.get_template('models/seasonal_water_yield.html') + +MAP_WIDTH = 450 # pixels + +qf_label_month_map = { + f"qf_{month_index}": str(month_index) for month_index in range(1, 13) +} + + +def _label_to_month(row): + return qf_label_month_map[row['MonthLabel']] + + +def _create_aggregate_map(geodataframe, xy_ratio, attribute, + title): + attr_map = altair.Chart(geodataframe).mark_geoshape( + stroke="white", + strokeWidth=0.5 + ).project( + type='identity', + reflectY=True + ).encode( + color=altair.Color( + attribute, + scale=altair.Scale(domainMid=0, scheme="brownbluegreen") + ), + tooltip=[altair.Tooltip(attribute, title=attribute)] + ).properties( + width=MAP_WIDTH, + height=MAP_WIDTH / xy_ratio, + title=title + ).configure_legend(**vector_utils.LEGEND_CONFIG) + + return attr_map.to_json() + + +def _create_linked_monthly_plots(aoi_vector_path, aggregate_csv_path, xy_ratio): + map_df = geopandas.read_file(aoi_vector_path) + values_df = pandas.read_csv(aggregate_csv_path) + values_df.month = values_df.month.astype(str) + + feat_select = altair.selection_point(fields=["geom_id"], name="feat_select", value=0) + + attr_map = altair.Chart(map_df).mark_geoshape( + stroke="white", + strokeWidth=0.5 + ).project( + type='identity', + reflectY=True + ).encode( + color=altair.condition( + feat_select, + altair.value("seagreen"), + altair.value("lightgray") + ), + tooltip=[altair.Tooltip("geom_id", title="Feature")] + ).properties( + width=MAP_WIDTH*1.25, + height=MAP_WIDTH*1.25 / xy_ratio, + title="AOI" + ).add_params( + feat_select + ) + + base_chart = altair.Chart(values_df) + + bar_chart = base_chart.mark_bar().transform_fold( + ['baseflow', 'quickflow'] + ).encode( + altair.X("month(month):O").title("Month"), + altair.Y("sum(value):Q").title("Quickflow + Baseflow (m3/month)"), + altair.Order(field='key', sort='ascending'), + color=altair.Color('key:N').scale( + domain=['quickflow', "baseflow", "precipitation"], + range=['#fdae6b', '#9ecae1', "#0500a3"] + ), + tooltip=[altair.Tooltip(val, aggregate="sum", type="quantitative", + format='.5f', title=val) + for val in ["quickflow", "baseflow", "precipitation"]] + ) + + precip_chart = base_chart.mark_line().encode( + altair.X("month(month):O").title("Month"), + altair.Y( + "sum(precipitation)", + axis=altair.Axis(orient="right") + ).title("Precipitation (m3/month)"), + color=altair.value('#0500a3') + ) + + combined_chart = altair.layer(bar_chart, precip_chart).resolve_scale( + y='independent' + ).transform_filter( + feat_select + ).properties( + title=altair.Title(altair.expr( + f'"Mean Quickflow + Baseflow for Feature " + {feat_select.name}.geom_id') + ) + ) + + legend_config = vector_utils.LEGEND_CONFIG.copy() + legend_config['orient'] = 'right' + + chart = altair.hconcat( + attr_map, combined_chart, + spacing=30 + ).configure_legend( + **legend_config + ).configure_axis( + **vector_utils.AXIS_CONFIG + ).configure_view( + discreteWidth=300 + ) + return chart.to_json() + + +def report(file_registry, args_dict, model_spec, target_html_filepath): + """Generate an html summary of Seasonal Water Yield results. + + Args: + file_registry (dict): The ``natcap.invest.FileRegistry.registry`` + that was returned by ``natcap.invest.seasonal_water_yield.execute``. + args_dict (dict): The arguments that were passed to + ``natcap.invest.seasonal_water_yield.execute``. + model_spec (natcap.invest.spec.ModelSpec): + ``natcap.invest.seasonal_water_yield.MODEL_SPEC`` + target_html_filepath (str): path to an html file generated by this + function. + + Returns: + None + """ + + # qb and vri_sum plots from the output aggregate vector + aggregated_results = geopandas.read_file(file_registry['aggregate_vector']) + _, xy_ratio = vector_utils.get_geojson_bbox(aggregated_results) + + qb_map_json = _create_aggregate_map( + aggregated_results, xy_ratio, 'qb', + gettext("Mean local recharge value within the watershed " + f"({model_spec.get_output('aggregate_vector').get_field('qb').units})")) + qb_map_caption = [ + model_spec.get_output('aggregate_vector').get_field('qb').about, + gettext('Values are in millimeters, but should be interpreted as ' + 'relative values, not absolute values.')] + + vri_sum_map_json = _create_aggregate_map( + aggregated_results, xy_ratio, 'vri_sum', + gettext("Total recharge contribution of the watershed " + f"({model_spec.get_output('aggregate_vector').get_field('vri_sum').units})")) + vri_sum_map_caption = [ + model_spec.get_output('aggregate_vector').get_field('vri_sum').about, + gettext('The sum of ``Vri_[suffix].tif`` pixel values within the watershed.')] + + vector_map_source_list = [model_spec.get_output('aggregate_vector').path] + + if args_dict['user_defined_local_recharge']: + # Quickflow isn't calculated if `user_defined_local_recharge` + # so we cannot construct monthly average qf + b charts + qf_b_charts = None + else: + # Monthly quickflow + baseflow plots and map + qf_b_charts_json = _create_linked_monthly_plots( + file_registry['aggregate_vector'], + file_registry['monthly_qf_table'], + xy_ratio) + qf_b_charts_caption = gettext( + """ + This chart displays the monthly combined average baseflow + quickflow for + each feature within the AOI, as well as the monthly average precipitation. + Select a feature on the AOI map to see the values for that feature. + Shift+Click to select multiple features; the chart will display the sum of + their values. + """ + ) + qf_b_charts_source_list = [ + model_spec.get_output('monthly_qf_table').path, + model_spec.get_output('aggregate_vector').path] + qf_b_charts = { + 'json': qf_b_charts_json, + 'caption': qf_b_charts_caption, + 'sources': qf_b_charts_source_list + } + + # Raster config lists + stream_raster_config_list = [ + RasterPlotConfig( + raster_path=file_registry['pit_filled_dem'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('pit_filled_dem')), + RasterPlotConfig( + raster_path=file_registry['stream'], + datatype=RasterDatatype.binary_high_contrast, + spec=model_spec.get_output('stream'))] + + output_raster_config_list = [ + RasterPlotConfig( + raster_path=file_registry['b'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('b')) + ] + + input_raster_config_list = [ + RasterPlotConfig( + raster_path=args_dict['dem_raster_path'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_input('dem_raster_path')), + RasterPlotConfig( + raster_path=args_dict['lulc_raster_path'], + datatype=RasterDatatype.nominal, + spec=model_spec.get_input('lulc_raster_path')) + ] + + if args_dict['user_defined_local_recharge']: + input_raster_config_list.append( + RasterPlotConfig( + raster_path=args_dict['l_path'], + datatype=RasterDatatype.divergent, + spec=model_spec.get_input('l_path'), + colormap='BrBG')) + qf_rasters = None + raster_outputs_heading = 'Annual Baseflow' + + else: + output_raster_config_list.extend([ + RasterPlotConfig( + raster_path=file_registry['annual_precip'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('annual_precip')), + RasterPlotConfig( + raster_path=file_registry['aet'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('aet')), + RasterPlotConfig( + raster_path=file_registry['cn'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('cn')), + RasterPlotConfig( + raster_path=file_registry['l'], + datatype=RasterDatatype.divergent, + spec=model_spec.get_output('l'), + colormap='BrBG')]) + raster_outputs_heading = 'Additional Raster Outputs' + + input_raster_config_list.append( + RasterPlotConfig( + raster_path=args_dict['soil_group_path'], + datatype=RasterDatatype.nominal, + spec=model_spec.get_input('soil_group_path'))) + + # Quickflow outputs are only created if not `user_defined_local_recharge` + annual_qf_raster_config = RasterPlotConfig( + raster_path=file_registry['qf'], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('qf'), + title=gettext("Annual Quickflow"), + transform=RasterTransform.log, + colormap='GnBu') + + monthly_qf_raster_config_list = [ + RasterPlotConfig( + raster_path=file_registry['qf_[MONTH]'][str(month)], + datatype=RasterDatatype.continuous, + spec=model_spec.get_output('qf_[MONTH]'), + title=gettext(f"{calendar.month_name[month]}"), + transform=RasterTransform.log, + colormap='GnBu' + ) for month in range(1, 13)] + + annual_qf_img_src = raster_utils.plot_and_base64_encode_rasters( + [annual_qf_raster_config]) + monthly_qf_plots = raster_utils.plot_raster_facets( + [raster_config.raster_path for raster_config + in monthly_qf_raster_config_list], + 'continuous', + transform=RasterTransform.log, + title_list=[raster_config.title for raster_config + in monthly_qf_raster_config_list], + small_plots=True, + colormap='GnBu', + supertitle=gettext("Monthly Quickflow")) + annual_qf_displayname = os.path.basename( + annual_qf_raster_config.raster_path) + monthly_qf_img_src = raster_utils.base64_encode(monthly_qf_plots) + monthly_qf_displayname = os.path.basename( + monthly_qf_raster_config_list[0].raster_path).replace('1', '[MONTH]') + qf_raster_caption = [ + gettext(f'Map of Annual Quickflow: {annual_qf_displayname}'), + gettext(f'Maps of Monthly Quickflow: {monthly_qf_displayname}' + ' (1 = January… 12 = December)') + ] + + qf_rasters = { + 'annual_qf_img_src': annual_qf_img_src, + 'monthly_qf_img_src': monthly_qf_img_src, + 'qf_caption': qf_raster_caption} + + # Create raster image sources and captions: + stream_img_src = raster_utils.plot_and_base64_encode_rasters( + stream_raster_config_list) + stream_raster_caption = raster_utils.caption_raster_list( + stream_raster_config_list) + stream_outputs_heading = gettext( + 'Stream Network Maps (Flow Algorithm: ' + f'{args_dict["flow_dir_algorithm"]}, ' + 'Threshold Flow Accumulation value: ' + f'{args_dict["threshold_flow_accumulation"]})') + + outputs_img_src = raster_utils.plot_and_base64_encode_rasters( + output_raster_config_list) + output_raster_caption = raster_utils.caption_raster_list( + output_raster_config_list) + + output_raster_stats_table = raster_utils.raster_workspace_summary( + file_registry).to_html(na_rep='') + + input_raster_stats_table = raster_utils.raster_inputs_summary( + args_dict, model_spec).to_html(na_rep='') + + inputs_img_src = raster_utils.plot_and_base64_encode_rasters( + input_raster_config_list) + inputs_raster_caption = raster_utils.caption_raster_list( + input_raster_config_list) + + with open(target_html_filepath, 'w', encoding='utf-8') as target_file: + target_file.write(TEMPLATE.render( + report_script=model_spec.reporter, + invest_version=__version__, + report_filepath=target_html_filepath, + model_id=model_spec.model_id, + model_name=model_spec.model_title, + model_description=model_spec.about, + userguide_page=model_spec.userguide, + timestamp=time.strftime('%Y-%m-%d %H:%M'), + args_dict=args_dict, + raster_group_caption=report_constants.RASTER_GROUP_CAPTION, + stats_table_note=report_constants.STATS_TABLE_NOTE, + stream_img_src=stream_img_src, + stream_caption=stream_raster_caption, + stream_outputs_heading=stream_outputs_heading, + outputs_heading=raster_outputs_heading, + outputs_img_src=outputs_img_src, + outputs_caption=output_raster_caption, + qf_rasters=qf_rasters, + output_raster_stats_table=output_raster_stats_table, + input_raster_stats_table=input_raster_stats_table, + inputs_img_src=inputs_img_src, + inputs_caption=inputs_raster_caption, + qf_b_charts=qf_b_charts, + qb_map_json=qb_map_json, + qb_map_caption=qb_map_caption, + vri_sum_map_json=vri_sum_map_json, + vri_sum_map_caption=vri_sum_map_caption, + aggregate_map_source_list=vector_map_source_list, + model_spec_outputs=model_spec.outputs, + )) + + LOGGER.info(f'Created {target_html_filepath}') diff --git a/src/natcap/invest/seasonal_water_yield/seasonal_water_yield.py b/src/natcap/invest/seasonal_water_yield/seasonal_water_yield.py index 8c6d931840..f52d2a2af8 100644 --- a/src/natcap/invest/seasonal_water_yield/seasonal_water_yield.py +++ b/src/natcap/invest/seasonal_water_yield/seasonal_water_yield.py @@ -1,4 +1,5 @@ """InVEST Seasonal Water Yield Model.""" +import csv import fractions import logging import os @@ -21,14 +22,31 @@ TARGET_NODATA = -1 N_MONTHS = 12 -MONTH_ID_TO_LABEL = [ - 'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', - 'nov', 'dec'] +MONTH_RANGE = range(1, N_MONTHS+1) +MONTH_ID_TO_LABEL = { + 1: 'jan', 2: 'feb', 3: 'mar', 4: 'apr', 5: 'may', 6: 'jun', + 7: 'jul', 8: 'aug', 9: 'sep', 10: 'oct', 11: 'nov', 12: 'dec'} + +_model_description = gettext( + """ + The Seasonal Water Yield (SWY) model estimates the amount of water produced + by a watershed, arriving in streams over the course of a year. The primary + outputs of the model are quickflow, local recharge, and baseflow. Quickflow + represents the amount of precipitation that runs off of the land directly, + during and soon after a rain event, and local recharge represents the amount + of rainfall that infiltrates into soil, minus what is evaporated or used by + vegetation. Baseflow is the amount of precipitation that enters streams more + gradually through sub-surface flow, including during the dry season. The model + is based on inputs of topography (DEM), soils, land cover and management, + rainfall, and vegetation water demand. + """) MODEL_SPEC = spec.ModelSpec( model_id="seasonal_water_yield", model_title=gettext("Seasonal Water Yield"), userguide="seasonal_water_yield.html", + reporter="natcap.invest.seasonal_water_yield.reporter", + about=_model_description, validate_spatial_overlap=True, different_projections_ok=True, aliases=("swy",), @@ -332,7 +350,7 @@ " (which is not evapotranspired before it reaches the stream)." ), data_type=float, - units=u.millimeter + units=u.millimeter / u.year ), spec.SingleBandRasterOutput( id="b_sum", @@ -343,21 +361,34 @@ " stream." ), data_type=float, - units=u.millimeter + units=u.millimeter / u.year ), spec.SingleBandRasterOutput( id="cn", path="CN.tif", about=gettext("Map of curve number values."), data_type=float, - units=u.none + units=u.none, + created_if="not user_defined_local_recharge" + ), + spec.SingleBandRasterOutput( + id="l", + path="L.tif", + about=gettext( + "Map of local recharge. If a user-defined local recharge input" + " is provided, this is a copy of that layer, aligned and clipped" + " to match the other spatial inputs. Otherwise, this is the" + " local recharge as calculated by the model." + ), + data_type=float, + units=u.millimeter / u.year ), spec.SingleBandRasterOutput( id="l_avail", path="L_avail.tif", about=gettext("Map of available local recharge"), data_type=float, - units=u.millimeter + units=u.millimeter / u.year ), spec.SingleBandRasterOutput( id="l_sum_avail", @@ -367,7 +398,8 @@ " available for evapotranspiration by this pixel." ), data_type=float, - units=u.millimeter + units=u.millimeter / u.year, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="l_sum", @@ -378,14 +410,15 @@ " evapotranspiration to downslope pixels." ), data_type=float, - units=u.millimeter + units=u.millimeter / u.year ), spec.SingleBandRasterOutput( id="qf", path="QF.tif", about=gettext("Map of quickflow"), data_type=float, - units=u.millimeter / u.year + units=u.millimeter / u.year, + created_if="not user_defined_local_recharge" ), spec.STREAM, spec.SingleBandRasterOutput( @@ -393,7 +426,8 @@ path="P.tif", about=gettext("The total precipitation across all months on this pixel."), data_type=float, - units=u.millimeter / u.year + units=u.millimeter / u.year, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="vri", @@ -403,7 +437,7 @@ " the total recharge." ), data_type=float, - units=u.millimeter + units=u.millimeter / u.year ), spec.VectorOutput( id="aggregate_vector", @@ -414,24 +448,89 @@ spec.NumberOutput( id="qb", about=gettext("Mean local recharge value within the watershed"), - units=u.millimeter + units=u.millimeter / u.year ), spec.NumberOutput( id="vri_sum", about=gettext( - "Total recharge contribution, (positive or negative) within the" + "Total recharge contribution (positive or negative) within the" " watershed." ), - units=u.millimeter + units=u.millimeter / u.year + ), + spec.IntegerOutput( + id="geom_id", + about=gettext( + "A unique ID for the watershed." + ), + units=u.none ) + ] ), + spec.CSVOutput( + id="monthly_qf_table", + path="monthly_quickflow_baseflow.csv", + about=gettext( + "Table of average monthly baseflow, quickflow, and precipitation" + " values for each watershed (or feature) within the AOI." + ), + columns=[ + spec.IntegerOutput( + id="geom_id", + about=gettext( + "A unique ID for the watershed. This will correspond to" + " the 'geom_id' column in the Aggregate Results shapefile." + ), + units=u.none + ), + spec.NumberOutput( + id="month", + about=gettext( + "Values are the numbers 1-12 corresponding to each month," + " January (1) through December (12)." + ), + units=u.none + ), + spec.NumberOutput( + id="quickflow", + about=gettext( + "The average quickflow value for the month in the watershed," + " expressed in cubic meters." + ), + units=u.meter ** 3 / u.month + ), + spec.NumberOutput( + id="baseflow", + about=gettext( + "The average baseflow value for the month in the watershed," + " expressed in cubic meters. Since baseflow is calculated on" + " an annual scale, the values for each watershed have been" + " distributed evenly across the year (annual average divided" + " by 12)." + ), + units=u.meter ** 3 / u.month + ), + spec.NumberOutput( + id="precipitation", + about=gettext( + "The average precipitation value for the month in the watershed," + " expressed in cubic meters. Values are based on the aligned" + " input monthly precipitation rasters." + ), + units=u.meter ** 3 / u.month + ) + ], + index_col="geom_id", + created_if="not user_defined_local_recharge" + ), spec.SingleBandRasterOutput( id="aet", path="intermediate_outputs/aet.tif", about=gettext("Map of actual evapotranspiration"), data_type=float, - units=u.millimeter + units=u.millimeter / u.year, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="flow_dir", @@ -441,7 +540,7 @@ " the option selected." ), data_type=int, - units=None + units=u.none ), spec.SingleBandRasterOutput( id="qf_[MONTH]", @@ -450,14 +549,16 @@ "Maps of monthly quickflow (1 = January… 12 = December)" ), data_type=float, - units=u.millimeter + units=u.millimeter / u.month, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="si", path="intermediate_outputs/Si.tif", about=gettext("Map of the S_i factor derived from CN"), data_type=float, - units=u.inch + units=u.inch, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="lulc_aligned", @@ -467,7 +568,7 @@ " spatial inputs" ), data_type=int, - units=None + units=u.none ), spec.SingleBandRasterOutput( id="dem_aligned", @@ -494,7 +595,8 @@ " other spatial inputs" ), data_type=int, - units=None + units=u.none, + created_if="not user_defined_local_recharge" ), spec.FLOW_ACCUMULATION.model_copy(update=dict( id="flow_accum", @@ -507,14 +609,16 @@ " other spatial inputs" ), data_type=float, - units=u.millimeter / u.year + units=u.millimeter / u.year, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="n_events[MONTH]", path="intermediate_outputs/n_events[MONTH].tif", about=gettext("Map of monthly rain events"), data_type=int, - units=None + units=u.none, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="et0_a[MONTH]", @@ -524,26 +628,16 @@ " spatial inputs" ), data_type=float, - units=u.millimeter + units=u.millimeter / u.month, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="kc_[MONTH]", path="intermediate_outputs/kc_[MONTH].tif", about=gettext("Map of monthly KC values"), data_type=float, - units=u.none - ), - spec.SingleBandRasterOutput( - id="l", - path="L.tif", - about=gettext( - "Map of local recharge. If a user-defined local recharge input" - " is provided, this is a copy of that layer, aligned and clipped" - " to match the other spatial inputs. Otherwise, this is the" - " local recharge as calculated by the model." - ), - data_type=float, - units=u.millimeter + units=u.none, + created_if="not user_defined_local_recharge" ), spec.SingleBandRasterOutput( id="cz_aligned", @@ -554,7 +648,7 @@ ), created_if="user_defined_climate_zones", data_type=int, - units=None + units=u.none ), spec.TASKGRAPH_CACHE ] @@ -661,7 +755,7 @@ def execute(args): # make all 12 entries equal to args['alpha_m'] alpha_m = float(fractions.Fraction(args['alpha_m'])) alpha_month_map = dict( - (month_index+1, alpha_m) for month_index in range(N_MONTHS)) + (month_index, alpha_m) for month_index in MONTH_RANGE) beta_i = float(fractions.Fraction(args['beta_i'])) gamma = float(fractions.Fraction(args['gamma'])) @@ -670,23 +764,22 @@ def execute(args): args['dem_raster_path'])['pixel_size'] LOGGER.info('Checking that the AOI is not the output aggregate vector') - LOGGER.debug("aoi_path: %s", args['aoi_path']) - LOGGER.debug("aggregate_vector_path: %s", - os.path.normpath(file_registry['aggregate_vector'])) + LOGGER.debug(f"aoi_path: {args['aoi_path']}") + LOGGER.debug("aggregate_vector_path: " + f"{os.path.normpath(file_registry['aggregate_vector'])}") if (os.path.normpath(args['aoi_path']) == os.path.normpath(file_registry['aggregate_vector'])): raise ValueError( "The input AOI is the same as the output aggregate vector, " "please choose a different workspace or move the AOI file " - "out of the current workspace %s" % - file_registry['aggregate_vector']) + f"out of the current workspace {file_registry['aggregate_vector']}") LOGGER.info('Aligning and clipping dataset list') input_align_list = [args['lulc_raster_path'], args['dem_raster_path']] output_align_list = [ file_registry['lulc_aligned'], file_registry['dem_aligned']] if not args['user_defined_local_recharge']: - month_indexes = [m+1 for m in range(N_MONTHS)] + month_indexes = [m for m in MONTH_RANGE] precip_df = MODEL_SPEC.get_input( 'precip_raster_table').get_validated_dataframe( @@ -711,9 +804,9 @@ def execute(args): precip_path_list + [args['soil_group_path']] + et0_path_list + input_align_list) output_align_list = ( - [file_registry['prcp_a[MONTH]', month] for month in range(12)] + + [file_registry['prcp_a[MONTH]', month] for month in MONTH_RANGE] + [file_registry['soil_group_aligned']] + - [file_registry['et0_a[MONTH]', month] for month in range(12)] + + [file_registry['et0_a[MONTH]', month] for month in MONTH_RANGE] + output_align_list) align_index = len(input_align_list) - 1 # this aligns with the DEM @@ -825,7 +918,7 @@ def execute(args): reclass_error_details = { 'raster_name': 'Climate Zone', 'column_name': 'cz_id', 'table_name': 'Climate Zone'} - for month_id in range(N_MONTHS): + for month_id in MONTH_RANGE: if args['user_defined_climate_zones']: cz_rain_events_df = MODEL_SPEC.get_input( 'climate_zone_table_path').get_validated_dataframe( @@ -843,7 +936,7 @@ def execute(args): target_path_list=[ file_registry['n_events[MONTH]', month_id]], dependent_task_list=[align_task], - task_name='n_events for month %d' % month_id) + task_name=f'n_events for month {month_id}') reclassify_n_events_task_list.append(n_events_task) else: n_events_task = task_graph.add_task( @@ -853,12 +946,12 @@ def execute(args): file_registry['n_events[MONTH]', month_id], gdal.GDT_Float32, [TARGET_NODATA]), kwargs={'fill_value_list': ( - rain_events_df['events'][month_id+1],)}, + rain_events_df['events'][month_id],)}, target_path_list=[ file_registry['n_events[MONTH]', month_id]], dependent_task_list=[align_task], task_name=( - 'n_events as a constant raster month %d' % month_id)) + f'n_events as a constant raster month {month_id}')) reclassify_n_events_task_list.append(n_events_task) curve_number_task = task_graph.add_task( @@ -882,8 +975,8 @@ def execute(args): task_name='calculate Si raster') quick_flow_task_list = [] - for month_index in range(N_MONTHS): - LOGGER.info('calculate quick flow for month %d', month_index+1) + for month_index in MONTH_RANGE: + LOGGER.info(f'calculate quick flow for month {month_index}') monthly_quick_flow_task = task_graph.add_task( func=_calculate_monthly_quick_flow, args=( @@ -891,21 +984,20 @@ def execute(args): file_registry['n_events[MONTH]', month_index], file_registry['stream'], file_registry['si'], - file_registry['qf_[MONTH]', month_index + 1]), + file_registry['qf_[MONTH]', month_index]), target_path_list=[ - file_registry['qf_[MONTH]', month_index + 1]], + file_registry['qf_[MONTH]', month_index]], dependent_task_list=[ - align_task, reclassify_n_events_task_list[month_index], + align_task, reclassify_n_events_task_list[month_index-1], si_task, stream_threshold_task], - task_name='calculate quick flow for month %d' % ( - month_index+1)) + task_name=f'calculate quick flow for month {month_index}') quick_flow_task_list.append(monthly_quick_flow_task) qf_task = task_graph.add_task( func=pygeoprocessing.raster_map, kwargs=dict( op=qfi_sum_op, - rasters=[file_registry['qf_[MONTH]', month] for month in range(1, 13)], + rasters=[file_registry['qf_[MONTH]', month] for month in MONTH_RANGE], target_path=file_registry['qf']), target_path_list=[file_registry['qf']], dependent_task_list=quick_flow_task_list, @@ -916,8 +1008,8 @@ def execute(args): reclass_error_details = { 'raster_name': 'LULC', 'column_name': 'lucode', 'table_name': 'Biophysical'} - for month_index in range(N_MONTHS): - kc_lookup = biophysical_df['kc_%d' % (month_index+1)].to_dict() + for month_index in MONTH_RANGE: + kc_lookup = biophysical_df[f'kc_{month_index}'].to_dict() kc_task = task_graph.add_task( func=utils.reclassify_raster, args=( @@ -926,7 +1018,7 @@ def execute(args): gdal.GDT_Float32, TARGET_NODATA, reclass_error_details), target_path_list=[file_registry['kc_[MONTH]', month_index]], dependent_task_list=[align_task], - task_name='classify kc month %d' % month_index) + task_name=f'classify kc month {month_index}') kc_task_list.append(kc_task) # call through to a cython function that does the necessary routing @@ -934,11 +1026,11 @@ def execute(args): calculate_local_recharge_task = task_graph.add_task( func=seasonal_water_yield_core.calculate_local_recharge, args=( - [file_registry['prcp_a[MONTH]', month] for month in range(12)], - [file_registry['et0_a[MONTH]', month] for month in range(12)], - [file_registry['qf_[MONTH]', month] for month in range(1, 13)], + [file_registry['prcp_a[MONTH]', month] for month in MONTH_RANGE], + [file_registry['et0_a[MONTH]', month] for month in MONTH_RANGE], + [file_registry['qf_[MONTH]', month] for month in MONTH_RANGE], file_registry['flow_dir'], - [file_registry['kc_[MONTH]', month] for month in range(12)], + [file_registry['kc_[MONTH]', month] for month in MONTH_RANGE], alpha_month_map, beta_i, gamma, file_registry['stream'], file_registry['l'], @@ -1027,6 +1119,21 @@ def execute(args): dependent_task_list=b_sum_dependent_task_list + [l_sum_task], task_name='calculate B_sum') + if not args['user_defined_local_recharge']: + monthly_csv_task = task_graph.add_task( + func=_generate_monthly_qf_b_p_csv, + args=( + file_registry['aggregate_vector'], + file_registry['b'], + [file_registry['qf_[MONTH]', month] for month in MONTH_RANGE], + [file_registry['prcp_a[MONTH]', month] for month in MONTH_RANGE], + file_registry['monthly_qf_table'] + ), + target_path_list=[file_registry['monthly_qf_table']], + dependent_task_list=[ + aggregate_recharge_task, align_task, b_sum_task] + quick_flow_task_list, + task_name='create monthly qf csv') + task_graph.close() task_graph.join() @@ -1398,8 +1505,7 @@ def _aggregate_recharge( """ if os.path.exists(aggregate_vector_path): LOGGER.warning( - '%s exists, deleting and writing new output', - aggregate_vector_path) + f'{aggregate_vector_path} exists, deleting and writing new output') os.remove(aggregate_vector_path) original_aoi_vector = gdal.OpenEx(aoi_path, gdal.OF_VECTOR) @@ -1414,7 +1520,6 @@ def _aggregate_recharge( for raster_path, aggregate_field_id, op_type in [ (l_path, 'qb', 'mean'), (vri_path, 'vri_sum', 'sum')]: - # aggregate carbon stocks by the new ID field aggregate_stats = pygeoprocessing.zonal_statistics( (raster_path, 1), aggregate_vector_path) @@ -1440,12 +1545,96 @@ def _aggregate_recharge( poly_feat.SetField(aggregate_field_id, float(value)) aggregate_layer.SetFeature(poly_feat) + fid_field = ogr.FieldDefn('geom_id', ogr.OFTInteger) + aggregate_layer.CreateField(fid_field) + for feature in aggregate_layer: + feature_id = feature.GetFID() + feature.SetField('geom_id', feature_id) + aggregate_layer.SetFeature(feature) + aggregate_layer.SyncToDisk() aggregate_layer = None gdal.Dataset.__swig_destroy__(aggregate_vector) aggregate_vector = None +def _generate_monthly_qf_b_p_csv( + aoi_path, annual_baseflow_path, monthly_quickflow_path_list, + monthly_precip_path_list, target_csv_path): + """Generate a CSV of average monthly Qf, B, and P values for the watersheds/AOI. + + Args: + aoi_path (string): path to shapefile that will be used to + aggregate rasters + annual_baseflow_path (string): path to the annual baseflow raster + monthly_quickflow_path_list (list): list of paths to monthly quickflow + rasters + monthly_precip_path_list (list): list of paths to aligned monthly + precipitation rasters + target_csv_path (string): path to the output CSV. If this file exists + on disk prior to the call, it is overwritten with the result of + this call. + + Returns: + None + """ + if os.path.exists(target_csv_path): + LOGGER.warning(f'{target_csv_path} exists, deleting and writing new output') + os.remove(target_csv_path) + + # Use the baseflow raster to get the pixel_size; all rasters should be aligned + raster_info = pygeoprocessing.get_raster_info(annual_baseflow_path) + pixel_area_m2 = numpy.prod([abs(x) for x in raster_info['pixel_size']]) + + # The baseflow raster is annual, so there will only be 1 + raster_path_tuples = [(annual_baseflow_path, 0, "baseflow")] + raster_path_tuples.extend([ + (qf_path, month, "quickflow") for qf_path, month + in zip(monthly_quickflow_path_list, MONTH_RANGE)]) + raster_path_tuples.extend([ + (precip_path, month, "precipitation") for precip_path, month + in zip(monthly_precip_path_list, MONTH_RANGE)]) + + raster_stats = pygeoprocessing.zonal_statistics( + [(path_tuple[0], 1) for path_tuple in raster_path_tuples], + aoi_path) + + stats_by_field = { + "baseflow": {}, + "quickflow": {}, + "precipitation": {} + } + for (_, month, field), stats in zip(raster_path_tuples, raster_stats): + stats_by_field[field][month] = stats + + # Baseflow is computed annually; distribute evenly over the year + b_avg_per_feat_per_month = {k: v['sum'] * 0.001 * pixel_area_m2 / 12 + for k, v in stats_by_field['baseflow'][0].items()} + + values_dict = {fid: {month: {'baseflow': b_val} + for month in MONTH_RANGE} + for fid, b_val in b_avg_per_feat_per_month.items()} + + for value_name in ['quickflow', 'precipitation']: + for month in MONTH_RANGE: + avg_per_feat_per_month = { + k: v['sum'] * 0.001 * pixel_area_m2 + for k, v in stats_by_field[value_name][month].items()} + for fid, value in avg_per_feat_per_month.items(): + values_dict[fid][month][value_name] = value + + with open(target_csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile, delimiter=',') + writer.writerow(['geom_id', 'month', 'quickflow', 'baseflow', 'precipitation']) + for fid, month_dicts in values_dict.items(): + for month, val_dicts in month_dicts.items(): + writer.writerow([fid, + month, + val_dicts['quickflow'], + val_dicts['baseflow'], + val_dicts['precipitation']]) + + @validation.invest_validator def validate(args, limit_to=None): """Validate args to ensure they conform to `execute`'s contract. diff --git a/tests/reports/test_raster_utils.py b/tests/reports/test_raster_utils.py index f4f04a2312..20462e63e3 100644 --- a/tests/reports/test_raster_utils.py +++ b/tests/reports/test_raster_utils.py @@ -339,6 +339,24 @@ def tearDown(self): """Override tearDown function to remove temporary directory.""" shutil.rmtree(self.workspace_dir) + @staticmethod + @patch('natcap.invest.reports.raster_utils._get_raster_units') + def create_small_plots_grid(workspace_dir, shape, mock_get_raster_units, + supertitle=None): + raster_paths = [os.path.join(workspace_dir, f'{s}.tif') + for s in ['a', 'b', 'c', 'd']] + arrays = [numpy.linspace( + i, i+1, num=numpy.multiply(*shape)).reshape(*shape) for i in range(4)] + for raster_path, array in zip(raster_paths, arrays): + pygeoprocessing.numpy_array_to_raster( + array, target_nodata=None, pixel_size=(1, 1), origin=(0, 0), + projection_wkt=PROJ_WKT, target_path=raster_path) + + mock_get_raster_units.return_value = 'flux capacitrons' + return raster_utils.plot_raster_facets( + raster_paths, RasterDatatype.continuous, small_plots=True, + supertitle=supertitle) + def test_plot_raster_facets(self): """Test rasters share a common colorscale.""" figname = 'plot_raster_facets.png' @@ -363,6 +381,40 @@ def test_plot_raster_facets(self): save_figure(fig, actual_png) compare_snapshots(reference, actual_png) + def test_plot_raster_facets_small_plots(self): + """Test small plots: standard AOI width should have 4 columns.""" + figname = 'plot_raster_facets_small_plots.png' + reference = os.path.join(REFS_DIR, figname) + shape = (4, 4) + fig = self.create_small_plots_grid(self.workspace_dir, shape) + + actual_png = os.path.join(self.workspace_dir, figname) + save_figure(fig, actual_png) + compare_snapshots(reference, actual_png) + + def test_plot_raster_facets_small_plots_wide_aoi(self): + """Test small plots: wide AOI width should have 3 columns.""" + figname = 'plot_raster_facets_small_plots_wide_aoi.png' + reference = os.path.join(REFS_DIR, figname) + shape = (6, 12) + fig = self.create_small_plots_grid(self.workspace_dir, shape) + + actual_png = os.path.join(self.workspace_dir, figname) + save_figure(fig, actual_png) + compare_snapshots(reference, actual_png) + + def test_plot_raster_facets_small_plots_supertitle(self): + """Test small plots with optional supertitle.""" + figname = 'plot_raster_facets_small_plots_supertitle.png' + reference = os.path.join(REFS_DIR, figname) + shape = (2, 10) + fig = self.create_small_plots_grid(self.workspace_dir, shape, + supertitle="Custom Title") + + actual_png = os.path.join(self.workspace_dir, figname) + save_figure(fig, actual_png) + compare_snapshots(reference, actual_png) + class RasterPlotTitleTests(unittest.TestCase): """Snapshot tests for plotting rasters with various titles.""" @@ -542,5 +594,5 @@ def test_raster_workspace_summary(self): file_registry, args_dict) dataframe = raster_utils.raster_workspace_summary(file_registry) - # There are 2 rasters in the sample output spec - self.assertEqual(dataframe.shape, (2, 7)) + # There are 3 rasters in the sample output spec + self.assertEqual(dataframe.shape, (3, 7)) diff --git a/tests/reports/test_swy_template.py b/tests/reports/test_swy_template.py new file mode 100644 index 0000000000..24f4d24b24 --- /dev/null +++ b/tests/reports/test_swy_template.py @@ -0,0 +1,105 @@ +import time +import unittest + +import pandas +from bs4 import BeautifulSoup + +from natcap.invest.seasonal_water_yield import MODEL_SPEC +from natcap.invest.reports import jinja_env + +TEMPLATE = jinja_env.get_template('models/seasonal_water_yield.html') + +BSOUP_HTML_PARSER = 'html.parser' + + +def _get_render_args(model_spec): + report_filepath = 'swy_report_test.html' + invest_version = '987.65.0' + timestamp = time.strftime('%Y-%m-%d %H:%M') + args_dict = {'suffix': 'test'} + img_src = 'bAse64eNcoDEdIMagE' + output_stats_table = '
' + input_stats_table = '
' + stats_table_note = 'This is a test!' + raster_group_caption = 'This is another test!' + stream_caption = ['stream.tif:Stream map.'] + heading = 'Test heading' + outputs_caption = ['results.tif:Results map.'] + inputs_caption = ['input.tif:Input map.'] + vegalite_json = '{}' + caption = 'figure caption' + agg_map_source_list = ['/source/file.shp'] + + return { + 'report_script': model_spec.reporter, + 'invest_version': invest_version, + 'report_filepath': report_filepath, + 'model_id': model_spec.model_id, + 'model_name': model_spec.model_title, + 'model_description': model_spec.about, + 'userguide_page': model_spec.userguide, + 'timestamp': timestamp, + 'args_dict': args_dict, + 'raster_group_caption': raster_group_caption, + 'stats_table_note': stats_table_note, + 'stream_img_src': img_src, + 'stream_caption': stream_caption, + 'stream_outputs_heading': heading, + 'outputs_heading': heading, + 'outputs_img_src': img_src, + 'outputs_caption': outputs_caption, + 'qf_rasters': None, + 'output_raster_stats_table': output_stats_table, + 'input_raster_stats_table': input_stats_table, + 'inputs_img_src': img_src, + 'inputs_caption': inputs_caption, + 'qf_b_charts': None, + 'qb_map_json': vegalite_json, + 'qb_map_caption': caption, + 'vri_sum_map_json': vegalite_json, + 'vri_sum_map_caption': caption, + 'aggregate_map_source_list': agg_map_source_list, + 'model_spec_outputs': model_spec.outputs + } + +class SeasonalWaterYieldTemplateTests(unittest.TestCase): + """Unit tests for SWY template.""" + + def test_render_without_user_defined_recharge(self): + """Test report rendering when user_defined_local_recharge is False.""" + + render_args = _get_render_args(MODEL_SPEC) + + render_args['qf_rasters'] = { + 'annual_qf_img_src': 'bAse64eNcoDEdIMagE', + 'monthly_qf_img_src': 'bAse64eNcoDEdIMagE', + 'qf_caption': 'test caption' + } + render_args['qf_b_charts'] = { + 'json': '{}', + 'caption': 'test caption', + 'sources': ['/src/path.shp', '/src/path.csv'] + } + html = TEMPLATE.render(render_args) + soup = BeautifulSoup(html, BSOUP_HTML_PARSER) + + sections = soup.find_all(class_='accordion-section') + # Includes quickflow raster section and monthly qf + b charts section + self.assertEqual(len(sections), 10) + + self.assertEqual( + soup.h1.string, f'InVEST Results: {MODEL_SPEC.model_title}') + + def test_render_with_user_defined_recharge(self): + """Test report rendering when user_defined_local_recharge is True.""" + + render_args = _get_render_args(MODEL_SPEC) + html = TEMPLATE.render(render_args) + soup = BeautifulSoup(html, BSOUP_HTML_PARSER) + + sections = soup.find_all(class_='accordion-section') + # No quickflow raster section or monthly qf + b charts section + self.assertEqual(len(sections), 8) + + self.assertEqual( + soup.h1.string, f'InVEST Results: {MODEL_SPEC.model_title}') diff --git a/tests/test_seasonal_water_yield_regression.py b/tests/test_seasonal_water_yield_regression.py index 777064409f..fb54023707 100644 --- a/tests/test_seasonal_water_yield_regression.py +++ b/tests/test_seasonal_water_yield_regression.py @@ -14,9 +14,6 @@ from .utils import assert_complete_execute gdal.UseExceptions() -REGRESSION_DATA = os.path.join( - os.path.dirname(__file__), '..', 'data', 'invest-test-data', - 'seasonal_water_yield') def make_simple_shp(base_shp_path, origin): @@ -601,7 +598,8 @@ def test_base_regression(self): execute_kwargs = { 'generate_report': bool(seasonal_water_yield.MODEL_SPEC.reporter), - 'save_file_registry': True + 'save_file_registry': True, + 'check_outputs': True } seasonal_water_yield.MODEL_SPEC.execute(args, **execute_kwargs) assert_complete_execute( @@ -616,6 +614,18 @@ def test_base_regression(self): os.path.join(args['workspace_dir'], 'aggregated_results_swy.shp'), agg_results_csv_path) + # check the values in the avg monthly quickflow baseflow precip csv + actual_result_df = pandas.read_csv( + os.path.join(args['workspace_dir'], 'monthly_quickflow_baseflow.csv')) + expected_qf = [56.69889, 62.00944, 67.35032, 72.72129, 78.12209, 83.55236, + 89.01173, 94.49973, 100.01591, 105.55979, 111.13096, 116.72885] + expected_b = [60.96804 for i in range(12)] + expected_p = [110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220] + for expected_val, col_name in [(expected_qf, 'quickflow'), + (expected_b, 'baseflow'), (expected_p, 'precipitation')]: + numpy.testing.assert_allclose(expected_val, actual_result_df[col_name], + rtol=1e-5) + def test_base_regression_d8(self): """SWY base regression test on sample data in D8 mode. @@ -668,6 +678,18 @@ def test_base_regression_d8(self): if mismatch_list: raise RuntimeError(f'results not expected: {mismatch_list}') + # check the values in the avg monthly quickflow baseflow precip csv + actual_result_df = pandas.read_csv( + os.path.join(args['workspace_dir'], 'monthly_quickflow_baseflow.csv')) + expected_qf = [55.81547, 61.04926, 66.31405, 71.60957, 76.93555, 82.29161, + 87.67738, 93.09235, 98.53606, 104.00803, 109.50784, 115.03485] + expected_b = [61.969 for i in range(12)] + expected_p = [110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220] + for expected_val, col_name in [(expected_qf, 'quickflow'), + (expected_b, 'baseflow'), (expected_p, 'precipitation')]: + numpy.testing.assert_allclose(expected_val, actual_result_df[col_name], + rtol=1e-5) + def test_base_regression_nodata_inf(self): """SWY base regression test on sample data with really small nodata. @@ -932,7 +954,14 @@ def test_user_recharge(self): make_recharge_raster(recharge_ras_path) args['l_path'] = recharge_ras_path - seasonal_water_yield.execute(args) + execute_kwargs = { + 'generate_report': bool(seasonal_water_yield.MODEL_SPEC.reporter), + 'save_file_registry': True, + 'check_outputs': True + } + seasonal_water_yield.MODEL_SPEC.execute(args, **execute_kwargs) + assert_complete_execute( + args, seasonal_water_yield.MODEL_SPEC, **execute_kwargs) # generate aggregated results csv table for assertion agg_results_csv_path = os.path.join(args['workspace_dir'], @@ -943,6 +972,49 @@ def test_user_recharge(self): os.path.join(args['workspace_dir'], 'aggregated_results_swy.shp'), agg_results_csv_path) + def test_user_climate_zones(self): + """SWY user climate zones test on sample data. + + Executes SWY in user defined climate zones mode and checks that the + output files are generated and that the aggregate shapefile fields + are the same as the regression case. + """ + from natcap.invest.seasonal_water_yield import seasonal_water_yield + + # use predefined directory so test can clean up files during teardown + workspace_dir = os.path.join(self.workspace_dir, 'workspace') + os.mkdir(workspace_dir) + args = SeasonalWaterYieldRegressionTests.generate_base_args( + workspace_dir) + args['monthly_alpha'] = False + args['results_suffix'] = '' + + cz_csv_path = os.path.join(self.workspace_dir, 'cz.csv') + make_climate_zone_csv(cz_csv_path) + cz_ras_path = os.path.join(args['workspace_dir'], 'dem.tif') + make_gradient_raster(cz_ras_path) + args['climate_zone_raster_path'] = cz_ras_path + args['climate_zone_table_path'] = cz_csv_path + args['user_defined_climate_zones'] = True + + execute_kwargs = { + 'generate_report': bool(seasonal_water_yield.MODEL_SPEC.reporter), + 'save_file_registry': True, + 'check_outputs': True + } + seasonal_water_yield.MODEL_SPEC.execute(args, **execute_kwargs) + assert_complete_execute( + args, seasonal_water_yield.MODEL_SPEC, **execute_kwargs) + + # generate aggregated results csv table for assertion + agg_results_csv_path = os.path.join(args['workspace_dir'], + 'agg_results_cz.csv') + make_agg_results_csv(agg_results_csv_path, climate_zones=True) + + SeasonalWaterYieldRegressionTests._assert_regression_results_equal( + os.path.join(args['workspace_dir'], 'aggregated_results_swy.shp'), + agg_results_csv_path) + @staticmethod def _assert_regression_results_equal( result_vector_path, agg_results_path): @@ -1262,10 +1334,11 @@ def test_local_recharge_undefined_nodata(self): target_aet_path = os.path.join(self.workspace_dir, 'target_aet_path.tif') + month_range = range(1, 13) seasonal_water_yield_core.calculate_local_recharge( - [precip_path for i in range(12)], [et0_path for i in range(12)], - [quickflow_path for i in range(12)], flow_dir_path, - [kc_path for i in range(12)], alpha_month_map, beta, + [precip_path for i in month_range], [et0_path for i in month_range], + [quickflow_path for i in month_range], flow_dir_path, + [kc_path for i in month_range], alpha_month_map, beta, gamma, stream_path, target_li_path, target_li_avail_path, target_l_sum_avail_path, target_aet_path, os.path.join(self.workspace_dir, 'target_precip_path.tif'), diff --git a/tests/test_spec.py b/tests/test_spec.py index 654587c16e..a7c9ec5b6a 100644 --- a/tests/test_spec.py +++ b/tests/test_spec.py @@ -288,7 +288,7 @@ def test_write_metadata_for_outputs(self): SAMPLE_MODEL_SPEC.generate_metadata_for_outputs(file_registry, args_dict) files, messages = geometamaker.validate_dir(self.workspace_dir) - self.assertEqual(len(files), 4) + self.assertEqual(len(files), 5) self.assertFalse(any(messages)) # Test some specific content of the metadata diff --git a/tests/utils.py b/tests/utils.py index 0a10cbd6ff..e74279410a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -107,12 +107,7 @@ def fake_execute(output_spec, workspace): dict (FileRegistry.registry) """ - file_registry = FileRegistry(output_spec, workspace) - for spec_data in output_spec: - reg_key = spec_data.id - if '[' in spec_data.id: - reg_key = (spec_data.id, 'A') - filepath = file_registry[reg_key] + def _create_file(spec_data, filepath): if isinstance(spec_data, spec.SingleBandRasterOutput): driver = gdal.GetDriverByName('GTIFF') raster = driver.Create(filepath, 2, 2, 1, gdal.GDT_Byte) @@ -134,4 +129,17 @@ def fake_execute(output_spec, workspace): # Such as taskgraph.db, just create the file. with open(filepath, 'w') as file: pass + + file_registry = FileRegistry(output_spec, workspace) + for spec_data in output_spec: + reg_key = spec_data.id + if '[' in spec_data.id: + reg_keys = [(spec_data.id, 'A'), (spec_data.id, 'B')] + for reg_key in reg_keys: + filepath = file_registry[reg_key] + _create_file(spec_data, filepath) + else: + filepath = file_registry[reg_key] + _create_file(spec_data, filepath) + return file_registry.registry