Cython has moved to github.
cython-devel
view Cython/Compiler/Optimize.py @ 2819:3bc6d034486a
INLINE -> CYTHON_INLINE to avoid conflicts
| author | Robert Bradshaw <robertwb@math.washington.edu> |
|---|---|
| date | Mon Jan 25 22:47:09 2010 -0800 (2 years ago) |
| parents | e557b0ea1381 |
| children | e36d5a315205 |
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
17 import codecs
19 try:
20 reduce
21 except NameError:
22 from functools import reduce
24 try:
25 set
26 except NameError:
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
31 nogil = False
33 def unwrap_node(node):
34 while isinstance(node, UtilNodes.ResultRefNode):
35 node = node.expression
36 return node
38 def is_common_value(a, b):
39 a = unwrap_node(a)
40 b = unwrap_node(b)
41 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
42 return a.name == b.name
43 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 return False
47 class IterationTransform(Visitor.VisitorTransform):
48 """Transform some common for-in loop patterns into efficient C loops:
50 - for-in-dict loop becomes a while loop calling PyDict_Next()
51 - for-in-enumerate is replaced by an external counter variable
52 - for-in-range loop becomes a plain C for loop
53 """
54 PyDict_Next_func_type = PyrexTypes.CFuncType(
55 PyrexTypes.c_bint_type, [
56 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
57 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
58 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
59 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
60 ])
62 PyDict_Next_name = EncodedString("PyDict_Next")
64 PyDict_Next_entry = Symtab.Entry(
65 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
67 visit_Node = Visitor.VisitorTransform.recurse_to_children
69 def visit_ModuleNode(self, node):
70 self.current_scope = node.scope
71 self.visitchildren(node)
72 return node
74 def visit_DefNode(self, node):
75 oldscope = self.current_scope
76 self.current_scope = node.entry.scope
77 self.visitchildren(node)
78 self.current_scope = oldscope
79 return node
81 def visit_ForInStatNode(self, node):
82 self.visitchildren(node)
83 return self._optimise_for_loop(node)
85 def _optimise_for_loop(self, node):
86 iterator = node.iterator.sequence
87 if iterator.type is Builtin.dict_type:
88 # like iterating over dict.keys()
89 return self._transform_dict_iteration(
90 node, dict_obj=iterator, keys=True, values=False)
92 # C array slice iteration?
93 if isinstance(iterator, ExprNodes.SliceIndexNode) and \
94 (iterator.base.type.is_array or iterator.base.type.is_ptr):
95 return self._transform_carray_iteration(node, iterator)
96 elif not isinstance(iterator, ExprNodes.SimpleCallNode):
97 return node
99 function = iterator.function
100 # dict iteration?
101 if isinstance(function, ExprNodes.AttributeNode) and \
102 function.obj.type == Builtin.dict_type:
103 dict_obj = function.obj
104 method = function.attribute
106 keys = values = False
107 if method == 'iterkeys':
108 keys = True
109 elif method == 'itervalues':
110 values = True
111 elif method == 'iteritems':
112 keys = values = True
113 else:
114 return node
115 return self._transform_dict_iteration(
116 node, dict_obj, keys, values)
118 # enumerate() ?
119 if iterator.self is None and function.is_name and \
120 function.entry and function.entry.is_builtin and \
121 function.name == 'enumerate':
122 return self._transform_enumerate_iteration(node, iterator)
124 # range() iteration?
125 if Options.convert_range and node.target.type.is_int:
126 if iterator.self is None and function.is_name and \
127 function.entry and function.entry.is_builtin and \
128 function.name in ('range', 'xrange'):
129 return self._transform_range_iteration(node, iterator)
131 return node
133 def _transform_carray_iteration(self, node, slice_node):
134 start = slice_node.start
135 stop = slice_node.stop
136 step = None
137 if not stop:
138 return node
140 carray_ptr = slice_node.base.coerce_to_simple(self.current_scope)
142 if start and start.constant_result != 0:
143 start_ptr_node = ExprNodes.AddNode(
144 start.pos,
145 operand1=carray_ptr,
146 operator='+',
147 operand2=start,
148 type=carray_ptr.type)
149 else:
150 start_ptr_node = carray_ptr
152 stop_ptr_node = ExprNodes.AddNode(
153 stop.pos,
154 operand1=carray_ptr,
155 operator='+',
156 operand2=stop,
157 type=carray_ptr.type
158 ).coerce_to_simple(self.current_scope)
160 counter = UtilNodes.TempHandle(carray_ptr.type)
161 counter_temp = counter.ref(node.target.pos)
163 if slice_node.base.type.is_string and node.target.type.is_pyobject:
164 # special case: char* -> bytes
165 target_value = ExprNodes.SliceIndexNode(
166 node.target.pos,
167 start=ExprNodes.IntNode(node.target.pos, value='0',
168 constant_result=0,
169 type=PyrexTypes.c_int_type),
170 stop=ExprNodes.IntNode(node.target.pos, value='1',
171 constant_result=1,
172 type=PyrexTypes.c_int_type),
173 base=counter_temp,
174 type=Builtin.bytes_type,
175 is_temp=1)
176 else:
177 target_value = ExprNodes.IndexNode(
178 node.target.pos,
179 index=ExprNodes.IntNode(node.target.pos, value='0',
180 constant_result=0,
181 type=PyrexTypes.c_int_type),
182 base=counter_temp,
183 is_buffer_access=False,
184 type=carray_ptr.type.base_type)
186 if target_value.type != node.target.type:
187 target_value = target_value.coerce_to(node.target.type,
188 self.current_scope)
190 target_assign = Nodes.SingleAssignmentNode(
191 pos = node.target.pos,
192 lhs = node.target,
193 rhs = target_value)
195 body = Nodes.StatListNode(
196 node.pos,
197 stats = [target_assign, node.body])
199 for_node = Nodes.ForFromStatNode(
200 node.pos,
201 bound1=start_ptr_node, relation1='<=',
202 target=counter_temp,
203 relation2='<', bound2=stop_ptr_node,
204 step=step, body=body,
205 else_clause=node.else_clause,
206 from_range=True)
208 return UtilNodes.TempsBlockNode(
209 node.pos, temps=[counter],
210 body=for_node)
212 def _transform_enumerate_iteration(self, node, enumerate_function):
213 args = enumerate_function.arg_tuple.args
214 if len(args) == 0:
215 error(enumerate_function.pos,
216 "enumerate() requires an iterable argument")
217 return node
218 elif len(args) > 1:
219 error(enumerate_function.pos,
220 "enumerate() takes at most 1 argument")
221 return node
223 if not node.target.is_sequence_constructor:
224 # leave this untouched for now
225 return node
226 targets = node.target.args
227 if len(targets) != 2:
228 # leave this untouched for now
229 return node
230 if not isinstance(targets[0], ExprNodes.NameNode):
231 # leave this untouched for now
232 return node
234 enumerate_target, iterable_target = targets
235 counter_type = enumerate_target.type
237 if not counter_type.is_pyobject and not counter_type.is_int:
238 # nothing we can do here, I guess
239 return node
241 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
242 value='0',
243 type=counter_type,
244 constant_result=0))
245 inc_expression = ExprNodes.AddNode(
246 enumerate_function.pos,
247 operand1 = temp,
248 operand2 = ExprNodes.IntNode(node.pos, value='1',
249 type=counter_type,
250 constant_result=1),
251 operator = '+',
252 type = counter_type,
253 is_temp = counter_type.is_pyobject
254 )
256 loop_body = [
257 Nodes.SingleAssignmentNode(
258 pos = enumerate_target.pos,
259 lhs = enumerate_target,
260 rhs = temp),
261 Nodes.SingleAssignmentNode(
262 pos = enumerate_target.pos,
263 lhs = temp,
264 rhs = inc_expression)
265 ]
267 if isinstance(node.body, Nodes.StatListNode):
268 node.body.stats = loop_body + node.body.stats
269 else:
270 loop_body.append(node.body)
271 node.body = Nodes.StatListNode(
272 node.body.pos,
273 stats = loop_body)
275 node.target = iterable_target
276 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
277 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
279 # recurse into loop to check for further optimisations
280 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
282 def _transform_range_iteration(self, node, range_function):
283 args = range_function.arg_tuple.args
284 if len(args) < 3:
285 step_pos = range_function.pos
286 step_value = 1
287 step = ExprNodes.IntNode(step_pos, value='1',
288 constant_result=1)
289 else:
290 step = args[2]
291 step_pos = step.pos
292 if not isinstance(step.constant_result, (int, long)):
293 # cannot determine step direction
294 return node
295 step_value = step.constant_result
296 if step_value == 0:
297 # will lead to an error elsewhere
298 return node
299 if not isinstance(step, ExprNodes.IntNode):
300 step = ExprNodes.IntNode(step_pos, value=str(step_value),
301 constant_result=step_value)
303 if step_value < 0:
304 step.value = str(-step_value)
305 relation1 = '>='
306 relation2 = '>'
307 else:
308 relation1 = '<='
309 relation2 = '<'
311 if len(args) == 1:
312 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
313 constant_result=0)
314 bound2 = args[0].coerce_to_integer(self.current_scope)
315 else:
316 bound1 = args[0].coerce_to_integer(self.current_scope)
317 bound2 = args[1].coerce_to_integer(self.current_scope)
318 step = step.coerce_to_integer(self.current_scope)
320 if not bound2.is_literal:
321 # stop bound must be immutable => keep it in a temp var
322 bound2_is_temp = True
323 bound2 = UtilNodes.LetRefNode(bound2)
324 else:
325 bound2_is_temp = False
327 for_node = Nodes.ForFromStatNode(
328 node.pos,
329 target=node.target,
330 bound1=bound1, relation1=relation1,
331 relation2=relation2, bound2=bound2,
332 step=step, body=node.body,
333 else_clause=node.else_clause,
334 from_range=True)
336 if bound2_is_temp:
337 for_node = UtilNodes.LetNode(bound2, for_node)
339 return for_node
341 def _transform_dict_iteration(self, node, dict_obj, keys, values):
342 py_object_ptr = PyrexTypes.c_void_ptr_type
344 temps = []
345 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
346 temps.append(temp)
347 dict_temp = temp.ref(dict_obj.pos)
348 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
349 temps.append(temp)
350 pos_temp = temp.ref(node.pos)
351 pos_temp_addr = ExprNodes.AmpersandNode(
352 node.pos, operand=pos_temp,
353 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
354 if keys:
355 temp = UtilNodes.TempHandle(py_object_ptr)
356 temps.append(temp)
357 key_temp = temp.ref(node.target.pos)
358 key_temp_addr = ExprNodes.AmpersandNode(
359 node.target.pos, operand=key_temp,
360 type=PyrexTypes.c_ptr_type(py_object_ptr))
361 else:
362 key_temp_addr = key_temp = ExprNodes.NullNode(
363 pos=node.target.pos)
364 if values:
365 temp = UtilNodes.TempHandle(py_object_ptr)
366 temps.append(temp)
367 value_temp = temp.ref(node.target.pos)
368 value_temp_addr = ExprNodes.AmpersandNode(
369 node.target.pos, operand=value_temp,
370 type=PyrexTypes.c_ptr_type(py_object_ptr))
371 else:
372 value_temp_addr = value_temp = ExprNodes.NullNode(
373 pos=node.target.pos)
375 key_target = value_target = node.target
376 tuple_target = None
377 if keys and values:
378 if node.target.is_sequence_constructor:
379 if len(node.target.args) == 2:
380 key_target, value_target = node.target.args
381 else:
382 # unusual case that may or may not lead to an error
383 return node
384 else:
385 tuple_target = node.target
387 def coerce_object_to(obj_node, dest_type):
388 if dest_type.is_pyobject:
389 if dest_type != obj_node.type:
390 if dest_type.is_extension_type or dest_type.is_builtin_type:
391 obj_node = ExprNodes.PyTypeTestNode(
392 obj_node, dest_type, self.current_scope, notnone=True)
393 result = ExprNodes.TypecastNode(
394 obj_node.pos,
395 operand = obj_node,
396 type = dest_type)
397 return (result, None)
398 else:
399 temp = UtilNodes.TempHandle(dest_type)
400 temps.append(temp)
401 temp_result = temp.ref(obj_node.pos)
402 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
403 def result(self):
404 return temp_result.result()
405 def generate_execution_code(self, code):
406 self.generate_result_code(code)
407 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
409 if isinstance(node.body, Nodes.StatListNode):
410 body = node.body
411 else:
412 body = Nodes.StatListNode(pos = node.body.pos,
413 stats = [node.body])
415 if tuple_target:
416 tuple_result = ExprNodes.TupleNode(
417 pos = tuple_target.pos,
418 args = [key_temp, value_temp],
419 is_temp = 1,
420 type = Builtin.tuple_type,
421 )
422 body.stats.insert(
423 0, Nodes.SingleAssignmentNode(
424 pos = tuple_target.pos,
425 lhs = tuple_target,
426 rhs = tuple_result))
427 else:
428 # execute all coercions before the assignments
429 coercion_stats = []
430 assign_stats = []
431 if keys:
432 temp_result, coercion = coerce_object_to(
433 key_temp, key_target.type)
434 if coercion:
435 coercion_stats.append(coercion)
436 assign_stats.append(
437 Nodes.SingleAssignmentNode(
438 pos = key_temp.pos,
439 lhs = key_target,
440 rhs = temp_result))
441 if values:
442 temp_result, coercion = coerce_object_to(
443 value_temp, value_target.type)
444 if coercion:
445 coercion_stats.append(coercion)
446 assign_stats.append(
447 Nodes.SingleAssignmentNode(
448 pos = value_temp.pos,
449 lhs = value_target,
450 rhs = temp_result))
451 body.stats[0:0] = coercion_stats + assign_stats
453 result_code = [
454 Nodes.SingleAssignmentNode(
455 pos = dict_obj.pos,
456 lhs = dict_temp,
457 rhs = dict_obj),
458 Nodes.SingleAssignmentNode(
459 pos = node.pos,
460 lhs = pos_temp,
461 rhs = ExprNodes.IntNode(node.pos, value='0',
462 constant_result=0)),
463 Nodes.WhileStatNode(
464 pos = node.pos,
465 condition = ExprNodes.SimpleCallNode(
466 pos = dict_obj.pos,
467 type = PyrexTypes.c_bint_type,
468 function = ExprNodes.NameNode(
469 pos = dict_obj.pos,
470 name = self.PyDict_Next_name,
471 type = self.PyDict_Next_func_type,
472 entry = self.PyDict_Next_entry),
473 args = [dict_temp, pos_temp_addr,
474 key_temp_addr, value_temp_addr]
475 ),
476 body = body,
477 else_clause = node.else_clause
478 )
479 ]
481 return UtilNodes.TempsBlockNode(
482 node.pos, temps=temps,
483 body=Nodes.StatListNode(
484 node.pos,
485 stats = result_code
486 ))
489 class SwitchTransform(Visitor.VisitorTransform):
490 """
491 This transformation tries to turn long if statements into C switch statements.
492 The requirement is that every clause be an (or of) var == value, where the var
493 is common among all clauses and both var and value are ints.
494 """
495 def extract_conditions(self, cond):
496 while True:
497 if isinstance(cond, ExprNodes.CoerceToTempNode):
498 cond = cond.arg
499 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
500 # this is what we get from the FlattenInListTransform
501 cond = cond.subexpression
502 elif isinstance(cond, ExprNodes.TypecastNode):
503 cond = cond.operand
504 else:
505 break
507 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
508 and cond.cascade is None
509 and cond.operator == '=='
510 and not cond.is_python_comparison()):
511 if is_common_value(cond.operand1, cond.operand1):
512 if cond.operand2.is_literal:
513 return cond.operand1, [cond.operand2]
514 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
515 return cond.operand1, [cond.operand2]
516 if is_common_value(cond.operand2, cond.operand2):
517 if cond.operand1.is_literal:
518 return cond.operand2, [cond.operand1]
519 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
520 return cond.operand2, [cond.operand1]
521 elif (isinstance(cond, ExprNodes.BoolBinopNode)
522 and cond.operator == 'or'):
523 t1, c1 = self.extract_conditions(cond.operand1)
524 t2, c2 = self.extract_conditions(cond.operand2)
525 if is_common_value(t1, t2):
526 return t1, c1+c2
527 return None, None
529 def visit_IfStatNode(self, node):
530 self.visitchildren(node)
531 common_var = None
532 case_count = 0
533 cases = []
534 for if_clause in node.if_clauses:
535 var, conditions = self.extract_conditions(if_clause.condition)
536 if var is None:
537 return node
538 elif common_var is not None and not is_common_value(var, common_var):
539 return node
540 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
541 return node
542 else:
543 common_var = var
544 case_count += len(conditions)
545 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
546 conditions = conditions,
547 body = if_clause.body))
548 if case_count < 2:
549 return node
551 common_var = unwrap_node(common_var)
552 return Nodes.SwitchStatNode(pos = node.pos,
553 test = common_var,
554 cases = cases,
555 else_clause = node.else_clause)
557 visit_Node = Visitor.VisitorTransform.recurse_to_children
560 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
561 """
562 This transformation flattens "x in [val1, ..., valn]" into a sequential list
563 of comparisons.
564 """
566 def visit_PrimaryCmpNode(self, node):
567 self.visitchildren(node)
568 if node.cascade is not None:
569 return node
570 elif node.operator == 'in':
571 conjunction = 'or'
572 eq_or_neq = '=='
573 elif node.operator == 'not_in':
574 conjunction = 'and'
575 eq_or_neq = '!='
576 else:
577 return node
579 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
580 return node
582 args = node.operand2.args
583 if len(args) == 0:
584 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
586 lhs = UtilNodes.ResultRefNode(node.operand1)
588 conds = []
589 for arg in args:
590 cond = ExprNodes.PrimaryCmpNode(
591 pos = node.pos,
592 operand1 = lhs,
593 operator = eq_or_neq,
594 operand2 = arg,
595 cascade = None)
596 conds.append(ExprNodes.TypecastNode(
597 pos = node.pos,
598 operand = cond,
599 type = PyrexTypes.c_bint_type))
600 def concat(left, right):
601 return ExprNodes.BoolBinopNode(
602 pos = node.pos,
603 operator = conjunction,
604 operand1 = left,
605 operand2 = right)
607 condition = reduce(concat, conds)
608 return UtilNodes.EvalWithTempExprNode(lhs, condition)
610 visit_Node = Visitor.VisitorTransform.recurse_to_children
613 class DropRefcountingTransform(Visitor.VisitorTransform):
614 """Drop ref-counting in safe places.
615 """
616 visit_Node = Visitor.VisitorTransform.recurse_to_children
618 def visit_ParallelAssignmentNode(self, node):
619 """
620 Parallel swap assignments like 'a,b = b,a' are safe.
621 """
622 left_names, right_names = [], []
623 left_indices, right_indices = [], []
624 temps = []
626 for stat in node.stats:
627 if isinstance(stat, Nodes.SingleAssignmentNode):
628 if not self._extract_operand(stat.lhs, left_names,
629 left_indices, temps):
630 return node
631 if not self._extract_operand(stat.rhs, right_names,
632 right_indices, temps):
633 return node
634 elif isinstance(stat, Nodes.CascadedAssignmentNode):
635 # FIXME
636 return node
637 else:
638 return node
640 if left_names or right_names:
641 # lhs/rhs names must be a non-redundant permutation
642 lnames = [ path for path, n in left_names ]
643 rnames = [ path for path, n in right_names ]
644 if set(lnames) != set(rnames):
645 return node
646 if len(set(lnames)) != len(right_names):
647 return node
649 if left_indices or right_indices:
650 # base name and index of index nodes must be a
651 # non-redundant permutation
652 lindices = []
653 for lhs_node in left_indices:
654 index_id = self._extract_index_id(lhs_node)
655 if not index_id:
656 return node
657 lindices.append(index_id)
658 rindices = []
659 for rhs_node in right_indices:
660 index_id = self._extract_index_id(rhs_node)
661 if not index_id:
662 return node
663 rindices.append(index_id)
665 if set(lindices) != set(rindices):
666 return node
667 if len(set(lindices)) != len(right_indices):
668 return node
670 # really supporting IndexNode requires support in
671 # __Pyx_GetItemInt(), so let's stop short for now
672 return node
674 temp_args = [t.arg for t in temps]
675 for temp in temps:
676 temp.use_managed_ref = False
678 for _, name_node in left_names + right_names:
679 if name_node not in temp_args:
680 name_node.use_managed_ref = False
682 for index_node in left_indices + right_indices:
683 index_node.use_managed_ref = False
685 return node
687 def _extract_operand(self, node, names, indices, temps):
688 node = unwrap_node(node)
689 if not node.type.is_pyobject:
690 return False
691 if isinstance(node, ExprNodes.CoerceToTempNode):
692 temps.append(node)
693 node = node.arg
694 name_path = []
695 obj_node = node
696 while isinstance(obj_node, ExprNodes.AttributeNode):
697 if obj_node.is_py_attr:
698 return False
699 name_path.append(obj_node.member)
700 obj_node = obj_node.obj
701 if isinstance(obj_node, ExprNodes.NameNode):
702 name_path.append(obj_node.name)
703 names.append( ('.'.join(name_path[::-1]), node) )
704 elif isinstance(node, ExprNodes.IndexNode):
705 if node.base.type != Builtin.list_type:
706 return False
707 if not node.index.type.is_int:
708 return False
709 if not isinstance(node.base, ExprNodes.NameNode):
710 return False
711 indices.append(node)
712 else:
713 return False
714 return True
716 def _extract_index_id(self, index_node):
717 base = index_node.base
718 index = index_node.index
719 if isinstance(index, ExprNodes.NameNode):
720 index_val = index.name
721 elif isinstance(index, ExprNodes.ConstNode):
722 # FIXME:
723 return None
724 else:
725 return None
726 return (base.name, index_val)
729 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
730 """Optimize some common calls to builtin types *before* the type
731 analysis phase and *after* the declarations analysis phase.
733 This transform cannot make use of any argument types, but it can
734 restructure the tree in a way that the type analysis phase can
735 respond to.
737 Introducing C function calls here may not be a good idea. Move
738 them to the OptimizeBuiltinCalls transform instead, which runs
739 after type analyis.
740 """
741 # only intercept on call nodes
742 visit_Node = Visitor.VisitorTransform.recurse_to_children
744 def visit_SimpleCallNode(self, node):
745 self.visitchildren(node)
746 function = node.function
747 if not self._function_is_builtin_name(function):
748 return node
749 return self._dispatch_to_handler(node, function, node.args)
751 def visit_GeneralCallNode(self, node):
752 self.visitchildren(node)
753 function = node.function
754 if not self._function_is_builtin_name(function):
755 return node
756 arg_tuple = node.positional_args
757 if not isinstance(arg_tuple, ExprNodes.TupleNode):
758 return node
759 args = arg_tuple.args
760 return self._dispatch_to_handler(
761 node, function, args, node.keyword_args)
763 def _function_is_builtin_name(self, function):
764 if not function.is_name:
765 return False
766 entry = self.env_stack[-1].lookup(function.name)
767 if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
768 return False
769 return True
771 def _dispatch_to_handler(self, node, function, args, kwargs=None):
772 if kwargs is None:
773 handler_name = '_handle_simple_function_%s' % function.name
774 else:
775 handler_name = '_handle_general_function_%s' % function.name
776 handle_call = getattr(self, handler_name, None)
777 if handle_call is not None:
778 if kwargs is None:
779 return handle_call(node, args)
780 else:
781 return handle_call(node, args, kwargs)
782 return node
784 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
785 node.function = ExprNodes.PythonCapiFunctionNode(
786 node.function.pos, node.function.name, cname, func_type,
787 utility_code = utility_code)
789 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
790 if not expected: # None or 0
791 arg_str = ''
792 elif isinstance(expected, basestring) or expected > 1:
793 arg_str = '...'
794 elif expected == 1:
795 arg_str = 'x'
796 else:
797 arg_str = ''
798 if expected is not None:
799 expected_str = 'expected %s, ' % expected
800 else:
801 expected_str = ''
802 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
803 function_name, arg_str, expected_str, len(args)))
805 # specific handlers for simple call nodes
807 def _handle_simple_function_set(self, node, pos_args):
808 """Replace set([a,b,...]) by a literal set {a,b,...} and
809 set([ x for ... ]) by a literal { x for ... }.
810 """
811 arg_count = len(pos_args)
812 if arg_count == 0:
813 return ExprNodes.SetNode(node.pos, args=[],
814 type=Builtin.set_type)
815 if arg_count > 1:
816 return node
817 iterable = pos_args[0]
818 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
819 return ExprNodes.SetNode(node.pos, args=iterable.args)
820 elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
821 isinstance(iterable.target, (ExprNodes.ListNode,
822 ExprNodes.SetNode)):
823 iterable.target = ExprNodes.SetNode(node.pos, args=[])
824 iterable.pos = node.pos
825 return iterable
826 else:
827 return node
829 def _handle_simple_function_dict(self, node, pos_args):
830 """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
831 """
832 if len(pos_args) != 1:
833 return node
834 arg = pos_args[0]
835 if isinstance(arg, ExprNodes.ComprehensionNode) and \
836 isinstance(arg.target, (ExprNodes.ListNode,
837 ExprNodes.SetNode)):
838 append_node = arg.append
839 if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
840 len(append_node.expr.args) == 2:
841 key_node, value_node = append_node.expr.args
842 target_node = ExprNodes.DictNode(
843 pos=arg.target.pos, key_value_pairs=[])
844 new_append_node = ExprNodes.DictComprehensionAppendNode(
845 append_node.pos, target=target_node,
846 key_expr=key_node, value_expr=value_node)
847 arg.target = target_node
848 arg.type = target_node.type
849 replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
850 return replace_in(arg)
851 return node
853 def _handle_simple_function_float(self, node, pos_args):
854 if len(pos_args) == 0:
855 return ExprNodes.FloatNode(node.pos, value='0.0')
856 if len(pos_args) > 1:
857 self._error_wrong_arg_count('float', node, pos_args, 1)
858 return node
860 # specific handlers for general call nodes
862 def _handle_general_function_dict(self, node, pos_args, kwargs):
863 """Replace dict(a=b,c=d,...) by the underlying keyword dict
864 construction which is done anyway.
865 """
866 if len(pos_args) > 0:
867 return node
868 if not isinstance(kwargs, ExprNodes.DictNode):
869 return node
870 if node.starstar_arg:
871 # we could optimize this by updating the kw dict instead
872 return node
873 return kwargs
876 class OptimizeBuiltinCalls(Visitor.EnvTransform):
877 """Optimize some common methods calls and instantiation patterns
878 for builtin types *after* the type analysis phase.
880 Running after type analysis, this transform can only perform
881 function replacements that do not alter the function return type
882 in a way that was not anticipated by the type analysis.
883 """
884 # only intercept on call nodes
885 visit_Node = Visitor.VisitorTransform.recurse_to_children
887 def visit_GeneralCallNode(self, node):
888 self.visitchildren(node)
889 function = node.function
890 if not function.type.is_pyobject:
891 return node
892 arg_tuple = node.positional_args
893 if not isinstance(arg_tuple, ExprNodes.TupleNode):
894 return node
895 args = arg_tuple.args
896 return self._dispatch_to_handler(
897 node, function, args, node.keyword_args)
899 def visit_SimpleCallNode(self, node):
900 self.visitchildren(node)
901 function = node.function
902 if function.type.is_pyobject:
903 arg_tuple = node.arg_tuple
904 if not isinstance(arg_tuple, ExprNodes.TupleNode):
905 return node
906 args = arg_tuple.args
907 else:
908 args = node.args
909 return self._dispatch_to_handler(
910 node, function, args)
912 ### cleanup to avoid redundant coercions to/from Python types
914 def _visit_PyTypeTestNode(self, node):
915 # disabled - appears to break assignments in some cases, and
916 # also drops a None check, which might still be required
917 """Flatten redundant type checks after tree changes.
918 """
919 old_arg = node.arg
920 self.visitchildren(node)
921 if old_arg is node.arg or node.arg.type != node.type:
922 return node
923 return node.arg
925 def visit_CoerceFromPyTypeNode(self, node):
926 """Drop redundant conversion nodes after tree changes.
928 Also, optimise away calls to Python's builtin int() and
929 float() if the result is going to be coerced back into a C
930 type anyway.
931 """
932 self.visitchildren(node)
933 arg = node.arg
934 if not arg.type.is_pyobject:
935 # no Python conversion left at all, just do a C coercion instead
936 if node.type == arg.type:
937 return arg
938 else:
939 return arg.coerce_to(node.type, self.env_stack[-1])
940 if not isinstance(arg, ExprNodes.SimpleCallNode):
941 return node
942 if not (node.type.is_int or node.type.is_float):
943 return node
944 function = arg.function
945 if not isinstance(function, ExprNodes.NameNode) \
946 or not function.type.is_builtin_type \
947 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
948 return node
949 args = arg.arg_tuple.args
950 if len(args) != 1:
951 return node
952 func_arg = args[0]
953 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
954 func_arg = func_arg.arg
955 elif func_arg.type.is_pyobject:
956 # play safe: Python conversion might work on all sorts of things
957 return node
958 if function.name == 'int':
959 if func_arg.type.is_int or node.type.is_int:
960 if func_arg.type == node.type:
961 return func_arg
962 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
963 return ExprNodes.TypecastNode(
964 node.pos, operand=func_arg, type=node.type)
965 elif function.name == 'float':
966 if func_arg.type.is_float or node.type.is_float:
967 if func_arg.type == node.type:
968 return func_arg
969 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
970 return ExprNodes.TypecastNode(
971 node.pos, operand=func_arg, type=node.type)
972 return node
974 ### dispatch to specific optimisers
976 def _find_handler(self, match_name, has_kwargs):
977 call_type = has_kwargs and 'general' or 'simple'
978 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
979 if handler is None:
980 handler = getattr(self, '_handle_any_%s' % match_name, None)
981 return handler
983 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
984 if function.is_name:
985 # we only consider functions that are either builtin
986 # Python functions or builtins that were already replaced
987 # into a C function call (defined in the builtin scope)
988 if not function.entry:
989 return node
990 is_builtin = function.entry.is_builtin \
991 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
992 if not is_builtin:
993 return node
994 function_handler = self._find_handler(
995 "function_%s" % function.name, kwargs)
996 if function_handler is None:
997 return node
998 if kwargs:
999 return function_handler(node, arg_list, kwargs)
1000 else:
1001 return function_handler(node, arg_list)
1002 elif function.is_attribute and function.type.is_pyobject:
1003 attr_name = function.attribute
1004 self_arg = function.obj
1005 obj_type = self_arg.type
1006 is_unbound_method = False
1007 if obj_type.is_builtin_type:
1008 if obj_type is Builtin.type_type and arg_list and \
1009 arg_list[0].type.is_pyobject:
1010 # calling an unbound method like 'list.append(L,x)'
1011 # (ignoring 'type.mro()' here ...)
1012 type_name = function.obj.name
1013 self_arg = None
1014 is_unbound_method = True
1015 else:
1016 type_name = obj_type.name
1017 else:
1018 type_name = "object" # safety measure
1019 method_handler = self._find_handler(
1020 "method_%s_%s" % (type_name, attr_name), kwargs)
1021 if method_handler is None:
1022 if attr_name in TypeSlots.method_name_to_slot \
1023 or attr_name == '__new__':
1024 method_handler = self._find_handler(
1025 "slot%s" % attr_name, kwargs)
1026 if method_handler is None:
1027 return node
1028 if self_arg is not None:
1029 arg_list = [self_arg] + list(arg_list)
1030 if kwargs:
1031 return method_handler(node, arg_list, kwargs, is_unbound_method)
1032 else:
1033 return method_handler(node, arg_list, is_unbound_method)
1034 else:
1035 return node
1037 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1038 if not expected: # None or 0
1039 arg_str = ''
1040 elif isinstance(expected, basestring) or expected > 1:
1041 arg_str = '...'
1042 elif expected == 1:
1043 arg_str = 'x'
1044 else:
1045 arg_str = ''
1046 if expected is not None:
1047 expected_str = 'expected %s, ' % expected
1048 else:
1049 expected_str = ''
1050 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1051 function_name, arg_str, expected_str, len(args)))
1053 ### builtin types
1055 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1056 Builtin.dict_type, [
1057 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1058 ])
1060 def _handle_simple_function_dict(self, node, pos_args):
1061 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1062 """
1063 if len(pos_args) != 1:
1064 return node
1065 arg = pos_args[0]
1066 if arg.type is Builtin.dict_type:
1067 arg = ExprNodes.NoneCheckNode(
1068 arg, "PyExc_TypeError", "'NoneType' is not iterable")
1069 return ExprNodes.PythonCapiCallNode(
1070 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1071 args = [arg],
1072 is_temp = node.is_temp
1073 )
1074 return node
1076 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1077 Builtin.tuple_type, [
1078 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1079 ])
1081 def _handle_simple_function_tuple(self, node, pos_args):
1082 """Replace tuple([...]) by a call to PyList_AsTuple.
1083 """
1084 if len(pos_args) != 1:
1085 return node
1086 list_arg = pos_args[0]
1087 if list_arg.type is not Builtin.list_type:
1088 return node
1089 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1090 ExprNodes.ListNode)):
1091 pos_args[0] = ExprNodes.NoneCheckNode(
1092 list_arg, "PyExc_TypeError",
1093 "'NoneType' object is not iterable")
1095 return ExprNodes.PythonCapiCallNode(
1096 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1097 args = pos_args,
1098 is_temp = node.is_temp
1099 )
1101 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1102 PyrexTypes.c_double_type, [
1103 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1104 ],
1105 exception_value = "((double)-1)",
1106 exception_check = True)
1108 def _handle_simple_function_float(self, node, pos_args):
1109 # Note: this requires the float() function to be typed as
1110 # returning a C 'double'
1111 if len(pos_args) != 1:
1112 self._error_wrong_arg_count('float', node, pos_args, 1)
1113 return node
1114 func_arg = pos_args[0]
1115 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1116 func_arg = func_arg.arg
1117 if func_arg.type is PyrexTypes.c_double_type:
1118 return func_arg
1119 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1120 return ExprNodes.TypecastNode(
1121 node.pos, operand=func_arg, type=node.type)
1122 return ExprNodes.PythonCapiCallNode(
1123 node.pos, "__Pyx_PyObject_AsDouble",
1124 self.PyObject_AsDouble_func_type,
1125 args = pos_args,
1126 is_temp = node.is_temp,
1127 utility_code = pyobject_as_double_utility_code,
1128 py_name = "float")
1130 ### builtin functions
1132 PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
1133 PyrexTypes.py_object_type, [
1134 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1135 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1136 ])
1138 PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
1139 PyrexTypes.py_object_type, [
1140 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1141 PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1142 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1143 ])
1145 def _handle_simple_function_getattr(self, node, pos_args):
1146 if len(pos_args) == 2:
1147 return ExprNodes.PythonCapiCallNode(
1148 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
1149 args = pos_args,
1150 is_temp = node.is_temp)
1151 elif len(pos_args) == 3:
1152 return ExprNodes.PythonCapiCallNode(
1153 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
1154 args = pos_args,
1155 is_temp = node.is_temp,
1156 utility_code = Builtin.getattr3_utility_code)
1157 else:
1158 self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
1159 return node
1161 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1162 PyrexTypes.c_size_t_type, [
1163 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1164 ])
1166 def _handle_simple_function_len(self, node, pos_args):
1167 # note: this only works because we already replaced len() by
1168 # PyObject_Length() which returns a Py_ssize_t instead of a
1169 # Python object, so we can return a plain size_t instead
1170 # without caring about Python object conversion etc.
1171 if len(pos_args) != 1:
1172 self._error_wrong_arg_count('len', node, pos_args, 1)
1173 return node
1174 arg = pos_args[0]
1175 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1176 arg = arg.arg
1177 if not arg.type.is_string:
1178 return node
1179 node = ExprNodes.PythonCapiCallNode(
1180 node.pos, "strlen", self.Pyx_strlen_func_type,
1181 args = [arg],
1182 is_temp = node.is_temp,
1183 utility_code = include_string_h_utility_code
1184 )
1185 return node
1187 Pyx_Type_func_type = PyrexTypes.CFuncType(
1188 Builtin.type_type, [
1189 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1190 ])
1192 def _handle_simple_function_type(self, node, pos_args):
1193 if len(pos_args) != 1:
1194 return node
1195 node = ExprNodes.PythonCapiCallNode(
1196 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1197 args = pos_args,
1198 is_temp = False)
1199 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1201 ### special methods
1203 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1204 PyrexTypes.py_object_type, [
1205 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1206 ])
1208 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1209 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1210 """
1211 obj = node.function.obj
1212 if not is_unbound_method or len(args) != 1:
1213 return node
1214 type_arg = args[0]
1215 if not obj.is_name or not type_arg.is_name:
1216 # play safe
1217 return node
1218 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1219 # not a known type, play safe
1220 return node
1221 if not type_arg.type_entry or not obj.type_entry:
1222 if obj.name != type_arg.name:
1223 return node
1224 # otherwise, we know it's a type and we know it's the same
1225 # type for both - that should do
1226 elif type_arg.type_entry != obj.type_entry:
1227 # different types - may or may not lead to an error at runtime
1228 return node
1230 # FIXME: we could potentially look up the actual tp_new C
1231 # method of the extension type and call that instead of the
1232 # generic slot. That would also allow us to pass parameters
1233 # efficiently.
1235 if not type_arg.type_entry:
1236 # arbitrary variable, needs a None check for safety
1237 type_arg = ExprNodes.NoneCheckNode(
1238 type_arg, "PyExc_TypeError",
1239 "object.__new__(X): X is not a type object (NoneType)")
1241 return ExprNodes.PythonCapiCallNode(
1242 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1243 args = [type_arg],
1244 utility_code = tpnew_utility_code,
1245 is_temp = node.is_temp
1246 )
1248 ### methods of builtin types
1250 PyObject_Append_func_type = PyrexTypes.CFuncType(
1251 PyrexTypes.py_object_type, [
1252 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1253 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1254 ])
1256 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
1257 # X.append() is almost always referring to a list
1258 if len(args) != 2:
1259 return node
1261 return ExprNodes.PythonCapiCallNode(
1262 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
1263 args = args,
1264 is_temp = node.is_temp,
1265 utility_code = append_utility_code
1266 )
1268 PyObject_Pop_func_type = PyrexTypes.CFuncType(
1269 PyrexTypes.py_object_type, [
1270 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1271 ])
1273 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
1274 PyrexTypes.py_object_type, [
1275 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1276 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
1277 ])
1279 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
1280 # X.pop([n]) is almost always referring to a list
1281 if len(args) == 1:
1282 return ExprNodes.PythonCapiCallNode(
1283 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
1284 args = args,
1285 is_temp = node.is_temp,
1286 utility_code = pop_utility_code
1287 )
1288 elif len(args) == 2:
1289 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
1290 original_type = args[1].arg.type
1291 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
1292 args[1] = args[1].arg
1293 return ExprNodes.PythonCapiCallNode(
1294 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
1295 args = args,
1296 is_temp = node.is_temp,
1297 utility_code = pop_index_utility_code
1298 )
1300 return node
1302 PyList_Append_func_type = PyrexTypes.CFuncType(
1303 PyrexTypes.c_int_type, [
1304 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
1305 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
1306 ],
1307 exception_value = "-1")
1309 def _handle_simple_method_list_append(self, node, args, is_unbound_method):
1310 if len(args) != 2:
1311 self._error_wrong_arg_count('list.append', node, args, 2)
1312 return node
1313 return self._substitute_method_call(
1314 node, "PyList_Append", self.PyList_Append_func_type,
1315 'append', is_unbound_method, args)
1317 single_param_func_type = PyrexTypes.CFuncType(
1318 PyrexTypes.c_int_type, [
1319 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1320 ],
1321 exception_value = "-1")
1323 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
1324 if len(args) != 1:
1325 return node
1326 return self._substitute_method_call(
1327 node, "PyList_Sort", self.single_param_func_type,
1328 'sort', is_unbound_method, args)
1330 def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
1331 if len(args) != 1:
1332 self._error_wrong_arg_count('list.reverse', node, args, 1)
1333 return node
1334 return self._substitute_method_call(
1335 node, "PyList_Reverse", self.single_param_func_type,
1336 'reverse', is_unbound_method, args)
1338 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
1339 Builtin.bytes_type, [
1340 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1341 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1342 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1343 ],
1344 exception_value = "NULL")
1346 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
1347 Builtin.bytes_type, [
1348 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1349 ],
1350 exception_value = "NULL")
1352 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
1353 'unicode_escape', 'raw_unicode_escape']
1355 _special_codecs = [ (name, codecs.getencoder(name))
1356 for name in _special_encodings ]
1358 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
1359 if len(args) < 1 or len(args) > 3:
1360 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1361 return node
1363 string_node = args[0]
1365 if len(args) == 1:
1366 null_node = ExprNodes.NullNode(node.pos)
1367 return self._substitute_method_call(
1368 node, "PyUnicode_AsEncodedString",
1369 self.PyUnicode_AsEncodedString_func_type,
1370 'encode', is_unbound_method, [string_node, null_node, null_node])
1372 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1373 if parameters is None:
1374 return node
1375 encoding, encoding_node, error_handling, error_handling_node = parameters
1377 if isinstance(string_node, ExprNodes.UnicodeNode):
1378 # constant, so try to do the encoding at compile time
1379 try:
1380 value = string_node.value.encode(encoding, error_handling)
1381 except:
1382 # well, looks like we can't
1383 pass
1384 else:
1385 value = BytesLiteral(value)
1386 value.encoding = encoding
1387 return ExprNodes.BytesNode(
1388 string_node.pos, value=value, type=Builtin.bytes_type)
1390 if error_handling == 'strict':
1391 # try to find a specific encoder function
1392 codec_name = self._find_special_codec_name(encoding)
1393 if codec_name is not None:
1394 encode_function = "PyUnicode_As%sString" % codec_name
1395 return self._substitute_method_call(
1396 node, encode_function,
1397 self.PyUnicode_AsXyzString_func_type,
1398 'encode', is_unbound_method, [string_node])
1400 return self._substitute_method_call(
1401 node, "PyUnicode_AsEncodedString",
1402 self.PyUnicode_AsEncodedString_func_type,
1403 'encode', is_unbound_method,
1404 [string_node, encoding_node, error_handling_node])
1406 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
1407 Builtin.unicode_type, [
1408 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1409 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1410 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1411 ],
1412 exception_value = "NULL")
1414 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
1415 Builtin.unicode_type, [
1416 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
1417 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
1418 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
1419 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1420 ],
1421 exception_value = "NULL")
1423 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
1424 if len(args) < 1 or len(args) > 3:
1425 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
1426 return node
1427 temps = []
1428 if isinstance(args[0], ExprNodes.SliceIndexNode):
1429 index_node = args[0]
1430 string_node = index_node.base
1431 if not string_node.type.is_string:
1432 # nothing to optimise here
1433 return node
1434 start, stop = index_node.start, index_node.stop
1435 if not start or start.constant_result == 0:
1436 start = None
1437 else:
1438 if start.type.is_pyobject:
1439 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1440 if stop:
1441 start = UtilNodes.LetRefNode(start)
1442 temps.append(start)
1443 string_node = ExprNodes.AddNode(pos=start.pos,
1444 operand1=string_node,
1445 operator='+',
1446 operand2=start,
1447 is_temp=False,
1448 type=string_node.type
1449 )
1450 if stop and stop.type.is_pyobject:
1451 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1452 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
1453 and args[0].arg.type.is_string:
1454 # use strlen() to find the string length, just as CPython would
1455 start = stop = None
1456 string_node = args[0].arg
1457 else:
1458 # let Python do its job
1459 return node
1461 if not stop:
1462 if start or not string_node.is_name:
1463 string_node = UtilNodes.LetRefNode(string_node)
1464 temps.append(string_node)
1465 stop = ExprNodes.PythonCapiCallNode(
1466 string_node.pos, "strlen", self.Pyx_strlen_func_type,
1467 args = [string_node],
1468 is_temp = False,
1469 utility_code = include_string_h_utility_code,
1470 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
1471 elif start:
1472 stop = ExprNodes.SubNode(
1473 pos = stop.pos,
1474 operand1 = stop,
1475 operator = '-',
1476 operand2 = start,
1477 is_temp = False,
1478 type = PyrexTypes.c_py_ssize_t_type
1479 )
1481 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
1482 if parameters is None:
1483 return node
1484 encoding, encoding_node, error_handling, error_handling_node = parameters
1486 # try to find a specific encoder function
1487 codec_name = None
1488 if encoding is not None:
1489 codec_name = self._find_special_codec_name(encoding)
1490 if codec_name is not None:
1491 decode_function = "PyUnicode_Decode%s" % codec_name
1492 node = ExprNodes.PythonCapiCallNode(
1493 node.pos, decode_function,
1494 self.PyUnicode_DecodeXyz_func_type,
1495 args = [string_node, stop, error_handling_node],
1496 is_temp = node.is_temp,
1497 )
1498 else:
1499 node = ExprNodes.PythonCapiCallNode(
1500 node.pos, "PyUnicode_Decode",
1501 self.PyUnicode_Decode_func_type,
1502 args = [string_node, stop, encoding_node, error_handling_node],
1503 is_temp = node.is_temp,
1504 )
1506 for temp in temps[::-1]:
1507 node = UtilNodes.EvalWithTempExprNode(temp, node)
1508 return node
1510 def _find_special_codec_name(self, encoding):
1511 try:
1512 requested_codec = codecs.getencoder(encoding)
1513 except:
1514 return None
1515 for name, codec in self._special_codecs:
1516 if codec == requested_codec:
1517 if '_' in name:
1518 name = ''.join([ s.capitalize()
1519 for s in name.split('_')])
1520 return name
1521 return None
1523 def _unpack_encoding_and_error_mode(self, pos, args):
1524 encoding_node = args[1]
1525 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
1526 encoding_node = encoding_node.arg
1527 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1528 ExprNodes.BytesNode)):
1529 encoding = encoding_node.value
1530 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
1531 type=PyrexTypes.c_char_ptr_type)
1532 elif encoding_node.type.is_string:
1533 encoding = None
1534 else:
1535 return None
1537 null_node = ExprNodes.NullNode(pos)
1538 if len(args) == 3:
1539 error_handling_node = args[2]
1540 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
1541 error_handling_node = error_handling_node.arg
1542 if isinstance(error_handling_node,
1543 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
1544 ExprNodes.BytesNode)):
1545 error_handling = error_handling_node.value
1546 if error_handling == 'strict':
1547 error_handling_node = null_node
1548 else:
1549 error_handling_node = ExprNodes.BytesNode(
1550 error_handling_node.pos, value=error_handling,
1551 type=PyrexTypes.c_char_ptr_type)
1552 elif error_handling_node.type.is_string:
1553 error_handling = None
1554 else:
1555 return None
1556 else:
1557 error_handling = 'strict'
1558 error_handling_node = null_node
1560 return (encoding, encoding_node, error_handling, error_handling_node)
1562 def _substitute_method_call(self, node, name, func_type,
1563 attr_name, is_unbound_method, args=()):
1564 args = list(args)
1565 if args:
1566 self_arg = args[0]
1567 if is_unbound_method:
1568 self_arg = ExprNodes.NoneCheckNode(
1569 self_arg, "PyExc_TypeError",
1570 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
1571 attr_name, node.function.obj.name))
1572 else:
1573 self_arg = ExprNodes.NoneCheckNode(
1574 self_arg, "PyExc_AttributeError",
1575 "'NoneType' object has no attribute '%s'" % attr_name)
1576 args[0] = self_arg
1577 return ExprNodes.PythonCapiCallNode(
1578 node.pos, name, func_type,
1579 args = args,
1580 is_temp = node.is_temp
1581 )
1584 append_utility_code = UtilityCode(
1585 proto = """
1586 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
1587 if (likely(PyList_CheckExact(L))) {
1588 if (PyList_Append(L, x) < 0) return NULL;
1589 Py_INCREF(Py_None);
1590 return Py_None; /* this is just to have an accurate signature */
1591 }
1592 else {
1593 PyObject *r, *m;
1594 m = __Pyx_GetAttrString(L, "append");
1595 if (!m) return NULL;
1596 r = PyObject_CallFunctionObjArgs(m, x, NULL);
1597 Py_DECREF(m);
1598 return r;
1599 }
1600 }
1601 """,
1602 impl = ""
1603 )
1606 pop_utility_code = UtilityCode(
1607 proto = """
1608 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
1609 if (likely(PyList_CheckExact(L))
1610 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
1611 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
1612 Py_SIZE(L) -= 1;
1613 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
1614 }
1615 else {
1616 PyObject *r, *m;
1617 m = __Pyx_GetAttrString(L, "pop");
1618 if (!m) return NULL;
1619 r = PyObject_CallObject(m, NULL);
1620 Py_DECREF(m);
1621 return r;
1622 }
1623 }
1624 """,
1625 impl = ""
1626 )
1628 pop_index_utility_code = UtilityCode(
1629 proto = """
1630 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
1631 """,
1632 impl = """
1633 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
1634 PyObject *r, *m, *t, *py_ix;
1635 if (likely(PyList_CheckExact(L))) {
1636 Py_ssize_t size = PyList_GET_SIZE(L);
1637 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
1638 if (ix < 0) {
1639 ix += size;
1640 }
1641 if (likely(0 <= ix && ix < size)) {
1642 Py_ssize_t i;
1643 PyObject* v = PyList_GET_ITEM(L, ix);
1644 Py_SIZE(L) -= 1;
1645 size -= 1;
1646 for(i=ix; i<size; i++) {
1647 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
1648 }
1649 return v;
1650 }
1651 }
1652 }
1653 py_ix = t = NULL;
1654 m = __Pyx_GetAttrString(L, "pop");
1655 if (!m) goto bad;
1656 py_ix = PyInt_FromSsize_t(ix);
1657 if (!py_ix) goto bad;
1658 t = PyTuple_New(1);
1659 if (!t) goto bad;
1660 PyTuple_SET_ITEM(t, 0, py_ix);
1661 py_ix = NULL;
1662 r = PyObject_CallObject(m, t);
1663 Py_DECREF(m);
1664 Py_DECREF(t);
1665 return r;
1666 bad:
1667 Py_XDECREF(m);
1668 Py_XDECREF(t);
1669 Py_XDECREF(py_ix);
1670 return NULL;
1671 }
1672 """
1673 )
1676 pyobject_as_double_utility_code = UtilityCode(
1677 proto = '''
1678 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
1680 #define __Pyx_PyObject_AsDouble(obj) \\
1681 ((likely(PyFloat_CheckExact(obj))) ? \\
1682 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
1683 ''',
1684 impl='''
1685 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
1686 PyObject* float_value;
1687 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
1688 return PyFloat_AsDouble(obj);
1689 } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
1690 #if PY_MAJOR_VERSION >= 3
1691 float_value = PyFloat_FromString(obj);
1692 #else
1693 float_value = PyFloat_FromString(obj, 0);
1694 #endif
1695 } else {
1696 PyObject* args = PyTuple_New(1);
1697 if (unlikely(!args)) goto bad;
1698 PyTuple_SET_ITEM(args, 0, obj);
1699 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
1700 PyTuple_SET_ITEM(args, 0, 0);
1701 Py_DECREF(args);
1702 }
1703 if (likely(float_value)) {
1704 double value = PyFloat_AS_DOUBLE(float_value);
1705 Py_DECREF(float_value);
1706 return value;
1707 }
1708 bad:
1709 return (double)-1;
1710 }
1711 '''
1712 )
1715 include_string_h_utility_code = UtilityCode(
1716 proto = """
1717 #include <string.h>
1718 """
1719 )
1722 tpnew_utility_code = UtilityCode(
1723 proto = """
1724 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
1725 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
1726 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
1727 }
1728 """ % {'TUPLE' : Naming.empty_tuple}
1729 )
1732 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
1733 """Calculate the result of constant expressions to store it in
1734 ``expr_node.constant_result``, and replace trivial cases by their
1735 constant result.
1736 """
1737 def _calculate_const(self, node):
1738 if node.constant_result is not ExprNodes.constant_value_not_set:
1739 return
1741 # make sure we always set the value
1742 not_a_constant = ExprNodes.not_a_constant
1743 node.constant_result = not_a_constant
1745 # check if all children are constant
1746 children = self.visitchildren(node)
1747 for child_result in children.itervalues():
1748 if type(child_result) is list:
1749 for child in child_result:
1750 if child.constant_result is not_a_constant:
1751 return
1752 elif child_result.constant_result is not_a_constant:
1753 return
1755 # now try to calculate the real constant value
1756 try:
1757 node.calculate_constant_result()
1758 # if node.constant_result is not ExprNodes.not_a_constant:
1759 # print node.__class__.__name__, node.constant_result
1760 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
1761 # ignore all 'normal' errors here => no constant result
1762 pass
1763 except Exception:
1764 # this looks like a real error
1765 import traceback, sys
1766 traceback.print_exc(file=sys.stdout)
1768 NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
1769 ExprNodes.LongNode, ExprNodes.FloatNode)
1771 def _widest_node_class(self, *nodes):
1772 try:
1773 return self.NODE_TYPE_ORDER[
1774 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
1775 except ValueError:
1776 return None
1778 def visit_ExprNode(self, node):
1779 self._calculate_const(node)
1780 return node
1782 def visit_BinopNode(self, node):
1783 self._calculate_const(node)
1784 if node.constant_result is ExprNodes.not_a_constant:
1785 return node
1786 if isinstance(node.constant_result, float):
1787 # We calculate float constants to make them available to
1788 # the compiler, but we do not aggregate them into a
1789 # constant node to prevent any loss of precision.
1790 return node
1791 if not node.operand1.is_literal or not node.operand2.is_literal:
1792 # We calculate other constants to make them available to
1793 # the compiler, but we only aggregate constant nodes
1794 # recursively, so non-const nodes are straight out.
1795 return node
1797 # now inject a new constant node with the calculated value
1798 try:
1799 type1, type2 = node.operand1.type, node.operand2.type
1800 if type1 is None or type2 is None:
1801 return node
1802 except AttributeError:
1803 return node
1805 if type1 is type2:
1806 new_node = node.operand1
1807 else:
1808 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
1809 if type(node.operand1) is type(node.operand2):
1810 new_node = node.operand1
1811 new_node.type = widest_type
1812 elif type1 is widest_type:
1813 new_node = node.operand1
1814 elif type2 is widest_type:
1815 new_node = node.operand2
1816 else:
1817 target_class = self._widest_node_class(
1818 node.operand1, node.operand2)
1819 if target_class is None:
1820 return node
1821 new_node = target_class(pos=node.pos, type = widest_type)
1823 new_node.constant_result = node.constant_result
1824 new_node.value = str(node.constant_result)
1825 #new_node = new_node.coerce_to(node.type, self.current_scope)
1826 return new_node
1828 # in the future, other nodes can have their own handler method here
1829 # that can replace them with a constant result node
1831 visit_Node = Visitor.VisitorTransform.recurse_to_children
1834 class FinalOptimizePhase(Visitor.CythonTransform):
1835 """
1836 This visitor handles several commuting optimizations, and is run
1837 just before the C code generation phase.
1839 The optimizations currently implemented in this class are:
1840 - Eliminate None assignment and refcounting for first assignment.
1841 - isinstance -> typecheck for cdef types
1842 """
1843 def visit_SingleAssignmentNode(self, node):
1844 """Avoid redundant initialisation of local variables before their
1845 first assignment.
1846 """
1847 self.visitchildren(node)
1848 if node.first:
1849 lhs = node.lhs
1850 lhs.lhs_of_first_assignment = True
1851 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
1852 # Have variable initialized to 0 rather than None
1853 lhs.entry.init_to_none = False
1854 lhs.entry.init = 0
1855 return node
1857 def visit_SimpleCallNode(self, node):
1858 """Replace generic calls to isinstance(x, type) by a more efficient
1859 type check.
1860 """
1861 self.visitchildren(node)
1862 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
1863 if node.function.name == 'isinstance':
1864 type_arg = node.args[1]
1865 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
1866 from CythonScope import utility_scope
1867 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
1868 node.function.type = node.function.entry.type
1869 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
1870 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
1871 return node
