Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 1347:4a4a4eba9cd6

avoid calling TupleNode.allocate_temps() in iter-dict transform
author Stefan Behnel <scoder@users.berlios.de>
date Tue Nov 18 19:36:29 2008 +0100 (3 years ago)
parents db3eb81258f4
children da561ec3162a
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 from StringEncoding import EncodedString
11 def unwrap_node(node):
12 while isinstance(node, ExprNodes.PersistentNode):
13 node = node.arg
14 return node
16 def is_common_value(a, b):
17 a = unwrap_node(a)
18 b = unwrap_node(b)
19 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
20 return a.name == b.name
21 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
22 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
23 return False
26 class DictIterTransform(Visitor.VisitorTransform):
27 """Transform a for-in-dict loop into a while loop calling PyDict_Next().
28 """
29 PyDict_Next_func_type = PyrexTypes.CFuncType(
30 PyrexTypes.c_bint_type, [
31 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
32 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
33 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
34 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
35 ])
37 PyDict_Next_name = EncodedString("PyDict_Next")
39 PyDict_Next_entry = Symtab.Entry(
40 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
42 def visit_ForInStatNode(self, node):
43 self.visitchildren(node)
44 iterator = node.iterator.sequence
45 if not isinstance(iterator, ExprNodes.SimpleCallNode):
46 return node
47 function = iterator.function
48 if not isinstance(function, ExprNodes.AttributeNode):
49 return node
50 if function.obj.type != Builtin.dict_type:
51 return node
52 dict_obj = function.obj
53 method = function.attribute
55 keys = values = False
56 if method == 'iterkeys':
57 keys = True
58 elif method == 'itervalues':
59 values = True
60 elif method == 'iteritems':
61 keys = values = True
62 else:
63 return node
65 py_object_ptr = PyrexTypes.c_void_ptr_type
67 temps = []
68 pos_temp = node.iterator.counter
69 pos_temp_addr = ExprNodes.AmpersandNode(
70 node.pos, operand=pos_temp,
71 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
72 if keys:
73 temp = UtilNodes.TempHandle(py_object_ptr)
74 temps.append(temp)
75 key_temp = temp.ref(node.target.pos)
76 key_temp_addr = ExprNodes.AmpersandNode(
77 node.target.pos, operand=key_temp,
78 type=PyrexTypes.c_ptr_type(py_object_ptr))
79 else:
80 key_temp_addr = key_temp = ExprNodes.NullNode(
81 pos=node.target.pos)
82 if values:
83 temp = UtilNodes.TempHandle(py_object_ptr)
84 temps.append(temp)
85 value_temp = temp.ref(node.target.pos)
86 value_temp_addr = ExprNodes.AmpersandNode(
87 node.target.pos, operand=value_temp,
88 type=PyrexTypes.c_ptr_type(py_object_ptr))
89 else:
90 value_temp_addr = value_temp = ExprNodes.NullNode(
91 pos=node.target.pos)
93 key_target = value_target = node.target
94 tuple_target = None
95 if keys and values:
96 if node.target.is_sequence_constructor:
97 if len(node.target.args) == 2:
98 key_target, value_target = node.target.args
99 else:
100 # FIXME ...
101 return node
102 else:
103 tuple_target = node.target
105 if keys:
106 key_cast = ExprNodes.TypecastNode(
107 pos = key_target.pos,
108 operand = key_temp,
109 type = key_target.type)
110 if values:
111 value_cast = ExprNodes.TypecastNode(
112 pos = value_target.pos,
113 operand = value_temp,
114 type = value_target.type)
116 if isinstance(node.body, Nodes.StatListNode):
117 body = node.body
118 else:
119 body = Nodes.StatListNode(pos = node.body.pos,
120 stats = [node.body])
122 if tuple_target:
123 temp = UtilNodes.TempHandle(py_object_ptr)
124 temps.append(temp)
125 temp_tuple = temp.ref(tuple_target.pos)
126 class TempTupleNode(ExprNodes.TupleNode):
127 # FIXME: remove this after result-code refactoring
128 def result(self):
129 return temp_tuple.result()
131 tuple_result = TempTupleNode(
132 pos = tuple_target.pos,
133 args = [key_cast, value_cast],
134 is_temp = 1,
135 type = Builtin.tuple_type,
136 )
137 body.stats.insert(0, Nodes.SingleAssignmentNode(
138 pos = tuple_target.pos,
139 lhs = tuple_target,
140 rhs = tuple_result))
141 else:
142 if values:
143 body.stats.insert(
144 0, Nodes.SingleAssignmentNode(
145 pos = value_target.pos,
146 lhs = value_target,
147 rhs = value_cast))
148 if keys:
149 body.stats.insert(
150 0, Nodes.SingleAssignmentNode(
151 pos = key_target.pos,
152 lhs = key_target,
153 rhs = key_cast))
155 result_code = [
156 Nodes.SingleAssignmentNode(
157 pos = node.pos,
158 lhs = pos_temp,
159 rhs = ExprNodes.IntNode(node.pos, value=0)),
160 Nodes.WhileStatNode(
161 pos = node.pos,
162 condition = ExprNodes.SimpleCallNode(
163 pos = dict_obj.pos,
164 type = PyrexTypes.c_bint_type,
165 function = ExprNodes.NameNode(
166 pos=dict_obj.pos, name=self.PyDict_Next_name,
167 type = self.PyDict_Next_func_type,
168 entry = self.PyDict_Next_entry),
169 args = [dict_obj, pos_temp_addr,
170 key_temp_addr, value_temp_addr]
171 ),
172 body = body,
173 else_clause = node.else_clause
174 )
175 ]
177 return UtilNodes.TempsBlockNode(
178 node.pos, temps=temps,
179 body=Nodes.StatListNode(
180 pos = node.pos,
181 stats = result_code
182 ))
184 def visit_Node(self, node):
185 self.visitchildren(node)
186 return node
189 class SwitchTransform(Visitor.VisitorTransform):
190 """
191 This transformation tries to turn long if statements into C switch statements.
192 The requirement is that every clause be an (or of) var == value, where the var
193 is common among all clauses and both var and value are ints.
194 """
195 def extract_conditions(self, cond):
197 if isinstance(cond, ExprNodes.CoerceToTempNode):
198 cond = cond.arg
200 if isinstance(cond, ExprNodes.TypecastNode):
201 cond = cond.operand
203 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
204 and cond.cascade is None
205 and cond.operator == '=='
206 and not cond.is_python_comparison()):
207 if is_common_value(cond.operand1, cond.operand1):
208 if isinstance(cond.operand2, ExprNodes.ConstNode):
209 return cond.operand1, [cond.operand2]
210 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
211 return cond.operand1, [cond.operand2]
212 if is_common_value(cond.operand2, cond.operand2):
213 if isinstance(cond.operand1, ExprNodes.ConstNode):
214 return cond.operand2, [cond.operand1]
215 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
216 return cond.operand2, [cond.operand1]
217 elif (isinstance(cond, ExprNodes.BoolBinopNode)
218 and cond.operator == 'or'):
219 t1, c1 = self.extract_conditions(cond.operand1)
220 t2, c2 = self.extract_conditions(cond.operand2)
221 if is_common_value(t1, t2):
222 return t1, c1+c2
223 return None, None
225 def visit_IfStatNode(self, node):
226 self.visitchildren(node)
227 common_var = None
228 case_count = 0
229 cases = []
230 for if_clause in node.if_clauses:
231 var, conditions = self.extract_conditions(if_clause.condition)
232 if var is None:
233 return node
234 elif common_var is not None and not is_common_value(var, common_var):
235 return node
236 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
237 return node
238 else:
239 common_var = var
240 case_count += len(conditions)
241 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
242 conditions = conditions,
243 body = if_clause.body))
244 if case_count < 2:
245 return node
247 common_var = unwrap_node(common_var)
248 return Nodes.SwitchStatNode(pos = node.pos,
249 test = common_var,
250 cases = cases,
251 else_clause = node.else_clause)
254 def visit_Node(self, node):
255 self.visitchildren(node)
256 return node
259 class FlattenInListTransform(Visitor.VisitorTransform):
260 """
261 This transformation flattens "x in [val1, ..., valn]" into a sequential list
262 of comparisons.
263 """
265 def visit_PrimaryCmpNode(self, node):
266 self.visitchildren(node)
267 if node.cascade is not None:
268 return node
269 elif node.operator == 'in':
270 conjunction = 'or'
271 eq_or_neq = '=='
272 elif node.operator == 'not_in':
273 conjunction = 'and'
274 eq_or_neq = '!='
275 else:
276 return node
278 if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
279 args = node.operand2.args
280 if len(args) == 0:
281 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
282 else:
283 lhs = ExprNodes.PersistentNode(node.operand1, len(args))
284 conds = []
285 for arg in args:
286 cond = ExprNodes.PrimaryCmpNode(
287 pos = node.pos,
288 operand1 = lhs,
289 operator = eq_or_neq,
290 operand2 = arg,
291 cascade = None)
292 conds.append(ExprNodes.TypecastNode(
293 pos = node.pos,
294 operand = cond,
295 type = PyrexTypes.c_bint_type))
296 def concat(left, right):
297 return ExprNodes.BoolBinopNode(
298 pos = node.pos,
299 operator = conjunction,
300 operand1 = left,
301 operand2 = right)
302 return reduce(concat, conds)
303 else:
304 return node
306 def visit_Node(self, node):
307 self.visitchildren(node)
308 return node
311 class FinalOptimizePhase(Visitor.CythonTransform):
312 """
313 This visitor handles several commuting optimizations, and is run
314 just before the C code generation phase.
316 The optimizations currently implemented in this class are:
317 - Eliminate None assignment and refcounting for first assignment.
318 - isinstance -> typecheck for cdef types
319 """
320 def visit_SingleAssignmentNode(self, node):
321 if node.first:
322 lhs = node.lhs
323 lhs.lhs_of_first_assignment = True
324 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
325 # Have variable initialized to 0 rather than None
326 lhs.entry.init_to_none = False
327 lhs.entry.init = 0
328 return node
330 def visit_SimpleCallNode(self, node):
331 self.visitchildren(node)
332 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
333 if node.function.name == 'isinstance':
334 type_arg = node.args[1]
335 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
336 object_module = self.context.find_module('python_object')
337 node.function.entry = object_module.lookup('PyObject_TypeCheck')
338 if node.function.entry is None:
339 return node # only happens when there was an error earlier
340 node.function.type = node.function.entry.type
341 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
342 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
343 return node