cython-devel
changeset 1478:b638811d14d0
implement set/dict comprehensions and set literals
| author | Stefan Behnel <scoder@users.berlios.de> |
|---|---|
| date | Fri Dec 12 09:21:10 2008 +0100 (3 years ago) |
| parents | 4105074b6c2e |
| children | f21bbde7d1c3 |
| files | Cython/Compiler/ExprNodes.py Cython/Compiler/Parsing.py tests/run/dictcomp.pyx tests/run/set.pyx tests/run/setcomp.pyx |
line diff
1.1 --- a/Cython/Compiler/ExprNodes.py Fri Dec 12 08:22:49 2008 +0100
1.2 +++ b/Cython/Compiler/ExprNodes.py Fri Dec 12 09:21:10 2008 +0100
1.3 @@ -12,7 +12,7 @@
1.4 from Nodes import Node
1.5 import PyrexTypes
1.6 from PyrexTypes import py_object_type, c_long_type, typecast, error_type
1.7 -from Builtin import list_type, tuple_type, dict_type, unicode_type
1.8 +from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type
1.9 import Symtab
1.10 import Options
1.11 from Annotate import AnnotationItem
1.12 @@ -3007,7 +3007,7 @@
1.13 gil_message = "Constructing Python list"
1.14
1.15 def analyse_expressions(self, env):
1.16 - ExprNode.analyse_expressions(self, env)
1.17 + SequenceNode.analyse_expressions(self, env)
1.18 self.coerce_to_pyobject(env)
1.19
1.20 def analyse_types(self, env):
1.21 @@ -3091,7 +3091,7 @@
1.22 arg.result()))
1.23 else:
1.24 raise InternalError("List type never specified")
1.25 -
1.26 +
1.27 def generate_subexpr_disposal_code(self, code):
1.28 # We call generate_post_assignment_code here instead
1.29 # of generate_disposal_code, because values were stored
1.30 @@ -3101,16 +3101,16 @@
1.31 # Should NOT call free_temps -- this is invoked by the default
1.32 # generate_evaluation_code which will do that.
1.33
1.34 -
1.35 -class ListComprehensionNode(SequenceNode):
1.36 -
1.37 +
1.38 +class ComprehensionNode(SequenceNode):
1.39 subexprs = []
1.40 is_sequence_constructor = 0 # not unpackable
1.41 + comp_result_type = py_object_type
1.42
1.43 child_attrs = ["loop", "append"]
1.44
1.45 def analyse_types(self, env):
1.46 - self.type = list_type
1.47 + self.type = self.comp_result_type
1.48 self.is_temp = 1
1.49 self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
1.50
1.51 @@ -3132,25 +3132,126 @@
1.52 self.loop.annotate(code)
1.53
1.54
1.55 -class ListComprehensionAppendNode(ExprNode):
1.56 -
1.57 +class ListComprehensionNode(ComprehensionNode):
1.58 + comp_result_type = list_type
1.59 +
1.60 + def generate_operation_code(self, code):
1.61 + code.putln("%s = PyList_New(%s); %s" %
1.62 + (self.result(),
1.63 + 0,
1.64 + code.error_goto_if_null(self.result(), self.pos)))
1.65 + self.loop.generate_execution_code(code)
1.66 +
1.67 +class SetComprehensionNode(ComprehensionNode):
1.68 + comp_result_type = set_type
1.69 +
1.70 + def generate_operation_code(self, code):
1.71 + code.putln("%s = PySet_New(0); %s" % # arg == iterable, not size!
1.72 + (self.result(),
1.73 + code.error_goto_if_null(self.result(), self.pos)))
1.74 + self.loop.generate_execution_code(code)
1.75 +
1.76 +class DictComprehensionNode(ComprehensionNode):
1.77 + comp_result_type = dict_type
1.78 +
1.79 + def generate_operation_code(self, code):
1.80 + code.putln("%s = PyDict_New(); %s" %
1.81 + (self.result(),
1.82 + code.error_goto_if_null(self.result(), self.pos)))
1.83 + self.loop.generate_execution_code(code)
1.84 +
1.85 +
1.86 +class ComprehensionAppendNode(NewTempExprNode):
1.87 # Need to be careful to avoid infinite recursion:
1.88 # target must not be in child_attrs/subexprs
1.89 subexprs = ['expr']
1.90
1.91 def analyse_types(self, env):
1.92 self.expr.analyse_types(env)
1.93 - if self.expr.type != py_object_type:
1.94 + if not self.expr.type.is_pyobject:
1.95 self.expr = self.expr.coerce_to_pyobject(env)
1.96 self.type = PyrexTypes.c_int_type
1.97 self.is_temp = 1
1.98 -
1.99 +
1.100 +class ListComprehensionAppendNode(ComprehensionAppendNode):
1.101 def generate_result_code(self, code):
1.102 code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" %
1.103 (self.result(),
1.104 - self.target.result(),
1.105 - self.expr.result(),
1.106 - code.error_goto_if(self.result(), self.pos)))
1.107 + self.target.result(),
1.108 + self.expr.result(),
1.109 + code.error_goto_if(self.result(), self.pos)))
1.110 +
1.111 +class SetComprehensionAppendNode(ComprehensionAppendNode):
1.112 + def generate_result_code(self, code):
1.113 + code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" %
1.114 + (self.result(),
1.115 + self.target.result(),
1.116 + self.expr.result(),
1.117 + code.error_goto_if(self.result(), self.pos)))
1.118 +
1.119 +class DictComprehensionAppendNode(ComprehensionAppendNode):
1.120 + subexprs = ['key_expr', 'value_expr']
1.121 +
1.122 + def analyse_types(self, env):
1.123 + self.key_expr.analyse_types(env)
1.124 + if not self.key_expr.type.is_pyobject:
1.125 + self.key_expr = self.key_expr.coerce_to_pyobject(env)
1.126 + self.value_expr.analyse_types(env)
1.127 + if not self.value_expr.type.is_pyobject:
1.128 + self.value_expr = self.value_expr.coerce_to_pyobject(env)
1.129 + self.type = PyrexTypes.c_int_type
1.130 + self.is_temp = 1
1.131 +
1.132 + def generate_result_code(self, code):
1.133 + code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
1.134 + (self.result(),
1.135 + self.target.result(),
1.136 + self.key_expr.result(),
1.137 + self.value_expr.result(),
1.138 + code.error_goto_if(self.result(), self.pos)))
1.139 +
1.140 +
1.141 +class SetNode(NewTempExprNode):
1.142 + # Set constructor.
1.143 +
1.144 + subexprs = ['args']
1.145 +
1.146 + gil_message = "Constructing Python set"
1.147 +
1.148 + def analyse_types(self, env):
1.149 + for i in range(len(self.args)):
1.150 + arg = self.args[i]
1.151 + arg.analyse_types(env)
1.152 + self.args[i] = arg.coerce_to_pyobject(env)
1.153 + self.type = set_type
1.154 + self.gil_check(env)
1.155 + self.is_temp = 1
1.156 +
1.157 + def compile_time_value(self, denv):
1.158 + values = [arg.compile_time_value(denv) for arg in self.args]
1.159 + try:
1.160 + set
1.161 + except NameError:
1.162 + from sets import Set as set
1.163 + try:
1.164 + return set(values)
1.165 + except Exception, e:
1.166 + self.compile_time_value_error(e)
1.167 +
1.168 + def generate_evaluation_code(self, code):
1.169 + self.allocate_temp_result(code)
1.170 + code.putln(
1.171 + "%s = PySet_New(0); %s" % (
1.172 + self.result(),
1.173 + code.error_goto_if_null(self.result(), self.pos)))
1.174 + for arg in self.args:
1.175 + arg.generate_evaluation_code(code)
1.176 + code.putln(
1.177 + code.error_goto_if_neg(
1.178 + "PySet_Add(%s, %s)" % (self.result(), arg.py_result()),
1.179 + self.pos))
1.180 + arg.generate_disposal_code(code)
1.181 + arg.free_temps(code)
1.182
1.183
1.184 class DictNode(ExprNode):
2.1 --- a/Cython/Compiler/Parsing.py Fri Dec 12 08:22:49 2008 +0100
2.2 +++ b/Cython/Compiler/Parsing.py Fri Dec 12 09:21:10 2008 +0100
2.3 @@ -473,7 +473,7 @@
2.4 return ExprNodes.SliceNode(pos,
2.5 start = start, stop = stop, step = step)
2.6
2.7 -#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dictmaker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+
2.8 +#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dict_or_set_maker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+
2.9
2.10 def p_atom(s):
2.11 pos = s.position()
2.12 @@ -491,7 +491,7 @@
2.13 elif sy == '[':
2.14 return p_list_maker(s)
2.15 elif sy == '{':
2.16 - return p_dict_maker(s)
2.17 + return p_dict_or_set_maker(s)
2.18 elif sy == '`':
2.19 return p_backquote_expr(s)
2.20 elif sy == 'INT':
2.21 @@ -701,13 +701,8 @@
2.22 if s.sy == 'for':
2.23 loop = p_list_for(s)
2.24 s.expect(']')
2.25 - inner_loop = loop
2.26 - while not isinstance(inner_loop.body, Nodes.PassStatNode):
2.27 - inner_loop = inner_loop.body
2.28 - if isinstance(inner_loop, Nodes.IfStatNode):
2.29 - inner_loop = inner_loop.if_clauses[0]
2.30 append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
2.31 - inner_loop.body = Nodes.ExprStatNode(pos, expr = append)
2.32 + set_inner_comp_append(loop, append)
2.33 return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
2.34 else:
2.35 exprs = [expr]
2.36 @@ -742,27 +737,69 @@
2.37 return Nodes.IfStatNode(pos,
2.38 if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
2.39 else_clause = None )
2.40 -
2.41 +
2.42 +def set_inner_comp_append(loop, append):
2.43 + inner_loop = loop
2.44 + while not isinstance(inner_loop.body, Nodes.PassStatNode):
2.45 + inner_loop = inner_loop.body
2.46 + if isinstance(inner_loop, Nodes.IfStatNode):
2.47 + inner_loop = inner_loop.if_clauses[0]
2.48 + inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append)
2.49 +
2.50 #dictmaker: test ':' test (',' test ':' test)* [',']
2.51
2.52 -def p_dict_maker(s):
2.53 +def p_dict_or_set_maker(s):
2.54 # s.sy == '{'
2.55 pos = s.position()
2.56 s.next()
2.57 - items = []
2.58 - while s.sy != '}':
2.59 - items.append(p_dict_item(s))
2.60 - if s.sy != ',':
2.61 - break
2.62 + if s.sy == '}':
2.63 s.next()
2.64 - s.expect('}')
2.65 - return ExprNodes.DictNode(pos, key_value_pairs = items)
2.66 -
2.67 -def p_dict_item(s):
2.68 - key = p_simple_expr(s)
2.69 - s.expect(':')
2.70 - value = p_simple_expr(s)
2.71 - return ExprNodes.DictItemNode(key.pos, key=key, value=value)
2.72 + return ExprNodes.DictNode(pos, key_value_pairs = [])
2.73 + item = p_simple_expr(s)
2.74 + if s.sy == ',' or s.sy == '}':
2.75 + # set literal
2.76 + values = [item]
2.77 + while s.sy == ',':
2.78 + s.next()
2.79 + values.append( p_simple_expr(s) )
2.80 + s.expect('}')
2.81 + return ExprNodes.SetNode(pos, args=values)
2.82 + elif s.sy == 'for':
2.83 + # set comprehension
2.84 + loop = p_list_for(s)
2.85 + s.expect('}')
2.86 + append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item)
2.87 + set_inner_comp_append(loop, append)
2.88 + return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append)
2.89 + elif s.sy == ':':
2.90 + # dict literal or comprehension
2.91 + key = item
2.92 + s.next()
2.93 + value = p_simple_expr(s)
2.94 + if s.sy == 'for':
2.95 + # dict comprehension
2.96 + loop = p_list_for(s)
2.97 + s.expect('}')
2.98 + append = ExprNodes.DictComprehensionAppendNode(
2.99 + item.pos, key_expr = key, value_expr = value)
2.100 + set_inner_comp_append(loop, append)
2.101 + return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append)
2.102 + else:
2.103 + # dict literal
2.104 + items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
2.105 + while s.sy == ',':
2.106 + s.next()
2.107 + key = p_simple_expr(s)
2.108 + s.expect(':')
2.109 + value = p_simple_expr(s)
2.110 + items.append(
2.111 + ExprNodes.DictItemNode(key.pos, key=key, value=value))
2.112 + s.expect('}')
2.113 + return ExprNodes.DictNode(pos, key_value_pairs=items)
2.114 + else:
2.115 + # raise an error
2.116 + s.expect('}')
2.117 + return ExprNodes.DictNode(pos, key_value_pairs = [])
2.118
2.119 def p_backquote_expr(s):
2.120 # s.sy == '`'
3.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
3.2 +++ b/tests/run/dictcomp.pyx Fri Dec 12 09:21:10 2008 +0100
3.3 @@ -0,0 +1,32 @@
3.4 +u"""
3.5 +>>> type(smoketest()) is dict
3.6 +True
3.7 +
3.8 +>>> sorted(smoketest().items())
3.9 +[(2, 0), (4, 4), (6, 8)]
3.10 +>>> list(typed().items())
3.11 +[(A, 1), (A, 1), (A, 1)]
3.12 +>>> sorted(iterdict().items())
3.13 +[(1, 'a'), (2, 'b'), (3, 'c')]
3.14 +"""
3.15 +
3.16 +def smoketest():
3.17 + return {x+2:x*2 for x in range(5) if x % 2 == 0}
3.18 +
3.19 +cdef class A:
3.20 + def __repr__(self): return u"A"
3.21 + def __richcmp__(one, other, op): return one is other
3.22 + def __hash__(self): return id(self) % 65536
3.23 +
3.24 +def typed():
3.25 + cdef A obj
3.26 + return {obj:1 for obj in [A(), A(), A()]}
3.27 +
3.28 +def iterdict():
3.29 + cdef dict d = dict(a=1,b=2,c=3)
3.30 + return {d[key]:key for key in d}
3.31 +
3.32 +def sorted(it):
3.33 + l = list(it)
3.34 + l.sort()
3.35 + return l
4.1 --- a/tests/run/set.pyx Fri Dec 12 08:22:49 2008 +0100
4.2 +++ b/tests/run/set.pyx Fri Dec 12 09:21:10 2008 +0100
4.3 @@ -1,14 +1,37 @@
4.4 -__doc__ = u"""
4.5 ->>> test_set_add()
4.6 -set(['a', 1])
4.7 ->>> test_set_clear()
4.8 -set([])
4.9 ->>> test_set_pop()
4.10 -set([])
4.11 ->>> test_set_discard()
4.12 -set([233, '12'])
4.13 +u"""
4.14 +>>> type(test_set_literal()) is _set
4.15 +True
4.16 +>>> sorted(test_set_literal())
4.17 +['a', 'b', 1]
4.18 +
4.19 +>>> type(test_set_add()) is _set
4.20 +True
4.21 +>>> sorted(test_set_add())
4.22 +['a', 1]
4.23 +
4.24 +>>> type(test_set_add()) is _set
4.25 +True
4.26 +>>> list(test_set_clear())
4.27 +[]
4.28 +
4.29 +>>> type(test_set_pop()) is _set
4.30 +True
4.31 +>>> list(test_set_pop())
4.32 +[]
4.33 +
4.34 +>>> type(test_set_discard()) is _set
4.35 +True
4.36 +>>> sorted(test_set_discard())
4.37 +['12', 233]
4.38 """
4.39
4.40 +# Py2.3 doesn't have the 'set' builtin type, but Cython does :)
4.41 +_set = set
4.42 +
4.43 +def test_set_literal():
4.44 + cdef set s1 = {1,'a',1,'b','a'}
4.45 + return s1
4.46 +
4.47 def test_set_add():
4.48 cdef set s1
4.49 s1 = set([1])
4.50 @@ -39,4 +62,16 @@
4.51 s1.discard('3')
4.52 s1.discard(3)
4.53 return s1
4.54 -
4.55 +
4.56 +def sorted(it):
4.57 + # Py3 can't compare strings to ints
4.58 + chars = []
4.59 + nums = []
4.60 + for item in it:
4.61 + if type(item) is int:
4.62 + nums.append(item)
4.63 + else:
4.64 + chars.append(item)
4.65 + nums.sort()
4.66 + chars.sort()
4.67 + return chars+nums
5.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
5.2 +++ b/tests/run/setcomp.pyx Fri Dec 12 09:21:10 2008 +0100
5.3 @@ -0,0 +1,37 @@
5.4 +u"""
5.5 +>>> type(smoketest()) is not list
5.6 +True
5.7 +>>> type(smoketest()) is _set
5.8 +True
5.9 +
5.10 +>>> sorted(smoketest())
5.11 +[0, 4, 8]
5.12 +>>> list(typed())
5.13 +[A, A, A]
5.14 +>>> sorted(iterdict())
5.15 +[1, 2, 3]
5.16 +"""
5.17 +
5.18 +# Py2.3 doesn't have the set type, but Cython does :)
5.19 +_set = set
5.20 +
5.21 +def smoketest():
5.22 + return {x*2 for x in range(5) if x % 2 == 0}
5.23 +
5.24 +cdef class A:
5.25 + def __repr__(self): return u"A"
5.26 + def __richcmp__(one, other, op): return one is other
5.27 + def __hash__(self): return id(self) % 65536
5.28 +
5.29 +def typed():
5.30 + cdef A obj
5.31 + return {obj for obj in {A(), A(), A()}}
5.32 +
5.33 +def iterdict():
5.34 + cdef dict d = dict(a=1,b=2,c=3)
5.35 + return {d[key] for key in d}
5.36 +
5.37 +def sorted(it):
5.38 + l = list(it)
5.39 + l.sort()
5.40 + return l
