Cython has moved to github.

cython-devel

view Cython/Compiler/Optimize.py @ 1221:fe6e3dd9d513

Fix ticket #72, (compiler crash on bad code)
author Robert Bradshaw <robertwb@math.washington.edu>
date Wed Oct 08 00:05:25 2008 -0700 (3 years ago)
parents e95e3d2cb88d
children 2fecdc45a284
line source
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
6 def unwrap_node(node):
7 while isinstance(node, ExprNodes.PersistentNode):
8 node = node.arg
9 return node
11 def is_common_value(a, b):
12 a = unwrap_node(a)
13 b = unwrap_node(b)
14 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
15 return a.name == b.name
16 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
17 return not a.is_py_attr and is_common_value(a.obj, b.obj)
18 return False
21 class SwitchTransform(Visitor.VisitorTransform):
22 """
23 This transformation tries to turn long if statements into C switch statements.
24 The requirement is that every clause be an (or of) var == value, where the var
25 is common among all clauses and both var and value are ints.
26 """
27 def extract_conditions(self, cond):
29 if isinstance(cond, ExprNodes.CoerceToTempNode):
30 cond = cond.arg
32 if isinstance(cond, ExprNodes.TypecastNode):
33 cond = cond.operand
35 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
36 and cond.cascade is None
37 and cond.operator == '=='
38 and not cond.is_python_comparison()):
39 if is_common_value(cond.operand1, cond.operand1):
40 if isinstance(cond.operand2, ExprNodes.ConstNode):
41 return cond.operand1, [cond.operand2]
42 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
43 return cond.operand1, [cond.operand2]
44 if is_common_value(cond.operand2, cond.operand2):
45 if isinstance(cond.operand1, ExprNodes.ConstNode):
46 return cond.operand2, [cond.operand1]
47 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
48 return cond.operand2, [cond.operand1]
49 elif (isinstance(cond, ExprNodes.BoolBinopNode)
50 and cond.operator == 'or'):
51 t1, c1 = self.extract_conditions(cond.operand1)
52 t2, c2 = self.extract_conditions(cond.operand2)
53 if is_common_value(t1, t2):
54 return t1, c1+c2
55 return None, None
57 def visit_IfStatNode(self, node):
58 self.visitchildren(node)
59 common_var = None
60 case_count = 0
61 cases = []
62 for if_clause in node.if_clauses:
63 var, conditions = self.extract_conditions(if_clause.condition)
64 if var is None:
65 return node
66 elif common_var is not None and not is_common_value(var, common_var):
67 return node
68 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
69 return node
70 else:
71 common_var = var
72 case_count += len(conditions)
73 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
74 conditions = conditions,
75 body = if_clause.body))
76 if case_count < 2:
77 return node
79 common_var = unwrap_node(common_var)
80 return Nodes.SwitchStatNode(pos = node.pos,
81 test = common_var,
82 cases = cases,
83 else_clause = node.else_clause)
86 def visit_Node(self, node):
87 self.visitchildren(node)
88 return node
91 class FlattenInListTransform(Visitor.VisitorTransform):
92 """
93 This transformation flattens "x in [val1, ..., valn]" into a sequential list
94 of comparisons.
95 """
97 def visit_PrimaryCmpNode(self, node):
98 self.visitchildren(node)
99 if node.cascade is not None:
100 return node
101 elif node.operator == 'in':
102 conjunction = 'or'
103 eq_or_neq = '=='
104 elif node.operator == 'not_in':
105 conjunction = 'and'
106 eq_or_neq = '!='
107 else:
108 return node
110 if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
111 args = node.operand2.args
112 if len(args) == 0:
113 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
114 else:
115 lhs = ExprNodes.PersistentNode(node.operand1, len(args))
116 conds = []
117 for arg in args:
118 cond = ExprNodes.PrimaryCmpNode(
119 pos = node.pos,
120 operand1 = lhs,
121 operator = eq_or_neq,
122 operand2 = arg,
123 cascade = None)
124 conds.append(ExprNodes.TypecastNode(
125 pos = node.pos,
126 operand = cond,
127 type = PyrexTypes.c_bint_type))
128 def concat(left, right):
129 return ExprNodes.BoolBinopNode(
130 pos = node.pos,
131 operator = conjunction,
132 operand1 = left,
133 operand2 = right)
134 return reduce(concat, conds)
135 else:
136 return node
138 def visit_Node(self, node):
139 self.visitchildren(node)
140 return node
143 class FinalOptimizePhase(Visitor.CythonTransform):
144 """
145 This visitor handles several commuting optimizations, and is run
146 just before the C code generation phase.
148 The optimizations currently implemented in this class are:
149 - Eliminate None assignment and refcounting for first assignment.
150 - isinstance -> typecheck for cdef types
151 """
152 def visit_SingleAssignmentNode(self, node):
153 if node.first:
154 lhs = node.lhs
155 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
156 # Have variable initialized to 0 rather than None
157 lhs.entry.init_to_none = False
158 lhs.entry.init = 0
159 # Set a flag in NameNode to skip the decref
160 lhs.skip_assignment_decref = True
161 return node
163 def visit_SimpleCallNode(self, node):
164 self.visitchildren(node)
165 if 0 and node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
166 if node.function.name == 'isinstance':
167 type_arg = node.args[1]
168 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
169 object_module = self.context.find_module('python_object')
170 node.function.entry = object_module.lookup('PyObject_TypeCheck')
171 node.function.type = node.function.entry.type
172 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
173 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
174 return node