Skip to content

Commit

Permalink
Fixes #119. cleaned up the class names normalization as well. If you …
Browse files Browse the repository at this point in the history
…do not specify class names for classifier, "class i" is what you get. I reran all of the examples.
  • Loading branch information
parrt committed Jan 27, 2021
1 parent fbcb423 commit d89ea22
Show file tree
Hide file tree
Showing 6 changed files with 36,658 additions and 29,126 deletions.
22 changes: 11 additions & 11 deletions dtreeviz/models/shadow_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,11 @@ def __init__(self,

self.feature_names = feature_names
self.target_name = target_name
self.class_names = class_names
# self.class_weight = self.get_class_weight()
self.x_data = ShadowDecTree._get_x_data(x_data)
self.y_data = ShadowDecTree._get_y_data(y_data)
# self.node_to_samples = self.get_node_samples()
self.root, self.leaves, self.internal = self._get_tree_nodes()
if class_names:
self.class_names = self._get_class_names()
if self.is_classifier():
self.class_names = self._normalize_class_names(class_names)

@abstractmethod
def is_fit(self) -> bool:
Expand Down Expand Up @@ -391,14 +388,17 @@ def get_leaf_sample_counts_by_class(self):
index, leaf_sample_0, leaf_samples_1 = zip(*leaf_samples)
return index, leaf_sample_0, leaf_samples_1

def _get_class_names(self):
def _normalize_class_names(self, class_names):
if self.is_classifier():
if isinstance(self.class_names, dict):
return self.class_names
elif isinstance(self.class_names, Sequence):
return {i: n for i, n in enumerate(self.class_names)}
if class_names is None:
return {i : f"class {i}" for i in range(self.nclasses())}
if isinstance(class_names, dict):
return class_names
elif isinstance(class_names, Sequence):
return {i: n for i, n in enumerate(class_names)}
else:
raise Exception(f"class_names must be dict or sequence, not {self.class_names.__class__.__name__}")
raise Exception(f"class_names must be dict or sequence, not {class_names.__class__.__name__}")
return None

def _get_tree_nodes(self):
# use locals not args to walk() for recursion speed in python
Expand Down
18,813 changes: 9,813 additions & 9,000 deletions notebooks/dtreeviz_sklearn_visualisations.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit d89ea22

Please sign in to comment.