cython-devel
changeset 1486:8a2e7b51e770
initial constant folding transform: calculate constant values in node.constant_result
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Sat Dec 13 22:23:00 2008 +0100 (19 months ago) |
| parents | 421a3edf1abf |
| children | 89ab303106e4 |
| files | Cython/Compiler/ExprNodes.py Cython/Compiler/Main.py Cython/Compiler/Optimize.py |
line diff
1.1 --- a/Cython/Compiler/ExprNodes.py Sat Dec 13 21:25:00 2008 +0100
1.2 +++ b/Cython/Compiler/ExprNodes.py Sat Dec 13 22:23:00 2008 +0100
1.3 @@ -22,6 +22,13 @@
1.4 from DebugFlags import debug_disposal_code, debug_temp_alloc, \
1.5 debug_coercion
1.6
1.7 +try:
1.8 + set
1.9 +except NameError:
1.10 + from sets import Set as set
1.11 +
1.12 +not_a_constant = object()
1.13 +constant_value_not_set = object()
1.14
1.15 class ExprNode(Node):
1.16 # subexprs [string] Class var holding names of subexpr node attrs
1.17 @@ -172,6 +179,8 @@
1.18 is_temp = 0
1.19 is_target = 0
1.20
1.21 + constant_result = constant_value_not_set
1.22 +
1.23 def get_child_attrs(self):
1.24 return self.subexprs
1.25 child_attrs = property(fget=get_child_attrs)
1.26 @@ -224,7 +233,17 @@
1.27 # Return the native C type of the result (i.e. the
1.28 # C type of the result_code expression).
1.29 return self.result_ctype or self.type
1.30 -
1.31 +
1.32 + def calculate_constant_result(self):
1.33 + # Calculate the constant result of this expression and store
1.34 + # it in ``self.constant_result``. Does nothing by default,
1.35 + # thus leaving ``self.constant_result`` unknown.
1.36 + #
1.37 + # This must only be called when it is assured that all
1.38 + # sub-expressions have a valid constant_result value. The
1.39 + # ConstantFolding transform will do this.
1.40 + pass
1.41 +
1.42 def compile_time_value(self, denv):
1.43 # Return value of compile-time expression, or report error.
1.44 error(self.pos, "Invalid compile-time expression")
1.45 @@ -736,7 +755,9 @@
1.46 # The constant value None
1.47
1.48 value = "Py_None"
1.49 -
1.50 +
1.51 + constant_result = None
1.52 +
1.53 def compile_time_value(self, denv):
1.54 return None
1.55
1.56 @@ -745,6 +766,8 @@
1.57
1.58 value = "Py_Ellipsis"
1.59
1.60 + constant_result = Ellipsis
1.61 +
1.62 def compile_time_value(self, denv):
1.63 return Ellipsis
1.64
1.65 @@ -775,7 +798,10 @@
1.66 class BoolNode(ConstNode):
1.67 type = PyrexTypes.c_bint_type
1.68 # The constant value True or False
1.69 -
1.70 +
1.71 + def calculate_constant_result(self):
1.72 + self.constant_result = self.value
1.73 +
1.74 def compile_time_value(self, denv):
1.75 return self.value
1.76
1.77 @@ -785,10 +811,14 @@
1.78 class NullNode(ConstNode):
1.79 type = PyrexTypes.c_null_ptr_type
1.80 value = "NULL"
1.81 + constant_result = 0
1.82
1.83
1.84 class CharNode(ConstNode):
1.85 type = PyrexTypes.c_char_type
1.86 +
1.87 + def calculate_constant_result(self):
1.88 + self.constant_result = ord(self.value)
1.89
1.90 def compile_time_value(self, denv):
1.91 return ord(self.value)
1.92 @@ -830,6 +860,9 @@
1.93 else:
1.94 return str(self.value) + self.unsigned + self.longness
1.95
1.96 + def calculate_constant_result(self):
1.97 + self.constant_result = int(self.value, 0)
1.98 +
1.99 def compile_time_value(self, denv):
1.100 return int(self.value, 0)
1.101
1.102 @@ -953,6 +986,9 @@
1.103 # Python long integer literal
1.104 #
1.105 # value string
1.106 +
1.107 + def calculate_constant_result(self):
1.108 + self.constant_result = long(self.value)
1.109
1.110 def compile_time_value(self, denv):
1.111 return long(self.value)
1.112 @@ -978,6 +1014,9 @@
1.113 # Imaginary number literal
1.114 #
1.115 # value float imaginary part
1.116 +
1.117 + def calculate_constant_result(self):
1.118 + self.constant_result = complex(0.0, self.value)
1.119
1.120 def compile_time_value(self, denv):
1.121 return complex(0.0, self.value)
1.122 @@ -1350,6 +1389,9 @@
1.123
1.124 gil_message = "Backquote expression"
1.125
1.126 + def calculate_constant_result(self):
1.127 + self.constant_result = repr(self.arg.constant_result)
1.128 +
1.129 def generate_result_code(self, code):
1.130 code.putln(
1.131 "%s = PyObject_Repr(%s); %s" % (
1.132 @@ -1582,7 +1624,11 @@
1.133 def __init__(self, pos, index, *args, **kw):
1.134 ExprNode.__init__(self, pos, index=index, *args, **kw)
1.135 self._index = index
1.136 -
1.137 +
1.138 + def calculate_constant_result(self):
1.139 + self.constant_result = \
1.140 + self.base.constant_result[self.index.constant_result]
1.141 +
1.142 def compile_time_value(self, denv):
1.143 base = self.base.compile_time_value(denv)
1.144 index = self.index.compile_time_value(denv)
1.145 @@ -1881,7 +1927,11 @@
1.146 # stop ExprNode or None
1.147
1.148 subexprs = ['base', 'start', 'stop']
1.149 -
1.150 +
1.151 + def calculate_constant_result(self):
1.152 + self.constant_result = self.base.constant_result[
1.153 + self.start.constant_result : self.stop.constant_result]
1.154 +
1.155 def compile_time_value(self, denv):
1.156 base = self.base.compile_time_value(denv)
1.157 if self.start is None:
1.158 @@ -2055,7 +2105,13 @@
1.159 # start ExprNode
1.160 # stop ExprNode
1.161 # step ExprNode
1.162 -
1.163 +
1.164 + def calculate_constant_result(self):
1.165 + self.constant_result = self.base.constant_result[
1.166 + self.start.constant_result : \
1.167 + self.stop.constant_result : \
1.168 + self.step.constant_result]
1.169 +
1.170 def compile_time_value(self, denv):
1.171 start = self.start.compile_time_value(denv)
1.172 if self.stop is None:
1.173 @@ -2452,6 +2508,9 @@
1.174 # arg ExprNode
1.175
1.176 subexprs = ['arg']
1.177 +
1.178 + def calculate_constant_result(self):
1.179 + self.constant_result = tuple(self.base.constant_result)
1.180
1.181 def compile_time_value(self, denv):
1.182 arg = self.arg.compile_time_value(denv)
1.183 @@ -2517,7 +2576,13 @@
1.184 self.analyse_as_python_attribute(env)
1.185 return self
1.186 return ExprNode.coerce_to(self, dst_type, env)
1.187 -
1.188 +
1.189 + def calculate_constant_result(self):
1.190 + attr = self.attribute
1.191 + if attr.beginswith("__") and attr.endswith("__"):
1.192 + return
1.193 + self.constant_result = getattr(self.obj.constant_result, attr)
1.194 +
1.195 def compile_time_value(self, denv):
1.196 attr = self.attribute
1.197 if attr.beginswith("__") and attr.endswith("__"):
1.198 @@ -2963,6 +3028,10 @@
1.199 else:
1.200 return Naming.empty_tuple
1.201
1.202 + def calculate_constant_result(self):
1.203 + self.constant_result = tuple([
1.204 + arg.constant_result for arg in self.args])
1.205 +
1.206 def compile_time_value(self, denv):
1.207 values = self.compile_time_value_list(denv)
1.208 try:
1.209 @@ -3058,6 +3127,10 @@
1.210 else:
1.211 SequenceNode.release_temp(self, env)
1.212
1.213 + def calculate_constant_result(self):
1.214 + self.constant_result = [
1.215 + arg.constant_result for arg in self.args]
1.216 +
1.217 def compile_time_value(self, denv):
1.218 return self.compile_time_value_list(denv)
1.219
1.220 @@ -3228,13 +3301,13 @@
1.221 self.gil_check(env)
1.222 self.is_temp = 1
1.223
1.224 + def calculate_constant_result(self):
1.225 + self.constant_result = set([
1.226 + arg.constant_result for arg in self.args])
1.227 +
1.228 def compile_time_value(self, denv):
1.229 values = [arg.compile_time_value(denv) for arg in self.args]
1.230 try:
1.231 - set
1.232 - except NameError:
1.233 - from sets import Set as set
1.234 - try:
1.235 return set(values)
1.236 except Exception, e:
1.237 self.compile_time_value_error(e)
1.238 @@ -3264,6 +3337,10 @@
1.239 # obj_conversion_errors [PyrexError] used internally
1.240
1.241 subexprs = ['key_value_pairs']
1.242 +
1.243 + def calculate_constant_result(self):
1.244 + self.constant_result = dict([
1.245 + item.constant_result for item in self.key_value_pairs])
1.246
1.247 def compile_time_value(self, denv):
1.248 pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv))
1.249 @@ -3366,6 +3443,10 @@
1.250 # key ExprNode
1.251 # value ExprNode
1.252 subexprs = ['key', 'value']
1.253 +
1.254 + def calculate_constant_result(self):
1.255 + self.constant_result = (
1.256 + self.key.constant_result, self.value.constant_result)
1.257
1.258 def analyse_types(self, env):
1.259 self.key.analyse_types(env)
1.260 @@ -3507,6 +3588,10 @@
1.261 # - Allocate temporary for result if needed.
1.262
1.263 subexprs = ['operand']
1.264 +
1.265 + def calculate_constant_result(self):
1.266 + func = compile_time_unary_operators[self.operator]
1.267 + self.constant_result = func(self.operand.constant_result)
1.268
1.269 def compile_time_value(self, denv):
1.270 func = compile_time_unary_operators.get(self.operator)
1.271 @@ -3566,7 +3651,10 @@
1.272 # 'not' operator
1.273 #
1.274 # operand ExprNode
1.275 -
1.276 +
1.277 + def calculate_constant_result(self):
1.278 + self.constant_result = not self.operand.constant_result
1.279 +
1.280 def compile_time_value(self, denv):
1.281 operand = self.operand.compile_time_value(denv)
1.282 try:
1.283 @@ -3897,7 +3985,13 @@
1.284 # - Allocate temporary for result if needed.
1.285
1.286 subexprs = ['operand1', 'operand2']
1.287 -
1.288 +
1.289 + def calculate_constant_result(self):
1.290 + func = compile_time_binary_operators[self.operator]
1.291 + self.constant_result = func(
1.292 + self.operand1.constant_result,
1.293 + self.operand2.constant_result)
1.294 +
1.295 def compile_time_value(self, denv):
1.296 func = get_compile_time_binop(self)
1.297 operand1 = self.operand1.compile_time_value(denv)
1.298 @@ -4137,6 +4231,16 @@
1.299 # operand2 ExprNode
1.300
1.301 subexprs = ['operand1', 'operand2']
1.302 +
1.303 + def calculate_constant_result(self):
1.304 + if self.operator == 'and':
1.305 + self.constant_result = \
1.306 + self.operand1.constant_result and \
1.307 + self.operand2.constant_result
1.308 + else:
1.309 + self.constant_result = \
1.310 + self.operand1.constant_result or \
1.311 + self.operand2.constant_result
1.312
1.313 def compile_time_value(self, denv):
1.314 if self.operator == 'and':
1.315 @@ -4261,7 +4365,13 @@
1.316 false_val = None
1.317
1.318 subexprs = ['test', 'true_val', 'false_val']
1.319 -
1.320 +
1.321 + def calculate_constant_result(self):
1.322 + if self.test.constant_result:
1.323 + self.constant_result = self.true_val.constant_result
1.324 + else:
1.325 + self.constant_result = self.false_val.constant_result
1.326 +
1.327 def analyse_types(self, env):
1.328 self.test.analyse_types(env)
1.329 self.test = self.test.coerce_to_boolean(env)
1.330 @@ -4350,6 +4460,15 @@
1.331 class CmpNode:
1.332 # Mixin class containing code common to PrimaryCmpNodes
1.333 # and CascadedCmpNodes.
1.334 +
1.335 + def calculate_cascaded_constant_result(self, operand1_result):
1.336 + func = compile_time_binary_operators[self.operator]
1.337 + operand2_result = self.operand2.constant_result
1.338 + result = func(operand1_result, operand2_result)
1.339 + if result and self.cascade:
1.340 + result = result and \
1.341 + self.cascade.cascaded_compile_time_value(operand2_result)
1.342 + self.constant_result = result
1.343
1.344 def cascaded_compile_time_value(self, operand1, denv):
1.345 func = get_compile_time_binop(self)
1.346 @@ -4362,6 +4481,7 @@
1.347 if result:
1.348 cascade = self.cascade
1.349 if cascade:
1.350 + # FIXME: I bet this must call cascaded_compile_time_value()
1.351 result = result and cascade.compile_time_value(operand2, denv)
1.352 return result
1.353
1.354 @@ -4468,6 +4588,10 @@
1.355 child_attrs = ['operand1', 'operand2', 'cascade']
1.356
1.357 cascade = None
1.358 +
1.359 + def calculate_constant_result(self):
1.360 + self.constant_result = self.calculate_cascaded_constant_result(
1.361 + self.operand1.constant_result)
1.362
1.363 def compile_time_value(self, denv):
1.364 operand1 = self.operand1.compile_time_value(denv)
1.365 @@ -4598,7 +4722,8 @@
1.366 child_attrs = ['operand2', 'cascade']
1.367
1.368 cascade = None
1.369 -
1.370 + constant_result = constant_value_not_set # FIXME: where to calculate this?
1.371 +
1.372 def analyse_types(self, env, operand1):
1.373 self.operand2.analyse_types(env)
1.374 if self.cascade:
2.1 --- a/Cython/Compiler/Main.py Sat Dec 13 21:25:00 2008 +0100
2.2 +++ b/Cython/Compiler/Main.py Sat Dec 13 22:23:00 2008 +0100
2.3 @@ -83,7 +83,7 @@
2.4 from ParseTreeTransforms import AlignFunctionDefinitions
2.5 from AutoDocTransforms import EmbedSignature
2.6 from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform
2.7 - from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase
2.8 + from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
2.9 from Buffer import IntroduceBufferAuxiliaryVars
2.10 from ModuleNode import check_c_declarations
2.11
2.12 @@ -123,6 +123,7 @@
2.13 IntroduceBufferAuxiliaryVars(self),
2.14 _check_c_declarations,
2.15 AnalyseExpressionsTransform(self),
2.16 + ConstantFolding(),
2.17 FlattenBuiltinTypeCreation(),
2.18 DictIterTransform(),
2.19 SwitchTransform(),
3.1 --- a/Cython/Compiler/Optimize.py Sat Dec 13 21:25:00 2008 +0100
3.2 +++ b/Cython/Compiler/Optimize.py Sat Dec 13 22:23:00 2008 +0100
3.3 @@ -387,6 +387,54 @@
3.4 return node
3.5
3.6
3.7 +class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
3.8 + """Calculate the result of constant expressions to store it in
3.9 + ``expr_node.constant_result``, and replace trivial cases by their
3.10 + constant result.
3.11 + """
3.12 + def _calculate_const(self, node):
3.13 + if node.constant_result is not ExprNodes.constant_value_not_set:
3.14 + return
3.15 +
3.16 + # make sure we always set the value
3.17 + not_a_constant = ExprNodes.not_a_constant
3.18 + node.constant_result = not_a_constant
3.19 +
3.20 + # check if all children are constant
3.21 + children = self.visitchildren(node)
3.22 + for child_result in children.itervalues():
3.23 + if type(child_result) is list:
3.24 + for child in child_result:
3.25 + if child.constant_result is not_a_constant:
3.26 + return
3.27 + elif child_result.constant_result is not_a_constant:
3.28 + return
3.29 +
3.30 + # now try to calculate the real constant value
3.31 + try:
3.32 + node.calculate_constant_result()
3.33 +# if node.constant_result is not ExprNodes.not_a_constant:
3.34 +# print node.__class__.__name__, node.constant_result
3.35 + except (ValueError, TypeError, IndexError, AttributeError):
3.36 + # ignore all 'normal' errors here => no constant result
3.37 + pass
3.38 + except Exception:
3.39 + # this looks like a real error
3.40 + import traceback, sys
3.41 + traceback.print_exc(file=sys.stdout)
3.42 +
3.43 + def visit_ExprNode(self, node):
3.44 + self._calculate_const(node)
3.45 + return node
3.46 +
3.47 + # in the future, other nodes can have their own handler method here
3.48 + # that can replace them with a constant result node
3.49 +
3.50 + def visit_Node(self, node):
3.51 + self.visitchildren(node)
3.52 + return node
3.53 +
3.54 +
3.55 class FinalOptimizePhase(Visitor.CythonTransform):
3.56 """
3.57 This visitor handles several commuting optimizations, and is run
