Cython has moved to github.
cython-devel
view Cython/Compiler/ParseTreeTransforms.py @ 3042:9756a762a5c8
fix ticket 467: restore eval-once semantics for all rhs items in parallel assignments by extracting common subexpressions into temps
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Sat Mar 06 15:30:38 2010 +0100 (2 years ago) |
| parents | 0a5602ec6abb |
| children | f38e938a4338 cad2a43b2cb0 |
line source
1 from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
2 from Cython.Compiler.Visitor import CythonTransform, EnvTransform
3 from Cython.Compiler.ModuleNode import ModuleNode
4 from Cython.Compiler.Nodes import *
5 from Cython.Compiler.ExprNodes import *
6 from Cython.Compiler.UtilNodes import *
7 from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
8 from Cython.Compiler.StringEncoding import EncodedString
9 from Cython.Compiler.Errors import error, CompileError
10 try:
11 set
12 except NameError:
13 from sets import Set as set
14 import copy
17 class NameNodeCollector(TreeVisitor):
18 """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
19 attribute.
20 """
21 def __init__(self):
22 super(NameNodeCollector, self).__init__()
23 self.name_nodes = []
25 visit_Node = TreeVisitor.visitchildren
27 def visit_NameNode(self, node):
28 self.name_nodes.append(node)
31 class SkipDeclarations(object):
32 """
33 Variable and function declarations can often have a deep tree structure,
34 and yet most transformations don't need to descend to this depth.
36 Declaration nodes are removed after AnalyseDeclarationsTransform, so there
37 is no need to use this for transformations after that point.
38 """
39 def visit_CTypeDefNode(self, node):
40 return node
42 def visit_CVarDefNode(self, node):
43 return node
45 def visit_CDeclaratorNode(self, node):
46 return node
48 def visit_CBaseTypeNode(self, node):
49 return node
51 def visit_CEnumDefNode(self, node):
52 return node
54 def visit_CStructOrUnionDefNode(self, node):
55 return node
58 class NormalizeTree(CythonTransform):
59 """
60 This transform fixes up a few things after parsing
61 in order to make the parse tree more suitable for
62 transforms.
64 a) After parsing, blocks with only one statement will
65 be represented by that statement, not by a StatListNode.
66 When doing transforms this is annoying and inconsistent,
67 as one cannot in general remove a statement in a consistent
68 way and so on. This transform wraps any single statements
69 in a StatListNode containing a single statement.
71 b) The PassStatNode is a noop and serves no purpose beyond
72 plugging such one-statement blocks; i.e., once parsed a
73 ` "pass" can just as well be represented using an empty
74 StatListNode. This means less special cases to worry about
75 in subsequent transforms (one always checks to see if a
76 StatListNode has no children to see if the block is empty).
77 """
79 def __init__(self, context):
80 super(NormalizeTree, self).__init__(context)
81 self.is_in_statlist = False
82 self.is_in_expr = False
84 def visit_ExprNode(self, node):
85 stacktmp = self.is_in_expr
86 self.is_in_expr = True
87 self.visitchildren(node)
88 self.is_in_expr = stacktmp
89 return node
91 def visit_StatNode(self, node, is_listcontainer=False):
92 stacktmp = self.is_in_statlist
93 self.is_in_statlist = is_listcontainer
94 self.visitchildren(node)
95 self.is_in_statlist = stacktmp
96 if not self.is_in_statlist and not self.is_in_expr:
97 return StatListNode(pos=node.pos, stats=[node])
98 else:
99 return node
101 def visit_StatListNode(self, node):
102 self.is_in_statlist = True
103 self.visitchildren(node)
104 self.is_in_statlist = False
105 return node
107 def visit_ParallelAssignmentNode(self, node):
108 return self.visit_StatNode(node, True)
110 def visit_CEnumDefNode(self, node):
111 return self.visit_StatNode(node, True)
113 def visit_CStructOrUnionDefNode(self, node):
114 return self.visit_StatNode(node, True)
116 # Eliminate PassStatNode
117 def visit_PassStatNode(self, node):
118 if not self.is_in_statlist:
119 return StatListNode(pos=node.pos, stats=[])
120 else:
121 return []
123 def visit_CDeclaratorNode(self, node):
124 return node
127 class PostParseError(CompileError): pass
129 # error strings checked by unit tests, so define them
130 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
131 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
132 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
133 class PostParse(CythonTransform):
134 """
135 Basic interpretation of the parse tree, as well as validity
136 checking that can be done on a very basic level on the parse
137 tree (while still not being a problem with the basic syntax,
138 as such).
140 Specifically:
141 - Default values to cdef assignments are turned into single
142 assignments following the declaration (everywhere but in class
143 bodies, where they raise a compile error)
145 - Interpret some node structures into Python runtime values.
146 Some nodes take compile-time arguments (currently:
147 TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
148 which should be interpreted. This happens in a general way
149 and other steps should be taken to ensure validity.
151 Type arguments cannot be interpreted in this way.
153 - For __cythonbufferdefaults__ the arguments are checked for
154 validity.
156 TemplatedTypeNode has its directives interpreted:
157 Any first positional argument goes into the "dtype" attribute,
158 any "ndim" keyword argument goes into the "ndim" attribute and
159 so on. Also it is checked that the directive combination is valid.
160 - __cythonbufferdefaults__ attributes are parsed and put into the
161 type information.
163 Note: Currently Parsing.py does a lot of interpretation and
164 reorganization that can be refactored into this transform
165 if a more pure Abstract Syntax Tree is wanted.
166 """
168 # Track our context.
169 scope_type = None # can be either of 'module', 'function', 'class'
171 def __init__(self, context):
172 super(PostParse, self).__init__(context)
173 self.specialattribute_handlers = {
174 '__cythonbufferdefaults__' : self.handle_bufferdefaults
175 }
177 def visit_ModuleNode(self, node):
178 self.scope_type = 'module'
179 self.scope_node = node
180 self.visitchildren(node)
181 return node
183 def visit_scope(self, node, scope_type):
184 prev = self.scope_type, self.scope_node
185 self.scope_type = scope_type
186 self.scope_node = node
187 self.visitchildren(node)
188 self.scope_type, self.scope_node = prev
189 return node
191 def visit_ClassDefNode(self, node):
192 return self.visit_scope(node, 'class')
194 def visit_FuncDefNode(self, node):
195 return self.visit_scope(node, 'function')
197 def visit_CStructOrUnionDefNode(self, node):
198 return self.visit_scope(node, 'struct')
200 # cdef variables
201 def handle_bufferdefaults(self, decl):
202 if not isinstance(decl.default, DictNode):
203 raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
204 self.scope_node.buffer_defaults_node = decl.default
205 self.scope_node.buffer_defaults_pos = decl.pos
207 def visit_CVarDefNode(self, node):
208 # This assumes only plain names and pointers are assignable on
209 # declaration. Also, it makes use of the fact that a cdef decl
210 # must appear before the first use, so we don't have to deal with
211 # "i = 3; cdef int i = i" and can simply move the nodes around.
212 try:
213 self.visitchildren(node)
214 stats = [node]
215 newdecls = []
216 for decl in node.declarators:
217 declbase = decl
218 while isinstance(declbase, CPtrDeclaratorNode):
219 declbase = declbase.base
220 if isinstance(declbase, CNameDeclaratorNode):
221 if declbase.default is not None:
222 if self.scope_type in ('class', 'struct'):
223 if isinstance(self.scope_node, CClassDefNode):
224 handler = self.specialattribute_handlers.get(decl.name)
225 if handler:
226 if decl is not declbase:
227 raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
228 handler(decl)
229 continue # Remove declaration
230 raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
231 first_assignment = self.scope_type != 'module'
232 stats.append(SingleAssignmentNode(node.pos,
233 lhs=NameNode(node.pos, name=declbase.name),
234 rhs=declbase.default, first=first_assignment))
235 declbase.default = None
236 newdecls.append(decl)
237 node.declarators = newdecls
238 return stats
239 except PostParseError, e:
240 # An error in a cdef clause is ok, simply remove the declaration
241 # and try to move on to report more errors
242 self.context.nonfatal_error(e)
243 return None
245 # Split parallel assignments (a,b = b,a) into separate partial
246 # assignments that are executed rhs-first using temps. This
247 # optimisation is best applied before type analysis so that known
248 # types on rhs and lhs can be matched directly.
250 def visit_SingleAssignmentNode(self, node):
251 self.visitchildren(node)
252 return self._visit_assignment_node(node, [node.lhs, node.rhs])
254 def visit_CascadedAssignmentNode(self, node):
255 self.visitchildren(node)
256 return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
258 def _visit_assignment_node(self, node, expr_list):
259 """Flatten parallel assignments into separate single
260 assignments or cascaded assignments.
261 """
262 if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
263 # no parallel assignments => nothing to do
264 return node
266 expr_list_list = []
267 flatten_parallel_assignments(expr_list, expr_list_list)
268 temp_refs = []
269 eliminate_rhs_duplicates(expr_list_list, temp_refs)
271 nodes = []
272 for expr_list in expr_list_list:
273 lhs_list = expr_list[:-1]
274 rhs = expr_list[-1]
275 if len(lhs_list) == 1:
276 node = Nodes.SingleAssignmentNode(rhs.pos,
277 lhs = lhs_list[0], rhs = rhs)
278 else:
279 node = Nodes.CascadedAssignmentNode(rhs.pos,
280 lhs_list = lhs_list, rhs = rhs)
281 nodes.append(node)
283 if len(nodes) == 1:
284 assign_node = nodes[0]
285 else:
286 assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
288 if temp_refs:
289 duplicates_and_temps = [ (temp.expression, temp)
290 for temp in temp_refs ]
291 sort_common_subsequences(duplicates_and_temps)
292 for _, temp_ref in duplicates_and_temps[::-1]:
293 assign_node = LetNode(temp_ref, assign_node)
295 return assign_node
297 def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
298 """Replace rhs items by LetRefNodes if they appear more than once.
299 Creates a sequence of LetRefNodes that set up the required temps
300 and appends them to ref_node_sequence. The input list is modified
301 in-place.
302 """
303 seen_nodes = set()
304 ref_nodes = {}
305 def find_duplicates(node):
306 if node.is_literal or node.is_name:
307 # no need to replace those; can't include attributes here
308 # as their access is not necessarily side-effect free
309 return
310 if node in seen_nodes:
311 if node not in ref_nodes:
312 ref_node = LetRefNode(node)
313 ref_nodes[node] = ref_node
314 ref_node_sequence.append(ref_node)
315 else:
316 seen_nodes.add(node)
317 if node.is_sequence_constructor:
318 for item in node.args:
319 find_duplicates(item)
321 for expr_list in expr_list_list:
322 rhs = expr_list[-1]
323 find_duplicates(rhs)
324 if not ref_nodes:
325 return
327 def substitute_nodes(node):
328 if node in ref_nodes:
329 return ref_nodes[node]
330 elif node.is_sequence_constructor:
331 node.args = map(substitute_nodes, node.args)
332 return node
334 # replace nodes inside of the common subexpressions
335 for node in ref_nodes:
336 if node.is_sequence_constructor:
337 node.args = map(substitute_nodes, node.args)
339 # replace common subexpressions on all rhs items
340 for expr_list in expr_list_list:
341 expr_list[-1] = substitute_nodes(expr_list[-1])
343 def sort_common_subsequences(items):
344 """Sort items/subsequences so that all items and subsequences that
345 an item contains appear before the item itself. This implies a
346 partial order, and the sort must be stable to preserve the
347 original order as much as possible, so we use a simple insertion
348 sort.
349 """
350 def contains(seq, x):
351 for item in seq:
352 if item is x:
353 return True
354 elif item.is_sequence_constructor and contains(item.args, x):
355 return True
356 return False
357 def lower_than(a,b):
358 return b.is_sequence_constructor and contains(b.args, a)
360 for pos, item in enumerate(items):
361 new_pos = pos
362 key = item[0]
363 for i in xrange(pos-1, -1, -1):
364 if lower_than(key, items[i][0]):
365 new_pos = i
366 if new_pos != pos:
367 for i in xrange(pos, new_pos, -1):
368 items[i] = items[i-1]
369 items[new_pos] = item
371 def flatten_parallel_assignments(input, output):
372 # The input is a list of expression nodes, representing the LHSs
373 # and RHS of one (possibly cascaded) assignment statement. For
374 # sequence constructors, rearranges the matching parts of both
375 # sides into a list of equivalent assignments between the
376 # individual elements. This transformation is applied
377 # recursively, so that nested structures get matched as well.
378 rhs = input[-1]
379 if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
380 output.append(input)
381 return
383 complete_assignments = []
385 rhs_size = len(rhs.args)
386 lhs_targets = [ [] for _ in xrange(rhs_size) ]
387 starred_assignments = []
388 for lhs in input[:-1]:
389 if not lhs.is_sequence_constructor:
390 if lhs.is_starred:
391 error(lhs.pos, "starred assignment target must be in a list or tuple")
392 complete_assignments.append(lhs)
393 continue
394 lhs_size = len(lhs.args)
395 starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
396 if starred_targets > 1:
397 error(lhs.pos, "more than 1 starred expression in assignment")
398 output.append([lhs,rhs])
399 continue
400 elif lhs_size - starred_targets > rhs_size:
401 error(lhs.pos, "need more than %d value%s to unpack"
402 % (rhs_size, (rhs_size != 1) and 's' or ''))
403 output.append([lhs,rhs])
404 continue
405 elif starred_targets == 1:
406 map_starred_assignment(lhs_targets, starred_assignments,
407 lhs.args, rhs.args)
408 elif lhs_size < rhs_size:
409 error(lhs.pos, "too many values to unpack (expected %d, got %d)"
410 % (lhs_size, rhs_size))
411 output.append([lhs,rhs])
412 continue
413 else:
414 for targets, expr in zip(lhs_targets, lhs.args):
415 targets.append(expr)
417 if complete_assignments:
418 complete_assignments.append(rhs)
419 output.append(complete_assignments)
421 # recursively flatten partial assignments
422 for cascade, rhs in zip(lhs_targets, rhs.args):
423 if cascade:
424 cascade.append(rhs)
425 flatten_parallel_assignments(cascade, output)
427 # recursively flatten starred assignments
428 for cascade in starred_assignments:
429 if cascade[0].is_sequence_constructor:
430 flatten_parallel_assignments(cascade, output)
431 else:
432 output.append(cascade)
434 def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
435 # Appends the fixed-position LHS targets to the target list that
436 # appear left and right of the starred argument.
437 #
438 # The starred_assignments list receives a new tuple
439 # (lhs_target, rhs_values_list) that maps the remaining arguments
440 # (those that match the starred target) to a list.
442 # left side of the starred target
443 for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
444 if expr.is_starred:
445 starred = i
446 lhs_remaining = len(lhs_args) - i - 1
447 break
448 targets.append(expr)
449 else:
450 raise InternalError("no starred arg found when splitting starred assignment")
452 # right side of the starred target
453 for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
454 lhs_args[-lhs_remaining:])):
455 targets.append(expr)
457 # the starred target itself, must be assigned a (potentially empty) list
458 target = lhs_args[starred].target # unpack starred node
459 starred_rhs = rhs_args[starred:]
460 if lhs_remaining:
461 starred_rhs = starred_rhs[:-lhs_remaining]
462 if starred_rhs:
463 pos = starred_rhs[0].pos
464 else:
465 pos = target.pos
466 starred_assignments.append([
467 target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
470 class PxdPostParse(CythonTransform, SkipDeclarations):
471 """
472 Basic interpretation/validity checking that should only be
473 done on pxd trees.
475 A lot of this checking currently happens in the parser; but
476 what is listed below happens here.
478 - "def" functions are let through only if they fill the
479 getbuffer/releasebuffer slots
481 - cdef functions are let through only if they are on the
482 top level and are declared "inline"
483 """
484 ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
485 ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
487 def __call__(self, node):
488 self.scope_type = 'pxd'
489 return super(PxdPostParse, self).__call__(node)
491 def visit_CClassDefNode(self, node):
492 old = self.scope_type
493 self.scope_type = 'cclass'
494 self.visitchildren(node)
495 self.scope_type = old
496 return node
498 def visit_FuncDefNode(self, node):
499 # FuncDefNode always come with an implementation (without
500 # an imp they are CVarDefNodes..)
501 err = self.ERR_INLINE_ONLY
503 if (isinstance(node, DefNode) and self.scope_type == 'cclass'
504 and node.name in ('__getbuffer__', '__releasebuffer__')):
505 err = None # allow these slots
507 if isinstance(node, CFuncDefNode):
508 if u'inline' in node.modifiers and self.scope_type == 'pxd':
509 node.inline_in_pxd = True
510 if node.visibility != 'private':
511 err = self.ERR_NOGO_WITH_INLINE % node.visibility
512 elif node.api:
513 err = self.ERR_NOGO_WITH_INLINE % 'api'
514 else:
515 err = None # allow inline function
516 else:
517 err = self.ERR_INLINE_ONLY
519 if err:
520 self.context.nonfatal_error(PostParseError(node.pos, err))
521 return None
522 else:
523 return node
525 class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
526 """
527 After parsing, directives can be stored in a number of places:
528 - #cython-comments at the top of the file (stored in ModuleNode)
529 - Command-line arguments overriding these
530 - @cython.directivename decorators
531 - with cython.directivename: statements
533 This transform is responsible for interpreting these various sources
534 and store the directive in two ways:
535 - Set the directives attribute of the ModuleNode for global directives.
536 - Use a CompilerDirectivesNode to override directives for a subtree.
538 (The first one is primarily to not have to modify with the tree
539 structure, so that ModuleNode stay on top.)
541 The directives are stored in dictionaries from name to value in effect.
542 Each such dictionary is always filled in for all possible directives,
543 using default values where no value is given by the user.
545 The available directives are controlled in Options.py.
547 Note that we have to run this prior to analysis, and so some minor
548 duplication of functionality has to occur: We manually track cimports
549 and which names the "cython" module may have been imported to.
550 """
551 unop_method_nodes = {
552 'typeof': TypeofNode,
554 'operator.address': AmpersandNode,
555 'operator.dereference': DereferenceNode,
556 'operator.preincrement' : inc_dec_constructor(True, '++'),
557 'operator.predecrement' : inc_dec_constructor(True, '--'),
558 'operator.postincrement': inc_dec_constructor(False, '++'),
559 'operator.postdecrement': inc_dec_constructor(False, '--'),
561 # For backwards compatability.
562 'address': AmpersandNode,
563 }
565 special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof',
566 'cast', 'pointer', 'compiled', 'NULL']
567 + unop_method_nodes.keys())
569 def __init__(self, context, compilation_directive_defaults):
570 super(InterpretCompilerDirectives, self).__init__(context)
571 self.compilation_directive_defaults = {}
572 for key, value in compilation_directive_defaults.iteritems():
573 self.compilation_directive_defaults[unicode(key)] = value
574 self.cython_module_names = set()
575 self.directive_names = {}
577 def check_directive_scope(self, pos, directive, scope):
578 legal_scopes = Options.directive_scopes.get(directive, None)
579 if legal_scopes and scope not in legal_scopes:
580 self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
581 'is not allowed in %s scope' % (directive, scope)))
582 return False
583 else:
584 return True
586 # Set up processing and handle the cython: comments.
587 def visit_ModuleNode(self, node):
588 for key, value in node.directive_comments.iteritems():
589 if not self.check_directive_scope(node.pos, key, 'module'):
590 self.wrong_scope_error(node.pos, key, 'module')
591 del node.directive_comments[key]
593 directives = copy.copy(Options.directive_defaults)
594 directives.update(self.compilation_directive_defaults)
595 directives.update(node.directive_comments)
596 self.directives = directives
597 node.directives = directives
598 self.visitchildren(node)
599 node.cython_module_names = self.cython_module_names
600 return node
602 # The following four functions track imports and cimports that
603 # begin with "cython"
604 def is_cython_directive(self, name):
605 return (name in Options.directive_types or
606 name in self.special_methods or
607 PyrexTypes.parse_basic_type(name))
609 def visit_CImportStatNode(self, node):
610 if node.module_name == u"cython":
611 self.cython_module_names.add(node.as_name or u"cython")
612 elif node.module_name.startswith(u"cython."):
613 if node.as_name:
614 self.directive_names[node.as_name] = node.module_name[7:]
615 else:
616 self.cython_module_names.add(u"cython")
617 # if this cimport was a compiler directive, we don't
618 # want to leave the cimport node sitting in the tree
619 return None
620 return node
622 def visit_FromCImportStatNode(self, node):
623 if (node.module_name == u"cython") or \
624 node.module_name.startswith(u"cython."):
625 submodule = (node.module_name + u".")[7:]
626 newimp = []
627 for pos, name, as_name, kind in node.imported_names:
628 full_name = submodule + name
629 if self.is_cython_directive(full_name):
630 if as_name is None:
631 as_name = full_name
632 self.directive_names[as_name] = full_name
633 if kind is not None:
634 self.context.nonfatal_error(PostParseError(pos,
635 "Compiler directive imports must be plain imports"))
636 else:
637 newimp.append((pos, name, as_name, kind))
638 if not newimp:
639 return None
640 node.imported_names = newimp
641 return node
643 def visit_FromImportStatNode(self, node):
644 if (node.module.module_name.value == u"cython") or \
645 node.module.module_name.value.startswith(u"cython."):
646 submodule = (node.module.module_name.value + u".")[7:]
647 newimp = []
648 for name, name_node in node.items:
649 full_name = submodule + name
650 if self.is_cython_directive(full_name):
651 self.directive_names[name_node.name] = full_name
652 else:
653 newimp.append((name, name_node))
654 if not newimp:
655 return None
656 node.items = newimp
657 return node
659 def visit_SingleAssignmentNode(self, node):
660 if (isinstance(node.rhs, ImportNode) and
661 node.rhs.module_name.value == u'cython'):
662 node = CImportStatNode(node.pos,
663 module_name = u'cython',
664 as_name = node.lhs.name)
665 self.visit_CImportStatNode(node)
666 else:
667 self.visitchildren(node)
668 return node
670 def visit_NameNode(self, node):
671 if node.name in self.cython_module_names:
672 node.is_cython_module = True
673 else:
674 node.cython_attribute = self.directive_names.get(node.name)
675 return node
677 def try_to_parse_directives(self, node):
678 # If node is the contents of an directive (in a with statement or
679 # decorator), returns a list of (directivename, value) pairs.
680 # Otherwise, returns None
681 if isinstance(node, CallNode):
682 self.visit(node.function)
683 optname = node.function.as_cython_attribute()
684 if optname:
685 directivetype = Options.directive_types.get(optname)
686 if directivetype:
687 args, kwds = node.explicit_args_kwds()
688 directives = []
689 key_value_pairs = []
690 if kwds is not None and directivetype is not dict:
691 for keyvalue in kwds.key_value_pairs:
692 key, value = keyvalue
693 sub_optname = "%s.%s" % (optname, key.value)
694 if Options.directive_types.get(sub_optname):
695 directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
696 else:
697 key_value_pairs.append(keyvalue)
698 if not key_value_pairs:
699 kwds = None
700 else:
701 kwds.key_value_pairs = key_value_pairs
702 if directives and not kwds and not args:
703 return directives
704 directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
705 return directives
707 return None
709 def try_to_parse_directive(self, optname, args, kwds, pos):
710 directivetype = Options.directive_types.get(optname)
711 if len(args) == 1 and isinstance(args[0], NoneNode):
712 return optname, Options.directive_defaults[optname]
713 elif directivetype is bool:
714 if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
715 raise PostParseError(pos,
716 'The %s directive takes one compile-time boolean argument' % optname)
717 return (optname, args[0].value)
718 elif directivetype is str:
719 if kwds is not None or len(args) != 1 or not isinstance(args[0], (StringNode, UnicodeNode)):
720 raise PostParseError(pos,
721 'The %s directive takes one compile-time string argument' % optname)
722 return (optname, str(args[0].value))
723 elif directivetype is dict:
724 if len(args) != 0:
725 raise PostParseError(pos,
726 'The %s directive takes no prepositional arguments' % optname)
727 return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
728 elif directivetype is list:
729 if kwds and len(kwds) != 0:
730 raise PostParseError(pos,
731 'The %s directive takes no keyword arguments' % optname)
732 return optname, [ str(arg.value) for arg in args ]
733 else:
734 assert False
736 def visit_with_directives(self, body, directives):
737 olddirectives = self.directives
738 newdirectives = copy.copy(olddirectives)
739 newdirectives.update(directives)
740 self.directives = newdirectives
741 assert isinstance(body, StatListNode), body
742 retbody = self.visit_Node(body)
743 directive = CompilerDirectivesNode(pos=retbody.pos, body=retbody,
744 directives=newdirectives)
745 self.directives = olddirectives
746 return directive
748 # Handle decorators
749 def visit_FuncDefNode(self, node):
750 directives = []
751 if node.decorators:
752 # Split the decorators into two lists -- real decorators and directives
753 realdecs = []
754 for dec in node.decorators:
755 new_directives = self.try_to_parse_directives(dec.decorator)
756 if new_directives is not None:
757 directives.extend(new_directives)
758 else:
759 realdecs.append(dec)
760 if realdecs and isinstance(node, CFuncDefNode):
761 raise PostParseError(realdecs[0].pos, "Cdef functions cannot take arbitrary decorators.")
762 else:
763 node.decorators = realdecs
765 if directives:
766 optdict = {}
767 directives.reverse() # Decorators coming first take precedence
768 for directive in directives:
769 name, value = directive
770 legal_scopes = Options.directive_scopes.get(name, None)
771 if not self.check_directive_scope(node.pos, name, 'function'):
772 continue
773 if name in optdict:
774 old_value = optdict[name]
775 # keywords and arg lists can be merged, everything
776 # else overrides completely
777 if isinstance(old_value, dict):
778 old_value.update(value)
779 elif isinstance(old_value, list):
780 old_value.extend(value)
781 else:
782 optdict[name] = value
783 else:
784 optdict[name] = value
785 body = StatListNode(node.pos, stats=[node])
786 return self.visit_with_directives(body, optdict)
787 else:
788 return self.visit_Node(node)
790 def visit_CVarDefNode(self, node):
791 if node.decorators:
792 for dec in node.decorators:
793 for directive in self.try_to_parse_directives(dec.decorator) or []:
794 if directive is not None and directive[0] == u'locals':
795 node.directive_locals = directive[1]
796 else:
797 self.context.nonfatal_error(PostParseError(dec.pos,
798 "Cdef functions can only take cython.locals() decorator."))
799 return node
801 # Handle with statements
802 def visit_WithStatNode(self, node):
803 directive_dict = {}
804 for directive in self.try_to_parse_directives(node.manager) or []:
805 if directive is not None:
806 if node.target is not None:
807 self.context.nonfatal_error(
808 PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
809 else:
810 name, value = directive
811 if self.check_directive_scope(node.pos, name, 'with statement'):
812 directive_dict[name] = value
813 if directive_dict:
814 return self.visit_with_directives(node.body, directive_dict)
815 return self.visit_Node(node)
817 class WithTransform(CythonTransform, SkipDeclarations):
819 # EXCINFO is manually set to a variable that contains
820 # the exc_info() tuple that can be generated by the enclosing except
821 # statement.
822 template_without_target = TreeFragment(u"""
823 MGR = EXPR
824 EXIT = MGR.__exit__
825 MGR.__enter__()
826 EXC = True
827 try:
828 try:
829 EXCINFO = None
830 BODY
831 except:
832 EXC = False
833 if not EXIT(*EXCINFO):
834 raise
835 finally:
836 if EXC:
837 EXIT(None, None, None)
838 """, temps=[u'MGR', u'EXC', u"EXIT"],
839 pipeline=[NormalizeTree(None)])
841 template_with_target = TreeFragment(u"""
842 MGR = EXPR
843 EXIT = MGR.__exit__
844 VALUE = MGR.__enter__()
845 EXC = True
846 try:
847 try:
848 EXCINFO = None
849 TARGET = VALUE
850 BODY
851 except:
852 EXC = False
853 if not EXIT(*EXCINFO):
854 raise
855 finally:
856 if EXC:
857 EXIT(None, None, None)
858 MGR = EXIT = VALUE = EXC = None
860 """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
861 pipeline=[NormalizeTree(None)])
863 def visit_WithStatNode(self, node):
864 # TODO: Cleanup badly needed
865 TemplateTransform.temp_name_counter += 1
866 handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
868 self.visitchildren(node, ['body'])
869 excinfo_temp = NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
870 if node.target is not None:
871 result = self.template_with_target.substitute({
872 u'EXPR' : node.manager,
873 u'BODY' : node.body,
874 u'TARGET' : node.target,
875 u'EXCINFO' : excinfo_temp
876 }, pos=node.pos)
877 else:
878 result = self.template_without_target.substitute({
879 u'EXPR' : node.manager,
880 u'BODY' : node.body,
881 u'EXCINFO' : excinfo_temp
882 }, pos=node.pos)
884 # Set except excinfo target to EXCINFO
885 try_except = result.stats[-1].body.stats[-1]
886 try_except.except_clauses[0].excinfo_target = NameNode(node.pos, name=handle)
887 # excinfo_temp.ref(node.pos))
889 # result.stats[-1].body.stats[-1] = TempsBlockNode(
890 # node.pos, temps=[excinfo_temp], body=try_except)
892 return result
894 def visit_ExprNode(self, node):
895 # With statements are never inside expressions.
896 return node
899 class DecoratorTransform(CythonTransform, SkipDeclarations):
901 def visit_DefNode(self, func_node):
902 self.visitchildren(func_node)
903 if not func_node.decorators:
904 return func_node
905 return self._handle_decorators(
906 func_node, func_node.name)
908 def _visit_CClassDefNode(self, class_node):
909 # This doesn't currently work, so it's disabled (also in the
910 # parser).
911 #
912 # Problem: assignments to cdef class names do not work. They
913 # would require an additional check anyway, as the extension
914 # type must not change its C type, so decorators cannot
915 # replace an extension type, just alter it and return it.
917 self.visitchildren(class_node)
918 if not class_node.decorators:
919 return class_node
920 return self._handle_decorators(
921 class_node, class_node.class_name)
923 def visit_ClassDefNode(self, class_node):
924 self.visitchildren(class_node)
925 if not class_node.decorators:
926 return class_node
927 return self._handle_decorators(
928 class_node, class_node.name)
930 def _handle_decorators(self, node, name):
931 decorator_result = NameNode(node.pos, name = name)
932 for decorator in node.decorators[::-1]:
933 decorator_result = SimpleCallNode(
934 decorator.pos,
935 function = decorator.decorator,
936 args = [decorator_result])
938 name_node = NameNode(node.pos, name = name)
939 reassignment = SingleAssignmentNode(
940 node.pos,
941 lhs = name_node,
942 rhs = decorator_result)
943 return [node, reassignment]
946 class AnalyseDeclarationsTransform(CythonTransform):
948 basic_property = TreeFragment(u"""
949 property NAME:
950 def __get__(self):
951 return ATTR
952 def __set__(self, value):
953 ATTR = value
954 """, level='c_class')
956 def __call__(self, root):
957 self.env_stack = [root.scope]
958 # needed to determine if a cdef var is declared after it's used.
959 self.seen_vars_stack = []
960 return super(AnalyseDeclarationsTransform, self).__call__(root)
962 def visit_NameNode(self, node):
963 self.seen_vars_stack[-1].add(node.name)
964 return node
966 def visit_ModuleNode(self, node):
967 self.seen_vars_stack.append(set())
968 node.analyse_declarations(self.env_stack[-1])
969 self.visitchildren(node)
970 self.seen_vars_stack.pop()
971 return node
973 def visit_ClassDefNode(self, node):
974 self.env_stack.append(node.scope)
975 self.visitchildren(node)
976 self.env_stack.pop()
977 return node
979 def visit_FuncDefNode(self, node):
980 self.seen_vars_stack.append(set())
981 lenv = node.local_scope
982 node.body.analyse_control_flow(lenv) # this will be totally refactored
983 node.declare_arguments(lenv)
984 for var, type_node in node.directive_locals.items():
985 if not lenv.lookup_here(var): # don't redeclare args
986 type = type_node.analyse_as_type(lenv)
987 if type:
988 lenv.declare_var(var, type, type_node.pos)
989 else:
990 error(type_node.pos, "Not a type")
991 node.body.analyse_declarations(lenv)
992 self.env_stack.append(lenv)
993 self.visitchildren(node)
994 self.env_stack.pop()
995 self.seen_vars_stack.pop()
996 return node
998 def visit_ComprehensionNode(self, node):
999 self.visitchildren(node)
1000 node.analyse_declarations(self.env_stack[-1])
1001 return node
1003 # Some nodes are no longer needed after declaration
1004 # analysis and can be dropped. The analysis was performed
1005 # on these nodes in a seperate recursive process from the
1006 # enclosing function or module, so we can simply drop them.
1007 def visit_CDeclaratorNode(self, node):
1008 # necessary to ensure that all CNameDeclaratorNodes are visited.
1009 self.visitchildren(node)
1010 return node
1012 def visit_CTypeDefNode(self, node):
1013 return node
1015 def visit_CBaseTypeNode(self, node):
1016 return None
1018 def visit_CEnumDefNode(self, node):
1019 if node.visibility == 'public':
1020 return node
1021 else:
1022 return None
1024 def visit_CStructOrUnionDefNode(self, node):
1025 return None
1027 def visit_CNameDeclaratorNode(self, node):
1028 if node.name in self.seen_vars_stack[-1]:
1029 entry = self.env_stack[-1].lookup(node.name)
1030 if entry is None or entry.visibility != 'extern':
1031 warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1032 self.visitchildren(node)
1033 return node
1035 def visit_CVarDefNode(self, node):
1037 # to ensure all CNameDeclaratorNodes are visited.
1038 self.visitchildren(node)
1040 if node.need_properties:
1041 # cdef public attributes may need type testing on
1042 # assignment, so we create a property accesss
1043 # mechanism for them.
1044 stats = []
1045 for entry in node.need_properties:
1046 property = self.create_Property(entry)
1047 property.analyse_declarations(node.dest_scope)
1048 self.visit(property)
1049 stats.append(property)
1050 return StatListNode(pos=node.pos, stats=stats)
1051 else:
1052 return None
1054 def create_Property(self, entry):
1055 template = self.basic_property
1056 property = template.substitute({
1057 u"ATTR": AttributeNode(pos=entry.pos,
1058 obj=NameNode(pos=entry.pos, name="self"),
1059 attribute=entry.name),
1060 }, pos=entry.pos).stats[0]
1061 property.name = entry.name
1062 return property
1064 class AnalyseExpressionsTransform(CythonTransform):
1066 def visit_ModuleNode(self, node):
1067 node.scope.infer_types()
1068 node.body.analyse_expressions(node.scope)
1069 self.visitchildren(node)
1070 return node
1072 def visit_FuncDefNode(self, node):
1073 node.local_scope.infer_types()
1074 node.body.analyse_expressions(node.local_scope)
1075 self.visitchildren(node)
1076 return node
1078 class AlignFunctionDefinitions(CythonTransform):
1079 """
1080 This class takes the signatures from a .pxd file and applies them to
1081 the def methods in a .py file.
1082 """
1084 def visit_ModuleNode(self, node):
1085 self.scope = node.scope
1086 self.directives = node.directives
1087 self.visitchildren(node)
1088 return node
1090 def visit_PyClassDefNode(self, node):
1091 pxd_def = self.scope.lookup(node.name)
1092 if pxd_def:
1093 if pxd_def.is_cclass:
1094 return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
1095 else:
1096 error(node.pos, "'%s' redeclared" % node.name)
1097 error(pxd_def.pos, "previous declaration here")
1098 return None
1099 else:
1100 return node
1102 def visit_CClassDefNode(self, node, pxd_def=None):
1103 if pxd_def is None:
1104 pxd_def = self.scope.lookup(node.class_name)
1105 if pxd_def:
1106 outer_scope = self.scope
1107 self.scope = pxd_def.type.scope
1108 self.visitchildren(node)
1109 if pxd_def:
1110 self.scope = outer_scope
1111 return node
1113 def visit_DefNode(self, node):
1114 pxd_def = self.scope.lookup(node.name)
1115 if pxd_def:
1116 if self.scope.is_c_class_scope and len(pxd_def.type.args) > 0:
1117 # The self parameter type needs adjusting.
1118 pxd_def.type.args[0].type = self.scope.parent_type
1119 if pxd_def.is_cfunction:
1120 node = node.as_cfunction(pxd_def)
1121 else:
1122 error(node.pos, "'%s' redeclared" % node.name)
1123 error(pxd_def.pos, "previous declaration here")
1124 return None
1125 elif self.scope.is_module_scope and self.directives['auto_cpdef']:
1126 node = node.as_cfunction(scope=self.scope)
1127 # Enable this when internal def functions are allowed.
1128 # self.visitchildren(node)
1129 return node
1132 class MarkClosureVisitor(CythonTransform):
1134 needs_closure = False
1136 def visit_FuncDefNode(self, node):
1137 self.needs_closure = False
1138 self.visitchildren(node)
1139 node.needs_closure = self.needs_closure
1140 self.needs_closure = True
1141 return node
1143 def visit_ClassDefNode(self, node):
1144 self.visitchildren(node)
1145 self.needs_closure = True
1146 return node
1148 def visit_YieldNode(self, node):
1149 self.needs_closure = True
1151 class CreateClosureClasses(CythonTransform):
1152 # Output closure classes in module scope for all functions
1153 # that need it.
1155 def visit_ModuleNode(self, node):
1156 self.module_scope = node.scope
1157 self.visitchildren(node)
1158 return node
1160 def create_class_from_scope(self, node, target_module_scope):
1161 as_name = temp_name_handle("closure")
1162 func_scope = node.local_scope
1164 entry = target_module_scope.declare_c_class(name = as_name,
1165 pos = node.pos, defining = True, implementing = True)
1166 class_scope = entry.type.scope
1167 for entry in func_scope.entries.values():
1168 class_scope.declare_var(pos=node.pos,
1169 name=entry.name,
1170 cname=entry.cname,
1171 type=entry.type,
1172 is_cdef=True)
1174 def visit_FuncDefNode(self, node):
1175 self.create_class_from_scope(node, self.module_scope)
1176 return node
1179 class GilCheck(VisitorTransform):
1180 """
1181 Call `node.gil_check(env)` on each node to make sure we hold the
1182 GIL when we need it. Raise an error when on Python operations
1183 inside a `nogil` environment.
1184 """
1185 def __call__(self, root):
1186 self.env_stack = [root.scope]
1187 self.nogil = False
1188 return super(GilCheck, self).__call__(root)
1190 def visit_FuncDefNode(self, node):
1191 self.env_stack.append(node.local_scope)
1192 was_nogil = self.nogil
1193 self.nogil = node.local_scope.nogil
1194 if self.nogil and node.nogil_check:
1195 node.nogil_check(node.local_scope)
1196 self.visitchildren(node)
1197 self.env_stack.pop()
1198 self.nogil = was_nogil
1199 return node
1201 def visit_GILStatNode(self, node):
1202 env = self.env_stack[-1]
1203 if self.nogil and node.nogil_check: node.nogil_check()
1204 was_nogil = self.nogil
1205 self.nogil = (node.state == 'nogil')
1206 self.visitchildren(node)
1207 self.nogil = was_nogil
1208 return node
1210 def visit_Node(self, node):
1211 if self.env_stack and self.nogil and node.nogil_check:
1212 node.nogil_check(self.env_stack[-1])
1213 self.visitchildren(node)
1214 return node
1217 class TransformBuiltinMethods(EnvTransform):
1219 def visit_SingleAssignmentNode(self, node):
1220 if node.declaration_only:
1221 return None
1222 else:
1223 self.visitchildren(node)
1224 return node
1226 def visit_AttributeNode(self, node):
1227 self.visitchildren(node)
1228 return self.visit_cython_attribute(node)
1230 def visit_NameNode(self, node):
1231 return self.visit_cython_attribute(node)
1233 def visit_cython_attribute(self, node):
1234 attribute = node.as_cython_attribute()
1235 if attribute:
1236 if attribute == u'compiled':
1237 node = BoolNode(node.pos, value=True)
1238 elif attribute == u'NULL':
1239 node = NullNode(node.pos)
1240 elif not PyrexTypes.parse_basic_type(attribute):
1241 error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
1242 return node
1244 def visit_SimpleCallNode(self, node):
1246 # locals builtin
1247 if isinstance(node.function, ExprNodes.NameNode):
1248 if node.function.name == 'locals':
1249 lenv = self.env_stack[-1]
1250 entry = lenv.lookup_here('locals')
1251 if entry:
1252 # not the builtin 'locals'
1253 return node
1254 if len(node.args) > 0:
1255 error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
1256 return node
1257 pos = node.pos
1258 items = [ExprNodes.DictItemNode(pos,
1259 key=ExprNodes.StringNode(pos, value=var),
1260 value=ExprNodes.NameNode(pos, name=var)) for var in lenv.entries]
1261 return ExprNodes.DictNode(pos, key_value_pairs=items)
1263 # cython.foo
1264 function = node.function.as_cython_attribute()
1265 if function:
1266 if function in InterpretCompilerDirectives.unop_method_nodes:
1267 if len(node.args) != 1:
1268 error(node.function.pos, u"%s() takes exactly one argument" % function)
1269 else:
1270 node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
1271 elif function == u'cast':
1272 if len(node.args) != 2:
1273 error(node.function.pos, u"cast() takes exactly two arguments")
1274 else:
1275 type = node.args[0].analyse_as_type(self.env_stack[-1])
1276 if type:
1277 node = TypecastNode(node.function.pos, type=type, operand=node.args[1])
1278 else:
1279 error(node.args[0].pos, "Not a type")
1280 elif function == u'sizeof':
1281 if len(node.args) != 1:
1282 error(node.function.pos, u"sizeof() takes exactly one argument")
1283 else:
1284 type = node.args[0].analyse_as_type(self.env_stack[-1])
1285 if type:
1286 node = SizeofTypeNode(node.function.pos, arg_type=type)
1287 else:
1288 node = SizeofVarNode(node.function.pos, operand=node.args[0])
1289 elif function == 'cmod':
1290 if len(node.args) != 2:
1291 error(node.function.pos, u"cmod() takes exactly two arguments")
1292 else:
1293 node = binop_node(node.function.pos, '%', node.args[0], node.args[1])
1294 node.cdivision = True
1295 elif function == 'cdiv':
1296 if len(node.args) != 2:
1297 error(node.function.pos, u"cdiv() takes exactly two arguments")
1298 else:
1299 node = binop_node(node.function.pos, '/', node.args[0], node.args[1])
1300 node.cdivision = True
1301 else:
1302 error(node.function.pos, u"'%s' not a valid cython language construct" % function)
1304 self.visitchildren(node)
1305 return node
