Cython has moved to github.
cython-devel
view Cython/Compiler/Optimize.py @ 2953:4b2e6c18fe38
enable for-in iteration also for C arrays of known size
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Thu Feb 11 20:42:35 2010 +0100 (2 years ago) |
| parents | e36d5a315205 |
| children | 1b927079ea17 7925926971e9 |
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
17 import codecs
19 try:
20 reduce
21 except NameError:
22 from functools import reduce
24 try:
25 set
26 except NameError:
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
31 nogil = False
33 def unwrap_node(node):
34 while isinstance(node, UtilNodes.ResultRefNode):
35 node = node.expression
36 return node
38 def is_common_value(a, b):
39 a = unwrap_node(a)
40 b = unwrap_node(b)
41 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
42 return a.name == b.name
43 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 return False
47 class IterationTransform(Visitor.VisitorTransform):
48 """Transform some common for-in loop patterns into efficient C loops:
50 - for-in-dict loop becomes a while loop calling PyDict_Next()
51 - for-in-enumerate is replaced by an external counter variable
52 - for-in-range loop becomes a plain C for loop
53 """
54 PyDict_Next_func_type = PyrexTypes.CFuncType(
55 PyrexTypes.c_bint_type, [
56 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
57 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
58 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
59 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
60 ])
62 PyDict_Next_name = EncodedString("PyDict_Next")
64 PyDict_Next_entry = Symtab.Entry(
65 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
67 visit_Node = Visitor.VisitorTransform.recurse_to_children
69 def visit_ModuleNode(self, node):
70 self.current_scope = node.scope
71 self.visitchildren(node)
72 return node
74 def visit_DefNode(self, node):
75 oldscope = self.current_scope
76 self.current_scope = node.entry.scope
77 self.visitchildren(node)
78 self.current_scope = oldscope
79 return node
81 def visit_ForInStatNode(self, node):
82 self.visitchildren(node)
83 return self._optimise_for_loop(node)
85 def _optimise_for_loop(self, node):
86 iterator = node.iterator.sequence
87 if iterator.type is Builtin.dict_type:
88 # like iterating over dict.keys()
89 return self._transform_dict_iteration(
90 node, dict_obj=iterator, keys=True, values=False)
92 # C array (slice) iteration?
93 if isinstance(iterator, ExprNodes.SliceIndexNode) and \
94 (iterator.base.type.is_array or iterator.base.type.is_ptr):
95 return self._transform_carray_iteration(node, iterator)
96 elif iterator.type.is_array:
97 return self._transform_carray_iteration(node, iterator)
98 elif not isinstance(iterator, ExprNodes.SimpleCallNode):
99 return node
101 function = iterator.function
102 # dict iteration?
103 if isinstance(function, ExprNodes.AttributeNode) and \
104 function.obj.type == Builtin.dict_type:
105 dict_obj = function.obj
106 method = function.attribute
108 keys = values = False
109 if method == 'iterkeys':
110 keys = True
111 elif method == 'itervalues':
112 values = True
113 elif method == 'iteritems':
114 keys = values = True
115 else:
116 return node
117 return self._transform_dict_iteration(
118 node, dict_obj, keys, values)
120 # enumerate() ?
121 if iterator.self is None and function.is_name and \
122 function.entry and function.entry.is_builtin and \
123 function.name == 'enumerate':
124 return self._transform_enumerate_iteration(node, iterator)
126 # range() iteration?
127 if Options.convert_range and node.target.type.is_int:
128 if iterator.self is None and function.is_name and \
129 function.entry and function.entry.is_builtin and \
130 function.name in ('range', 'xrange'):
131 return self._transform_range_iteration(node, iterator)
133 return node
135 def _transform_carray_iteration(self, node, slice_node):
136 if isinstance(slice_node, ExprNodes.SliceIndexNode):
137 slice_base = slice_node.base
138 start = slice_node.start
139 stop = slice_node.stop
140 step = None
141 if not stop:
142 return node
143 elif slice_node.type.is_array and slice_node.type.size is not None:
144 slice_base = slice_node
145 start = None
146 stop = ExprNodes.IntNode(
147 slice_node.pos, value=str(slice_node.type.size))
148 step = None
149 else:
150 return node
152 ptr_type = slice_base.type
153 if ptr_type.is_array:
154 ptr_type = ptr_type.element_ptr_type()
155 carray_ptr = slice_base.coerce_to_simple(self.current_scope)
157 if start and start.constant_result != 0:
158 start_ptr_node = ExprNodes.AddNode(
159 start.pos,
160 operand1=carray_ptr,
161 operator='+',
162 operand2=start,
163 type=ptr_type)
164 else:
165 start_ptr_node = carray_ptr
167 stop_ptr_node = ExprNodes.AddNode(
168 stop.pos,
169 operand1=carray_ptr,
170 operator='+',
171 operand2=stop,
172 type=ptr_type
173 ).coerce_to_simple(self.current_scope)
175 counter = UtilNodes.TempHandle(ptr_type)
176 counter_temp = counter.ref(node.target.pos)
178 if slice_base.type.is_string and node.target.type.is_pyobject:
179 # special case: char* -> bytes
180 target_value = ExprNodes.SliceIndexNode(
181 node.target.pos,
182 start=ExprNodes.IntNode(node.target.pos, value='0',
183 constant_result=0,
184 type=PyrexTypes.c_int_type),
185 stop=ExprNodes.IntNode(node.target.pos, value='1',
186 constant_result=1,
187 type=PyrexTypes.c_int_type),
188 base=counter_temp,
189 type=Builtin.bytes_type,
190 is_temp=1)
191 else:
192 target_value = ExprNodes.IndexNode(
193 node.target.pos,
194 index=ExprNodes.IntNode(node.target.pos, value='0',
195 constant_result=0,
196 type=PyrexTypes.c_int_type),
197 base=counter_temp,
198 is_buffer_access=False,
199 type=ptr_type.base_type)
201 if target_value.type != node.target.type:
202 target_value = target_value.coerce_to(node.target.type,
203 self.current_scope)
205 target_assign = Nodes.SingleAssignmentNode(
206 pos = node.target.pos,
207 lhs = node.target,
208 rhs = target_value)
210 body = Nodes.StatListNode(
211 node.pos,
212 stats = [target_assign, node.body])
214 for_node = Nodes.ForFromStatNode(
215 node.pos,
216 bound1=start_ptr_node, relation1='<=',
217 target=counter_temp,
218 relation2='<', bound2=stop_ptr_node,
219 step=step, body=body,
220 else_clause=node.else_clause,
221 from_range=True)
223 return UtilNodes.TempsBlockNode(
224 node.pos, temps=[counter],
225 body=for_node)
227 def _transform_enumerate_iteration(self, node, enumerate_function):
228 args = enumerate_function.arg_tuple.args
229 if len(args) == 0:
230 error(enumerate_function.pos,
231 "enumerate() requires an iterable argument")
232 return node
233 elif len(args) > 1:
234 error(enumerate_function.pos,
235 "enumerate() takes at most 1 argument")
236 return node
238 if not node.target.is_sequence_constructor:
239 # leave this untouched for now
240 return node
241 targets = node.target.args
242 if len(targets) != 2:
243 # leave this untouched for now
244 return node
245 if not isinstance(targets[0], ExprNodes.NameNode):
246 # leave this untouched for now
247 return node
249 enumerate_target, iterable_target = targets
250 counter_type = enumerate_target.type
252 if not counter_type.is_pyobject and not counter_type.is_int:
253 # nothing we can do here, I guess
254 return node
256 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
257 value='0',
258 type=counter_type,
259 constant_result=0))
260 inc_expression = ExprNodes.AddNode(
261 enumerate_function.pos,
262 operand1 = temp,
263 operand2 = ExprNodes.IntNode(node.pos, value='1',
264 type=counter_type,
265 constant_result=1),
266 operator = '+',
267 type = counter_type,
268 is_temp = counter_type.is_pyobject
269 )
271 loop_body = [
272 Nodes.SingleAssignmentNode(
273 pos = enumerate_target.pos,
274 lhs = enumerate_target,
275 rhs = temp),
276 Nodes.SingleAssignmentNode(
277 pos = enumerate_target.pos,
278 lhs = temp,
279 rhs = inc_expression)
280 ]
282 if isinstance(node.body, Nodes.StatListNode):
283 node.body.stats = loop_body + node.body.stats
284 else:
285 loop_body.append(node.body)
286 node.body = Nodes.StatListNode(
287 node.body.pos,
288 stats = loop_body)
290 node.target = iterable_target
291 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
292 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
294 # recurse into loop to check for further optimisations
295 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
297 def _transform_range_iteration(self, node, range_function):
298 args = range_function.arg_tuple.args
299 if len(args) < 3:
300 step_pos = range_function.pos
301 step_value = 1
302 step = ExprNodes.IntNode(step_pos, value='1',
303 constant_result=1)
304 else:
305 step = args[2]
306 step_pos = step.pos
307 if not isinstance(step.constant_result, (int, long)):
308 # cannot determine step direction
309 return node
310 step_value = step.constant_result
311 if step_value == 0:
312 # will lead to an error elsewhere
313 return node
314 if not isinstance(step, ExprNodes.IntNode):
315 step = ExprNodes.IntNode(step_pos, value=str(step_value),
316 constant_result=step_value)
318 if step_value < 0:
319 step.value = str(-step_value)
320 relation1 = '>='
321 relation2 = '>'
322 else:
323 relation1 = '<='
324 relation2 = '<'
326 if len(args) == 1:
327 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
328 constant_result=0)
329 bound2 = args[0].coerce_to_integer(self.current_scope)
330 else:
331 bound1 = args[0].coerce_to_integer(self.current_scope)
332 bound2 = args[1].coerce_to_integer(self.current_scope)
333 step = step.coerce_to_integer(self.current_scope)
335 if not bound2.is_literal:
336 # stop bound must be immutable => keep it in a temp var
337 bound2_is_temp = True
338 bound2 = UtilNodes.LetRefNode(bound2)
339 else:
340 bound2_is_temp = False
342 for_node = Nodes.ForFromStatNode(
343 node.pos,
344 target=node.target,
345 bound1=bound1, relation1=relation1,
346 relation2=relation2, bound2=bound2,
347 step=step, body=node.body,
348 else_clause=node.else_clause,
349 from_range=True)
351 if bound2_is_temp:
352 for_node = UtilNodes.LetNode(bound2, for_node)
354 return for_node
356 def _transform_dict_iteration(self, node, dict_obj, keys, values):
357 py_object_ptr = PyrexTypes.c_void_ptr_type
359 temps = []
360 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
361 temps.append(temp)
362 dict_temp = temp.ref(dict_obj.pos)
363 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
364 temps.append(temp)
365 pos_temp = temp.ref(node.pos)
366 pos_temp_addr = ExprNodes.AmpersandNode(
367 node.pos, operand=pos_temp,
368 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
369 if keys:
370 temp = UtilNodes.TempHandle(py_object_ptr)
371 temps.append(temp)
372 key_temp = temp.ref(node.target.pos)
373 key_temp_addr = ExprNodes.AmpersandNode(
374 node.target.pos, operand=key_temp,
375 type=PyrexTypes.c_ptr_type(py_object_ptr))
376 else:
377 key_temp_addr = key_temp = ExprNodes.NullNode(
378 pos=node.target.pos)
379 if values:
380 temp = UtilNodes.TempHandle(py_object_ptr)
381 temps.append(temp)
382 value_temp = temp.ref(node.target.pos)
383 value_temp_addr = ExprNodes.AmpersandNode(
384 node.target.pos, operand=value_temp,
385 type=PyrexTypes.c_ptr_type(py_object_ptr))
386 else:
387 value_temp_addr = value_temp = ExprNodes.NullNode(
388 pos=node.target.pos)
390 key_target = value_target = node.target
391 tuple_target = None
392 if keys and values:
393 if node.target.is_sequence_constructor:
394 if len(node.target.args) == 2:
395 key_target, value_target = node.target.args
396 else:
397 # unusual case that may or may not lead to an error
398 return node
399 else:
400 tuple_target = node.target
402 def coerce_object_to(obj_node, dest_type):
403 if dest_type.is_pyobject:
404 if dest_type != obj_node.type:
405 if dest_type.is_extension_type or dest_type.is_builtin_type:
406 obj_node = ExprNodes.PyTypeTestNode(
407 obj_node, dest_type, self.current_scope, notnone=True)
408 result = ExprNodes.TypecastNode(
409 obj_node.pos,
410 operand = obj_node,
411 type = dest_type)
412 return (result, None)
413 else:
414 temp = UtilNodes.TempHandle(dest_type)
415 temps.append(temp)
416 temp_result = temp.ref(obj_node.pos)
417 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
418 def result(self):
419 return temp_result.result()
420 def generate_execution_code(self, code):
421 self.generate_result_code(code)
422 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
424 if isinstance(node.body, Nodes.StatListNode):
425 body = node.body
426 else:
427 body = Nodes.StatListNode(pos = node.body.pos,
428 stats = [node.body])
430 if tuple_target:
431 tuple_result = ExprNodes.TupleNode(
432 pos = tuple_target.pos,
433 args = [key_temp, value_temp],
434 is_temp = 1,
435 type = Builtin.tuple_type,
436 )
437 body.stats.insert(
438 0, Nodes.SingleAssignmentNode(
439 pos = tuple_target.pos,
440 lhs = tuple_target,
441 rhs = tuple_result))
442 else:
443 # execute all coercions before the assignments
444 coercion_stats = []
445 assign_stats = []
446 if keys:
447 temp_result, coercion = coerce_object_to(
448 key_temp, key_target.type)
449 if coercion:
450 coercion_stats.append(coercion)
451 assign_stats.append(
452 Nodes.SingleAssignmentNode(
453 pos = key_temp.pos,
454 lhs = key_target,
455 rhs = temp_result))
456 if values:
457 temp_result, coercion = coerce_object_to(
458 value_temp, value_target.type)
459 if coercion:
460 coercion_stats.append(coercion)
461 assign_stats.append(
462 Nodes.SingleAssignmentNode(
463 pos = value_temp.pos,
464 lhs = value_target,
465 rhs = temp_result))
466 body.stats[0:0] = coercion_stats + assign_stats
468 result_code = [
469 Nodes.SingleAssignmentNode(
470 pos = dict_obj.pos,
471 lhs = dict_temp,
472 rhs = dict_obj),
473 Nodes.SingleAssignmentNode(
474 pos = node.pos,
475 lhs = pos_temp,
476 rhs = ExprNodes.IntNode(node.pos, value='0',
477 constant_result=0)),
478 Nodes.WhileStatNode(
479 pos = node.pos,
480 condition = ExprNodes.SimpleCallNode(
481 pos = dict_obj.pos,
482 type = PyrexTypes.c_bint_type,
483 function = ExprNodes.NameNode(
484 pos = dict_obj.pos,
485 name = self.PyDict_Next_name,
486 type = self.PyDict_Next_func_type,
487 entry = self.PyDict_Next_entry),
488 args = [dict_temp, pos_temp_addr,
489 key_temp_addr, value_temp_addr]
490 ),
491 body = body,
492 else_clause = node.else_clause
493 )
494 ]
496 return UtilNodes.TempsBlockNode(
497 node.pos, temps=temps,
498 body=Nodes.StatListNode(
499 node.pos,
500 stats = result_code
501 ))
504 class SwitchTransform(Visitor.VisitorTransform):
505 """
506 This transformation tries to turn long if statements into C switch statements.
507 The requirement is that every clause be an (or of) var == value, where the var
508 is common among all clauses and both var and value are ints.
509 """
510 def extract_conditions(self, cond):
511 while True:
512 if isinstance(cond, ExprNodes.CoerceToTempNode):
513 cond = cond.arg
514 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
515 # this is what we get from the FlattenInListTransform
516 cond = cond.subexpression
517 elif isinstance(cond, ExprNodes.TypecastNode):
518 cond = cond.operand
519 else:
520 break
522 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
523 and cond.cascade is None
524 and cond.operator == '=='
525 and not cond.is_python_comparison()):
526 if is_common_value(cond.operand1, cond.operand1):
527 if cond.operand2.is_literal:
528 return cond.operand1, [cond.operand2]
529 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
530 return cond.operand1, [cond.operand2]
531 if is_common_value(cond.operand2, cond.operand2):
532 if cond.operand1.is_literal:
533 return cond.operand2, [cond.operand1]
534 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
535 return cond.operand2, [cond.operand1]
536 elif (isinstance(cond, ExprNodes.BoolBinopNode)
537 and cond.operator == 'or'):
538 t1, c1 = self.extract_conditions(cond.operand1)
539 t2, c2 = self.extract_conditions(cond.operand2)
540 if is_common_value(t1, t2):
541 return t1, c1+c2
542 return None, None
544 def visit_IfStatNode(self, node):
545 self.visitchildren(node)
546 common_var = None
547 case_count = 0
548 cases = []
549 for if_clause in node.if_clauses:
550 var, conditions = self.extract_conditions(if_clause.condition)
551 if var is None:
552 return node
553 elif common_var is not None and not is_common_value(var, common_var):
554 return node
555 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
556 return node
557 else:
558 common_var = var
559 case_count += len(conditions)
560 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
561 conditions = conditions,
562 body = if_clause.body))
563 if case_count < 2:
564 return node
566 common_var = unwrap_node(common_var)
567 return Nodes.SwitchStatNode(pos = node.pos,
568 test = common_var,
569 cases = cases,
570 else_clause = node.else_clause)
572 visit_Node = Visitor.VisitorTransform.recurse_to_children
575 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
576 """
577 This transformation flattens "x in [val1, ..., valn]" into a sequential list
578 of comparisons.
579 """
581 def visit_PrimaryCmpNode(self, node):
582 self.visitchildren(node)
583 if node.cascade is not None:
584 return node
585 elif node.operator == 'in':
586 conjunction = 'or'
587 eq_or_neq = '=='
588 elif node.operator == 'not_in':
589 conjunction = 'and'
590 eq_or_neq = '!='
591 else:
592 return node
594 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
595 return node
597 args = node.operand2.args
598 if len(args) == 0:
599 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
601 lhs = UtilNodes.ResultRefNode(node.operand1)
603 conds = []
604 for arg in args:
605 cond = ExprNodes.PrimaryCmpNode(
606 pos = node.pos,
607 operand1 = lhs,
608 operator = eq_or_neq,
609 operand2 = arg,
610 cascade = None)
611 conds.append(ExprNodes.TypecastNode(
612 pos = node.pos,
613 operand = cond,
614 type = PyrexTypes.c_bint_type))
615 def concat(left, right):
616 return ExprNodes.BoolBinopNode(
617 pos = node.pos,
618 operator = conjunction,
619 operand1 = left,
620 operand2 = right)
622 condition = reduce(concat, conds)
623 return UtilNodes.EvalWithTempExprNode(lhs, condition)
625 visit_Node = Visitor.VisitorTransform.recurse_to_children
628 class DropRefcountingTransform(Visitor.VisitorTransform):
629 """Drop ref-counting in safe places.
630 """
631 visit_Node = Visitor.VisitorTransform.recurse_to_children
633 def visit_ParallelAssignmentNode(self, node):
634 """
635 Parallel swap assignments like 'a,b = b,a' are safe.
636 """
637 left_names, right_names = [], []
638 left_indices, right_indices = [], []
639 temps = []
641 for stat in node.stats:
642 if isinstance(stat, Nodes.SingleAssignmentNode):
643 if not self._extract_operand(stat.lhs, left_names,
644 left_indices, temps):
645 return node
646 if not self._extract_operand(stat.rhs, right_names,
647 right_indices, temps):
648 return node
649 elif isinstance(stat, Nodes.CascadedAssignmentNode):
650 # FIXME
651 return node
652 else:
653 return node
655 if left_names or right_names:
656 # lhs/rhs names must be a non-redundant permutation
657 lnames = [ path for path, n in left_names ]
658 rnames = [ path for path, n in right_names ]
659 if set(lnames) != set(rnames):
660 return node
661 if len(set(lnames)) != len(right_names):
662 return node
664 if left_indices or right_indices:
665 # base name and index of index nodes must be a
666 # non-redundant permutation
667 lindices = []
668 for lhs_node in left_indices:
669 index_id = self._extract_index_id(lhs_node)
670 if not index_id:
671 return node
672 lindices.append(index_id)
673 rindices = []
674 for rhs_node in right_indices:
675 index_id = self._extract_index_id(rhs_node)
676 if not index_id:
677 return node
678 rindices.append(index_id)
680 if set(lindices) != set(rindices):
681 return node
682 if len(set(lindices)) != len(right_indices):
683 return node
685 # really supporting IndexNode requires support in
686 # __Pyx_GetItemInt(), so let's stop short for now
687 return node
689 temp_args = [t.arg for t in temps]
690 for temp in temps:
691 temp.use_managed_ref = False
693 for _, name_node in left_names + right_names:
694 if name_node not in temp_args:
695 name_node.use_managed_ref = False
697 for index_node in left_indices + right_indices:
698 index_node.use_managed_ref = False
700 return node
702 def _extract_operand(self, node, names, indices, temps):
703 node = unwrap_node(node)
704 if not node.type.is_pyobject:
705 return False
706 if isinstance(node, ExprNodes.CoerceToTempNode):
707 temps.append(node)
708 node = node.arg
709 name_path = []
710 obj_node = node
711 while isinstance(obj_node, ExprNodes.AttributeNode):
712 if obj_node.is_py_attr:
713 return False
714 name_path.append(obj_node.member)
715 obj_node = obj_node.obj
716 if isinstance(obj_node, ExprNodes.NameNode):
717 name_path.append(obj_node.name)
718 names.append( ('.'.join(name_path[::-1]), node) )
719 elif isinstance(node, ExprNodes.IndexNode):
720 if node.base.type != Builtin.list_type:
721 return False
722 if not node.index.type.is_int:
723 return False
724 if not isinstance(node.base, ExprNodes.NameNode):
725 return False
726 indices.append(node)
727 else:
728 return False
729 return True
731 def _extract_index_id(self, index_node):
732 base = index_node.base
733 index = index_node.index
734 if isinstance(index, ExprNodes.NameNode):
735 index_val = index.name
736 elif isinstance(index, ExprNodes.ConstNode):
737 # FIXME:
738 return None
739 else:
740 return None
741 return (base.name, index_val)
744 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
745 """Optimize some common calls to builtin types *before* the type
746 analysis phase and *after* the declarations analysis phase.
748 This transform cannot make use of any argument types, but it can
749 restructure the tree in a way that the type analysis phase can
750 respond to.
752 Introducing C function calls here may not be a good idea. Move
753 them to the OptimizeBuiltinCalls transform instead, which runs
754 after type analyis.
755 """
756 # only intercept on call nodes
757 visit_Node = Visitor.VisitorTransform.recurse_to_children
759 def visit_SimpleCallNode(self, node):
760 self.visitchildren(node)
761 function = node.function
762 if not self._function_is_builtin_name(function):
763 return node
764 return self._dispatch_to_handler(node, function, node.args)
766 def visit_GeneralCallNode(self, node):
767 self.visitchildren(node)
768 function = node.function
769 if not self._function_is_builtin_name(function):
770 return node
771 arg_tuple = node.positional_args
772 if not isinstance(arg_tuple, ExprNodes.TupleNode):
773 return node
774 args = arg_tuple.args
775 return self._dispatch_to_handler(
776 node, function, args, node.keyword_args)
778 def _function_is_builtin_name(self, function):
779 if not function.is_name:
780 return False
781 entry = self.env_stack[-1].lookup(function.name)
782 if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
783 return False
784 return True
786 def _dispatch_to_handler(self, node, function, args, kwargs=None):
787 if kwargs is None:
788 handler_name = '_handle_simple_function_%s' % function.name
789 else:
790 handler_name = '_handle_general_function_%s' % function.name
791 handle_call = getattr(self, handler_name, None)
792 if handle_call is not None:
793 if kwargs is None:
794 return handle_call(node, args)
795 else:
796 return handle_call(node, args, kwargs)
797 return node
799 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
800 node.function = ExprNodes.PythonCapiFunctionNode(
801 node.function.pos, node.function.name, cname, func_type,
802 utility_code = utility_code)
804 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
805 if not expected: # None or 0
806 arg_str = ''
807 elif isinstance(expected, basestring) or expected > 1:
808 arg_str = '...'
809 elif expected == 1:
810 arg_str = 'x'
811 else:
812 arg_str = ''
813 if expected is not None:
814 expected_str = 'expected %s, ' % expected
815 else:
816 expected_str = ''
817 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
818 function_name, arg_str, expected_str, len(args)))
820 # specific handlers for simple call nodes
822 def _handle_simple_function_set(self, node, pos_args):
823 """Replace set([a,b,...]) by a literal set {a,b,...} and
824 set([ x for ... ]) by a literal { x for ... }.
825 """
826 arg_count = len(pos_args)
827 if arg_count == 0:
828 return ExprNodes.SetNode(node.pos, args=[],
829 type=Builtin.set_type)
830 if arg_count > 1:
831 return node
832 iterable = pos_args[0]
833 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
834 return ExprNodes.SetNode(node.pos, args=iterable.args)
835 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
836 isinstance(iterable.target, (ExprNodes.ListNode,
837 ExprNodes.SetNode)):
838 iterable.target = ExprNodes.SetNode(node.pos, args=[])
839 iterable.pos = node.pos
840 return iterable
841 else:
842 return node
844 def _handle_simple_function_dict(self, node, pos_args):
845 """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
846 """
847 if len(pos_args) != 1:
848 return node
849 arg = pos_args[0]
850 if isinstance(arg, ExprNodes.ComprehensionNode) and \
851 isinstance(arg.target, (ExprNodes.ListNode,
852 ExprNodes.SetNode)):
853 append_node = arg.append
854 if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
855 len(append_node.expr.args) == 2:
856 key_node, value_node = append_node.expr.args
857 target_node = ExprNodes.DictNode(
858 pos=arg.target.pos, key_value_pairs=[])
859 new_append_node = ExprNodes.DictComprehensionAppendNode(
860 append_node.pos, target=target_node,
861 key_expr=key_node, value_expr=value_node)
862 arg.target = target_node
863 arg.type = target_node.type
864 replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
865 return replace_in(arg)
866 return node
868 def _handle_simple_function_float(self, node, pos_args):
869 if len(pos_args) == 0:
870 return ExprNodes.FloatNode(node.pos, value='0.0')
871 if len(pos_args) > 1:
872 self._error_wrong_arg_count('float', node, pos_args, 1)
873 return node
875 # specific handlers for general call nodes
877 def _handle_general_function_dict(self, node, pos_args, kwargs):
878 """Replace dict(a=b,c=d,...) by the underlying keyword dict
879 construction which is done anyway.
880 """
881 if len(pos_args) > 0:
882 return node
883 if not isinstance(kwargs, ExprNodes.DictNode):
884 return node
885 if node.starstar_arg:
886 # we could optimize this by updating the kw dict instead
887 return node
888 return kwargs
891 class OptimizeBuiltinCalls(Visitor.EnvTransform):
892 """Optimize some common methods calls and instantiation patterns
893 for builtin types *after* the type analysis phase.
895 Running after type analysis, this transform can only perform
896 function replacements that do not alter the function return type
897 in a way that was not anticipated by the type analysis.
898 """
899 # only intercept on call nodes
900 visit_Node = Visitor.VisitorTransform.recurse_to_children
902 def visit_GeneralCallNode(self, node):
903 self.visitchildren(node)
904 function = node.function
905 if not function.type.is_pyobject:
906 return node
907 arg_tuple = node.positional_args
908 if not isinstance(arg_tuple, ExprNodes.TupleNode):
909 return node
910 args = arg_tuple.args
911 return self._dispatch_to_handler(
912 node, function, args, node.keyword_args)
914 def visit_SimpleCallNode(self, node):
915 self.visitchildren(node)
916 function = node.function
917 if function.type.is_pyobject:
918 arg_tuple = node.arg_tuple
919 if not isinstance(arg_tuple, ExprNodes.TupleNode):
920 return node
921 args = arg_tuple.args
922 else:
923 args = node.args
924 return self._dispatch_to_handler(
925 node, function, args)
927 ### cleanup to avoid redundant coercions to/from Python types
929 def _visit_PyTypeTestNode(self, node):
930 # disabled - appears to break assignments in some cases, and
931 # also drops a None check, which might still be required
932 """Flatten redundant type checks after tree changes.
933 """
934 old_arg = node.arg
935 self.visitchildren(node)
936 if old_arg is node.arg or node.arg.type != node.type:
937 return node
938 return node.arg
940 def visit_CoerceFromPyTypeNode(self, node):
941 """Drop redundant conversion nodes after tree changes.
943 Also, optimise away calls to Python's builtin int() and
944 float() if the result is going to be coerced back into a C
945 type anyway.
946 """
947 self.visitchildren(node)
948 arg = node.arg
949 if not arg.type.is_pyobject:
950 # no Python conversion left at all, just do a C coercion instead
951 if node.type == arg.type:
952 return arg
953 else:
954 return arg.coerce_to(node.type, self.env_stack[-1])
955 if not isinstance(arg, ExprNodes.SimpleCallNode):
956 return node
957 if not (node.type.is_int or node.type.is_float):
958 return node
959 function = arg.function
960 if not isinstance(function, ExprNodes.NameNode) \
961 or not function.type.is_builtin_type \
962 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
963 return node
964 args = arg.arg_tuple.args
965 if len(args) != 1:
966 return node
967 func_arg = args[0]
968 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
969 func_arg = func_arg.arg
970 elif func_arg.type.is_pyobject:
971 # play safe: Python conversion might work on all sorts of things
972 return node
973 if function.name == 'int':
974 if func_arg.type.is_int or node.type.is_int:
975 if func_arg.type == node.type:
976 return func_arg
977 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
978 return ExprNodes.TypecastNode(
979 node.pos, operand=func_arg, type=node.type)
980 elif function.name == 'float':
981 if func_arg.type.is_float or node.type.is_float:
982 if func_arg.type == node.type:
983 return func_arg
984 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
985 return ExprNodes.TypecastNode(
986 node.pos, operand=func_arg, type=node.type)
987 return node
989 ### dispatch to specific optimisers
991 def _find_handler(self, match_name, has_kwargs):
992 call_type = has_kwargs and 'general' or 'simple'
993 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
994 if handler is None:
995 handler = getattr(self, '_handle_any_%s' % match_name, None)
996 return handler
998 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
999 if function.is_name:
1000 # we only consider functions that are either builtin
1001 # Python functions or builtins that were already replaced
1002 # into a C function call (defined in the builtin scope)
1003 if not function.entry:
1004 return node
1005 is_builtin = function.entry.is_builtin \
1006 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1007 if not is_builtin:
1008 return node
1009 function_handler = self._find_handler(
1010 "function_%s" % function.name, kwargs)
1011 if function_handler is None:
1012 return node
1013 if kwargs:
1014 return function_handler(node, arg_list, kwargs)
1015 else:
1016 return function_handler(node, arg_list)
1017 elif function.is_attribute and function.type.is_pyobject:
1018 attr_name = function.attribute
1019 self_arg = function.obj
1020 obj_type = self_arg.type
1021 is_unbound_method = False
1022 if obj_type.is_builtin_type:
1023 if obj_type is Builtin.type_type and arg_list and \
1024 arg_list[0].type.is_pyobject:
1025 # calling an unbound method like 'list.append(L,x)'
1026 # (ignoring 'type.mro()' here ...)
1027 type_name = function.obj.name
1028 self_arg = None
1029 is_unbound_method = True
1030 else:
1031 type_name = obj_type.name
1032 else:
1033 type_name = "object" # safety measure
1034 method_handler = self._find_handler(
1035 "method_%s_%s" % (type_name, attr_name), kwargs)
1036 if method_handler is None:
1037 if attr_name in TypeSlots.method_name_to_slot \
1038 or attr_name == '__new__':
1039 method_handler = self._find_handler(
1040 "slot%s" % attr_name, kwargs)
1041 if method_handler is None:
1042 return node
1043 if self_arg is not None:
1044 arg_list = [self_arg] + list(arg_list)
1045 if kwargs:
1046 return method_handler(node, arg_list, kwargs, is_unbound_method)
1047 else:
1048 return method_handler(node, arg_list, is_unbound_method)
1049 else:
1050 return node
1052 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1053 if not expected: # None or 0
1054 arg_str = ''
1055 elif isinstance(expected, basestring) or expected > 1:
1056 arg_str = '...'
1057 elif expected == 1:
1058 arg_str = 'x'
1059 else:
1060 arg_str = ''
1061 if expected is not None:
1062 expected_str = 'expected %s, ' % expected
1063 else:
1064 expected_str = ''
1065 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1066 function_name, arg_str, expected_str, len(args)))
1068 ### builtin types
1070 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1071 Builtin.dict_type, [
1072 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1073 ])
1075 def _handle_simple_function_dict(self, node, pos_args):
1076 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1077 """
1078 if len(pos_args) != 1:
1079 return node
1080 arg = pos_args[0]
1081 if arg.type is Builtin.dict_type:
1082 arg = ExprNodes.NoneCheckNode(
1083 arg, "PyExc_TypeError", "'NoneType' is not iterable")
1084 return ExprNodes.PythonCapiCallNode(
1085 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1086 args = [arg],
1087 is_temp = node.is_temp
1088 )
1089 return node
1091 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1092 Builtin.tuple_type, [
1093 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1094 ])
1096 def _handle_simple_function_tuple(self, node, pos_args):
1097 """Replace tuple([...]) by a call to PyList_AsTuple.
1098 """
1099 if len(pos_args) != 1:
1100 return node
1101 list_arg = pos_args[0]
1102 if list_arg.type is not Builtin.list_type:
1103 return node
1104 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1105 ExprNodes.ListNode)):
1106 pos_args[0] = ExprNodes.NoneCheckNode(
1107 list_arg, "PyExc_TypeError",
1108 "'NoneType' object is not iterable")
1110 return ExprNodes.PythonCapiCallNode(
1111 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1112 args = pos_args,
1113 is_temp = node.is_temp
1114 )
1116 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1117 PyrexTypes.c_double_type, [
1118 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1119 ],
1120 exception_value = "((double)-1)",
1121 exception_check = True)
1123 def _handle_simple_function_float(self, node, pos_args):
1124 # Note: this requires the float() function to be typed as
1125 # returning a C 'double'
1126 if len(pos_args) != 1:
1127 self._error_wrong_arg_count('float', node, pos_args, 1)
1128 return node
1129 func_arg = pos_args[0]
1130 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1131 func_arg = func_arg.arg
1132 if func_arg.type is PyrexTypes.c_double_type:
1133 return func_arg
1134 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1135 return ExprNodes.TypecastNode(
1136 node.pos, operand=func_arg, type=node.type)
1137 return ExprNodes.PythonCapiCallNode(
1138 node.pos, "__Pyx_PyObject_AsDouble",
1139 self.PyObject_AsDouble_func_type,
1140 args = pos_args,
1141 is_temp = node.is_temp,
1142 utility_code = pyobject_as_double_utility_code,
1143 py_name = "float")
1145 ### builtin functions
1147 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
1148 PyrexTypes.py_object_type, [
1149 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1150 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1151 ])
1153 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
1154 PyrexTypes.py_object_type, [
1155 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1156 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1157 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1158 ])
1160 def _handle_simple_function_getattr(self, node, pos_args):
1161 if len(pos_args) == 2:
1162 return ExprNodes.PythonCapiCallNode(
1163 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
1164 args = pos_args,
1165 is_temp = node.is_temp)
1166 elif len(pos_args) == 3:
1167 return ExprNodes.PythonCapiCallNode(
1168 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
1169 args = pos_args,
1170 is_temp = node.is_temp,
1171 utility_code = Builtin.getattr3_utility_code)
1172 else:
1173 self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
1174 return node
1176 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1177 PyrexTypes.c_size_t_type, [
1178 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1179 ])
1181 def _handle_simple_function_len(self, node, pos_args):
1182 # note: this only works because we already replaced len() by
1183 # PyObject_Length() which returns a Py_ssize_t instead of a
1184 # Python object, so we can return a plain size_t instead
1185 # without caring about Python object conversion etc.
1186 if len(pos_args) != 1:
1187 self._error_wrong_arg_count('len', node, pos_args, 1)
1188 return node
1189 arg = pos_args[0]
1190 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1191 arg = arg.arg
1192 if not arg.type.is_string:
1193 return node
1194 node = ExprNodes.PythonCapiCallNode(
1195 node.pos, "strlen", self.Pyx_strlen_func_type,
1196 args = [arg],
1197 is_temp = node.is_temp,
1198 utility_code = include_string_h_utility_code
1199 )
1200 return node
1202 Pyx_Type_func_type = PyrexTypes.CFuncType(
1203 Builtin.type_type, [
1204 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1205 ])
1207 def _handle_simple_function_type(self, node, pos_args):
1208 if len(pos_args) != 1:
1209 return node
1210 node = ExprNodes.PythonCapiCallNode(
1211 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1212 args = pos_args,
1213 is_temp = False)
1214 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1216 ### special methods
1218 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1219 PyrexTypes.py_object_type, [
1220 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1221 ])
1223 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1224 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1225 """
1226 obj = node.function.obj
1227 if not is_unbound_method or len(args) != 1:
1228 return node
1229 type_arg = args[0]
1230 if not obj.is_name or not type_arg.is_name:
1231 # play safe
1232 return node
1233 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1234 # not a known type, play safe
1235 return node
1236 if not type_arg.type_entry or not obj.type_entry:
1237 if obj.name != type_arg.name:
1238 return node
1239 # otherwise, we know it's a type and we know it's the same
1240 # type for both - that should do
1241 elif type_arg.type_entry != obj.type_entry:
1242 # different types - may or may not lead to an error at runtime
1243 return node
1245 # FIXME: we could potentially look up the actual tp_new C
1246 # method of the extension type and call that instead of the
1247 # generic slot. That would also allow us to pass parameters
1248 # efficiently.
1250 if not type_arg.type_entry:
1251 # arbitrary variable, needs a None check for safety
1252 type_arg = ExprNodes.NoneCheckNode(
1253 type_arg, "PyExc_TypeError",
1254 "object.__new__(X): X is not a type object (NoneType)")
1256 return ExprNodes.PythonCapiCallNode(
1257 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1258 args = [type_arg],
1259 utility_code = tpnew_utility_code,
1260 is_temp = node.is_temp
1261 )
1263 ### methods of builtin types
1265 PyObject_Append_func_type = PyrexTypes.CFuncType(
1266 PyrexTypes.py_object_type, [
1267 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1268 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1269 ])
1271 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
1272 # X.append() is almost always referring to a list
1273 if len(args) != 2:
1274 return node
1276 return ExprNodes.PythonCapiCallNode(
1277 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
1278 args = args,
1279 is_temp = node.is_temp,
1280 utility_code = append_utility_code
1281 )
1283 PyObject_Pop_func_type = PyrexTypes.CFuncType(
1284 PyrexTypes.py_object_type, [
1285 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1286 ])
1288 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
1289 PyrexTypes.py_object_type, [
1290 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1291 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
1292 ])
1294 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
1295 # X.pop([n]) is almost always referring to a list
1296 if len(args) == 1:
1297 return ExprNodes.PythonCapiCallNode(
1298 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
1299 args = args,
1300 is_temp = node.is_temp,
1301 utility_code = pop_utility_code
1302 )
1303 elif len(args) == 2:
1304 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
1305 original_type = args[1].arg.type
1306 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
1307 args[1] = args[1].arg
1308 return ExprNodes.PythonCapiCallNode(
1309 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
1310 args = args,
1311 is_temp = node.is_temp,
1312 utility_code = pop_index_utility_code
1313 )
1315 return node
1317 PyList_Append_func_type = PyrexTypes.CFuncType(
1318 PyrexTypes.c_int_type, [
1319 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1320 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1321 ],
1322 exception_value = "-1")
1324 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
1325 if len(args) != 2:
1326 self._error_wrong_arg_count('list.append', node, args, 2)
1327 return node
1328 return self._substitute_method_call(
1329 node, "PyList_Append", self.PyList_Append_func_type,
1330 'append', is_unbound_method, args)
1332 single_param_func_type = PyrexTypes.CFuncType(
1333 PyrexTypes.c_int_type, [
1334 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1335 ],
1336 exception_value = "-1")
1338 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
1339 if len(args) != 1:
1340 return node
1341 return self._substitute_method_call(
1342 node, "PyList_Sort", self.single_param_func_type,
1343 'sort', is_unbound_method, args)
1345 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
1346 if len(args) != 1:
1347 self._error_wrong_arg_count('list.reverse', node, args, 1)
1348 return node
1349 return self._substitute_method_call(
1350 node, "PyList_Reverse", self.single_param_func_type,
1351 'reverse', is_unbound_method, args)
1353 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
1354 Builtin.bytes_type, [
1355 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1356 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1357 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1358 ],
1359 exception_value = "NULL")
1361 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
1362 Builtin.bytes_type, [
1363 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1364 ],
1365 exception_value = "NULL")
1367 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
1368 'unicode_escape', 'raw_unicode_escape']
1370 _special_codecs = [ (name, codecs.getencoder(name))
1371 for name in _special_encodings ]
1373 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
1374 if len(args) < 1 or len(args) > 3:
1375 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1376 return node
1378 string_node = args[0]
1380 if len(args) == 1:
1381 null_node = ExprNodes.NullNode(node.pos)
1382 return self._substitute_method_call(
1383 node, "PyUnicode_AsEncodedString",
1384 self.PyUnicode_AsEncodedString_func_type,
1385 'encode', is_unbound_method, [string_node, null_node, null_node])
1387 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1388 if parameters is None:
1389 return node
1390 encoding, encoding_node, error_handling, error_handling_node = parameters
1392 if isinstance(string_node, ExprNodes.UnicodeNode):
1393 # constant, so try to do the encoding at compile time
1394 try:
1395 value = string_node.value.encode(encoding, error_handling)
1396 except:
1397 # well, looks like we can't
1398 pass
1399 else:
1400 value = BytesLiteral(value)
1401 value.encoding = encoding
1402 return ExprNodes.BytesNode(
1403 string_node.pos, value=value, type=Builtin.bytes_type)
1405 if error_handling == 'strict':
1406 # try to find a specific encoder function
1407 codec_name = self._find_special_codec_name(encoding)
1408 if codec_name is not None:
1409 encode_function = "PyUnicode_As%sString" % codec_name
1410 return self._substitute_method_call(
1411 node, encode_function,
1412 self.PyUnicode_AsXyzString_func_type,
1413 'encode', is_unbound_method, [string_node])
1415 return self._substitute_method_call(
1416 node, "PyUnicode_AsEncodedString",
1417 self.PyUnicode_AsEncodedString_func_type,
1418 'encode', is_unbound_method,
1419 [string_node, encoding_node, error_handling_node])
1421 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
1422 Builtin.unicode_type, [
1423 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1424 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1425 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1426 ],
1427 exception_value = "NULL")
1429 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
1430 Builtin.unicode_type, [
1431 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1432 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1433 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1434 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1435 ],
1436 exception_value = "NULL")
1438 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
1439 if len(args) < 1 or len(args) > 3:
1440 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
1441 return node
1442 temps = []
1443 if isinstance(args[0], ExprNodes.SliceIndexNode):
1444 index_node = args[0]
1445 string_node = index_node.base
1446 if not string_node.type.is_string:
1447 # nothing to optimise here
1448 return node
1449 start, stop = index_node.start, index_node.stop
1450 if not start or start.constant_result == 0:
1451 start = None
1452 else:
1453 if start.type.is_pyobject:
1454 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1455 if stop:
1456 start = UtilNodes.LetRefNode(start)
1457 temps.append(start)
1458 string_node = ExprNodes.AddNode(pos=start.pos,
1459 operand1=string_node,
1460 operator='+',
1461 operand2=start,
1462 is_temp=False,
1463 type=string_node.type
1464 )
1465 if stop and stop.type.is_pyobject:
1466 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1467 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
1468 and args[0].arg.type.is_string:
1469 # use strlen() to find the string length, just as CPython would
1470 start = stop = None
1471 string_node = args[0].arg
1472 else:
1473 # let Python do its job
1474 return node
1476 if not stop:
1477 if start or not string_node.is_name:
1478 string_node = UtilNodes.LetRefNode(string_node)
1479 temps.append(string_node)
1480 stop = ExprNodes.PythonCapiCallNode(
1481 string_node.pos, "strlen", self.Pyx_strlen_func_type,
1482 args = [string_node],
1483 is_temp = False,
1484 utility_code = include_string_h_utility_code,
1485 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1486 elif start:
1487 stop = ExprNodes.SubNode(
1488 pos = stop.pos,
1489 operand1 = stop,
1490 operator = '-',
1491 operand2 = start,
1492 is_temp = False,
1493 type = PyrexTypes.c_py_ssize_t_type
1494 )
1496 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1497 if parameters is None:
1498 return node
1499 encoding, encoding_node, error_handling, error_handling_node = parameters
1501 # try to find a specific encoder function
1502 codec_name = None
1503 if encoding is not None:
1504 codec_name = self._find_special_codec_name(encoding)
1505 if codec_name is not None:
1506 decode_function = "PyUnicode_Decode%s" % codec_name
1507 node = ExprNodes.PythonCapiCallNode(
1508 node.pos, decode_function,
1509 self.PyUnicode_DecodeXyz_func_type,
1510 args = [string_node, stop, error_handling_node],
1511 is_temp = node.is_temp,
1512 )
1513 else:
1514 node = ExprNodes.PythonCapiCallNode(
1515 node.pos, "PyUnicode_Decode",
1516 self.PyUnicode_Decode_func_type,
1517 args = [string_node, stop, encoding_node, error_handling_node],
1518 is_temp = node.is_temp,
1519 )
1521 for temp in temps[::-1]:
1522 node = UtilNodes.EvalWithTempExprNode(temp, node)
1523 return node
1525 def _find_special_codec_name(self, encoding):
1526 try:
1527 requested_codec = codecs.getencoder(encoding)
1528 except:
1529 return None
1530 for name, codec in self._special_codecs:
1531 if codec == requested_codec:
1532 if '_' in name:
1533 name = ''.join([ s.capitalize()
1534 for s in name.split('_')])
1535 return name
1536 return None
1538 def _unpack_encoding_and_error_mode(self, pos, args):
1539 encoding_node = args[1]
1540 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
1541 encoding_node = encoding_node.arg
1542 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1543 ExprNodes.BytesNode)):
1544 encoding = encoding_node.value
1545 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
1546 type=PyrexTypes.c_char_ptr_type)
1547 elif encoding_node.type.is_string:
1548 encoding = None
1549 else:
1550 return None
1552 null_node = ExprNodes.NullNode(pos)
1553 if len(args) == 3:
1554 error_handling_node = args[2]
1555 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
1556 error_handling_node = error_handling_node.arg
1557 if isinstance(error_handling_node,
1558 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1559 ExprNodes.BytesNode)):
1560 error_handling = error_handling_node.value
1561 if error_handling == 'strict':
1562 error_handling_node = null_node
1563 else:
1564 error_handling_node = ExprNodes.BytesNode(
1565 error_handling_node.pos, value=error_handling,
1566 type=PyrexTypes.c_char_ptr_type)
1567 elif error_handling_node.type.is_string:
1568 error_handling = None
1569 else:
1570 return None
1571 else:
1572 error_handling = 'strict'
1573 error_handling_node = null_node
1575 return (encoding, encoding_node, error_handling, error_handling_node)
1577 def _substitute_method_call(self, node, name, func_type,
1578 attr_name, is_unbound_method, args=()):
1579 args = list(args)
1580 if args:
1581 self_arg = args[0]
1582 if is_unbound_method:
1583 self_arg = ExprNodes.NoneCheckNode(
1584 self_arg, "PyExc_TypeError",
1585 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
1586 attr_name, node.function.obj.name))
1587 else:
1588 self_arg = ExprNodes.NoneCheckNode(
1589 self_arg, "PyExc_AttributeError",
1590 "'NoneType' object has no attribute '%s'" % attr_name)
1591 args[0] = self_arg
1592 return ExprNodes.PythonCapiCallNode(
1593 node.pos, name, func_type,
1594 args = args,
1595 is_temp = node.is_temp
1596 )
1599 append_utility_code = UtilityCode(
1600 proto = """
1601 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1602 if (likely(PyList_CheckExact(L))) {
1603 if (PyList_Append(L, x) < 0) return NULL;
1604 Py_INCREF(Py_None);
1605 return Py_None; /* this is just to have an accurate signature */
1606 }
1607 else {
1608 PyObject *r, *m;
1609 m = __Pyx_GetAttrString(L, "append");
1610 if (!m) return NULL;
1611 r = PyObject_CallFunctionObjArgs(m, x, NULL);
1612 Py_DECREF(m);
1613 return r;
1614 }
1615 }
1616 """,
1617 impl = ""
1618 )
1621 pop_utility_code = UtilityCode(
1622 proto = """
1623 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1624 if (likely(PyList_CheckExact(L))
1625 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
1626 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
1627 Py_SIZE(L) -= 1;
1628 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
1629 }
1630 else {
1631 PyObject *r, *m;
1632 m = __Pyx_GetAttrString(L, "pop");
1633 if (!m) return NULL;
1634 r = PyObject_CallObject(m, NULL);
1635 Py_DECREF(m);
1636 return r;
1637 }
1638 }
1639 """,
1640 impl = ""
1641 )
1643 pop_index_utility_code = UtilityCode(
1644 proto = """
1645 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
1646 """,
1647 impl = """
1648 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
1649 PyObject *r, *m, *t, *py_ix;
1650 if (likely(PyList_CheckExact(L))) {
1651 Py_ssize_t size = PyList_GET_SIZE(L);
1652 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
1653 if (ix < 0) {
1654 ix += size;
1655 }
1656 if (likely(0 <= ix && ix < size)) {
1657 Py_ssize_t i;
1658 PyObject* v = PyList_GET_ITEM(L, ix);
1659 Py_SIZE(L) -= 1;
1660 size -= 1;
1661 for(i=ix; i<size; i++) {
1662 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
1663 }
1664 return v;
1665 }
1666 }
1667 }
1668 py_ix = t = NULL;
1669 m = __Pyx_GetAttrString(L, "pop");
1670 if (!m) goto bad;
1671 py_ix = PyInt_FromSsize_t(ix);
1672 if (!py_ix) goto bad;
1673 t = PyTuple_New(1);
1674 if (!t) goto bad;
1675 PyTuple_SET_ITEM(t, 0, py_ix);
1676 py_ix = NULL;
1677 r = PyObject_CallObject(m, t);
1678 Py_DECREF(m);
1679 Py_DECREF(t);
1680 return r;
1681 bad:
1682 Py_XDECREF(m);
1683 Py_XDECREF(t);
1684 Py_XDECREF(py_ix);
1685 return NULL;
1686 }
1687 """
1688 )
1691 pyobject_as_double_utility_code = UtilityCode(
1692 proto = '''
1693 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
1695 #define __Pyx_PyObject_AsDouble(obj) \\
1696 ((likely(PyFloat_CheckExact(obj))) ? \\
1697 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
1698 ''',
1699 impl='''
1700 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
1701 PyObject* float_value;
1702 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
1703 return PyFloat_AsDouble(obj);
1704 } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
1705 #if PY_MAJOR_VERSION >= 3
1706 float_value = PyFloat_FromString(obj);
1707 #else
1708 float_value = PyFloat_FromString(obj, 0);
1709 #endif
1710 } else {
1711 PyObject* args = PyTuple_New(1);
1712 if (unlikely(!args)) goto bad;
1713 PyTuple_SET_ITEM(args, 0, obj);
1714 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
1715 PyTuple_SET_ITEM(args, 0, 0);
1716 Py_DECREF(args);
1717 }
1718 if (likely(float_value)) {
1719 double value = PyFloat_AS_DOUBLE(float_value);
1720 Py_DECREF(float_value);
1721 return value;
1722 }
1723 bad:
1724 return (double)-1;
1725 }
1726 '''
1727 )
1730 include_string_h_utility_code = UtilityCode(
1731 proto = """
1732 #include <string.h>
1733 """
1734 )
1737 tpnew_utility_code = UtilityCode(
1738 proto = """
1739 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1740 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
1741 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
1742 }
1743 """ % {'TUPLE' : Naming.empty_tuple}
1744 )
1747 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
1748 """Calculate the result of constant expressions to store it in
1749 ``expr_node.constant_result``, and replace trivial cases by their
1750 constant result.
1751 """
1752 def _calculate_const(self, node):
1753 if node.constant_result is not ExprNodes.constant_value_not_set:
1754 return
1756 # make sure we always set the value
1757 not_a_constant = ExprNodes.not_a_constant
1758 node.constant_result = not_a_constant
1760 # check if all children are constant
1761 children = self.visitchildren(node)
1762 for child_result in children.itervalues():
1763 if type(child_result) is list:
1764 for child in child_result:
1765 if child.constant_result is not_a_constant:
1766 return
1767 elif child_result.constant_result is not_a_constant:
1768 return
1770 # now try to calculate the real constant value
1771 try:
1772 node.calculate_constant_result()
1773 # if node.constant_result is not ExprNodes.not_a_constant:
1774 # print node.__class__.__name__, node.constant_result
1775 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
1776 # ignore all 'normal' errors here => no constant result
1777 pass
1778 except Exception:
1779 # this looks like a real error
1780 import traceback, sys
1781 traceback.print_exc(file=sys.stdout)
1783 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
1784 ExprNodes.LongNode, ExprNodes.FloatNode)
1786 def _widest_node_class(self, *nodes):
1787 try:
1788 return self.NODE_TYPE_ORDER[
1789 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
1790 except ValueError:
1791 return None
1793 def visit_ExprNode(self, node):
1794 self._calculate_const(node)
1795 return node
1797 def visit_BinopNode(self, node):
1798 self._calculate_const(node)
1799 if node.constant_result is ExprNodes.not_a_constant:
1800 return node
1801 if isinstance(node.constant_result, float):
1802 # We calculate float constants to make them available to
1803 # the compiler, but we do not aggregate them into a
1804 # constant node to prevent any loss of precision.
1805 return node
1806 if not node.operand1.is_literal or not node.operand2.is_literal:
1807 # We calculate other constants to make them available to
1808 # the compiler, but we only aggregate constant nodes
1809 # recursively, so non-const nodes are straight out.
1810 return node
1812 # now inject a new constant node with the calculated value
1813 try:
1814 type1, type2 = node.operand1.type, node.operand2.type
1815 if type1 is None or type2 is None:
1816 return node
1817 except AttributeError:
1818 return node
1820 if type1 is type2:
1821 new_node = node.operand1
1822 else:
1823 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
1824 if type(node.operand1) is type(node.operand2):
1825 new_node = node.operand1
1826 new_node.type = widest_type
1827 elif type1 is widest_type:
1828 new_node = node.operand1
1829 elif type2 is widest_type:
1830 new_node = node.operand2
1831 else:
1832 target_class = self._widest_node_class(
1833 node.operand1, node.operand2)
1834 if target_class is None:
1835 return node
1836 new_node = target_class(pos=node.pos, type = widest_type)
1838 new_node.constant_result = node.constant_result
1839 new_node.value = str(node.constant_result)
1840 #new_node = new_node.coerce_to(node.type, self.current_scope)
1841 return new_node
1843 # in the future, other nodes can have their own handler method here
1844 # that can replace them with a constant result node
1846 visit_Node = Visitor.VisitorTransform.recurse_to_children
1849 class FinalOptimizePhase(Visitor.CythonTransform):
1850 """
1851 This visitor handles several commuting optimizations, and is run
1852 just before the C code generation phase.
1854 The optimizations currently implemented in this class are:
1855 - Eliminate None assignment and refcounting for first assignment.
1856 - isinstance -> typecheck for cdef types
1857 """
1858 def visit_SingleAssignmentNode(self, node):
1859 """Avoid redundant initialisation of local variables before their
1860 first assignment.
1861 """
1862 self.visitchildren(node)
1863 if node.first:
1864 lhs = node.lhs
1865 lhs.lhs_of_first_assignment = True
1866 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
1867 # Have variable initialized to 0 rather than None
1868 lhs.entry.init_to_none = False
1869 lhs.entry.init = 0
1870 return node
1872 def visit_SimpleCallNode(self, node):
1873 """Replace generic calls to isinstance(x, type) by a more efficient
1874 type check.
1875 """
1876 self.visitchildren(node)
1877 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
1878 if node.function.name == 'isinstance':
1879 type_arg = node.args[1]
1880 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
1881 from CythonScope import utility_scope
1882 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
1883 node.function.type = node.function.entry.type
1884 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
1885 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
1886 return node
