Cython has moved to github.
cython-devel
view Cython/Compiler/Optimize.py @ 1500:c1a7180ac974
moved iter-range() optimisation into a transform (worth a review)
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Wed Dec 17 22:29:11 2008 +0100 (3 years ago) |
| parents | b8e290068894 |
| children | 27f0d0f718a5 |
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 from StringEncoding import EncodedString
12 from ParseTreeTransforms import SkipDeclarations
14 #def unwrap_node(node):
15 # while isinstance(node, ExprNodes.PersistentNode):
16 # node = node.arg
17 # return node
19 # Temporary hack while PersistentNode is out of order
20 def unwrap_node(node):
21 return node
23 def is_common_value(a, b):
24 a = unwrap_node(a)
25 b = unwrap_node(b)
26 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
27 return a.name == b.name
28 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
29 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
30 return False
33 class IterationTransform(Visitor.VisitorTransform):
34 """Transform some common for-in loop patterns into efficient C loops:
36 - for-in-dict loop becomes a while loop calling PyDict_Next()
37 - for-in-range loop becomes a plain C for loop
38 """
39 PyDict_Next_func_type = PyrexTypes.CFuncType(
40 PyrexTypes.c_bint_type, [
41 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
42 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
43 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
44 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
45 ])
47 PyDict_Next_name = EncodedString("PyDict_Next")
49 PyDict_Next_entry = Symtab.Entry(
50 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
52 def visit_Node(self, node):
53 # descend into statements (loops) and nodes (comprehensions)
54 self.visitchildren(node)
55 return node
57 def visit_ModuleNode(self, node):
58 self.current_scope = node.scope
59 self.visitchildren(node)
60 return node
62 def visit_DefNode(self, node):
63 oldscope = self.current_scope
64 self.current_scope = node.entry.scope
65 self.visitchildren(node)
66 self.current_scope = oldscope
67 return node
69 def visit_ForInStatNode(self, node):
70 self.visitchildren(node)
71 iterator = node.iterator.sequence
72 if iterator.type is Builtin.dict_type:
73 # like iterating over dict.keys()
74 return self._transform_dict_iteration(
75 node, dict_obj=iterator, keys=True, values=False)
76 if not isinstance(iterator, ExprNodes.SimpleCallNode):
77 return node
79 function = iterator.function
80 # dict iteration?
81 if isinstance(function, ExprNodes.AttributeNode) and \
82 function.obj.type == Builtin.dict_type:
83 dict_obj = function.obj
84 method = function.attribute
86 keys = values = False
87 if method == 'iterkeys':
88 keys = True
89 elif method == 'itervalues':
90 values = True
91 elif method == 'iteritems':
92 keys = values = True
93 else:
94 return node
95 return self._transform_dict_iteration(
96 node, dict_obj, keys, values)
98 # range() iteration?
99 if Options.convert_range and node.target.type.is_int:
100 if iterator.self is None and \
101 isinstance(function, ExprNodes.NameNode) and \
102 function.entry.is_builtin and \
103 function.name in ('range', 'xrange'):
104 return self._transform_range_iteration(
105 node, iterator)
107 return node
109 def _transform_range_iteration(self, node, range_function):
110 args = range_function.arg_tuple.args
111 if len(args) < 3:
112 step_pos = range_function.pos
113 step_value = 1
114 step = ExprNodes.IntNode(step_pos, value=1)
115 else:
116 step = args[2]
117 step_pos = step.pos
118 if step.constant_result is ExprNodes.not_a_constant:
119 # cannot determine step direction
120 return node
121 try:
122 # FIXME: check how Python handles rounding here, e.g. from float
123 step_value = int(step.constant_result)
124 except:
125 return node
126 if not isinstance(step, ExprNodes.IntNode):
127 step = ExprNodes.IntNode(step_pos, value=step_value)
129 if step_value > 0:
130 relation1 = '<='
131 relation2 = '<'
132 elif step_value < 0:
133 step.value = -step_value
134 relation1 = '>='
135 relation2 = '>'
136 else:
137 return node
139 if len(args) == 1:
140 bound1 = ExprNodes.IntNode(range_function.pos, value=0)
141 bound2 = args[0]
142 else:
143 bound1 = args[0]
144 bound2 = args[1]
146 for_node = Nodes.ForFromStatNode(
147 node.pos,
148 target=node.target,
149 bound1=bound1, relation1=relation1,
150 relation2=relation2, bound2=bound2,
151 step=step, body=node.body,
152 else_clause=node.else_clause,
153 loopvar_name = node.target.entry.cname)
154 for_node.reanalyse_c_loop(self.current_scope)
155 # for_node.analyse_expressions(self.current_scope)
156 return for_node
158 def _transform_dict_iteration(self, node, dict_obj, keys, values):
159 py_object_ptr = PyrexTypes.c_void_ptr_type
161 temps = []
162 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
163 temps.append(temp)
164 dict_temp = temp.ref(dict_obj.pos)
165 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
166 temps.append(temp)
167 pos_temp = temp.ref(node.pos)
168 pos_temp_addr = ExprNodes.AmpersandNode(
169 node.pos, operand=pos_temp,
170 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
171 if keys:
172 temp = UtilNodes.TempHandle(py_object_ptr)
173 temps.append(temp)
174 key_temp = temp.ref(node.target.pos)
175 key_temp_addr = ExprNodes.AmpersandNode(
176 node.target.pos, operand=key_temp,
177 type=PyrexTypes.c_ptr_type(py_object_ptr))
178 else:
179 key_temp_addr = key_temp = ExprNodes.NullNode(
180 pos=node.target.pos)
181 if values:
182 temp = UtilNodes.TempHandle(py_object_ptr)
183 temps.append(temp)
184 value_temp = temp.ref(node.target.pos)
185 value_temp_addr = ExprNodes.AmpersandNode(
186 node.target.pos, operand=value_temp,
187 type=PyrexTypes.c_ptr_type(py_object_ptr))
188 else:
189 value_temp_addr = value_temp = ExprNodes.NullNode(
190 pos=node.target.pos)
192 key_target = value_target = node.target
193 tuple_target = None
194 if keys and values:
195 if node.target.is_sequence_constructor:
196 if len(node.target.args) == 2:
197 key_target, value_target = node.target.args
198 else:
199 # unusual case that may or may not lead to an error
200 return node
201 else:
202 tuple_target = node.target
204 def coerce_object_to(obj_node, dest_type):
205 class FakeEnv(object):
206 nogil = False
207 if dest_type.is_pyobject:
208 if dest_type.is_extension_type or dest_type.is_builtin_type:
209 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
210 result = ExprNodes.TypecastNode(
211 obj_node.pos,
212 operand = obj_node,
213 type = dest_type)
214 return (result, None)
215 else:
216 temp = UtilNodes.TempHandle(dest_type)
217 temps.append(temp)
218 temp_result = temp.ref(obj_node.pos)
219 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
220 def result(self):
221 return temp_result.result()
222 def generate_execution_code(self, code):
223 self.generate_result_code(code)
224 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
226 if isinstance(node.body, Nodes.StatListNode):
227 body = node.body
228 else:
229 body = Nodes.StatListNode(pos = node.body.pos,
230 stats = [node.body])
232 if tuple_target:
233 tuple_result = ExprNodes.TupleNode(
234 pos = tuple_target.pos,
235 args = [key_temp, value_temp],
236 is_temp = 1,
237 type = Builtin.tuple_type,
238 )
239 body.stats.insert(
240 0, Nodes.SingleAssignmentNode(
241 pos = tuple_target.pos,
242 lhs = tuple_target,
243 rhs = tuple_result))
244 else:
245 # execute all coercions before the assignments
246 coercion_stats = []
247 assign_stats = []
248 if keys:
249 temp_result, coercion = coerce_object_to(
250 key_temp, key_target.type)
251 if coercion:
252 coercion_stats.append(coercion)
253 assign_stats.append(
254 Nodes.SingleAssignmentNode(
255 pos = key_temp.pos,
256 lhs = key_target,
257 rhs = temp_result))
258 if values:
259 temp_result, coercion = coerce_object_to(
260 value_temp, value_target.type)
261 if coercion:
262 coercion_stats.append(coercion)
263 assign_stats.append(
264 Nodes.SingleAssignmentNode(
265 pos = value_temp.pos,
266 lhs = value_target,
267 rhs = temp_result))
268 body.stats[0:0] = coercion_stats + assign_stats
270 result_code = [
271 Nodes.SingleAssignmentNode(
272 pos = dict_obj.pos,
273 lhs = dict_temp,
274 rhs = dict_obj),
275 Nodes.SingleAssignmentNode(
276 pos = node.pos,
277 lhs = pos_temp,
278 rhs = ExprNodes.IntNode(node.pos, value=0)),
279 Nodes.WhileStatNode(
280 pos = node.pos,
281 condition = ExprNodes.SimpleCallNode(
282 pos = dict_obj.pos,
283 type = PyrexTypes.c_bint_type,
284 function = ExprNodes.NameNode(
285 pos = dict_obj.pos,
286 name = self.PyDict_Next_name,
287 type = self.PyDict_Next_func_type,
288 entry = self.PyDict_Next_entry),
289 args = [dict_temp, pos_temp_addr,
290 key_temp_addr, value_temp_addr]
291 ),
292 body = body,
293 else_clause = node.else_clause
294 )
295 ]
297 return UtilNodes.TempsBlockNode(
298 node.pos, temps=temps,
299 body=Nodes.StatListNode(
300 node.pos,
301 stats = result_code
302 ))
305 class SwitchTransform(Visitor.VisitorTransform):
306 """
307 This transformation tries to turn long if statements into C switch statements.
308 The requirement is that every clause be an (or of) var == value, where the var
309 is common among all clauses and both var and value are ints.
310 """
311 def extract_conditions(self, cond):
313 if isinstance(cond, ExprNodes.CoerceToTempNode):
314 cond = cond.arg
316 if isinstance(cond, ExprNodes.TypecastNode):
317 cond = cond.operand
319 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
320 and cond.cascade is None
321 and cond.operator == '=='
322 and not cond.is_python_comparison()):
323 if is_common_value(cond.operand1, cond.operand1):
324 if isinstance(cond.operand2, ExprNodes.ConstNode):
325 return cond.operand1, [cond.operand2]
326 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
327 return cond.operand1, [cond.operand2]
328 if is_common_value(cond.operand2, cond.operand2):
329 if isinstance(cond.operand1, ExprNodes.ConstNode):
330 return cond.operand2, [cond.operand1]
331 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
332 return cond.operand2, [cond.operand1]
333 elif (isinstance(cond, ExprNodes.BoolBinopNode)
334 and cond.operator == 'or'):
335 t1, c1 = self.extract_conditions(cond.operand1)
336 t2, c2 = self.extract_conditions(cond.operand2)
337 if is_common_value(t1, t2):
338 return t1, c1+c2
339 return None, None
341 def visit_IfStatNode(self, node):
342 self.visitchildren(node)
343 common_var = None
344 case_count = 0
345 cases = []
346 for if_clause in node.if_clauses:
347 var, conditions = self.extract_conditions(if_clause.condition)
348 if var is None:
349 return node
350 elif common_var is not None and not is_common_value(var, common_var):
351 return node
352 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
353 return node
354 else:
355 common_var = var
356 case_count += len(conditions)
357 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
358 conditions = conditions,
359 body = if_clause.body))
360 if case_count < 2:
361 return node
363 common_var = unwrap_node(common_var)
364 return Nodes.SwitchStatNode(pos = node.pos,
365 test = common_var,
366 cases = cases,
367 else_clause = node.else_clause)
370 def visit_Node(self, node):
371 self.visitchildren(node)
372 return node
375 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
376 """
377 This transformation flattens "x in [val1, ..., valn]" into a sequential list
378 of comparisons.
379 """
381 def visit_PrimaryCmpNode(self, node):
382 self.visitchildren(node)
383 if node.cascade is not None:
384 return node
385 elif node.operator == 'in':
386 conjunction = 'or'
387 eq_or_neq = '=='
388 elif node.operator == 'not_in':
389 conjunction = 'and'
390 eq_or_neq = '!='
391 else:
392 return node
394 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
395 return node
397 args = node.operand2.args
398 if len(args) == 0:
399 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
401 lhs = UtilNodes.ResultRefNode(node.operand1)
403 conds = []
404 for arg in args:
405 cond = ExprNodes.PrimaryCmpNode(
406 pos = node.pos,
407 operand1 = lhs,
408 operator = eq_or_neq,
409 operand2 = arg,
410 cascade = None)
411 conds.append(ExprNodes.TypecastNode(
412 pos = node.pos,
413 operand = cond,
414 type = PyrexTypes.c_bint_type))
415 def concat(left, right):
416 return ExprNodes.BoolBinopNode(
417 pos = node.pos,
418 operator = conjunction,
419 operand1 = left,
420 operand2 = right)
422 condition = reduce(concat, conds)
423 return UtilNodes.EvalWithTempExprNode(lhs, condition)
425 def visit_Node(self, node):
426 self.visitchildren(node)
427 return node
430 class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
431 """Optimise some common instantiation patterns for builtin types.
432 """
433 def visit_GeneralCallNode(self, node):
434 self.visitchildren(node)
435 handler = self._find_handler('general', node.function)
436 if handler is not None:
437 node = handler(node, node.positional_args, node.keyword_args)
438 return node
440 def visit_SimpleCallNode(self, node):
441 self.visitchildren(node)
442 handler = self._find_handler('simple', node.function)
443 if handler is not None:
444 node = handler(node, node.arg_tuple, None)
445 return node
447 def _find_handler(self, call_type, function):
448 if not function.type.is_builtin_type:
449 return None
450 handler = getattr(self, '_handle_%s_%s' % (call_type, function.name), None)
451 if handler is None:
452 handler = getattr(self, '_handle_any_%s' % function.name, None)
453 return handler
455 def _handle_general_dict(self, node, pos_args, kwargs):
456 """Replace dict(a=b,c=d,...) by the underlying keyword dict
457 construction which is done anyway.
458 """
459 if not isinstance(pos_args, ExprNodes.TupleNode):
460 return node
461 if len(pos_args.args) > 0:
462 return node
463 if not isinstance(kwargs, ExprNodes.DictNode):
464 return node
465 if node.starstar_arg:
466 # we could optimise this by updating the kw dict instead
467 return node
468 return kwargs
470 def _handle_simple_set(self, node, pos_args, kwargs):
471 """Replace set([a,b,...]) by a literal set {a,b,...}.
472 """
473 if not isinstance(pos_args, ExprNodes.TupleNode):
474 return node
475 arg_count = len(pos_args.args)
476 if arg_count == 0:
477 return ExprNodes.SetNode(node.pos, args=[],
478 type=Builtin.set_type, is_temp=1)
479 if arg_count > 1:
480 return node
481 iterable = pos_args.args[0]
482 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
483 return ExprNodes.SetNode(node.pos, args=iterable.args,
484 type=Builtin.set_type, is_temp=1)
485 elif isinstance(iterable, ExprNodes.ListComprehensionNode):
486 iterable.__class__ = ExprNodes.SetComprehensionNode
487 iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
488 iterable.pos = node.pos
489 return iterable
490 else:
491 return node
493 def visit_PyTypeTestNode(self, node):
494 """Flatten redundant type checks after tree changes.
495 """
496 old_arg = node.arg
497 self.visitchildren(node)
498 if old_arg is node.arg or node.arg.type != node.type:
499 return node
500 return node.arg
502 def visit_Node(self, node):
503 self.visitchildren(node)
504 return node
507 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
508 """Calculate the result of constant expressions to store it in
509 ``expr_node.constant_result``, and replace trivial cases by their
510 constant result.
511 """
512 def _calculate_const(self, node):
513 if node.constant_result is not ExprNodes.constant_value_not_set:
514 return
516 # make sure we always set the value
517 not_a_constant = ExprNodes.not_a_constant
518 node.constant_result = not_a_constant
520 # check if all children are constant
521 children = self.visitchildren(node)
522 for child_result in children.itervalues():
523 if type(child_result) is list:
524 for child in child_result:
525 if child.constant_result is not_a_constant:
526 return
527 elif child_result.constant_result is not_a_constant:
528 return
530 # now try to calculate the real constant value
531 try:
532 node.calculate_constant_result()
533 # if node.constant_result is not ExprNodes.not_a_constant:
534 # print node.__class__.__name__, node.constant_result
535 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
536 # ignore all 'normal' errors here => no constant result
537 pass
538 except Exception:
539 # this looks like a real error
540 import traceback, sys
541 traceback.print_exc(file=sys.stdout)
543 def visit_ExprNode(self, node):
544 self._calculate_const(node)
545 return node
547 # def visit_NumBinopNode(self, node):
548 def visit_BinopNode(self, node):
549 self._calculate_const(node)
550 if node.type is PyrexTypes.py_object_type:
551 return node
552 if node.constant_result is ExprNodes.not_a_constant:
553 return node
554 # print node.constant_result, node.operand1, node.operand2, node.pos
555 if isinstance(node.operand1, ExprNodes.ConstNode) and \
556 node.type is node.operand1.type:
557 new_node = node.operand1
558 elif isinstance(node.operand2, ExprNodes.ConstNode) and \
559 node.type is node.operand2.type:
560 new_node = node.operand2
561 else:
562 return node
563 new_node.value = new_node.constant_result = node.constant_result
564 new_node = new_node.coerce_to(node.type, self.current_scope)
565 return new_node
567 # in the future, other nodes can have their own handler method here
568 # that can replace them with a constant result node
570 def visit_ModuleNode(self, node):
571 self.current_scope = node.scope
572 self.visitchildren(node)
573 return node
575 def visit_FuncDefNode(self, node):
576 old_scope = self.current_scope
577 self.current_scope = node.entry.scope
578 self.visitchildren(node)
579 self.current_scope = old_scope
580 return node
582 def visit_Node(self, node):
583 self.visitchildren(node)
584 return node
587 class FinalOptimizePhase(Visitor.CythonTransform):
588 """
589 This visitor handles several commuting optimizations, and is run
590 just before the C code generation phase.
592 The optimizations currently implemented in this class are:
593 - Eliminate None assignment and refcounting for first assignment.
594 - isinstance -> typecheck for cdef types
595 """
596 def visit_SingleAssignmentNode(self, node):
597 """Avoid redundant initialisation of local variables before their
598 first assignment.
599 """
600 self.visitchildren(node)
601 if node.first:
602 lhs = node.lhs
603 lhs.lhs_of_first_assignment = True
604 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
605 # Have variable initialized to 0 rather than None
606 lhs.entry.init_to_none = False
607 lhs.entry.init = 0
608 return node
610 def visit_SimpleCallNode(self, node):
611 """Replace generic calls to isinstance(x, type) by a more efficient
612 type check.
613 """
614 self.visitchildren(node)
615 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
616 if node.function.name == 'isinstance':
617 type_arg = node.args[1]
618 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
619 object_module = self.context.find_module('python_object')
620 node.function.entry = object_module.lookup('PyObject_TypeCheck')
621 if node.function.entry is None:
622 return node # only happens when there was an error earlier
623 node.function.type = node.function.entry.type
624 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
625 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
626 return node
