Skip to content

Commit

Permalink
ENH: Use dict to map data to aes in manual scales
Browse files Browse the repository at this point in the history
closes #169
  • Loading branch information
has2k1 committed Aug 1, 2018
1 parent 83183df commit 3029492
Show file tree
Hide file tree
Showing 16 changed files with 95 additions and 39 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ Enhancements
Using a list allows you to bundle up objects. I can be convenient when
creating some complicated plots. See the Periodic Table Example.

- You can now use a ``dict`` (with manual scales) to map data values to
aesthetics (:issue:`169`).

Bug Fixes
*********

Expand Down
6 changes: 4 additions & 2 deletions plotnine/geoms/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class geom(object):
__base__ = True
DEFAULT_AES = dict() #: Default aesthetics for the geom
REQUIRED_AES = set() #: Required aesthetics for the geom
NON_MISSING_AES = set() #: Required aesthetics for the geom
DEFAULT_PARAMS = dict() #: Required parameters for the geom

data = None #: geom/layer specific dataframe
Expand Down Expand Up @@ -395,6 +396,7 @@ def handle_na(self, data):
`na_rm` parameter is False. It only takes into account
the columns of the required aesthetics.
"""
return remove_missing(data, self.params['na_rm'],
list(self.REQUIRED_AES),
return remove_missing(data,
self.params['na_rm'],
list(self.REQUIRED_AES | self.NON_MISSING_AES),
self.__class__.__name__)
1 change: 1 addition & 0 deletions plotnine/geoms/geom_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class geom_bar(geom_rect):
plotnine.geoms.geom_histogram
"""
REQUIRED_AES = {'x', 'y'}
NON_MISSING_AES = {'xmin', 'xmax', 'ymin', 'ymax'}
DEFAULT_PARAMS = {'stat': 'count', 'position': 'stack',
'na_rm': False, 'width': None}

Expand Down
1 change: 1 addition & 0 deletions plotnine/geoms/geom_col.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ class geom_col(geom_bar):
plotnine.geoms.geom_bar
"""
REQUIRED_AES = {'x', 'y'}
NON_MISSING_AES = {'xmin', 'xmax', 'ymin', 'ymax'}
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'stack',
'na_rm': False, 'width': None}
1 change: 1 addition & 0 deletions plotnine/geoms/geom_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class geom_dotplot(geom):
"""
DEFAULT_AES = {'alpha': 1, 'color': 'black', 'fill': 'black'}
REQUIRED_AES = {'x', 'y'}
NON_MISSING_AES = {'size', 'shape'}
DEFAULT_PARAMS = {'stat': 'bindot', 'position': 'identity',
'na_rm': False, 'stackdir': 'up', 'stackratio': 1,
'dotsize': 1, 'stackgroups': False}
Expand Down
1 change: 1 addition & 0 deletions plotnine/geoms/geom_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class geom_point(geom):
DEFAULT_AES = {'alpha': 1, 'color': 'black', 'fill': None,
'shape': 'o', 'size': 1.5, 'stroke': 0.5}
REQUIRED_AES = {'x', 'y'}
NON_MISSING_AES = {'color', 'shape', 'size'}
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'identity',
'na_rm': False}

Expand Down
1 change: 1 addition & 0 deletions plotnine/geoms/geom_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class geom_segment(geom):
DEFAULT_AES = {'alpha': 1, 'color': 'black', 'linetype': 'solid',
'size': 0.5}
REQUIRED_AES = {'x', 'y', 'xend', 'yend'}
NON_MISSING_AES = {'linetype', 'size', 'shape'}
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'identity',
'na_rm': False, 'lineend': 'butt', 'arrow': None}

Expand Down
54 changes: 30 additions & 24 deletions plotnine/guides/guide_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..scales.scale import scale_continuous
from ..utils import ColoredDrawingArea, suppress, SIZE_FACTOR
from ..utils import Registry
from ..utils import Registry, remove_missing
from ..exceptions import PlotnineError
from ..geoms import geom_text
from ..aes import rename_aesthetics
Expand Down Expand Up @@ -172,6 +172,10 @@ def get_legend_geom(layer):
data[ae] = self.override_aes[ae]

geom = get_legend_geom(l)
data = remove_missing(
data, l.geom.params['na_rm'],
list(l.geom.REQUIRED_AES | l.geom.NON_MISSING_AES),
'{} legend'.format(l.geom.__class__.__name__))
self.glayers.append(Bunch(geom=geom, data=data, layer=l))

if not self.glayers:
Expand Down Expand Up @@ -223,27 +227,28 @@ def determine_side_length(initial_size):
pad = default_pad
# Full size of object to appear in the
# legend key
if 'size' in gl.data:
_size = gl.data['size'].iloc[i] * SIZE_FACTOR
if 'stroke' in gl.data:
_size += (2 * gl.data['stroke'].iloc[i] *
SIZE_FACTOR)

# special case, color does not apply to
# border/linewidth
if issubclass(gl.geom, geom_text):
pad = 0
if _size < initial_size:
continue

try:
# color(edgecolor) affects size(linewidth)
# When the edge is not visible, we should
# not expand the size of the keys
if gl.data['color'].iloc[i] is not None:
size[i] = np.max([_size+pad, size[i]])
except KeyError:
break
with suppress(IndexError):
if 'size' in gl.data:
_size = gl.data['size'].iloc[i] * SIZE_FACTOR
if 'stroke' in gl.data:
_size += (2 * gl.data['stroke'].iloc[i] *
SIZE_FACTOR)

# special case, color does not apply to
# border/linewidth
if issubclass(gl.geom, geom_text):
pad = 0
if _size < initial_size:
continue

try:
# color(edgecolor) affects size(linewidth)
# When the edge is not visible, we should
# not expand the size of the keys
if gl.data['color'].iloc[i] is not None:
size[i] = np.max([_size+pad, size[i]])
except KeyError:
break

return size

Expand Down Expand Up @@ -310,8 +315,9 @@ def draw(self):
0, 0, color='white')
# overlay geoms
for gl in self.glayers:
data = gl.data.iloc[i]
da = gl.geom.draw_legend(data, da, gl.layer)
with suppress(IndexError):
data = gl.data.iloc[i]
da = gl.geom.draw_legend(data, da, gl.layer)
drawings.append(da)
themeable['legend_key'].append(drawings)

Expand Down
8 changes: 7 additions & 1 deletion plotnine/scales/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,19 @@ def map(self, x, limits=None):
pal = self.palette(n)
if isinstance(pal, dict):
# manual palette with specific assignments
pal_match = [pal[val] for val in x]
pal_match = []
for val in x:
try:
pal_match.append(pal[val])
except KeyError:
pal_match.append(self.na_value)
else:
pal = np.asarray(pal)
pal_match = pal[match(x, limits)]
bool_idx = pd.isnull(pal_match)
if np.any(bool_idx):
pal_match[bool_idx] = self.na_value

return pal_match

def break_info(self, range=None):
Expand Down
13 changes: 10 additions & 3 deletions plotnine/scales/scale_manual.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import absolute_import, division, print_function

from mizani.palettes import manual_pal
from warnings import warn

from ..doctools import document
from ..utils import alias
Expand All @@ -17,9 +16,17 @@ class _scale_manual(scale_discrete):
{superclass_parameters}
"""
def __init__(self, values, **kwargs):
self.palette = manual_pal(values)
self._values = values
scale_discrete.__init__(self, **kwargs)

def palette(self, n):
max_n = len(self._values)
if n > max_n:
msg = ("Palette can return a maximum of {} values. "
"{} were requested from it.")
warn(msg.format(max_n, n))
return self._values


@document
class scale_color_manual(_scale_manual):
Expand Down
5 changes: 3 additions & 2 deletions plotnine/stats/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class stat(object):

REQUIRED_AES = set()
DEFAULT_AES = dict()
NON_MISSING_AES = set()
DEFAULT_PARAMS = dict()

# Should the values produced by the statistic also
Expand All @@ -35,7 +36,7 @@ class stat(object):
# see: stat_bin
CREATES = set()

# Documentation for the aesthetics. It is added under the
# Documentation for the aesthetics. It ie added under the
# documentation for mapping parameter. Use {aesthetics_table}
# placeholder to insert a table for all the aesthetics and
# their default values.
Expand Down Expand Up @@ -252,7 +253,7 @@ def compute_layer(cls, data, params, layout):
data = remove_missing(
data,
na_rm=params.get('na_rm', False),
vars=list(cls.REQUIRED_AES),
vars=list(cls.REQUIRED_AES | cls.NON_MISSING_AES),
name=cls.__name__,
finite=True)

Expand Down
1 change: 1 addition & 0 deletions plotnine/stats/stat_bindot.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class stat_bindot(stat):
"""

REQUIRED_AES = {'x'}
NON_MISSING_AES = {'weight'}
DEFAULT_PARAMS = {'geom': 'dotplot', 'position': 'identity',
'na_rm': False,
'bins': None, 'binwidth': None, 'origin': None,
Expand Down
1 change: 1 addition & 0 deletions plotnine/stats/stat_boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class stat_boxplot(stat):
"""

REQUIRED_AES = {'x', 'y'}
NON_MISSING_AES = {'weight'}
DEFAULT_PARAMS = {'geom': 'boxplot', 'position': 'dodge',
'na_rm': False, 'coef': 1.5, 'width': None}
CREATES = {'lower', 'upper', 'middle', 'ymin', 'ymax',
Expand Down
1 change: 1 addition & 0 deletions plotnine/stats/stat_ydensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class stat_ydensity(stat):
e.g. :py:`'stat(width)'`.
"""
REQUIRED_AES = {'x', 'y'}
NON_MISSING_AES = {'weight'}
DEFAULT_PARAMS = {'geom': 'violin', 'position': 'dodge',
'na_rm': False,
'adjust': 1, 'kernel': 'gaussian',
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 30 additions & 7 deletions plotnine/tests/test_scale_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,21 @@ def is_manual_scale(name):
return (name.startswith('scale_') and
name.endswith('_manual'))

manual_scales = [getattr(scale_manual, name)
for name in scale_manual.__dict__
if is_manual_scale(name)]

values = [1, 2, 3, 4, 5]
for name in scale_manual.__dict__:
if is_manual_scale(name):
s = getattr(scale_manual, name)(values)
assert s.palette(2) == values[:2]
assert s.palette(len(values)) == values
with pytest.warns(UserWarning):
s.palette(len(values)+1)
for _scale in manual_scales:
s = _scale(values)
assert s.palette(2) == values
assert s.palette(len(values)) == values
with pytest.warns(UserWarning):
s.palette(len(values)+1)

values = {'A': 'red', 'B': 'violet', 'C': 'blue'}
sc = scale_manual.scale_color_manual(values)
assert sc.palette(3) == values


def test_alpha_palette():
Expand Down Expand Up @@ -415,3 +422,19 @@ def test_multiple_aesthetics():
type='qual', palette=1, aesthetics=['fill', 'color'])
)
assert p + _theme == 'multiple_aesthetics'


def test_missing_manual_dict_aesthetic():
df = pd.DataFrame({
'x': range(15),
'y': range(15),
'c': np.repeat(list('ABC'), 5)
})

values = {'A': 'red', 'B': 'violet', 'D': 'blue'}

p = (ggplot(df, aes('x', 'y', color='c'))
+ geom_point(size=3)
+ scale_manual.scale_color_manual(values)
)
assert p + _theme == 'missing_manual_dict_aesthetic'

0 comments on commit 3029492

Please sign in to comment.