Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 2698:d2b963bc03f2

support len(char*) efficiently by calling strlen() instead
author Stefan Behnel <scoder@users.berlios.de>
date Thu Nov 26 20:49:09 2009 +0100 (2 years ago)
parents aca331774536
children 94aabaec040c
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
17 import codecs
19 try:
20 reduce
21 except NameError:
22 from functools import reduce
24 try:
25 set
26 except NameError:
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
31 nogil = False
33 def unwrap_node(node):
34 while isinstance(node, UtilNodes.ResultRefNode):
35 node = node.expression
36 return node
38 def is_common_value(a, b):
39 a = unwrap_node(a)
40 b = unwrap_node(b)
41 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
42 return a.name == b.name
43 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 return False
47 class IterationTransform(Visitor.VisitorTransform):
48 """Transform some common for-in loop patterns into efficient C loops:
50 - for-in-dict loop becomes a while loop calling PyDict_Next()
51 - for-in-enumerate is replaced by an external counter variable
52 - for-in-range loop becomes a plain C for loop
53 """
54 PyDict_Next_func_type = PyrexTypes.CFuncType(
55 PyrexTypes.c_bint_type, [
56 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
57 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
58 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
59 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
60 ])
62 PyDict_Next_name = EncodedString("PyDict_Next")
64 PyDict_Next_entry = Symtab.Entry(
65 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
67 visit_Node = Visitor.VisitorTransform.recurse_to_children
69 def visit_ModuleNode(self, node):
70 self.current_scope = node.scope
71 self.visitchildren(node)
72 return node
74 def visit_DefNode(self, node):
75 oldscope = self.current_scope
76 self.current_scope = node.entry.scope
77 self.visitchildren(node)
78 self.current_scope = oldscope
79 return node
81 def visit_ForInStatNode(self, node):
82 self.visitchildren(node)
83 return self._optimise_for_loop(node)
85 def _optimise_for_loop(self, node):
86 iterator = node.iterator.sequence
87 if iterator.type is Builtin.dict_type:
88 # like iterating over dict.keys()
89 return self._transform_dict_iteration(
90 node, dict_obj=iterator, keys=True, values=False)
92 # C array slice iteration?
93 if isinstance(iterator, ExprNodes.SliceIndexNode) and \
94 (iterator.base.type.is_array or iterator.base.type.is_ptr):
95 return self._transform_carray_iteration(node, iterator)
96 elif not isinstance(iterator, ExprNodes.SimpleCallNode):
97 return node
99 function = iterator.function
100 # dict iteration?
101 if isinstance(function, ExprNodes.AttributeNode) and \
102 function.obj.type == Builtin.dict_type:
103 dict_obj = function.obj
104 method = function.attribute
106 keys = values = False
107 if method == 'iterkeys':
108 keys = True
109 elif method == 'itervalues':
110 values = True
111 elif method == 'iteritems':
112 keys = values = True
113 else:
114 return node
115 return self._transform_dict_iteration(
116 node, dict_obj, keys, values)
118 # enumerate() ?
119 if iterator.self is None and \
120 isinstance(function, ExprNodes.NameNode) and \
121 function.entry.is_builtin and \
122 function.name == 'enumerate':
123 return self._transform_enumerate_iteration(node, iterator)
125 # range() iteration?
126 if Options.convert_range and node.target.type.is_int:
127 if iterator.self is None and \
128 isinstance(function, ExprNodes.NameNode) and \
129 function.entry.is_builtin and \
130 function.name in ('range', 'xrange'):
131 return self._transform_range_iteration(node, iterator)
133 return node
135 def _transform_carray_iteration(self, node, slice_node):
136 start = slice_node.start
137 stop = slice_node.stop
138 step = None
139 if not stop:
140 return node
142 carray_ptr = slice_node.base.coerce_to_simple(self.current_scope)
144 if start and start.constant_result != 0:
145 start_ptr_node = ExprNodes.AddNode(
146 start.pos,
147 operand1=carray_ptr,
148 operator='+',
149 operand2=start,
150 type=carray_ptr.type)
151 else:
152 start_ptr_node = carray_ptr
154 stop_ptr_node = ExprNodes.AddNode(
155 stop.pos,
156 operand1=carray_ptr,
157 operator='+',
158 operand2=stop,
159 type=carray_ptr.type
160 ).coerce_to_simple(self.current_scope)
162 counter = UtilNodes.TempHandle(carray_ptr.type)
163 counter_temp = counter.ref(node.target.pos)
165 if slice_node.base.type.is_string and node.target.type.is_pyobject:
166 # special case: char* -> bytes
167 target_value = ExprNodes.SliceIndexNode(
168 node.target.pos,
169 start=ExprNodes.IntNode(node.target.pos, value='0',
170 constant_result=0,
171 type=PyrexTypes.c_int_type),
172 stop=ExprNodes.IntNode(node.target.pos, value='1',
173 constant_result=1,
174 type=PyrexTypes.c_int_type),
175 base=counter_temp,
176 type=Builtin.bytes_type,
177 is_temp=1)
178 else:
179 target_value = ExprNodes.IndexNode(
180 node.target.pos,
181 index=ExprNodes.IntNode(node.target.pos, value='0',
182 constant_result=0,
183 type=PyrexTypes.c_int_type),
184 base=counter_temp,
185 is_buffer_access=False,
186 type=carray_ptr.type.base_type)
188 if target_value.type != node.target.type:
189 target_value = target_value.coerce_to(node.target.type,
190 self.current_scope)
192 target_assign = Nodes.SingleAssignmentNode(
193 pos = node.target.pos,
194 lhs = node.target,
195 rhs = target_value)
197 body = Nodes.StatListNode(
198 node.pos,
199 stats = [target_assign, node.body])
201 for_node = Nodes.ForFromStatNode(
202 node.pos,
203 bound1=start_ptr_node, relation1='<=',
204 target=counter_temp,
205 relation2='<', bound2=stop_ptr_node,
206 step=step, body=body,
207 else_clause=node.else_clause,
208 from_range=True)
210 return UtilNodes.TempsBlockNode(
211 node.pos, temps=[counter],
212 body=for_node)
214 def _transform_enumerate_iteration(self, node, enumerate_function):
215 args = enumerate_function.arg_tuple.args
216 if len(args) == 0:
217 error(enumerate_function.pos,
218 "enumerate() requires an iterable argument")
219 return node
220 elif len(args) > 1:
221 error(enumerate_function.pos,
222 "enumerate() takes at most 1 argument")
223 return node
225 if not node.target.is_sequence_constructor:
226 # leave this untouched for now
227 return node
228 targets = node.target.args
229 if len(targets) != 2:
230 # leave this untouched for now
231 return node
232 if not isinstance(targets[0], ExprNodes.NameNode):
233 # leave this untouched for now
234 return node
236 enumerate_target, iterable_target = targets
237 counter_type = enumerate_target.type
239 if not counter_type.is_pyobject and not counter_type.is_int:
240 # nothing we can do here, I guess
241 return node
243 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
244 value='0',
245 type=counter_type,
246 constant_result=0))
247 inc_expression = ExprNodes.AddNode(
248 enumerate_function.pos,
249 operand1 = temp,
250 operand2 = ExprNodes.IntNode(node.pos, value='1',
251 type=counter_type,
252 constant_result=1),
253 operator = '+',
254 type = counter_type,
255 is_temp = counter_type.is_pyobject
256 )
258 loop_body = [
259 Nodes.SingleAssignmentNode(
260 pos = enumerate_target.pos,
261 lhs = enumerate_target,
262 rhs = temp),
263 Nodes.SingleAssignmentNode(
264 pos = enumerate_target.pos,
265 lhs = temp,
266 rhs = inc_expression)
267 ]
269 if isinstance(node.body, Nodes.StatListNode):
270 node.body.stats = loop_body + node.body.stats
271 else:
272 loop_body.append(node.body)
273 node.body = Nodes.StatListNode(
274 node.body.pos,
275 stats = loop_body)
277 node.target = iterable_target
278 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
279 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
281 # recurse into loop to check for further optimisations
282 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
284 def _transform_range_iteration(self, node, range_function):
285 args = range_function.arg_tuple.args
286 if len(args) < 3:
287 step_pos = range_function.pos
288 step_value = 1
289 step = ExprNodes.IntNode(step_pos, value='1',
290 constant_result=1)
291 else:
292 step = args[2]
293 step_pos = step.pos
294 if not isinstance(step.constant_result, (int, long)):
295 # cannot determine step direction
296 return node
297 step_value = step.constant_result
298 if step_value == 0:
299 # will lead to an error elsewhere
300 return node
301 if not isinstance(step, ExprNodes.IntNode):
302 step = ExprNodes.IntNode(step_pos, value=str(step_value),
303 constant_result=step_value)
305 if step_value < 0:
306 step.value = str(-step_value)
307 relation1 = '>='
308 relation2 = '>'
309 else:
310 relation1 = '<='
311 relation2 = '<'
313 if len(args) == 1:
314 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
315 constant_result=0)
316 bound2 = args[0].coerce_to_integer(self.current_scope)
317 else:
318 bound1 = args[0].coerce_to_integer(self.current_scope)
319 bound2 = args[1].coerce_to_integer(self.current_scope)
320 step = step.coerce_to_integer(self.current_scope)
322 if not bound2.is_literal:
323 # stop bound must be immutable => keep it in a temp var
324 bound2_is_temp = True
325 bound2 = UtilNodes.LetRefNode(bound2)
326 else:
327 bound2_is_temp = False
329 for_node = Nodes.ForFromStatNode(
330 node.pos,
331 target=node.target,
332 bound1=bound1, relation1=relation1,
333 relation2=relation2, bound2=bound2,
334 step=step, body=node.body,
335 else_clause=node.else_clause,
336 from_range=True)
338 if bound2_is_temp:
339 for_node = UtilNodes.LetNode(bound2, for_node)
341 return for_node
343 def _transform_dict_iteration(self, node, dict_obj, keys, values):
344 py_object_ptr = PyrexTypes.c_void_ptr_type
346 temps = []
347 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
348 temps.append(temp)
349 dict_temp = temp.ref(dict_obj.pos)
350 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
351 temps.append(temp)
352 pos_temp = temp.ref(node.pos)
353 pos_temp_addr = ExprNodes.AmpersandNode(
354 node.pos, operand=pos_temp,
355 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
356 if keys:
357 temp = UtilNodes.TempHandle(py_object_ptr)
358 temps.append(temp)
359 key_temp = temp.ref(node.target.pos)
360 key_temp_addr = ExprNodes.AmpersandNode(
361 node.target.pos, operand=key_temp,
362 type=PyrexTypes.c_ptr_type(py_object_ptr))
363 else:
364 key_temp_addr = key_temp = ExprNodes.NullNode(
365 pos=node.target.pos)
366 if values:
367 temp = UtilNodes.TempHandle(py_object_ptr)
368 temps.append(temp)
369 value_temp = temp.ref(node.target.pos)
370 value_temp_addr = ExprNodes.AmpersandNode(
371 node.target.pos, operand=value_temp,
372 type=PyrexTypes.c_ptr_type(py_object_ptr))
373 else:
374 value_temp_addr = value_temp = ExprNodes.NullNode(
375 pos=node.target.pos)
377 key_target = value_target = node.target
378 tuple_target = None
379 if keys and values:
380 if node.target.is_sequence_constructor:
381 if len(node.target.args) == 2:
382 key_target, value_target = node.target.args
383 else:
384 # unusual case that may or may not lead to an error
385 return node
386 else:
387 tuple_target = node.target
389 def coerce_object_to(obj_node, dest_type):
390 if dest_type.is_pyobject:
391 if dest_type != obj_node.type:
392 if dest_type.is_extension_type or dest_type.is_builtin_type:
393 obj_node = ExprNodes.PyTypeTestNode(
394 obj_node, dest_type, self.current_scope, notnone=True)
395 result = ExprNodes.TypecastNode(
396 obj_node.pos,
397 operand = obj_node,
398 type = dest_type)
399 return (result, None)
400 else:
401 temp = UtilNodes.TempHandle(dest_type)
402 temps.append(temp)
403 temp_result = temp.ref(obj_node.pos)
404 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
405 def result(self):
406 return temp_result.result()
407 def generate_execution_code(self, code):
408 self.generate_result_code(code)
409 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
411 if isinstance(node.body, Nodes.StatListNode):
412 body = node.body
413 else:
414 body = Nodes.StatListNode(pos = node.body.pos,
415 stats = [node.body])
417 if tuple_target:
418 tuple_result = ExprNodes.TupleNode(
419 pos = tuple_target.pos,
420 args = [key_temp, value_temp],
421 is_temp = 1,
422 type = Builtin.tuple_type,
423 )
424 body.stats.insert(
425 0, Nodes.SingleAssignmentNode(
426 pos = tuple_target.pos,
427 lhs = tuple_target,
428 rhs = tuple_result))
429 else:
430 # execute all coercions before the assignments
431 coercion_stats = []
432 assign_stats = []
433 if keys:
434 temp_result, coercion = coerce_object_to(
435 key_temp, key_target.type)
436 if coercion:
437 coercion_stats.append(coercion)
438 assign_stats.append(
439 Nodes.SingleAssignmentNode(
440 pos = key_temp.pos,
441 lhs = key_target,
442 rhs = temp_result))
443 if values:
444 temp_result, coercion = coerce_object_to(
445 value_temp, value_target.type)
446 if coercion:
447 coercion_stats.append(coercion)
448 assign_stats.append(
449 Nodes.SingleAssignmentNode(
450 pos = value_temp.pos,
451 lhs = value_target,
452 rhs = temp_result))
453 body.stats[0:0] = coercion_stats + assign_stats
455 result_code = [
456 Nodes.SingleAssignmentNode(
457 pos = dict_obj.pos,
458 lhs = dict_temp,
459 rhs = dict_obj),
460 Nodes.SingleAssignmentNode(
461 pos = node.pos,
462 lhs = pos_temp,
463 rhs = ExprNodes.IntNode(node.pos, value='0',
464 constant_result=0)),
465 Nodes.WhileStatNode(
466 pos = node.pos,
467 condition = ExprNodes.SimpleCallNode(
468 pos = dict_obj.pos,
469 type = PyrexTypes.c_bint_type,
470 function = ExprNodes.NameNode(
471 pos = dict_obj.pos,
472 name = self.PyDict_Next_name,
473 type = self.PyDict_Next_func_type,
474 entry = self.PyDict_Next_entry),
475 args = [dict_temp, pos_temp_addr,
476 key_temp_addr, value_temp_addr]
477 ),
478 body = body,
479 else_clause = node.else_clause
480 )
481 ]
483 return UtilNodes.TempsBlockNode(
484 node.pos, temps=temps,
485 body=Nodes.StatListNode(
486 node.pos,
487 stats = result_code
488 ))
491 class SwitchTransform(Visitor.VisitorTransform):
492 """
493 This transformation tries to turn long if statements into C switch statements.
494 The requirement is that every clause be an (or of) var == value, where the var
495 is common among all clauses and both var and value are ints.
496 """
497 def extract_conditions(self, cond):
498 while True:
499 if isinstance(cond, ExprNodes.CoerceToTempNode):
500 cond = cond.arg
501 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
502 # this is what we get from the FlattenInListTransform
503 cond = cond.subexpression
504 elif isinstance(cond, ExprNodes.TypecastNode):
505 cond = cond.operand
506 else:
507 break
509 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
510 and cond.cascade is None
511 and cond.operator == '=='
512 and not cond.is_python_comparison()):
513 if is_common_value(cond.operand1, cond.operand1):
514 if cond.operand2.is_literal:
515 return cond.operand1, [cond.operand2]
516 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
517 return cond.operand1, [cond.operand2]
518 if is_common_value(cond.operand2, cond.operand2):
519 if cond.operand1.is_literal:
520 return cond.operand2, [cond.operand1]
521 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
522 return cond.operand2, [cond.operand1]
523 elif (isinstance(cond, ExprNodes.BoolBinopNode)
524 and cond.operator == 'or'):
525 t1, c1 = self.extract_conditions(cond.operand1)
526 t2, c2 = self.extract_conditions(cond.operand2)
527 if is_common_value(t1, t2):
528 return t1, c1+c2
529 return None, None
531 def visit_IfStatNode(self, node):
532 self.visitchildren(node)
533 common_var = None
534 case_count = 0
535 cases = []
536 for if_clause in node.if_clauses:
537 var, conditions = self.extract_conditions(if_clause.condition)
538 if var is None:
539 return node
540 elif common_var is not None and not is_common_value(var, common_var):
541 return node
542 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
543 return node
544 else:
545 common_var = var
546 case_count += len(conditions)
547 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
548 conditions = conditions,
549 body = if_clause.body))
550 if case_count < 2:
551 return node
553 common_var = unwrap_node(common_var)
554 return Nodes.SwitchStatNode(pos = node.pos,
555 test = common_var,
556 cases = cases,
557 else_clause = node.else_clause)
559 visit_Node = Visitor.VisitorTransform.recurse_to_children
562 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
563 """
564 This transformation flattens "x in [val1, ..., valn]" into a sequential list
565 of comparisons.
566 """
568 def visit_PrimaryCmpNode(self, node):
569 self.visitchildren(node)
570 if node.cascade is not None:
571 return node
572 elif node.operator == 'in':
573 conjunction = 'or'
574 eq_or_neq = '=='
575 elif node.operator == 'not_in':
576 conjunction = 'and'
577 eq_or_neq = '!='
578 else:
579 return node
581 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
582 return node
584 args = node.operand2.args
585 if len(args) == 0:
586 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
588 lhs = UtilNodes.ResultRefNode(node.operand1)
590 conds = []
591 for arg in args:
592 cond = ExprNodes.PrimaryCmpNode(
593 pos = node.pos,
594 operand1 = lhs,
595 operator = eq_or_neq,
596 operand2 = arg,
597 cascade = None)
598 conds.append(ExprNodes.TypecastNode(
599 pos = node.pos,
600 operand = cond,
601 type = PyrexTypes.c_bint_type))
602 def concat(left, right):
603 return ExprNodes.BoolBinopNode(
604 pos = node.pos,
605 operator = conjunction,
606 operand1 = left,
607 operand2 = right)
609 condition = reduce(concat, conds)
610 return UtilNodes.EvalWithTempExprNode(lhs, condition)
612 visit_Node = Visitor.VisitorTransform.recurse_to_children
615 class DropRefcountingTransform(Visitor.VisitorTransform):
616 """Drop ref-counting in safe places.
617 """
618 visit_Node = Visitor.VisitorTransform.recurse_to_children
620 def visit_ParallelAssignmentNode(self, node):
621 left_names, right_names = [], []
622 left_indices, right_indices = [], []
623 temps = []
625 for stat in node.stats:
626 if isinstance(stat, Nodes.SingleAssignmentNode):
627 if not self._extract_operand(stat.lhs, left_names,
628 left_indices, temps):
629 return node
630 if not self._extract_operand(stat.rhs, right_names,
631 right_indices, temps):
632 return node
633 elif isinstance(stat, Nodes.CascadedAssignmentNode):
634 # FIXME
635 return node
636 else:
637 return node
639 if left_names or right_names:
640 # lhs/rhs names must be a non-redundant permutation
641 lnames = [ path for path, n in left_names ]
642 rnames = [ path for path, n in right_names ]
643 if set(lnames) != set(rnames):
644 return node
645 if len(set(lnames)) != len(right_names):
646 return node
648 if left_indices or right_indices:
649 # base name and index of index nodes must be a
650 # non-redundant permutation
651 lindices = []
652 for lhs_node in left_indices:
653 index_id = self._extract_index_id(lhs_node)
654 if not index_id:
655 return node
656 lindices.append(index_id)
657 rindices = []
658 for rhs_node in right_indices:
659 index_id = self._extract_index_id(rhs_node)
660 if not index_id:
661 return node
662 rindices.append(index_id)
664 if set(lindices) != set(rindices):
665 return node
666 if len(set(lindices)) != len(right_indices):
667 return node
669 # really supporting IndexNode requires support in
670 # __Pyx_GetItemInt(), so let's stop short for now
671 return node
673 temp_args = [t.arg for t in temps]
674 for temp in temps:
675 temp.use_managed_ref = False
677 for _, name_node in left_names + right_names:
678 if name_node not in temp_args:
679 name_node.use_managed_ref = False
681 for index_node in left_indices + right_indices:
682 index_node.use_managed_ref = False
684 return node
686 def _extract_operand(self, node, names, indices, temps):
687 node = unwrap_node(node)
688 if not node.type.is_pyobject:
689 return False
690 if isinstance(node, ExprNodes.CoerceToTempNode):
691 temps.append(node)
692 node = node.arg
693 name_path = []
694 obj_node = node
695 while isinstance(obj_node, ExprNodes.AttributeNode):
696 if obj_node.is_py_attr:
697 return False
698 name_path.append(obj_node.member)
699 obj_node = obj_node.obj
700 if isinstance(obj_node, ExprNodes.NameNode):
701 name_path.append(obj_node.name)
702 names.append( ('.'.join(name_path[::-1]), node) )
703 elif isinstance(node, ExprNodes.IndexNode):
704 if node.base.type != Builtin.list_type:
705 return False
706 if not node.index.type.is_int:
707 return False
708 if not isinstance(node.base, ExprNodes.NameNode):
709 return False
710 indices.append(node)
711 else:
712 return False
713 return True
715 def _extract_index_id(self, index_node):
716 base = index_node.base
717 index = index_node.index
718 if isinstance(index, ExprNodes.NameNode):
719 index_val = index.name
720 elif isinstance(index, ExprNodes.ConstNode):
721 # FIXME:
722 return None
723 else:
724 return None
725 return (base.name, index_val)
728 class OptimizeBuiltinCalls(Visitor.EnvTransform):
729 """Optimize some common methods calls and instantiation patterns
730 for builtin types.
731 """
732 # only intercept on call nodes
733 visit_Node = Visitor.VisitorTransform.recurse_to_children
735 def visit_GeneralCallNode(self, node):
736 self.visitchildren(node)
737 function = node.function
738 if not function.type.is_pyobject:
739 return node
740 arg_tuple = node.positional_args
741 if not isinstance(arg_tuple, ExprNodes.TupleNode):
742 return node
743 args = arg_tuple.args
744 return self._dispatch_to_handler(
745 node, function, args, node.keyword_args)
747 def visit_SimpleCallNode(self, node):
748 self.visitchildren(node)
749 function = node.function
750 if function.type.is_pyobject:
751 arg_tuple = node.arg_tuple
752 if not isinstance(arg_tuple, ExprNodes.TupleNode):
753 return node
754 args = arg_tuple.args
755 else:
756 args = node.args
757 return self._dispatch_to_handler(
758 node, function, args)
760 ### cleanup to avoid redundant coercions to/from Python types
762 def _visit_PyTypeTestNode(self, node):
763 # disabled - appears to break assignments in some cases, and
764 # also drops a None check, which might still be required
765 """Flatten redundant type checks after tree changes.
766 """
767 old_arg = node.arg
768 self.visitchildren(node)
769 if old_arg is node.arg or node.arg.type != node.type:
770 return node
771 return node.arg
773 def visit_CoerceFromPyTypeNode(self, node):
774 """Drop redundant conversion nodes after tree changes.
776 Also, optimise away calls to Python's builtin int() and
777 float() if the result is going to be coerced back into a C
778 type anyway.
779 """
780 self.visitchildren(node)
781 arg = node.arg
782 if not arg.type.is_pyobject:
783 # no Python conversion left at all, just do a C coercion instead
784 if node.type == arg.type:
785 return arg
786 else:
787 return arg.coerce_to(node.type, self.env_stack[-1])
788 if not isinstance(arg, ExprNodes.SimpleCallNode):
789 return node
790 if not (node.type.is_int or node.type.is_float):
791 return node
792 function = arg.function
793 if not isinstance(function, ExprNodes.NameNode) \
794 or not function.type.is_builtin_type \
795 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
796 return node
797 args = arg.arg_tuple.args
798 if len(args) != 1:
799 return node
800 func_arg = args[0]
801 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
802 func_arg = func_arg.arg
803 elif func_arg.type.is_pyobject:
804 # play safe: Python conversion might work on all sorts of things
805 return node
806 if function.name == 'int':
807 if func_arg.type.is_int or node.type.is_int:
808 if func_arg.type == node.type:
809 return func_arg
810 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
811 return ExprNodes.CastNode(func_arg, node.type)
812 elif function.name == 'float':
813 if func_arg.type.is_float or node.type.is_float:
814 if func_arg.type == node.type:
815 return func_arg
816 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
817 return ExprNodes.CastNode(func_arg, node.type)
818 return node
820 ### dispatch to specific optimisers
822 def _find_handler(self, match_name, has_kwargs):
823 call_type = has_kwargs and 'general' or 'simple'
824 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
825 if handler is None:
826 handler = getattr(self, '_handle_any_%s' % match_name, None)
827 return handler
829 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
830 if function.is_name:
831 # we only consider functions that are either builtin
832 # Python functions or builtins that were already replaced
833 # into a C function call (defined in the builtin scope)
834 is_builtin = function.entry.is_builtin \
835 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
836 if not is_builtin:
837 return node
838 function_handler = self._find_handler(
839 "function_%s" % function.name, kwargs)
840 if function_handler is None:
841 return node
842 if kwargs:
843 return function_handler(node, arg_list, kwargs)
844 else:
845 return function_handler(node, arg_list)
846 elif function.is_attribute and function.type.is_pyobject:
847 attr_name = function.attribute
848 self_arg = function.obj
849 obj_type = self_arg.type
850 is_unbound_method = False
851 if obj_type.is_builtin_type:
852 if obj_type is Builtin.type_type and arg_list and \
853 arg_list[0].type.is_pyobject:
854 # calling an unbound method like 'list.append(L,x)'
855 # (ignoring 'type.mro()' here ...)
856 type_name = function.obj.name
857 self_arg = None
858 is_unbound_method = True
859 else:
860 type_name = obj_type.name
861 else:
862 type_name = "object" # safety measure
863 method_handler = self._find_handler(
864 "method_%s_%s" % (type_name, attr_name), kwargs)
865 if method_handler is None:
866 if attr_name in TypeSlots.method_name_to_slot \
867 or attr_name == '__new__':
868 method_handler = self._find_handler(
869 "slot%s" % attr_name, kwargs)
870 if method_handler is None:
871 return node
872 if self_arg is not None:
873 arg_list = [self_arg] + list(arg_list)
874 if kwargs:
875 return method_handler(node, arg_list, kwargs, is_unbound_method)
876 else:
877 return method_handler(node, arg_list, is_unbound_method)
878 else:
879 return node
881 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
882 if not expected: # None or 0
883 arg_str = ''
884 elif isinstance(expected, basestring) or expected > 1:
885 arg_str = '...'
886 elif expected == 1:
887 arg_str = 'x'
888 else:
889 arg_str = ''
890 if expected is not None:
891 expected_str = 'expected %s, ' % expected
892 else:
893 expected_str = ''
894 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
895 function_name, arg_str, expected_str, len(args)))
897 ### builtin types
899 def _handle_general_function_dict(self, node, pos_args, kwargs):
900 """Replace dict(a=b,c=d,...) by the underlying keyword dict
901 construction which is done anyway.
902 """
903 if len(pos_args) > 0:
904 return node
905 if not isinstance(kwargs, ExprNodes.DictNode):
906 return node
907 if node.starstar_arg:
908 # we could optimize this by updating the kw dict instead
909 return node
910 return kwargs
912 PyDict_Copy_func_type = PyrexTypes.CFuncType(
913 Builtin.dict_type, [
914 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
915 ])
917 def _handle_simple_function_dict(self, node, pos_args):
918 """Replace dict(some_dict) by PyDict_Copy(some_dict) and
919 dict([ (a,b) for ... ]) by a literal { a:b for ... }.
920 """
921 if len(pos_args) != 1:
922 return node
923 arg = pos_args[0]
924 if arg.type is Builtin.dict_type:
925 arg = ExprNodes.NoneCheckNode(
926 arg, "PyExc_TypeError", "'NoneType' is not iterable")
927 return ExprNodes.PythonCapiCallNode(
928 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
929 args = [arg],
930 is_temp = node.is_temp
931 )
932 elif isinstance(arg, ExprNodes.ComprehensionNode) and \
933 arg.type is Builtin.list_type:
934 append_node = arg.append
935 if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
936 len(append_node.expr.args) == 2:
937 key_node, value_node = append_node.expr.args
938 target_node = ExprNodes.DictNode(
939 pos=arg.target.pos, key_value_pairs=[], is_temp=1)
940 new_append_node = ExprNodes.DictComprehensionAppendNode(
941 append_node.pos, target=target_node,
942 key_expr=key_node, value_expr=value_node,
943 is_temp=1)
944 arg.target = target_node
945 arg.type = target_node.type
946 replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
947 return replace_in(arg)
948 return node
950 def _handle_simple_function_set(self, node, pos_args):
951 """Replace set([a,b,...]) by a literal set {a,b,...} and
952 set([ x for ... ]) by a literal { x for ... }.
953 """
954 arg_count = len(pos_args)
955 if arg_count == 0:
956 return ExprNodes.SetNode(node.pos, args=[],
957 type=Builtin.set_type, is_temp=1)
958 if arg_count > 1:
959 return node
960 iterable = pos_args[0]
961 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
962 return ExprNodes.SetNode(node.pos, args=iterable.args,
963 type=Builtin.set_type, is_temp=1)
964 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
965 iterable.type is Builtin.list_type:
966 iterable.target = ExprNodes.SetNode(
967 node.pos, args=[], type=Builtin.set_type, is_temp=1)
968 iterable.type = Builtin.set_type
969 iterable.pos = node.pos
970 return iterable
971 else:
972 return node
974 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
975 Builtin.tuple_type, [
976 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
977 ])
979 def _handle_simple_function_tuple(self, node, pos_args):
980 """Replace tuple([...]) by a call to PyList_AsTuple.
981 """
982 if len(pos_args) != 1:
983 return node
984 list_arg = pos_args[0]
985 if list_arg.type is not Builtin.list_type:
986 return node
987 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
988 ExprNodes.ListNode)):
989 pos_args[0] = ExprNodes.NoneCheckNode(
990 list_arg, "PyExc_TypeError",
991 "'NoneType' object is not iterable")
993 return ExprNodes.PythonCapiCallNode(
994 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
995 args = pos_args,
996 is_temp = node.is_temp
997 )
999 ### builtin functions
1001 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
1002 PyrexTypes.py_object_type, [
1003 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1004 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1005 ])
1007 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
1008 PyrexTypes.py_object_type, [
1009 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1010 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1011 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1012 ])
1014 def _handle_simple_function_getattr(self, node, pos_args):
1015 if len(pos_args) == 2:
1016 node = ExprNodes.PythonCapiCallNode(
1017 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
1018 args = pos_args,
1019 is_temp = node.is_temp
1021 elif len(pos_args) == 3:
1022 node = ExprNodes.PythonCapiCallNode(
1023 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
1024 utility_code = Builtin.getattr3_utility_code,
1025 args = pos_args,
1026 is_temp = node.is_temp
1028 else:
1029 self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
1030 return node
1032 Pyx_Type_func_type = PyrexTypes.CFuncType(
1033 Builtin.type_type, [
1034 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1035 ])
1037 def _handle_simple_function_type(self, node, pos_args):
1038 if len(pos_args) != 1:
1039 return node
1040 node = ExprNodes.PythonCapiCallNode(
1041 node.pos, "__Pyx_Type", self.Pyx_Type_func_type,
1042 args = pos_args,
1043 is_temp = node.is_temp,
1044 utility_code = pytype_utility_code,
1046 return node
1048 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1049 PyrexTypes.c_size_t_type, [
1050 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1051 ])
1053 def _handle_simple_function_len(self, node, pos_args):
1054 if len(pos_args) != 1:
1055 self._error_wrong_arg_count('len', node, pos_args, 1)
1056 return node
1057 arg = pos_args[0]
1058 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1059 arg = arg.arg
1060 if not arg.type.is_string:
1061 return node
1062 node = ExprNodes.PythonCapiCallNode(
1063 node.pos, "strlen", self.Pyx_strlen_func_type,
1064 args = [arg],
1065 is_temp = node.is_temp,
1066 utility_code = include_string_h_utility_code,
1068 return node
1070 ### special methods
1072 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1073 PyrexTypes.py_object_type, [
1074 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1075 ])
1077 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1078 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1079 """
1080 obj = node.function.obj
1081 if not is_unbound_method or len(args) != 1:
1082 return node
1083 type_arg = args[0]
1084 if not obj.is_name or not type_arg.is_name:
1085 # play safe
1086 return node
1087 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1088 # not a known type, play safe
1089 return node
1090 if not type_arg.type_entry or not obj.type_entry:
1091 if obj.name != type_arg.name:
1092 return node
1093 # otherwise, we know it's a type and we know it's the same
1094 # type for both - that should do
1095 elif type_arg.type_entry != obj.type_entry:
1096 # different types - may or may not lead to an error at runtime
1097 return node
1099 # FIXME: we could potentially look up the actual tp_new C
1100 # method of the extension type and call that instead of the
1101 # generic slot. That would also allow us to pass parameters
1102 # efficiently.
1104 if not type_arg.type_entry:
1105 # arbitrary variable, needs a None check for safety
1106 type_arg = ExprNodes.NoneCheckNode(
1107 type_arg, "PyExc_TypeError",
1108 "object.__new__(X): X is not a type object (NoneType)")
1110 return ExprNodes.PythonCapiCallNode(
1111 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1112 args = [type_arg],
1113 utility_code = tpnew_utility_code,
1114 is_temp = node.is_temp
1117 ### methods of builtin types
1119 PyObject_Append_func_type = PyrexTypes.CFuncType(
1120 PyrexTypes.py_object_type, [
1121 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1122 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1123 ])
1125 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
1126 # X.append() is almost always referring to a list
1127 if len(args) != 2:
1128 return node
1130 return ExprNodes.PythonCapiCallNode(
1131 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
1132 args = args,
1133 is_temp = node.is_temp,
1134 utility_code = append_utility_code
1137 PyObject_Pop_func_type = PyrexTypes.CFuncType(
1138 PyrexTypes.py_object_type, [
1139 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1140 ])
1142 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
1143 PyrexTypes.py_object_type, [
1144 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1145 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
1146 ])
1148 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
1149 # X.pop([n]) is almost always referring to a list
1150 if len(args) == 1:
1151 return ExprNodes.PythonCapiCallNode(
1152 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
1153 args = args,
1154 is_temp = node.is_temp,
1155 utility_code = pop_utility_code
1157 elif len(args) == 2:
1158 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
1159 original_type = args[1].arg.type
1160 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
1161 args[1] = args[1].arg
1162 return ExprNodes.PythonCapiCallNode(
1163 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
1164 args = args,
1165 is_temp = node.is_temp,
1166 utility_code = pop_index_utility_code
1169 return node
1171 PyList_Append_func_type = PyrexTypes.CFuncType(
1172 PyrexTypes.c_int_type, [
1173 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1174 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1175 ],
1176 exception_value = "-1")
1178 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
1179 if len(args) != 2:
1180 self._error_wrong_arg_count('list.append', node, args, 2)
1181 return node
1182 return self._substitute_method_call(
1183 node, "PyList_Append", self.PyList_Append_func_type,
1184 'append', is_unbound_method, args)
1186 single_param_func_type = PyrexTypes.CFuncType(
1187 PyrexTypes.c_int_type, [
1188 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1189 ],
1190 exception_value = "-1")
1192 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
1193 if len(args) != 1:
1194 return node
1195 return self._substitute_method_call(
1196 node, "PyList_Sort", self.single_param_func_type,
1197 'sort', is_unbound_method, args)
1199 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
1200 if len(args) != 1:
1201 self._error_wrong_arg_count('list.reverse', node, args, 1)
1202 return node
1203 return self._substitute_method_call(
1204 node, "PyList_Reverse", self.single_param_func_type,
1205 'reverse', is_unbound_method, args)
1207 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
1208 Builtin.bytes_type, [
1209 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1210 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1211 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1212 ],
1213 exception_value = "NULL")
1215 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
1216 Builtin.bytes_type, [
1217 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1218 ],
1219 exception_value = "NULL")
1221 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
1222 'unicode_escape', 'raw_unicode_escape']
1224 _special_codecs = [ (name, codecs.getencoder(name))
1225 for name in _special_encodings ]
1227 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
1228 if len(args) < 1 or len(args) > 3:
1229 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1230 return node
1232 string_node = args[0]
1234 if len(args) == 1:
1235 null_node = ExprNodes.NullNode(node.pos)
1236 return self._substitute_method_call(
1237 node, "PyUnicode_AsEncodedString",
1238 self.PyUnicode_AsEncodedString_func_type,
1239 'encode', is_unbound_method, [string_node, null_node, null_node])
1241 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1242 if parameters is None:
1243 return node
1244 encoding, encoding_node, error_handling, error_handling_node = parameters
1246 if isinstance(string_node, ExprNodes.UnicodeNode):
1247 # constant, so try to do the encoding at compile time
1248 try:
1249 value = string_node.value.encode(encoding, error_handling)
1250 except:
1251 # well, looks like we can't
1252 pass
1253 else:
1254 value = BytesLiteral(value)
1255 value.encoding = encoding
1256 return ExprNodes.BytesNode(
1257 string_node.pos, value=value, type=Builtin.bytes_type)
1259 if error_handling == 'strict':
1260 # try to find a specific encoder function
1261 codec_name = self._find_special_codec_name(encoding)
1262 if codec_name is not None:
1263 encode_function = "PyUnicode_As%sString" % codec_name
1264 return self._substitute_method_call(
1265 node, encode_function,
1266 self.PyUnicode_AsXyzString_func_type,
1267 'encode', is_unbound_method, [string_node])
1269 return self._substitute_method_call(
1270 node, "PyUnicode_AsEncodedString",
1271 self.PyUnicode_AsEncodedString_func_type,
1272 'encode', is_unbound_method,
1273 [string_node, encoding_node, error_handling_node])
1275 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
1276 Builtin.unicode_type, [
1277 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1278 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1279 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1280 ],
1281 exception_value = "NULL")
1283 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
1284 Builtin.unicode_type, [
1285 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1286 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1287 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1288 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1289 ],
1290 exception_value = "NULL")
1292 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
1293 if len(args) < 1 or len(args) > 3:
1294 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
1295 return node
1296 if not isinstance(args[0], ExprNodes.SliceIndexNode):
1297 # we need the string length as a slice end index
1298 return node
1299 index_node = args[0]
1300 string_node = index_node.base
1301 if not string_node.type.is_string:
1302 # nothing to optimise here
1303 return node
1304 start, stop = index_node.start, index_node.stop
1305 if not stop:
1306 # FIXME: could use strlen() - although Python will do that anyway ...
1307 return node
1308 if stop.type.is_pyobject:
1309 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1310 if start and start.constant_result != 0:
1311 # FIXME: put start into a temp and do the math
1312 return node
1314 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1315 if parameters is None:
1316 return node
1317 encoding, encoding_node, error_handling, error_handling_node = parameters
1319 # try to find a specific encoder function
1320 codec_name = self._find_special_codec_name(encoding)
1321 if codec_name is not None:
1322 decode_function = "PyUnicode_Decode%s" % codec_name
1323 return ExprNodes.PythonCapiCallNode(
1324 node.pos, decode_function,
1325 self.PyUnicode_DecodeXyz_func_type,
1326 args = [string_node, stop, error_handling_node],
1327 is_temp = node.is_temp,
1330 return ExprNodes.PythonCapiCallNode(
1331 node.pos, "PyUnicode_Decode",
1332 self.PyUnicode_Decode_func_type,
1333 args = [string_node, stop, encoding_node, error_handling_node],
1334 is_temp = node.is_temp,
1337 def _find_special_codec_name(self, encoding):
1338 try:
1339 requested_codec = codecs.getencoder(encoding)
1340 except:
1341 return None
1342 for name, codec in self._special_codecs:
1343 if codec == requested_codec:
1344 if '_' in name:
1345 name = ''.join([ s.capitalize()
1346 for s in name.split('_')])
1347 return name
1348 return None
1350 def _unpack_encoding_and_error_mode(self, pos, args):
1351 encoding_node = args[1]
1352 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
1353 encoding_node = encoding_node.arg
1354 if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1355 ExprNodes.BytesNode)):
1356 return None
1357 encoding = encoding_node.value
1358 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
1359 type=PyrexTypes.c_char_ptr_type)
1361 null_node = ExprNodes.NullNode(pos)
1362 if len(args) == 3:
1363 error_handling_node = args[2]
1364 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
1365 error_handling_node = error_handling_node.arg
1366 if not isinstance(error_handling_node,
1367 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1368 ExprNodes.BytesNode)):
1369 return None
1370 error_handling = error_handling_node.value
1371 if error_handling == 'strict':
1372 error_handling_node = null_node
1373 else:
1374 error_handling_node = ExprNodes.BytesNode(
1375 error_handling_node.pos, value=error_handling,
1376 type=PyrexTypes.c_char_ptr_type)
1377 else:
1378 error_handling = 'strict'
1379 error_handling_node = null_node
1381 return (encoding, encoding_node, error_handling, error_handling_node)
1383 def _substitute_method_call(self, node, name, func_type,
1384 attr_name, is_unbound_method, args=()):
1385 args = list(args)
1386 if args:
1387 self_arg = args[0]
1388 if is_unbound_method:
1389 self_arg = ExprNodes.NoneCheckNode(
1390 self_arg, "PyExc_TypeError",
1391 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
1392 attr_name, node.function.obj.name))
1393 else:
1394 self_arg = ExprNodes.NoneCheckNode(
1395 self_arg, "PyExc_AttributeError",
1396 "'NoneType' object has no attribute '%s'" % attr_name)
1397 args[0] = self_arg
1398 return ExprNodes.PythonCapiCallNode(
1399 node.pos, name, func_type,
1400 args = args,
1401 is_temp = node.is_temp
1405 append_utility_code = UtilityCode(
1406 proto = """
1407 static INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1408 if (likely(PyList_CheckExact(L))) {
1409 if (PyList_Append(L, x) < 0) return NULL;
1410 Py_INCREF(Py_None);
1411 return Py_None; /* this is just to have an accurate signature */
1413 else {
1414 PyObject *r, *m;
1415 m = __Pyx_GetAttrString(L, "append");
1416 if (!m) return NULL;
1417 r = PyObject_CallFunctionObjArgs(m, x, NULL);
1418 Py_DECREF(m);
1419 return r;
1422 """,
1423 impl = ""
1427 pop_utility_code = UtilityCode(
1428 proto = """
1429 static INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1430 if (likely(PyList_CheckExact(L))
1431 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
1432 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
1433 Py_SIZE(L) -= 1;
1434 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
1436 else {
1437 PyObject *r, *m;
1438 m = __Pyx_GetAttrString(L, "pop");
1439 if (!m) return NULL;
1440 r = PyObject_CallObject(m, NULL);
1441 Py_DECREF(m);
1442 return r;
1445 """,
1446 impl = ""
1449 pop_index_utility_code = UtilityCode(
1450 proto = """
1451 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
1452 """,
1453 impl = """
1454 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
1455 PyObject *r, *m, *t, *py_ix;
1456 if (likely(PyList_CheckExact(L))) {
1457 Py_ssize_t size = PyList_GET_SIZE(L);
1458 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
1459 if (ix < 0) {
1460 ix += size;
1462 if (likely(0 <= ix && ix < size)) {
1463 Py_ssize_t i;
1464 PyObject* v = PyList_GET_ITEM(L, ix);
1465 Py_SIZE(L) -= 1;
1466 size -= 1;
1467 for(i=ix; i<size; i++) {
1468 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
1470 return v;
1474 py_ix = t = NULL;
1475 m = __Pyx_GetAttrString(L, "pop");
1476 if (!m) goto bad;
1477 py_ix = PyInt_FromSsize_t(ix);
1478 if (!py_ix) goto bad;
1479 t = PyTuple_New(1);
1480 if (!t) goto bad;
1481 PyTuple_SET_ITEM(t, 0, py_ix);
1482 py_ix = NULL;
1483 r = PyObject_CallObject(m, t);
1484 Py_DECREF(m);
1485 Py_DECREF(t);
1486 return r;
1487 bad:
1488 Py_XDECREF(m);
1489 Py_XDECREF(t);
1490 Py_XDECREF(py_ix);
1491 return NULL;
1493 """
1497 pytype_utility_code = UtilityCode(
1498 proto = """
1499 static INLINE PyObject* __Pyx_Type(PyObject* o) {
1500 PyObject* type = (PyObject*) Py_TYPE(o);
1501 Py_INCREF(type);
1502 return type;
1504 """
1508 include_string_h_utility_code = UtilityCode(
1509 proto = """
1510 #include <string.h>
1511 """
1515 tpnew_utility_code = UtilityCode(
1516 proto = """
1517 static INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1518 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
1519 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
1521 """ % {'TUPLE' : Naming.empty_tuple}
1525 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
1526 """Calculate the result of constant expressions to store it in
1527 ``expr_node.constant_result``, and replace trivial cases by their
1528 constant result.
1529 """
1530 def _calculate_const(self, node):
1531 if node.constant_result is not ExprNodes.constant_value_not_set:
1532 return
1534 # make sure we always set the value
1535 not_a_constant = ExprNodes.not_a_constant
1536 node.constant_result = not_a_constant
1538 # check if all children are constant
1539 children = self.visitchildren(node)
1540 for child_result in children.itervalues():
1541 if type(child_result) is list:
1542 for child in child_result:
1543 if child.constant_result is not_a_constant:
1544 return
1545 elif child_result.constant_result is not_a_constant:
1546 return
1548 # now try to calculate the real constant value
1549 try:
1550 node.calculate_constant_result()
1551 # if node.constant_result is not ExprNodes.not_a_constant:
1552 # print node.__class__.__name__, node.constant_result
1553 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
1554 # ignore all 'normal' errors here => no constant result
1555 pass
1556 except Exception:
1557 # this looks like a real error
1558 import traceback, sys
1559 traceback.print_exc(file=sys.stdout)
1561 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
1562 ExprNodes.LongNode, ExprNodes.FloatNode)
1564 def _widest_node_class(self, *nodes):
1565 try:
1566 return self.NODE_TYPE_ORDER[
1567 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
1568 except ValueError:
1569 return None
1571 def visit_ExprNode(self, node):
1572 self._calculate_const(node)
1573 return node
1575 def visit_BinopNode(self, node):
1576 self._calculate_const(node)
1577 if node.constant_result is ExprNodes.not_a_constant:
1578 return node
1579 if isinstance(node.constant_result, float):
1580 # We calculate float constants to make them available to
1581 # the compiler, but we do not aggregate them into a
1582 # constant node to prevent any loss of precision.
1583 return node
1584 if not node.operand1.is_literal or not node.operand2.is_literal:
1585 # We calculate other constants to make them available to
1586 # the compiler, but we only aggregate constant nodes
1587 # recursively, so non-const nodes are straight out.
1588 return node
1590 # now inject a new constant node with the calculated value
1591 try:
1592 type1, type2 = node.operand1.type, node.operand2.type
1593 if type1 is None or type2 is None:
1594 return node
1595 except AttributeError:
1596 return node
1598 if type1 is type2:
1599 new_node = node.operand1
1600 else:
1601 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
1602 if type(node.operand1) is type(node.operand2):
1603 new_node = node.operand1
1604 new_node.type = widest_type
1605 elif type1 is widest_type:
1606 new_node = node.operand1
1607 elif type2 is widest_type:
1608 new_node = node.operand2
1609 else:
1610 target_class = self._widest_node_class(
1611 node.operand1, node.operand2)
1612 if target_class is None:
1613 return node
1614 new_node = target_class(pos=node.pos, type = widest_type)
1616 new_node.constant_result = node.constant_result
1617 new_node.value = str(node.constant_result)
1618 #new_node = new_node.coerce_to(node.type, self.current_scope)
1619 return new_node
1621 # in the future, other nodes can have their own handler method here
1622 # that can replace them with a constant result node
1624 visit_Node = Visitor.VisitorTransform.recurse_to_children
1627 class FinalOptimizePhase(Visitor.CythonTransform):
1628 """
1629 This visitor handles several commuting optimizations, and is run
1630 just before the C code generation phase.
1632 The optimizations currently implemented in this class are:
1633 - Eliminate None assignment and refcounting for first assignment.
1634 - isinstance -> typecheck for cdef types
1635 """
1636 def visit_SingleAssignmentNode(self, node):
1637 """Avoid redundant initialisation of local variables before their
1638 first assignment.
1639 """
1640 self.visitchildren(node)
1641 if node.first:
1642 lhs = node.lhs
1643 lhs.lhs_of_first_assignment = True
1644 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
1645 # Have variable initialized to 0 rather than None
1646 lhs.entry.init_to_none = False
1647 lhs.entry.init = 0
1648 return node
1650 def visit_SimpleCallNode(self, node):
1651 """Replace generic calls to isinstance(x, type) by a more efficient
1652 type check.
1653 """
1654 self.visitchildren(node)
1655 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
1656 if node.function.name == 'isinstance':
1657 type_arg = node.args[1]
1658 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
1659 from CythonScope import utility_scope
1660 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
1661 node.function.type = node.function.entry.type
1662 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
1663 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
1664 return node