Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 2653:2e3dda4a7d23

Optimized list pop.
author Robert Bradshaw <robertwb@math.washington.edu>
date Tue Nov 03 01:01:54 2009 -0800 (2 years ago)
parents 6f1592e517ff
children 87e471ca62e0 b9710187b2c8
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
17 import codecs
19 try:
20 reduce
21 except NameError:
22 from functools import reduce
24 try:
25 set
26 except NameError:
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
31 nogil = False
33 def unwrap_node(node):
34 while isinstance(node, UtilNodes.ResultRefNode):
35 node = node.expression
36 return node
38 def is_common_value(a, b):
39 a = unwrap_node(a)
40 b = unwrap_node(b)
41 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
42 return a.name == b.name
43 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 return False
47 class IterationTransform(Visitor.VisitorTransform):
48 """Transform some common for-in loop patterns into efficient C loops:
50 - for-in-dict loop becomes a while loop calling PyDict_Next()
51 - for-in-enumerate is replaced by an external counter variable
52 - for-in-range loop becomes a plain C for loop
53 """
54 PyDict_Next_func_type = PyrexTypes.CFuncType(
55 PyrexTypes.c_bint_type, [
56 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
57 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
58 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
59 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
60 ])
62 PyDict_Next_name = EncodedString("PyDict_Next")
64 PyDict_Next_entry = Symtab.Entry(
65 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
67 visit_Node = Visitor.VisitorTransform.recurse_to_children
69 def visit_ModuleNode(self, node):
70 self.current_scope = node.scope
71 self.visitchildren(node)
72 return node
74 def visit_DefNode(self, node):
75 oldscope = self.current_scope
76 self.current_scope = node.entry.scope
77 self.visitchildren(node)
78 self.current_scope = oldscope
79 return node
81 def visit_ForInStatNode(self, node):
82 self.visitchildren(node)
83 return self._optimise_for_loop(node)
85 def _optimise_for_loop(self, node):
86 iterator = node.iterator.sequence
87 if iterator.type is Builtin.dict_type:
88 # like iterating over dict.keys()
89 return self._transform_dict_iteration(
90 node, dict_obj=iterator, keys=True, values=False)
92 # C array slice iteration?
93 if isinstance(iterator, ExprNodes.SliceIndexNode) and \
94 (iterator.base.type.is_array or iterator.base.type.is_ptr):
95 return self._transform_carray_iteration(node, iterator)
96 elif not isinstance(iterator, ExprNodes.SimpleCallNode):
97 return node
99 function = iterator.function
100 # dict iteration?
101 if isinstance(function, ExprNodes.AttributeNode) and \
102 function.obj.type == Builtin.dict_type:
103 dict_obj = function.obj
104 method = function.attribute
106 keys = values = False
107 if method == 'iterkeys':
108 keys = True
109 elif method == 'itervalues':
110 values = True
111 elif method == 'iteritems':
112 keys = values = True
113 else:
114 return node
115 return self._transform_dict_iteration(
116 node, dict_obj, keys, values)
118 # enumerate() ?
119 if iterator.self is None and \
120 isinstance(function, ExprNodes.NameNode) and \
121 function.entry.is_builtin and \
122 function.name == 'enumerate':
123 return self._transform_enumerate_iteration(node, iterator)
125 # range() iteration?
126 if Options.convert_range and node.target.type.is_int:
127 if iterator.self is None and \
128 isinstance(function, ExprNodes.NameNode) and \
129 function.entry.is_builtin and \
130 function.name in ('range', 'xrange'):
131 return self._transform_range_iteration(node, iterator)
133 return node
135 def _transform_carray_iteration(self, node, slice_node):
136 start = slice_node.start
137 stop = slice_node.stop
138 step = None
139 if not stop:
140 return node
142 carray_ptr = slice_node.base.coerce_to_simple(self.current_scope)
144 if start and start.constant_result != 0:
145 start_ptr_node = ExprNodes.AddNode(
146 start.pos,
147 operand1=carray_ptr,
148 operator='+',
149 operand2=start,
150 type=carray_ptr.type)
151 else:
152 start_ptr_node = carray_ptr
154 stop_ptr_node = ExprNodes.AddNode(
155 stop.pos,
156 operand1=carray_ptr,
157 operator='+',
158 operand2=stop,
159 type=carray_ptr.type
160 ).coerce_to_simple(self.current_scope)
162 counter = UtilNodes.TempHandle(carray_ptr.type)
163 counter_temp = counter.ref(node.target.pos)
165 if slice_node.base.type.is_string and node.target.type.is_pyobject:
166 # special case: char* -> bytes
167 target_value = ExprNodes.SliceIndexNode(
168 node.target.pos,
169 start=ExprNodes.IntNode(node.target.pos, value='0',
170 constant_result=0,
171 type=PyrexTypes.c_int_type),
172 stop=ExprNodes.IntNode(node.target.pos, value='1',
173 constant_result=1,
174 type=PyrexTypes.c_int_type),
175 base=counter_temp,
176 type=Builtin.bytes_type,
177 is_temp=1)
178 else:
179 target_value = ExprNodes.IndexNode(
180 node.target.pos,
181 index=ExprNodes.IntNode(node.target.pos, value='0',
182 constant_result=0,
183 type=PyrexTypes.c_int_type),
184 base=counter_temp,
185 is_buffer_access=False,
186 type=carray_ptr.type.base_type)
188 if target_value.type != node.target.type:
189 target_value = target_value.coerce_to(node.target.type,
190 self.current_scope)
192 target_assign = Nodes.SingleAssignmentNode(
193 pos = node.target.pos,
194 lhs = node.target,
195 rhs = target_value)
197 body = Nodes.StatListNode(
198 node.pos,
199 stats = [target_assign, node.body])
201 for_node = Nodes.ForFromStatNode(
202 node.pos,
203 bound1=start_ptr_node, relation1='<=',
204 target=counter_temp,
205 relation2='<', bound2=stop_ptr_node,
206 step=step, body=body,
207 else_clause=node.else_clause,
208 from_range=True)
210 return UtilNodes.TempsBlockNode(
211 node.pos, temps=[counter],
212 body=for_node)
214 def _transform_enumerate_iteration(self, node, enumerate_function):
215 args = enumerate_function.arg_tuple.args
216 if len(args) == 0:
217 error(enumerate_function.pos,
218 "enumerate() requires an iterable argument")
219 return node
220 elif len(args) > 1:
221 error(enumerate_function.pos,
222 "enumerate() takes at most 1 argument")
223 return node
225 if not node.target.is_sequence_constructor:
226 # leave this untouched for now
227 return node
228 targets = node.target.args
229 if len(targets) != 2:
230 # leave this untouched for now
231 return node
232 if not isinstance(targets[0], ExprNodes.NameNode):
233 # leave this untouched for now
234 return node
236 enumerate_target, iterable_target = targets
237 counter_type = enumerate_target.type
239 if not counter_type.is_pyobject and not counter_type.is_int:
240 # nothing we can do here, I guess
241 return node
243 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
244 value='0',
245 type=counter_type,
246 constant_result=0))
247 inc_expression = ExprNodes.AddNode(
248 enumerate_function.pos,
249 operand1 = temp,
250 operand2 = ExprNodes.IntNode(node.pos, value='1',
251 type=counter_type,
252 constant_result=1),
253 operator = '+',
254 type = counter_type,
255 is_temp = counter_type.is_pyobject
256 )
258 loop_body = [
259 Nodes.SingleAssignmentNode(
260 pos = enumerate_target.pos,
261 lhs = enumerate_target,
262 rhs = temp),
263 Nodes.SingleAssignmentNode(
264 pos = enumerate_target.pos,
265 lhs = temp,
266 rhs = inc_expression)
267 ]
269 if isinstance(node.body, Nodes.StatListNode):
270 node.body.stats = loop_body + node.body.stats
271 else:
272 loop_body.append(node.body)
273 node.body = Nodes.StatListNode(
274 node.body.pos,
275 stats = loop_body)
277 node.target = iterable_target
278 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
279 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
281 # recurse into loop to check for further optimisations
282 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
284 def _transform_range_iteration(self, node, range_function):
285 args = range_function.arg_tuple.args
286 if len(args) < 3:
287 step_pos = range_function.pos
288 step_value = 1
289 step = ExprNodes.IntNode(step_pos, value='1',
290 constant_result=1)
291 else:
292 step = args[2]
293 step_pos = step.pos
294 if not isinstance(step.constant_result, (int, long)):
295 # cannot determine step direction
296 return node
297 step_value = step.constant_result
298 if step_value == 0:
299 # will lead to an error elsewhere
300 return node
301 if not isinstance(step, ExprNodes.IntNode):
302 step = ExprNodes.IntNode(step_pos, value=str(step_value),
303 constant_result=step_value)
305 if step_value < 0:
306 step.value = str(-step_value)
307 relation1 = '>='
308 relation2 = '>'
309 else:
310 relation1 = '<='
311 relation2 = '<'
313 if len(args) == 1:
314 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
315 constant_result=0)
316 bound2 = args[0].coerce_to_integer(self.current_scope)
317 else:
318 bound1 = args[0].coerce_to_integer(self.current_scope)
319 bound2 = args[1].coerce_to_integer(self.current_scope)
320 step = step.coerce_to_integer(self.current_scope)
322 if not bound2.is_literal:
323 # stop bound must be immutable => keep it in a temp var
324 bound2_is_temp = True
325 bound2 = UtilNodes.LetRefNode(bound2)
326 else:
327 bound2_is_temp = False
329 for_node = Nodes.ForFromStatNode(
330 node.pos,
331 target=node.target,
332 bound1=bound1, relation1=relation1,
333 relation2=relation2, bound2=bound2,
334 step=step, body=node.body,
335 else_clause=node.else_clause,
336 from_range=True)
338 if bound2_is_temp:
339 for_node = UtilNodes.LetNode(bound2, for_node)
341 return for_node
343 def _transform_dict_iteration(self, node, dict_obj, keys, values):
344 py_object_ptr = PyrexTypes.c_void_ptr_type
346 temps = []
347 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
348 temps.append(temp)
349 dict_temp = temp.ref(dict_obj.pos)
350 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
351 temps.append(temp)
352 pos_temp = temp.ref(node.pos)
353 pos_temp_addr = ExprNodes.AmpersandNode(
354 node.pos, operand=pos_temp,
355 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
356 if keys:
357 temp = UtilNodes.TempHandle(py_object_ptr)
358 temps.append(temp)
359 key_temp = temp.ref(node.target.pos)
360 key_temp_addr = ExprNodes.AmpersandNode(
361 node.target.pos, operand=key_temp,
362 type=PyrexTypes.c_ptr_type(py_object_ptr))
363 else:
364 key_temp_addr = key_temp = ExprNodes.NullNode(
365 pos=node.target.pos)
366 if values:
367 temp = UtilNodes.TempHandle(py_object_ptr)
368 temps.append(temp)
369 value_temp = temp.ref(node.target.pos)
370 value_temp_addr = ExprNodes.AmpersandNode(
371 node.target.pos, operand=value_temp,
372 type=PyrexTypes.c_ptr_type(py_object_ptr))
373 else:
374 value_temp_addr = value_temp = ExprNodes.NullNode(
375 pos=node.target.pos)
377 key_target = value_target = node.target
378 tuple_target = None
379 if keys and values:
380 if node.target.is_sequence_constructor:
381 if len(node.target.args) == 2:
382 key_target, value_target = node.target.args
383 else:
384 # unusual case that may or may not lead to an error
385 return node
386 else:
387 tuple_target = node.target
389 def coerce_object_to(obj_node, dest_type):
390 if dest_type.is_pyobject:
391 if dest_type != obj_node.type:
392 if dest_type.is_extension_type or dest_type.is_builtin_type:
393 obj_node = ExprNodes.PyTypeTestNode(
394 obj_node, dest_type, self.current_scope, notnone=True)
395 result = ExprNodes.TypecastNode(
396 obj_node.pos,
397 operand = obj_node,
398 type = dest_type)
399 return (result, None)
400 else:
401 temp = UtilNodes.TempHandle(dest_type)
402 temps.append(temp)
403 temp_result = temp.ref(obj_node.pos)
404 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
405 def result(self):
406 return temp_result.result()
407 def generate_execution_code(self, code):
408 self.generate_result_code(code)
409 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
411 if isinstance(node.body, Nodes.StatListNode):
412 body = node.body
413 else:
414 body = Nodes.StatListNode(pos = node.body.pos,
415 stats = [node.body])
417 if tuple_target:
418 tuple_result = ExprNodes.TupleNode(
419 pos = tuple_target.pos,
420 args = [key_temp, value_temp],
421 is_temp = 1,
422 type = Builtin.tuple_type,
423 )
424 body.stats.insert(
425 0, Nodes.SingleAssignmentNode(
426 pos = tuple_target.pos,
427 lhs = tuple_target,
428 rhs = tuple_result))
429 else:
430 # execute all coercions before the assignments
431 coercion_stats = []
432 assign_stats = []
433 if keys:
434 temp_result, coercion = coerce_object_to(
435 key_temp, key_target.type)
436 if coercion:
437 coercion_stats.append(coercion)
438 assign_stats.append(
439 Nodes.SingleAssignmentNode(
440 pos = key_temp.pos,
441 lhs = key_target,
442 rhs = temp_result))
443 if values:
444 temp_result, coercion = coerce_object_to(
445 value_temp, value_target.type)
446 if coercion:
447 coercion_stats.append(coercion)
448 assign_stats.append(
449 Nodes.SingleAssignmentNode(
450 pos = value_temp.pos,
451 lhs = value_target,
452 rhs = temp_result))
453 body.stats[0:0] = coercion_stats + assign_stats
455 result_code = [
456 Nodes.SingleAssignmentNode(
457 pos = dict_obj.pos,
458 lhs = dict_temp,
459 rhs = dict_obj),
460 Nodes.SingleAssignmentNode(
461 pos = node.pos,
462 lhs = pos_temp,
463 rhs = ExprNodes.IntNode(node.pos, value='0',
464 constant_result=0)),
465 Nodes.WhileStatNode(
466 pos = node.pos,
467 condition = ExprNodes.SimpleCallNode(
468 pos = dict_obj.pos,
469 type = PyrexTypes.c_bint_type,
470 function = ExprNodes.NameNode(
471 pos = dict_obj.pos,
472 name = self.PyDict_Next_name,
473 type = self.PyDict_Next_func_type,
474 entry = self.PyDict_Next_entry),
475 args = [dict_temp, pos_temp_addr,
476 key_temp_addr, value_temp_addr]
477 ),
478 body = body,
479 else_clause = node.else_clause
480 )
481 ]
483 return UtilNodes.TempsBlockNode(
484 node.pos, temps=temps,
485 body=Nodes.StatListNode(
486 node.pos,
487 stats = result_code
488 ))
491 class SwitchTransform(Visitor.VisitorTransform):
492 """
493 This transformation tries to turn long if statements into C switch statements.
494 The requirement is that every clause be an (or of) var == value, where the var
495 is common among all clauses and both var and value are ints.
496 """
497 def extract_conditions(self, cond):
498 while True:
499 if isinstance(cond, ExprNodes.CoerceToTempNode):
500 cond = cond.arg
501 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
502 # this is what we get from the FlattenInListTransform
503 cond = cond.subexpression
504 elif isinstance(cond, ExprNodes.TypecastNode):
505 cond = cond.operand
506 else:
507 break
509 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
510 and cond.cascade is None
511 and cond.operator == '=='
512 and not cond.is_python_comparison()):
513 if is_common_value(cond.operand1, cond.operand1):
514 if cond.operand2.is_literal:
515 return cond.operand1, [cond.operand2]
516 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
517 return cond.operand1, [cond.operand2]
518 if is_common_value(cond.operand2, cond.operand2):
519 if cond.operand1.is_literal:
520 return cond.operand2, [cond.operand1]
521 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
522 return cond.operand2, [cond.operand1]
523 elif (isinstance(cond, ExprNodes.BoolBinopNode)
524 and cond.operator == 'or'):
525 t1, c1 = self.extract_conditions(cond.operand1)
526 t2, c2 = self.extract_conditions(cond.operand2)
527 if is_common_value(t1, t2):
528 return t1, c1+c2
529 return None, None
531 def visit_IfStatNode(self, node):
532 self.visitchildren(node)
533 common_var = None
534 case_count = 0
535 cases = []
536 for if_clause in node.if_clauses:
537 var, conditions = self.extract_conditions(if_clause.condition)
538 if var is None:
539 return node
540 elif common_var is not None and not is_common_value(var, common_var):
541 return node
542 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
543 return node
544 else:
545 common_var = var
546 case_count += len(conditions)
547 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
548 conditions = conditions,
549 body = if_clause.body))
550 if case_count < 2:
551 return node
553 common_var = unwrap_node(common_var)
554 return Nodes.SwitchStatNode(pos = node.pos,
555 test = common_var,
556 cases = cases,
557 else_clause = node.else_clause)
559 visit_Node = Visitor.VisitorTransform.recurse_to_children
562 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
563 """
564 This transformation flattens "x in [val1, ..., valn]" into a sequential list
565 of comparisons.
566 """
568 def visit_PrimaryCmpNode(self, node):
569 self.visitchildren(node)
570 if node.cascade is not None:
571 return node
572 elif node.operator == 'in':
573 conjunction = 'or'
574 eq_or_neq = '=='
575 elif node.operator == 'not_in':
576 conjunction = 'and'
577 eq_or_neq = '!='
578 else:
579 return node
581 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
582 return node
584 args = node.operand2.args
585 if len(args) == 0:
586 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
588 lhs = UtilNodes.ResultRefNode(node.operand1)
590 conds = []
591 for arg in args:
592 cond = ExprNodes.PrimaryCmpNode(
593 pos = node.pos,
594 operand1 = lhs,
595 operator = eq_or_neq,
596 operand2 = arg,
597 cascade = None)
598 conds.append(ExprNodes.TypecastNode(
599 pos = node.pos,
600 operand = cond,
601 type = PyrexTypes.c_bint_type))
602 def concat(left, right):
603 return ExprNodes.BoolBinopNode(
604 pos = node.pos,
605 operator = conjunction,
606 operand1 = left,
607 operand2 = right)
609 condition = reduce(concat, conds)
610 return UtilNodes.EvalWithTempExprNode(lhs, condition)
612 visit_Node = Visitor.VisitorTransform.recurse_to_children
615 class DropRefcountingTransform(Visitor.VisitorTransform):
616 """Drop ref-counting in safe places.
617 """
618 visit_Node = Visitor.VisitorTransform.recurse_to_children
620 def visit_ParallelAssignmentNode(self, node):
621 left_names, right_names = [], []
622 left_indices, right_indices = [], []
623 temps = []
625 for stat in node.stats:
626 if isinstance(stat, Nodes.SingleAssignmentNode):
627 if not self._extract_operand(stat.lhs, left_names,
628 left_indices, temps):
629 return node
630 if not self._extract_operand(stat.rhs, right_names,
631 right_indices, temps):
632 return node
633 elif isinstance(stat, Nodes.CascadedAssignmentNode):
634 # FIXME
635 return node
636 else:
637 return node
639 if left_names or right_names:
640 # lhs/rhs names must be a non-redundant permutation
641 lnames = [ path for path, n in left_names ]
642 rnames = [ path for path, n in right_names ]
643 if set(lnames) != set(rnames):
644 return node
645 if len(set(lnames)) != len(right_names):
646 return node
648 if left_indices or right_indices:
649 # base name and index of index nodes must be a
650 # non-redundant permutation
651 lindices = []
652 for lhs_node in left_indices:
653 index_id = self._extract_index_id(lhs_node)
654 if not index_id:
655 return node
656 lindices.append(index_id)
657 rindices = []
658 for rhs_node in right_indices:
659 index_id = self._extract_index_id(rhs_node)
660 if not index_id:
661 return node
662 rindices.append(index_id)
664 if set(lindices) != set(rindices):
665 return node
666 if len(set(lindices)) != len(right_indices):
667 return node
669 # really supporting IndexNode requires support in
670 # __Pyx_GetItemInt(), so let's stop short for now
671 return node
673 temp_args = [t.arg for t in temps]
674 for temp in temps:
675 temp.use_managed_ref = False
677 for _, name_node in left_names + right_names:
678 if name_node not in temp_args:
679 name_node.use_managed_ref = False
681 for index_node in left_indices + right_indices:
682 index_node.use_managed_ref = False
684 return node
686 def _extract_operand(self, node, names, indices, temps):
687 node = unwrap_node(node)
688 if not node.type.is_pyobject:
689 return False
690 if isinstance(node, ExprNodes.CoerceToTempNode):
691 temps.append(node)
692 node = node.arg
693 name_path = []
694 obj_node = node
695 while isinstance(obj_node, ExprNodes.AttributeNode):
696 if obj_node.is_py_attr:
697 return False
698 name_path.append(obj_node.member)
699 obj_node = obj_node.obj
700 if isinstance(obj_node, ExprNodes.NameNode):
701 name_path.append(obj_node.name)
702 names.append( ('.'.join(name_path[::-1]), node) )
703 elif isinstance(node, ExprNodes.IndexNode):
704 if node.base.type != Builtin.list_type:
705 return False
706 if not node.index.type.is_int:
707 return False
708 if not isinstance(node.base, ExprNodes.NameNode):
709 return False
710 indices.append(node)
711 else:
712 return False
713 return True
715 def _extract_index_id(self, index_node):
716 base = index_node.base
717 index = index_node.index
718 if isinstance(index, ExprNodes.NameNode):
719 index_val = index.name
720 elif isinstance(index, ExprNodes.ConstNode):
721 # FIXME:
722 return None
723 else:
724 return None
725 return (base.name, index_val)
728 class OptimizeBuiltinCalls(Visitor.EnvTransform):
729 """Optimize some common methods calls and instantiation patterns
730 for builtin types.
731 """
732 # only intercept on call nodes
733 visit_Node = Visitor.VisitorTransform.recurse_to_children
735 def visit_GeneralCallNode(self, node):
736 self.visitchildren(node)
737 function = node.function
738 if not function.type.is_pyobject:
739 return node
740 arg_tuple = node.positional_args
741 if not isinstance(arg_tuple, ExprNodes.TupleNode):
742 return node
743 return self._dispatch_to_handler(
744 node, function, arg_tuple, node.keyword_args)
746 def visit_SimpleCallNode(self, node):
747 self.visitchildren(node)
748 function = node.function
749 if not function.type.is_pyobject:
750 return node
751 arg_tuple = node.arg_tuple
752 if not isinstance(arg_tuple, ExprNodes.TupleNode):
753 return node
754 return self._dispatch_to_handler(
755 node, node.function, arg_tuple)
757 ### cleanup to avoid redundant coercions to/from Python types
759 def _visit_PyTypeTestNode(self, node):
760 # disabled - appears to break assignments in some cases, and
761 # also drops a None check, which might still be required
762 """Flatten redundant type checks after tree changes.
763 """
764 old_arg = node.arg
765 self.visitchildren(node)
766 if old_arg is node.arg or node.arg.type != node.type:
767 return node
768 return node.arg
770 def visit_CoerceFromPyTypeNode(self, node):
771 """Drop redundant conversion nodes after tree changes.
773 Also, optimise away calls to Python's builtin int() and
774 float() if the result is going to be coerced back into a C
775 type anyway.
776 """
777 self.visitchildren(node)
778 arg = node.arg
779 if not arg.type.is_pyobject:
780 # no Python conversion left at all, just do a C coercion instead
781 if node.type == arg.type:
782 return arg
783 else:
784 return arg.coerce_to(node.type, self.env_stack[-1])
785 if not isinstance(arg, ExprNodes.SimpleCallNode):
786 return node
787 if not (node.type.is_int or node.type.is_float):
788 return node
789 function = arg.function
790 if not isinstance(function, ExprNodes.NameNode) \
791 or not function.type.is_builtin_type \
792 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
793 return node
794 args = arg.arg_tuple.args
795 if len(args) != 1:
796 return node
797 func_arg = args[0]
798 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
799 func_arg = func_arg.arg
800 elif func_arg.type.is_pyobject:
801 # play safe: Python conversion might work on all sorts of things
802 return node
803 if function.name == 'int':
804 if func_arg.type.is_int or node.type.is_int:
805 if func_arg.type == node.type:
806 return func_arg
807 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
808 return ExprNodes.CastNode(func_arg, node.type)
809 elif function.name == 'float':
810 if func_arg.type.is_float or node.type.is_float:
811 if func_arg.type == node.type:
812 return func_arg
813 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
814 return ExprNodes.CastNode(func_arg, node.type)
815 return node
817 ### dispatch to specific optimisers
819 def _find_handler(self, match_name, has_kwargs):
820 call_type = has_kwargs and 'general' or 'simple'
821 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
822 if handler is None:
823 handler = getattr(self, '_handle_any_%s' % match_name, None)
824 return handler
826 def _dispatch_to_handler(self, node, function, arg_tuple, kwargs=None):
827 if function.is_name:
828 match_name = "_function_%s" % function.name
829 function_handler = self._find_handler(
830 "function_%s" % function.name, kwargs)
831 if function_handler is None:
832 return node
833 if kwargs:
834 return function_handler(node, arg_tuple, kwargs)
835 else:
836 return function_handler(node, arg_tuple)
837 elif function.is_attribute:
838 attr_name = function.attribute
839 arg_list = arg_tuple.args
840 self_arg = function.obj
841 obj_type = self_arg.type
842 is_unbound_method = False
843 if obj_type.is_builtin_type:
844 if obj_type is Builtin.type_type and arg_list and \
845 arg_list[0].type.is_pyobject:
846 # calling an unbound method like 'list.append(L,x)'
847 # (ignoring 'type.mro()' here ...)
848 type_name = function.obj.name
849 self_arg = None
850 is_unbound_method = True
851 else:
852 type_name = obj_type.name
853 else:
854 type_name = "object" # safety measure
855 method_handler = self._find_handler(
856 "method_%s_%s" % (type_name, attr_name), kwargs)
857 if method_handler is None:
858 if attr_name in TypeSlots.method_name_to_slot \
859 or attr_name == '__new__':
860 method_handler = self._find_handler(
861 "slot%s" % attr_name, kwargs)
862 if method_handler is None:
863 return node
864 if self_arg is not None:
865 arg_list = [self_arg] + list(arg_list)
866 if kwargs:
867 return method_handler(node, arg_list, kwargs, is_unbound_method)
868 else:
869 return method_handler(node, arg_list, is_unbound_method)
870 else:
871 return node
873 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
874 if not expected: # None or 0
875 arg_str = ''
876 elif isinstance(expected, basestring) or expected > 1:
877 arg_str = '...'
878 elif expected == 1:
879 arg_str = 'x'
880 else:
881 arg_str = ''
882 if expected is not None:
883 expected_str = 'expected %s, ' % expected
884 else:
885 expected_str = ''
886 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
887 function_name, arg_str, expected_str, len(args)))
889 ### builtin types
891 def _handle_general_function_dict(self, node, pos_args, kwargs):
892 """Replace dict(a=b,c=d,...) by the underlying keyword dict
893 construction which is done anyway.
894 """
895 if len(pos_args.args) > 0:
896 return node
897 if not isinstance(kwargs, ExprNodes.DictNode):
898 return node
899 if node.starstar_arg:
900 # we could optimize this by updating the kw dict instead
901 return node
902 return kwargs
904 PyDict_Copy_func_type = PyrexTypes.CFuncType(
905 Builtin.dict_type, [
906 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
907 ])
909 def _handle_simple_function_dict(self, node, pos_args):
910 """Replace dict(some_dict) by PyDict_Copy(some_dict) and
911 dict([ (a,b) for ... ]) by a literal { a:b for ... }.
912 """
913 if len(pos_args.args) != 1:
914 return node
915 arg = pos_args.args[0]
916 if arg.type is Builtin.dict_type:
917 arg = ExprNodes.NoneCheckNode(
918 arg, "PyExc_TypeError", "'NoneType' is not iterable")
919 return ExprNodes.PythonCapiCallNode(
920 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
921 args = [arg],
922 is_temp = node.is_temp
923 )
924 elif isinstance(arg, ExprNodes.ComprehensionNode) and \
925 arg.type is Builtin.list_type:
926 append_node = arg.append
927 if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
928 len(append_node.expr.args) == 2:
929 key_node, value_node = append_node.expr.args
930 target_node = ExprNodes.DictNode(
931 pos=arg.target.pos, key_value_pairs=[], is_temp=1)
932 new_append_node = ExprNodes.DictComprehensionAppendNode(
933 append_node.pos, target=target_node,
934 key_expr=key_node, value_expr=value_node,
935 is_temp=1)
936 arg.target = target_node
937 arg.type = target_node.type
938 replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
939 return replace_in(arg)
940 return node
942 def _handle_simple_function_set(self, node, pos_args):
943 """Replace set([a,b,...]) by a literal set {a,b,...} and
944 set([ x for ... ]) by a literal { x for ... }.
945 """
946 arg_count = len(pos_args.args)
947 if arg_count == 0:
948 return ExprNodes.SetNode(node.pos, args=[],
949 type=Builtin.set_type, is_temp=1)
950 if arg_count > 1:
951 return node
952 iterable = pos_args.args[0]
953 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
954 return ExprNodes.SetNode(node.pos, args=iterable.args,
955 type=Builtin.set_type, is_temp=1)
956 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
957 iterable.type is Builtin.list_type:
958 iterable.target = ExprNodes.SetNode(
959 node.pos, args=[], type=Builtin.set_type, is_temp=1)
960 iterable.type = Builtin.set_type
961 iterable.pos = node.pos
962 return iterable
963 else:
964 return node
966 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
967 Builtin.tuple_type, [
968 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
969 ])
971 def _handle_simple_function_tuple(self, node, pos_args):
972 """Replace tuple([...]) by a call to PyList_AsTuple.
973 """
974 if len(pos_args.args) != 1:
975 return node
976 list_arg = pos_args.args[0]
977 if list_arg.type is not Builtin.list_type:
978 return node
979 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
980 ExprNodes.ListNode)):
981 pos_args.args[0] = ExprNodes.NoneCheckNode(
982 list_arg, "PyExc_TypeError",
983 "'NoneType' object is not iterable")
985 return ExprNodes.PythonCapiCallNode(
986 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
987 args = pos_args.args,
988 is_temp = node.is_temp
989 )
991 ### builtin functions
993 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
994 PyrexTypes.py_object_type, [
995 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
996 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
997 ])
999 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
1000 PyrexTypes.py_object_type, [
1001 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1002 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1003 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1004 ])
1006 def _handle_simple_function_getattr(self, node, pos_args):
1007 args = pos_args.args
1008 if len(args) == 2:
1009 node = ExprNodes.PythonCapiCallNode(
1010 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
1011 args = args,
1012 is_temp = node.is_temp
1014 elif len(args) == 3:
1015 node = ExprNodes.PythonCapiCallNode(
1016 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
1017 utility_code = Builtin.getattr3_utility_code,
1018 args = args,
1019 is_temp = node.is_temp
1021 else:
1022 self._error_wrong_arg_count('getattr', node, args, '2 or 3')
1023 return node
1025 Pyx_Type_func_type = PyrexTypes.CFuncType(
1026 Builtin.type_type, [
1027 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1028 ])
1030 def _handle_simple_function_type(self, node, pos_args):
1031 args = pos_args.args
1032 if len(args) != 1:
1033 return node
1034 node = ExprNodes.PythonCapiCallNode(
1035 node.pos, "__Pyx_Type", self.Pyx_Type_func_type,
1036 args = args,
1037 is_temp = node.is_temp,
1038 utility_code = pytype_utility_code,
1040 return node
1042 ### special methods
1044 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1045 PyrexTypes.py_object_type, [
1046 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1047 ])
1049 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1050 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1051 """
1052 obj = node.function.obj
1053 if not is_unbound_method or len(args) != 1:
1054 return node
1055 type_arg = args[0]
1056 if not obj.is_name or not type_arg.is_name:
1057 # play safe
1058 return node
1059 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1060 # not a known type, play safe
1061 return node
1062 if not type_arg.type_entry or not obj.type_entry:
1063 if obj.name != type_arg.name:
1064 return node
1065 # otherwise, we know it's a type and we know it's the same
1066 # type for both - that should do
1067 elif type_arg.type_entry != obj.type_entry:
1068 # different types - may or may not lead to an error at runtime
1069 return node
1071 # FIXME: we could potentially look up the actual tp_new C method
1072 # of the extension type and call that instead of the generic slot
1074 if not type_arg.type_entry:
1075 # arbitrary variable, needs a None check for safety
1076 type_arg = ExprNodes.NoneCheckNode(
1077 type_arg, "PyExc_TypeError",
1078 "object.__new__(X): X is not a type object (NoneType)")
1080 return ExprNodes.PythonCapiCallNode(
1081 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1082 args = [type_arg],
1083 utility_code = tpnew_utility_code,
1084 is_temp = node.is_temp
1087 ### methods of builtin types
1089 PyObject_Append_func_type = PyrexTypes.CFuncType(
1090 PyrexTypes.py_object_type, [
1091 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1092 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1093 ])
1095 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
1096 # X.append() is almost always referring to a list
1097 if len(args) != 2:
1098 return node
1100 return ExprNodes.PythonCapiCallNode(
1101 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
1102 args = args,
1103 is_temp = node.is_temp,
1104 utility_code = append_utility_code
1107 PyObject_Pop_func_type = PyrexTypes.CFuncType(
1108 PyrexTypes.py_object_type, [
1109 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1110 ])
1112 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
1113 PyrexTypes.py_object_type, [
1114 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1115 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
1116 ])
1118 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
1119 # X.pop([n]) is almost always referring to a list
1120 if len(args) == 1:
1121 return ExprNodes.PythonCapiCallNode(
1122 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
1123 args = args,
1124 is_temp = node.is_temp,
1125 utility_code = pop_utility_code
1127 elif len(args) == 2:
1128 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
1129 original_type = args[1].arg.type
1130 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
1131 args[1] = args[1].arg
1132 return ExprNodes.PythonCapiCallNode(
1133 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
1134 args = args,
1135 is_temp = node.is_temp,
1136 utility_code = pop_index_utility_code
1139 return node
1141 PyList_Append_func_type = PyrexTypes.CFuncType(
1142 PyrexTypes.c_int_type, [
1143 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1144 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1145 ],
1146 exception_value = "-1")
1148 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
1149 if len(args) != 2:
1150 self._error_wrong_arg_count('list.append', node, args, 2)
1151 return node
1152 return self._substitute_method_call(
1153 node, "PyList_Append", self.PyList_Append_func_type,
1154 'append', is_unbound_method, args)
1156 single_param_func_type = PyrexTypes.CFuncType(
1157 PyrexTypes.c_int_type, [
1158 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1159 ],
1160 exception_value = "-1")
1162 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
1163 if len(args) != 1:
1164 return node
1165 return self._substitute_method_call(
1166 node, "PyList_Sort", self.single_param_func_type,
1167 'sort', is_unbound_method, args)
1169 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
1170 if len(args) != 1:
1171 self._error_wrong_arg_count('list.reverse', node, args, 1)
1172 return node
1173 return self._substitute_method_call(
1174 node, "PyList_Reverse", self.single_param_func_type,
1175 'reverse', is_unbound_method, args)
1177 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
1178 Builtin.bytes_type, [
1179 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1180 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1181 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1182 ],
1183 exception_value = "NULL")
1185 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
1186 Builtin.bytes_type, [
1187 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1188 ],
1189 exception_value = "NULL")
1191 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
1192 'unicode_escape', 'raw_unicode_escape']
1194 _special_codecs = [ (name, codecs.getencoder(name))
1195 for name in _special_encodings ]
1197 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
1198 if len(args) < 1 or len(args) > 3:
1199 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1200 return node
1202 string_node = args[0]
1204 if len(args) == 1:
1205 null_node = ExprNodes.NullNode(node.pos)
1206 return self._substitute_method_call(
1207 node, "PyUnicode_AsEncodedString",
1208 self.PyUnicode_AsEncodedString_func_type,
1209 'encode', is_unbound_method, [string_node, null_node, null_node])
1211 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1212 if parameters is None:
1213 return node
1214 encoding, encoding_node, error_handling, error_handling_node = parameters
1216 if isinstance(string_node, ExprNodes.UnicodeNode):
1217 # constant, so try to do the encoding at compile time
1218 try:
1219 value = string_node.value.encode(encoding, error_handling)
1220 except:
1221 # well, looks like we can't
1222 pass
1223 else:
1224 value = BytesLiteral(value)
1225 value.encoding = encoding
1226 return ExprNodes.BytesNode(
1227 string_node.pos, value=value, type=Builtin.bytes_type)
1229 if error_handling == 'strict':
1230 # try to find a specific encoder function
1231 codec_name = self._find_special_codec_name(encoding)
1232 if codec_name is not None:
1233 encode_function = "PyUnicode_As%sString" % codec_name
1234 return self._substitute_method_call(
1235 node, encode_function,
1236 self.PyUnicode_AsXyzString_func_type,
1237 'encode', is_unbound_method, [string_node])
1239 return self._substitute_method_call(
1240 node, "PyUnicode_AsEncodedString",
1241 self.PyUnicode_AsEncodedString_func_type,
1242 'encode', is_unbound_method,
1243 [string_node, encoding_node, error_handling_node])
1245 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
1246 Builtin.unicode_type, [
1247 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1248 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1249 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1250 ],
1251 exception_value = "NULL")
1253 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
1254 Builtin.unicode_type, [
1255 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1256 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1257 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1258 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1259 ],
1260 exception_value = "NULL")
1262 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
1263 if len(args) < 1 or len(args) > 3:
1264 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
1265 return node
1266 if not isinstance(args[0], ExprNodes.SliceIndexNode):
1267 # we need the string length as a slice end index
1268 return node
1269 index_node = args[0]
1270 string_node = index_node.base
1271 if not string_node.type.is_string:
1272 # nothing to optimise here
1273 return node
1274 start, stop = index_node.start, index_node.stop
1275 if not stop:
1276 # FIXME: could use strlen() - although Python will do that anyway ...
1277 return node
1278 if stop.type.is_pyobject:
1279 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1280 if start and start.constant_result != 0:
1281 # FIXME: put start into a temp and do the math
1282 return node
1284 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1285 if parameters is None:
1286 return node
1287 encoding, encoding_node, error_handling, error_handling_node = parameters
1289 # try to find a specific encoder function
1290 codec_name = self._find_special_codec_name(encoding)
1291 if codec_name is not None:
1292 decode_function = "PyUnicode_Decode%s" % codec_name
1293 return ExprNodes.PythonCapiCallNode(
1294 node.pos, decode_function,
1295 self.PyUnicode_DecodeXyz_func_type,
1296 args = [string_node, stop, error_handling_node],
1297 is_temp = node.is_temp,
1300 return ExprNodes.PythonCapiCallNode(
1301 node.pos, "PyUnicode_Decode",
1302 self.PyUnicode_Decode_func_type,
1303 args = [string_node, stop, encoding_node, error_handling_node],
1304 is_temp = node.is_temp,
1307 def _find_special_codec_name(self, encoding):
1308 try:
1309 requested_codec = codecs.getencoder(encoding)
1310 except:
1311 return None
1312 for name, codec in self._special_codecs:
1313 if codec == requested_codec:
1314 if '_' in name:
1315 name = ''.join([ s.capitalize()
1316 for s in name.split('_')])
1317 return name
1318 return None
1320 def _unpack_encoding_and_error_mode(self, pos, args):
1321 encoding_node = args[1]
1322 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
1323 encoding_node = encoding_node.arg
1324 if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1325 ExprNodes.BytesNode)):
1326 return None
1327 encoding = encoding_node.value
1328 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
1329 type=PyrexTypes.c_char_ptr_type)
1331 null_node = ExprNodes.NullNode(pos)
1332 if len(args) == 3:
1333 error_handling_node = args[2]
1334 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
1335 error_handling_node = error_handling_node.arg
1336 if not isinstance(error_handling_node,
1337 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1338 ExprNodes.BytesNode)):
1339 return None
1340 error_handling = error_handling_node.value
1341 if error_handling == 'strict':
1342 error_handling_node = null_node
1343 else:
1344 error_handling_node = ExprNodes.BytesNode(
1345 error_handling_node.pos, value=error_handling,
1346 type=PyrexTypes.c_char_ptr_type)
1347 else:
1348 error_handling = 'strict'
1349 error_handling_node = null_node
1351 return (encoding, encoding_node, error_handling, error_handling_node)
1353 def _substitute_method_call(self, node, name, func_type,
1354 attr_name, is_unbound_method, args=()):
1355 args = list(args)
1356 if args:
1357 self_arg = args[0]
1358 if is_unbound_method:
1359 self_arg = ExprNodes.NoneCheckNode(
1360 self_arg, "PyExc_TypeError",
1361 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
1362 attr_name, node.function.obj.name))
1363 else:
1364 self_arg = ExprNodes.NoneCheckNode(
1365 self_arg, "PyExc_AttributeError",
1366 "'NoneType' object has no attribute '%s'" % attr_name)
1367 args[0] = self_arg
1368 return ExprNodes.PythonCapiCallNode(
1369 node.pos, name, func_type,
1370 args = args,
1371 is_temp = node.is_temp
1375 append_utility_code = UtilityCode(
1376 proto = """
1377 static INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1378 if (likely(PyList_CheckExact(L))) {
1379 if (PyList_Append(L, x) < 0) return NULL;
1380 Py_INCREF(Py_None);
1381 return Py_None; /* this is just to have an accurate signature */
1383 else {
1384 PyObject *r, *m;
1385 m = __Pyx_GetAttrString(L, "append");
1386 if (!m) return NULL;
1387 r = PyObject_CallFunctionObjArgs(m, x, NULL);
1388 Py_DECREF(m);
1389 return r;
1392 """,
1393 impl = ""
1397 pop_utility_code = UtilityCode(
1398 proto = """
1399 static INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1400 if (likely(PyList_CheckExact(L))
1401 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
1402 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
1403 Py_SIZE(L) -= 1;
1404 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
1406 else {
1407 PyObject *r, *m;
1408 m = __Pyx_GetAttrString(L, "pop");
1409 if (!m) return NULL;
1410 r = PyObject_CallObject(m, NULL);
1411 Py_DECREF(m);
1412 return r;
1415 """,
1416 impl = ""
1419 pop_index_utility_code = UtilityCode(
1420 proto = """
1421 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
1422 """,
1423 impl = """
1424 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
1425 PyObject *r, *m, *t, *py_ix;
1426 if (likely(PyList_CheckExact(L))) {
1427 Py_ssize_t size = PyList_GET_SIZE(L);
1428 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
1429 if (ix < 0) {
1430 ix += size;
1432 if (likely(0 <= ix && ix < size)) {
1433 Py_ssize_t i;
1434 PyObject* v = PyList_GET_ITEM(L, ix);
1435 Py_SIZE(L) -= 1;
1436 size -= 1;
1437 for(i=ix; i<size; i++) {
1438 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
1440 return v;
1444 py_ix = t = NULL;
1445 m = __Pyx_GetAttrString(L, "pop");
1446 if (!m) goto bad;
1447 py_ix = PyInt_FromSsize_t(ix);
1448 if (!py_ix) goto bad;
1449 t = PyTuple_New(1);
1450 if (!t) goto bad;
1451 PyTuple_SET_ITEM(t, 0, py_ix);
1452 py_ix = NULL;
1453 r = PyObject_CallObject(m, t);
1454 Py_DECREF(m);
1455 Py_DECREF(t);
1456 return r;
1457 bad:
1458 Py_XDECREF(m);
1459 Py_XDECREF(t);
1460 Py_XDECREF(py_ix);
1461 return NULL;
1463 """
1467 pytype_utility_code = UtilityCode(
1468 proto = """
1469 static INLINE PyObject* __Pyx_Type(PyObject* o) {
1470 PyObject* type = (PyObject*) Py_TYPE(o);
1471 Py_INCREF(type);
1472 return type;
1474 """
1478 tpnew_utility_code = UtilityCode(
1479 proto = """
1480 static INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1481 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
1482 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
1484 """ % {'TUPLE' : Naming.empty_tuple}
1488 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
1489 """Calculate the result of constant expressions to store it in
1490 ``expr_node.constant_result``, and replace trivial cases by their
1491 constant result.
1492 """
1493 def _calculate_const(self, node):
1494 if node.constant_result is not ExprNodes.constant_value_not_set:
1495 return
1497 # make sure we always set the value
1498 not_a_constant = ExprNodes.not_a_constant
1499 node.constant_result = not_a_constant
1501 # check if all children are constant
1502 children = self.visitchildren(node)
1503 for child_result in children.itervalues():
1504 if type(child_result) is list:
1505 for child in child_result:
1506 if child.constant_result is not_a_constant:
1507 return
1508 elif child_result.constant_result is not_a_constant:
1509 return
1511 # now try to calculate the real constant value
1512 try:
1513 node.calculate_constant_result()
1514 # if node.constant_result is not ExprNodes.not_a_constant:
1515 # print node.__class__.__name__, node.constant_result
1516 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
1517 # ignore all 'normal' errors here => no constant result
1518 pass
1519 except Exception:
1520 # this looks like a real error
1521 import traceback, sys
1522 traceback.print_exc(file=sys.stdout)
1524 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
1525 ExprNodes.LongNode, ExprNodes.FloatNode)
1527 def _widest_node_class(self, *nodes):
1528 try:
1529 return self.NODE_TYPE_ORDER[
1530 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
1531 except ValueError:
1532 return None
1534 def visit_ExprNode(self, node):
1535 self._calculate_const(node)
1536 return node
1538 def visit_BinopNode(self, node):
1539 self._calculate_const(node)
1540 if node.constant_result is ExprNodes.not_a_constant:
1541 return node
1542 if isinstance(node.constant_result, float):
1543 # We calculate float constants to make them available to
1544 # the compiler, but we do not aggregate them into a
1545 # constant node to prevent any loss of precision.
1546 return node
1547 if not node.operand1.is_literal or not node.operand2.is_literal:
1548 # We calculate other constants to make them available to
1549 # the compiler, but we only aggregate constant nodes
1550 # recursively, so non-const nodes are straight out.
1551 return node
1553 # now inject a new constant node with the calculated value
1554 try:
1555 type1, type2 = node.operand1.type, node.operand2.type
1556 if type1 is None or type2 is None:
1557 return node
1558 except AttributeError:
1559 return node
1561 if type1 is type2:
1562 new_node = node.operand1
1563 else:
1564 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
1565 if type(node.operand1) is type(node.operand2):
1566 new_node = node.operand1
1567 new_node.type = widest_type
1568 elif type1 is widest_type:
1569 new_node = node.operand1
1570 elif type2 is widest_type:
1571 new_node = node.operand2
1572 else:
1573 target_class = self._widest_node_class(
1574 node.operand1, node.operand2)
1575 if target_class is None:
1576 return node
1577 new_node = target_class(pos=node.pos, type = widest_type)
1579 new_node.constant_result = node.constant_result
1580 new_node.value = str(node.constant_result)
1581 #new_node = new_node.coerce_to(node.type, self.current_scope)
1582 return new_node
1584 # in the future, other nodes can have their own handler method here
1585 # that can replace them with a constant result node
1587 visit_Node = Visitor.VisitorTransform.recurse_to_children
1590 class FinalOptimizePhase(Visitor.CythonTransform):
1591 """
1592 This visitor handles several commuting optimizations, and is run
1593 just before the C code generation phase.
1595 The optimizations currently implemented in this class are:
1596 - Eliminate None assignment and refcounting for first assignment.
1597 - isinstance -> typecheck for cdef types
1598 """
1599 def visit_SingleAssignmentNode(self, node):
1600 """Avoid redundant initialisation of local variables before their
1601 first assignment.
1602 """
1603 self.visitchildren(node)
1604 if node.first:
1605 lhs = node.lhs
1606 lhs.lhs_of_first_assignment = True
1607 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
1608 # Have variable initialized to 0 rather than None
1609 lhs.entry.init_to_none = False
1610 lhs.entry.init = 0
1611 return node
1613 def visit_SimpleCallNode(self, node):
1614 """Replace generic calls to isinstance(x, type) by a more efficient
1615 type check.
1616 """
1617 self.visitchildren(node)
1618 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
1619 if node.function.name == 'isinstance':
1620 type_arg = node.args[1]
1621 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
1622 from CythonScope import utility_scope
1623 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
1624 node.function.type = node.function.entry.type
1625 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
1626 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
1627 return node