Feat(Tree): Add exclude_nodes in balance to trees

This commit is contained in:
Bertrand Benjamin 2018-10-05 08:52:04 +02:00
parent 948402755a
commit 9e0a703e98
2 changed files with 33 additions and 8 deletions

View File

@ -106,7 +106,10 @@ class Expression(object):
""" """
if optimize: if optimize:
try: try:
self._tree = self._tree.balance() # TODO: need to test exclude_nodes |ven. oct. 5 08:51:02 CEST 2018
self._tree = self._tree.balance(
exclude_nodes=["\\", "**"]
)
except AttributeError: except AttributeError:
pass pass
try: try:

View File

@ -655,7 +655,7 @@ class Tree(object):
else: else:
return self.left_value return self.left_value
def balance(self): def balance(self, exclude_nodes = []):
""" Recursively balance the tree without permutting different nodes """ Recursively balance the tree without permutting different nodes
:return: balanced tree :return: balanced tree
@ -719,6 +719,26 @@ class Tree(object):
| > * | > *
| | > 8 | | > 8
| | > 9 | | > 9
>>> t = Tree.from_str("1+2+3+4+5/6/7/8/9")
>>> bal_t = t.balance(exclude_nodes=['/'])
>>> print(bal_t)
+
> +
| > +
| | > +
| | | > 1
| | | > 2
| | > 3
| > 4
> /
| > /
| | > /
| | | > /
| | | | > 5
| | | | > 6
| | | > 7
| | > 8
| > 9
""" """
try: try:
l_depth = self.left_value.depth() l_depth = self.left_value.depth()
@ -730,27 +750,29 @@ class Tree(object):
r_depth = 1 r_depth = 1
if l_depth > r_depth+1 and\ if l_depth > r_depth+1 and\
self.node == self.left_value.node: self.node == self.left_value.node and \
self.node not in exclude_nodes:
new_left = self.left_value.long_branch new_left = self.left_value.long_branch
new_right = Tree(self.node, new_right = Tree(self.node,
self.left_value.short_branch, self.left_value.short_branch,
self.right_value) self.right_value)
return Tree(self.node, new_left, new_right).balance() return Tree(self.node, new_left, new_right).balance(exclude_nodes)
if r_depth > l_depth+1 and\ if r_depth > l_depth+1 and\
self.node == self.right_value.node: self.node == self.right_value.node and \
self.node not in exclude_nodes:
new_right = self.right_value.long_branch new_right = self.right_value.long_branch
new_left = Tree(self.node, new_left = Tree(self.node,
self.left_value, self.left_value,
self.right_value.short_branch) self.right_value.short_branch)
return Tree(self.node, new_left, new_right).balance() return Tree(self.node, new_left, new_right).balance(exclude_nodes)
try: try:
left_v = self.left_value.balance() left_v = self.left_value.balance(exclude_nodes)
except AttributeError: except AttributeError:
left_v = self.left_value left_v = self.left_value
try: try:
right_v = self.right_value.balance() right_v = self.right_value.balance(exclude_nodes)
except AttributeError: except AttributeError:
right_v = self.right_value right_v = self.right_value