Skip to content

Commit

Permalink
Merge pull request #2133 from djhoese/feature-enh-map-blocks
Browse files Browse the repository at this point in the history
Rewrite 'apply_enhancement' as individual decorators to allow for easier dask map_blocks usage
  • Loading branch information
mraspaud authored Aug 5, 2022
2 parents 07ed078 + 0fb7080 commit aa7f0dd
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 136 deletions.
222 changes: 126 additions & 96 deletions satpy/enhancements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import logging
import os
import warnings
from functools import partial
from collections import namedtuple
from functools import wraps
from numbers import Number

import dask
Expand Down Expand Up @@ -49,67 +50,77 @@ def invert(img, *args):
return img.invert(*args)


def apply_enhancement(data, func, exclude=None, separate=False,
pass_dask=False):
"""Apply `func` to the provided data.
def exclude_alpha(func):
"""Exclude the alpha channel from the DataArray before further processing."""
@wraps(func)
def wrapper(data, **kwargs):
bands = data.coords['bands'].values
exclude = ['A'] if 'A' in bands else []
band_data = data.sel(bands=[b for b in bands
if b not in exclude])
band_data = func(band_data, **kwargs)

attrs = data.attrs
attrs.update(band_data.attrs)
# combine the new data with the excluded data
new_data = xr.concat([band_data, data.sel(bands=exclude)],
dim='bands')
data.data = new_data.sel(bands=bands).data
data.attrs = attrs
return data
return wrapper

Args:
data (xarray.DataArray): Data to be modified inplace.
func (callable): Function to be applied to an xarray
exclude (iterable): Bands in the 'bands' dimension to not include
in the calculations.
separate (bool): Apply `func` one band at a time. Default is False.
pass_dask (bool): Pass the underlying dask array instead of the
xarray.DataArray.

"""
attrs = data.attrs
bands = data.coords['bands'].values
if exclude is None:
exclude = ['A'] if 'A' in bands else []
def on_separate_bands(func):
"""Apply `func` one band of the DataArray at a time.
If this decorator is to be applied along with `on_dask_array`, this decorator has to be applied first, eg::
@on_separate_bands
@on_dask_array
def my_enhancement_function(data):
...
if separate:
"""
@wraps(func)
def wrapper(data, **kwargs):
attrs = data.attrs
data_arrs = []
for idx, band_name in enumerate(bands):
band_data = data.sel(bands=[band_name])
if band_name in exclude:
# don't modify alpha
data_arrs.append(band_data)
continue

if pass_dask:
dims = band_data.dims
coords = band_data.coords
d_arr = func(band_data.data, index=idx)
band_data = xr.DataArray(d_arr, dims=dims, coords=coords)
else:
band_data = func(band_data, index=idx)
for idx, band in enumerate(data.coords['bands'].values):
band_data = func(data.sel(bands=[band]), index=idx, **kwargs)
data_arrs.append(band_data)
# we assume that the func can add attrs
attrs.update(band_data.attrs)

data.data = xr.concat(data_arrs, dim='bands').data
data.attrs = attrs
return data

band_data = data.sel(bands=[b for b in bands
if b not in exclude])
if pass_dask:
dims = band_data.dims
coords = band_data.coords
d_arr = func(band_data.data)
band_data = xr.DataArray(d_arr, dims=dims, coords=coords)
else:
band_data = func(band_data)
return wrapper

attrs.update(band_data.attrs)
# combine the new data with the excluded data
new_data = xr.concat([band_data, data.sel(bands=exclude)],
dim='bands')
data.data = new_data.sel(bands=bands).data
data.attrs = attrs

return data
def on_dask_array(func):
"""Pass the underlying dask array to *func* instead of the xarray.DataArray."""
@wraps(func)
def wrapper(data, **kwargs):
dims = data.dims
coords = data.coords
d_arr = func(data.data, **kwargs)
return xr.DataArray(d_arr, dims=dims, coords=coords)
return wrapper


def using_map_blocks(func):
"""Run the provided function using :func:`dask.array.core.map_blocks`.
This means dask will call the provided function with a single chunk
as a numpy array.
"""
@wraps(func)
def wrapper(data, **kwargs):
return da.map_blocks(func, data, meta=np.array((), dtype=data.dtype), dtype=data.dtype, chunks=data.chunks,
**kwargs)
return on_dask_array(wrapper)


def crefl_scaling(img, **kwargs):
Expand Down Expand Up @@ -185,15 +196,16 @@ def piecewise_linear_stretch(
xp = np.asarray(xp) / reference_scale_factor
fp = np.asarray(fp) / reference_scale_factor

def func(band_data, xp, fp, index=None):
# Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end).
band_data = xr.DataArray(da.clip(band_data.data.map_blocks(np.interp, xp=xp, fp=fp), 0, 1),
coords=band_data.coords, dims=band_data.dims, name=band_data.name,
attrs=band_data.attrs)
return band_data
return _piecewise_linear(img.data, xp=xp, fp=fp)

func_with_kwargs = partial(func, xp=xp, fp=fp)
return apply_enhancement(img.data, func_with_kwargs, separate=True)

@exclude_alpha
@using_map_blocks
def _piecewise_linear(band_data, xp, fp):
# Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end).
interp_data = np.interp(band_data, xp=xp, fp=fp)
interp_data = np.clip(interp_data, 0, 1, out=interp_data)
return interp_data


def cira_stretch(img, **kwargs):
Expand All @@ -202,18 +214,19 @@ def cira_stretch(img, **kwargs):
Applicable only for visible channels.
"""
LOG.debug("Applying the cira-stretch")
return _cira_stretch(img.data)

def func(band_data):
log_root = np.log10(0.0223)
denom = (1.0 - log_root) * 0.75
band_data *= 0.01
band_data = band_data.clip(np.finfo(float).eps)
band_data = np.log10(band_data)
band_data -= log_root
band_data /= denom
return band_data

return apply_enhancement(img.data, func)
@exclude_alpha
def _cira_stretch(band_data):
log_root = np.log10(0.0223)
denom = (1.0 - log_root) * 0.75
band_data *= 0.01
band_data = band_data.clip(np.finfo(float).eps)
band_data = np.log10(band_data)
band_data -= log_root
band_data /= denom
return band_data


def reinhard_to_srgb(img, saturation=1.25, white=100, **kwargs):
Expand Down Expand Up @@ -272,18 +285,21 @@ def _lookup_delayed(luts, band_data):
def lookup(img, **kwargs):
"""Assign values to channels based on a table."""
luts = np.array(kwargs['luts'], dtype=np.float32) / 255.0
return _lookup_table(img.data, luts=luts)

def func(band_data, luts=luts, index=-1):
# NaN/null values will become 0
lut = luts[:, index] if len(luts.shape) == 2 else luts
band_data = band_data.clip(0, lut.size - 1).astype(np.uint8)

new_delay = dask.delayed(_lookup_delayed)(lut, band_data)
new_data = da.from_delayed(new_delay, shape=band_data.shape,
dtype=luts.dtype)
return new_data
@exclude_alpha
@on_separate_bands
@using_map_blocks
def _lookup_table(band_data, luts=None, index=-1):
# NaN/null values will become 0
lut = luts[:, index] if len(luts.shape) == 2 else luts
band_data = band_data.clip(0, lut.size - 1).astype(np.uint8)

return apply_enhancement(img.data, func, separate=True, pass_dask=True)
new_delay = dask.delayed(_lookup_delayed)(lut, band_data)
new_data = da.from_delayed(new_delay, shape=band_data.shape,
dtype=luts.dtype)
return new_data


def colorize(img, **kwargs):
Expand Down Expand Up @@ -510,14 +526,6 @@ def _read_colormap_data_from_file(filename):
return np.loadtxt(filename, delimiter=",")


def _three_d_effect_delayed(band_data, kernel, mode):
"""Kernel for running delayed 3D effect creation."""
from scipy.signal import convolve2d
band_data = band_data.reshape(band_data.shape[1:])
new_data = convolve2d(band_data, kernel, mode=mode)
return new_data.reshape((1, band_data.shape[0], band_data.shape[1]))


def three_d_effect(img, **kwargs):
"""Create 3D effect using convolution."""
w = kwargs.get('weight', 1)
Expand All @@ -527,14 +535,26 @@ def three_d_effect(img, **kwargs):
[-w, 0, w]])
mode = kwargs.get('convolve_mode', 'same')

def func(band_data, kernel=kernel, mode=mode, index=None):
del index
return _three_d_effect(img.data, kernel=kernel, mode=mode)


@exclude_alpha
@on_separate_bands
@using_map_blocks
def _three_d_effect(band_data, kernel=None, mode=None, index=None):
del index

delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode)
new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype)
return new_data
delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode)
new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype)
return new_data

return apply_enhancement(img.data, func, separate=True, pass_dask=True)

def _three_d_effect_delayed(band_data, kernel, mode):
"""Kernel for running delayed 3D effect creation."""
from scipy.signal import convolve2d
band_data = band_data.reshape(band_data.shape[1:])
new_data = convolve2d(band_data, kernel, mode=mode)
return new_data.reshape((1, band_data.shape[0], band_data.shape[1]))


def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs):
Expand Down Expand Up @@ -563,10 +583,20 @@ def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs
high_factor = threshold_out / (max_in - threshold)
high_offset = high_factor * max_in

def _bt_threshold(band_data):
# expects dask array to be passed
return da.where(band_data >= threshold,
high_offset - high_factor * band_data,
low_offset - low_factor * band_data)
Coeffs = namedtuple("Coeffs", "factor offset")
high = Coeffs(high_factor, high_offset)
low = Coeffs(low_factor, low_offset)

return _bt_threshold(img.data,
threshold=threshold,
high_coeffs=high,
low_coeffs=low)


return apply_enhancement(img.data, _bt_threshold, pass_dask=True)
@exclude_alpha
@using_map_blocks
def _bt_threshold(band_data, threshold, high_coeffs, low_coeffs):
# expects dask array to be passed
return da.where(band_data >= threshold,
high_coeffs.offset - high_coeffs.factor * band_data,
low_coeffs.offset - low_coeffs.factor * band_data)
41 changes: 22 additions & 19 deletions satpy/enhancements/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,32 @@
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Enhancement functions specific to the ABI sensor."""

from satpy.enhancements import apply_enhancement
from satpy.enhancements import exclude_alpha, using_map_blocks


def cimss_true_color_contrast(img, **kwargs):
"""Scale data based on CIMSS True Color recipe for AWIPS."""
def func(img_data):
"""Perform per-chunk enhancement.
_cimss_true_color_contrast(img.data)

Code ported from Kaba Bah's AWIPS python plugin for creating the
CIMSS Natural (True) Color image in AWIPS. AWIPS provides that python
code the image data on a 0-255 scale. Satpy gives this function the
data on a 0-1.0 scale (assuming linear stretching and sqrt
enhancements have already been applied).

"""
max_value = 1.0
acont = (255.0 / 10.0) / 255.0
amax = (255.0 + 4.0) / 255.0
amid = 1.0 / 2.0
afact = (amax * (acont + max_value) / (max_value * (amax - acont)))
aband = (afact * (img_data - amid) + amid)
aband[aband <= 10 / 255.0] = 0
aband[aband >= 1.0] = 1.0
return aband
@exclude_alpha
@using_map_blocks
def _cimss_true_color_contrast(img_data):
"""Perform per-chunk enhancement.
apply_enhancement(img.data, func, pass_dask=True)
Code ported from Kaba Bah's AWIPS python plugin for creating the
CIMSS Natural (True) Color image in AWIPS. AWIPS provides that python
code the image data on a 0-255 scale. Satpy gives this function the
data on a 0-1.0 scale (assuming linear stretching and sqrt
enhancements have already been applied).
"""
max_value = 1.0
acont = (255.0 / 10.0) / 255.0
amax = (255.0 + 4.0) / 255.0
amid = 1.0 / 2.0
afact = (amax * (acont + max_value) / (max_value * (amax - acont)))
aband = (afact * (img_data - amid) + amid)
aband[aband <= 10 / 255.0] = 0
aband[aband >= 1.0] = 1.0
return aband
22 changes: 12 additions & 10 deletions satpy/enhancements/ahi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dask.array as da
import numpy as np

from satpy.enhancements import apply_enhancement
from satpy.enhancements import exclude_alpha, on_dask_array


def jma_true_color_reproduction(img, **kwargs):
Expand All @@ -31,14 +31,16 @@ def jma_true_color_reproduction(img, **kwargs):
Colorado State University—CIRA
https://www.jma.go.jp/jma/jma-eng/satellite/introduction/TCR.html
"""
_jma_true_color_reproduction(img.data)

def func(img_data):
ccm = np.array([
[1.1759, 0.0561, -0.1322],
[-0.0386, 0.9587, 0.0559],
[-0.0189, -0.1161, 1.0777]
])
output = da.dot(img_data.T, ccm.T)
return output.T

apply_enhancement(img.data, func, pass_dask=True)
@exclude_alpha
@on_dask_array
def _jma_true_color_reproduction(img_data):
ccm = np.array([
[1.1759, 0.0561, -0.1322],
[-0.0386, 0.9587, 0.0559],
[-0.0189, -0.1161, 1.0777]
])
output = da.dot(img_data.T, ccm.T)
return output.T
Loading

0 comments on commit aa7f0dd

Please sign in to comment.