Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 1523:965dc9fc3da7

tiny cleanup, fix #163
author Stefan Behnel <scoder@users.berlios.de>
date Fri Dec 19 21:57:32 2008 +0100 (3 years ago)
parents 524b90274c1b
children b3a305cf4a4b
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 from StringEncoding import EncodedString
12 from ParseTreeTransforms import SkipDeclarations
14 #def unwrap_node(node):
15 # while isinstance(node, ExprNodes.PersistentNode):
16 # node = node.arg
17 # return node
19 # Temporary hack while PersistentNode is out of order
20 def unwrap_node(node):
21 return node
23 def is_common_value(a, b):
24 a = unwrap_node(a)
25 b = unwrap_node(b)
26 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
27 return a.name == b.name
28 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
29 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
30 return False
33 class IterationTransform(Visitor.VisitorTransform):
34 """Transform some common for-in loop patterns into efficient C loops:
36 - for-in-dict loop becomes a while loop calling PyDict_Next()
37 - for-in-range loop becomes a plain C for loop
38 """
39 PyDict_Next_func_type = PyrexTypes.CFuncType(
40 PyrexTypes.c_bint_type, [
41 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
42 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
43 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
44 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
45 ])
47 PyDict_Next_name = EncodedString("PyDict_Next")
49 PyDict_Next_entry = Symtab.Entry(
50 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
52 def visit_Node(self, node):
53 # descend into statements (loops) and nodes (comprehensions)
54 self.visitchildren(node)
55 return node
57 def visit_ModuleNode(self, node):
58 self.current_scope = node.scope
59 self.visitchildren(node)
60 return node
62 def visit_DefNode(self, node):
63 oldscope = self.current_scope
64 self.current_scope = node.entry.scope
65 self.visitchildren(node)
66 self.current_scope = oldscope
67 return node
69 def visit_ForInStatNode(self, node):
70 self.visitchildren(node)
71 iterator = node.iterator.sequence
72 if iterator.type is Builtin.dict_type:
73 # like iterating over dict.keys()
74 return self._transform_dict_iteration(
75 node, dict_obj=iterator, keys=True, values=False)
76 if not isinstance(iterator, ExprNodes.SimpleCallNode):
77 return node
79 function = iterator.function
80 # dict iteration?
81 if isinstance(function, ExprNodes.AttributeNode) and \
82 function.obj.type == Builtin.dict_type:
83 dict_obj = function.obj
84 method = function.attribute
86 keys = values = False
87 if method == 'iterkeys':
88 keys = True
89 elif method == 'itervalues':
90 values = True
91 elif method == 'iteritems':
92 keys = values = True
93 else:
94 return node
95 return self._transform_dict_iteration(
96 node, dict_obj, keys, values)
98 # range() iteration?
99 if Options.convert_range and node.target.type.is_int:
100 if iterator.self is None and \
101 isinstance(function, ExprNodes.NameNode) and \
102 function.entry.is_builtin and \
103 function.name in ('range', 'xrange'):
104 return self._transform_range_iteration(node, iterator)
106 return node
108 def _transform_range_iteration(self, node, range_function):
109 args = range_function.arg_tuple.args
110 if len(args) < 3:
111 step_pos = range_function.pos
112 step_value = 1
113 step = ExprNodes.IntNode(step_pos, value=1)
114 else:
115 step = args[2]
116 step_pos = step.pos
117 if not isinstance(step.constant_result, (int, long)):
118 # cannot determine step direction
119 return node
120 step_value = step.constant_result
121 if step_value == 0:
122 # will lead to an error elsewhere
123 return node
124 if not isinstance(step, ExprNodes.IntNode):
125 step = ExprNodes.IntNode(step_pos, value=step_value)
127 if step_value < 0:
128 step.value = -step_value
129 relation1 = '>='
130 relation2 = '>'
131 else:
132 relation1 = '<='
133 relation2 = '<'
135 if len(args) == 1:
136 bound1 = ExprNodes.IntNode(range_function.pos, value=0)
137 bound2 = args[0].coerce_to_integer(self.current_scope)
138 else:
139 bound1 = args[0].coerce_to_integer(self.current_scope)
140 bound2 = args[1].coerce_to_integer(self.current_scope)
141 step = step.coerce_to_integer(self.current_scope)
143 for_node = Nodes.ForFromStatNode(
144 node.pos,
145 target=node.target,
146 bound1=bound1, relation1=relation1,
147 relation2=relation2, bound2=bound2,
148 step=step, body=node.body,
149 else_clause=node.else_clause,
150 loopvar_node=node.target)
151 return for_node
153 def _transform_dict_iteration(self, node, dict_obj, keys, values):
154 py_object_ptr = PyrexTypes.c_void_ptr_type
156 temps = []
157 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
158 temps.append(temp)
159 dict_temp = temp.ref(dict_obj.pos)
160 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
161 temps.append(temp)
162 pos_temp = temp.ref(node.pos)
163 pos_temp_addr = ExprNodes.AmpersandNode(
164 node.pos, operand=pos_temp,
165 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
166 if keys:
167 temp = UtilNodes.TempHandle(py_object_ptr)
168 temps.append(temp)
169 key_temp = temp.ref(node.target.pos)
170 key_temp_addr = ExprNodes.AmpersandNode(
171 node.target.pos, operand=key_temp,
172 type=PyrexTypes.c_ptr_type(py_object_ptr))
173 else:
174 key_temp_addr = key_temp = ExprNodes.NullNode(
175 pos=node.target.pos)
176 if values:
177 temp = UtilNodes.TempHandle(py_object_ptr)
178 temps.append(temp)
179 value_temp = temp.ref(node.target.pos)
180 value_temp_addr = ExprNodes.AmpersandNode(
181 node.target.pos, operand=value_temp,
182 type=PyrexTypes.c_ptr_type(py_object_ptr))
183 else:
184 value_temp_addr = value_temp = ExprNodes.NullNode(
185 pos=node.target.pos)
187 key_target = value_target = node.target
188 tuple_target = None
189 if keys and values:
190 if node.target.is_sequence_constructor:
191 if len(node.target.args) == 2:
192 key_target, value_target = node.target.args
193 else:
194 # unusual case that may or may not lead to an error
195 return node
196 else:
197 tuple_target = node.target
199 def coerce_object_to(obj_node, dest_type):
200 class FakeEnv(object):
201 nogil = False
202 if dest_type.is_pyobject:
203 if dest_type.is_extension_type or dest_type.is_builtin_type:
204 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
205 result = ExprNodes.TypecastNode(
206 obj_node.pos,
207 operand = obj_node,
208 type = dest_type)
209 return (result, None)
210 else:
211 temp = UtilNodes.TempHandle(dest_type)
212 temps.append(temp)
213 temp_result = temp.ref(obj_node.pos)
214 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
215 def result(self):
216 return temp_result.result()
217 def generate_execution_code(self, code):
218 self.generate_result_code(code)
219 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
221 if isinstance(node.body, Nodes.StatListNode):
222 body = node.body
223 else:
224 body = Nodes.StatListNode(pos = node.body.pos,
225 stats = [node.body])
227 if tuple_target:
228 tuple_result = ExprNodes.TupleNode(
229 pos = tuple_target.pos,
230 args = [key_temp, value_temp],
231 is_temp = 1,
232 type = Builtin.tuple_type,
233 )
234 body.stats.insert(
235 0, Nodes.SingleAssignmentNode(
236 pos = tuple_target.pos,
237 lhs = tuple_target,
238 rhs = tuple_result))
239 else:
240 # execute all coercions before the assignments
241 coercion_stats = []
242 assign_stats = []
243 if keys:
244 temp_result, coercion = coerce_object_to(
245 key_temp, key_target.type)
246 if coercion:
247 coercion_stats.append(coercion)
248 assign_stats.append(
249 Nodes.SingleAssignmentNode(
250 pos = key_temp.pos,
251 lhs = key_target,
252 rhs = temp_result))
253 if values:
254 temp_result, coercion = coerce_object_to(
255 value_temp, value_target.type)
256 if coercion:
257 coercion_stats.append(coercion)
258 assign_stats.append(
259 Nodes.SingleAssignmentNode(
260 pos = value_temp.pos,
261 lhs = value_target,
262 rhs = temp_result))
263 body.stats[0:0] = coercion_stats + assign_stats
265 result_code = [
266 Nodes.SingleAssignmentNode(
267 pos = dict_obj.pos,
268 lhs = dict_temp,
269 rhs = dict_obj),
270 Nodes.SingleAssignmentNode(
271 pos = node.pos,
272 lhs = pos_temp,
273 rhs = ExprNodes.IntNode(node.pos, value=0)),
274 Nodes.WhileStatNode(
275 pos = node.pos,
276 condition = ExprNodes.SimpleCallNode(
277 pos = dict_obj.pos,
278 type = PyrexTypes.c_bint_type,
279 function = ExprNodes.NameNode(
280 pos = dict_obj.pos,
281 name = self.PyDict_Next_name,
282 type = self.PyDict_Next_func_type,
283 entry = self.PyDict_Next_entry),
284 args = [dict_temp, pos_temp_addr,
285 key_temp_addr, value_temp_addr]
286 ),
287 body = body,
288 else_clause = node.else_clause
289 )
290 ]
292 return UtilNodes.TempsBlockNode(
293 node.pos, temps=temps,
294 body=Nodes.StatListNode(
295 node.pos,
296 stats = result_code
297 ))
300 class SwitchTransform(Visitor.VisitorTransform):
301 """
302 This transformation tries to turn long if statements into C switch statements.
303 The requirement is that every clause be an (or of) var == value, where the var
304 is common among all clauses and both var and value are ints.
305 """
306 def extract_conditions(self, cond):
308 if isinstance(cond, ExprNodes.CoerceToTempNode):
309 cond = cond.arg
311 if isinstance(cond, ExprNodes.TypecastNode):
312 cond = cond.operand
314 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
315 and cond.cascade is None
316 and cond.operator == '=='
317 and not cond.is_python_comparison()):
318 if is_common_value(cond.operand1, cond.operand1):
319 if isinstance(cond.operand2, ExprNodes.ConstNode):
320 return cond.operand1, [cond.operand2]
321 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
322 return cond.operand1, [cond.operand2]
323 if is_common_value(cond.operand2, cond.operand2):
324 if isinstance(cond.operand1, ExprNodes.ConstNode):
325 return cond.operand2, [cond.operand1]
326 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
327 return cond.operand2, [cond.operand1]
328 elif (isinstance(cond, ExprNodes.BoolBinopNode)
329 and cond.operator == 'or'):
330 t1, c1 = self.extract_conditions(cond.operand1)
331 t2, c2 = self.extract_conditions(cond.operand2)
332 if is_common_value(t1, t2):
333 return t1, c1+c2
334 return None, None
336 def visit_IfStatNode(self, node):
337 self.visitchildren(node)
338 common_var = None
339 case_count = 0
340 cases = []
341 for if_clause in node.if_clauses:
342 var, conditions = self.extract_conditions(if_clause.condition)
343 if var is None:
344 return node
345 elif common_var is not None and not is_common_value(var, common_var):
346 return node
347 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
348 return node
349 else:
350 common_var = var
351 case_count += len(conditions)
352 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
353 conditions = conditions,
354 body = if_clause.body))
355 if case_count < 2:
356 return node
358 common_var = unwrap_node(common_var)
359 return Nodes.SwitchStatNode(pos = node.pos,
360 test = common_var,
361 cases = cases,
362 else_clause = node.else_clause)
365 def visit_Node(self, node):
366 self.visitchildren(node)
367 return node
370 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
371 """
372 This transformation flattens "x in [val1, ..., valn]" into a sequential list
373 of comparisons.
374 """
376 def visit_PrimaryCmpNode(self, node):
377 self.visitchildren(node)
378 if node.cascade is not None:
379 return node
380 elif node.operator == 'in':
381 conjunction = 'or'
382 eq_or_neq = '=='
383 elif node.operator == 'not_in':
384 conjunction = 'and'
385 eq_or_neq = '!='
386 else:
387 return node
389 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
390 return node
392 args = node.operand2.args
393 if len(args) == 0:
394 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
396 lhs = UtilNodes.ResultRefNode(node.operand1)
398 conds = []
399 for arg in args:
400 cond = ExprNodes.PrimaryCmpNode(
401 pos = node.pos,
402 operand1 = lhs,
403 operator = eq_or_neq,
404 operand2 = arg,
405 cascade = None)
406 conds.append(ExprNodes.TypecastNode(
407 pos = node.pos,
408 operand = cond,
409 type = PyrexTypes.c_bint_type))
410 def concat(left, right):
411 return ExprNodes.BoolBinopNode(
412 pos = node.pos,
413 operator = conjunction,
414 operand1 = left,
415 operand2 = right)
417 condition = reduce(concat, conds)
418 return UtilNodes.EvalWithTempExprNode(lhs, condition)
420 def visit_Node(self, node):
421 self.visitchildren(node)
422 return node
425 class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
426 """Optimise some common instantiation patterns for builtin types.
427 """
428 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
429 PyrexTypes.py_object_type, [
430 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
431 ])
433 PyList_AsTuple_name = EncodedString("PyList_AsTuple")
435 PyList_AsTuple_entry = Symtab.Entry(
436 PyList_AsTuple_name, PyList_AsTuple_name, PyList_AsTuple_func_type)
438 def visit_GeneralCallNode(self, node):
439 self.visitchildren(node)
440 handler = self._find_handler('general', node.function)
441 if handler is not None:
442 node = handler(node, node.positional_args, node.keyword_args)
443 return node
445 def visit_SimpleCallNode(self, node):
446 self.visitchildren(node)
447 handler = self._find_handler('simple', node.function)
448 if handler is not None:
449 node = handler(node, node.arg_tuple)
450 return node
452 def _find_handler(self, call_type, function):
453 if not function.type.is_builtin_type:
454 return None
455 if not isinstance(function, ExprNodes.NameNode):
456 return None
457 handler = getattr(self, '_handle_%s_%s' % (call_type, function.name), None)
458 if handler is None:
459 handler = getattr(self, '_handle_any_%s' % function.name, None)
460 return handler
462 def _handle_general_dict(self, node, pos_args, kwargs):
463 """Replace dict(a=b,c=d,...) by the underlying keyword dict
464 construction which is done anyway.
465 """
466 if not isinstance(pos_args, ExprNodes.TupleNode):
467 return node
468 if len(pos_args.args) > 0:
469 return node
470 if not isinstance(kwargs, ExprNodes.DictNode):
471 return node
472 if node.starstar_arg:
473 # we could optimise this by updating the kw dict instead
474 return node
475 return kwargs
477 def _handle_simple_set(self, node, pos_args):
478 """Replace set([a,b,...]) by a literal set {a,b,...}.
479 """
480 if not isinstance(pos_args, ExprNodes.TupleNode):
481 return node
482 arg_count = len(pos_args.args)
483 if arg_count == 0:
484 return ExprNodes.SetNode(node.pos, args=[],
485 type=Builtin.set_type, is_temp=1)
486 if arg_count > 1:
487 return node
488 iterable = pos_args.args[0]
489 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
490 return ExprNodes.SetNode(node.pos, args=iterable.args,
491 type=Builtin.set_type, is_temp=1)
492 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
493 iterable.type is Builtin.list_type:
494 iterable.target = ExprNodes.SetNode(
495 node.pos, args=[], type=Builtin.set_type, is_temp=1)
496 iterable.type = Builtin.set_type
497 iterable.pos = node.pos
498 return iterable
499 else:
500 return node
502 def _handle_simple_tuple(self, node, pos_args):
503 """Replace tuple([...]) by a call to PyList_AsTuple.
504 """
505 if not isinstance(pos_args, ExprNodes.TupleNode):
506 return node
507 if len(pos_args.args) != 1:
508 return node
509 list_arg = pos_args.args[0]
510 if list_arg.type is not Builtin.list_type:
511 return node
512 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
513 ExprNodes.ListNode)):
514 # everything else may be None => take the safe path
515 return node
517 node.args = pos_args.args
518 node.arg_tuple = None
519 node.type = Builtin.tuple_type
520 node.result_ctype = Builtin.tuple_type
521 node.function = ExprNodes.NameNode(
522 pos = node.pos,
523 name = self.PyList_AsTuple_name,
524 type = self.PyList_AsTuple_func_type,
525 entry = self.PyList_AsTuple_entry)
526 return node
528 def visit_PyTypeTestNode(self, node):
529 """Flatten redundant type checks after tree changes.
530 """
531 old_arg = node.arg
532 self.visitchildren(node)
533 if old_arg is node.arg or node.arg.type != node.type:
534 return node
535 return node.arg
537 def visit_Node(self, node):
538 self.visitchildren(node)
539 return node
542 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
543 """Calculate the result of constant expressions to store it in
544 ``expr_node.constant_result``, and replace trivial cases by their
545 constant result.
546 """
547 def _calculate_const(self, node):
548 if node.constant_result is not ExprNodes.constant_value_not_set:
549 return
551 # make sure we always set the value
552 not_a_constant = ExprNodes.not_a_constant
553 node.constant_result = not_a_constant
555 # check if all children are constant
556 children = self.visitchildren(node)
557 for child_result in children.itervalues():
558 if type(child_result) is list:
559 for child in child_result:
560 if child.constant_result is not_a_constant:
561 return
562 elif child_result.constant_result is not_a_constant:
563 return
565 # now try to calculate the real constant value
566 try:
567 node.calculate_constant_result()
568 # if node.constant_result is not ExprNodes.not_a_constant:
569 # print node.__class__.__name__, node.constant_result
570 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
571 # ignore all 'normal' errors here => no constant result
572 pass
573 except Exception:
574 # this looks like a real error
575 import traceback, sys
576 traceback.print_exc(file=sys.stdout)
578 def visit_ExprNode(self, node):
579 self._calculate_const(node)
580 return node
582 # def visit_NumBinopNode(self, node):
583 def visit_BinopNode(self, node):
584 self._calculate_const(node)
585 if node.type is PyrexTypes.py_object_type:
586 return node
587 if node.constant_result is ExprNodes.not_a_constant:
588 return node
589 # print node.constant_result, node.operand1, node.operand2, node.pos
590 if isinstance(node.operand1, ExprNodes.ConstNode) and \
591 node.type is node.operand1.type:
592 new_node = node.operand1
593 elif isinstance(node.operand2, ExprNodes.ConstNode) and \
594 node.type is node.operand2.type:
595 new_node = node.operand2
596 else:
597 return node
598 new_node.value = new_node.constant_result = node.constant_result
599 new_node = new_node.coerce_to(node.type, self.current_scope)
600 return new_node
602 # in the future, other nodes can have their own handler method here
603 # that can replace them with a constant result node
605 def visit_ModuleNode(self, node):
606 self.current_scope = node.scope
607 self.visitchildren(node)
608 return node
610 def visit_FuncDefNode(self, node):
611 old_scope = self.current_scope
612 self.current_scope = node.entry.scope
613 self.visitchildren(node)
614 self.current_scope = old_scope
615 return node
617 def visit_Node(self, node):
618 self.visitchildren(node)
619 return node
622 class FinalOptimizePhase(Visitor.CythonTransform):
623 """
624 This visitor handles several commuting optimizations, and is run
625 just before the C code generation phase.
627 The optimizations currently implemented in this class are:
628 - Eliminate None assignment and refcounting for first assignment.
629 - isinstance -> typecheck for cdef types
630 """
631 def visit_SingleAssignmentNode(self, node):
632 """Avoid redundant initialisation of local variables before their
633 first assignment.
634 """
635 self.visitchildren(node)
636 if node.first:
637 lhs = node.lhs
638 lhs.lhs_of_first_assignment = True
639 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
640 # Have variable initialized to 0 rather than None
641 lhs.entry.init_to_none = False
642 lhs.entry.init = 0
643 return node
645 def visit_SimpleCallNode(self, node):
646 """Replace generic calls to isinstance(x, type) by a more efficient
647 type check.
648 """
649 self.visitchildren(node)
650 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
651 if node.function.name == 'isinstance':
652 type_arg = node.args[1]
653 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
654 object_module = self.context.find_module('python_object')
655 node.function.entry = object_module.lookup('PyObject_TypeCheck')
656 if node.function.entry is None:
657 return node # only happens when there was an error earlier
658 node.function.type = node.function.entry.type
659 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
660 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
661 return node