cython-devel
changeset 3042:9756a762a5c8
fix ticket 467: restore eval-once semantics for all rhs items in parallel assignments by extracting common subexpressions into temps
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Sat Mar 06 15:30:38 2010 +0100 (3 years ago) |
| parents | 0ccfe1050137 |
| children | f38e938a4338cad2a43b2cb0 |
| files | Cython/Compiler/ParseTreeTransforms.py Cython/Compiler/UtilNodes.py |
line diff
1.1 --- a/Cython/Compiler/ParseTreeTransforms.py Sat Mar 06 08:04:50 2010 +0100
1.2 +++ b/Cython/Compiler/ParseTreeTransforms.py Sat Mar 06 15:30:38 2010 +0100
1.3 @@ -265,6 +265,9 @@
1.4
1.5 expr_list_list = []
1.6 flatten_parallel_assignments(expr_list, expr_list_list)
1.7 + temp_refs = []
1.8 + eliminate_rhs_duplicates(expr_list_list, temp_refs)
1.9 +
1.10 nodes = []
1.11 for expr_list in expr_list_list:
1.12 lhs_list = expr_list[:-1]
1.13 @@ -276,11 +279,94 @@
1.14 node = Nodes.CascadedAssignmentNode(rhs.pos,
1.15 lhs_list = lhs_list, rhs = rhs)
1.16 nodes.append(node)
1.17 +
1.18 if len(nodes) == 1:
1.19 - return nodes[0]
1.20 + assign_node = nodes[0]
1.21 else:
1.22 - return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
1.23 + assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
1.24
1.25 + if temp_refs:
1.26 + duplicates_and_temps = [ (temp.expression, temp)
1.27 + for temp in temp_refs ]
1.28 + sort_common_subsequences(duplicates_and_temps)
1.29 + for _, temp_ref in duplicates_and_temps[::-1]:
1.30 + assign_node = LetNode(temp_ref, assign_node)
1.31 +
1.32 + return assign_node
1.33 +
1.34 +def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
1.35 + """Replace rhs items by LetRefNodes if they appear more than once.
1.36 + Creates a sequence of LetRefNodes that set up the required temps
1.37 + and appends them to ref_node_sequence. The input list is modified
1.38 + in-place.
1.39 + """
1.40 + seen_nodes = set()
1.41 + ref_nodes = {}
1.42 + def find_duplicates(node):
1.43 + if node.is_literal or node.is_name:
1.44 + # no need to replace those; can't include attributes here
1.45 + # as their access is not necessarily side-effect free
1.46 + return
1.47 + if node in seen_nodes:
1.48 + if node not in ref_nodes:
1.49 + ref_node = LetRefNode(node)
1.50 + ref_nodes[node] = ref_node
1.51 + ref_node_sequence.append(ref_node)
1.52 + else:
1.53 + seen_nodes.add(node)
1.54 + if node.is_sequence_constructor:
1.55 + for item in node.args:
1.56 + find_duplicates(item)
1.57 +
1.58 + for expr_list in expr_list_list:
1.59 + rhs = expr_list[-1]
1.60 + find_duplicates(rhs)
1.61 + if not ref_nodes:
1.62 + return
1.63 +
1.64 + def substitute_nodes(node):
1.65 + if node in ref_nodes:
1.66 + return ref_nodes[node]
1.67 + elif node.is_sequence_constructor:
1.68 + node.args = map(substitute_nodes, node.args)
1.69 + return node
1.70 +
1.71 + # replace nodes inside of the common subexpressions
1.72 + for node in ref_nodes:
1.73 + if node.is_sequence_constructor:
1.74 + node.args = map(substitute_nodes, node.args)
1.75 +
1.76 + # replace common subexpressions on all rhs items
1.77 + for expr_list in expr_list_list:
1.78 + expr_list[-1] = substitute_nodes(expr_list[-1])
1.79 +
1.80 +def sort_common_subsequences(items):
1.81 + """Sort items/subsequences so that all items and subsequences that
1.82 + an item contains appear before the item itself. This implies a
1.83 + partial order, and the sort must be stable to preserve the
1.84 + original order as much as possible, so we use a simple insertion
1.85 + sort.
1.86 + """
1.87 + def contains(seq, x):
1.88 + for item in seq:
1.89 + if item is x:
1.90 + return True
1.91 + elif item.is_sequence_constructor and contains(item.args, x):
1.92 + return True
1.93 + return False
1.94 + def lower_than(a,b):
1.95 + return b.is_sequence_constructor and contains(b.args, a)
1.96 +
1.97 + for pos, item in enumerate(items):
1.98 + new_pos = pos
1.99 + key = item[0]
1.100 + for i in xrange(pos-1, -1, -1):
1.101 + if lower_than(key, items[i][0]):
1.102 + new_pos = i
1.103 + if new_pos != pos:
1.104 + for i in xrange(pos, new_pos, -1):
1.105 + items[i] = items[i-1]
1.106 + items[new_pos] = item
1.107
1.108 def flatten_parallel_assignments(input, output):
1.109 # The input is a list of expression nodes, representing the LHSs
2.1 --- a/Cython/Compiler/UtilNodes.py Sat Mar 06 08:04:50 2010 +0100
2.2 +++ b/Cython/Compiler/UtilNodes.py Sat Mar 06 15:30:38 2010 +0100
2.3 @@ -130,6 +130,9 @@
2.4 def infer_type(self, env):
2.5 return self.expression.infer_type(env)
2.6
2.7 + def is_simple(self):
2.8 + return True
2.9 +
2.10 def result(self):
2.11 return self.result_code
2.12
2.13 @@ -222,7 +225,8 @@
2.14 # BLOCK (can modify temp)
2.15 # if temp is an object, decref
2.16 #
2.17 - # To be used after analysis phase, does no analysis.
2.18 + # Usually used after analysis phase, but forwards analysis methods
2.19 + # to its children
2.20
2.21 child_attrs = ['temp_expression', 'body']
2.22
2.23 @@ -231,6 +235,17 @@
2.24 self.pos = body.pos
2.25 self.body = body
2.26
2.27 + def analyse_control_flow(self, env):
2.28 + self.body.analyse_control_flow(env)
2.29 +
2.30 + def analyse_declarations(self, env):
2.31 + self.temp_expression.analyse_declarations(env)
2.32 + self.body.analyse_declarations(env)
2.33 +
2.34 + def analyse_expressions(self, env):
2.35 + self.temp_expression.analyse_expressions(env)
2.36 + self.body.analyse_expressions(env)
2.37 +
2.38 def generate_execution_code(self, code):
2.39 self.setup_temp_expr(code)
2.40 self.body.generate_execution_code(code)
