cython-devel
changeset 1500:c1a7180ac974
moved iter-range() optimisation into a transform (worth a review)
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Wed Dec 17 22:29:11 2008 +0100 (19 months ago) |
| parents | 08eb7538e220 |
| children | e3ff81c3835b |
| files | Cython/Compiler/Main.py Cython/Compiler/Nodes.py Cython/Compiler/Optimize.py tests/run/r_forloop.pyx |
line diff
1.1 --- a/Cython/Compiler/Main.py Wed Dec 17 22:24:04 2008 +0100
1.2 +++ b/Cython/Compiler/Main.py Wed Dec 17 22:29:11 2008 +0100
1.3 @@ -82,7 +82,7 @@
1.4 from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
1.5 from ParseTreeTransforms import AlignFunctionDefinitions
1.6 from AutoDocTransforms import EmbedSignature
1.7 - from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform
1.8 + from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
1.9 from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
1.10 from Buffer import IntroduceBufferAuxiliaryVars
1.11 from ModuleNode import check_c_declarations
1.12 @@ -125,7 +125,7 @@
1.13 AnalyseExpressionsTransform(self),
1.14 FlattenBuiltinTypeCreation(),
1.15 ConstantFolding(),
1.16 - DictIterTransform(),
1.17 + IterationTransform(),
1.18 SwitchTransform(),
1.19 FinalOptimizePhase(self),
1.20 # ClearResultCodes(self),
2.1 --- a/Cython/Compiler/Nodes.py Wed Dec 17 22:24:04 2008 +0100
2.2 +++ b/Cython/Compiler/Nodes.py Wed Dec 17 22:29:11 2008 +0100
2.3 @@ -3719,7 +3719,7 @@
2.4 def analyse_expressions(self, env):
2.5 import ExprNodes
2.6 self.target.analyse_target_types(env)
2.7 - if Options.convert_range and self.target.type.is_int:
2.8 + if False: # Options.convert_range and self.target.type.is_int:
2.9 sequence = self.iterator.sequence
2.10 if isinstance(sequence, ExprNodes.SimpleCallNode) \
2.11 and sequence.self is None \
2.12 @@ -3801,7 +3801,11 @@
2.13 # loopvar_name string
2.14 # py_loopvar_node PyTempNode or None
2.15 child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"]
2.16 -
2.17 +
2.18 + is_py_target = False
2.19 + loopvar_name = None
2.20 + py_loopvar_node = None
2.21 +
2.22 def analyse_declarations(self, env):
2.23 self.target.analyse_target_declaration(env)
2.24 self.body.analyse_declarations(env)
2.25 @@ -3866,6 +3870,13 @@
2.26 self.bound2.release_temp(env)
2.27 if self.step is not None:
2.28 self.step.release_temp(env)
2.29 +
2.30 + def reanalyse_c_loop(self, env):
2.31 + # only make sure all subnodes have an integer type
2.32 + self.bound1 = self.bound1.coerce_to_integer(env)
2.33 + self.bound2 = self.bound2.coerce_to_integer(env)
2.34 + if self.step is not None:
2.35 + self.step = self.step.coerce_to_integer(env)
2.36
2.37 def generate_execution_code(self, code):
2.38 old_loop_labels = code.new_loop_labels()
3.1 --- a/Cython/Compiler/Optimize.py Wed Dec 17 22:24:04 2008 +0100
3.2 +++ b/Cython/Compiler/Optimize.py Wed Dec 17 22:29:11 2008 +0100
3.3 @@ -6,6 +6,7 @@
3.4 import UtilNodes
3.5 import TypeSlots
3.6 import Symtab
3.7 +import Options
3.8 from StringEncoding import EncodedString
3.9
3.10 from ParseTreeTransforms import SkipDeclarations
3.11 @@ -29,8 +30,11 @@
3.12 return False
3.13
3.14
3.15 -class DictIterTransform(Visitor.VisitorTransform):
3.16 - """Transform a for-in-dict loop into a while loop calling PyDict_Next().
3.17 +class IterationTransform(Visitor.VisitorTransform):
3.18 + """Transform some common for-in loop patterns into efficient C loops:
3.19 +
3.20 + - for-in-dict loop becomes a while loop calling PyDict_Next()
3.21 + - for-in-range loop becomes a plain C for loop
3.22 """
3.23 PyDict_Next_func_type = PyrexTypes.CFuncType(
3.24 PyrexTypes.c_bint_type, [
3.25 @@ -50,6 +54,18 @@
3.26 self.visitchildren(node)
3.27 return node
3.28
3.29 + def visit_ModuleNode(self, node):
3.30 + self.current_scope = node.scope
3.31 + self.visitchildren(node)
3.32 + return node
3.33 +
3.34 + def visit_DefNode(self, node):
3.35 + oldscope = self.current_scope
3.36 + self.current_scope = node.entry.scope
3.37 + self.visitchildren(node)
3.38 + self.current_scope = oldscope
3.39 + return node
3.40 +
3.41 def visit_ForInStatNode(self, node):
3.42 self.visitchildren(node)
3.43 iterator = node.iterator.sequence
3.44 @@ -61,6 +77,7 @@
3.45 return node
3.46
3.47 function = iterator.function
3.48 + # dict iteration?
3.49 if isinstance(function, ExprNodes.AttributeNode) and \
3.50 function.obj.type == Builtin.dict_type:
3.51 dict_obj = function.obj
3.52 @@ -77,8 +94,67 @@
3.53 return node
3.54 return self._transform_dict_iteration(
3.55 node, dict_obj, keys, values)
3.56 +
3.57 + # range() iteration?
3.58 + if Options.convert_range and node.target.type.is_int:
3.59 + if iterator.self is None and \
3.60 + isinstance(function, ExprNodes.NameNode) and \
3.61 + function.entry.is_builtin and \
3.62 + function.name in ('range', 'xrange'):
3.63 + return self._transform_range_iteration(
3.64 + node, iterator)
3.65 +
3.66 return node
3.67
3.68 + def _transform_range_iteration(self, node, range_function):
3.69 + args = range_function.arg_tuple.args
3.70 + if len(args) < 3:
3.71 + step_pos = range_function.pos
3.72 + step_value = 1
3.73 + step = ExprNodes.IntNode(step_pos, value=1)
3.74 + else:
3.75 + step = args[2]
3.76 + step_pos = step.pos
3.77 + if step.constant_result is ExprNodes.not_a_constant:
3.78 + # cannot determine step direction
3.79 + return node
3.80 + try:
3.81 + # FIXME: check how Python handles rounding here, e.g. from float
3.82 + step_value = int(step.constant_result)
3.83 + except:
3.84 + return node
3.85 + if not isinstance(step, ExprNodes.IntNode):
3.86 + step = ExprNodes.IntNode(step_pos, value=step_value)
3.87 +
3.88 + if step_value > 0:
3.89 + relation1 = '<='
3.90 + relation2 = '<'
3.91 + elif step_value < 0:
3.92 + step.value = -step_value
3.93 + relation1 = '>='
3.94 + relation2 = '>'
3.95 + else:
3.96 + return node
3.97 +
3.98 + if len(args) == 1:
3.99 + bound1 = ExprNodes.IntNode(range_function.pos, value=0)
3.100 + bound2 = args[0]
3.101 + else:
3.102 + bound1 = args[0]
3.103 + bound2 = args[1]
3.104 +
3.105 + for_node = Nodes.ForFromStatNode(
3.106 + node.pos,
3.107 + target=node.target,
3.108 + bound1=bound1, relation1=relation1,
3.109 + relation2=relation2, bound2=bound2,
3.110 + step=step, body=node.body,
3.111 + else_clause=node.else_clause,
3.112 + loopvar_name = node.target.entry.cname)
3.113 + for_node.reanalyse_c_loop(self.current_scope)
3.114 +# for_node.analyse_expressions(self.current_scope)
3.115 + return for_node
3.116 +
3.117 def _transform_dict_iteration(self, node, dict_obj, keys, values):
3.118 py_object_ptr = PyrexTypes.c_void_ptr_type
3.119
4.1 --- a/tests/run/r_forloop.pyx Wed Dec 17 22:24:04 2008 +0100
4.2 +++ b/tests/run/r_forloop.pyx Wed Dec 17 22:29:11 2008 +0100
4.3 @@ -12,8 +12,22 @@
4.4 Spam!
4.5 Spam!
4.6 Spam!
4.7 + >>> go_c_all()
4.8 + Spam!
4.9 + Spam!
4.10 + Spam!
4.11 + >>> go_c_all_exprs(1)
4.12 + Spam!
4.13 + >>> go_c_all_exprs(3)
4.14 + Spam!
4.15 + Spam!
4.16 + >>> go_c_calc(2)
4.17 + Spam!
4.18 + Spam!
4.19 >>> go_c_ret()
4.20 2
4.21 + >>> go_c_calc_ret(2)
4.22 + 6
4.23
4.24 >>> go_list()
4.25 Spam!
4.26 @@ -54,6 +68,30 @@
4.27 for i in range(4):
4.28 print u"Spam!"
4.29
4.30 +def go_c_all():
4.31 + cdef int i
4.32 + for i in range(8,2,-2):
4.33 + print u"Spam!"
4.34 +
4.35 +def go_c_all_exprs(x):
4.36 + cdef int i
4.37 + for i in range(4*x,2*x,-3):
4.38 + print u"Spam!"
4.39 +
4.40 +def f(x):
4.41 + return 2*x
4.42 +
4.43 +def go_c_calc(x):
4.44 + cdef int i
4.45 + for i in range(2*f(x),f(x), -2):
4.46 + print u"Spam!"
4.47 +
4.48 +def go_c_calc_ret(x):
4.49 + cdef int i
4.50 + for i in range(2*f(x),f(x), -2):
4.51 + if i < 2*f(x):
4.52 + return i
4.53 +
4.54 def go_c_ret():
4.55 cdef int i
4.56 for i in range(4):
