Cython has moved to github.

cython-devel

view Cython/Compiler/ParseTreeTransforms.py @ 3091:1a2e04bc1395

remove dependency on structmember.h
author Lisandro Dalcin <dalcinl@gmail.com>
date Thu Mar 11 17:21:13 2010 -0300 (2 years ago)
parents f38e938a4338
children f2b986c34e79
line source
1 from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
2 from Cython.Compiler.Visitor import CythonTransform, EnvTransform
3 from Cython.Compiler.ModuleNode import ModuleNode
4 from Cython.Compiler.Nodes import *
5 from Cython.Compiler.ExprNodes import *
6 from Cython.Compiler.UtilNodes import *
7 from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
8 from Cython.Compiler.StringEncoding import EncodedString
9 from Cython.Compiler.Errors import error, CompileError
10 try:
11 set
12 except NameError:
13 from sets import Set as set
14 import copy
17 class NameNodeCollector(TreeVisitor):
18 """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
19 attribute.
20 """
21 def __init__(self):
22 super(NameNodeCollector, self).__init__()
23 self.name_nodes = []
25 visit_Node = TreeVisitor.visitchildren
27 def visit_NameNode(self, node):
28 self.name_nodes.append(node)
31 class SkipDeclarations(object):
32 """
33 Variable and function declarations can often have a deep tree structure,
34 and yet most transformations don't need to descend to this depth.
36 Declaration nodes are removed after AnalyseDeclarationsTransform, so there
37 is no need to use this for transformations after that point.
38 """
39 def visit_CTypeDefNode(self, node):
40 return node
42 def visit_CVarDefNode(self, node):
43 return node
45 def visit_CDeclaratorNode(self, node):
46 return node
48 def visit_CBaseTypeNode(self, node):
49 return node
51 def visit_CEnumDefNode(self, node):
52 return node
54 def visit_CStructOrUnionDefNode(self, node):
55 return node
58 class NormalizeTree(CythonTransform):
59 """
60 This transform fixes up a few things after parsing
61 in order to make the parse tree more suitable for
62 transforms.
64 a) After parsing, blocks with only one statement will
65 be represented by that statement, not by a StatListNode.
66 When doing transforms this is annoying and inconsistent,
67 as one cannot in general remove a statement in a consistent
68 way and so on. This transform wraps any single statements
69 in a StatListNode containing a single statement.
71 b) The PassStatNode is a noop and serves no purpose beyond
72 plugging such one-statement blocks; i.e., once parsed a
73 ` "pass" can just as well be represented using an empty
74 StatListNode. This means less special cases to worry about
75 in subsequent transforms (one always checks to see if a
76 StatListNode has no children to see if the block is empty).
77 """
79 def __init__(self, context):
80 super(NormalizeTree, self).__init__(context)
81 self.is_in_statlist = False
82 self.is_in_expr = False
84 def visit_ExprNode(self, node):
85 stacktmp = self.is_in_expr
86 self.is_in_expr = True
87 self.visitchildren(node)
88 self.is_in_expr = stacktmp
89 return node
91 def visit_StatNode(self, node, is_listcontainer=False):
92 stacktmp = self.is_in_statlist
93 self.is_in_statlist = is_listcontainer
94 self.visitchildren(node)
95 self.is_in_statlist = stacktmp
96 if not self.is_in_statlist and not self.is_in_expr:
97 return StatListNode(pos=node.pos, stats=[node])
98 else:
99 return node
101 def visit_StatListNode(self, node):
102 self.is_in_statlist = True
103 self.visitchildren(node)
104 self.is_in_statlist = False
105 return node
107 def visit_ParallelAssignmentNode(self, node):
108 return self.visit_StatNode(node, True)
110 def visit_CEnumDefNode(self, node):
111 return self.visit_StatNode(node, True)
113 def visit_CStructOrUnionDefNode(self, node):
114 return self.visit_StatNode(node, True)
116 # Eliminate PassStatNode
117 def visit_PassStatNode(self, node):
118 if not self.is_in_statlist:
119 return StatListNode(pos=node.pos, stats=[])
120 else:
121 return []
123 def visit_CDeclaratorNode(self, node):
124 return node
127 class PostParseError(CompileError): pass
129 # error strings checked by unit tests, so define them
130 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
131 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
132 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
133 class PostParse(CythonTransform):
134 """
135 Basic interpretation of the parse tree, as well as validity
136 checking that can be done on a very basic level on the parse
137 tree (while still not being a problem with the basic syntax,
138 as such).
140 Specifically:
141 - Default values to cdef assignments are turned into single
142 assignments following the declaration (everywhere but in class
143 bodies, where they raise a compile error)
145 - Interpret some node structures into Python runtime values.
146 Some nodes take compile-time arguments (currently:
147 TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
148 which should be interpreted. This happens in a general way
149 and other steps should be taken to ensure validity.
151 Type arguments cannot be interpreted in this way.
153 - For __cythonbufferdefaults__ the arguments are checked for
154 validity.
156 TemplatedTypeNode has its directives interpreted:
157 Any first positional argument goes into the "dtype" attribute,
158 any "ndim" keyword argument goes into the "ndim" attribute and
159 so on. Also it is checked that the directive combination is valid.
160 - __cythonbufferdefaults__ attributes are parsed and put into the
161 type information.
163 Note: Currently Parsing.py does a lot of interpretation and
164 reorganization that can be refactored into this transform
165 if a more pure Abstract Syntax Tree is wanted.
166 """
168 # Track our context.
169 scope_type = None # can be either of 'module', 'function', 'class'
171 def __init__(self, context):
172 super(PostParse, self).__init__(context)
173 self.specialattribute_handlers = {
174 '__cythonbufferdefaults__' : self.handle_bufferdefaults
175 }
177 def visit_ModuleNode(self, node):
178 self.scope_type = 'module'
179 self.scope_node = node
180 self.visitchildren(node)
181 return node
183 def visit_scope(self, node, scope_type):
184 prev = self.scope_type, self.scope_node
185 self.scope_type = scope_type
186 self.scope_node = node
187 self.visitchildren(node)
188 self.scope_type, self.scope_node = prev
189 return node
191 def visit_ClassDefNode(self, node):
192 return self.visit_scope(node, 'class')
194 def visit_FuncDefNode(self, node):
195 return self.visit_scope(node, 'function')
197 def visit_CStructOrUnionDefNode(self, node):
198 return self.visit_scope(node, 'struct')
200 # cdef variables
201 def handle_bufferdefaults(self, decl):
202 if not isinstance(decl.default, DictNode):
203 raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
204 self.scope_node.buffer_defaults_node = decl.default
205 self.scope_node.buffer_defaults_pos = decl.pos
207 def visit_CVarDefNode(self, node):
208 # This assumes only plain names and pointers are assignable on
209 # declaration. Also, it makes use of the fact that a cdef decl
210 # must appear before the first use, so we don't have to deal with
211 # "i = 3; cdef int i = i" and can simply move the nodes around.
212 try:
213 self.visitchildren(node)
214 stats = [node]
215 newdecls = []
216 for decl in node.declarators:
217 declbase = decl
218 while isinstance(declbase, CPtrDeclaratorNode):
219 declbase = declbase.base
220 if isinstance(declbase, CNameDeclaratorNode):
221 if declbase.default is not None:
222 if self.scope_type in ('class', 'struct'):
223 if isinstance(self.scope_node, CClassDefNode):
224 handler = self.specialattribute_handlers.get(decl.name)
225 if handler:
226 if decl is not declbase:
227 raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
228 handler(decl)
229 continue # Remove declaration
230 raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
231 first_assignment = self.scope_type != 'module'
232 stats.append(SingleAssignmentNode(node.pos,
233 lhs=NameNode(node.pos, name=declbase.name),
234 rhs=declbase.default, first=first_assignment))
235 declbase.default = None
236 newdecls.append(decl)
237 node.declarators = newdecls
238 return stats
239 except PostParseError, e:
240 # An error in a cdef clause is ok, simply remove the declaration
241 # and try to move on to report more errors
242 self.context.nonfatal_error(e)
243 return None
245 # Split parallel assignments (a,b = b,a) into separate partial
246 # assignments that are executed rhs-first using temps. This
247 # optimisation is best applied before type analysis so that known
248 # types on rhs and lhs can be matched directly.
250 def visit_SingleAssignmentNode(self, node):
251 self.visitchildren(node)
252 return self._visit_assignment_node(node, [node.lhs, node.rhs])
254 def visit_CascadedAssignmentNode(self, node):
255 self.visitchildren(node)
256 return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
258 def _visit_assignment_node(self, node, expr_list):
259 """Flatten parallel assignments into separate single
260 assignments or cascaded assignments.
261 """
262 if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
263 # no parallel assignments => nothing to do
264 return node
266 expr_list_list = []
267 flatten_parallel_assignments(expr_list, expr_list_list)
268 temp_refs = []
269 eliminate_rhs_duplicates(expr_list_list, temp_refs)
271 nodes = []
272 for expr_list in expr_list_list:
273 lhs_list = expr_list[:-1]
274 rhs = expr_list[-1]
275 if len(lhs_list) == 1:
276 node = Nodes.SingleAssignmentNode(rhs.pos,
277 lhs = lhs_list[0], rhs = rhs)
278 else:
279 node = Nodes.CascadedAssignmentNode(rhs.pos,
280 lhs_list = lhs_list, rhs = rhs)
281 nodes.append(node)
283 if len(nodes) == 1:
284 assign_node = nodes[0]
285 else:
286 assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
288 if temp_refs:
289 duplicates_and_temps = [ (temp.expression, temp)
290 for temp in temp_refs ]
291 sort_common_subsequences(duplicates_and_temps)
292 for _, temp_ref in duplicates_and_temps[::-1]:
293 assign_node = LetNode(temp_ref, assign_node)
295 return assign_node
297 def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
298 """Replace rhs items by LetRefNodes if they appear more than once.
299 Creates a sequence of LetRefNodes that set up the required temps
300 and appends them to ref_node_sequence. The input list is modified
301 in-place.
302 """
303 seen_nodes = set()
304 ref_nodes = {}
305 def find_duplicates(node):
306 if node.is_literal or node.is_name:
307 # no need to replace those; can't include attributes here
308 # as their access is not necessarily side-effect free
309 return
310 if node in seen_nodes:
311 if node not in ref_nodes:
312 ref_node = LetRefNode(node)
313 ref_nodes[node] = ref_node
314 ref_node_sequence.append(ref_node)
315 else:
316 seen_nodes.add(node)
317 if node.is_sequence_constructor:
318 for item in node.args:
319 find_duplicates(item)
321 for expr_list in expr_list_list:
322 rhs = expr_list[-1]
323 find_duplicates(rhs)
324 if not ref_nodes:
325 return
327 def substitute_nodes(node):
328 if node in ref_nodes:
329 return ref_nodes[node]
330 elif node.is_sequence_constructor:
331 node.args = map(substitute_nodes, node.args)
332 return node
334 # replace nodes inside of the common subexpressions
335 for node in ref_nodes:
336 if node.is_sequence_constructor:
337 node.args = map(substitute_nodes, node.args)
339 # replace common subexpressions on all rhs items
340 for expr_list in expr_list_list:
341 expr_list[-1] = substitute_nodes(expr_list[-1])
343 def sort_common_subsequences(items):
344 """Sort items/subsequences so that all items and subsequences that
345 an item contains appear before the item itself. This implies a
346 partial order, and the sort must be stable to preserve the
347 original order as much as possible, so we use a simple insertion
348 sort.
349 """
350 def contains(seq, x):
351 for item in seq:
352 if item is x:
353 return True
354 elif item.is_sequence_constructor and contains(item.args, x):
355 return True
356 return False
357 def lower_than(a,b):
358 return b.is_sequence_constructor and contains(b.args, a)
360 for pos, item in enumerate(items):
361 new_pos = pos
362 key = item[0]
363 for i in xrange(pos-1, -1, -1):
364 if lower_than(key, items[i][0]):
365 new_pos = i
366 if new_pos != pos:
367 for i in xrange(pos, new_pos, -1):
368 items[i] = items[i-1]
369 items[new_pos] = item
371 def flatten_parallel_assignments(input, output):
372 # The input is a list of expression nodes, representing the LHSs
373 # and RHS of one (possibly cascaded) assignment statement. For
374 # sequence constructors, rearranges the matching parts of both
375 # sides into a list of equivalent assignments between the
376 # individual elements. This transformation is applied
377 # recursively, so that nested structures get matched as well.
378 rhs = input[-1]
379 if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
380 output.append(input)
381 return
383 complete_assignments = []
385 rhs_size = len(rhs.args)
386 lhs_targets = [ [] for _ in xrange(rhs_size) ]
387 starred_assignments = []
388 for lhs in input[:-1]:
389 if not lhs.is_sequence_constructor:
390 if lhs.is_starred:
391 error(lhs.pos, "starred assignment target must be in a list or tuple")
392 complete_assignments.append(lhs)
393 continue
394 lhs_size = len(lhs.args)
395 starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
396 if starred_targets > 1:
397 error(lhs.pos, "more than 1 starred expression in assignment")
398 output.append([lhs,rhs])
399 continue
400 elif lhs_size - starred_targets > rhs_size:
401 error(lhs.pos, "need more than %d value%s to unpack"
402 % (rhs_size, (rhs_size != 1) and 's' or ''))
403 output.append([lhs,rhs])
404 continue
405 elif starred_targets:
406 map_starred_assignment(lhs_targets, starred_assignments,
407 lhs.args, rhs.args)
408 elif lhs_size < rhs_size:
409 error(lhs.pos, "too many values to unpack (expected %d, got %d)"
410 % (lhs_size, rhs_size))
411 output.append([lhs,rhs])
412 continue
413 else:
414 for targets, expr in zip(lhs_targets, lhs.args):
415 targets.append(expr)
417 if complete_assignments:
418 complete_assignments.append(rhs)
419 output.append(complete_assignments)
421 # recursively flatten partial assignments
422 for cascade, rhs in zip(lhs_targets, rhs.args):
423 if cascade:
424 cascade.append(rhs)
425 flatten_parallel_assignments(cascade, output)
427 # recursively flatten starred assignments
428 for cascade in starred_assignments:
429 if cascade[0].is_sequence_constructor:
430 flatten_parallel_assignments(cascade, output)
431 else:
432 output.append(cascade)
434 def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
435 # Appends the fixed-position LHS targets to the target list that
436 # appear left and right of the starred argument.
437 #
438 # The starred_assignments list receives a new tuple
439 # (lhs_target, rhs_values_list) that maps the remaining arguments
440 # (those that match the starred target) to a list.
442 # left side of the starred target
443 for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
444 if expr.is_starred:
445 starred = i
446 lhs_remaining = len(lhs_args) - i - 1
447 break
448 targets.append(expr)
449 else:
450 raise InternalError("no starred arg found when splitting starred assignment")
452 # right side of the starred target
453 for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
454 lhs_args[-lhs_remaining:])):
455 targets.append(expr)
457 # the starred target itself, must be assigned a (potentially empty) list
458 target = lhs_args[starred].target # unpack starred node
459 starred_rhs = rhs_args[starred:]
460 if lhs_remaining:
461 starred_rhs = starred_rhs[:-lhs_remaining]
462 if starred_rhs:
463 pos = starred_rhs[0].pos
464 else:
465 pos = target.pos
466 starred_assignments.append([
467 target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
470 class PxdPostParse(CythonTransform, SkipDeclarations):
471 """
472 Basic interpretation/validity checking that should only be
473 done on pxd trees.
475 A lot of this checking currently happens in the parser; but
476 what is listed below happens here.
478 - "def" functions are let through only if they fill the
479 getbuffer/releasebuffer slots
481 - cdef functions are let through only if they are on the
482 top level and are declared "inline"
483 """
484 ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
485 ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
487 def __call__(self, node):
488 self.scope_type = 'pxd'
489 return super(PxdPostParse, self).__call__(node)
491 def visit_CClassDefNode(self, node):
492 old = self.scope_type
493 self.scope_type = 'cclass'
494 self.visitchildren(node)
495 self.scope_type = old
496 return node
498 def visit_FuncDefNode(self, node):
499 # FuncDefNode always come with an implementation (without
500 # an imp they are CVarDefNodes..)
501 err = self.ERR_INLINE_ONLY
503 if (isinstance(node, DefNode) and self.scope_type == 'cclass'
504 and node.name in ('__getbuffer__', '__releasebuffer__')):
505 err = None # allow these slots
507 if isinstance(node, CFuncDefNode):
508 if u'inline' in node.modifiers and self.scope_type == 'pxd':
509 node.inline_in_pxd = True
510 if node.visibility != 'private':
511 err = self.ERR_NOGO_WITH_INLINE % node.visibility
512 elif node.api:
513 err = self.ERR_NOGO_WITH_INLINE % 'api'
514 else:
515 err = None # allow inline function
516 else:
517 err = self.ERR_INLINE_ONLY
519 if err:
520 self.context.nonfatal_error(PostParseError(node.pos, err))
521 return None
522 else:
523 return node
525 class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
526 """
527 After parsing, directives can be stored in a number of places:
528 - #cython-comments at the top of the file (stored in ModuleNode)
529 - Command-line arguments overriding these
530 - @cython.directivename decorators
531 - with cython.directivename: statements
533 This transform is responsible for interpreting these various sources
534 and store the directive in two ways:
535 - Set the directives attribute of the ModuleNode for global directives.
536 - Use a CompilerDirectivesNode to override directives for a subtree.
538 (The first one is primarily to not have to modify with the tree
539 structure, so that ModuleNode stay on top.)
541 The directives are stored in dictionaries from name to value in effect.
542 Each such dictionary is always filled in for all possible directives,
543 using default values where no value is given by the user.
545 The available directives are controlled in Options.py.
547 Note that we have to run this prior to analysis, and so some minor
548 duplication of functionality has to occur: We manually track cimports
549 and which names the "cython" module may have been imported to.
550 """
551 unop_method_nodes = {
552 'typeof': TypeofNode,
554 'operator.address': AmpersandNode,
555 'operator.dereference': DereferenceNode,
556 'operator.preincrement' : inc_dec_constructor(True, '++'),
557 'operator.predecrement' : inc_dec_constructor(True, '--'),
558 'operator.postincrement': inc_dec_constructor(False, '++'),
559 'operator.postdecrement': inc_dec_constructor(False, '--'),
561 # For backwards compatability.
562 'address': AmpersandNode,
563 }
565 special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof',
566 'cast', 'pointer', 'compiled', 'NULL']
567 + unop_method_nodes.keys())
569 def __init__(self, context, compilation_directive_defaults):
570 super(InterpretCompilerDirectives, self).__init__(context)
571 self.compilation_directive_defaults = {}
572 for key, value in compilation_directive_defaults.iteritems():
573 self.compilation_directive_defaults[unicode(key)] = value
574 self.cython_module_names = set()
575 self.directive_names = {}
577 def check_directive_scope(self, pos, directive, scope):
578 legal_scopes = Options.directive_scopes.get(directive, None)
579 if legal_scopes and scope not in legal_scopes:
580 self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
581 'is not allowed in %s scope' % (directive, scope)))
582 return False
583 else:
584 return True
586 # Set up processing and handle the cython: comments.
587 def visit_ModuleNode(self, node):
588 for key, value in node.directive_comments.iteritems():
589 if not self.check_directive_scope(node.pos, key, 'module'):
590 self.wrong_scope_error(node.pos, key, 'module')
591 del node.directive_comments[key]
593 directives = copy.copy(Options.directive_defaults)
594 directives.update(self.compilation_directive_defaults)
595 directives.update(node.directive_comments)
596 self.directives = directives
597 node.directives = directives
598 self.visitchildren(node)
599 node.cython_module_names = self.cython_module_names
600 return node
602 # The following four functions track imports and cimports that
603 # begin with "cython"
604 def is_cython_directive(self, name):
605 return (name in Options.directive_types or
606 name in self.special_methods or
607 PyrexTypes.parse_basic_type(name))
609 def visit_CImportStatNode(self, node):
610 if node.module_name == u"cython":
611 self.cython_module_names.add(node.as_name or u"cython")
612 elif node.module_name.startswith(u"cython."):
613 if node.as_name:
614 self.directive_names[node.as_name] = node.module_name[7:]
615 else:
616 self.cython_module_names.add(u"cython")
617 # if this cimport was a compiler directive, we don't
618 # want to leave the cimport node sitting in the tree
619 return None
620 return node
622 def visit_FromCImportStatNode(self, node):
623 if (node.module_name == u"cython") or \
624 node.module_name.startswith(u"cython."):
625 submodule = (node.module_name + u".")[7:]
626 newimp = []
627 for pos, name, as_name, kind in node.imported_names:
628 full_name = submodule + name
629 if self.is_cython_directive(full_name):
630 if as_name is None:
631 as_name = full_name
632 self.directive_names[as_name] = full_name
633 if kind is not None:
634 self.context.nonfatal_error(PostParseError(pos,
635 "Compiler directive imports must be plain imports"))
636 else:
637 newimp.append((pos, name, as_name, kind))
638 if not newimp:
639 return None
640 node.imported_names = newimp
641 return node
643 def visit_FromImportStatNode(self, node):
644 if (node.module.module_name.value == u"cython") or \
645 node.module.module_name.value.startswith(u"cython."):
646 submodule = (node.module.module_name.value + u".")[7:]
647 newimp = []
648 for name, name_node in node.items:
649 full_name = submodule + name
650 if self.is_cython_directive(full_name):
651 self.directive_names[name_node.name] = full_name
652 else:
653 newimp.append((name, name_node))
654 if not newimp:
655 return None
656 node.items = newimp
657 return node
659 def visit_SingleAssignmentNode(self, node):
660 if (isinstance(node.rhs, ImportNode) and
661 node.rhs.module_name.value == u'cython'):
662 node = CImportStatNode(node.pos,
663 module_name = u'cython',
664 as_name = node.lhs.name)
665 self.visit_CImportStatNode(node)
666 else:
667 self.visitchildren(node)
668 return node
670 def visit_NameNode(self, node):
671 if node.name in self.cython_module_names:
672 node.is_cython_module = True
673 else:
674 node.cython_attribute = self.directive_names.get(node.name)
675 return node
677 def try_to_parse_directives(self, node):
678 # If node is the contents of an directive (in a with statement or
679 # decorator), returns a list of (directivename, value) pairs.
680 # Otherwise, returns None
681 if isinstance(node, CallNode):
682 self.visit(node.function)
683 optname = node.function.as_cython_attribute()
684 if optname:
685 directivetype = Options.directive_types.get(optname)
686 if directivetype:
687 args, kwds = node.explicit_args_kwds()
688 directives = []
689 key_value_pairs = []
690 if kwds is not None and directivetype is not dict:
691 for keyvalue in kwds.key_value_pairs:
692 key, value = keyvalue
693 sub_optname = "%s.%s" % (optname, key.value)
694 if Options.directive_types.get(sub_optname):
695 directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
696 else:
697 key_value_pairs.append(keyvalue)
698 if not key_value_pairs:
699 kwds = None
700 else:
701 kwds.key_value_pairs = key_value_pairs
702 if directives and not kwds and not args:
703 return directives
704 directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
705 return directives
707 return None
709 def try_to_parse_directive(self, optname, args, kwds, pos):
710 directivetype = Options.directive_types.get(optname)
711 if len(args) == 1 and isinstance(args[0], NoneNode):
712 return optname, Options.directive_defaults[optname]
713 elif directivetype is bool:
714 if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
715 raise PostParseError(pos,
716 'The %s directive takes one compile-time boolean argument' % optname)
717 return (optname, args[0].value)
718 elif directivetype is str:
719 if kwds is not None or len(args) != 1 or not isinstance(args[0], (StringNode, UnicodeNode)):
720 raise PostParseError(pos,
721 'The %s directive takes one compile-time string argument' % optname)
722 return (optname, str(args[0].value))
723 elif directivetype is dict:
724 if len(args) != 0:
725 raise PostParseError(pos,
726 'The %s directive takes no prepositional arguments' % optname)
727 return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
728 elif directivetype is list:
729 if kwds and len(kwds) != 0:
730 raise PostParseError(pos,
731 'The %s directive takes no keyword arguments' % optname)
732 return optname, [ str(arg.value) for arg in args ]
733 else:
734 assert False
736 def visit_with_directives(self, body, directives):
737 olddirectives = self.directives
738 newdirectives = copy.copy(olddirectives)
739 newdirectives.update(directives)
740 self.directives = newdirectives
741 assert isinstance(body, StatListNode), body
742 retbody = self.visit_Node(body)
743 directive = CompilerDirectivesNode(pos=retbody.pos, body=retbody,
744 directives=newdirectives)
745 self.directives = olddirectives
746 return directive
748 # Handle decorators
749 def visit_FuncDefNode(self, node):
750 directives = []
751 if node.decorators:
752 # Split the decorators into two lists -- real decorators and directives
753 realdecs = []
754 for dec in node.decorators:
755 new_directives = self.try_to_parse_directives(dec.decorator)
756 if new_directives is not None:
757 directives.extend(new_directives)
758 else:
759 realdecs.append(dec)
760 if realdecs and isinstance(node, CFuncDefNode):
761 raise PostParseError(realdecs[0].pos, "Cdef functions cannot take arbitrary decorators.")
762 else:
763 node.decorators = realdecs
765 if directives:
766 optdict = {}
767 directives.reverse() # Decorators coming first take precedence
768 for directive in directives:
769 name, value = directive
770 legal_scopes = Options.directive_scopes.get(name, None)
771 if not self.check_directive_scope(node.pos, name, 'function'):
772 continue
773 if name in optdict:
774 old_value = optdict[name]
775 # keywords and arg lists can be merged, everything
776 # else overrides completely
777 if isinstance(old_value, dict):
778 old_value.update(value)
779 elif isinstance(old_value, list):
780 old_value.extend(value)
781 else:
782 optdict[name] = value
783 else:
784 optdict[name] = value
785 body = StatListNode(node.pos, stats=[node])
786 return self.visit_with_directives(body, optdict)
787 else:
788 return self.visit_Node(node)
790 def visit_CVarDefNode(self, node):
791 if node.decorators:
792 for dec in node.decorators:
793 for directive in self.try_to_parse_directives(dec.decorator) or []:
794 if directive is not None and directive[0] == u'locals':
795 node.directive_locals = directive[1]
796 else:
797 self.context.nonfatal_error(PostParseError(dec.pos,
798 "Cdef functions can only take cython.locals() decorator."))
799 return node
801 # Handle with statements
802 def visit_WithStatNode(self, node):
803 directive_dict = {}
804 for directive in self.try_to_parse_directives(node.manager) or []:
805 if directive is not None:
806 if node.target is not None:
807 self.context.nonfatal_error(
808 PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
809 else:
810 name, value = directive
811 if self.check_directive_scope(node.pos, name, 'with statement'):
812 directive_dict[name] = value
813 if directive_dict:
814 return self.visit_with_directives(node.body, directive_dict)
815 return self.visit_Node(node)
817 class WithTransform(CythonTransform, SkipDeclarations):
819 # EXCINFO is manually set to a variable that contains
820 # the exc_info() tuple that can be generated by the enclosing except
821 # statement.
822 template_without_target = TreeFragment(u"""
823 MGR = EXPR
824 EXIT = MGR.__exit__
825 MGR.__enter__()
826 EXC = True
827 try:
828 try:
829 EXCINFO = None
830 BODY
831 except:
832 EXC = False
833 if not EXIT(*EXCINFO):
834 raise
835 finally:
836 if EXC:
837 EXIT(None, None, None)
838 """, temps=[u'MGR', u'EXC', u"EXIT"],
839 pipeline=[NormalizeTree(None)])
841 template_with_target = TreeFragment(u"""
842 MGR = EXPR
843 EXIT = MGR.__exit__
844 VALUE = MGR.__enter__()
845 EXC = True
846 try:
847 try:
848 EXCINFO = None
849 TARGET = VALUE
850 BODY
851 except:
852 EXC = False
853 if not EXIT(*EXCINFO):
854 raise
855 finally:
856 if EXC:
857 EXIT(None, None, None)
858 MGR = EXIT = VALUE = EXC = None
860 """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
861 pipeline=[NormalizeTree(None)])
863 def visit_WithStatNode(self, node):
864 # TODO: Cleanup badly needed
865 TemplateTransform.temp_name_counter += 1
866 handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
868 self.visitchildren(node, ['body'])
869 excinfo_temp = NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
870 if node.target is not None:
871 result = self.template_with_target.substitute({
872 u'EXPR' : node.manager,
873 u'BODY' : node.body,
874 u'TARGET' : node.target,
875 u'EXCINFO' : excinfo_temp
876 }, pos=node.pos)
877 else:
878 result = self.template_without_target.substitute({
879 u'EXPR' : node.manager,
880 u'BODY' : node.body,
881 u'EXCINFO' : excinfo_temp
882 }, pos=node.pos)
884 # Set except excinfo target to EXCINFO
885 try_except = result.stats[-1].body.stats[-1]
886 try_except.except_clauses[0].excinfo_target = NameNode(node.pos, name=handle)
887 # excinfo_temp.ref(node.pos))
889 # result.stats[-1].body.stats[-1] = TempsBlockNode(
890 # node.pos, temps=[excinfo_temp], body=try_except)
892 return result
894 def visit_ExprNode(self, node):
895 # With statements are never inside expressions.
896 return node
899 class DecoratorTransform(CythonTransform, SkipDeclarations):
901 def visit_DefNode(self, func_node):
902 self.visitchildren(func_node)
903 if not func_node.decorators:
904 return func_node
905 return self._handle_decorators(
906 func_node, func_node.name)
908 def _visit_CClassDefNode(self, class_node):
909 # This doesn't currently work, so it's disabled (also in the
910 # parser).
911 #
912 # Problem: assignments to cdef class names do not work. They
913 # would require an additional check anyway, as the extension
914 # type must not change its C type, so decorators cannot
915 # replace an extension type, just alter it and return it.
917 self.visitchildren(class_node)
918 if not class_node.decorators:
919 return class_node
920 return self._handle_decorators(
921 class_node, class_node.class_name)
923 def visit_ClassDefNode(self, class_node):
924 self.visitchildren(class_node)
925 if not class_node.decorators:
926 return class_node
927 return self._handle_decorators(
928 class_node, class_node.name)
930 def _handle_decorators(self, node, name):
931 decorator_result = NameNode(node.pos, name = name)
932 for decorator in node.decorators[::-1]:
933 decorator_result = SimpleCallNode(
934 decorator.pos,
935 function = decorator.decorator,
936 args = [decorator_result])
938 name_node = NameNode(node.pos, name = name)
939 reassignment = SingleAssignmentNode(
940 node.pos,
941 lhs = name_node,
942 rhs = decorator_result)
943 return [node, reassignment]
946 class AnalyseDeclarationsTransform(CythonTransform):
948 basic_property = TreeFragment(u"""
949 property NAME:
950 def __get__(self):
951 return ATTR
952 def __set__(self, value):
953 ATTR = value
954 """, level='c_class')
955 basic_property_ro = TreeFragment(u"""
956 property NAME:
957 def __get__(self):
958 return ATTR
959 """, level='c_class')
961 def __call__(self, root):
962 self.env_stack = [root.scope]
963 # needed to determine if a cdef var is declared after it's used.
964 self.seen_vars_stack = []
965 return super(AnalyseDeclarationsTransform, self).__call__(root)
967 def visit_NameNode(self, node):
968 self.seen_vars_stack[-1].add(node.name)
969 return node
971 def visit_ModuleNode(self, node):
972 self.seen_vars_stack.append(set())
973 node.analyse_declarations(self.env_stack[-1])
974 self.visitchildren(node)
975 self.seen_vars_stack.pop()
976 return node
978 def visit_ClassDefNode(self, node):
979 self.env_stack.append(node.scope)
980 self.visitchildren(node)
981 self.env_stack.pop()
982 return node
984 def visit_FuncDefNode(self, node):
985 self.seen_vars_stack.append(set())
986 lenv = node.local_scope
987 node.body.analyse_control_flow(lenv) # this will be totally refactored
988 node.declare_arguments(lenv)
989 for var, type_node in node.directive_locals.items():
990 if not lenv.lookup_here(var): # don't redeclare args
991 type = type_node.analyse_as_type(lenv)
992 if type:
993 lenv.declare_var(var, type, type_node.pos)
994 else:
995 error(type_node.pos, "Not a type")
996 node.body.analyse_declarations(lenv)
997 self.env_stack.append(lenv)
998 self.visitchildren(node)
999 self.env_stack.pop()
1000 self.seen_vars_stack.pop()
1001 return node
1003 def visit_ComprehensionNode(self, node):
1004 self.visitchildren(node)
1005 node.analyse_declarations(self.env_stack[-1])
1006 return node
1008 # Some nodes are no longer needed after declaration
1009 # analysis and can be dropped. The analysis was performed
1010 # on these nodes in a seperate recursive process from the
1011 # enclosing function or module, so we can simply drop them.
1012 def visit_CDeclaratorNode(self, node):
1013 # necessary to ensure that all CNameDeclaratorNodes are visited.
1014 self.visitchildren(node)
1015 return node
1017 def visit_CTypeDefNode(self, node):
1018 return node
1020 def visit_CBaseTypeNode(self, node):
1021 return None
1023 def visit_CEnumDefNode(self, node):
1024 if node.visibility == 'public':
1025 return node
1026 else:
1027 return None
1029 def visit_CStructOrUnionDefNode(self, node):
1030 return None
1032 def visit_CNameDeclaratorNode(self, node):
1033 if node.name in self.seen_vars_stack[-1]:
1034 entry = self.env_stack[-1].lookup(node.name)
1035 if entry is None or entry.visibility != 'extern':
1036 warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1037 self.visitchildren(node)
1038 return node
1040 def visit_CVarDefNode(self, node):
1042 # to ensure all CNameDeclaratorNodes are visited.
1043 self.visitchildren(node)
1045 if node.properties:
1046 stats = []
1047 for entry in node.properties:
1048 property = self.create_Property(entry)
1049 property.analyse_declarations(node.dest_scope)
1050 self.visit(property)
1051 stats.append(property)
1052 return StatListNode(pos=node.pos, stats=stats)
1053 else:
1054 return None
1056 def create_Property(self, entry):
1057 if entry.visibility == 'public':
1058 template = self.basic_property
1059 elif entry.visibility == 'readonly':
1060 template = self.basic_property_ro
1061 property = template.substitute({
1062 u"ATTR": AttributeNode(pos=entry.pos,
1063 obj=NameNode(pos=entry.pos, name="self"),
1064 attribute=entry.name),
1065 }, pos=entry.pos).stats[0]
1066 property.name = entry.name
1067 # ---------------------------------------
1068 # XXX This should go to AutoDocTransforms
1069 # ---------------------------------------
1070 if self.current_directives['embedsignature']:
1071 attr_name = entry.name
1072 type_name = entry.type.declaration_code("", for_display=1)
1073 default_value = ''
1074 if not entry.type.is_pyobject:
1075 type_name = "'%s'" % type_name
1076 elif entry.type.is_extension_type:
1077 type_name = entry.type.module_name + '.' + type_name
1078 if entry.init is not None:
1079 default_value = ' = ' + entry.init
1080 elif entry.init_to_none:
1081 default_value = ' = ' + repr(None)
1082 docstring = attr_name + ': ' + type_name + default_value
1083 property.doc = EncodedString(docstring)
1084 # ---------------------------------------
1085 return property
1087 class AnalyseExpressionsTransform(CythonTransform):
1089 def visit_ModuleNode(self, node):
1090 node.scope.infer_types()
1091 node.body.analyse_expressions(node.scope)
1092 self.visitchildren(node)
1093 return node
1095 def visit_FuncDefNode(self, node):
1096 node.local_scope.infer_types()
1097 node.body.analyse_expressions(node.local_scope)
1098 self.visitchildren(node)
1099 return node
1101 class AlignFunctionDefinitions(CythonTransform):
1102 """
1103 This class takes the signatures from a .pxd file and applies them to
1104 the def methods in a .py file.
1105 """
1107 def visit_ModuleNode(self, node):
1108 self.scope = node.scope
1109 self.directives = node.directives
1110 self.visitchildren(node)
1111 return node
1113 def visit_PyClassDefNode(self, node):
1114 pxd_def = self.scope.lookup(node.name)
1115 if pxd_def:
1116 if pxd_def.is_cclass:
1117 return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
1118 else:
1119 error(node.pos, "'%s' redeclared" % node.name)
1120 error(pxd_def.pos, "previous declaration here")
1121 return None
1122 else:
1123 return node
1125 def visit_CClassDefNode(self, node, pxd_def=None):
1126 if pxd_def is None:
1127 pxd_def = self.scope.lookup(node.class_name)
1128 if pxd_def:
1129 outer_scope = self.scope
1130 self.scope = pxd_def.type.scope
1131 self.visitchildren(node)
1132 if pxd_def:
1133 self.scope = outer_scope
1134 return node
1136 def visit_DefNode(self, node):
1137 pxd_def = self.scope.lookup(node.name)
1138 if pxd_def:
1139 if self.scope.is_c_class_scope and len(pxd_def.type.args) > 0:
1140 # The self parameter type needs adjusting.
1141 pxd_def.type.args[0].type = self.scope.parent_type
1142 if pxd_def.is_cfunction:
1143 node = node.as_cfunction(pxd_def)
1144 else:
1145 error(node.pos, "'%s' redeclared" % node.name)
1146 error(pxd_def.pos, "previous declaration here")
1147 return None
1148 elif self.scope.is_module_scope and self.directives['auto_cpdef']:
1149 node = node.as_cfunction(scope=self.scope)
1150 # Enable this when internal def functions are allowed.
1151 # self.visitchildren(node)
1152 return node
1155 class MarkClosureVisitor(CythonTransform):
1157 needs_closure = False
1159 def visit_FuncDefNode(self, node):
1160 self.needs_closure = False
1161 self.visitchildren(node)
1162 node.needs_closure = self.needs_closure
1163 self.needs_closure = True
1164 return node
1166 def visit_ClassDefNode(self, node):
1167 self.visitchildren(node)
1168 self.needs_closure = True
1169 return node
1171 def visit_YieldNode(self, node):
1172 self.needs_closure = True
1174 class CreateClosureClasses(CythonTransform):
1175 # Output closure classes in module scope for all functions
1176 # that need it.
1178 def visit_ModuleNode(self, node):
1179 self.module_scope = node.scope
1180 self.visitchildren(node)
1181 return node
1183 def create_class_from_scope(self, node, target_module_scope):
1184 as_name = temp_name_handle("closure")
1185 func_scope = node.local_scope
1187 entry = target_module_scope.declare_c_class(name = as_name,
1188 pos = node.pos, defining = True, implementing = True)
1189 class_scope = entry.type.scope
1190 for entry in func_scope.entries.values():
1191 class_scope.declare_var(pos=node.pos,
1192 name=entry.name,
1193 cname=entry.cname,
1194 type=entry.type,
1195 is_cdef=True)
1197 def visit_FuncDefNode(self, node):
1198 self.create_class_from_scope(node, self.module_scope)
1199 return node
1202 class GilCheck(VisitorTransform):
1203 """
1204 Call `node.gil_check(env)` on each node to make sure we hold the
1205 GIL when we need it. Raise an error when on Python operations
1206 inside a `nogil` environment.
1207 """
1208 def __call__(self, root):
1209 self.env_stack = [root.scope]
1210 self.nogil = False
1211 return super(GilCheck, self).__call__(root)
1213 def visit_FuncDefNode(self, node):
1214 self.env_stack.append(node.local_scope)
1215 was_nogil = self.nogil
1216 self.nogil = node.local_scope.nogil
1217 if self.nogil and node.nogil_check:
1218 node.nogil_check(node.local_scope)
1219 self.visitchildren(node)
1220 self.env_stack.pop()
1221 self.nogil = was_nogil
1222 return node
1224 def visit_GILStatNode(self, node):
1225 env = self.env_stack[-1]
1226 if self.nogil and node.nogil_check: node.nogil_check()
1227 was_nogil = self.nogil
1228 self.nogil = (node.state == 'nogil')
1229 self.visitchildren(node)
1230 self.nogil = was_nogil
1231 return node
1233 def visit_Node(self, node):
1234 if self.env_stack and self.nogil and node.nogil_check:
1235 node.nogil_check(self.env_stack[-1])
1236 self.visitchildren(node)
1237 return node
1240 class TransformBuiltinMethods(EnvTransform):
1242 def visit_SingleAssignmentNode(self, node):
1243 if node.declaration_only:
1244 return None
1245 else:
1246 self.visitchildren(node)
1247 return node
1249 def visit_AttributeNode(self, node):
1250 self.visitchildren(node)
1251 return self.visit_cython_attribute(node)
1253 def visit_NameNode(self, node):
1254 return self.visit_cython_attribute(node)
1256 def visit_cython_attribute(self, node):
1257 attribute = node.as_cython_attribute()
1258 if attribute:
1259 if attribute == u'compiled':
1260 node = BoolNode(node.pos, value=True)
1261 elif attribute == u'NULL':
1262 node = NullNode(node.pos)
1263 elif not PyrexTypes.parse_basic_type(attribute):
1264 error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
1265 return node
1267 def visit_SimpleCallNode(self, node):
1269 # locals builtin
1270 if isinstance(node.function, ExprNodes.NameNode):
1271 if node.function.name == 'locals':
1272 lenv = self.env_stack[-1]
1273 entry = lenv.lookup_here('locals')
1274 if entry:
1275 # not the builtin 'locals'
1276 return node
1277 if len(node.args) > 0:
1278 error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
1279 return node
1280 pos = node.pos
1281 items = [ExprNodes.DictItemNode(pos,
1282 key=ExprNodes.StringNode(pos, value=var),
1283 value=ExprNodes.NameNode(pos, name=var)) for var in lenv.entries]
1284 return ExprNodes.DictNode(pos, key_value_pairs=items)
1286 # cython.foo
1287 function = node.function.as_cython_attribute()
1288 if function:
1289 if function in InterpretCompilerDirectives.unop_method_nodes:
1290 if len(node.args) != 1:
1291 error(node.function.pos, u"%s() takes exactly one argument" % function)
1292 else:
1293 node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
1294 elif function == u'cast':
1295 if len(node.args) != 2:
1296 error(node.function.pos, u"cast() takes exactly two arguments")
1297 else:
1298 type = node.args[0].analyse_as_type(self.env_stack[-1])
1299 if type:
1300 node = TypecastNode(node.function.pos, type=type, operand=node.args[1])
1301 else:
1302 error(node.args[0].pos, "Not a type")
1303 elif function == u'sizeof':
1304 if len(node.args) != 1:
1305 error(node.function.pos, u"sizeof() takes exactly one argument")
1306 else:
1307 type = node.args[0].analyse_as_type(self.env_stack[-1])
1308 if type:
1309 node = SizeofTypeNode(node.function.pos, arg_type=type)
1310 else:
1311 node = SizeofVarNode(node.function.pos, operand=node.args[0])
1312 elif function == 'cmod':
1313 if len(node.args) != 2:
1314 error(node.function.pos, u"cmod() takes exactly two arguments")
1315 else:
1316 node = binop_node(node.function.pos, '%', node.args[0], node.args[1])
1317 node.cdivision = True
1318 elif function == 'cdiv':
1319 if len(node.args) != 2:
1320 error(node.function.pos, u"cdiv() takes exactly two arguments")
1321 else:
1322 node = binop_node(node.function.pos, '/', node.args[0], node.args[1])
1323 node.cdivision = True
1324 else:
1325 error(node.function.pos, u"'%s' not a valid cython language construct" % function)
1327 self.visitchildren(node)
1328 return node