Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 3156:a92b70b5624e

optimise unicode.split() and unicode.splitlines()
author Stefan Behnel <scoder@users.berlios.de>
date Sun Mar 21 07:57:00 2010 +0100 (2 years ago)
parents 3d05d88c544d
children bb090cf72455
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
17 import codecs
19 try:
20 reduce
21 except NameError:
22 from functools import reduce
24 try:
25 set
26 except NameError:
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
31 nogil = False
33 def unwrap_node(node):
34 while isinstance(node, UtilNodes.ResultRefNode):
35 node = node.expression
36 return node
38 def is_common_value(a, b):
39 a = unwrap_node(a)
40 b = unwrap_node(b)
41 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
42 return a.name == b.name
43 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 return False
47 class IterationTransform(Visitor.VisitorTransform):
48 """Transform some common for-in loop patterns into efficient C loops:
50 - for-in-dict loop becomes a while loop calling PyDict_Next()
51 - for-in-enumerate is replaced by an external counter variable
52 - for-in-range loop becomes a plain C for loop
53 """
54 PyDict_Next_func_type = PyrexTypes.CFuncType(
55 PyrexTypes.c_bint_type, [
56 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
57 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
58 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
59 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
60 ])
62 PyDict_Next_name = EncodedString("PyDict_Next")
64 PyDict_Next_entry = Symtab.Entry(
65 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
67 visit_Node = Visitor.VisitorTransform.recurse_to_children
69 def visit_ModuleNode(self, node):
70 self.current_scope = node.scope
71 self.visitchildren(node)
72 return node
74 def visit_DefNode(self, node):
75 oldscope = self.current_scope
76 self.current_scope = node.entry.scope
77 self.visitchildren(node)
78 self.current_scope = oldscope
79 return node
81 def visit_ForInStatNode(self, node):
82 self.visitchildren(node)
83 return self._optimise_for_loop(node)
85 def _optimise_for_loop(self, node):
86 iterator = node.iterator.sequence
87 if iterator.type is Builtin.dict_type:
88 # like iterating over dict.keys()
89 return self._transform_dict_iteration(
90 node, dict_obj=iterator, keys=True, values=False)
92 # C array (slice) iteration?
93 if isinstance(iterator, ExprNodes.SliceIndexNode) and \
94 (iterator.base.type.is_array or iterator.base.type.is_ptr):
95 return self._transform_carray_iteration(node, iterator)
96 elif iterator.type.is_array:
97 return self._transform_carray_iteration(node, iterator)
98 elif not isinstance(iterator, ExprNodes.SimpleCallNode):
99 return node
101 function = iterator.function
102 # dict iteration?
103 if isinstance(function, ExprNodes.AttributeNode) and \
104 function.obj.type == Builtin.dict_type:
105 dict_obj = function.obj
106 method = function.attribute
108 keys = values = False
109 if method == 'iterkeys':
110 keys = True
111 elif method == 'itervalues':
112 values = True
113 elif method == 'iteritems':
114 keys = values = True
115 else:
116 return node
117 return self._transform_dict_iteration(
118 node, dict_obj, keys, values)
120 # enumerate() ?
121 if iterator.self is None and function.is_name and \
122 function.entry and function.entry.is_builtin and \
123 function.name == 'enumerate':
124 return self._transform_enumerate_iteration(node, iterator)
126 # range() iteration?
127 if Options.convert_range and node.target.type.is_int:
128 if iterator.self is None and function.is_name and \
129 function.entry and function.entry.is_builtin and \
130 function.name in ('range', 'xrange'):
131 return self._transform_range_iteration(node, iterator)
133 return node
135 def _transform_carray_iteration(self, node, slice_node):
136 if isinstance(slice_node, ExprNodes.SliceIndexNode):
137 slice_base = slice_node.base
138 start = slice_node.start
139 stop = slice_node.stop
140 step = None
141 if not stop:
142 return node
143 elif slice_node.type.is_array and slice_node.type.size is not None:
144 slice_base = slice_node
145 start = None
146 stop = ExprNodes.IntNode(
147 slice_node.pos, value=str(slice_node.type.size))
148 step = None
149 else:
150 return node
152 ptr_type = slice_base.type
153 if ptr_type.is_array:
154 ptr_type = ptr_type.element_ptr_type()
155 carray_ptr = slice_base.coerce_to_simple(self.current_scope)
157 if start and start.constant_result != 0:
158 start_ptr_node = ExprNodes.AddNode(
159 start.pos,
160 operand1=carray_ptr,
161 operator='+',
162 operand2=start,
163 type=ptr_type)
164 else:
165 start_ptr_node = carray_ptr
167 stop_ptr_node = ExprNodes.AddNode(
168 stop.pos,
169 operand1=carray_ptr,
170 operator='+',
171 operand2=stop,
172 type=ptr_type
173 ).coerce_to_simple(self.current_scope)
175 counter = UtilNodes.TempHandle(ptr_type)
176 counter_temp = counter.ref(node.target.pos)
178 if slice_base.type.is_string and node.target.type.is_pyobject:
179 # special case: char* -> bytes
180 target_value = ExprNodes.SliceIndexNode(
181 node.target.pos,
182 start=ExprNodes.IntNode(node.target.pos, value='0',
183 constant_result=0,
184 type=PyrexTypes.c_int_type),
185 stop=ExprNodes.IntNode(node.target.pos, value='1',
186 constant_result=1,
187 type=PyrexTypes.c_int_type),
188 base=counter_temp,
189 type=Builtin.bytes_type,
190 is_temp=1)
191 else:
192 target_value = ExprNodes.IndexNode(
193 node.target.pos,
194 index=ExprNodes.IntNode(node.target.pos, value='0',
195 constant_result=0,
196 type=PyrexTypes.c_int_type),
197 base=counter_temp,
198 is_buffer_access=False,
199 type=ptr_type.base_type)
201 if target_value.type != node.target.type:
202 target_value = target_value.coerce_to(node.target.type,
203 self.current_scope)
205 target_assign = Nodes.SingleAssignmentNode(
206 pos = node.target.pos,
207 lhs = node.target,
208 rhs = target_value)
210 body = Nodes.StatListNode(
211 node.pos,
212 stats = [target_assign, node.body])
214 for_node = Nodes.ForFromStatNode(
215 node.pos,
216 bound1=start_ptr_node, relation1='<=',
217 target=counter_temp,
218 relation2='<', bound2=stop_ptr_node,
219 step=step, body=body,
220 else_clause=node.else_clause,
221 from_range=True)
223 return UtilNodes.TempsBlockNode(
224 node.pos, temps=[counter],
225 body=for_node)
227 def _transform_enumerate_iteration(self, node, enumerate_function):
228 args = enumerate_function.arg_tuple.args
229 if len(args) == 0:
230 error(enumerate_function.pos,
231 "enumerate() requires an iterable argument")
232 return node
233 elif len(args) > 1:
234 error(enumerate_function.pos,
235 "enumerate() takes at most 1 argument")
236 return node
238 if not node.target.is_sequence_constructor:
239 # leave this untouched for now
240 return node
241 targets = node.target.args
242 if len(targets) != 2:
243 # leave this untouched for now
244 return node
245 if not isinstance(targets[0], ExprNodes.NameNode):
246 # leave this untouched for now
247 return node
249 enumerate_target, iterable_target = targets
250 counter_type = enumerate_target.type
252 if not counter_type.is_pyobject and not counter_type.is_int:
253 # nothing we can do here, I guess
254 return node
256 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
257 value='0',
258 type=counter_type,
259 constant_result=0))
260 inc_expression = ExprNodes.AddNode(
261 enumerate_function.pos,
262 operand1 = temp,
263 operand2 = ExprNodes.IntNode(node.pos, value='1',
264 type=counter_type,
265 constant_result=1),
266 operator = '+',
267 type = counter_type,
268 is_temp = counter_type.is_pyobject
269 )
271 loop_body = [
272 Nodes.SingleAssignmentNode(
273 pos = enumerate_target.pos,
274 lhs = enumerate_target,
275 rhs = temp),
276 Nodes.SingleAssignmentNode(
277 pos = enumerate_target.pos,
278 lhs = temp,
279 rhs = inc_expression)
280 ]
282 if isinstance(node.body, Nodes.StatListNode):
283 node.body.stats = loop_body + node.body.stats
284 else:
285 loop_body.append(node.body)
286 node.body = Nodes.StatListNode(
287 node.body.pos,
288 stats = loop_body)
290 node.target = iterable_target
291 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
292 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
294 # recurse into loop to check for further optimisations
295 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
297 def _transform_range_iteration(self, node, range_function):
298 args = range_function.arg_tuple.args
299 if len(args) < 3:
300 step_pos = range_function.pos
301 step_value = 1
302 step = ExprNodes.IntNode(step_pos, value='1',
303 constant_result=1)
304 else:
305 step = args[2]
306 step_pos = step.pos
307 if not isinstance(step.constant_result, (int, long)):
308 # cannot determine step direction
309 return node
310 step_value = step.constant_result
311 if step_value == 0:
312 # will lead to an error elsewhere
313 return node
314 if not isinstance(step, ExprNodes.IntNode):
315 step = ExprNodes.IntNode(step_pos, value=str(step_value),
316 constant_result=step_value)
318 if step_value < 0:
319 step.value = str(-step_value)
320 relation1 = '>='
321 relation2 = '>'
322 else:
323 relation1 = '<='
324 relation2 = '<'
326 if len(args) == 1:
327 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
328 constant_result=0)
329 bound2 = args[0].coerce_to_integer(self.current_scope)
330 else:
331 bound1 = args[0].coerce_to_integer(self.current_scope)
332 bound2 = args[1].coerce_to_integer(self.current_scope)
333 step = step.coerce_to_integer(self.current_scope)
335 if not bound2.is_literal:
336 # stop bound must be immutable => keep it in a temp var
337 bound2_is_temp = True
338 bound2 = UtilNodes.LetRefNode(bound2)
339 else:
340 bound2_is_temp = False
342 for_node = Nodes.ForFromStatNode(
343 node.pos,
344 target=node.target,
345 bound1=bound1, relation1=relation1,
346 relation2=relation2, bound2=bound2,
347 step=step, body=node.body,
348 else_clause=node.else_clause,
349 from_range=True)
351 if bound2_is_temp:
352 for_node = UtilNodes.LetNode(bound2, for_node)
354 return for_node
356 def _transform_dict_iteration(self, node, dict_obj, keys, values):
357 py_object_ptr = PyrexTypes.c_void_ptr_type
359 temps = []
360 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
361 temps.append(temp)
362 dict_temp = temp.ref(dict_obj.pos)
363 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
364 temps.append(temp)
365 pos_temp = temp.ref(node.pos)
366 pos_temp_addr = ExprNodes.AmpersandNode(
367 node.pos, operand=pos_temp,
368 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
369 if keys:
370 temp = UtilNodes.TempHandle(py_object_ptr)
371 temps.append(temp)
372 key_temp = temp.ref(node.target.pos)
373 key_temp_addr = ExprNodes.AmpersandNode(
374 node.target.pos, operand=key_temp,
375 type=PyrexTypes.c_ptr_type(py_object_ptr))
376 else:
377 key_temp_addr = key_temp = ExprNodes.NullNode(
378 pos=node.target.pos)
379 if values:
380 temp = UtilNodes.TempHandle(py_object_ptr)
381 temps.append(temp)
382 value_temp = temp.ref(node.target.pos)
383 value_temp_addr = ExprNodes.AmpersandNode(
384 node.target.pos, operand=value_temp,
385 type=PyrexTypes.c_ptr_type(py_object_ptr))
386 else:
387 value_temp_addr = value_temp = ExprNodes.NullNode(
388 pos=node.target.pos)
390 key_target = value_target = node.target
391 tuple_target = None
392 if keys and values:
393 if node.target.is_sequence_constructor:
394 if len(node.target.args) == 2:
395 key_target, value_target = node.target.args
396 else:
397 # unusual case that may or may not lead to an error
398 return node
399 else:
400 tuple_target = node.target
402 def coerce_object_to(obj_node, dest_type):
403 if dest_type.is_pyobject:
404 if dest_type != obj_node.type:
405 if dest_type.is_extension_type or dest_type.is_builtin_type:
406 obj_node = ExprNodes.PyTypeTestNode(
407 obj_node, dest_type, self.current_scope, notnone=True)
408 result = ExprNodes.TypecastNode(
409 obj_node.pos,
410 operand = obj_node,
411 type = dest_type)
412 return (result, None)
413 else:
414 temp = UtilNodes.TempHandle(dest_type)
415 temps.append(temp)
416 temp_result = temp.ref(obj_node.pos)
417 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
418 def result(self):
419 return temp_result.result()
420 def generate_execution_code(self, code):
421 self.generate_result_code(code)
422 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
424 if isinstance(node.body, Nodes.StatListNode):
425 body = node.body
426 else:
427 body = Nodes.StatListNode(pos = node.body.pos,
428 stats = [node.body])
430 if tuple_target:
431 tuple_result = ExprNodes.TupleNode(
432 pos = tuple_target.pos,
433 args = [key_temp, value_temp],
434 is_temp = 1,
435 type = Builtin.tuple_type,
436 )
437 body.stats.insert(
438 0, Nodes.SingleAssignmentNode(
439 pos = tuple_target.pos,
440 lhs = tuple_target,
441 rhs = tuple_result))
442 else:
443 # execute all coercions before the assignments
444 coercion_stats = []
445 assign_stats = []
446 if keys:
447 temp_result, coercion = coerce_object_to(
448 key_temp, key_target.type)
449 if coercion:
450 coercion_stats.append(coercion)
451 assign_stats.append(
452 Nodes.SingleAssignmentNode(
453 pos = key_temp.pos,
454 lhs = key_target,
455 rhs = temp_result))
456 if values:
457 temp_result, coercion = coerce_object_to(
458 value_temp, value_target.type)
459 if coercion:
460 coercion_stats.append(coercion)
461 assign_stats.append(
462 Nodes.SingleAssignmentNode(
463 pos = value_temp.pos,
464 lhs = value_target,
465 rhs = temp_result))
466 body.stats[0:0] = coercion_stats + assign_stats
468 result_code = [
469 Nodes.SingleAssignmentNode(
470 pos = dict_obj.pos,
471 lhs = dict_temp,
472 rhs = dict_obj),
473 Nodes.SingleAssignmentNode(
474 pos = node.pos,
475 lhs = pos_temp,
476 rhs = ExprNodes.IntNode(node.pos, value='0',
477 constant_result=0)),
478 Nodes.WhileStatNode(
479 pos = node.pos,
480 condition = ExprNodes.SimpleCallNode(
481 pos = dict_obj.pos,
482 type = PyrexTypes.c_bint_type,
483 function = ExprNodes.NameNode(
484 pos = dict_obj.pos,
485 name = self.PyDict_Next_name,
486 type = self.PyDict_Next_func_type,
487 entry = self.PyDict_Next_entry),
488 args = [dict_temp, pos_temp_addr,
489 key_temp_addr, value_temp_addr]
490 ),
491 body = body,
492 else_clause = node.else_clause
493 )
494 ]
496 return UtilNodes.TempsBlockNode(
497 node.pos, temps=temps,
498 body=Nodes.StatListNode(
499 node.pos,
500 stats = result_code
501 ))
504 class SwitchTransform(Visitor.VisitorTransform):
505 """
506 This transformation tries to turn long if statements into C switch statements.
507 The requirement is that every clause be an (or of) var == value, where the var
508 is common among all clauses and both var and value are ints.
509 """
510 def extract_conditions(self, cond):
511 while True:
512 if isinstance(cond, ExprNodes.CoerceToTempNode):
513 cond = cond.arg
514 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
515 # this is what we get from the FlattenInListTransform
516 cond = cond.subexpression
517 elif isinstance(cond, ExprNodes.TypecastNode):
518 cond = cond.operand
519 else:
520 break
522 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
523 and cond.cascade is None
524 and cond.operator == '=='
525 and not cond.is_python_comparison()):
526 if is_common_value(cond.operand1, cond.operand1):
527 if cond.operand2.is_literal:
528 return cond.operand1, [cond.operand2]
529 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
530 return cond.operand1, [cond.operand2]
531 if is_common_value(cond.operand2, cond.operand2):
532 if cond.operand1.is_literal:
533 return cond.operand2, [cond.operand1]
534 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
535 return cond.operand2, [cond.operand1]
536 elif (isinstance(cond, ExprNodes.BoolBinopNode)
537 and cond.operator == 'or'):
538 t1, c1 = self.extract_conditions(cond.operand1)
539 t2, c2 = self.extract_conditions(cond.operand2)
540 if is_common_value(t1, t2):
541 return t1, c1+c2
542 return None, None
544 def visit_IfStatNode(self, node):
545 self.visitchildren(node)
546 common_var = None
547 case_count = 0
548 cases = []
549 for if_clause in node.if_clauses:
550 var, conditions = self.extract_conditions(if_clause.condition)
551 if var is None:
552 return node
553 elif common_var is not None and not is_common_value(var, common_var):
554 return node
555 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
556 return node
557 else:
558 common_var = var
559 case_count += len(conditions)
560 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
561 conditions = conditions,
562 body = if_clause.body))
563 if case_count < 2:
564 return node
566 common_var = unwrap_node(common_var)
567 return Nodes.SwitchStatNode(pos = node.pos,
568 test = common_var,
569 cases = cases,
570 else_clause = node.else_clause)
572 visit_Node = Visitor.VisitorTransform.recurse_to_children
575 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
576 """
577 This transformation flattens "x in [val1, ..., valn]" into a sequential list
578 of comparisons.
579 """
581 def visit_PrimaryCmpNode(self, node):
582 self.visitchildren(node)
583 if node.cascade is not None:
584 return node
585 elif node.operator == 'in':
586 conjunction = 'or'
587 eq_or_neq = '=='
588 elif node.operator == 'not_in':
589 conjunction = 'and'
590 eq_or_neq = '!='
591 else:
592 return node
594 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
595 return node
597 args = node.operand2.args
598 if len(args) == 0:
599 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
601 lhs = UtilNodes.ResultRefNode(node.operand1)
603 conds = []
604 for arg in args:
605 cond = ExprNodes.PrimaryCmpNode(
606 pos = node.pos,
607 operand1 = lhs,
608 operator = eq_or_neq,
609 operand2 = arg,
610 cascade = None)
611 conds.append(ExprNodes.TypecastNode(
612 pos = node.pos,
613 operand = cond,
614 type = PyrexTypes.c_bint_type))
615 def concat(left, right):
616 return ExprNodes.BoolBinopNode(
617 pos = node.pos,
618 operator = conjunction,
619 operand1 = left,
620 operand2 = right)
622 condition = reduce(concat, conds)
623 return UtilNodes.EvalWithTempExprNode(lhs, condition)
625 visit_Node = Visitor.VisitorTransform.recurse_to_children
628 class DropRefcountingTransform(Visitor.VisitorTransform):
629 """Drop ref-counting in safe places.
630 """
631 visit_Node = Visitor.VisitorTransform.recurse_to_children
633 def visit_ParallelAssignmentNode(self, node):
634 """
635 Parallel swap assignments like 'a,b = b,a' are safe.
636 """
637 left_names, right_names = [], []
638 left_indices, right_indices = [], []
639 temps = []
641 for stat in node.stats:
642 if isinstance(stat, Nodes.SingleAssignmentNode):
643 if not self._extract_operand(stat.lhs, left_names,
644 left_indices, temps):
645 return node
646 if not self._extract_operand(stat.rhs, right_names,
647 right_indices, temps):
648 return node
649 elif isinstance(stat, Nodes.CascadedAssignmentNode):
650 # FIXME
651 return node
652 else:
653 return node
655 if left_names or right_names:
656 # lhs/rhs names must be a non-redundant permutation
657 lnames = [ path for path, n in left_names ]
658 rnames = [ path for path, n in right_names ]
659 if set(lnames) != set(rnames):
660 return node
661 if len(set(lnames)) != len(right_names):
662 return node
664 if left_indices or right_indices:
665 # base name and index of index nodes must be a
666 # non-redundant permutation
667 lindices = []
668 for lhs_node in left_indices:
669 index_id = self._extract_index_id(lhs_node)
670 if not index_id:
671 return node
672 lindices.append(index_id)
673 rindices = []
674 for rhs_node in right_indices:
675 index_id = self._extract_index_id(rhs_node)
676 if not index_id:
677 return node
678 rindices.append(index_id)
680 if set(lindices) != set(rindices):
681 return node
682 if len(set(lindices)) != len(right_indices):
683 return node
685 # really supporting IndexNode requires support in
686 # __Pyx_GetItemInt(), so let's stop short for now
687 return node
689 temp_args = [t.arg for t in temps]
690 for temp in temps:
691 temp.use_managed_ref = False
693 for _, name_node in left_names + right_names:
694 if name_node not in temp_args:
695 name_node.use_managed_ref = False
697 for index_node in left_indices + right_indices:
698 index_node.use_managed_ref = False
700 return node
702 def _extract_operand(self, node, names, indices, temps):
703 node = unwrap_node(node)
704 if not node.type.is_pyobject:
705 return False
706 if isinstance(node, ExprNodes.CoerceToTempNode):
707 temps.append(node)
708 node = node.arg
709 name_path = []
710 obj_node = node
711 while isinstance(obj_node, ExprNodes.AttributeNode):
712 if obj_node.is_py_attr:
713 return False
714 name_path.append(obj_node.member)
715 obj_node = obj_node.obj
716 if isinstance(obj_node, ExprNodes.NameNode):
717 name_path.append(obj_node.name)
718 names.append( ('.'.join(name_path[::-1]), node) )
719 elif isinstance(node, ExprNodes.IndexNode):
720 if node.base.type != Builtin.list_type:
721 return False
722 if not node.index.type.is_int:
723 return False
724 if not isinstance(node.base, ExprNodes.NameNode):
725 return False
726 indices.append(node)
727 else:
728 return False
729 return True
731 def _extract_index_id(self, index_node):
732 base = index_node.base
733 index = index_node.index
734 if isinstance(index, ExprNodes.NameNode):
735 index_val = index.name
736 elif isinstance(index, ExprNodes.ConstNode):
737 # FIXME:
738 return None
739 else:
740 return None
741 return (base.name, index_val)
744 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
745 """Optimize some common calls to builtin types *before* the type
746 analysis phase and *after* the declarations analysis phase.
748 This transform cannot make use of any argument types, but it can
749 restructure the tree in a way that the type analysis phase can
750 respond to.
752 Introducing C function calls here may not be a good idea. Move
753 them to the OptimizeBuiltinCalls transform instead, which runs
754 after type analyis.
755 """
756 # only intercept on call nodes
757 visit_Node = Visitor.VisitorTransform.recurse_to_children
759 def visit_SimpleCallNode(self, node):
760 self.visitchildren(node)
761 function = node.function
762 if not self._function_is_builtin_name(function):
763 return node
764 return self._dispatch_to_handler(node, function, node.args)
766 def visit_GeneralCallNode(self, node):
767 self.visitchildren(node)
768 function = node.function
769 if not self._function_is_builtin_name(function):
770 return node
771 arg_tuple = node.positional_args
772 if not isinstance(arg_tuple, ExprNodes.TupleNode):
773 return node
774 args = arg_tuple.args
775 return self._dispatch_to_handler(
776 node, function, args, node.keyword_args)
778 def _function_is_builtin_name(self, function):
779 if not function.is_name:
780 return False
781 entry = self.env_stack[-1].lookup(function.name)
782 if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
783 return False
784 return True
786 def _dispatch_to_handler(self, node, function, args, kwargs=None):
787 if kwargs is None:
788 handler_name = '_handle_simple_function_%s' % function.name
789 else:
790 handler_name = '_handle_general_function_%s' % function.name
791 handle_call = getattr(self, handler_name, None)
792 if handle_call is not None:
793 if kwargs is None:
794 return handle_call(node, args)
795 else:
796 return handle_call(node, args, kwargs)
797 return node
799 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
800 node.function = ExprNodes.PythonCapiFunctionNode(
801 node.function.pos, node.function.name, cname, func_type,
802 utility_code = utility_code)
804 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
805 if not expected: # None or 0
806 arg_str = ''
807 elif isinstance(expected, basestring) or expected > 1:
808 arg_str = '...'
809 elif expected == 1:
810 arg_str = 'x'
811 else:
812 arg_str = ''
813 if expected is not None:
814 expected_str = 'expected %s, ' % expected
815 else:
816 expected_str = ''
817 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
818 function_name, arg_str, expected_str, len(args)))
820 # specific handlers for simple call nodes
822 def _handle_simple_function_set(self, node, pos_args):
823 """Replace set([a,b,...]) by a literal set {a,b,...} and
824 set([ x for ... ]) by a literal { x for ... }.
825 """
826 arg_count = len(pos_args)
827 if arg_count == 0:
828 return ExprNodes.SetNode(node.pos, args=[],
829 type=Builtin.set_type)
830 if arg_count > 1:
831 return node
832 iterable = pos_args[0]
833 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
834 return ExprNodes.SetNode(node.pos, args=iterable.args)
835 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
836 isinstance(iterable.target, (ExprNodes.ListNode,
837 ExprNodes.SetNode)):
838 iterable.target = ExprNodes.SetNode(node.pos, args=[])
839 iterable.pos = node.pos
840 return iterable
841 else:
842 return node
844 def _handle_simple_function_dict(self, node, pos_args):
845 """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
846 """
847 if len(pos_args) != 1:
848 return node
849 arg = pos_args[0]
850 if isinstance(arg, ExprNodes.ComprehensionNode) and \
851 isinstance(arg.target, (ExprNodes.ListNode,
852 ExprNodes.SetNode)):
853 append_node = arg.append
854 if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
855 len(append_node.expr.args) == 2:
856 key_node, value_node = append_node.expr.args
857 target_node = ExprNodes.DictNode(
858 pos=arg.target.pos, key_value_pairs=[])
859 new_append_node = ExprNodes.DictComprehensionAppendNode(
860 append_node.pos, target=target_node,
861 key_expr=key_node, value_expr=value_node)
862 arg.target = target_node
863 arg.type = target_node.type
864 replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
865 return replace_in(arg)
866 return node
868 def _handle_simple_function_float(self, node, pos_args):
869 if len(pos_args) == 0:
870 return ExprNodes.FloatNode(node.pos, value='0.0')
871 if len(pos_args) > 1:
872 self._error_wrong_arg_count('float', node, pos_args, 1)
873 return node
875 # specific handlers for general call nodes
877 def _handle_general_function_dict(self, node, pos_args, kwargs):
878 """Replace dict(a=b,c=d,...) by the underlying keyword dict
879 construction which is done anyway.
880 """
881 if len(pos_args) > 0:
882 return node
883 if not isinstance(kwargs, ExprNodes.DictNode):
884 return node
885 if node.starstar_arg:
886 # we could optimize this by updating the kw dict instead
887 return node
888 return kwargs
891 class OptimizeBuiltinCalls(Visitor.EnvTransform):
892 """Optimize some common methods calls and instantiation patterns
893 for builtin types *after* the type analysis phase.
895 Running after type analysis, this transform can only perform
896 function replacements that do not alter the function return type
897 in a way that was not anticipated by the type analysis.
898 """
899 # only intercept on call nodes
900 visit_Node = Visitor.VisitorTransform.recurse_to_children
902 def visit_GeneralCallNode(self, node):
903 self.visitchildren(node)
904 function = node.function
905 if not function.type.is_pyobject:
906 return node
907 arg_tuple = node.positional_args
908 if not isinstance(arg_tuple, ExprNodes.TupleNode):
909 return node
910 args = arg_tuple.args
911 return self._dispatch_to_handler(
912 node, function, args, node.keyword_args)
914 def visit_SimpleCallNode(self, node):
915 self.visitchildren(node)
916 function = node.function
917 if function.type.is_pyobject:
918 arg_tuple = node.arg_tuple
919 if not isinstance(arg_tuple, ExprNodes.TupleNode):
920 return node
921 args = arg_tuple.args
922 else:
923 args = node.args
924 return self._dispatch_to_handler(
925 node, function, args)
927 ### cleanup to avoid redundant coercions to/from Python types
929 def _visit_PyTypeTestNode(self, node):
930 # disabled - appears to break assignments in some cases, and
931 # also drops a None check, which might still be required
932 """Flatten redundant type checks after tree changes.
933 """
934 old_arg = node.arg
935 self.visitchildren(node)
936 if old_arg is node.arg or node.arg.type != node.type:
937 return node
938 return node.arg
940 def visit_CoerceFromPyTypeNode(self, node):
941 """Drop redundant conversion nodes after tree changes.
943 Also, optimise away calls to Python's builtin int() and
944 float() if the result is going to be coerced back into a C
945 type anyway.
946 """
947 self.visitchildren(node)
948 arg = node.arg
949 if not arg.type.is_pyobject:
950 # no Python conversion left at all, just do a C coercion instead
951 if node.type == arg.type:
952 return arg
953 else:
954 return arg.coerce_to(node.type, self.env_stack[-1])
955 if not isinstance(arg, ExprNodes.SimpleCallNode):
956 return node
957 if not (node.type.is_int or node.type.is_float):
958 return node
959 function = arg.function
960 if not isinstance(function, ExprNodes.NameNode) \
961 or not function.type.is_builtin_type \
962 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
963 return node
964 args = arg.arg_tuple.args
965 if len(args) != 1:
966 return node
967 func_arg = args[0]
968 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
969 func_arg = func_arg.arg
970 elif func_arg.type.is_pyobject:
971 # play safe: Python conversion might work on all sorts of things
972 return node
973 if function.name == 'int':
974 if func_arg.type.is_int or node.type.is_int:
975 if func_arg.type == node.type:
976 return func_arg
977 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
978 return ExprNodes.TypecastNode(
979 node.pos, operand=func_arg, type=node.type)
980 elif function.name == 'float':
981 if func_arg.type.is_float or node.type.is_float:
982 if func_arg.type == node.type:
983 return func_arg
984 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
985 return ExprNodes.TypecastNode(
986 node.pos, operand=func_arg, type=node.type)
987 return node
989 ### dispatch to specific optimisers
991 def _find_handler(self, match_name, has_kwargs):
992 call_type = has_kwargs and 'general' or 'simple'
993 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
994 if handler is None:
995 handler = getattr(self, '_handle_any_%s' % match_name, None)
996 return handler
998 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
999 if function.is_name:
1000 # we only consider functions that are either builtin
1001 # Python functions or builtins that were already replaced
1002 # into a C function call (defined in the builtin scope)
1003 if not function.entry:
1004 return node
1005 is_builtin = function.entry.is_builtin \
1006 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1007 if not is_builtin:
1008 return node
1009 function_handler = self._find_handler(
1010 "function_%s" % function.name, kwargs)
1011 if function_handler is None:
1012 return node
1013 if kwargs:
1014 return function_handler(node, arg_list, kwargs)
1015 else:
1016 return function_handler(node, arg_list)
1017 elif function.is_attribute and function.type.is_pyobject:
1018 attr_name = function.attribute
1019 self_arg = function.obj
1020 obj_type = self_arg.type
1021 is_unbound_method = False
1022 if obj_type.is_builtin_type:
1023 if obj_type is Builtin.type_type and arg_list and \
1024 arg_list[0].type.is_pyobject:
1025 # calling an unbound method like 'list.append(L,x)'
1026 # (ignoring 'type.mro()' here ...)
1027 type_name = function.obj.name
1028 self_arg = None
1029 is_unbound_method = True
1030 else:
1031 type_name = obj_type.name
1032 else:
1033 type_name = "object" # safety measure
1034 method_handler = self._find_handler(
1035 "method_%s_%s" % (type_name, attr_name), kwargs)
1036 if method_handler is None:
1037 if attr_name in TypeSlots.method_name_to_slot \
1038 or attr_name == '__new__':
1039 method_handler = self._find_handler(
1040 "slot%s" % attr_name, kwargs)
1041 if method_handler is None:
1042 return node
1043 if self_arg is not None:
1044 arg_list = [self_arg] + list(arg_list)
1045 if kwargs:
1046 return method_handler(node, arg_list, kwargs, is_unbound_method)
1047 else:
1048 return method_handler(node, arg_list, is_unbound_method)
1049 else:
1050 return node
1052 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1053 if not expected: # None or 0
1054 arg_str = ''
1055 elif isinstance(expected, basestring) or expected > 1:
1056 arg_str = '...'
1057 elif expected == 1:
1058 arg_str = 'x'
1059 else:
1060 arg_str = ''
1061 if expected is not None:
1062 expected_str = 'expected %s, ' % expected
1063 else:
1064 expected_str = ''
1065 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1066 function_name, arg_str, expected_str, len(args)))
1068 ### builtin types
1070 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1071 Builtin.dict_type, [
1072 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1073 ])
1075 def _handle_simple_function_dict(self, node, pos_args):
1076 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1077 """
1078 if len(pos_args) != 1:
1079 return node
1080 arg = pos_args[0]
1081 if arg.type is Builtin.dict_type:
1082 arg = ExprNodes.NoneCheckNode(
1083 arg, "PyExc_TypeError", "'NoneType' is not iterable")
1084 return ExprNodes.PythonCapiCallNode(
1085 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1086 args = [arg],
1087 is_temp = node.is_temp
1089 return node
1091 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1092 Builtin.tuple_type, [
1093 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1094 ])
1096 def _handle_simple_function_tuple(self, node, pos_args):
1097 """Replace tuple([...]) by a call to PyList_AsTuple.
1098 """
1099 if len(pos_args) != 1:
1100 return node
1101 list_arg = pos_args[0]
1102 if list_arg.type is not Builtin.list_type:
1103 return node
1104 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1105 ExprNodes.ListNode)):
1106 pos_args[0] = ExprNodes.NoneCheckNode(
1107 list_arg, "PyExc_TypeError",
1108 "'NoneType' object is not iterable")
1110 return ExprNodes.PythonCapiCallNode(
1111 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1112 args = pos_args,
1113 is_temp = node.is_temp
1116 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1117 PyrexTypes.c_double_type, [
1118 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1119 ],
1120 exception_value = "((double)-1)",
1121 exception_check = True)
1123 def _handle_simple_function_float(self, node, pos_args):
1124 """Transform float() into either a C type cast or a faster C
1125 function call.
1126 """
1127 # Note: this requires the float() function to be typed as
1128 # returning a C 'double'
1129 if len(pos_args) != 1:
1130 self._error_wrong_arg_count('float', node, pos_args, 1)
1131 return node
1132 func_arg = pos_args[0]
1133 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1134 func_arg = func_arg.arg
1135 if func_arg.type is PyrexTypes.c_double_type:
1136 return func_arg
1137 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1138 return ExprNodes.TypecastNode(
1139 node.pos, operand=func_arg, type=node.type)
1140 return ExprNodes.PythonCapiCallNode(
1141 node.pos, "__Pyx_PyObject_AsDouble",
1142 self.PyObject_AsDouble_func_type,
1143 args = pos_args,
1144 is_temp = node.is_temp,
1145 utility_code = pyobject_as_double_utility_code,
1146 py_name = "float")
1148 ### builtin functions
1150 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
1151 PyrexTypes.py_object_type, [
1152 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1153 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1154 ])
1156 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
1157 PyrexTypes.py_object_type, [
1158 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1159 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1160 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1161 ])
1163 def _handle_simple_function_getattr(self, node, pos_args):
1164 """Replace 2/3 argument forms of getattr() by C-API calls.
1165 """
1166 if len(pos_args) == 2:
1167 return ExprNodes.PythonCapiCallNode(
1168 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
1169 args = pos_args,
1170 is_temp = node.is_temp)
1171 elif len(pos_args) == 3:
1172 return ExprNodes.PythonCapiCallNode(
1173 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
1174 args = pos_args,
1175 is_temp = node.is_temp,
1176 utility_code = Builtin.getattr3_utility_code)
1177 else:
1178 self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
1179 return node
1181 PyObject_GetIter_func_type = PyrexTypes.CFuncType(
1182 PyrexTypes.py_object_type, [
1183 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1184 ])
1186 PyCallIter_New_func_type = PyrexTypes.CFuncType(
1187 PyrexTypes.py_object_type, [
1188 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1189 PyrexTypes.CFuncTypeArg("sentinel", PyrexTypes.py_object_type, None),
1190 ])
1192 def _handle_simple_function_iter(self, node, pos_args):
1193 """Replace 1/2 argument forms of iter() by C-API calls.
1194 """
1195 if len(pos_args) == 1:
1196 return ExprNodes.PythonCapiCallNode(
1197 node.pos, "PyObject_GetIter", self.PyObject_GetIter_func_type,
1198 args = pos_args,
1199 is_temp = node.is_temp)
1200 elif len(pos_args) == 2:
1201 return ExprNodes.PythonCapiCallNode(
1202 node.pos, "PyCallIter_New", self.PyCallIter_New_func_type,
1203 args = pos_args,
1204 is_temp = node.is_temp)
1205 else:
1206 self._error_wrong_arg_count('iter', node, pos_args, '1 or 2')
1207 return node
1209 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1210 PyrexTypes.c_size_t_type, [
1211 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1212 ])
1214 def _handle_simple_function_len(self, node, pos_args):
1215 """Replace len(char*) by the equivalent call to strlen().
1216 """
1217 if len(pos_args) != 1:
1218 self._error_wrong_arg_count('len', node, pos_args, 1)
1219 return node
1220 arg = pos_args[0]
1221 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1222 arg = arg.arg
1223 if not arg.type.is_string:
1224 return node
1225 if not node.type.is_numeric:
1226 # this optimisation only works when we already replaced
1227 # len() by PyObject_Length() which returns a Py_ssize_t
1228 # instead of a Python object, so we can return a plain
1229 # size_t instead without caring about Python object
1230 # conversion etc.
1231 return node
1232 node = ExprNodes.PythonCapiCallNode(
1233 node.pos, "strlen", self.Pyx_strlen_func_type,
1234 args = [arg],
1235 is_temp = node.is_temp,
1236 utility_code = include_string_h_utility_code
1238 return node
1240 Pyx_Type_func_type = PyrexTypes.CFuncType(
1241 Builtin.type_type, [
1242 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1243 ])
1245 def _handle_simple_function_type(self, node, pos_args):
1246 """Replace type(o) by a macro call to Py_TYPE(o).
1247 """
1248 if len(pos_args) != 1:
1249 return node
1250 node = ExprNodes.PythonCapiCallNode(
1251 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1252 args = pos_args,
1253 is_temp = False)
1254 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1256 ### special methods
1258 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1259 PyrexTypes.py_object_type, [
1260 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1261 ])
1263 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1264 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1265 """
1266 obj = node.function.obj
1267 if not is_unbound_method or len(args) != 1:
1268 return node
1269 type_arg = args[0]
1270 if not obj.is_name or not type_arg.is_name:
1271 # play safe
1272 return node
1273 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1274 # not a known type, play safe
1275 return node
1276 if not type_arg.type_entry or not obj.type_entry:
1277 if obj.name != type_arg.name:
1278 return node
1279 # otherwise, we know it's a type and we know it's the same
1280 # type for both - that should do
1281 elif type_arg.type_entry != obj.type_entry:
1282 # different types - may or may not lead to an error at runtime
1283 return node
1285 # FIXME: we could potentially look up the actual tp_new C
1286 # method of the extension type and call that instead of the
1287 # generic slot. That would also allow us to pass parameters
1288 # efficiently.
1290 if not type_arg.type_entry:
1291 # arbitrary variable, needs a None check for safety
1292 type_arg = ExprNodes.NoneCheckNode(
1293 type_arg, "PyExc_TypeError",
1294 "object.__new__(X): X is not a type object (NoneType)")
1296 return ExprNodes.PythonCapiCallNode(
1297 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1298 args = [type_arg],
1299 utility_code = tpnew_utility_code,
1300 is_temp = node.is_temp
1303 ### methods of builtin types
1305 PyObject_Append_func_type = PyrexTypes.CFuncType(
1306 PyrexTypes.py_object_type, [
1307 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1308 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1309 ])
1311 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
1312 """Optimistic optimisation as X.append() is almost always
1313 referring to a list.
1314 """
1315 if len(args) != 2:
1316 return node
1318 return ExprNodes.PythonCapiCallNode(
1319 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
1320 args = args,
1321 is_temp = node.is_temp,
1322 utility_code = append_utility_code
1325 PyObject_Pop_func_type = PyrexTypes.CFuncType(
1326 PyrexTypes.py_object_type, [
1327 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1328 ])
1330 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
1331 PyrexTypes.py_object_type, [
1332 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1333 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
1334 ])
1336 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
1337 """Optimistic optimisation as X.pop([n]) is almost always
1338 referring to a list.
1339 """
1340 if len(args) == 1:
1341 return ExprNodes.PythonCapiCallNode(
1342 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
1343 args = args,
1344 is_temp = node.is_temp,
1345 utility_code = pop_utility_code
1347 elif len(args) == 2:
1348 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
1349 original_type = args[1].arg.type
1350 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
1351 args[1] = args[1].arg
1352 return ExprNodes.PythonCapiCallNode(
1353 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
1354 args = args,
1355 is_temp = node.is_temp,
1356 utility_code = pop_index_utility_code
1359 return node
1361 PyList_Append_func_type = PyrexTypes.CFuncType(
1362 PyrexTypes.c_int_type, [
1363 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1364 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1365 ],
1366 exception_value = "-1")
1368 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
1369 """Call PyList_Append() instead of l.append().
1370 """
1371 if len(args) != 2:
1372 self._error_wrong_arg_count('list.append', node, args, 2)
1373 return node
1374 return self._substitute_method_call(
1375 node, "PyList_Append", self.PyList_Append_func_type,
1376 'append', is_unbound_method, args)
1378 single_param_func_type = PyrexTypes.CFuncType(
1379 PyrexTypes.c_int_type, [
1380 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1381 ],
1382 exception_value = "-1")
1384 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
1385 """Call PyList_Sort() instead of the 0-argument l.sort().
1386 """
1387 if len(args) != 1:
1388 return node
1389 return self._substitute_method_call(
1390 node, "PyList_Sort", self.single_param_func_type,
1391 'sort', is_unbound_method, args)
1393 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
1394 """Call PyList_Reverse() instead of l.reverse().
1395 """
1396 if len(args) != 1:
1397 self._error_wrong_arg_count('list.reverse', node, args, 1)
1398 return node
1399 return self._substitute_method_call(
1400 node, "PyList_Reverse", self.single_param_func_type,
1401 'reverse', is_unbound_method, args)
1403 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
1404 PyrexTypes.py_object_type, [
1405 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
1406 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
1407 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1408 ])
1410 def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
1411 """Replace dict.get() by a call to PyDict_GetItem().
1412 """
1413 if len(args) == 2:
1414 args.append(ExprNodes.NoneNode(node.pos))
1415 elif len(args) != 3:
1416 self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
1417 return node
1419 return self._substitute_method_call(
1420 node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
1421 'get', is_unbound_method, args,
1422 utility_code = dict_getitem_default_utility_code)
1424 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
1425 Builtin.list_type, [
1426 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
1427 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
1428 ])
1430 def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
1431 """Replace unicode.splitlines(...) by a direct call to the
1432 corresponding C-API function.
1433 """
1434 if len(args) not in (1,2):
1435 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
1436 return node
1437 if len(args) < 2:
1438 args.append(ExprNodes.BoolNode(node.pos, value=False))
1439 else:
1440 args[1] = args[1].coerce_to(PyrexTypes.c_bint_type,
1441 self.env_stack[-1])
1443 return self._substitute_method_call(
1444 node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
1445 'splitlines', is_unbound_method, args)
1447 PyUnicode_Split_func_type = PyrexTypes.CFuncType(
1448 Builtin.list_type, [
1449 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
1450 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
1451 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
1455 def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
1456 """Replace unicode.split(...) by a direct call to the
1457 corresponding C-API function.
1458 """
1459 if len(args) not in (1,2,3):
1460 self._error_wrong_arg_count('unicode.split', node, args, "1-3")
1461 return node
1462 if len(args) < 2:
1463 args.append(ExprNodes.NullNode(node.pos))
1464 if len(args) < 3:
1465 args.append(ExprNodes.IntNode(
1466 node.pos, value="-1", type=PyrexTypes.c_py_ssize_t_type))
1467 else:
1468 args[2] = args[2].coerce_to(PyrexTypes.c_py_ssize_t_type,
1469 self.env_stack[-1])
1471 return self._substitute_method_call(
1472 node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
1473 'split', is_unbound_method, args)
1475 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
1476 Builtin.bytes_type, [
1477 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1478 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1479 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1480 ])
1482 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
1483 Builtin.bytes_type, [
1484 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1485 ])
1487 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
1488 'unicode_escape', 'raw_unicode_escape']
1490 _special_codecs = [ (name, codecs.getencoder(name))
1491 for name in _special_encodings ]
1493 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
1494 """Replace unicode.encode(...) by a direct C-API call to the
1495 corresponding codec.
1496 """
1497 if len(args) < 1 or len(args) > 3:
1498 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1499 return node
1501 string_node = args[0]
1503 if len(args) == 1:
1504 null_node = ExprNodes.NullNode(node.pos)
1505 return self._substitute_method_call(
1506 node, "PyUnicode_AsEncodedString",
1507 self.PyUnicode_AsEncodedString_func_type,
1508 'encode', is_unbound_method, [string_node, null_node, null_node])
1510 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1511 if parameters is None:
1512 return node
1513 encoding, encoding_node, error_handling, error_handling_node = parameters
1515 if isinstance(string_node, ExprNodes.UnicodeNode):
1516 # constant, so try to do the encoding at compile time
1517 try:
1518 value = string_node.value.encode(encoding, error_handling)
1519 except:
1520 # well, looks like we can't
1521 pass
1522 else:
1523 value = BytesLiteral(value)
1524 value.encoding = encoding
1525 return ExprNodes.BytesNode(
1526 string_node.pos, value=value, type=Builtin.bytes_type)
1528 if error_handling == 'strict':
1529 # try to find a specific encoder function
1530 codec_name = self._find_special_codec_name(encoding)
1531 if codec_name is not None:
1532 encode_function = "PyUnicode_As%sString" % codec_name
1533 return self._substitute_method_call(
1534 node, encode_function,
1535 self.PyUnicode_AsXyzString_func_type,
1536 'encode', is_unbound_method, [string_node])
1538 return self._substitute_method_call(
1539 node, "PyUnicode_AsEncodedString",
1540 self.PyUnicode_AsEncodedString_func_type,
1541 'encode', is_unbound_method,
1542 [string_node, encoding_node, error_handling_node])
1544 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
1545 Builtin.unicode_type, [
1546 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1547 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1548 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1549 ])
1551 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
1552 Builtin.unicode_type, [
1553 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1554 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1555 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1556 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1557 ])
1559 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
1560 """Replace char*.decode() by a direct C-API call to the
1561 corresponding codec, possibly resoving a slice on the char*.
1562 """
1563 if len(args) < 1 or len(args) > 3:
1564 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
1565 return node
1566 temps = []
1567 if isinstance(args[0], ExprNodes.SliceIndexNode):
1568 index_node = args[0]
1569 string_node = index_node.base
1570 if not string_node.type.is_string:
1571 # nothing to optimise here
1572 return node
1573 start, stop = index_node.start, index_node.stop
1574 if not start or start.constant_result == 0:
1575 start = None
1576 else:
1577 if start.type.is_pyobject:
1578 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1579 if stop:
1580 start = UtilNodes.LetRefNode(start)
1581 temps.append(start)
1582 string_node = ExprNodes.AddNode(pos=start.pos,
1583 operand1=string_node,
1584 operator='+',
1585 operand2=start,
1586 is_temp=False,
1587 type=string_node.type
1589 if stop and stop.type.is_pyobject:
1590 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1591 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
1592 and args[0].arg.type.is_string:
1593 # use strlen() to find the string length, just as CPython would
1594 start = stop = None
1595 string_node = args[0].arg
1596 else:
1597 # let Python do its job
1598 return node
1600 if not stop:
1601 if start or not string_node.is_name:
1602 string_node = UtilNodes.LetRefNode(string_node)
1603 temps.append(string_node)
1604 stop = ExprNodes.PythonCapiCallNode(
1605 string_node.pos, "strlen", self.Pyx_strlen_func_type,
1606 args = [string_node],
1607 is_temp = False,
1608 utility_code = include_string_h_utility_code,
1609 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1610 elif start:
1611 stop = ExprNodes.SubNode(
1612 pos = stop.pos,
1613 operand1 = stop,
1614 operator = '-',
1615 operand2 = start,
1616 is_temp = False,
1617 type = PyrexTypes.c_py_ssize_t_type
1620 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1621 if parameters is None:
1622 return node
1623 encoding, encoding_node, error_handling, error_handling_node = parameters
1625 # try to find a specific encoder function
1626 codec_name = None
1627 if encoding is not None:
1628 codec_name = self._find_special_codec_name(encoding)
1629 if codec_name is not None:
1630 decode_function = "PyUnicode_Decode%s" % codec_name
1631 node = ExprNodes.PythonCapiCallNode(
1632 node.pos, decode_function,
1633 self.PyUnicode_DecodeXyz_func_type,
1634 args = [string_node, stop, error_handling_node],
1635 is_temp = node.is_temp,
1637 else:
1638 node = ExprNodes.PythonCapiCallNode(
1639 node.pos, "PyUnicode_Decode",
1640 self.PyUnicode_Decode_func_type,
1641 args = [string_node, stop, encoding_node, error_handling_node],
1642 is_temp = node.is_temp,
1645 for temp in temps[::-1]:
1646 node = UtilNodes.EvalWithTempExprNode(temp, node)
1647 return node
1649 def _find_special_codec_name(self, encoding):
1650 try:
1651 requested_codec = codecs.getencoder(encoding)
1652 except:
1653 return None
1654 for name, codec in self._special_codecs:
1655 if codec == requested_codec:
1656 if '_' in name:
1657 name = ''.join([ s.capitalize()
1658 for s in name.split('_')])
1659 return name
1660 return None
1662 def _unpack_encoding_and_error_mode(self, pos, args):
1663 encoding_node = args[1]
1664 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
1665 encoding_node = encoding_node.arg
1666 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1667 ExprNodes.BytesNode)):
1668 encoding = encoding_node.value
1669 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
1670 type=PyrexTypes.c_char_ptr_type)
1671 elif encoding_node.type.is_string:
1672 encoding = None
1673 else:
1674 return None
1676 null_node = ExprNodes.NullNode(pos)
1677 if len(args) == 3:
1678 error_handling_node = args[2]
1679 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
1680 error_handling_node = error_handling_node.arg
1681 if isinstance(error_handling_node,
1682 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1683 ExprNodes.BytesNode)):
1684 error_handling = error_handling_node.value
1685 if error_handling == 'strict':
1686 error_handling_node = null_node
1687 else:
1688 error_handling_node = ExprNodes.BytesNode(
1689 error_handling_node.pos, value=error_handling,
1690 type=PyrexTypes.c_char_ptr_type)
1691 elif error_handling_node.type.is_string:
1692 error_handling = None
1693 else:
1694 return None
1695 else:
1696 error_handling = 'strict'
1697 error_handling_node = null_node
1699 return (encoding, encoding_node, error_handling, error_handling_node)
1701 def _substitute_method_call(self, node, name, func_type,
1702 attr_name, is_unbound_method, args=(),
1703 utility_code=None):
1704 args = list(args)
1705 if args:
1706 self_arg = args[0]
1707 if is_unbound_method:
1708 self_arg = ExprNodes.NoneCheckNode(
1709 self_arg, "PyExc_TypeError",
1710 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
1711 attr_name, node.function.obj.name))
1712 else:
1713 self_arg = ExprNodes.NoneCheckNode(
1714 self_arg, "PyExc_AttributeError",
1715 "'NoneType' object has no attribute '%s'" % attr_name)
1716 args[0] = self_arg
1717 return ExprNodes.PythonCapiCallNode(
1718 node.pos, name, func_type,
1719 args = args,
1720 is_temp = node.is_temp,
1721 utility_code = utility_code
1725 dict_getitem_default_utility_code = UtilityCode(
1726 proto = '''
1727 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
1728 PyObject* value;
1729 #if PY_MAJOR_VERSION >= 3
1730 value = PyDict_GetItemWithError(d, key);
1731 if (unlikely(!value)) {
1732 if (unlikely(PyErr_Occurred()))
1733 return NULL;
1734 value = default_value;
1736 Py_INCREF(value);
1737 #else
1738 if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
1739 /* these presumably have safe hash functions */
1740 value = PyDict_GetItem(d, key);
1741 if (unlikely(!value)) {
1742 value = default_value;
1744 Py_INCREF(value);
1745 } else {
1746 PyObject *m;
1747 m = __Pyx_GetAttrString(d, "get");
1748 if (!m) return NULL;
1749 value = PyObject_CallFunctionObjArgs(m, key,
1750 (default_value == Py_None) ? NULL : default_value, NULL);
1751 Py_DECREF(m);
1753 #endif
1754 return value;
1756 ''',
1757 impl = ""
1760 append_utility_code = UtilityCode(
1761 proto = """
1762 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1763 if (likely(PyList_CheckExact(L))) {
1764 if (PyList_Append(L, x) < 0) return NULL;
1765 Py_INCREF(Py_None);
1766 return Py_None; /* this is just to have an accurate signature */
1768 else {
1769 PyObject *r, *m;
1770 m = __Pyx_GetAttrString(L, "append");
1771 if (!m) return NULL;
1772 r = PyObject_CallFunctionObjArgs(m, x, NULL);
1773 Py_DECREF(m);
1774 return r;
1777 """,
1778 impl = ""
1782 pop_utility_code = UtilityCode(
1783 proto = """
1784 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1785 #if PY_VERSION_HEX >= 0x02040000
1786 if (likely(PyList_CheckExact(L))
1787 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
1788 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
1789 Py_SIZE(L) -= 1;
1790 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
1792 #endif
1793 PyObject *r, *m;
1794 m = __Pyx_GetAttrString(L, "pop");
1795 if (!m) return NULL;
1796 r = PyObject_CallObject(m, NULL);
1797 Py_DECREF(m);
1798 return r;
1800 """,
1801 impl = ""
1804 pop_index_utility_code = UtilityCode(
1805 proto = """
1806 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
1807 """,
1808 impl = """
1809 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
1810 PyObject *r, *m, *t, *py_ix;
1811 #if PY_VERSION_HEX >= 0x02040000
1812 if (likely(PyList_CheckExact(L))) {
1813 Py_ssize_t size = PyList_GET_SIZE(L);
1814 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
1815 if (ix < 0) {
1816 ix += size;
1818 if (likely(0 <= ix && ix < size)) {
1819 Py_ssize_t i;
1820 PyObject* v = PyList_GET_ITEM(L, ix);
1821 Py_SIZE(L) -= 1;
1822 size -= 1;
1823 for(i=ix; i<size; i++) {
1824 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
1826 return v;
1830 #endif
1831 py_ix = t = NULL;
1832 m = __Pyx_GetAttrString(L, "pop");
1833 if (!m) goto bad;
1834 py_ix = PyInt_FromSsize_t(ix);
1835 if (!py_ix) goto bad;
1836 t = PyTuple_New(1);
1837 if (!t) goto bad;
1838 PyTuple_SET_ITEM(t, 0, py_ix);
1839 py_ix = NULL;
1840 r = PyObject_CallObject(m, t);
1841 Py_DECREF(m);
1842 Py_DECREF(t);
1843 return r;
1844 bad:
1845 Py_XDECREF(m);
1846 Py_XDECREF(t);
1847 Py_XDECREF(py_ix);
1848 return NULL;
1850 """
1854 pyobject_as_double_utility_code = UtilityCode(
1855 proto = '''
1856 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
1858 #define __Pyx_PyObject_AsDouble(obj) \\
1859 ((likely(PyFloat_CheckExact(obj))) ? \\
1860 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
1861 ''',
1862 impl='''
1863 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
1864 PyObject* float_value;
1865 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
1866 return PyFloat_AsDouble(obj);
1867 } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
1868 #if PY_MAJOR_VERSION >= 3
1869 float_value = PyFloat_FromString(obj);
1870 #else
1871 float_value = PyFloat_FromString(obj, 0);
1872 #endif
1873 } else {
1874 PyObject* args = PyTuple_New(1);
1875 if (unlikely(!args)) goto bad;
1876 PyTuple_SET_ITEM(args, 0, obj);
1877 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
1878 PyTuple_SET_ITEM(args, 0, 0);
1879 Py_DECREF(args);
1881 if (likely(float_value)) {
1882 double value = PyFloat_AS_DOUBLE(float_value);
1883 Py_DECREF(float_value);
1884 return value;
1886 bad:
1887 return (double)-1;
1889 '''
1893 include_string_h_utility_code = UtilityCode(
1894 proto = """
1895 #include <string.h>
1896 """
1900 tpnew_utility_code = UtilityCode(
1901 proto = """
1902 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1903 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
1904 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
1906 """ % {'TUPLE' : Naming.empty_tuple}
1910 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
1911 """Calculate the result of constant expressions to store it in
1912 ``expr_node.constant_result``, and replace trivial cases by their
1913 constant result.
1914 """
1915 def _calculate_const(self, node):
1916 if node.constant_result is not ExprNodes.constant_value_not_set:
1917 return
1919 # make sure we always set the value
1920 not_a_constant = ExprNodes.not_a_constant
1921 node.constant_result = not_a_constant
1923 # check if all children are constant
1924 children = self.visitchildren(node)
1925 for child_result in children.itervalues():
1926 if type(child_result) is list:
1927 for child in child_result:
1928 if child.constant_result is not_a_constant:
1929 return
1930 elif child_result.constant_result is not_a_constant:
1931 return
1933 # now try to calculate the real constant value
1934 try:
1935 node.calculate_constant_result()
1936 # if node.constant_result is not ExprNodes.not_a_constant:
1937 # print node.__class__.__name__, node.constant_result
1938 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
1939 # ignore all 'normal' errors here => no constant result
1940 pass
1941 except Exception:
1942 # this looks like a real error
1943 import traceback, sys
1944 traceback.print_exc(file=sys.stdout)
1946 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
1947 ExprNodes.LongNode, ExprNodes.FloatNode)
1949 def _widest_node_class(self, *nodes):
1950 try:
1951 return self.NODE_TYPE_ORDER[
1952 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
1953 except ValueError:
1954 return None
1956 def visit_ExprNode(self, node):
1957 self._calculate_const(node)
1958 return node
1960 def visit_BinopNode(self, node):
1961 self._calculate_const(node)
1962 if node.constant_result is ExprNodes.not_a_constant:
1963 return node
1964 if isinstance(node.constant_result, float):
1965 # We calculate float constants to make them available to
1966 # the compiler, but we do not aggregate them into a
1967 # constant node to prevent any loss of precision.
1968 return node
1969 if not node.operand1.is_literal or not node.operand2.is_literal:
1970 # We calculate other constants to make them available to
1971 # the compiler, but we only aggregate constant nodes
1972 # recursively, so non-const nodes are straight out.
1973 return node
1975 # now inject a new constant node with the calculated value
1976 try:
1977 type1, type2 = node.operand1.type, node.operand2.type
1978 if type1 is None or type2 is None:
1979 return node
1980 except AttributeError:
1981 return node
1983 if type1 is type2:
1984 new_node = node.operand1
1985 else:
1986 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
1987 if type(node.operand1) is type(node.operand2):
1988 new_node = node.operand1
1989 new_node.type = widest_type
1990 elif type1 is widest_type:
1991 new_node = node.operand1
1992 elif type2 is widest_type:
1993 new_node = node.operand2
1994 else:
1995 target_class = self._widest_node_class(
1996 node.operand1, node.operand2)
1997 if target_class is None:
1998 return node
1999 new_node = target_class(pos=node.pos, type = widest_type)
2001 new_node.constant_result = node.constant_result
2002 new_node.value = str(node.constant_result)
2003 #new_node = new_node.coerce_to(node.type, self.current_scope)
2004 return new_node
2006 # in the future, other nodes can have their own handler method here
2007 # that can replace them with a constant result node
2009 visit_Node = Visitor.VisitorTransform.recurse_to_children
2012 class FinalOptimizePhase(Visitor.CythonTransform):
2013 """
2014 This visitor handles several commuting optimizations, and is run
2015 just before the C code generation phase.
2017 The optimizations currently implemented in this class are:
2018 - Eliminate None assignment and refcounting for first assignment.
2019 - isinstance -> typecheck for cdef types
2020 """
2021 def visit_SingleAssignmentNode(self, node):
2022 """Avoid redundant initialisation of local variables before their
2023 first assignment.
2024 """
2025 self.visitchildren(node)
2026 if node.first:
2027 lhs = node.lhs
2028 lhs.lhs_of_first_assignment = True
2029 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
2030 # Have variable initialized to 0 rather than None
2031 lhs.entry.init_to_none = False
2032 lhs.entry.init = 0
2033 return node
2035 def visit_SimpleCallNode(self, node):
2036 """Replace generic calls to isinstance(x, type) by a more efficient
2037 type check.
2038 """
2039 self.visitchildren(node)
2040 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
2041 if node.function.name == 'isinstance':
2042 type_arg = node.args[1]
2043 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
2044 from CythonScope import utility_scope
2045 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
2046 node.function.type = node.function.entry.type
2047 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
2048 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
2049 return node