Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 2194:b9d8cecc8975

general optimisation support for calls to builtin types and their methods

currently providing optimisations for
- getattr(o,a)
- getattr(o,a,d)
- X.append(o)
- L.append(o)
- list.append(L,x)
author Stefan Behnel <scoder@users.berlios.de>
date Sun Mar 29 13:27:55 2009 +0200 (3 years ago)
parents 0c165f57146c
children c3ab9cc6856f
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
11 from Cython.Utils import UtilityCode
12 from StringEncoding import EncodedString
13 from Errors import error
14 from ParseTreeTransforms import SkipDeclarations
16 #def unwrap_node(node):
17 # while isinstance(node, ExprNodes.PersistentNode):
18 # node = node.arg
19 # return node
21 # Temporary hack while PersistentNode is out of order
22 def unwrap_node(node):
23 return node
25 def is_common_value(a, b):
26 a = unwrap_node(a)
27 b = unwrap_node(b)
28 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
29 return a.name == b.name
30 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
31 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
32 return False
35 class IterationTransform(Visitor.VisitorTransform):
36 """Transform some common for-in loop patterns into efficient C loops:
38 - for-in-dict loop becomes a while loop calling PyDict_Next()
39 - for-in-range loop becomes a plain C for loop
40 """
41 PyDict_Next_func_type = PyrexTypes.CFuncType(
42 PyrexTypes.c_bint_type, [
43 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
44 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
45 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
46 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
47 ])
49 PyDict_Next_name = EncodedString("PyDict_Next")
51 PyDict_Next_entry = Symtab.Entry(
52 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
54 visit_Node = Visitor.VisitorTransform.recurse_to_children
56 def visit_ModuleNode(self, node):
57 self.current_scope = node.scope
58 self.visitchildren(node)
59 return node
61 def visit_DefNode(self, node):
62 oldscope = self.current_scope
63 self.current_scope = node.entry.scope
64 self.visitchildren(node)
65 self.current_scope = oldscope
66 return node
68 def visit_ForInStatNode(self, node):
69 self.visitchildren(node)
70 iterator = node.iterator.sequence
71 if iterator.type is Builtin.dict_type:
72 # like iterating over dict.keys()
73 return self._transform_dict_iteration(
74 node, dict_obj=iterator, keys=True, values=False)
75 if not isinstance(iterator, ExprNodes.SimpleCallNode):
76 return node
78 function = iterator.function
79 # dict iteration?
80 if isinstance(function, ExprNodes.AttributeNode) and \
81 function.obj.type == Builtin.dict_type:
82 dict_obj = function.obj
83 method = function.attribute
85 keys = values = False
86 if method == 'iterkeys':
87 keys = True
88 elif method == 'itervalues':
89 values = True
90 elif method == 'iteritems':
91 keys = values = True
92 else:
93 return node
94 return self._transform_dict_iteration(
95 node, dict_obj, keys, values)
97 # range() iteration?
98 if Options.convert_range and node.target.type.is_int:
99 if iterator.self is None and \
100 isinstance(function, ExprNodes.NameNode) and \
101 function.entry.is_builtin and \
102 function.name in ('range', 'xrange'):
103 return self._transform_range_iteration(node, iterator)
105 return node
107 def _transform_range_iteration(self, node, range_function):
108 args = range_function.arg_tuple.args
109 if len(args) < 3:
110 step_pos = range_function.pos
111 step_value = 1
112 step = ExprNodes.IntNode(step_pos, value=1)
113 else:
114 step = args[2]
115 step_pos = step.pos
116 if not isinstance(step.constant_result, (int, long)):
117 # cannot determine step direction
118 return node
119 step_value = step.constant_result
120 if step_value == 0:
121 # will lead to an error elsewhere
122 return node
123 if not isinstance(step, ExprNodes.IntNode):
124 step = ExprNodes.IntNode(step_pos, value=step_value)
126 if step_value < 0:
127 step.value = -step_value
128 relation1 = '>='
129 relation2 = '>'
130 else:
131 relation1 = '<='
132 relation2 = '<'
134 if len(args) == 1:
135 bound1 = ExprNodes.IntNode(range_function.pos, value=0)
136 bound2 = args[0].coerce_to_integer(self.current_scope)
137 else:
138 bound1 = args[0].coerce_to_integer(self.current_scope)
139 bound2 = args[1].coerce_to_integer(self.current_scope)
140 step = step.coerce_to_integer(self.current_scope)
142 for_node = Nodes.ForFromStatNode(
143 node.pos,
144 target=node.target,
145 bound1=bound1, relation1=relation1,
146 relation2=relation2, bound2=bound2,
147 step=step, body=node.body,
148 else_clause=node.else_clause,
149 from_range=True)
150 return for_node
152 def _transform_dict_iteration(self, node, dict_obj, keys, values):
153 py_object_ptr = PyrexTypes.c_void_ptr_type
155 temps = []
156 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
157 temps.append(temp)
158 dict_temp = temp.ref(dict_obj.pos)
159 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
160 temps.append(temp)
161 pos_temp = temp.ref(node.pos)
162 pos_temp_addr = ExprNodes.AmpersandNode(
163 node.pos, operand=pos_temp,
164 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
165 if keys:
166 temp = UtilNodes.TempHandle(py_object_ptr)
167 temps.append(temp)
168 key_temp = temp.ref(node.target.pos)
169 key_temp_addr = ExprNodes.AmpersandNode(
170 node.target.pos, operand=key_temp,
171 type=PyrexTypes.c_ptr_type(py_object_ptr))
172 else:
173 key_temp_addr = key_temp = ExprNodes.NullNode(
174 pos=node.target.pos)
175 if values:
176 temp = UtilNodes.TempHandle(py_object_ptr)
177 temps.append(temp)
178 value_temp = temp.ref(node.target.pos)
179 value_temp_addr = ExprNodes.AmpersandNode(
180 node.target.pos, operand=value_temp,
181 type=PyrexTypes.c_ptr_type(py_object_ptr))
182 else:
183 value_temp_addr = value_temp = ExprNodes.NullNode(
184 pos=node.target.pos)
186 key_target = value_target = node.target
187 tuple_target = None
188 if keys and values:
189 if node.target.is_sequence_constructor:
190 if len(node.target.args) == 2:
191 key_target, value_target = node.target.args
192 else:
193 # unusual case that may or may not lead to an error
194 return node
195 else:
196 tuple_target = node.target
198 def coerce_object_to(obj_node, dest_type):
199 class FakeEnv(object):
200 nogil = False
201 if dest_type.is_pyobject:
202 if dest_type.is_extension_type or dest_type.is_builtin_type:
203 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
204 result = ExprNodes.TypecastNode(
205 obj_node.pos,
206 operand = obj_node,
207 type = dest_type)
208 return (result, None)
209 else:
210 temp = UtilNodes.TempHandle(dest_type)
211 temps.append(temp)
212 temp_result = temp.ref(obj_node.pos)
213 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
214 def result(self):
215 return temp_result.result()
216 def generate_execution_code(self, code):
217 self.generate_result_code(code)
218 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
220 if isinstance(node.body, Nodes.StatListNode):
221 body = node.body
222 else:
223 body = Nodes.StatListNode(pos = node.body.pos,
224 stats = [node.body])
226 if tuple_target:
227 tuple_result = ExprNodes.TupleNode(
228 pos = tuple_target.pos,
229 args = [key_temp, value_temp],
230 is_temp = 1,
231 type = Builtin.tuple_type,
232 )
233 body.stats.insert(
234 0, Nodes.SingleAssignmentNode(
235 pos = tuple_target.pos,
236 lhs = tuple_target,
237 rhs = tuple_result))
238 else:
239 # execute all coercions before the assignments
240 coercion_stats = []
241 assign_stats = []
242 if keys:
243 temp_result, coercion = coerce_object_to(
244 key_temp, key_target.type)
245 if coercion:
246 coercion_stats.append(coercion)
247 assign_stats.append(
248 Nodes.SingleAssignmentNode(
249 pos = key_temp.pos,
250 lhs = key_target,
251 rhs = temp_result))
252 if values:
253 temp_result, coercion = coerce_object_to(
254 value_temp, value_target.type)
255 if coercion:
256 coercion_stats.append(coercion)
257 assign_stats.append(
258 Nodes.SingleAssignmentNode(
259 pos = value_temp.pos,
260 lhs = value_target,
261 rhs = temp_result))
262 body.stats[0:0] = coercion_stats + assign_stats
264 result_code = [
265 Nodes.SingleAssignmentNode(
266 pos = dict_obj.pos,
267 lhs = dict_temp,
268 rhs = dict_obj),
269 Nodes.SingleAssignmentNode(
270 pos = node.pos,
271 lhs = pos_temp,
272 rhs = ExprNodes.IntNode(node.pos, value=0)),
273 Nodes.WhileStatNode(
274 pos = node.pos,
275 condition = ExprNodes.SimpleCallNode(
276 pos = dict_obj.pos,
277 type = PyrexTypes.c_bint_type,
278 function = ExprNodes.NameNode(
279 pos = dict_obj.pos,
280 name = self.PyDict_Next_name,
281 type = self.PyDict_Next_func_type,
282 entry = self.PyDict_Next_entry),
283 args = [dict_temp, pos_temp_addr,
284 key_temp_addr, value_temp_addr]
285 ),
286 body = body,
287 else_clause = node.else_clause
288 )
289 ]
291 return UtilNodes.TempsBlockNode(
292 node.pos, temps=temps,
293 body=Nodes.StatListNode(
294 node.pos,
295 stats = result_code
296 ))
299 class SwitchTransform(Visitor.VisitorTransform):
300 """
301 This transformation tries to turn long if statements into C switch statements.
302 The requirement is that every clause be an (or of) var == value, where the var
303 is common among all clauses and both var and value are ints.
304 """
305 def extract_conditions(self, cond):
307 if isinstance(cond, ExprNodes.CoerceToTempNode):
308 cond = cond.arg
310 if isinstance(cond, ExprNodes.TypecastNode):
311 cond = cond.operand
313 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
314 and cond.cascade is None
315 and cond.operator == '=='
316 and not cond.is_python_comparison()):
317 if is_common_value(cond.operand1, cond.operand1):
318 if isinstance(cond.operand2, ExprNodes.ConstNode):
319 return cond.operand1, [cond.operand2]
320 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
321 return cond.operand1, [cond.operand2]
322 if is_common_value(cond.operand2, cond.operand2):
323 if isinstance(cond.operand1, ExprNodes.ConstNode):
324 return cond.operand2, [cond.operand1]
325 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
326 return cond.operand2, [cond.operand1]
327 elif (isinstance(cond, ExprNodes.BoolBinopNode)
328 and cond.operator == 'or'):
329 t1, c1 = self.extract_conditions(cond.operand1)
330 t2, c2 = self.extract_conditions(cond.operand2)
331 if is_common_value(t1, t2):
332 return t1, c1+c2
333 return None, None
335 def visit_IfStatNode(self, node):
336 self.visitchildren(node)
337 common_var = None
338 case_count = 0
339 cases = []
340 for if_clause in node.if_clauses:
341 var, conditions = self.extract_conditions(if_clause.condition)
342 if var is None:
343 return node
344 elif common_var is not None and not is_common_value(var, common_var):
345 return node
346 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
347 return node
348 else:
349 common_var = var
350 case_count += len(conditions)
351 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
352 conditions = conditions,
353 body = if_clause.body))
354 if case_count < 2:
355 return node
357 common_var = unwrap_node(common_var)
358 return Nodes.SwitchStatNode(pos = node.pos,
359 test = common_var,
360 cases = cases,
361 else_clause = node.else_clause)
363 visit_Node = Visitor.VisitorTransform.recurse_to_children
366 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
367 """
368 This transformation flattens "x in [val1, ..., valn]" into a sequential list
369 of comparisons.
370 """
372 def visit_PrimaryCmpNode(self, node):
373 self.visitchildren(node)
374 if node.cascade is not None:
375 return node
376 elif node.operator == 'in':
377 conjunction = 'or'
378 eq_or_neq = '=='
379 elif node.operator == 'not_in':
380 conjunction = 'and'
381 eq_or_neq = '!='
382 else:
383 return node
385 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
386 return node
388 args = node.operand2.args
389 if len(args) == 0:
390 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
392 lhs = UtilNodes.ResultRefNode(node.operand1)
394 conds = []
395 for arg in args:
396 cond = ExprNodes.PrimaryCmpNode(
397 pos = node.pos,
398 operand1 = lhs,
399 operator = eq_or_neq,
400 operand2 = arg,
401 cascade = None)
402 conds.append(ExprNodes.TypecastNode(
403 pos = node.pos,
404 operand = cond,
405 type = PyrexTypes.c_bint_type))
406 def concat(left, right):
407 return ExprNodes.BoolBinopNode(
408 pos = node.pos,
409 operator = conjunction,
410 operand1 = left,
411 operand2 = right)
413 condition = reduce(concat, conds)
414 return UtilNodes.EvalWithTempExprNode(lhs, condition)
416 visit_Node = Visitor.VisitorTransform.recurse_to_children
419 class OptimiseBuiltinCalls(Visitor.VisitorTransform):
420 """Optimise some common methods calls and instantiation patterns
421 for builtin types.
422 """
423 # only intercept on call nodes
424 visit_Node = Visitor.VisitorTransform.recurse_to_children
426 def visit_GeneralCallNode(self, node):
427 self.visitchildren(node)
428 handler = self._find_handler('general', node.function)
429 if handler is not None:
430 node = handler(node, node.positional_args, node.keyword_args)
431 return node
433 def visit_SimpleCallNode(self, node):
434 self.visitchildren(node)
435 handler = self._find_handler('simple', node.function)
436 if handler is not None:
437 node = handler(node, node.arg_tuple)
438 return node
440 def visit_PyTypeTestNode(self, node):
441 """Flatten redundant type checks after tree changes.
442 """
443 old_arg = node.arg
444 self.visitchildren(node)
445 if old_arg is node.arg or node.arg.type != node.type:
446 return node
447 return node.arg
449 def _find_handler(self, call_type, function):
450 if not function.type.is_pyobject:
451 return None
452 if function.is_name:
453 if not function.type.is_builtin_type and '_' in function.name:
454 # not interesting anyway, so let's play safe here
455 return None
456 match_name = function.name
457 elif isinstance(function, ExprNodes.AttributeNode):
458 if not function.obj.type.is_builtin_type:
459 type_name = "object" # safety measure
460 else:
461 type_name = function.obj.type.name
462 match_name = "%s_%s" % (type_name, function.attribute)
463 else:
464 return None
465 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
466 if handler is None:
467 handler = getattr(self, '_handle_any_%s' % match_name, None)
468 return handler
470 ### builtin types
472 def _handle_general_dict(self, node, pos_args, kwargs):
473 """Replace dict(a=b,c=d,...) by the underlying keyword dict
474 construction which is done anyway.
475 """
476 if not isinstance(pos_args, ExprNodes.TupleNode):
477 return node
478 if len(pos_args.args) > 0:
479 return node
480 if not isinstance(kwargs, ExprNodes.DictNode):
481 return node
482 if node.starstar_arg:
483 # we could optimise this by updating the kw dict instead
484 return node
485 return kwargs
487 def _handle_simple_set(self, node, pos_args):
488 """Replace set([a,b,...]) by a literal set {a,b,...}.
489 """
490 if not isinstance(pos_args, ExprNodes.TupleNode):
491 return node
492 arg_count = len(pos_args.args)
493 if arg_count == 0:
494 return ExprNodes.SetNode(node.pos, args=[],
495 type=Builtin.set_type, is_temp=1)
496 if arg_count > 1:
497 return node
498 iterable = pos_args.args[0]
499 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
500 return ExprNodes.SetNode(node.pos, args=iterable.args,
501 type=Builtin.set_type, is_temp=1)
502 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
503 iterable.type is Builtin.list_type:
504 iterable.target = ExprNodes.SetNode(
505 node.pos, args=[], type=Builtin.set_type, is_temp=1)
506 iterable.type = Builtin.set_type
507 iterable.pos = node.pos
508 return iterable
509 else:
510 return node
512 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
513 Builtin.tuple_type, [
514 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
515 ])
517 def _handle_simple_tuple(self, node, pos_args):
518 """Replace tuple([...]) by a call to PyList_AsTuple.
519 """
520 if not isinstance(pos_args, ExprNodes.TupleNode):
521 return node
522 if len(pos_args.args) != 1:
523 return node
524 list_arg = pos_args.args[0]
525 if list_arg.type is not Builtin.list_type:
526 return node
527 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
528 ExprNodes.ListNode)):
529 # everything else may be None => take the safe path
530 return node
532 return ExprNodes.PythonCapiCallNode(
533 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
534 args = pos_args.args,
535 is_temp = node.is_temp
536 )
538 ### builtin functions
540 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
541 PyrexTypes.py_object_type, [
542 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
543 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
544 ])
546 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
547 PyrexTypes.py_object_type, [
548 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
549 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
550 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
551 ])
553 def _handle_simple_getattr(self, node, pos_args):
554 # not really a builtin *type*, but worth optimising anyway
555 if not isinstance(pos_args, ExprNodes.TupleNode):
556 return node
557 args = pos_args.args
558 if len(args) == 2:
559 node = ExprNodes.PythonCapiCallNode(
560 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
561 args = args,
562 is_temp = node.is_temp
563 )
564 elif len(args) == 3:
565 node = ExprNodes.PythonCapiCallNode(
566 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
567 utility_code = Builtin.getattr3_utility_code,
568 args = args,
569 is_temp = node.is_temp
570 )
571 else:
572 error(node.pos, "getattr() called with wrong number of args, "
573 "expected 2 or 3, found %d" %
574 len(pos_args.args))
575 return node
577 ### methods of builtin types
579 PyObject_Append_func_type = PyrexTypes.CFuncType(
580 PyrexTypes.py_object_type, [
581 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
582 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
583 ])
585 def _handle_simple_object_append(self, node, pos_args):
586 # X.append() is almost always referring to a list
587 if not isinstance(pos_args, ExprNodes.TupleNode):
588 return node
589 if len(pos_args.args) != 1:
590 return node
592 args = [node.function.obj] + pos_args.args
593 return ExprNodes.PythonCapiCallNode(
594 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
595 args = args,
596 is_temp = node.is_temp,
597 utility_code = append_utility_code # FIXME: move to Builtin.py
598 )
600 PyList_Append_func_type = PyrexTypes.CFuncType(
601 PyrexTypes.c_int_type, [
602 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
603 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
604 ],
605 exception_value = "-1")
607 def _handle_simple_list_append(self, node, pos_args):
608 if not isinstance(pos_args, ExprNodes.TupleNode):
609 return node
610 if len(pos_args.args) != 1:
611 error(node.pos, "list.append(x) called with wrong number of args, found %d" %
612 len(pos_args.args))
613 return node
615 obj = node.function.obj
616 # FIXME: obj may need a None check (ticket #166)
617 args = [obj] + pos_args.args
618 return ExprNodes.PythonCapiCallNode(
619 node.pos, "PyList_Append", self.PyList_Append_func_type,
620 args = args,
621 is_temp = node.is_temp
622 )
624 def _handle_simple_type_append(self, node, pos_args):
625 # unbound method call to list.append(L, x) ?
626 if node.function.obj.name != 'list':
627 return node
628 if not isinstance(pos_args, ExprNodes.TupleNode):
629 return node
631 args = pos_args.args
632 if len(args) != 2:
633 error(node.pos, "list.append(x) called with wrong number of args, found %d" %
634 len(pos_args.args))
635 return node
637 # FIXME: this may need a type check on the first operand
638 return ExprNodes.PythonCapiCallNode(
639 node.pos, "PyList_Append", self.PyList_Append_func_type,
640 args = args,
641 is_temp = node.is_temp
642 )
645 append_utility_code = UtilityCode(
646 proto = """
647 static INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
648 if (likely(PyList_CheckExact(L))) {
649 if (PyList_Append(L, x) < 0) return NULL;
650 Py_INCREF(Py_None);
651 return Py_None; /* this is just to have an accurate signature */
652 }
653 else {
654 PyObject *r, *m;
655 m = __Pyx_GetAttrString(L, "append");
656 if (!m) return NULL;
657 r = PyObject_CallFunctionObjArgs(m, x, NULL);
658 Py_DECREF(m);
659 return r;
660 }
661 }
662 """,
663 impl = ""
664 )
667 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
668 """Calculate the result of constant expressions to store it in
669 ``expr_node.constant_result``, and replace trivial cases by their
670 constant result.
671 """
672 def _calculate_const(self, node):
673 if node.constant_result is not ExprNodes.constant_value_not_set:
674 return
676 # make sure we always set the value
677 not_a_constant = ExprNodes.not_a_constant
678 node.constant_result = not_a_constant
680 # check if all children are constant
681 children = self.visitchildren(node)
682 for child_result in children.itervalues():
683 if type(child_result) is list:
684 for child in child_result:
685 if child.constant_result is not_a_constant:
686 return
687 elif child_result.constant_result is not_a_constant:
688 return
690 # now try to calculate the real constant value
691 try:
692 node.calculate_constant_result()
693 # if node.constant_result is not ExprNodes.not_a_constant:
694 # print node.__class__.__name__, node.constant_result
695 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
696 # ignore all 'normal' errors here => no constant result
697 pass
698 except Exception:
699 # this looks like a real error
700 import traceback, sys
701 traceback.print_exc(file=sys.stdout)
703 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
704 ExprNodes.LongNode, ExprNodes.FloatNode)
706 def _widest_node_class(self, *nodes):
707 try:
708 return self.NODE_TYPE_ORDER[
709 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
710 except ValueError:
711 return None
713 def visit_ExprNode(self, node):
714 self._calculate_const(node)
715 return node
717 def visit_BinopNode(self, node):
718 self._calculate_const(node)
719 if node.constant_result is ExprNodes.not_a_constant:
720 return node
721 try:
722 if node.operand1.type is None or node.operand2.type is None:
723 return node
724 except AttributeError:
725 return node
727 type1, type2 = node.operand1.type, node.operand2.type
728 if isinstance(node.operand1, ExprNodes.ConstNode) and \
729 isinstance(node.operand1, ExprNodes.ConstNode):
730 if type1 is type2:
731 new_node = node.operand1
732 else:
733 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
734 if type(node.operand1) is type(node.operand2):
735 new_node = node.operand1
736 new_node.type = widest_type
737 elif type1 is widest_type:
738 new_node = node.operand1
739 elif type2 is widest_type:
740 new_node = node.operand2
741 else:
742 target_class = self._widest_node_class(
743 node.operand1, node.operand2)
744 if target_class is None:
745 return node
746 new_node = target_class(type = widest_type)
747 else:
748 return node
750 new_node.constant_result = node.constant_result
751 new_node.value = str(node.constant_result)
752 #new_node = new_node.coerce_to(node.type, self.current_scope)
753 return new_node
755 # in the future, other nodes can have their own handler method here
756 # that can replace them with a constant result node
758 visit_Node = Visitor.VisitorTransform.recurse_to_children
761 class FinalOptimizePhase(Visitor.CythonTransform):
762 """
763 This visitor handles several commuting optimizations, and is run
764 just before the C code generation phase.
766 The optimizations currently implemented in this class are:
767 - Eliminate None assignment and refcounting for first assignment.
768 - isinstance -> typecheck for cdef types
769 """
770 def visit_SingleAssignmentNode(self, node):
771 """Avoid redundant initialisation of local variables before their
772 first assignment.
773 """
774 self.visitchildren(node)
775 if node.first:
776 lhs = node.lhs
777 lhs.lhs_of_first_assignment = True
778 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
779 # Have variable initialized to 0 rather than None
780 lhs.entry.init_to_none = False
781 lhs.entry.init = 0
782 return node
784 def visit_SimpleCallNode(self, node):
785 """Replace generic calls to isinstance(x, type) by a more efficient
786 type check.
787 """
788 self.visitchildren(node)
789 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
790 if node.function.name == 'isinstance':
791 type_arg = node.args[1]
792 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
793 object_module = self.context.find_module('python_object')
794 node.function.entry = object_module.lookup('PyObject_TypeCheck')
795 if node.function.entry is None:
796 return node # only happens when there was an error earlier
797 node.function.type = node.function.entry.type
798 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
799 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
800 return node