Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/config/plotting.csv
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
``graticule_radii`` None Define a list of radii at which circular graticules will be drawn.
``interactive`` False Enable interactive mode. If True then plots will be drawn after each plotting command.
``label_ts_threshold`` 0.0 TS threshold for labeling sources in sky maps. If None then no sources will be labeled.
``label_source`` None Name(s) of source(s) to label on plots. If specified, only these sources will be labeled, overriding label_ts_threshold. Can be a single source name (string) or a list of source names.
``loge_bounds`` None
7 changes: 7 additions & 0 deletions fermipy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def validate_option(opt_name, opt_val, schema_type):
if opt_val is None:
return

# Union type: schema_type is a tuple of allowed types (e.g. (str, list))
if isinstance(schema_type, tuple):
if type(opt_val) not in schema_type:
raise TypeError('Wrong type for %s %s (allowed: %s)' %
(opt_name, type(opt_val), schema_type))
return

type_match = type(opt_val) is schema_type
type_checks = (schema_type in [list, dict, bool] or
type(opt_val) in [list, dict, bool])
Expand Down
2 changes: 2 additions & 0 deletions fermipy/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ def make_attrs_class(typename, d):
'figsize': ([8.0, 6.0], 'Set the default figure size.', list),
'label_ts_threshold':
(0., 'TS threshold for labeling sources in sky maps. If None then no sources will be labeled.', float),
'label_source':
(None, 'Name(s) of source(s) to label on plots. If specified, only these sources will be labeled, overriding label_ts_threshold. Can be a single source name (string) or a list of source names.', (str, list)),
'interactive': (False, 'Enable interactive mode. If True then plots will be drawn after each plotting command.', bool),
}

Expand Down
46 changes: 42 additions & 4 deletions fermipy/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ class ROIPlotter(fermipy.config.Configurable):
'catalogs': (None, '', list),
'graticule_radii': (None, '', list),
'label_ts_threshold': (0.0, '', float),
'label_source': (None, '', (str, list)),
'cmap': ('ds9_b', '', str),
}

Expand Down Expand Up @@ -534,6 +535,7 @@ def plot_roi(self, roi, **kwargs):
src_color = 'w'

label_ts_threshold = kwargs.get('label_ts_threshold', 0.0)
label_source = kwargs.get('label_source', None)
plot_kwargs = dict(linestyle='None', marker='+',
markerfacecolor='None', mew=0.66, ms=8,
# markersize=8,
Expand All @@ -543,16 +545,24 @@ def plot_roi(self, roi, **kwargs):
fontweight='normal')

ts = np.array([s['ts'] for s in roi.point_sources])
skydir = roi._src_skydir
labels = [s.name for s in roi.point_sources]

if label_ts_threshold is None:
# If label_source is specified, only label those sources
if label_source is not None:
# Convert to list if it's a single string
if utils.isstr(label_source):
label_source = [label_source]
# Create mask for sources matching the specified names
m = np.array([s.name in label_source for s in roi.point_sources])
# Otherwise use the TS threshold logic
elif label_ts_threshold is None:
m = np.zeros(len(ts), dtype=bool)
elif label_ts_threshold <= 0:
m = np.ones(len(ts), dtype=bool)
else:
m = ts > label_ts_threshold

skydir = roi._src_skydir
labels = [s.name for s in roi.point_sources]
self.plot_sources(skydir, labels, plot_kwargs, text_kwargs,
label_mask=m, **kwargs)

Expand Down Expand Up @@ -587,6 +597,7 @@ def plot(self, **kwargs):
self.config['graticule_radii'])
label_ts_threshold = kwargs.get('label_ts_threshold',
self.config['label_ts_threshold'])
label_source = kwargs.get('label_source', self.config['label_source'])

im_kwargs = dict(cmap=self.config['cmap'],
interpolation='nearest', transform=None,
Expand All @@ -608,7 +619,8 @@ def plot(self, **kwargs):

if self._roi is not None:
self.plot_roi(self._roi,
label_ts_threshold=label_ts_threshold)
label_ts_threshold=label_ts_threshold,
label_source=label_source)

self._extent = im.get_extent()
ax.set_xlim(self._extent[0], self._extent[1])
Expand Down Expand Up @@ -979,6 +991,17 @@ def make_residmap_plots(self, maps, roi=None, **kwargs):
Crop the image by this factor. If None then no crop is
applied.

label_source : str or list, optional
Name(s) of source(s) to label on the plot. If specified,
only these sources will be labeled, overriding the
`label_ts_threshold` setting. Can be a single source name
(string) or a list of source names.

label_ts_threshold : float, optional
TS threshold for labeling sources. Only used if
`label_source` is not specified. If None, no sources will
be labeled. If <= 0, all sources will be labeled.

"""

fmt = kwargs.get('format', self.config['format'])
Expand All @@ -992,6 +1015,7 @@ def make_residmap_plots(self, maps, roi=None, **kwargs):
kwargs.setdefault('graticule_radii', self.config['graticule_radii'])
kwargs.setdefault('label_ts_threshold',
self.config['label_ts_threshold'])
kwargs.setdefault('label_source', self.config['label_source'])
cmap = kwargs.setdefault('cmap', self.config['cmap'])
cmap_resid = kwargs.pop('cmap_resid', self.config['cmap_resid'])
kwargs.setdefault('catalogs', self.config['catalogs'])
Expand Down Expand Up @@ -1113,10 +1137,22 @@ def make_tsmap_plots(self, maps, roi=None, **kwargs):
zoom : float
Crop the image by this factor. If None then no crop is
applied.

label_source : str or list, optional
Name(s) of source(s) to label on the plot. If specified,
only these sources will be labeled, overriding the
`label_ts_threshold` setting. Can be a single source name
(string) or a list of source names.

label_ts_threshold : float, optional
TS threshold for labeling sources. Only used if
`label_source` is not specified. If None, no sources will
be labeled. If <= 0, all sources will be labeled.
"""
kwargs.setdefault('graticule_radii', self.config['graticule_radii'])
kwargs.setdefault('label_ts_threshold',
self.config['label_ts_threshold'])
kwargs.setdefault('label_source', self.config['label_source'])
kwargs.setdefault('cmap', self.config['cmap'])
kwargs.setdefault('catalogs', self.config['catalogs'])
fmt = kwargs.get('format', self.config['format'])
Expand Down Expand Up @@ -1199,6 +1235,7 @@ def make_psmap_plots(self, psmaps, roi=None, **kwargs):
kwargs.setdefault('graticule_radii', self.config['graticule_radii'])
kwargs.setdefault('label_ts_threshold',
self.config['label_ts_threshold'])
kwargs.setdefault('label_source', self.config['label_source'])
kwargs.setdefault('cmap', self.config['cmap'])
kwargs.setdefault('catalogs', self.config['catalogs'])
fmt = kwargs.get('format', self.config['format'])
Expand Down Expand Up @@ -1371,6 +1408,7 @@ def make_roi_plots(self, gta, mcube_tot, **kwargs):
'graticule_radii', self.config['graticule_radii'])
roi_kwargs.setdefault('label_ts_threshold',
self.config['label_ts_threshold'])
roi_kwargs.setdefault('label_source', self.config['label_source'])
roi_kwargs.setdefault('cmap', self.config['cmap'])
roi_kwargs.setdefault('catalogs', self._catalogs)

Expand Down
13 changes: 12 additions & 1 deletion fermipy/tests/test_gtanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,18 @@ def test_gtanalysis_fit_newton(create_diffuse_dir, create_draco_analysis):
def test_gtanalysis_tsmap(create_diffuse_dir, create_draco_analysis):
gta = create_draco_analysis
gta.load_roi('fit1')
gta.tsmap(model={}, make_plots=True)
gta.tsmap(model={}, make_plots=True, prefix='tsmap_default')
gta.config['plotting']['label_source'] = None
gta.config['plotting']['label_ts_threshold'] = 25.0
gta.tsmap(model={}, make_plots=True, prefix='tsmap_ts25')

gta.config['plotting']['label_source'] = ['draco',
'4FGL J1741.2+5739',
'4FGL J1742.5+5944']
gta.config['plotting']['label_ts_threshold'] = 0.0
gta.tsmap(model={}, make_plots=True, prefix='tsmap_label_source')
gta.config['plotting']['label_source'] = None


def test_gtanalysis_psmap(create_diffuse_dir, create_draco_analysis):
gta = create_draco_analysis
Expand Down
44 changes: 44 additions & 0 deletions fermipy/tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import absolute_import, division, print_function

import numpy as np

from fermipy.plotting import ROIPlotter


class _DummySrc(dict):
def __init__(self, name, ts):
super(_DummySrc, self).__init__(ts=ts)
self.name = name


class _DummyROI(object):
def __init__(self):
self.point_sources = [
_DummySrc('srcA', 10.0),
_DummySrc('srcB', 20.0),
_DummySrc('srcC', 5.0),
]
self._src_skydir = None


class _MaskRecorder(object):
def __init__(self):
self.label_mask = None

def plot_sources(self, skydir, labels, plot_kwargs, text_kwargs, **kwargs):
self.label_mask = kwargs.get('label_mask')


def test_plot_roi_label_source_and_ts_threshold():
"""Verify source labeling by explicit names and TS threshold."""
roi = _DummyROI()

# label_source path: only listed sources are labeled.
recorder = _MaskRecorder()
ROIPlotter.plot_roi(recorder, roi, label_source=['srcA', 'srcC'])
assert np.array_equal(recorder.label_mask, np.array([True, False, True]))

# label_ts_threshold path: only sources above threshold are labeled.
recorder = _MaskRecorder()
ROIPlotter.plot_roi(recorder, roi, label_ts_threshold=12.0)
assert np.array_equal(recorder.label_mask, np.array([False, True, False]))
Loading