Cython has moved to github.
cython-devel
view Cython/Compiler/TypeInference.py @ 2811:6f8fc01ddab1
Verbose type inference directive.
| author | Robert Bradshaw <robertwb@math.washington.edu> |
|---|---|
| date | Thu Jan 21 16:41:28 2010 -0800 (2 years ago) |
| parents | 3859bf52700a |
| children | c0db10c1a935 |
line source
1 from Errors import error, warning, warn_once, InternalError
2 import ExprNodes
3 import Nodes
4 import Builtin
5 import PyrexTypes
6 from PyrexTypes import py_object_type, unspecified_type
7 from Visitor import CythonTransform
9 try:
10 set
11 except NameError:
12 # Python 2.3
13 from sets import Set as set
16 class TypedExprNode(ExprNodes.ExprNode):
17 # Used for declaring assignments of a specified type whithout a known entry.
18 def __init__(self, type):
19 self.type = type
21 object_expr = TypedExprNode(py_object_type)
23 class MarkAssignments(CythonTransform):
25 def mark_assignment(self, lhs, rhs):
26 if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
27 if lhs.entry is None:
28 # TODO: This shouldn't happen...
29 return
30 lhs.entry.assignments.append(rhs)
31 elif isinstance(lhs, ExprNodes.SequenceNode):
32 for arg in lhs.args:
33 self.mark_assignment(arg, object_expr)
34 else:
35 # Could use this info to infer cdef class attributes...
36 pass
38 def visit_SingleAssignmentNode(self, node):
39 self.mark_assignment(node.lhs, node.rhs)
40 self.visitchildren(node)
41 return node
43 def visit_CascadedAssignmentNode(self, node):
44 for lhs in node.lhs_list:
45 self.mark_assignment(lhs, node.rhs)
46 self.visitchildren(node)
47 return node
49 def visit_InPlaceAssignmentNode(self, node):
50 self.mark_assignment(node.lhs, node.create_binop_node())
51 self.visitchildren(node)
52 return node
54 def visit_ForInStatNode(self, node):
55 # TODO: Remove redundancy with range optimization...
56 is_special = False
57 sequence = node.iterator.sequence
58 if isinstance(sequence, ExprNodes.SimpleCallNode):
59 function = sequence.function
60 if sequence.self is None and function.is_name:
61 if function.name in ('range', 'xrange'):
62 is_special = True
63 for arg in sequence.args[:2]:
64 self.mark_assignment(node.target, arg)
65 if len(sequence.args) > 2:
66 self.mark_assignment(
67 node.target,
68 ExprNodes.binop_node(node.pos,
69 '+',
70 sequence.args[0],
71 sequence.args[2]))
72 if not is_special:
73 self.mark_assignment(node.target, object_expr)
74 self.visitchildren(node)
75 return node
77 def visit_ForFromStatNode(self, node):
78 self.mark_assignment(node.target, node.bound1)
79 if node.step is not None:
80 self.mark_assignment(node.target,
81 ExprNodes.binop_node(node.pos,
82 '+',
83 node.bound1,
84 node.step))
85 self.visitchildren(node)
86 return node
88 def visit_ExceptClauseNode(self, node):
89 if node.target is not None:
90 self.mark_assignment(node.target, object_expr)
91 self.visitchildren(node)
92 return node
94 def visit_FromCImportStatNode(self, node):
95 pass # Can't be assigned to...
97 def visit_FromImportStatNode(self, node):
98 for name, target in node.items:
99 if name != "*":
100 self.mark_assignment(target, object_expr)
101 self.visitchildren(node)
102 return node
104 def visit_DefNode(self, node):
105 # use fake expressions with the right result type
106 if node.star_arg:
107 self.mark_assignment(
108 node.star_arg, TypedExprNode(Builtin.tuple_type))
109 if node.starstar_arg:
110 self.mark_assignment(
111 node.starstar_arg, TypedExprNode(Builtin.dict_type))
112 self.visitchildren(node)
113 return node
116 class PyObjectTypeInferer:
117 """
118 If it's not declared, it's a PyObject.
119 """
120 def infer_types(self, scope):
121 """
122 Given a dict of entries, map all unspecified types to a specified type.
123 """
124 for name, entry in scope.entries.items():
125 if entry.type is unspecified_type:
126 entry.type = py_object_type
128 class SimpleAssignmentTypeInferer:
129 """
130 Very basic type inference.
131 """
132 # TODO: Implement a real type inference algorithm.
133 # (Something more powerful than just extending this one...)
134 def infer_types(self, scope):
135 enabled = scope.directives['infer_types']
136 verbose = scope.directives['infer_types.verbose']
137 if enabled == True:
138 spanning_type = aggressive_spanning_type
139 elif enabled is None: # safe mode
140 spanning_type = safe_spanning_type
141 else:
142 for entry in scope.entries.values():
143 if entry.type is unspecified_type:
144 entry.type = py_object_type
145 return
147 dependancies_by_entry = {} # entry -> dependancies
148 entries_by_dependancy = {} # dependancy -> entries
149 ready_to_infer = []
150 for name, entry in scope.entries.items():
151 if entry.type is unspecified_type:
152 all = set()
153 for expr in entry.assignments:
154 all.update(expr.type_dependencies(scope))
155 if all:
156 dependancies_by_entry[entry] = all
157 for dep in all:
158 if dep not in entries_by_dependancy:
159 entries_by_dependancy[dep] = set([entry])
160 else:
161 entries_by_dependancy[dep].add(entry)
162 else:
163 ready_to_infer.append(entry)
164 def resolve_dependancy(dep):
165 if dep in entries_by_dependancy:
166 for entry in entries_by_dependancy[dep]:
167 entry_deps = dependancies_by_entry[entry]
168 entry_deps.remove(dep)
169 if not entry_deps and entry != dep:
170 del dependancies_by_entry[entry]
171 ready_to_infer.append(entry)
172 # Try to infer things in order...
173 while True:
174 while ready_to_infer:
175 entry = ready_to_infer.pop()
176 types = [expr.infer_type(scope) for expr in entry.assignments]
177 if types:
178 entry.type = spanning_type(types)
179 else:
180 # FIXME: raise a warning?
181 # print "No assignments", entry.pos, entry
182 entry.type = py_object_type
183 if verbose:
184 warning(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type), 1)
185 resolve_dependancy(entry)
186 # Deal with simple circular dependancies...
187 for entry, deps in dependancies_by_entry.items():
188 if len(deps) == 1 and deps == set([entry]):
189 types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
190 if types:
191 entry.type = spanning_type(types)
192 types = [expr.infer_type(scope) for expr in entry.assignments]
193 entry.type = spanning_type(types) # might be wider...
194 resolve_dependancy(entry)
195 del dependancies_by_entry[entry]
196 if ready_to_infer:
197 break
198 if not ready_to_infer:
199 break
201 # We can't figure out the rest with this algorithm, let them be objects.
202 for entry in dependancies_by_entry:
203 entry.type = py_object_type
204 if verbose:
205 warning(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type), 1)
207 def find_spanning_type(type1, type2):
208 if type1 is type2:
209 return type1
210 elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
211 # type inference can break the coercion back to a Python bool
212 # if it returns an arbitrary int type here
213 return py_object_type
214 result_type = PyrexTypes.spanning_type(type1, type2)
215 if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type):
216 # Python's float type is just a C double, so it's safe to
217 # use the C type instead
218 return PyrexTypes.c_double_type
219 return result_type
221 def aggressive_spanning_type(types):
222 result_type = reduce(find_spanning_type, types)
223 return result_type
225 def safe_spanning_type(types):
226 result_type = reduce(find_spanning_type, types)
227 if result_type.is_pyobject:
228 # any specific Python type is always safe to infer
229 return result_type
230 elif result_type is PyrexTypes.c_double_type:
231 # Python's float type is just a C double, so it's safe to use
232 # the C type instead
233 return result_type
234 elif result_type is PyrexTypes.c_bint_type:
235 # find_spanning_type() only returns 'bint' for clean boolean
236 # operations without other int types, so this is safe, too
237 return result_type
238 return py_object_type
241 def get_type_inferer():
242 return SimpleAssignmentTypeInferer()
