Cython has moved to github.
cython-devel
view Cython/Compiler/Optimize.py @ 1486:8a2e7b51e770
initial constant folding transform: calculate constant values in node.constant_result
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Sat Dec 13 22:23:00 2008 +0100 (3 years ago) |
| parents | 5c0b99aebb9b |
| children | e8038f8da796 |
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 from StringEncoding import EncodedString
11 from ParseTreeTransforms import SkipDeclarations
13 #def unwrap_node(node):
14 # while isinstance(node, ExprNodes.PersistentNode):
15 # node = node.arg
16 # return node
18 # Temporary hack while PersistentNode is out of order
19 def unwrap_node(node):
20 return node
22 def is_common_value(a, b):
23 a = unwrap_node(a)
24 b = unwrap_node(b)
25 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
26 return a.name == b.name
27 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
28 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
29 return False
32 class DictIterTransform(Visitor.VisitorTransform):
33 """Transform a for-in-dict loop into a while loop calling PyDict_Next().
34 """
35 PyDict_Next_func_type = PyrexTypes.CFuncType(
36 PyrexTypes.c_bint_type, [
37 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
38 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
39 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
40 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
41 ])
43 PyDict_Next_name = EncodedString("PyDict_Next")
45 PyDict_Next_entry = Symtab.Entry(
46 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
48 def visit_ForInStatNode(self, node):
49 self.visitchildren(node)
50 iterator = node.iterator.sequence
51 if iterator.type is Builtin.dict_type:
52 # like iterating over dict.keys()
53 dict_obj = iterator
54 keys = True
55 values = False
56 else:
57 if not isinstance(iterator, ExprNodes.SimpleCallNode):
58 return node
59 function = iterator.function
60 if not isinstance(function, ExprNodes.AttributeNode):
61 return node
62 if function.obj.type != Builtin.dict_type:
63 return node
64 dict_obj = function.obj
65 method = function.attribute
67 keys = values = False
68 if method == 'iterkeys':
69 keys = True
70 elif method == 'itervalues':
71 values = True
72 elif method == 'iteritems':
73 keys = values = True
74 else:
75 return node
77 py_object_ptr = PyrexTypes.c_void_ptr_type
79 temps = []
80 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
81 temps.append(temp)
82 dict_temp = temp.ref(dict_obj.pos)
83 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
84 temps.append(temp)
85 pos_temp = temp.ref(node.pos)
86 pos_temp_addr = ExprNodes.AmpersandNode(
87 node.pos, operand=pos_temp,
88 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
89 if keys:
90 temp = UtilNodes.TempHandle(py_object_ptr)
91 temps.append(temp)
92 key_temp = temp.ref(node.target.pos)
93 key_temp_addr = ExprNodes.AmpersandNode(
94 node.target.pos, operand=key_temp,
95 type=PyrexTypes.c_ptr_type(py_object_ptr))
96 else:
97 key_temp_addr = key_temp = ExprNodes.NullNode(
98 pos=node.target.pos)
99 if values:
100 temp = UtilNodes.TempHandle(py_object_ptr)
101 temps.append(temp)
102 value_temp = temp.ref(node.target.pos)
103 value_temp_addr = ExprNodes.AmpersandNode(
104 node.target.pos, operand=value_temp,
105 type=PyrexTypes.c_ptr_type(py_object_ptr))
106 else:
107 value_temp_addr = value_temp = ExprNodes.NullNode(
108 pos=node.target.pos)
110 key_target = value_target = node.target
111 tuple_target = None
112 if keys and values:
113 if node.target.is_sequence_constructor:
114 if len(node.target.args) == 2:
115 key_target, value_target = node.target.args
116 else:
117 # unusual case that may or may not lead to an error
118 return node
119 else:
120 tuple_target = node.target
122 def coerce_object_to(obj_node, dest_type):
123 class FakeEnv(object):
124 nogil = False
125 if dest_type.is_pyobject:
126 if dest_type.is_extension_type or dest_type.is_builtin_type:
127 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
128 result = ExprNodes.TypecastNode(
129 obj_node.pos,
130 operand = obj_node,
131 type = dest_type)
132 return (result, None)
133 else:
134 temp = UtilNodes.TempHandle(dest_type)
135 temps.append(temp)
136 temp_result = temp.ref(obj_node.pos)
137 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
138 def result(self):
139 return temp_result.result()
140 def generate_execution_code(self, code):
141 self.generate_result_code(code)
142 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
144 if isinstance(node.body, Nodes.StatListNode):
145 body = node.body
146 else:
147 body = Nodes.StatListNode(pos = node.body.pos,
148 stats = [node.body])
150 if tuple_target:
151 tuple_result = ExprNodes.TupleNode(
152 pos = tuple_target.pos,
153 args = [key_temp, value_temp],
154 is_temp = 1,
155 type = Builtin.tuple_type,
156 )
157 body.stats.insert(
158 0, Nodes.SingleAssignmentNode(
159 pos = tuple_target.pos,
160 lhs = tuple_target,
161 rhs = tuple_result))
162 else:
163 # execute all coercions before the assignments
164 coercion_stats = []
165 assign_stats = []
166 if keys:
167 temp_result, coercion = coerce_object_to(
168 key_temp, key_target.type)
169 if coercion:
170 coercion_stats.append(coercion)
171 assign_stats.append(
172 Nodes.SingleAssignmentNode(
173 pos = key_temp.pos,
174 lhs = key_target,
175 rhs = temp_result))
176 if values:
177 temp_result, coercion = coerce_object_to(
178 value_temp, value_target.type)
179 if coercion:
180 coercion_stats.append(coercion)
181 assign_stats.append(
182 Nodes.SingleAssignmentNode(
183 pos = value_temp.pos,
184 lhs = value_target,
185 rhs = temp_result))
186 body.stats[0:0] = coercion_stats + assign_stats
188 result_code = [
189 Nodes.SingleAssignmentNode(
190 pos = dict_obj.pos,
191 lhs = dict_temp,
192 rhs = dict_obj),
193 Nodes.SingleAssignmentNode(
194 pos = node.pos,
195 lhs = pos_temp,
196 rhs = ExprNodes.IntNode(node.pos, value=0)),
197 Nodes.WhileStatNode(
198 pos = node.pos,
199 condition = ExprNodes.SimpleCallNode(
200 pos = dict_obj.pos,
201 type = PyrexTypes.c_bint_type,
202 function = ExprNodes.NameNode(
203 pos = dict_obj.pos,
204 name = self.PyDict_Next_name,
205 type = self.PyDict_Next_func_type,
206 entry = self.PyDict_Next_entry),
207 args = [dict_temp, pos_temp_addr,
208 key_temp_addr, value_temp_addr]
209 ),
210 body = body,
211 else_clause = node.else_clause
212 )
213 ]
215 return UtilNodes.TempsBlockNode(
216 node.pos, temps=temps,
217 body=Nodes.StatListNode(
218 node.pos,
219 stats = result_code
220 ))
222 def visit_Node(self, node):
223 # descend into statements (loops) and nodes (comprehensions)
224 self.visitchildren(node)
225 return node
228 class SwitchTransform(Visitor.VisitorTransform):
229 """
230 This transformation tries to turn long if statements into C switch statements.
231 The requirement is that every clause be an (or of) var == value, where the var
232 is common among all clauses and both var and value are ints.
233 """
234 def extract_conditions(self, cond):
236 if isinstance(cond, ExprNodes.CoerceToTempNode):
237 cond = cond.arg
239 if isinstance(cond, ExprNodes.TypecastNode):
240 cond = cond.operand
242 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
243 and cond.cascade is None
244 and cond.operator == '=='
245 and not cond.is_python_comparison()):
246 if is_common_value(cond.operand1, cond.operand1):
247 if isinstance(cond.operand2, ExprNodes.ConstNode):
248 return cond.operand1, [cond.operand2]
249 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
250 return cond.operand1, [cond.operand2]
251 if is_common_value(cond.operand2, cond.operand2):
252 if isinstance(cond.operand1, ExprNodes.ConstNode):
253 return cond.operand2, [cond.operand1]
254 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
255 return cond.operand2, [cond.operand1]
256 elif (isinstance(cond, ExprNodes.BoolBinopNode)
257 and cond.operator == 'or'):
258 t1, c1 = self.extract_conditions(cond.operand1)
259 t2, c2 = self.extract_conditions(cond.operand2)
260 if is_common_value(t1, t2):
261 return t1, c1+c2
262 return None, None
264 def visit_IfStatNode(self, node):
265 self.visitchildren(node)
266 common_var = None
267 case_count = 0
268 cases = []
269 for if_clause in node.if_clauses:
270 var, conditions = self.extract_conditions(if_clause.condition)
271 if var is None:
272 return node
273 elif common_var is not None and not is_common_value(var, common_var):
274 return node
275 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
276 return node
277 else:
278 common_var = var
279 case_count += len(conditions)
280 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
281 conditions = conditions,
282 body = if_clause.body))
283 if case_count < 2:
284 return node
286 common_var = unwrap_node(common_var)
287 return Nodes.SwitchStatNode(pos = node.pos,
288 test = common_var,
289 cases = cases,
290 else_clause = node.else_clause)
293 def visit_Node(self, node):
294 self.visitchildren(node)
295 return node
298 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
299 """
300 This transformation flattens "x in [val1, ..., valn]" into a sequential list
301 of comparisons.
302 """
304 def visit_PrimaryCmpNode(self, node):
305 self.visitchildren(node)
306 if node.cascade is not None:
307 return node
308 elif node.operator == 'in':
309 conjunction = 'or'
310 eq_or_neq = '=='
311 elif node.operator == 'not_in':
312 conjunction = 'and'
313 eq_or_neq = '!='
314 else:
315 return node
317 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
318 return node
320 args = node.operand2.args
321 if len(args) == 0:
322 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
324 lhs = UtilNodes.ResultRefNode(node.operand1)
326 conds = []
327 for arg in args:
328 cond = ExprNodes.PrimaryCmpNode(
329 pos = node.pos,
330 operand1 = lhs,
331 operator = eq_or_neq,
332 operand2 = arg,
333 cascade = None)
334 conds.append(ExprNodes.TypecastNode(
335 pos = node.pos,
336 operand = cond,
337 type = PyrexTypes.c_bint_type))
338 def concat(left, right):
339 return ExprNodes.BoolBinopNode(
340 pos = node.pos,
341 operator = conjunction,
342 operand1 = left,
343 operand2 = right)
345 condition = reduce(concat, conds)
346 return UtilNodes.EvalWithTempExprNode(lhs, condition)
348 def visit_Node(self, node):
349 self.visitchildren(node)
350 return node
353 class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
354 """Optimise some common instantiation patterns for builtin types.
355 """
356 def visit_GeneralCallNode(self, node):
357 """Replace dict(a=b,c=d,...) by the underlying keyword dict
358 construction which is done anyway.
359 """
360 self.visitchildren(node)
361 if not node.function.type.is_builtin_type:
362 return node
363 if node.function.name != 'dict':
364 return node
365 if not isinstance(node.positional_args, ExprNodes.TupleNode):
366 return node
367 if len(node.positional_args.args) > 0:
368 return node
369 if not isinstance(node.keyword_args, ExprNodes.DictNode):
370 return node
371 if node.starstar_arg:
372 # we could optimise this by updating the kw dict instead
373 return node
374 return node.keyword_args
376 def visit_PyTypeTestNode(self, node):
377 """Flatten redundant type checks after tree changes.
378 """
379 old_arg = node.arg
380 self.visitchildren(node)
381 if old_arg is node.arg or node.arg.type != node.type:
382 return node
383 return node.arg
385 def visit_Node(self, node):
386 self.visitchildren(node)
387 return node
390 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
391 """Calculate the result of constant expressions to store it in
392 ``expr_node.constant_result``, and replace trivial cases by their
393 constant result.
394 """
395 def _calculate_const(self, node):
396 if node.constant_result is not ExprNodes.constant_value_not_set:
397 return
399 # make sure we always set the value
400 not_a_constant = ExprNodes.not_a_constant
401 node.constant_result = not_a_constant
403 # check if all children are constant
404 children = self.visitchildren(node)
405 for child_result in children.itervalues():
406 if type(child_result) is list:
407 for child in child_result:
408 if child.constant_result is not_a_constant:
409 return
410 elif child_result.constant_result is not_a_constant:
411 return
413 # now try to calculate the real constant value
414 try:
415 node.calculate_constant_result()
416 # if node.constant_result is not ExprNodes.not_a_constant:
417 # print node.__class__.__name__, node.constant_result
418 except (ValueError, TypeError, IndexError, AttributeError):
419 # ignore all 'normal' errors here => no constant result
420 pass
421 except Exception:
422 # this looks like a real error
423 import traceback, sys
424 traceback.print_exc(file=sys.stdout)
426 def visit_ExprNode(self, node):
427 self._calculate_const(node)
428 return node
430 # in the future, other nodes can have their own handler method here
431 # that can replace them with a constant result node
433 def visit_Node(self, node):
434 self.visitchildren(node)
435 return node
438 class FinalOptimizePhase(Visitor.CythonTransform):
439 """
440 This visitor handles several commuting optimizations, and is run
441 just before the C code generation phase.
443 The optimizations currently implemented in this class are:
444 - Eliminate None assignment and refcounting for first assignment.
445 - isinstance -> typecheck for cdef types
446 """
447 def visit_SingleAssignmentNode(self, node):
448 """Avoid redundant initialisation of local variables before their
449 first assignment.
450 """
451 self.visitchildren(node)
452 if node.first:
453 lhs = node.lhs
454 lhs.lhs_of_first_assignment = True
455 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
456 # Have variable initialized to 0 rather than None
457 lhs.entry.init_to_none = False
458 lhs.entry.init = 0
459 return node
461 def visit_SimpleCallNode(self, node):
462 """Replace generic calls to isinstance(x, type) by a more efficient
463 type check.
464 """
465 self.visitchildren(node)
466 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
467 if node.function.name == 'isinstance':
468 type_arg = node.args[1]
469 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
470 object_module = self.context.find_module('python_object')
471 node.function.entry = object_module.lookup('PyObject_TypeCheck')
472 if node.function.entry is None:
473 return node # only happens when there was an error earlier
474 node.function.type = node.function.entry.type
475 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
476 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
477 return node
