diff --git a/json_merger/comparator.py b/json_merger/comparator.py index c4fce06..c56446a 100644 --- a/json_merger/comparator.py +++ b/json_merger/comparator.py @@ -105,16 +105,25 @@ class PrimaryKeyComparator(BaseComparator): primary_key_fields = ['pk'] normalization_functions = {} - def _have_field_equal(self, obj1, obj2, field): + def _get_compared_objects_at_field_path(self, obj1, obj2, field): key_path = tuple(k for k in field.split('.') if k) o1 = get_obj_at_key_path(obj1, key_path, NOTHING) o2 = get_obj_at_key_path(obj2, key_path, NOTHING) + return o1, o2 + + def _have_field_equal(self, obj1, obj2, field): + o1, o2 = self._get_compared_objects_at_field_path(obj1, obj2, field) if o1 == NOTHING or o2 == NOTHING: return False fn = self.normalization_functions.get(field, lambda x: x) return fn(o1) == fn(o2) + def _are_fields_nothing(self, obj1, obj2, field): + o1, o2 = self._get_compared_objects_at_field_path(obj1, obj2, field) + if o1 == NOTHING or o2 == NOTHING: + return True + def equal(self, obj1, obj2): if obj1 == obj2: return True @@ -124,7 +133,11 @@ def equal(self, obj1, obj2): field_set = [field_set] checks = [self._have_field_equal(obj1, obj2, field) for field in field_set] - if all(checks): + are_all_fields_nothing = [ + self._are_fields_nothing(obj1, obj2, field) + for field in field_set + ] + if all(checks) and not all(are_all_fields_nothing): return True return False