Cython has moved to github.
cython-devel
view Cython/Compiler/Optimize.py @ 2765:dda8f62132fc
specialised implementation for 'float(x) -> C double' to avoid redundant calls of float()
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Tue Dec 08 01:26:53 2009 +0100 (2 years ago) |
| parents | e5026358ebd3 |
| children | c1f3e5f0c594 |
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
17 import codecs
19 try:
20 reduce
21 except NameError:
22 from functools import reduce
24 try:
25 set
26 except NameError:
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
31 nogil = False
33 def unwrap_node(node):
34 while isinstance(node, UtilNodes.ResultRefNode):
35 node = node.expression
36 return node
38 def is_common_value(a, b):
39 a = unwrap_node(a)
40 b = unwrap_node(b)
41 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
42 return a.name == b.name
43 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 return False
47 class IterationTransform(Visitor.VisitorTransform):
48 """Transform some common for-in loop patterns into efficient C loops:
50 - for-in-dict loop becomes a while loop calling PyDict_Next()
51 - for-in-enumerate is replaced by an external counter variable
52 - for-in-range loop becomes a plain C for loop
53 """
54 PyDict_Next_func_type = PyrexTypes.CFuncType(
55 PyrexTypes.c_bint_type, [
56 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
57 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
58 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
59 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
60 ])
62 PyDict_Next_name = EncodedString("PyDict_Next")
64 PyDict_Next_entry = Symtab.Entry(
65 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
67 visit_Node = Visitor.VisitorTransform.recurse_to_children
69 def visit_ModuleNode(self, node):
70 self.current_scope = node.scope
71 self.visitchildren(node)
72 return node
74 def visit_DefNode(self, node):
75 oldscope = self.current_scope
76 self.current_scope = node.entry.scope
77 self.visitchildren(node)
78 self.current_scope = oldscope
79 return node
81 def visit_ForInStatNode(self, node):
82 self.visitchildren(node)
83 return self._optimise_for_loop(node)
85 def _optimise_for_loop(self, node):
86 iterator = node.iterator.sequence
87 if iterator.type is Builtin.dict_type:
88 # like iterating over dict.keys()
89 return self._transform_dict_iteration(
90 node, dict_obj=iterator, keys=True, values=False)
92 # C array slice iteration?
93 if isinstance(iterator, ExprNodes.SliceIndexNode) and \
94 (iterator.base.type.is_array or iterator.base.type.is_ptr):
95 return self._transform_carray_iteration(node, iterator)
96 elif not isinstance(iterator, ExprNodes.SimpleCallNode):
97 return node
99 function = iterator.function
100 # dict iteration?
101 if isinstance(function, ExprNodes.AttributeNode) and \
102 function.obj.type == Builtin.dict_type:
103 dict_obj = function.obj
104 method = function.attribute
106 keys = values = False
107 if method == 'iterkeys':
108 keys = True
109 elif method == 'itervalues':
110 values = True
111 elif method == 'iteritems':
112 keys = values = True
113 else:
114 return node
115 return self._transform_dict_iteration(
116 node, dict_obj, keys, values)
118 # enumerate() ?
119 if iterator.self is None and function.is_name and \
120 function.entry and function.entry.is_builtin and \
121 function.name == 'enumerate':
122 return self._transform_enumerate_iteration(node, iterator)
124 # range() iteration?
125 if Options.convert_range and node.target.type.is_int:
126 if iterator.self is None and function.is_name and \
127 function.entry and function.entry.is_builtin and \
128 function.name in ('range', 'xrange'):
129 return self._transform_range_iteration(node, iterator)
131 return node
133 def _transform_carray_iteration(self, node, slice_node):
134 start = slice_node.start
135 stop = slice_node.stop
136 step = None
137 if not stop:
138 return node
140 carray_ptr = slice_node.base.coerce_to_simple(self.current_scope)
142 if start and start.constant_result != 0:
143 start_ptr_node = ExprNodes.AddNode(
144 start.pos,
145 operand1=carray_ptr,
146 operator='+',
147 operand2=start,
148 type=carray_ptr.type)
149 else:
150 start_ptr_node = carray_ptr
152 stop_ptr_node = ExprNodes.AddNode(
153 stop.pos,
154 operand1=carray_ptr,
155 operator='+',
156 operand2=stop,
157 type=carray_ptr.type
158 ).coerce_to_simple(self.current_scope)
160 counter = UtilNodes.TempHandle(carray_ptr.type)
161 counter_temp = counter.ref(node.target.pos)
163 if slice_node.base.type.is_string and node.target.type.is_pyobject:
164 # special case: char* -> bytes
165 target_value = ExprNodes.SliceIndexNode(
166 node.target.pos,
167 start=ExprNodes.IntNode(node.target.pos, value='0',
168 constant_result=0,
169 type=PyrexTypes.c_int_type),
170 stop=ExprNodes.IntNode(node.target.pos, value='1',
171 constant_result=1,
172 type=PyrexTypes.c_int_type),
173 base=counter_temp,
174 type=Builtin.bytes_type,
175 is_temp=1)
176 else:
177 target_value = ExprNodes.IndexNode(
178 node.target.pos,
179 index=ExprNodes.IntNode(node.target.pos, value='0',
180 constant_result=0,
181 type=PyrexTypes.c_int_type),
182 base=counter_temp,
183 is_buffer_access=False,
184 type=carray_ptr.type.base_type)
186 if target_value.type != node.target.type:
187 target_value = target_value.coerce_to(node.target.type,
188 self.current_scope)
190 target_assign = Nodes.SingleAssignmentNode(
191 pos = node.target.pos,
192 lhs = node.target,
193 rhs = target_value)
195 body = Nodes.StatListNode(
196 node.pos,
197 stats = [target_assign, node.body])
199 for_node = Nodes.ForFromStatNode(
200 node.pos,
201 bound1=start_ptr_node, relation1='<=',
202 target=counter_temp,
203 relation2='<', bound2=stop_ptr_node,
204 step=step, body=body,
205 else_clause=node.else_clause,
206 from_range=True)
208 return UtilNodes.TempsBlockNode(
209 node.pos, temps=[counter],
210 body=for_node)
212 def _transform_enumerate_iteration(self, node, enumerate_function):
213 args = enumerate_function.arg_tuple.args
214 if len(args) == 0:
215 error(enumerate_function.pos,
216 "enumerate() requires an iterable argument")
217 return node
218 elif len(args) > 1:
219 error(enumerate_function.pos,
220 "enumerate() takes at most 1 argument")
221 return node
223 if not node.target.is_sequence_constructor:
224 # leave this untouched for now
225 return node
226 targets = node.target.args
227 if len(targets) != 2:
228 # leave this untouched for now
229 return node
230 if not isinstance(targets[0], ExprNodes.NameNode):
231 # leave this untouched for now
232 return node
234 enumerate_target, iterable_target = targets
235 counter_type = enumerate_target.type
237 if not counter_type.is_pyobject and not counter_type.is_int:
238 # nothing we can do here, I guess
239 return node
241 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
242 value='0',
243 type=counter_type,
244 constant_result=0))
245 inc_expression = ExprNodes.AddNode(
246 enumerate_function.pos,
247 operand1 = temp,
248 operand2 = ExprNodes.IntNode(node.pos, value='1',
249 type=counter_type,
250 constant_result=1),
251 operator = '+',
252 type = counter_type,
253 is_temp = counter_type.is_pyobject
254 )
256 loop_body = [
257 Nodes.SingleAssignmentNode(
258 pos = enumerate_target.pos,
259 lhs = enumerate_target,
260 rhs = temp),
261 Nodes.SingleAssignmentNode(
262 pos = enumerate_target.pos,
263 lhs = temp,
264 rhs = inc_expression)
265 ]
267 if isinstance(node.body, Nodes.StatListNode):
268 node.body.stats = loop_body + node.body.stats
269 else:
270 loop_body.append(node.body)
271 node.body = Nodes.StatListNode(
272 node.body.pos,
273 stats = loop_body)
275 node.target = iterable_target
276 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
277 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
279 # recurse into loop to check for further optimisations
280 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
282 def _transform_range_iteration(self, node, range_function):
283 args = range_function.arg_tuple.args
284 if len(args) < 3:
285 step_pos = range_function.pos
286 step_value = 1
287 step = ExprNodes.IntNode(step_pos, value='1',
288 constant_result=1)
289 else:
290 step = args[2]
291 step_pos = step.pos
292 if not isinstance(step.constant_result, (int, long)):
293 # cannot determine step direction
294 return node
295 step_value = step.constant_result
296 if step_value == 0:
297 # will lead to an error elsewhere
298 return node
299 if not isinstance(step, ExprNodes.IntNode):
300 step = ExprNodes.IntNode(step_pos, value=str(step_value),
301 constant_result=step_value)
303 if step_value < 0:
304 step.value = str(-step_value)
305 relation1 = '>='
306 relation2 = '>'
307 else:
308 relation1 = '<='
309 relation2 = '<'
311 if len(args) == 1:
312 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
313 constant_result=0)
314 bound2 = args[0].coerce_to_integer(self.current_scope)
315 else:
316 bound1 = args[0].coerce_to_integer(self.current_scope)
317 bound2 = args[1].coerce_to_integer(self.current_scope)
318 step = step.coerce_to_integer(self.current_scope)
320 if not bound2.is_literal:
321 # stop bound must be immutable => keep it in a temp var
322 bound2_is_temp = True
323 bound2 = UtilNodes.LetRefNode(bound2)
324 else:
325 bound2_is_temp = False
327 for_node = Nodes.ForFromStatNode(
328 node.pos,
329 target=node.target,
330 bound1=bound1, relation1=relation1,
331 relation2=relation2, bound2=bound2,
332 step=step, body=node.body,
333 else_clause=node.else_clause,
334 from_range=True)
336 if bound2_is_temp:
337 for_node = UtilNodes.LetNode(bound2, for_node)
339 return for_node
341 def _transform_dict_iteration(self, node, dict_obj, keys, values):
342 py_object_ptr = PyrexTypes.c_void_ptr_type
344 temps = []
345 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
346 temps.append(temp)
347 dict_temp = temp.ref(dict_obj.pos)
348 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
349 temps.append(temp)
350 pos_temp = temp.ref(node.pos)
351 pos_temp_addr = ExprNodes.AmpersandNode(
352 node.pos, operand=pos_temp,
353 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
354 if keys:
355 temp = UtilNodes.TempHandle(py_object_ptr)
356 temps.append(temp)
357 key_temp = temp.ref(node.target.pos)
358 key_temp_addr = ExprNodes.AmpersandNode(
359 node.target.pos, operand=key_temp,
360 type=PyrexTypes.c_ptr_type(py_object_ptr))
361 else:
362 key_temp_addr = key_temp = ExprNodes.NullNode(
363 pos=node.target.pos)
364 if values:
365 temp = UtilNodes.TempHandle(py_object_ptr)
366 temps.append(temp)
367 value_temp = temp.ref(node.target.pos)
368 value_temp_addr = ExprNodes.AmpersandNode(
369 node.target.pos, operand=value_temp,
370 type=PyrexTypes.c_ptr_type(py_object_ptr))
371 else:
372 value_temp_addr = value_temp = ExprNodes.NullNode(
373 pos=node.target.pos)
375 key_target = value_target = node.target
376 tuple_target = None
377 if keys and values:
378 if node.target.is_sequence_constructor:
379 if len(node.target.args) == 2:
380 key_target, value_target = node.target.args
381 else:
382 # unusual case that may or may not lead to an error
383 return node
384 else:
385 tuple_target = node.target
387 def coerce_object_to(obj_node, dest_type):
388 if dest_type.is_pyobject:
389 if dest_type != obj_node.type:
390 if dest_type.is_extension_type or dest_type.is_builtin_type:
391 obj_node = ExprNodes.PyTypeTestNode(
392 obj_node, dest_type, self.current_scope, notnone=True)
393 result = ExprNodes.TypecastNode(
394 obj_node.pos,
395 operand = obj_node,
396 type = dest_type)
397 return (result, None)
398 else:
399 temp = UtilNodes.TempHandle(dest_type)
400 temps.append(temp)
401 temp_result = temp.ref(obj_node.pos)
402 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
403 def result(self):
404 return temp_result.result()
405 def generate_execution_code(self, code):
406 self.generate_result_code(code)
407 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
409 if isinstance(node.body, Nodes.StatListNode):
410 body = node.body
411 else:
412 body = Nodes.StatListNode(pos = node.body.pos,
413 stats = [node.body])
415 if tuple_target:
416 tuple_result = ExprNodes.TupleNode(
417 pos = tuple_target.pos,
418 args = [key_temp, value_temp],
419 is_temp = 1,
420 type = Builtin.tuple_type,
421 )
422 body.stats.insert(
423 0, Nodes.SingleAssignmentNode(
424 pos = tuple_target.pos,
425 lhs = tuple_target,
426 rhs = tuple_result))
427 else:
428 # execute all coercions before the assignments
429 coercion_stats = []
430 assign_stats = []
431 if keys:
432 temp_result, coercion = coerce_object_to(
433 key_temp, key_target.type)
434 if coercion:
435 coercion_stats.append(coercion)
436 assign_stats.append(
437 Nodes.SingleAssignmentNode(
438 pos = key_temp.pos,
439 lhs = key_target,
440 rhs = temp_result))
441 if values:
442 temp_result, coercion = coerce_object_to(
443 value_temp, value_target.type)
444 if coercion:
445 coercion_stats.append(coercion)
446 assign_stats.append(
447 Nodes.SingleAssignmentNode(
448 pos = value_temp.pos,
449 lhs = value_target,
450 rhs = temp_result))
451 body.stats[0:0] = coercion_stats + assign_stats
453 result_code = [
454 Nodes.SingleAssignmentNode(
455 pos = dict_obj.pos,
456 lhs = dict_temp,
457 rhs = dict_obj),
458 Nodes.SingleAssignmentNode(
459 pos = node.pos,
460 lhs = pos_temp,
461 rhs = ExprNodes.IntNode(node.pos, value='0',
462 constant_result=0)),
463 Nodes.WhileStatNode(
464 pos = node.pos,
465 condition = ExprNodes.SimpleCallNode(
466 pos = dict_obj.pos,
467 type = PyrexTypes.c_bint_type,
468 function = ExprNodes.NameNode(
469 pos = dict_obj.pos,
470 name = self.PyDict_Next_name,
471 type = self.PyDict_Next_func_type,
472 entry = self.PyDict_Next_entry),
473 args = [dict_temp, pos_temp_addr,
474 key_temp_addr, value_temp_addr]
475 ),
476 body = body,
477 else_clause = node.else_clause
478 )
479 ]
481 return UtilNodes.TempsBlockNode(
482 node.pos, temps=temps,
483 body=Nodes.StatListNode(
484 node.pos,
485 stats = result_code
486 ))
489 class SwitchTransform(Visitor.VisitorTransform):
490 """
491 This transformation tries to turn long if statements into C switch statements.
492 The requirement is that every clause be an (or of) var == value, where the var
493 is common among all clauses and both var and value are ints.
494 """
495 def extract_conditions(self, cond):
496 while True:
497 if isinstance(cond, ExprNodes.CoerceToTempNode):
498 cond = cond.arg
499 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
500 # this is what we get from the FlattenInListTransform
501 cond = cond.subexpression
502 elif isinstance(cond, ExprNodes.TypecastNode):
503 cond = cond.operand
504 else:
505 break
507 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
508 and cond.cascade is None
509 and cond.operator == '=='
510 and not cond.is_python_comparison()):
511 if is_common_value(cond.operand1, cond.operand1):
512 if cond.operand2.is_literal:
513 return cond.operand1, [cond.operand2]
514 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
515 return cond.operand1, [cond.operand2]
516 if is_common_value(cond.operand2, cond.operand2):
517 if cond.operand1.is_literal:
518 return cond.operand2, [cond.operand1]
519 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
520 return cond.operand2, [cond.operand1]
521 elif (isinstance(cond, ExprNodes.BoolBinopNode)
522 and cond.operator == 'or'):
523 t1, c1 = self.extract_conditions(cond.operand1)
524 t2, c2 = self.extract_conditions(cond.operand2)
525 if is_common_value(t1, t2):
526 return t1, c1+c2
527 return None, None
529 def visit_IfStatNode(self, node):
530 self.visitchildren(node)
531 common_var = None
532 case_count = 0
533 cases = []
534 for if_clause in node.if_clauses:
535 var, conditions = self.extract_conditions(if_clause.condition)
536 if var is None:
537 return node
538 elif common_var is not None and not is_common_value(var, common_var):
539 return node
540 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
541 return node
542 else:
543 common_var = var
544 case_count += len(conditions)
545 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
546 conditions = conditions,
547 body = if_clause.body))
548 if case_count < 2:
549 return node
551 common_var = unwrap_node(common_var)
552 return Nodes.SwitchStatNode(pos = node.pos,
553 test = common_var,
554 cases = cases,
555 else_clause = node.else_clause)
557 visit_Node = Visitor.VisitorTransform.recurse_to_children
560 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
561 """
562 This transformation flattens "x in [val1, ..., valn]" into a sequential list
563 of comparisons.
564 """
566 def visit_PrimaryCmpNode(self, node):
567 self.visitchildren(node)
568 if node.cascade is not None:
569 return node
570 elif node.operator == 'in':
571 conjunction = 'or'
572 eq_or_neq = '=='
573 elif node.operator == 'not_in':
574 conjunction = 'and'
575 eq_or_neq = '!='
576 else:
577 return node
579 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
580 return node
582 args = node.operand2.args
583 if len(args) == 0:
584 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
586 lhs = UtilNodes.ResultRefNode(node.operand1)
588 conds = []
589 for arg in args:
590 cond = ExprNodes.PrimaryCmpNode(
591 pos = node.pos,
592 operand1 = lhs,
593 operator = eq_or_neq,
594 operand2 = arg,
595 cascade = None)
596 conds.append(ExprNodes.TypecastNode(
597 pos = node.pos,
598 operand = cond,
599 type = PyrexTypes.c_bint_type))
600 def concat(left, right):
601 return ExprNodes.BoolBinopNode(
602 pos = node.pos,
603 operator = conjunction,
604 operand1 = left,
605 operand2 = right)
607 condition = reduce(concat, conds)
608 return UtilNodes.EvalWithTempExprNode(lhs, condition)
610 visit_Node = Visitor.VisitorTransform.recurse_to_children
613 class DropRefcountingTransform(Visitor.VisitorTransform):
614 """Drop ref-counting in safe places.
615 """
616 visit_Node = Visitor.VisitorTransform.recurse_to_children
618 def visit_ParallelAssignmentNode(self, node):
619 left_names, right_names = [], []
620 left_indices, right_indices = [], []
621 temps = []
623 for stat in node.stats:
624 if isinstance(stat, Nodes.SingleAssignmentNode):
625 if not self._extract_operand(stat.lhs, left_names,
626 left_indices, temps):
627 return node
628 if not self._extract_operand(stat.rhs, right_names,
629 right_indices, temps):
630 return node
631 elif isinstance(stat, Nodes.CascadedAssignmentNode):
632 # FIXME
633 return node
634 else:
635 return node
637 if left_names or right_names:
638 # lhs/rhs names must be a non-redundant permutation
639 lnames = [ path for path, n in left_names ]
640 rnames = [ path for path, n in right_names ]
641 if set(lnames) != set(rnames):
642 return node
643 if len(set(lnames)) != len(right_names):
644 return node
646 if left_indices or right_indices:
647 # base name and index of index nodes must be a
648 # non-redundant permutation
649 lindices = []
650 for lhs_node in left_indices:
651 index_id = self._extract_index_id(lhs_node)
652 if not index_id:
653 return node
654 lindices.append(index_id)
655 rindices = []
656 for rhs_node in right_indices:
657 index_id = self._extract_index_id(rhs_node)
658 if not index_id:
659 return node
660 rindices.append(index_id)
662 if set(lindices) != set(rindices):
663 return node
664 if len(set(lindices)) != len(right_indices):
665 return node
667 # really supporting IndexNode requires support in
668 # __Pyx_GetItemInt(), so let's stop short for now
669 return node
671 temp_args = [t.arg for t in temps]
672 for temp in temps:
673 temp.use_managed_ref = False
675 for _, name_node in left_names + right_names:
676 if name_node not in temp_args:
677 name_node.use_managed_ref = False
679 for index_node in left_indices + right_indices:
680 index_node.use_managed_ref = False
682 return node
684 def _extract_operand(self, node, names, indices, temps):
685 node = unwrap_node(node)
686 if not node.type.is_pyobject:
687 return False
688 if isinstance(node, ExprNodes.CoerceToTempNode):
689 temps.append(node)
690 node = node.arg
691 name_path = []
692 obj_node = node
693 while isinstance(obj_node, ExprNodes.AttributeNode):
694 if obj_node.is_py_attr:
695 return False
696 name_path.append(obj_node.member)
697 obj_node = obj_node.obj
698 if isinstance(obj_node, ExprNodes.NameNode):
699 name_path.append(obj_node.name)
700 names.append( ('.'.join(name_path[::-1]), node) )
701 elif isinstance(node, ExprNodes.IndexNode):
702 if node.base.type != Builtin.list_type:
703 return False
704 if not node.index.type.is_int:
705 return False
706 if not isinstance(node.base, ExprNodes.NameNode):
707 return False
708 indices.append(node)
709 else:
710 return False
711 return True
713 def _extract_index_id(self, index_node):
714 base = index_node.base
715 index = index_node.index
716 if isinstance(index, ExprNodes.NameNode):
717 index_val = index.name
718 elif isinstance(index, ExprNodes.ConstNode):
719 # FIXME:
720 return None
721 else:
722 return None
723 return (base.name, index_val)
726 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
727 """Optimize some common calls to builtin types *before* the type
728 analysis phase and *after* the declarations analysis phase.
730 This transform cannot make use of any argument types, but it can
731 restructure the tree in a way that the type analysis phase can
732 respond to.
733 """
734 # only intercept on call nodes
735 visit_Node = Visitor.VisitorTransform.recurse_to_children
737 def visit_SimpleCallNode(self, node):
738 self.visitchildren(node)
739 function = node.function
740 if not self._function_is_builtin_name(function):
741 return node
742 return self._dispatch_to_handler(node, function, node.args)
744 def visit_GeneralCallNode(self, node):
745 self.visitchildren(node)
746 function = node.function
747 if not self._function_is_builtin_name(function):
748 return node
749 arg_tuple = node.positional_args
750 if not isinstance(arg_tuple, ExprNodes.TupleNode):
751 return node
752 args = arg_tuple.args
753 return self._dispatch_to_handler(
754 node, function, args, node.keyword_args)
756 def _function_is_builtin_name(self, function):
757 if not function.is_name:
758 return False
759 entry = self.env_stack[-1].lookup(function.name)
760 if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
761 return False
762 return True
764 def _dispatch_to_handler(self, node, function, args, kwargs=None):
765 if kwargs is None:
766 handler_name = '_handle_simple_function_%s' % function.name
767 else:
768 handler_name = '_handle_general_function_%s' % function.name
769 handle_call = getattr(self, handler_name, None)
770 if handle_call is not None:
771 if kwargs is None:
772 return handle_call(node, args)
773 else:
774 return handle_call(node, args, kwargs)
775 return node
777 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
778 node.function = ExprNodes.PythonCapiFunctionNode(
779 node.function.pos, node.function.name, cname, func_type,
780 utility_code = utility_code)
782 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
783 if not expected: # None or 0
784 arg_str = ''
785 elif isinstance(expected, basestring) or expected > 1:
786 arg_str = '...'
787 elif expected == 1:
788 arg_str = 'x'
789 else:
790 arg_str = ''
791 if expected is not None:
792 expected_str = 'expected %s, ' % expected
793 else:
794 expected_str = ''
795 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
796 function_name, arg_str, expected_str, len(args)))
798 # specific handlers for simple call nodes
800 def _handle_simple_function_set(self, node, pos_args):
801 """Replace set([a,b,...]) by a literal set {a,b,...} and
802 set([ x for ... ]) by a literal { x for ... }.
803 """
804 arg_count = len(pos_args)
805 if arg_count == 0:
806 return ExprNodes.SetNode(node.pos, args=[],
807 type=Builtin.set_type)
808 if arg_count > 1:
809 return node
810 iterable = pos_args[0]
811 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
812 return ExprNodes.SetNode(node.pos, args=iterable.args)
813 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
814 isinstance(iterable.target, (ExprNodes.ListNode,
815 ExprNodes.SetNode)):
816 iterable.target = ExprNodes.SetNode(node.pos, args=[])
817 iterable.pos = node.pos
818 return iterable
819 else:
820 return node
822 def _handle_simple_function_dict(self, node, pos_args):
823 """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
824 """
825 if len(pos_args) != 1:
826 return node
827 arg = pos_args[0]
828 if isinstance(arg, ExprNodes.ComprehensionNode) and \
829 isinstance(arg.target, (ExprNodes.ListNode,
830 ExprNodes.SetNode)):
831 append_node = arg.append
832 if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
833 len(append_node.expr.args) == 2:
834 key_node, value_node = append_node.expr.args
835 target_node = ExprNodes.DictNode(
836 pos=arg.target.pos, key_value_pairs=[])
837 new_append_node = ExprNodes.DictComprehensionAppendNode(
838 append_node.pos, target=target_node,
839 key_expr=key_node, value_expr=value_node)
840 arg.target = target_node
841 arg.type = target_node.type
842 replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
843 return replace_in(arg)
844 return node
846 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
847 PyrexTypes.py_object_type, [
848 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
849 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
850 ])
852 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
853 PyrexTypes.py_object_type, [
854 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
855 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
856 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
857 ])
859 def _handle_simple_function_getattr(self, node, pos_args):
860 if len(pos_args) == 2:
861 self._inject_capi_function(
862 node, "PyObject_GetAttr",
863 self.PyObject_GetAttr2_func_type)
864 elif len(pos_args) == 3:
865 self._inject_capi_function(
866 node, "__Pyx_GetAttr3",
867 self.PyObject_GetAttr3_func_type,
868 Builtin.getattr3_utility_code)
869 else:
870 self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
871 return node
873 Pyx_Type_func_type = PyrexTypes.CFuncType(
874 Builtin.type_type, [
875 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
876 ])
878 def _handle_simple_function_type(self, node, pos_args):
879 if len(pos_args) != 1:
880 return node
881 self._inject_capi_function(
882 node, "__Pyx_Type",
883 self.Pyx_Type_func_type,
884 pytype_utility_code)
885 return node
887 def _handle_simple_function_float(self, node, pos_args):
888 if len(pos_args) == 0:
889 return ExprNodes.FloatNode(node.pos, value='0.0')
890 if len(pos_args) > 1:
891 self._error_wrong_arg_count('float', node, pos_args, 1)
892 return node
894 # specific handlers for general call nodes
896 def _handle_general_function_dict(self, node, pos_args, kwargs):
897 """Replace dict(a=b,c=d,...) by the underlying keyword dict
898 construction which is done anyway.
899 """
900 if len(pos_args) > 0:
901 return node
902 if not isinstance(kwargs, ExprNodes.DictNode):
903 return node
904 if node.starstar_arg:
905 # we could optimize this by updating the kw dict instead
906 return node
907 return kwargs
910 class OptimizeBuiltinCalls(Visitor.EnvTransform):
911 """Optimize some common methods calls and instantiation patterns
912 for builtin types *after* the type analysis phase.
914 Running after type analysis, this transform can only perform
915 function replacements that do not alter the function return type
916 in a way that was not anticipated by the type analysis.
917 """
918 # only intercept on call nodes
919 visit_Node = Visitor.VisitorTransform.recurse_to_children
921 def visit_GeneralCallNode(self, node):
922 self.visitchildren(node)
923 function = node.function
924 if not function.type.is_pyobject:
925 return node
926 arg_tuple = node.positional_args
927 if not isinstance(arg_tuple, ExprNodes.TupleNode):
928 return node
929 args = arg_tuple.args
930 return self._dispatch_to_handler(
931 node, function, args, node.keyword_args)
933 def visit_SimpleCallNode(self, node):
934 self.visitchildren(node)
935 function = node.function
936 if function.type.is_pyobject:
937 arg_tuple = node.arg_tuple
938 if not isinstance(arg_tuple, ExprNodes.TupleNode):
939 return node
940 args = arg_tuple.args
941 else:
942 args = node.args
943 return self._dispatch_to_handler(
944 node, function, args)
946 ### cleanup to avoid redundant coercions to/from Python types
948 def _visit_PyTypeTestNode(self, node):
949 # disabled - appears to break assignments in some cases, and
950 # also drops a None check, which might still be required
951 """Flatten redundant type checks after tree changes.
952 """
953 old_arg = node.arg
954 self.visitchildren(node)
955 if old_arg is node.arg or node.arg.type != node.type:
956 return node
957 return node.arg
959 def visit_CoerceFromPyTypeNode(self, node):
960 """Drop redundant conversion nodes after tree changes.
962 Also, optimise away calls to Python's builtin int() and
963 float() if the result is going to be coerced back into a C
964 type anyway.
965 """
966 self.visitchildren(node)
967 arg = node.arg
968 if not arg.type.is_pyobject:
969 # no Python conversion left at all, just do a C coercion instead
970 if node.type == arg.type:
971 return arg
972 else:
973 return arg.coerce_to(node.type, self.env_stack[-1])
974 if not isinstance(arg, ExprNodes.SimpleCallNode):
975 return node
976 if not (node.type.is_int or node.type.is_float):
977 return node
978 function = arg.function
979 if not isinstance(function, ExprNodes.NameNode) \
980 or not function.type.is_builtin_type \
981 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
982 return node
983 args = arg.arg_tuple.args
984 if len(args) != 1:
985 return node
986 func_arg = args[0]
987 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
988 func_arg = func_arg.arg
989 elif func_arg.type.is_pyobject:
990 # play safe: Python conversion might work on all sorts of things
991 return node
992 if function.name == 'int':
993 if func_arg.type.is_int or node.type.is_int:
994 if func_arg.type == node.type:
995 return func_arg
996 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
997 return ExprNodes.CastNode(func_arg, node.type)
998 elif function.name == 'float':
999 if func_arg.type.is_float or node.type.is_float:
1000 if func_arg.type == node.type:
1001 return func_arg
1002 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1003 return ExprNodes.CastNode(func_arg, node.type)
1004 return node
1006 ### dispatch to specific optimisers
1008 def _find_handler(self, match_name, has_kwargs):
1009 call_type = has_kwargs and 'general' or 'simple'
1010 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1011 if handler is None:
1012 handler = getattr(self, '_handle_any_%s' % match_name, None)
1013 return handler
1015 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1016 if function.is_name:
1017 # we only consider functions that are either builtin
1018 # Python functions or builtins that were already replaced
1019 # into a C function call (defined in the builtin scope)
1020 if not function.entry:
1021 return node
1022 is_builtin = function.entry.is_builtin \
1023 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1024 if not is_builtin:
1025 return node
1026 function_handler = self._find_handler(
1027 "function_%s" % function.name, kwargs)
1028 if function_handler is None:
1029 return node
1030 if kwargs:
1031 return function_handler(node, arg_list, kwargs)
1032 else:
1033 return function_handler(node, arg_list)
1034 elif function.is_attribute and function.type.is_pyobject:
1035 attr_name = function.attribute
1036 self_arg = function.obj
1037 obj_type = self_arg.type
1038 is_unbound_method = False
1039 if obj_type.is_builtin_type:
1040 if obj_type is Builtin.type_type and arg_list and \
1041 arg_list[0].type.is_pyobject:
1042 # calling an unbound method like 'list.append(L,x)'
1043 # (ignoring 'type.mro()' here ...)
1044 type_name = function.obj.name
1045 self_arg = None
1046 is_unbound_method = True
1047 else:
1048 type_name = obj_type.name
1049 else:
1050 type_name = "object" # safety measure
1051 method_handler = self._find_handler(
1052 "method_%s_%s" % (type_name, attr_name), kwargs)
1053 if method_handler is None:
1054 if attr_name in TypeSlots.method_name_to_slot \
1055 or attr_name == '__new__':
1056 method_handler = self._find_handler(
1057 "slot%s" % attr_name, kwargs)
1058 if method_handler is None:
1059 return node
1060 if self_arg is not None:
1061 arg_list = [self_arg] + list(arg_list)
1062 if kwargs:
1063 return method_handler(node, arg_list, kwargs, is_unbound_method)
1064 else:
1065 return method_handler(node, arg_list, is_unbound_method)
1066 else:
1067 return node
1069 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1070 if not expected: # None or 0
1071 arg_str = ''
1072 elif isinstance(expected, basestring) or expected > 1:
1073 arg_str = '...'
1074 elif expected == 1:
1075 arg_str = 'x'
1076 else:
1077 arg_str = ''
1078 if expected is not None:
1079 expected_str = 'expected %s, ' % expected
1080 else:
1081 expected_str = ''
1082 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1083 function_name, arg_str, expected_str, len(args)))
1085 ### builtin types
1087 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1088 Builtin.dict_type, [
1089 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1090 ])
1092 def _handle_simple_function_dict(self, node, pos_args):
1093 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1094 """
1095 if len(pos_args) != 1:
1096 return node
1097 arg = pos_args[0]
1098 if arg.type is Builtin.dict_type:
1099 arg = ExprNodes.NoneCheckNode(
1100 arg, "PyExc_TypeError", "'NoneType' is not iterable")
1101 return ExprNodes.PythonCapiCallNode(
1102 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1103 args = [arg],
1104 is_temp = node.is_temp
1105 )
1106 return node
1108 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1109 Builtin.tuple_type, [
1110 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1111 ])
1113 def _handle_simple_function_tuple(self, node, pos_args):
1114 """Replace tuple([...]) by a call to PyList_AsTuple.
1115 """
1116 if len(pos_args) != 1:
1117 return node
1118 list_arg = pos_args[0]
1119 if list_arg.type is not Builtin.list_type:
1120 return node
1121 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1122 ExprNodes.ListNode)):
1123 pos_args[0] = ExprNodes.NoneCheckNode(
1124 list_arg, "PyExc_TypeError",
1125 "'NoneType' object is not iterable")
1127 return ExprNodes.PythonCapiCallNode(
1128 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1129 args = pos_args,
1130 is_temp = node.is_temp
1131 )
1133 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1134 PyrexTypes.c_double_type, [
1135 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1136 ],
1137 exception_value = "((double)-1)",
1138 exception_check = True)
1140 def _handle_simple_function_float(self, node, pos_args):
1141 # Note: this requires the float() function to be typed as
1142 # returning a C 'double'
1143 if len(pos_args) != 1:
1144 self._error_wrong_arg_count('float', node, pos_args, 1)
1145 return node
1146 func_arg = pos_args[0]
1147 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1148 func_arg = func_arg.arg
1149 if func_arg.type is PyrexTypes.c_double_type:
1150 return func_arg
1151 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1152 return ExprNodes.CastNode(func_arg, node.type)
1153 return ExprNodes.PythonCapiCallNode(
1154 node.pos, "__Pyx_PyObject_AsDouble",
1155 self.PyObject_AsDouble_func_type,
1156 args = pos_args,
1157 is_temp = node.is_temp,
1158 utility_code = pyobject_as_double_utility_code,
1159 py_name = "float")
1161 ### builtin functions
1163 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1164 PyrexTypes.c_size_t_type, [
1165 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1166 ])
1168 def _handle_simple_function_len(self, node, pos_args):
1169 # note: this only works because we already replaced len() by
1170 # PyObject_Length() which returns a Py_ssize_t instead of a
1171 # Python object, so we can return a plain size_t instead
1172 # without caring about Python object conversion etc.
1173 if len(pos_args) != 1:
1174 self._error_wrong_arg_count('len', node, pos_args, 1)
1175 return node
1176 arg = pos_args[0]
1177 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1178 arg = arg.arg
1179 if not arg.type.is_string:
1180 return node
1181 node = ExprNodes.PythonCapiCallNode(
1182 node.pos, "strlen", self.Pyx_strlen_func_type,
1183 args = [arg],
1184 is_temp = node.is_temp,
1185 utility_code = include_string_h_utility_code
1186 )
1187 return node
1189 ### special methods
1191 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1192 PyrexTypes.py_object_type, [
1193 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1194 ])
1196 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1197 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1198 """
1199 obj = node.function.obj
1200 if not is_unbound_method or len(args) != 1:
1201 return node
1202 type_arg = args[0]
1203 if not obj.is_name or not type_arg.is_name:
1204 # play safe
1205 return node
1206 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1207 # not a known type, play safe
1208 return node
1209 if not type_arg.type_entry or not obj.type_entry:
1210 if obj.name != type_arg.name:
1211 return node
1212 # otherwise, we know it's a type and we know it's the same
1213 # type for both - that should do
1214 elif type_arg.type_entry != obj.type_entry:
1215 # different types - may or may not lead to an error at runtime
1216 return node
1218 # FIXME: we could potentially look up the actual tp_new C
1219 # method of the extension type and call that instead of the
1220 # generic slot. That would also allow us to pass parameters
1221 # efficiently.
1223 if not type_arg.type_entry:
1224 # arbitrary variable, needs a None check for safety
1225 type_arg = ExprNodes.NoneCheckNode(
1226 type_arg, "PyExc_TypeError",
1227 "object.__new__(X): X is not a type object (NoneType)")
1229 return ExprNodes.PythonCapiCallNode(
1230 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1231 args = [type_arg],
1232 utility_code = tpnew_utility_code,
1233 is_temp = node.is_temp
1234 )
1236 ### methods of builtin types
1238 PyObject_Append_func_type = PyrexTypes.CFuncType(
1239 PyrexTypes.py_object_type, [
1240 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1241 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1242 ])
1244 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
1245 # X.append() is almost always referring to a list
1246 if len(args) != 2:
1247 return node
1249 return ExprNodes.PythonCapiCallNode(
1250 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
1251 args = args,
1252 is_temp = node.is_temp,
1253 utility_code = append_utility_code
1254 )
1256 PyObject_Pop_func_type = PyrexTypes.CFuncType(
1257 PyrexTypes.py_object_type, [
1258 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1259 ])
1261 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
1262 PyrexTypes.py_object_type, [
1263 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1264 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
1265 ])
1267 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
1268 # X.pop([n]) is almost always referring to a list
1269 if len(args) == 1:
1270 return ExprNodes.PythonCapiCallNode(
1271 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
1272 args = args,
1273 is_temp = node.is_temp,
1274 utility_code = pop_utility_code
1275 )
1276 elif len(args) == 2:
1277 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
1278 original_type = args[1].arg.type
1279 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
1280 args[1] = args[1].arg
1281 return ExprNodes.PythonCapiCallNode(
1282 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
1283 args = args,
1284 is_temp = node.is_temp,
1285 utility_code = pop_index_utility_code
1286 )
1288 return node
1290 PyList_Append_func_type = PyrexTypes.CFuncType(
1291 PyrexTypes.c_int_type, [
1292 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1293 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1294 ],
1295 exception_value = "-1")
1297 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
1298 if len(args) != 2:
1299 self._error_wrong_arg_count('list.append', node, args, 2)
1300 return node
1301 return self._substitute_method_call(
1302 node, "PyList_Append", self.PyList_Append_func_type,
1303 'append', is_unbound_method, args)
1305 single_param_func_type = PyrexTypes.CFuncType(
1306 PyrexTypes.c_int_type, [
1307 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1308 ],
1309 exception_value = "-1")
1311 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
1312 if len(args) != 1:
1313 return node
1314 return self._substitute_method_call(
1315 node, "PyList_Sort", self.single_param_func_type,
1316 'sort', is_unbound_method, args)
1318 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
1319 if len(args) != 1:
1320 self._error_wrong_arg_count('list.reverse', node, args, 1)
1321 return node
1322 return self._substitute_method_call(
1323 node, "PyList_Reverse", self.single_param_func_type,
1324 'reverse', is_unbound_method, args)
1326 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
1327 Builtin.bytes_type, [
1328 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1329 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1330 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1331 ],
1332 exception_value = "NULL")
1334 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
1335 Builtin.bytes_type, [
1336 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1337 ],
1338 exception_value = "NULL")
1340 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
1341 'unicode_escape', 'raw_unicode_escape']
1343 _special_codecs = [ (name, codecs.getencoder(name))
1344 for name in _special_encodings ]
1346 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
1347 if len(args) < 1 or len(args) > 3:
1348 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1349 return node
1351 string_node = args[0]
1353 if len(args) == 1:
1354 null_node = ExprNodes.NullNode(node.pos)
1355 return self._substitute_method_call(
1356 node, "PyUnicode_AsEncodedString",
1357 self.PyUnicode_AsEncodedString_func_type,
1358 'encode', is_unbound_method, [string_node, null_node, null_node])
1360 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1361 if parameters is None:
1362 return node
1363 encoding, encoding_node, error_handling, error_handling_node = parameters
1365 if isinstance(string_node, ExprNodes.UnicodeNode):
1366 # constant, so try to do the encoding at compile time
1367 try:
1368 value = string_node.value.encode(encoding, error_handling)
1369 except:
1370 # well, looks like we can't
1371 pass
1372 else:
1373 value = BytesLiteral(value)
1374 value.encoding = encoding
1375 return ExprNodes.BytesNode(
1376 string_node.pos, value=value, type=Builtin.bytes_type)
1378 if error_handling == 'strict':
1379 # try to find a specific encoder function
1380 codec_name = self._find_special_codec_name(encoding)
1381 if codec_name is not None:
1382 encode_function = "PyUnicode_As%sString" % codec_name
1383 return self._substitute_method_call(
1384 node, encode_function,
1385 self.PyUnicode_AsXyzString_func_type,
1386 'encode', is_unbound_method, [string_node])
1388 return self._substitute_method_call(
1389 node, "PyUnicode_AsEncodedString",
1390 self.PyUnicode_AsEncodedString_func_type,
1391 'encode', is_unbound_method,
1392 [string_node, encoding_node, error_handling_node])
1394 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
1395 Builtin.unicode_type, [
1396 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1397 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1398 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1399 ],
1400 exception_value = "NULL")
1402 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
1403 Builtin.unicode_type, [
1404 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1405 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1406 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1407 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1408 ],
1409 exception_value = "NULL")
1411 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
1412 if len(args) < 1 or len(args) > 3:
1413 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
1414 return node
1415 temps = []
1416 if isinstance(args[0], ExprNodes.SliceIndexNode):
1417 index_node = args[0]
1418 string_node = index_node.base
1419 if not string_node.type.is_string:
1420 # nothing to optimise here
1421 return node
1422 start, stop = index_node.start, index_node.stop
1423 if not start or start.constant_result == 0:
1424 start = None
1425 else:
1426 if start.type.is_pyobject:
1427 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1428 if stop:
1429 start = UtilNodes.LetRefNode(start)
1430 temps.append(start)
1431 string_node = ExprNodes.AddNode(pos=start.pos,
1432 operand1=string_node,
1433 operator='+',
1434 operand2=start,
1435 is_temp=False,
1436 type=string_node.type
1437 )
1438 if stop and stop.type.is_pyobject:
1439 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1440 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
1441 and args[0].arg.type.is_string:
1442 # use strlen() to find the string length, just as CPython would
1443 start = stop = None
1444 string_node = args[0].arg
1445 else:
1446 # let Python do its job
1447 return node
1449 if not stop:
1450 if start or not string_node.is_name:
1451 string_node = UtilNodes.LetRefNode(string_node)
1452 temps.append(string_node)
1453 stop = ExprNodes.PythonCapiCallNode(
1454 string_node.pos, "strlen", self.Pyx_strlen_func_type,
1455 args = [string_node],
1456 is_temp = False,
1457 utility_code = include_string_h_utility_code,
1458 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1459 elif start:
1460 stop = ExprNodes.SubNode(
1461 pos = stop.pos,
1462 operand1 = stop,
1463 operator = '-',
1464 operand2 = start,
1465 is_temp = False,
1466 type = PyrexTypes.c_py_ssize_t_type
1467 )
1469 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1470 if parameters is None:
1471 return node
1472 encoding, encoding_node, error_handling, error_handling_node = parameters
1474 # try to find a specific encoder function
1475 codec_name = None
1476 if encoding is not None:
1477 codec_name = self._find_special_codec_name(encoding)
1478 if codec_name is not None:
1479 decode_function = "PyUnicode_Decode%s" % codec_name
1480 node = ExprNodes.PythonCapiCallNode(
1481 node.pos, decode_function,
1482 self.PyUnicode_DecodeXyz_func_type,
1483 args = [string_node, stop, error_handling_node],
1484 is_temp = node.is_temp,
1485 )
1486 else:
1487 node = ExprNodes.PythonCapiCallNode(
1488 node.pos, "PyUnicode_Decode",
1489 self.PyUnicode_Decode_func_type,
1490 args = [string_node, stop, encoding_node, error_handling_node],
1491 is_temp = node.is_temp,
1492 )
1494 for temp in temps[::-1]:
1495 node = UtilNodes.EvalWithTempExprNode(temp, node)
1496 return node
1498 def _find_special_codec_name(self, encoding):
1499 try:
1500 requested_codec = codecs.getencoder(encoding)
1501 except:
1502 return None
1503 for name, codec in self._special_codecs:
1504 if codec == requested_codec:
1505 if '_' in name:
1506 name = ''.join([ s.capitalize()
1507 for s in name.split('_')])
1508 return name
1509 return None
1511 def _unpack_encoding_and_error_mode(self, pos, args):
1512 encoding_node = args[1]
1513 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
1514 encoding_node = encoding_node.arg
1515 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1516 ExprNodes.BytesNode)):
1517 encoding = encoding_node.value
1518 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
1519 type=PyrexTypes.c_char_ptr_type)
1520 elif encoding_node.type.is_string:
1521 encoding = None
1522 else:
1523 return None
1525 null_node = ExprNodes.NullNode(pos)
1526 if len(args) == 3:
1527 error_handling_node = args[2]
1528 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
1529 error_handling_node = error_handling_node.arg
1530 if isinstance(error_handling_node,
1531 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1532 ExprNodes.BytesNode)):
1533 error_handling = error_handling_node.value
1534 if error_handling == 'strict':
1535 error_handling_node = null_node
1536 else:
1537 error_handling_node = ExprNodes.BytesNode(
1538 error_handling_node.pos, value=error_handling,
1539 type=PyrexTypes.c_char_ptr_type)
1540 elif error_handling_node.type.is_string:
1541 error_handling = None
1542 else:
1543 return None
1544 else:
1545 error_handling = 'strict'
1546 error_handling_node = null_node
1548 return (encoding, encoding_node, error_handling, error_handling_node)
1550 def _substitute_method_call(self, node, name, func_type,
1551 attr_name, is_unbound_method, args=()):
1552 args = list(args)
1553 if args:
1554 self_arg = args[0]
1555 if is_unbound_method:
1556 self_arg = ExprNodes.NoneCheckNode(
1557 self_arg, "PyExc_TypeError",
1558 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
1559 attr_name, node.function.obj.name))
1560 else:
1561 self_arg = ExprNodes.NoneCheckNode(
1562 self_arg, "PyExc_AttributeError",
1563 "'NoneType' object has no attribute '%s'" % attr_name)
1564 args[0] = self_arg
1565 return ExprNodes.PythonCapiCallNode(
1566 node.pos, name, func_type,
1567 args = args,
1568 is_temp = node.is_temp
1569 )
1572 append_utility_code = UtilityCode(
1573 proto = """
1574 static INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1575 if (likely(PyList_CheckExact(L))) {
1576 if (PyList_Append(L, x) < 0) return NULL;
1577 Py_INCREF(Py_None);
1578 return Py_None; /* this is just to have an accurate signature */
1579 }
1580 else {
1581 PyObject *r, *m;
1582 m = __Pyx_GetAttrString(L, "append");
1583 if (!m) return NULL;
1584 r = PyObject_CallFunctionObjArgs(m, x, NULL);
1585 Py_DECREF(m);
1586 return r;
1587 }
1588 }
1589 """,
1590 impl = ""
1591 )
1594 pop_utility_code = UtilityCode(
1595 proto = """
1596 static INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1597 if (likely(PyList_CheckExact(L))
1598 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
1599 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
1600 Py_SIZE(L) -= 1;
1601 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
1602 }
1603 else {
1604 PyObject *r, *m;
1605 m = __Pyx_GetAttrString(L, "pop");
1606 if (!m) return NULL;
1607 r = PyObject_CallObject(m, NULL);
1608 Py_DECREF(m);
1609 return r;
1610 }
1611 }
1612 """,
1613 impl = ""
1614 )
1616 pop_index_utility_code = UtilityCode(
1617 proto = """
1618 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
1619 """,
1620 impl = """
1621 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
1622 PyObject *r, *m, *t, *py_ix;
1623 if (likely(PyList_CheckExact(L))) {
1624 Py_ssize_t size = PyList_GET_SIZE(L);
1625 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
1626 if (ix < 0) {
1627 ix += size;
1628 }
1629 if (likely(0 <= ix && ix < size)) {
1630 Py_ssize_t i;
1631 PyObject* v = PyList_GET_ITEM(L, ix);
1632 Py_SIZE(L) -= 1;
1633 size -= 1;
1634 for(i=ix; i<size; i++) {
1635 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
1636 }
1637 return v;
1638 }
1639 }
1640 }
1641 py_ix = t = NULL;
1642 m = __Pyx_GetAttrString(L, "pop");
1643 if (!m) goto bad;
1644 py_ix = PyInt_FromSsize_t(ix);
1645 if (!py_ix) goto bad;
1646 t = PyTuple_New(1);
1647 if (!t) goto bad;
1648 PyTuple_SET_ITEM(t, 0, py_ix);
1649 py_ix = NULL;
1650 r = PyObject_CallObject(m, t);
1651 Py_DECREF(m);
1652 Py_DECREF(t);
1653 return r;
1654 bad:
1655 Py_XDECREF(m);
1656 Py_XDECREF(t);
1657 Py_XDECREF(py_ix);
1658 return NULL;
1659 }
1660 """
1661 )
1664 pyobject_as_double_utility_code = UtilityCode(
1665 proto = '''
1666 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
1668 #define __Pyx_PyObject_AsDouble(obj) \\
1669 ((likely(PyFloat_CheckExact(obj))) ? \\
1670 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
1671 ''',
1672 impl='''
1673 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
1674 PyObject* float_value;
1675 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
1676 return PyFloat_AsDouble(obj);
1677 } else if (PyUnicode_CheckExact(obj) || PyString_CheckExact(obj)) {
1678 #if PY_MAJOR_VERSION >= 3
1679 float_value = PyFloat_FromString(obj);
1680 #else
1681 float_value = PyFloat_FromString(obj, 0);
1682 #endif
1683 } else {
1684 PyObject* args = PyTuple_New(1);
1685 if (unlikely(!args)) goto bad;
1686 PyTuple_SET_ITEM(args, 0, obj);
1687 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
1688 PyTuple_SET_ITEM(args, 0, 0);
1689 Py_DECREF(args);
1690 }
1691 if (likely(float_value)) {
1692 double value = PyFloat_AS_DOUBLE(float_value);
1693 Py_DECREF(float_value);
1694 return value;
1695 }
1696 bad:
1697 return (double)-1;
1698 }
1699 '''
1700 )
1703 pytype_utility_code = UtilityCode(
1704 proto = """
1705 static INLINE PyObject* __Pyx_Type(PyObject* o) {
1706 PyObject* type = (PyObject*) Py_TYPE(o);
1707 Py_INCREF(type);
1708 return type;
1709 }
1710 """
1711 )
1714 include_string_h_utility_code = UtilityCode(
1715 proto = """
1716 #include <string.h>
1717 """
1718 )
1721 tpnew_utility_code = UtilityCode(
1722 proto = """
1723 static INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1724 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
1725 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
1726 }
1727 """ % {'TUPLE' : Naming.empty_tuple}
1728 )
1731 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
1732 """Calculate the result of constant expressions to store it in
1733 ``expr_node.constant_result``, and replace trivial cases by their
1734 constant result.
1735 """
1736 def _calculate_const(self, node):
1737 if node.constant_result is not ExprNodes.constant_value_not_set:
1738 return
1740 # make sure we always set the value
1741 not_a_constant = ExprNodes.not_a_constant
1742 node.constant_result = not_a_constant
1744 # check if all children are constant
1745 children = self.visitchildren(node)
1746 for child_result in children.itervalues():
1747 if type(child_result) is list:
1748 for child in child_result:
1749 if child.constant_result is not_a_constant:
1750 return
1751 elif child_result.constant_result is not_a_constant:
1752 return
1754 # now try to calculate the real constant value
1755 try:
1756 node.calculate_constant_result()
1757 # if node.constant_result is not ExprNodes.not_a_constant:
1758 # print node.__class__.__name__, node.constant_result
1759 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
1760 # ignore all 'normal' errors here => no constant result
1761 pass
1762 except Exception:
1763 # this looks like a real error
1764 import traceback, sys
1765 traceback.print_exc(file=sys.stdout)
1767 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
1768 ExprNodes.LongNode, ExprNodes.FloatNode)
1770 def _widest_node_class(self, *nodes):
1771 try:
1772 return self.NODE_TYPE_ORDER[
1773 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
1774 except ValueError:
1775 return None
1777 def visit_ExprNode(self, node):
1778 self._calculate_const(node)
1779 return node
1781 def visit_BinopNode(self, node):
1782 self._calculate_const(node)
1783 if node.constant_result is ExprNodes.not_a_constant:
1784 return node
1785 if isinstance(node.constant_result, float):
1786 # We calculate float constants to make them available to
1787 # the compiler, but we do not aggregate them into a
1788 # constant node to prevent any loss of precision.
1789 return node
1790 if not node.operand1.is_literal or not node.operand2.is_literal:
1791 # We calculate other constants to make them available to
1792 # the compiler, but we only aggregate constant nodes
1793 # recursively, so non-const nodes are straight out.
1794 return node
1796 # now inject a new constant node with the calculated value
1797 try:
1798 type1, type2 = node.operand1.type, node.operand2.type
1799 if type1 is None or type2 is None:
1800 return node
1801 except AttributeError:
1802 return node
1804 if type1 is type2:
1805 new_node = node.operand1
1806 else:
1807 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
1808 if type(node.operand1) is type(node.operand2):
1809 new_node = node.operand1
1810 new_node.type = widest_type
1811 elif type1 is widest_type:
1812 new_node = node.operand1
1813 elif type2 is widest_type:
1814 new_node = node.operand2
1815 else:
1816 target_class = self._widest_node_class(
1817 node.operand1, node.operand2)
1818 if target_class is None:
1819 return node
1820 new_node = target_class(pos=node.pos, type = widest_type)
1822 new_node.constant_result = node.constant_result
1823 new_node.value = str(node.constant_result)
1824 #new_node = new_node.coerce_to(node.type, self.current_scope)
1825 return new_node
1827 # in the future, other nodes can have their own handler method here
1828 # that can replace them with a constant result node
1830 visit_Node = Visitor.VisitorTransform.recurse_to_children
1833 class FinalOptimizePhase(Visitor.CythonTransform):
1834 """
1835 This visitor handles several commuting optimizations, and is run
1836 just before the C code generation phase.
1838 The optimizations currently implemented in this class are:
1839 - Eliminate None assignment and refcounting for first assignment.
1840 - isinstance -> typecheck for cdef types
1841 """
1842 def visit_SingleAssignmentNode(self, node):
1843 """Avoid redundant initialisation of local variables before their
1844 first assignment.
1845 """
1846 self.visitchildren(node)
1847 if node.first:
1848 lhs = node.lhs
1849 lhs.lhs_of_first_assignment = True
1850 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
1851 # Have variable initialized to 0 rather than None
1852 lhs.entry.init_to_none = False
1853 lhs.entry.init = 0
1854 return node
1856 def visit_SimpleCallNode(self, node):
1857 """Replace generic calls to isinstance(x, type) by a more efficient
1858 type check.
1859 """
1860 self.visitchildren(node)
1861 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
1862 if node.function.name == 'isinstance':
1863 type_arg = node.args[1]
1864 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
1865 from CythonScope import utility_scope
1866 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
1867 node.function.type = node.function.entry.type
1868 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
1869 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
1870 return node
