cython-devel

changeset 1500:c1a7180ac974

moved iter-range() optimisation into a transform (worth a review)
author Stefan Behnel <scoder@users.berlios.de>
date Wed Dec 17 22:29:11 2008 +0100 (19 months ago)
parents 08eb7538e220
children e3ff81c3835b
files Cython/Compiler/Main.py Cython/Compiler/Nodes.py Cython/Compiler/Optimize.py tests/run/r_forloop.pyx
line diff
1.1 --- a/Cython/Compiler/Main.py Wed Dec 17 22:24:04 2008 +0100 1.2 +++ b/Cython/Compiler/Main.py Wed Dec 17 22:29:11 2008 +0100 1.3 @@ -82,7 +82,7 @@ 1.4 from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods 1.5 from ParseTreeTransforms import AlignFunctionDefinitions 1.6 from AutoDocTransforms import EmbedSignature 1.7 - from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform 1.8 + from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform 1.9 from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase 1.10 from Buffer import IntroduceBufferAuxiliaryVars 1.11 from ModuleNode import check_c_declarations 1.12 @@ -125,7 +125,7 @@ 1.13 AnalyseExpressionsTransform(self), 1.14 FlattenBuiltinTypeCreation(), 1.15 ConstantFolding(), 1.16 - DictIterTransform(), 1.17 + IterationTransform(), 1.18 SwitchTransform(), 1.19 FinalOptimizePhase(self), 1.20 # ClearResultCodes(self),
2.1 --- a/Cython/Compiler/Nodes.py Wed Dec 17 22:24:04 2008 +0100 2.2 +++ b/Cython/Compiler/Nodes.py Wed Dec 17 22:29:11 2008 +0100 2.3 @@ -3719,7 +3719,7 @@ 2.4 def analyse_expressions(self, env): 2.5 import ExprNodes 2.6 self.target.analyse_target_types(env) 2.7 - if Options.convert_range and self.target.type.is_int: 2.8 + if False: # Options.convert_range and self.target.type.is_int: 2.9 sequence = self.iterator.sequence 2.10 if isinstance(sequence, ExprNodes.SimpleCallNode) \ 2.11 and sequence.self is None \ 2.12 @@ -3801,7 +3801,11 @@ 2.13 # loopvar_name string 2.14 # py_loopvar_node PyTempNode or None 2.15 child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"] 2.16 - 2.17 + 2.18 + is_py_target = False 2.19 + loopvar_name = None 2.20 + py_loopvar_node = None 2.21 + 2.22 def analyse_declarations(self, env): 2.23 self.target.analyse_target_declaration(env) 2.24 self.body.analyse_declarations(env) 2.25 @@ -3866,6 +3870,13 @@ 2.26 self.bound2.release_temp(env) 2.27 if self.step is not None: 2.28 self.step.release_temp(env) 2.29 + 2.30 + def reanalyse_c_loop(self, env): 2.31 + # only make sure all subnodes have an integer type 2.32 + self.bound1 = self.bound1.coerce_to_integer(env) 2.33 + self.bound2 = self.bound2.coerce_to_integer(env) 2.34 + if self.step is not None: 2.35 + self.step = self.step.coerce_to_integer(env) 2.36 2.37 def generate_execution_code(self, code): 2.38 old_loop_labels = code.new_loop_labels()
3.1 --- a/Cython/Compiler/Optimize.py Wed Dec 17 22:24:04 2008 +0100 3.2 +++ b/Cython/Compiler/Optimize.py Wed Dec 17 22:29:11 2008 +0100 3.3 @@ -6,6 +6,7 @@ 3.4 import UtilNodes 3.5 import TypeSlots 3.6 import Symtab 3.7 +import Options 3.8 from StringEncoding import EncodedString 3.9 3.10 from ParseTreeTransforms import SkipDeclarations 3.11 @@ -29,8 +30,11 @@ 3.12 return False 3.13 3.14 3.15 -class DictIterTransform(Visitor.VisitorTransform): 3.16 - """Transform a for-in-dict loop into a while loop calling PyDict_Next(). 3.17 +class IterationTransform(Visitor.VisitorTransform): 3.18 + """Transform some common for-in loop patterns into efficient C loops: 3.19 + 3.20 + - for-in-dict loop becomes a while loop calling PyDict_Next() 3.21 + - for-in-range loop becomes a plain C for loop 3.22 """ 3.23 PyDict_Next_func_type = PyrexTypes.CFuncType( 3.24 PyrexTypes.c_bint_type, [ 3.25 @@ -50,6 +54,18 @@ 3.26 self.visitchildren(node) 3.27 return node 3.28 3.29 + def visit_ModuleNode(self, node): 3.30 + self.current_scope = node.scope 3.31 + self.visitchildren(node) 3.32 + return node 3.33 + 3.34 + def visit_DefNode(self, node): 3.35 + oldscope = self.current_scope 3.36 + self.current_scope = node.entry.scope 3.37 + self.visitchildren(node) 3.38 + self.current_scope = oldscope 3.39 + return node 3.40 + 3.41 def visit_ForInStatNode(self, node): 3.42 self.visitchildren(node) 3.43 iterator = node.iterator.sequence 3.44 @@ -61,6 +77,7 @@ 3.45 return node 3.46 3.47 function = iterator.function 3.48 + # dict iteration? 3.49 if isinstance(function, ExprNodes.AttributeNode) and \ 3.50 function.obj.type == Builtin.dict_type: 3.51 dict_obj = function.obj 3.52 @@ -77,8 +94,67 @@ 3.53 return node 3.54 return self._transform_dict_iteration( 3.55 node, dict_obj, keys, values) 3.56 + 3.57 + # range() iteration? 3.58 + if Options.convert_range and node.target.type.is_int: 3.59 + if iterator.self is None and \ 3.60 + isinstance(function, ExprNodes.NameNode) and \ 3.61 + function.entry.is_builtin and \ 3.62 + function.name in ('range', 'xrange'): 3.63 + return self._transform_range_iteration( 3.64 + node, iterator) 3.65 + 3.66 return node 3.67 3.68 + def _transform_range_iteration(self, node, range_function): 3.69 + args = range_function.arg_tuple.args 3.70 + if len(args) < 3: 3.71 + step_pos = range_function.pos 3.72 + step_value = 1 3.73 + step = ExprNodes.IntNode(step_pos, value=1) 3.74 + else: 3.75 + step = args[2] 3.76 + step_pos = step.pos 3.77 + if step.constant_result is ExprNodes.not_a_constant: 3.78 + # cannot determine step direction 3.79 + return node 3.80 + try: 3.81 + # FIXME: check how Python handles rounding here, e.g. from float 3.82 + step_value = int(step.constant_result) 3.83 + except: 3.84 + return node 3.85 + if not isinstance(step, ExprNodes.IntNode): 3.86 + step = ExprNodes.IntNode(step_pos, value=step_value) 3.87 + 3.88 + if step_value > 0: 3.89 + relation1 = '<=' 3.90 + relation2 = '<' 3.91 + elif step_value < 0: 3.92 + step.value = -step_value 3.93 + relation1 = '>=' 3.94 + relation2 = '>' 3.95 + else: 3.96 + return node 3.97 + 3.98 + if len(args) == 1: 3.99 + bound1 = ExprNodes.IntNode(range_function.pos, value=0) 3.100 + bound2 = args[0] 3.101 + else: 3.102 + bound1 = args[0] 3.103 + bound2 = args[1] 3.104 + 3.105 + for_node = Nodes.ForFromStatNode( 3.106 + node.pos, 3.107 + target=node.target, 3.108 + bound1=bound1, relation1=relation1, 3.109 + relation2=relation2, bound2=bound2, 3.110 + step=step, body=node.body, 3.111 + else_clause=node.else_clause, 3.112 + loopvar_name = node.target.entry.cname) 3.113 + for_node.reanalyse_c_loop(self.current_scope) 3.114 +# for_node.analyse_expressions(self.current_scope) 3.115 + return for_node 3.116 + 3.117 def _transform_dict_iteration(self, node, dict_obj, keys, values): 3.118 py_object_ptr = PyrexTypes.c_void_ptr_type 3.119
4.1 --- a/tests/run/r_forloop.pyx Wed Dec 17 22:24:04 2008 +0100 4.2 +++ b/tests/run/r_forloop.pyx Wed Dec 17 22:29:11 2008 +0100 4.3 @@ -12,8 +12,22 @@ 4.4 Spam! 4.5 Spam! 4.6 Spam! 4.7 + >>> go_c_all() 4.8 + Spam! 4.9 + Spam! 4.10 + Spam! 4.11 + >>> go_c_all_exprs(1) 4.12 + Spam! 4.13 + >>> go_c_all_exprs(3) 4.14 + Spam! 4.15 + Spam! 4.16 + >>> go_c_calc(2) 4.17 + Spam! 4.18 + Spam! 4.19 >>> go_c_ret() 4.20 2 4.21 + >>> go_c_calc_ret(2) 4.22 + 6 4.23 4.24 >>> go_list() 4.25 Spam! 4.26 @@ -54,6 +68,30 @@ 4.27 for i in range(4): 4.28 print u"Spam!" 4.29 4.30 +def go_c_all(): 4.31 + cdef int i 4.32 + for i in range(8,2,-2): 4.33 + print u"Spam!" 4.34 + 4.35 +def go_c_all_exprs(x): 4.36 + cdef int i 4.37 + for i in range(4*x,2*x,-3): 4.38 + print u"Spam!" 4.39 + 4.40 +def f(x): 4.41 + return 2*x 4.42 + 4.43 +def go_c_calc(x): 4.44 + cdef int i 4.45 + for i in range(2*f(x),f(x), -2): 4.46 + print u"Spam!" 4.47 + 4.48 +def go_c_calc_ret(x): 4.49 + cdef int i 4.50 + for i in range(2*f(x),f(x), -2): 4.51 + if i < 2*f(x): 4.52 + return i 4.53 + 4.54 def go_c_ret(): 4.55 cdef int i 4.56 for i in range(4):