cython-devel

changeset 1486:8a2e7b51e770

initial constant folding transform: calculate constant values in node.constant_result
author Stefan Behnel <scoder@users.berlios.de>
date Sat Dec 13 22:23:00 2008 +0100 (19 months ago)
parents 421a3edf1abf
children 89ab303106e4
files Cython/Compiler/ExprNodes.py Cython/Compiler/Main.py Cython/Compiler/Optimize.py
line diff
1.1 --- a/Cython/Compiler/ExprNodes.py Sat Dec 13 21:25:00 2008 +0100 1.2 +++ b/Cython/Compiler/ExprNodes.py Sat Dec 13 22:23:00 2008 +0100 1.3 @@ -22,6 +22,13 @@ 1.4 from DebugFlags import debug_disposal_code, debug_temp_alloc, \ 1.5 debug_coercion 1.6 1.7 +try: 1.8 + set 1.9 +except NameError: 1.10 + from sets import Set as set 1.11 + 1.12 +not_a_constant = object() 1.13 +constant_value_not_set = object() 1.14 1.15 class ExprNode(Node): 1.16 # subexprs [string] Class var holding names of subexpr node attrs 1.17 @@ -172,6 +179,8 @@ 1.18 is_temp = 0 1.19 is_target = 0 1.20 1.21 + constant_result = constant_value_not_set 1.22 + 1.23 def get_child_attrs(self): 1.24 return self.subexprs 1.25 child_attrs = property(fget=get_child_attrs) 1.26 @@ -224,7 +233,17 @@ 1.27 # Return the native C type of the result (i.e. the 1.28 # C type of the result_code expression). 1.29 return self.result_ctype or self.type 1.30 - 1.31 + 1.32 + def calculate_constant_result(self): 1.33 + # Calculate the constant result of this expression and store 1.34 + # it in ``self.constant_result``. Does nothing by default, 1.35 + # thus leaving ``self.constant_result`` unknown. 1.36 + # 1.37 + # This must only be called when it is assured that all 1.38 + # sub-expressions have a valid constant_result value. The 1.39 + # ConstantFolding transform will do this. 1.40 + pass 1.41 + 1.42 def compile_time_value(self, denv): 1.43 # Return value of compile-time expression, or report error. 1.44 error(self.pos, "Invalid compile-time expression") 1.45 @@ -736,7 +755,9 @@ 1.46 # The constant value None 1.47 1.48 value = "Py_None" 1.49 - 1.50 + 1.51 + constant_result = None 1.52 + 1.53 def compile_time_value(self, denv): 1.54 return None 1.55 1.56 @@ -745,6 +766,8 @@ 1.57 1.58 value = "Py_Ellipsis" 1.59 1.60 + constant_result = Ellipsis 1.61 + 1.62 def compile_time_value(self, denv): 1.63 return Ellipsis 1.64 1.65 @@ -775,7 +798,10 @@ 1.66 class BoolNode(ConstNode): 1.67 type = PyrexTypes.c_bint_type 1.68 # The constant value True or False 1.69 - 1.70 + 1.71 + def calculate_constant_result(self): 1.72 + self.constant_result = self.value 1.73 + 1.74 def compile_time_value(self, denv): 1.75 return self.value 1.76 1.77 @@ -785,10 +811,14 @@ 1.78 class NullNode(ConstNode): 1.79 type = PyrexTypes.c_null_ptr_type 1.80 value = "NULL" 1.81 + constant_result = 0 1.82 1.83 1.84 class CharNode(ConstNode): 1.85 type = PyrexTypes.c_char_type 1.86 + 1.87 + def calculate_constant_result(self): 1.88 + self.constant_result = ord(self.value) 1.89 1.90 def compile_time_value(self, denv): 1.91 return ord(self.value) 1.92 @@ -830,6 +860,9 @@ 1.93 else: 1.94 return str(self.value) + self.unsigned + self.longness 1.95 1.96 + def calculate_constant_result(self): 1.97 + self.constant_result = int(self.value, 0) 1.98 + 1.99 def compile_time_value(self, denv): 1.100 return int(self.value, 0) 1.101 1.102 @@ -953,6 +986,9 @@ 1.103 # Python long integer literal 1.104 # 1.105 # value string 1.106 + 1.107 + def calculate_constant_result(self): 1.108 + self.constant_result = long(self.value) 1.109 1.110 def compile_time_value(self, denv): 1.111 return long(self.value) 1.112 @@ -978,6 +1014,9 @@ 1.113 # Imaginary number literal 1.114 # 1.115 # value float imaginary part 1.116 + 1.117 + def calculate_constant_result(self): 1.118 + self.constant_result = complex(0.0, self.value) 1.119 1.120 def compile_time_value(self, denv): 1.121 return complex(0.0, self.value) 1.122 @@ -1350,6 +1389,9 @@ 1.123 1.124 gil_message = "Backquote expression" 1.125 1.126 + def calculate_constant_result(self): 1.127 + self.constant_result = repr(self.arg.constant_result) 1.128 + 1.129 def generate_result_code(self, code): 1.130 code.putln( 1.131 "%s = PyObject_Repr(%s); %s" % ( 1.132 @@ -1582,7 +1624,11 @@ 1.133 def __init__(self, pos, index, *args, **kw): 1.134 ExprNode.__init__(self, pos, index=index, *args, **kw) 1.135 self._index = index 1.136 - 1.137 + 1.138 + def calculate_constant_result(self): 1.139 + self.constant_result = \ 1.140 + self.base.constant_result[self.index.constant_result] 1.141 + 1.142 def compile_time_value(self, denv): 1.143 base = self.base.compile_time_value(denv) 1.144 index = self.index.compile_time_value(denv) 1.145 @@ -1881,7 +1927,11 @@ 1.146 # stop ExprNode or None 1.147 1.148 subexprs = ['base', 'start', 'stop'] 1.149 - 1.150 + 1.151 + def calculate_constant_result(self): 1.152 + self.constant_result = self.base.constant_result[ 1.153 + self.start.constant_result : self.stop.constant_result] 1.154 + 1.155 def compile_time_value(self, denv): 1.156 base = self.base.compile_time_value(denv) 1.157 if self.start is None: 1.158 @@ -2055,7 +2105,13 @@ 1.159 # start ExprNode 1.160 # stop ExprNode 1.161 # step ExprNode 1.162 - 1.163 + 1.164 + def calculate_constant_result(self): 1.165 + self.constant_result = self.base.constant_result[ 1.166 + self.start.constant_result : \ 1.167 + self.stop.constant_result : \ 1.168 + self.step.constant_result] 1.169 + 1.170 def compile_time_value(self, denv): 1.171 start = self.start.compile_time_value(denv) 1.172 if self.stop is None: 1.173 @@ -2452,6 +2508,9 @@ 1.174 # arg ExprNode 1.175 1.176 subexprs = ['arg'] 1.177 + 1.178 + def calculate_constant_result(self): 1.179 + self.constant_result = tuple(self.base.constant_result) 1.180 1.181 def compile_time_value(self, denv): 1.182 arg = self.arg.compile_time_value(denv) 1.183 @@ -2517,7 +2576,13 @@ 1.184 self.analyse_as_python_attribute(env) 1.185 return self 1.186 return ExprNode.coerce_to(self, dst_type, env) 1.187 - 1.188 + 1.189 + def calculate_constant_result(self): 1.190 + attr = self.attribute 1.191 + if attr.beginswith("__") and attr.endswith("__"): 1.192 + return 1.193 + self.constant_result = getattr(self.obj.constant_result, attr) 1.194 + 1.195 def compile_time_value(self, denv): 1.196 attr = self.attribute 1.197 if attr.beginswith("__") and attr.endswith("__"): 1.198 @@ -2963,6 +3028,10 @@ 1.199 else: 1.200 return Naming.empty_tuple 1.201 1.202 + def calculate_constant_result(self): 1.203 + self.constant_result = tuple([ 1.204 + arg.constant_result for arg in self.args]) 1.205 + 1.206 def compile_time_value(self, denv): 1.207 values = self.compile_time_value_list(denv) 1.208 try: 1.209 @@ -3058,6 +3127,10 @@ 1.210 else: 1.211 SequenceNode.release_temp(self, env) 1.212 1.213 + def calculate_constant_result(self): 1.214 + self.constant_result = [ 1.215 + arg.constant_result for arg in self.args] 1.216 + 1.217 def compile_time_value(self, denv): 1.218 return self.compile_time_value_list(denv) 1.219 1.220 @@ -3228,13 +3301,13 @@ 1.221 self.gil_check(env) 1.222 self.is_temp = 1 1.223 1.224 + def calculate_constant_result(self): 1.225 + self.constant_result = set([ 1.226 + arg.constant_result for arg in self.args]) 1.227 + 1.228 def compile_time_value(self, denv): 1.229 values = [arg.compile_time_value(denv) for arg in self.args] 1.230 try: 1.231 - set 1.232 - except NameError: 1.233 - from sets import Set as set 1.234 - try: 1.235 return set(values) 1.236 except Exception, e: 1.237 self.compile_time_value_error(e) 1.238 @@ -3264,6 +3337,10 @@ 1.239 # obj_conversion_errors [PyrexError] used internally 1.240 1.241 subexprs = ['key_value_pairs'] 1.242 + 1.243 + def calculate_constant_result(self): 1.244 + self.constant_result = dict([ 1.245 + item.constant_result for item in self.key_value_pairs]) 1.246 1.247 def compile_time_value(self, denv): 1.248 pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv)) 1.249 @@ -3366,6 +3443,10 @@ 1.250 # key ExprNode 1.251 # value ExprNode 1.252 subexprs = ['key', 'value'] 1.253 + 1.254 + def calculate_constant_result(self): 1.255 + self.constant_result = ( 1.256 + self.key.constant_result, self.value.constant_result) 1.257 1.258 def analyse_types(self, env): 1.259 self.key.analyse_types(env) 1.260 @@ -3507,6 +3588,10 @@ 1.261 # - Allocate temporary for result if needed. 1.262 1.263 subexprs = ['operand'] 1.264 + 1.265 + def calculate_constant_result(self): 1.266 + func = compile_time_unary_operators[self.operator] 1.267 + self.constant_result = func(self.operand.constant_result) 1.268 1.269 def compile_time_value(self, denv): 1.270 func = compile_time_unary_operators.get(self.operator) 1.271 @@ -3566,7 +3651,10 @@ 1.272 # 'not' operator 1.273 # 1.274 # operand ExprNode 1.275 - 1.276 + 1.277 + def calculate_constant_result(self): 1.278 + self.constant_result = not self.operand.constant_result 1.279 + 1.280 def compile_time_value(self, denv): 1.281 operand = self.operand.compile_time_value(denv) 1.282 try: 1.283 @@ -3897,7 +3985,13 @@ 1.284 # - Allocate temporary for result if needed. 1.285 1.286 subexprs = ['operand1', 'operand2'] 1.287 - 1.288 + 1.289 + def calculate_constant_result(self): 1.290 + func = compile_time_binary_operators[self.operator] 1.291 + self.constant_result = func( 1.292 + self.operand1.constant_result, 1.293 + self.operand2.constant_result) 1.294 + 1.295 def compile_time_value(self, denv): 1.296 func = get_compile_time_binop(self) 1.297 operand1 = self.operand1.compile_time_value(denv) 1.298 @@ -4137,6 +4231,16 @@ 1.299 # operand2 ExprNode 1.300 1.301 subexprs = ['operand1', 'operand2'] 1.302 + 1.303 + def calculate_constant_result(self): 1.304 + if self.operator == 'and': 1.305 + self.constant_result = \ 1.306 + self.operand1.constant_result and \ 1.307 + self.operand2.constant_result 1.308 + else: 1.309 + self.constant_result = \ 1.310 + self.operand1.constant_result or \ 1.311 + self.operand2.constant_result 1.312 1.313 def compile_time_value(self, denv): 1.314 if self.operator == 'and': 1.315 @@ -4261,7 +4365,13 @@ 1.316 false_val = None 1.317 1.318 subexprs = ['test', 'true_val', 'false_val'] 1.319 - 1.320 + 1.321 + def calculate_constant_result(self): 1.322 + if self.test.constant_result: 1.323 + self.constant_result = self.true_val.constant_result 1.324 + else: 1.325 + self.constant_result = self.false_val.constant_result 1.326 + 1.327 def analyse_types(self, env): 1.328 self.test.analyse_types(env) 1.329 self.test = self.test.coerce_to_boolean(env) 1.330 @@ -4350,6 +4460,15 @@ 1.331 class CmpNode: 1.332 # Mixin class containing code common to PrimaryCmpNodes 1.333 # and CascadedCmpNodes. 1.334 + 1.335 + def calculate_cascaded_constant_result(self, operand1_result): 1.336 + func = compile_time_binary_operators[self.operator] 1.337 + operand2_result = self.operand2.constant_result 1.338 + result = func(operand1_result, operand2_result) 1.339 + if result and self.cascade: 1.340 + result = result and \ 1.341 + self.cascade.cascaded_compile_time_value(operand2_result) 1.342 + self.constant_result = result 1.343 1.344 def cascaded_compile_time_value(self, operand1, denv): 1.345 func = get_compile_time_binop(self) 1.346 @@ -4362,6 +4481,7 @@ 1.347 if result: 1.348 cascade = self.cascade 1.349 if cascade: 1.350 + # FIXME: I bet this must call cascaded_compile_time_value() 1.351 result = result and cascade.compile_time_value(operand2, denv) 1.352 return result 1.353 1.354 @@ -4468,6 +4588,10 @@ 1.355 child_attrs = ['operand1', 'operand2', 'cascade'] 1.356 1.357 cascade = None 1.358 + 1.359 + def calculate_constant_result(self): 1.360 + self.constant_result = self.calculate_cascaded_constant_result( 1.361 + self.operand1.constant_result) 1.362 1.363 def compile_time_value(self, denv): 1.364 operand1 = self.operand1.compile_time_value(denv) 1.365 @@ -4598,7 +4722,8 @@ 1.366 child_attrs = ['operand2', 'cascade'] 1.367 1.368 cascade = None 1.369 - 1.370 + constant_result = constant_value_not_set # FIXME: where to calculate this? 1.371 + 1.372 def analyse_types(self, env, operand1): 1.373 self.operand2.analyse_types(env) 1.374 if self.cascade:
2.1 --- a/Cython/Compiler/Main.py Sat Dec 13 21:25:00 2008 +0100 2.2 +++ b/Cython/Compiler/Main.py Sat Dec 13 22:23:00 2008 +0100 2.3 @@ -83,7 +83,7 @@ 2.4 from ParseTreeTransforms import AlignFunctionDefinitions 2.5 from AutoDocTransforms import EmbedSignature 2.6 from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform 2.7 - from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase 2.8 + from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase 2.9 from Buffer import IntroduceBufferAuxiliaryVars 2.10 from ModuleNode import check_c_declarations 2.11 2.12 @@ -123,6 +123,7 @@ 2.13 IntroduceBufferAuxiliaryVars(self), 2.14 _check_c_declarations, 2.15 AnalyseExpressionsTransform(self), 2.16 + ConstantFolding(), 2.17 FlattenBuiltinTypeCreation(), 2.18 DictIterTransform(), 2.19 SwitchTransform(),
3.1 --- a/Cython/Compiler/Optimize.py Sat Dec 13 21:25:00 2008 +0100 3.2 +++ b/Cython/Compiler/Optimize.py Sat Dec 13 22:23:00 2008 +0100 3.3 @@ -387,6 +387,54 @@ 3.4 return node 3.5 3.6 3.7 +class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): 3.8 + """Calculate the result of constant expressions to store it in 3.9 + ``expr_node.constant_result``, and replace trivial cases by their 3.10 + constant result. 3.11 + """ 3.12 + def _calculate_const(self, node): 3.13 + if node.constant_result is not ExprNodes.constant_value_not_set: 3.14 + return 3.15 + 3.16 + # make sure we always set the value 3.17 + not_a_constant = ExprNodes.not_a_constant 3.18 + node.constant_result = not_a_constant 3.19 + 3.20 + # check if all children are constant 3.21 + children = self.visitchildren(node) 3.22 + for child_result in children.itervalues(): 3.23 + if type(child_result) is list: 3.24 + for child in child_result: 3.25 + if child.constant_result is not_a_constant: 3.26 + return 3.27 + elif child_result.constant_result is not_a_constant: 3.28 + return 3.29 + 3.30 + # now try to calculate the real constant value 3.31 + try: 3.32 + node.calculate_constant_result() 3.33 +# if node.constant_result is not ExprNodes.not_a_constant: 3.34 +# print node.__class__.__name__, node.constant_result 3.35 + except (ValueError, TypeError, IndexError, AttributeError): 3.36 + # ignore all 'normal' errors here => no constant result 3.37 + pass 3.38 + except Exception: 3.39 + # this looks like a real error 3.40 + import traceback, sys 3.41 + traceback.print_exc(file=sys.stdout) 3.42 + 3.43 + def visit_ExprNode(self, node): 3.44 + self._calculate_const(node) 3.45 + return node 3.46 + 3.47 + # in the future, other nodes can have their own handler method here 3.48 + # that can replace them with a constant result node 3.49 + 3.50 + def visit_Node(self, node): 3.51 + self.visitchildren(node) 3.52 + return node 3.53 + 3.54 + 3.55 class FinalOptimizePhase(Visitor.CythonTransform): 3.56 """ 3.57 This visitor handles several commuting optimizations, and is run