Cython has moved to github.
cython-devel
view Cython/Compiler/Optimize.py @ 1447:36de07933246
merge
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Tue Dec 02 20:27:40 2008 +0100 (3 years ago) |
| parents | 07a018cdcdd8 6dbd25167239 |
| children | 34f96dba6feb |
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 from StringEncoding import EncodedString
11 #def unwrap_node(node):
12 # while isinstance(node, ExprNodes.PersistentNode):
13 # node = node.arg
14 # return node
16 # Temporary hack while PersistentNode is out of order
17 def unwrap_node(node):
18 return node
20 def is_common_value(a, b):
21 a = unwrap_node(a)
22 b = unwrap_node(b)
23 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
24 return a.name == b.name
25 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
26 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
27 return False
30 class DictIterTransform(Visitor.VisitorTransform):
31 """Transform a for-in-dict loop into a while loop calling PyDict_Next().
32 """
33 PyDict_Next_func_type = PyrexTypes.CFuncType(
34 PyrexTypes.c_bint_type, [
35 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
36 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
37 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
38 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
39 ])
41 PyDict_Next_name = EncodedString("PyDict_Next")
43 PyDict_Next_entry = Symtab.Entry(
44 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
46 def visit_ForInStatNode(self, node):
47 self.visitchildren(node)
48 iterator = node.iterator.sequence
49 if iterator.type is Builtin.dict_type:
50 # like iterating over dict.keys()
51 dict_obj = iterator
52 keys = True
53 values = False
54 else:
55 if not isinstance(iterator, ExprNodes.SimpleCallNode):
56 return node
57 function = iterator.function
58 if not isinstance(function, ExprNodes.AttributeNode):
59 return node
60 if function.obj.type != Builtin.dict_type:
61 return node
62 dict_obj = function.obj
63 method = function.attribute
65 keys = values = False
66 if method == 'iterkeys':
67 keys = True
68 elif method == 'itervalues':
69 values = True
70 elif method == 'iteritems':
71 keys = values = True
72 else:
73 return node
75 py_object_ptr = PyrexTypes.c_void_ptr_type
77 temps = []
78 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
79 temps.append(temp)
80 dict_temp = temp.ref(dict_obj.pos)
81 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
82 temps.append(temp)
83 pos_temp = temp.ref(node.pos)
84 pos_temp_addr = ExprNodes.AmpersandNode(
85 node.pos, operand=pos_temp,
86 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
87 if keys:
88 temp = UtilNodes.TempHandle(py_object_ptr)
89 temps.append(temp)
90 key_temp = temp.ref(node.target.pos)
91 key_temp_addr = ExprNodes.AmpersandNode(
92 node.target.pos, operand=key_temp,
93 type=PyrexTypes.c_ptr_type(py_object_ptr))
94 else:
95 key_temp_addr = key_temp = ExprNodes.NullNode(
96 pos=node.target.pos)
97 if values:
98 temp = UtilNodes.TempHandle(py_object_ptr)
99 temps.append(temp)
100 value_temp = temp.ref(node.target.pos)
101 value_temp_addr = ExprNodes.AmpersandNode(
102 node.target.pos, operand=value_temp,
103 type=PyrexTypes.c_ptr_type(py_object_ptr))
104 else:
105 value_temp_addr = value_temp = ExprNodes.NullNode(
106 pos=node.target.pos)
108 key_target = value_target = node.target
109 tuple_target = None
110 if keys and values:
111 if node.target.is_sequence_constructor:
112 if len(node.target.args) == 2:
113 key_target, value_target = node.target.args
114 else:
115 # unusual case that may or may not lead to an error
116 return node
117 else:
118 tuple_target = node.target
120 def coerce_object_to(obj_node, dest_type):
121 class FakeEnv(object):
122 nogil = False
123 if dest_type.is_pyobject:
124 if dest_type.is_extension_type or dest_type.is_builtin_type:
125 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
126 result = ExprNodes.TypecastNode(
127 obj_node.pos,
128 operand = obj_node,
129 type = dest_type)
130 return (result, None)
131 else:
132 temp = UtilNodes.TempHandle(dest_type)
133 temps.append(temp)
134 temp_result = temp.ref(obj_node.pos)
135 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
136 def result(self):
137 return temp_result.result()
138 def generate_execution_code(self, code):
139 self.generate_result_code(code)
140 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
142 if isinstance(node.body, Nodes.StatListNode):
143 body = node.body
144 else:
145 body = Nodes.StatListNode(pos = node.body.pos,
146 stats = [node.body])
148 if tuple_target:
149 tuple_result = ExprNodes.TupleNode(
150 pos = tuple_target.pos,
151 args = [key_temp, value_temp],
152 is_temp = 1,
153 type = Builtin.tuple_type,
154 )
155 body.stats.insert(
156 0, Nodes.SingleAssignmentNode(
157 pos = tuple_target.pos,
158 lhs = tuple_target,
159 rhs = tuple_result))
160 else:
161 # execute all coercions before the assignments
162 coercion_stats = []
163 assign_stats = []
164 if keys:
165 temp_result, coercion = coerce_object_to(
166 key_temp, key_target.type)
167 if coercion:
168 coercion_stats.append(coercion)
169 assign_stats.append(
170 Nodes.SingleAssignmentNode(
171 pos = key_temp.pos,
172 lhs = key_target,
173 rhs = temp_result))
174 if values:
175 temp_result, coercion = coerce_object_to(
176 value_temp, value_target.type)
177 if coercion:
178 coercion_stats.append(coercion)
179 assign_stats.append(
180 Nodes.SingleAssignmentNode(
181 pos = value_temp.pos,
182 lhs = value_target,
183 rhs = temp_result))
184 body.stats[0:0] = coercion_stats + assign_stats
186 result_code = [
187 Nodes.SingleAssignmentNode(
188 pos = dict_obj.pos,
189 lhs = dict_temp,
190 rhs = dict_obj),
191 Nodes.SingleAssignmentNode(
192 pos = node.pos,
193 lhs = pos_temp,
194 rhs = ExprNodes.IntNode(node.pos, value=0)),
195 Nodes.WhileStatNode(
196 pos = node.pos,
197 condition = ExprNodes.SimpleCallNode(
198 pos = dict_obj.pos,
199 type = PyrexTypes.c_bint_type,
200 function = ExprNodes.NameNode(
201 pos = dict_obj.pos,
202 name = self.PyDict_Next_name,
203 type = self.PyDict_Next_func_type,
204 entry = self.PyDict_Next_entry),
205 args = [dict_temp, pos_temp_addr,
206 key_temp_addr, value_temp_addr]
207 ),
208 body = body,
209 else_clause = node.else_clause
210 )
211 ]
213 return UtilNodes.TempsBlockNode(
214 node.pos, temps=temps,
215 body=Nodes.StatListNode(
216 node.pos,
217 stats = result_code
218 ))
220 def visit_Node(self, node):
221 self.visitchildren(node)
222 return node
225 class SwitchTransform(Visitor.VisitorTransform):
226 """
227 This transformation tries to turn long if statements into C switch statements.
228 The requirement is that every clause be an (or of) var == value, where the var
229 is common among all clauses and both var and value are ints.
230 """
231 def extract_conditions(self, cond):
233 if isinstance(cond, ExprNodes.CoerceToTempNode):
234 cond = cond.arg
236 if isinstance(cond, ExprNodes.TypecastNode):
237 cond = cond.operand
239 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
240 and cond.cascade is None
241 and cond.operator == '=='
242 and not cond.is_python_comparison()):
243 if is_common_value(cond.operand1, cond.operand1):
244 if isinstance(cond.operand2, ExprNodes.ConstNode):
245 return cond.operand1, [cond.operand2]
246 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
247 return cond.operand1, [cond.operand2]
248 if is_common_value(cond.operand2, cond.operand2):
249 if isinstance(cond.operand1, ExprNodes.ConstNode):
250 return cond.operand2, [cond.operand1]
251 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
252 return cond.operand2, [cond.operand1]
253 elif (isinstance(cond, ExprNodes.BoolBinopNode)
254 and cond.operator == 'or'):
255 t1, c1 = self.extract_conditions(cond.operand1)
256 t2, c2 = self.extract_conditions(cond.operand2)
257 if is_common_value(t1, t2):
258 return t1, c1+c2
259 return None, None
261 def visit_IfStatNode(self, node):
262 self.visitchildren(node)
263 common_var = None
264 case_count = 0
265 cases = []
266 for if_clause in node.if_clauses:
267 var, conditions = self.extract_conditions(if_clause.condition)
268 if var is None:
269 return node
270 elif common_var is not None and not is_common_value(var, common_var):
271 return node
272 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
273 return node
274 else:
275 common_var = var
276 case_count += len(conditions)
277 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
278 conditions = conditions,
279 body = if_clause.body))
280 if case_count < 2:
281 return node
283 common_var = unwrap_node(common_var)
284 return Nodes.SwitchStatNode(pos = node.pos,
285 test = common_var,
286 cases = cases,
287 else_clause = node.else_clause)
290 def visit_Node(self, node):
291 self.visitchildren(node)
292 return node
295 class FlattenInListTransform(Visitor.VisitorTransform):
296 """
297 This transformation flattens "x in [val1, ..., valn]" into a sequential list
298 of comparisons.
299 """
301 def visit_PrimaryCmpNode(self, node):
302 self.visitchildren(node)
303 if node.cascade is not None:
304 return node
305 elif node.operator == 'in':
306 conjunction = 'or'
307 eq_or_neq = '=='
308 elif node.operator == 'not_in':
309 conjunction = 'and'
310 eq_or_neq = '!='
311 else:
312 return node
314 if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
315 args = node.operand2.args
316 if len(args) == 0:
317 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
319 if node.operand1.is_temp or node.operand1.is_simple():
320 lhs = node.operand1
321 else:
322 # FIXME: allocate temp for evaluated node.operand1
323 return node
325 conds = []
326 for arg in args:
327 cond = ExprNodes.PrimaryCmpNode(
328 pos = node.pos,
329 operand1 = lhs,
330 operator = eq_or_neq,
331 operand2 = arg,
332 cascade = None)
333 conds.append(ExprNodes.TypecastNode(
334 pos = node.pos,
335 operand = cond,
336 type = PyrexTypes.c_bint_type))
337 def concat(left, right):
338 return ExprNodes.BoolBinopNode(
339 pos = node.pos,
340 operator = conjunction,
341 operand1 = left,
342 operand2 = right)
344 return reduce(concat, conds)
345 else:
346 return node
348 def visit_Node(self, node):
349 self.visitchildren(node)
350 return node
353 class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
354 """Optimise some common instantiation patterns for builtin types.
355 """
356 def visit_GeneralCallNode(self, node):
357 """Replace dict(a=b,c=d,...) by the underlying keyword dict
358 construction which is done anyway.
359 """
360 self.visitchildren(node)
361 if not node.function.type.is_builtin_type:
362 return node
363 if node.function.name != 'dict':
364 return node
365 if not isinstance(node.positional_args, ExprNodes.TupleNode):
366 return node
367 if len(node.positional_args.args) > 0:
368 return node
369 if not isinstance(node.keyword_args, ExprNodes.DictNode):
370 return node
371 if node.starstar_arg:
372 # we could optimise this by updating the kw dict instead
373 return node
374 return node.keyword_args
376 def visit_PyTypeTestNode(self, node):
377 """Flatten redundant type checks after tree changes.
378 """
379 old_arg = node.arg
380 self.visitchildren(node)
381 if old_arg is node.arg or node.arg.type != node.type:
382 return node
383 return node.arg
385 def visit_Node(self, node):
386 self.visitchildren(node)
387 return node
390 class FinalOptimizePhase(Visitor.CythonTransform):
391 """
392 This visitor handles several commuting optimizations, and is run
393 just before the C code generation phase.
395 The optimizations currently implemented in this class are:
396 - Eliminate None assignment and refcounting for first assignment.
397 - isinstance -> typecheck for cdef types
398 """
399 def visit_SingleAssignmentNode(self, node):
400 """Avoid redundant initialisation of local variables before their
401 first assignment.
402 """
403 self.visitchildren(node)
404 if node.first:
405 lhs = node.lhs
406 lhs.lhs_of_first_assignment = True
407 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
408 # Have variable initialized to 0 rather than None
409 lhs.entry.init_to_none = False
410 lhs.entry.init = 0
411 return node
413 def visit_SimpleCallNode(self, node):
414 """Replace generic calls to isinstance(x, type) by a more efficient
415 type check.
416 """
417 self.visitchildren(node)
418 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
419 if node.function.name == 'isinstance':
420 type_arg = node.args[1]
421 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
422 object_module = self.context.find_module('python_object')
423 node.function.entry = object_module.lookup('PyObject_TypeCheck')
424 if node.function.entry is None:
425 return node # only happens when there was an error earlier
426 node.function.type = node.function.entry.type
427 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
428 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
429 return node
