Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 1376:e490ccfecad0

handle value coercion correctly in dict iteration
author Stefan Behnel <scoder@users.berlios.de>
date Tue Nov 25 18:24:52 2008 +0100 (3 years ago)
parents 875a228251ae
children 9f94f4e5b3d7
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 from StringEncoding import EncodedString
11 def unwrap_node(node):
12 while isinstance(node, ExprNodes.PersistentNode):
13 node = node.arg
14 return node
16 def is_common_value(a, b):
17 a = unwrap_node(a)
18 b = unwrap_node(b)
19 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
20 return a.name == b.name
21 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
22 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
23 return False
26 class DictIterTransform(Visitor.VisitorTransform):
27 """Transform a for-in-dict loop into a while loop calling PyDict_Next().
28 """
29 PyDict_Next_func_type = PyrexTypes.CFuncType(
30 PyrexTypes.c_bint_type, [
31 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
32 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
33 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
34 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
35 ])
37 PyDict_Next_name = EncodedString("PyDict_Next")
39 PyDict_Next_entry = Symtab.Entry(
40 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
42 def visit_ForInStatNode(self, node):
43 self.visitchildren(node)
44 iterator = node.iterator.sequence
45 if iterator.type is Builtin.dict_type:
46 # like iterating over dict.keys()
47 dict_obj = iterator
48 keys = True
49 values = False
50 else:
51 if not isinstance(iterator, ExprNodes.SimpleCallNode):
52 return node
53 function = iterator.function
54 if not isinstance(function, ExprNodes.AttributeNode):
55 return node
56 if function.obj.type != Builtin.dict_type:
57 return node
58 dict_obj = function.obj
59 method = function.attribute
61 keys = values = False
62 if method == 'iterkeys':
63 keys = True
64 elif method == 'itervalues':
65 values = True
66 elif method == 'iteritems':
67 keys = values = True
68 else:
69 return node
71 py_object_ptr = PyrexTypes.c_void_ptr_type
73 temps = []
74 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
75 temps.append(temp)
76 dict_temp = temp.ref(dict_obj.pos)
77 pos_temp = node.iterator.counter
78 pos_temp_addr = ExprNodes.AmpersandNode(
79 node.pos, operand=pos_temp,
80 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
81 if keys:
82 temp = UtilNodes.TempHandle(py_object_ptr)
83 temps.append(temp)
84 key_temp = temp.ref(node.target.pos)
85 key_temp_addr = ExprNodes.AmpersandNode(
86 node.target.pos, operand=key_temp,
87 type=PyrexTypes.c_ptr_type(py_object_ptr))
88 else:
89 key_temp_addr = key_temp = ExprNodes.NullNode(
90 pos=node.target.pos)
91 if values:
92 temp = UtilNodes.TempHandle(py_object_ptr)
93 temps.append(temp)
94 value_temp = temp.ref(node.target.pos)
95 value_temp_addr = ExprNodes.AmpersandNode(
96 node.target.pos, operand=value_temp,
97 type=PyrexTypes.c_ptr_type(py_object_ptr))
98 else:
99 value_temp_addr = value_temp = ExprNodes.NullNode(
100 pos=node.target.pos)
102 key_target = value_target = node.target
103 tuple_target = None
104 if keys and values:
105 if node.target.is_sequence_constructor:
106 if len(node.target.args) == 2:
107 key_target, value_target = node.target.args
108 else:
109 # unusual case that may or may not lead to an error
110 return node
111 else:
112 tuple_target = node.target
114 def coerce_object_to(obj_node, dest_type):
115 class FakeEnv(object):
116 nogil = False
117 if dest_type.is_pyobject:
118 if dest_type.is_extension_type or dest_type.is_builtin_type:
119 return (obj_node, ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv()))
120 else:
121 return (obj_node, None)
122 else:
123 temp = UtilNodes.TempHandle(dest_type)
124 temps.append(temp)
125 temp_result = temp.ref(obj_node.pos)
126 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
127 # FIXME: remove this after result-code refactoring
128 def result(self):
129 return temp_result.result()
130 def generate_execution_code(self, code):
131 self.generate_result_code(code)
132 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
134 if isinstance(node.body, Nodes.StatListNode):
135 body = node.body
136 else:
137 body = Nodes.StatListNode(pos = node.body.pos,
138 stats = [node.body])
140 if tuple_target:
141 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
142 temps.append(temp)
143 temp_tuple = temp.ref(tuple_target.pos)
144 class TempTupleNode(ExprNodes.TupleNode):
145 # FIXME: remove this after result-code refactoring
146 def result(self):
147 return temp_tuple.result()
149 tuple_result = TempTupleNode(
150 pos = tuple_target.pos,
151 args = [key_temp, value_temp],
152 is_temp = 1,
153 type = Builtin.tuple_type,
154 )
155 body.stats.insert(0, Nodes.SingleAssignmentNode(
156 pos = tuple_target.pos,
157 lhs = tuple_target,
158 rhs = tuple_result))
159 else:
160 # execute all coercions before the assignments
161 coercion_stats = []
162 assign_stats = []
163 if keys:
164 temp_result, coercion = coerce_object_to(
165 key_temp, key_target.type)
166 if coercion:
167 coercion_stats.append(coercion)
168 assign_stats.append(
169 Nodes.SingleAssignmentNode(
170 pos = key_temp.pos,
171 rhs = temp_result,
172 lhs = key_target))
173 if values:
174 temp_result, coercion = coerce_object_to(
175 value_temp, value_target.type)
176 if coercion:
177 coercion_stats.append(coercion)
178 assign_stats.append(
179 Nodes.SingleAssignmentNode(
180 pos = value_temp.pos,
181 rhs = temp_result,
182 lhs = value_target))
183 body.stats[0:0] = coercion_stats + assign_stats
185 result_code = [
186 Nodes.SingleAssignmentNode(
187 pos = node.pos,
188 lhs = pos_temp,
189 rhs = ExprNodes.IntNode(node.pos, value=0)),
190 Nodes.SingleAssignmentNode(
191 pos = dict_obj.pos,
192 lhs = dict_temp,
193 rhs = dict_obj),
194 Nodes.WhileStatNode(
195 pos = node.pos,
196 condition = ExprNodes.SimpleCallNode(
197 pos = dict_obj.pos,
198 type = PyrexTypes.c_bint_type,
199 function = ExprNodes.NameNode(
200 pos = dict_obj.pos,
201 name = self.PyDict_Next_name,
202 type = self.PyDict_Next_func_type,
203 entry = self.PyDict_Next_entry),
204 args = [dict_temp, pos_temp_addr,
205 key_temp_addr, value_temp_addr]
206 ),
207 body = body,
208 else_clause = node.else_clause
209 )
210 ]
212 return UtilNodes.TempsBlockNode(
213 node.pos, temps=temps,
214 body=Nodes.StatListNode(
215 pos = node.pos,
216 stats = result_code
217 ))
219 def visit_Node(self, node):
220 self.visitchildren(node)
221 return node
224 class SwitchTransform(Visitor.VisitorTransform):
225 """
226 This transformation tries to turn long if statements into C switch statements.
227 The requirement is that every clause be an (or of) var == value, where the var
228 is common among all clauses and both var and value are ints.
229 """
230 def extract_conditions(self, cond):
232 if isinstance(cond, ExprNodes.CoerceToTempNode):
233 cond = cond.arg
235 if isinstance(cond, ExprNodes.TypecastNode):
236 cond = cond.operand
238 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
239 and cond.cascade is None
240 and cond.operator == '=='
241 and not cond.is_python_comparison()):
242 if is_common_value(cond.operand1, cond.operand1):
243 if isinstance(cond.operand2, ExprNodes.ConstNode):
244 return cond.operand1, [cond.operand2]
245 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
246 return cond.operand1, [cond.operand2]
247 if is_common_value(cond.operand2, cond.operand2):
248 if isinstance(cond.operand1, ExprNodes.ConstNode):
249 return cond.operand2, [cond.operand1]
250 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
251 return cond.operand2, [cond.operand1]
252 elif (isinstance(cond, ExprNodes.BoolBinopNode)
253 and cond.operator == 'or'):
254 t1, c1 = self.extract_conditions(cond.operand1)
255 t2, c2 = self.extract_conditions(cond.operand2)
256 if is_common_value(t1, t2):
257 return t1, c1+c2
258 return None, None
260 def visit_IfStatNode(self, node):
261 self.visitchildren(node)
262 common_var = None
263 case_count = 0
264 cases = []
265 for if_clause in node.if_clauses:
266 var, conditions = self.extract_conditions(if_clause.condition)
267 if var is None:
268 return node
269 elif common_var is not None and not is_common_value(var, common_var):
270 return node
271 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
272 return node
273 else:
274 common_var = var
275 case_count += len(conditions)
276 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
277 conditions = conditions,
278 body = if_clause.body))
279 if case_count < 2:
280 return node
282 common_var = unwrap_node(common_var)
283 return Nodes.SwitchStatNode(pos = node.pos,
284 test = common_var,
285 cases = cases,
286 else_clause = node.else_clause)
289 def visit_Node(self, node):
290 self.visitchildren(node)
291 return node
294 class FlattenInListTransform(Visitor.VisitorTransform):
295 """
296 This transformation flattens "x in [val1, ..., valn]" into a sequential list
297 of comparisons.
298 """
300 def visit_PrimaryCmpNode(self, node):
301 self.visitchildren(node)
302 if node.cascade is not None:
303 return node
304 elif node.operator == 'in':
305 conjunction = 'or'
306 eq_or_neq = '=='
307 elif node.operator == 'not_in':
308 conjunction = 'and'
309 eq_or_neq = '!='
310 else:
311 return node
313 if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
314 args = node.operand2.args
315 if len(args) == 0:
316 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
317 else:
318 lhs = ExprNodes.PersistentNode(node.operand1, len(args))
319 conds = []
320 for arg in args:
321 cond = ExprNodes.PrimaryCmpNode(
322 pos = node.pos,
323 operand1 = lhs,
324 operator = eq_or_neq,
325 operand2 = arg,
326 cascade = None)
327 conds.append(ExprNodes.TypecastNode(
328 pos = node.pos,
329 operand = cond,
330 type = PyrexTypes.c_bint_type))
331 def concat(left, right):
332 return ExprNodes.BoolBinopNode(
333 pos = node.pos,
334 operator = conjunction,
335 operand1 = left,
336 operand2 = right)
337 return reduce(concat, conds)
338 else:
339 return node
341 def visit_Node(self, node):
342 self.visitchildren(node)
343 return node
346 class FinalOptimizePhase(Visitor.CythonTransform):
347 """
348 This visitor handles several commuting optimizations, and is run
349 just before the C code generation phase.
351 The optimizations currently implemented in this class are:
352 - Eliminate None assignment and refcounting for first assignment.
353 - isinstance -> typecheck for cdef types
354 """
355 def visit_SingleAssignmentNode(self, node):
356 if node.first:
357 lhs = node.lhs
358 lhs.lhs_of_first_assignment = True
359 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
360 # Have variable initialized to 0 rather than None
361 lhs.entry.init_to_none = False
362 lhs.entry.init = 0
363 return node
365 def visit_SimpleCallNode(self, node):
366 self.visitchildren(node)
367 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
368 if node.function.name == 'isinstance':
369 type_arg = node.args[1]
370 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
371 object_module = self.context.find_module('python_object')
372 node.function.entry = object_module.lookup('PyObject_TypeCheck')
373 if node.function.entry is None:
374 return node # only happens when there was an error earlier
375 node.function.type = node.function.entry.type
376 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
377 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
378 return node