Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 2203:4495287d2fdb

NoneCheckNode to enforce runtime None checks for object references
author Stefan Behnel <scoder@users.berlios.de>
date Sun Mar 29 20:55:51 2009 +0200 (3 years ago)
parents 606442166cc7
children 91be9458f343
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
11 from Cython.Utils import UtilityCode
12 from StringEncoding import EncodedString
13 from Errors import error
14 from ParseTreeTransforms import SkipDeclarations
16 #def unwrap_node(node):
17 # while isinstance(node, ExprNodes.PersistentNode):
18 # node = node.arg
19 # return node
21 # Temporary hack while PersistentNode is out of order
22 def unwrap_node(node):
23 return node
25 def is_common_value(a, b):
26 a = unwrap_node(a)
27 b = unwrap_node(b)
28 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
29 return a.name == b.name
30 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
31 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
32 return False
35 class IterationTransform(Visitor.VisitorTransform):
36 """Transform some common for-in loop patterns into efficient C loops:
38 - for-in-dict loop becomes a while loop calling PyDict_Next()
39 - for-in-range loop becomes a plain C for loop
40 """
41 PyDict_Next_func_type = PyrexTypes.CFuncType(
42 PyrexTypes.c_bint_type, [
43 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
44 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
45 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
46 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
47 ])
49 PyDict_Next_name = EncodedString("PyDict_Next")
51 PyDict_Next_entry = Symtab.Entry(
52 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
54 visit_Node = Visitor.VisitorTransform.recurse_to_children
56 def visit_ModuleNode(self, node):
57 self.current_scope = node.scope
58 self.visitchildren(node)
59 return node
61 def visit_DefNode(self, node):
62 oldscope = self.current_scope
63 self.current_scope = node.entry.scope
64 self.visitchildren(node)
65 self.current_scope = oldscope
66 return node
68 def visit_ForInStatNode(self, node):
69 self.visitchildren(node)
70 iterator = node.iterator.sequence
71 if iterator.type is Builtin.dict_type:
72 # like iterating over dict.keys()
73 return self._transform_dict_iteration(
74 node, dict_obj=iterator, keys=True, values=False)
75 if not isinstance(iterator, ExprNodes.SimpleCallNode):
76 return node
78 function = iterator.function
79 # dict iteration?
80 if isinstance(function, ExprNodes.AttributeNode) and \
81 function.obj.type == Builtin.dict_type:
82 dict_obj = function.obj
83 method = function.attribute
85 keys = values = False
86 if method == 'iterkeys':
87 keys = True
88 elif method == 'itervalues':
89 values = True
90 elif method == 'iteritems':
91 keys = values = True
92 else:
93 return node
94 return self._transform_dict_iteration(
95 node, dict_obj, keys, values)
97 # range() iteration?
98 if Options.convert_range and node.target.type.is_int:
99 if iterator.self is None and \
100 isinstance(function, ExprNodes.NameNode) and \
101 function.entry.is_builtin and \
102 function.name in ('range', 'xrange'):
103 return self._transform_range_iteration(node, iterator)
105 return node
107 def _transform_range_iteration(self, node, range_function):
108 args = range_function.arg_tuple.args
109 if len(args) < 3:
110 step_pos = range_function.pos
111 step_value = 1
112 step = ExprNodes.IntNode(step_pos, value=1)
113 else:
114 step = args[2]
115 step_pos = step.pos
116 if not isinstance(step.constant_result, (int, long)):
117 # cannot determine step direction
118 return node
119 step_value = step.constant_result
120 if step_value == 0:
121 # will lead to an error elsewhere
122 return node
123 if not isinstance(step, ExprNodes.IntNode):
124 step = ExprNodes.IntNode(step_pos, value=step_value)
126 if step_value < 0:
127 step.value = -step_value
128 relation1 = '>='
129 relation2 = '>'
130 else:
131 relation1 = '<='
132 relation2 = '<'
134 if len(args) == 1:
135 bound1 = ExprNodes.IntNode(range_function.pos, value=0)
136 bound2 = args[0].coerce_to_integer(self.current_scope)
137 else:
138 bound1 = args[0].coerce_to_integer(self.current_scope)
139 bound2 = args[1].coerce_to_integer(self.current_scope)
140 step = step.coerce_to_integer(self.current_scope)
142 for_node = Nodes.ForFromStatNode(
143 node.pos,
144 target=node.target,
145 bound1=bound1, relation1=relation1,
146 relation2=relation2, bound2=bound2,
147 step=step, body=node.body,
148 else_clause=node.else_clause,
149 from_range=True)
150 return for_node
152 def _transform_dict_iteration(self, node, dict_obj, keys, values):
153 py_object_ptr = PyrexTypes.c_void_ptr_type
155 temps = []
156 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
157 temps.append(temp)
158 dict_temp = temp.ref(dict_obj.pos)
159 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
160 temps.append(temp)
161 pos_temp = temp.ref(node.pos)
162 pos_temp_addr = ExprNodes.AmpersandNode(
163 node.pos, operand=pos_temp,
164 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
165 if keys:
166 temp = UtilNodes.TempHandle(py_object_ptr)
167 temps.append(temp)
168 key_temp = temp.ref(node.target.pos)
169 key_temp_addr = ExprNodes.AmpersandNode(
170 node.target.pos, operand=key_temp,
171 type=PyrexTypes.c_ptr_type(py_object_ptr))
172 else:
173 key_temp_addr = key_temp = ExprNodes.NullNode(
174 pos=node.target.pos)
175 if values:
176 temp = UtilNodes.TempHandle(py_object_ptr)
177 temps.append(temp)
178 value_temp = temp.ref(node.target.pos)
179 value_temp_addr = ExprNodes.AmpersandNode(
180 node.target.pos, operand=value_temp,
181 type=PyrexTypes.c_ptr_type(py_object_ptr))
182 else:
183 value_temp_addr = value_temp = ExprNodes.NullNode(
184 pos=node.target.pos)
186 key_target = value_target = node.target
187 tuple_target = None
188 if keys and values:
189 if node.target.is_sequence_constructor:
190 if len(node.target.args) == 2:
191 key_target, value_target = node.target.args
192 else:
193 # unusual case that may or may not lead to an error
194 return node
195 else:
196 tuple_target = node.target
198 def coerce_object_to(obj_node, dest_type):
199 class FakeEnv(object):
200 nogil = False
201 if dest_type.is_pyobject:
202 if dest_type.is_extension_type or dest_type.is_builtin_type:
203 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
204 result = ExprNodes.TypecastNode(
205 obj_node.pos,
206 operand = obj_node,
207 type = dest_type)
208 return (result, None)
209 else:
210 temp = UtilNodes.TempHandle(dest_type)
211 temps.append(temp)
212 temp_result = temp.ref(obj_node.pos)
213 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
214 def result(self):
215 return temp_result.result()
216 def generate_execution_code(self, code):
217 self.generate_result_code(code)
218 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
220 if isinstance(node.body, Nodes.StatListNode):
221 body = node.body
222 else:
223 body = Nodes.StatListNode(pos = node.body.pos,
224 stats = [node.body])
226 if tuple_target:
227 tuple_result = ExprNodes.TupleNode(
228 pos = tuple_target.pos,
229 args = [key_temp, value_temp],
230 is_temp = 1,
231 type = Builtin.tuple_type,
232 )
233 body.stats.insert(
234 0, Nodes.SingleAssignmentNode(
235 pos = tuple_target.pos,
236 lhs = tuple_target,
237 rhs = tuple_result))
238 else:
239 # execute all coercions before the assignments
240 coercion_stats = []
241 assign_stats = []
242 if keys:
243 temp_result, coercion = coerce_object_to(
244 key_temp, key_target.type)
245 if coercion:
246 coercion_stats.append(coercion)
247 assign_stats.append(
248 Nodes.SingleAssignmentNode(
249 pos = key_temp.pos,
250 lhs = key_target,
251 rhs = temp_result))
252 if values:
253 temp_result, coercion = coerce_object_to(
254 value_temp, value_target.type)
255 if coercion:
256 coercion_stats.append(coercion)
257 assign_stats.append(
258 Nodes.SingleAssignmentNode(
259 pos = value_temp.pos,
260 lhs = value_target,
261 rhs = temp_result))
262 body.stats[0:0] = coercion_stats + assign_stats
264 result_code = [
265 Nodes.SingleAssignmentNode(
266 pos = dict_obj.pos,
267 lhs = dict_temp,
268 rhs = dict_obj),
269 Nodes.SingleAssignmentNode(
270 pos = node.pos,
271 lhs = pos_temp,
272 rhs = ExprNodes.IntNode(node.pos, value=0)),
273 Nodes.WhileStatNode(
274 pos = node.pos,
275 condition = ExprNodes.SimpleCallNode(
276 pos = dict_obj.pos,
277 type = PyrexTypes.c_bint_type,
278 function = ExprNodes.NameNode(
279 pos = dict_obj.pos,
280 name = self.PyDict_Next_name,
281 type = self.PyDict_Next_func_type,
282 entry = self.PyDict_Next_entry),
283 args = [dict_temp, pos_temp_addr,
284 key_temp_addr, value_temp_addr]
285 ),
286 body = body,
287 else_clause = node.else_clause
288 )
289 ]
291 return UtilNodes.TempsBlockNode(
292 node.pos, temps=temps,
293 body=Nodes.StatListNode(
294 node.pos,
295 stats = result_code
296 ))
299 class SwitchTransform(Visitor.VisitorTransform):
300 """
301 This transformation tries to turn long if statements into C switch statements.
302 The requirement is that every clause be an (or of) var == value, where the var
303 is common among all clauses and both var and value are ints.
304 """
305 def extract_conditions(self, cond):
307 if isinstance(cond, ExprNodes.CoerceToTempNode):
308 cond = cond.arg
310 if isinstance(cond, ExprNodes.TypecastNode):
311 cond = cond.operand
313 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
314 and cond.cascade is None
315 and cond.operator == '=='
316 and not cond.is_python_comparison()):
317 if is_common_value(cond.operand1, cond.operand1):
318 if isinstance(cond.operand2, ExprNodes.ConstNode):
319 return cond.operand1, [cond.operand2]
320 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
321 return cond.operand1, [cond.operand2]
322 if is_common_value(cond.operand2, cond.operand2):
323 if isinstance(cond.operand1, ExprNodes.ConstNode):
324 return cond.operand2, [cond.operand1]
325 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
326 return cond.operand2, [cond.operand1]
327 elif (isinstance(cond, ExprNodes.BoolBinopNode)
328 and cond.operator == 'or'):
329 t1, c1 = self.extract_conditions(cond.operand1)
330 t2, c2 = self.extract_conditions(cond.operand2)
331 if is_common_value(t1, t2):
332 return t1, c1+c2
333 return None, None
335 def visit_IfStatNode(self, node):
336 self.visitchildren(node)
337 common_var = None
338 case_count = 0
339 cases = []
340 for if_clause in node.if_clauses:
341 var, conditions = self.extract_conditions(if_clause.condition)
342 if var is None:
343 return node
344 elif common_var is not None and not is_common_value(var, common_var):
345 return node
346 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
347 return node
348 else:
349 common_var = var
350 case_count += len(conditions)
351 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
352 conditions = conditions,
353 body = if_clause.body))
354 if case_count < 2:
355 return node
357 common_var = unwrap_node(common_var)
358 return Nodes.SwitchStatNode(pos = node.pos,
359 test = common_var,
360 cases = cases,
361 else_clause = node.else_clause)
363 visit_Node = Visitor.VisitorTransform.recurse_to_children
366 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
367 """
368 This transformation flattens "x in [val1, ..., valn]" into a sequential list
369 of comparisons.
370 """
372 def visit_PrimaryCmpNode(self, node):
373 self.visitchildren(node)
374 if node.cascade is not None:
375 return node
376 elif node.operator == 'in':
377 conjunction = 'or'
378 eq_or_neq = '=='
379 elif node.operator == 'not_in':
380 conjunction = 'and'
381 eq_or_neq = '!='
382 else:
383 return node
385 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
386 return node
388 args = node.operand2.args
389 if len(args) == 0:
390 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
392 lhs = UtilNodes.ResultRefNode(node.operand1)
394 conds = []
395 for arg in args:
396 cond = ExprNodes.PrimaryCmpNode(
397 pos = node.pos,
398 operand1 = lhs,
399 operator = eq_or_neq,
400 operand2 = arg,
401 cascade = None)
402 conds.append(ExprNodes.TypecastNode(
403 pos = node.pos,
404 operand = cond,
405 type = PyrexTypes.c_bint_type))
406 def concat(left, right):
407 return ExprNodes.BoolBinopNode(
408 pos = node.pos,
409 operator = conjunction,
410 operand1 = left,
411 operand2 = right)
413 condition = reduce(concat, conds)
414 return UtilNodes.EvalWithTempExprNode(lhs, condition)
416 visit_Node = Visitor.VisitorTransform.recurse_to_children
419 class OptimizeBuiltinCalls(Visitor.VisitorTransform):
420 """Optimize some common methods calls and instantiation patterns
421 for builtin types.
422 """
423 # only intercept on call nodes
424 visit_Node = Visitor.VisitorTransform.recurse_to_children
426 def visit_GeneralCallNode(self, node):
427 self.visitchildren(node)
428 function = node.function
429 if not function.type.is_pyobject:
430 return node
431 arg_tuple = node.positional_args
432 if not isinstance(arg_tuple, ExprNodes.TupleNode):
433 return node
434 return self._dispatch_to_handler(
435 node, function, arg_tuple, node.keyword_args)
437 def visit_SimpleCallNode(self, node):
438 self.visitchildren(node)
439 function = node.function
440 if not function.type.is_pyobject:
441 return node
442 arg_tuple = node.arg_tuple
443 if not isinstance(arg_tuple, ExprNodes.TupleNode):
444 return node
445 return self._dispatch_to_handler(
446 node, node.function, arg_tuple)
448 def visit_PyTypeTestNode(self, node):
449 """Flatten redundant type checks after tree changes.
450 """
451 old_arg = node.arg
452 self.visitchildren(node)
453 if old_arg is node.arg or node.arg.type != node.type:
454 return node
455 return node.arg
457 def _find_handler(self, match_name, has_kwargs):
458 call_type = has_kwargs and 'general' or 'simple'
459 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
460 if handler is None:
461 handler = getattr(self, '_handle_any_%s' % match_name, None)
462 return handler
464 def _dispatch_to_handler(self, node, function, arg_tuple, kwargs=None):
465 if function.is_name:
466 match_name = "_function_%s" % function.name
467 function_handler = self._find_handler(
468 "function_%s" % function.name, kwargs)
469 if function_handler is None:
470 return node
471 if kwargs:
472 return function_handler(node, arg_tuple, kwargs)
473 else:
474 return function_handler(node, arg_tuple)
475 elif isinstance(function, ExprNodes.AttributeNode):
476 arg_list = arg_tuple.args
477 self_arg = function.obj
478 obj_type = self_arg.type
479 is_unbound_method = False
480 if obj_type.is_builtin_type:
481 if obj_type is Builtin.type_type and arg_list and \
482 arg_list[0].type.is_pyobject:
483 # calling an unbound method like 'list.append(L,x)'
484 # (ignoring 'type.mro()' here ...)
485 type_name = function.obj.name
486 self_arg = None
487 is_unbound_method = True
488 else:
489 type_name = obj_type.name
490 else:
491 type_name = "object" # safety measure
492 method_handler = self._find_handler(
493 "method_%s_%s" % (type_name, function.attribute), kwargs)
494 if method_handler is None:
495 return node
496 if self_arg is not None:
497 arg_list = [self_arg] + list(arg_list)
498 if kwargs:
499 return method_handler(node, arg_list, kwargs, is_unbound_method)
500 else:
501 return method_handler(node, arg_list, is_unbound_method)
502 else:
503 return node
505 ### builtin types
507 def _handle_general_function_dict(self, node, pos_args, kwargs):
508 """Replace dict(a=b,c=d,...) by the underlying keyword dict
509 construction which is done anyway.
510 """
511 if len(pos_args.args) > 0:
512 return node
513 if not isinstance(kwargs, ExprNodes.DictNode):
514 return node
515 if node.starstar_arg:
516 # we could optimize this by updating the kw dict instead
517 return node
518 return kwargs
520 PyDict_Copy_func_type = PyrexTypes.CFuncType(
521 Builtin.dict_type, [
522 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
523 ])
525 def _handle_simple_function_dict(self, node, pos_args):
526 """Replace dict(some_dict) by PyDict_Copy(some_dict).
527 """
528 if len(pos_args.args) != 1:
529 return node
530 if pos_args.args[0].type is not Builtin.dict_type:
531 return node
533 return ExprNodes.PythonCapiCallNode(
534 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
535 args = pos_args.args,
536 is_temp = node.is_temp
537 )
539 def _handle_simple_function_set(self, node, pos_args):
540 """Replace set([a,b,...]) by a literal set {a,b,...}.
541 """
542 arg_count = len(pos_args.args)
543 if arg_count == 0:
544 return ExprNodes.SetNode(node.pos, args=[],
545 type=Builtin.set_type, is_temp=1)
546 if arg_count > 1:
547 return node
548 iterable = pos_args.args[0]
549 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
550 return ExprNodes.SetNode(node.pos, args=iterable.args,
551 type=Builtin.set_type, is_temp=1)
552 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
553 iterable.type is Builtin.list_type:
554 iterable.target = ExprNodes.SetNode(
555 node.pos, args=[], type=Builtin.set_type, is_temp=1)
556 iterable.type = Builtin.set_type
557 iterable.pos = node.pos
558 return iterable
559 else:
560 return node
562 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
563 Builtin.tuple_type, [
564 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
565 ])
567 def _handle_simple_function_tuple(self, node, pos_args):
568 """Replace tuple([...]) by a call to PyList_AsTuple.
569 """
570 if len(pos_args.args) != 1:
571 return node
572 list_arg = pos_args.args[0]
573 if list_arg.type is not Builtin.list_type:
574 return node
575 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
576 ExprNodes.ListNode)):
577 # everything else may be None => take the safe path
578 return node
580 return ExprNodes.PythonCapiCallNode(
581 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
582 args = pos_args.args,
583 is_temp = node.is_temp
584 )
586 ### builtin functions
588 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
589 PyrexTypes.py_object_type, [
590 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
591 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
592 ])
594 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
595 PyrexTypes.py_object_type, [
596 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
597 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
598 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
599 ])
601 def _handle_simple_function_getattr(self, node, pos_args):
602 # not really a builtin *type*, but worth optimising anyway
603 args = pos_args.args
604 if len(args) == 2:
605 node = ExprNodes.PythonCapiCallNode(
606 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
607 args = args,
608 is_temp = node.is_temp
609 )
610 elif len(args) == 3:
611 node = ExprNodes.PythonCapiCallNode(
612 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
613 utility_code = Builtin.getattr3_utility_code,
614 args = args,
615 is_temp = node.is_temp
616 )
617 else:
618 error(node.pos, "getattr() called with wrong number of args, "
619 "expected 2 or 3, found %d" % len(args))
620 return node
622 ### methods of builtin types
624 PyObject_Append_func_type = PyrexTypes.CFuncType(
625 PyrexTypes.py_object_type, [
626 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
627 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
628 ])
630 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
631 # X.append() is almost always referring to a list
632 if len(args) != 2:
633 return node
635 return ExprNodes.PythonCapiCallNode(
636 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
637 args = args,
638 is_temp = node.is_temp,
639 utility_code = append_utility_code
640 )
642 PyList_Append_func_type = PyrexTypes.CFuncType(
643 PyrexTypes.c_int_type, [
644 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
645 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
646 ],
647 exception_value = "-1")
649 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
650 if len(args) != 2:
651 error(node.pos, "list.append(x) called with wrong number of args, found %d" %
652 len(args))
653 return node
654 return self._substitute_method_call(
655 node, "PyList_Append", self.PyList_Append_func_type,
656 'append', is_unbound_method, args)
658 single_param_func_type = PyrexTypes.CFuncType(
659 PyrexTypes.c_int_type, [
660 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
661 ],
662 exception_value = "-1")
664 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
665 if len(args) != 1:
666 return node
667 return self._substitute_method_call(
668 node, "PyList_Sort", self.single_param_func_type,
669 'sort', is_unbound_method, args)
671 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
672 if len(args) != 1:
673 error(node.pos, "list.reverse(x) called with wrong number of args, found %d" %
674 len(args))
675 return node
676 return self._substitute_method_call(
677 node, "PyList_Reverse", self.single_param_func_type,
678 'reverse', is_unbound_method, args)
680 def _substitute_method_call(self, node, name, func_type,
681 attr_name, is_unbound_method, args=()):
682 args = list(args)
683 if args:
684 self_arg = args[0]
685 if is_unbound_method:
686 self_arg = ExprNodes.NoneCheckNode(
687 self_arg, "PyExc_TypeError",
688 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
689 attr_name, node.function.obj.name))
690 else:
691 self_arg = ExprNodes.NoneCheckNode(
692 self_arg, "PyExc_AttributeError",
693 "'NoneType' object has no attribute '%s'" % attr_name)
694 args[0] = self_arg
695 # FIXME: args[0] may need a runtime None check (ticket #166)
696 return ExprNodes.PythonCapiCallNode(
697 node.pos, name, func_type,
698 args = args,
699 is_temp = node.is_temp
700 )
703 append_utility_code = UtilityCode(
704 proto = """
705 static INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
706 if (likely(PyList_CheckExact(L))) {
707 if (PyList_Append(L, x) < 0) return NULL;
708 Py_INCREF(Py_None);
709 return Py_None; /* this is just to have an accurate signature */
710 }
711 else {
712 PyObject *r, *m;
713 m = __Pyx_GetAttrString(L, "append");
714 if (!m) return NULL;
715 r = PyObject_CallFunctionObjArgs(m, x, NULL);
716 Py_DECREF(m);
717 return r;
718 }
719 }
720 """,
721 impl = ""
722 )
725 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
726 """Calculate the result of constant expressions to store it in
727 ``expr_node.constant_result``, and replace trivial cases by their
728 constant result.
729 """
730 def _calculate_const(self, node):
731 if node.constant_result is not ExprNodes.constant_value_not_set:
732 return
734 # make sure we always set the value
735 not_a_constant = ExprNodes.not_a_constant
736 node.constant_result = not_a_constant
738 # check if all children are constant
739 children = self.visitchildren(node)
740 for child_result in children.itervalues():
741 if type(child_result) is list:
742 for child in child_result:
743 if child.constant_result is not_a_constant:
744 return
745 elif child_result.constant_result is not_a_constant:
746 return
748 # now try to calculate the real constant value
749 try:
750 node.calculate_constant_result()
751 # if node.constant_result is not ExprNodes.not_a_constant:
752 # print node.__class__.__name__, node.constant_result
753 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
754 # ignore all 'normal' errors here => no constant result
755 pass
756 except Exception:
757 # this looks like a real error
758 import traceback, sys
759 traceback.print_exc(file=sys.stdout)
761 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
762 ExprNodes.LongNode, ExprNodes.FloatNode)
764 def _widest_node_class(self, *nodes):
765 try:
766 return self.NODE_TYPE_ORDER[
767 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
768 except ValueError:
769 return None
771 def visit_ExprNode(self, node):
772 self._calculate_const(node)
773 return node
775 def visit_BinopNode(self, node):
776 self._calculate_const(node)
777 if node.constant_result is ExprNodes.not_a_constant:
778 return node
779 try:
780 if node.operand1.type is None or node.operand2.type is None:
781 return node
782 except AttributeError:
783 return node
785 type1, type2 = node.operand1.type, node.operand2.type
786 if isinstance(node.operand1, ExprNodes.ConstNode) and \
787 isinstance(node.operand1, ExprNodes.ConstNode):
788 if type1 is type2:
789 new_node = node.operand1
790 else:
791 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
792 if type(node.operand1) is type(node.operand2):
793 new_node = node.operand1
794 new_node.type = widest_type
795 elif type1 is widest_type:
796 new_node = node.operand1
797 elif type2 is widest_type:
798 new_node = node.operand2
799 else:
800 target_class = self._widest_node_class(
801 node.operand1, node.operand2)
802 if target_class is None:
803 return node
804 new_node = target_class(type = widest_type)
805 else:
806 return node
808 new_node.constant_result = node.constant_result
809 new_node.value = str(node.constant_result)
810 #new_node = new_node.coerce_to(node.type, self.current_scope)
811 return new_node
813 # in the future, other nodes can have their own handler method here
814 # that can replace them with a constant result node
816 visit_Node = Visitor.VisitorTransform.recurse_to_children
819 class FinalOptimizePhase(Visitor.CythonTransform):
820 """
821 This visitor handles several commuting optimizations, and is run
822 just before the C code generation phase.
824 The optimizations currently implemented in this class are:
825 - Eliminate None assignment and refcounting for first assignment.
826 - isinstance -> typecheck for cdef types
827 """
828 def visit_SingleAssignmentNode(self, node):
829 """Avoid redundant initialisation of local variables before their
830 first assignment.
831 """
832 self.visitchildren(node)
833 if node.first:
834 lhs = node.lhs
835 lhs.lhs_of_first_assignment = True
836 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
837 # Have variable initialized to 0 rather than None
838 lhs.entry.init_to_none = False
839 lhs.entry.init = 0
840 return node
842 def visit_SimpleCallNode(self, node):
843 """Replace generic calls to isinstance(x, type) by a more efficient
844 type check.
845 """
846 self.visitchildren(node)
847 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
848 if node.function.name == 'isinstance':
849 type_arg = node.args[1]
850 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
851 from CythonScope import utility_scope
852 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
853 node.function.type = node.function.entry.type
854 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
855 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
856 return node