diff --git a/mapytex/calculus/API/expression.py b/mapytex/calculus/API/expression.py index 114f290..231b253 100644 --- a/mapytex/calculus/API/expression.py +++ b/mapytex/calculus/API/expression.py @@ -106,7 +106,10 @@ class Expression(object): """ if optimize: try: - self._tree = self._tree.balance() + # TODO: need to test exclude_nodes |ven. oct. 5 08:51:02 CEST 2018 + self._tree = self._tree.balance( + exclude_nodes=["\\", "**"] + ) except AttributeError: pass try: diff --git a/mapytex/calculus/core/tree.py b/mapytex/calculus/core/tree.py index d64d4e4..6f8bafa 100644 --- a/mapytex/calculus/core/tree.py +++ b/mapytex/calculus/core/tree.py @@ -655,7 +655,7 @@ class Tree(object): else: return self.left_value - def balance(self): + def balance(self, exclude_nodes = []): """ Recursively balance the tree without permutting different nodes :return: balanced tree @@ -719,6 +719,26 @@ class Tree(object): | > * | | > 8 | | > 9 + >>> t = Tree.from_str("1+2+3+4+5/6/7/8/9") + >>> bal_t = t.balance(exclude_nodes=['/']) + >>> print(bal_t) + + + > + + | > + + | | > + + | | | > 1 + | | | > 2 + | | > 3 + | > 4 + > / + | > / + | | > / + | | | > / + | | | | > 5 + | | | | > 6 + | | | > 7 + | | > 8 + | > 9 """ try: l_depth = self.left_value.depth() @@ -730,27 +750,29 @@ class Tree(object): r_depth = 1 if l_depth > r_depth+1 and\ - self.node == self.left_value.node: + self.node == self.left_value.node and \ + self.node not in exclude_nodes: new_left = self.left_value.long_branch new_right = Tree(self.node, self.left_value.short_branch, self.right_value) - return Tree(self.node, new_left, new_right).balance() + return Tree(self.node, new_left, new_right).balance(exclude_nodes) if r_depth > l_depth+1 and\ - self.node == self.right_value.node: + self.node == self.right_value.node and \ + self.node not in exclude_nodes: new_right = self.right_value.long_branch new_left = Tree(self.node, self.left_value, self.right_value.short_branch) - return Tree(self.node, new_left, new_right).balance() + return Tree(self.node, new_left, new_right).balance(exclude_nodes) try: - left_v = self.left_value.balance() + left_v = self.left_value.balance(exclude_nodes) except AttributeError: left_v = self.left_value try: - right_v = self.right_value.balance() + right_v = self.right_value.balance(exclude_nodes) except AttributeError: right_v = self.right_value