cython-devel

changeset 1235:5de00fce9b73

Buffers: NumPy record array support, format string parsing improvements
author Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
date Sat Oct 11 18:48:15 2008 +0200 (3 years ago)
parents 81b6cc6209c1
children c9fc106f9412
files Cython/Compiler/Buffer.py Cython/Compiler/PyrexTypes.py Cython/Includes/numpy.pxd tests/run/bufaccess.pyx tests/run/numpy_test.pyx
line diff
1.1 --- a/Cython/Compiler/Buffer.py Fri Oct 10 02:42:55 2008 -0700 1.2 +++ b/Cython/Compiler/Buffer.py Sat Oct 11 18:48:15 2008 +0200 1.3 @@ -562,14 +562,31 @@ 1.4 1.5 return name 1.6 1.7 +def get_typestringchecker(code, dtype): 1.8 + """ 1.9 + Returns the name of a typestring checker with the given type; emitting 1.10 + it to code if needed. 1.11 + """ 1.12 + name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype) 1.13 + code.globalstate.use_code_from(create_typestringchecker, 1.14 + name, 1.15 + dtype=dtype) 1.16 + return name 1.17 + 1.18 def create_typestringchecker(protocode, defcode, name, dtype): 1.19 + 1.20 + def put_assert(cond, msg): 1.21 + defcode.putln("if (!(%s)) {" % cond) 1.22 + msg += ", got '%s'" 1.23 + defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg) 1.24 + defcode.putln("return NULL;") 1.25 + defcode.putln("}") 1.26 + 1.27 if dtype.is_error: return 1.28 - simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr 1.29 + simple = dtype.is_simple_buffer_dtype() 1.30 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex() 1.31 # Cannot add utility code recursively... 1.32 - if simple: 1.33 - itemchecker = get_ts_check_item(dtype, protocode) 1.34 - else: 1.35 + if not simple: 1.36 dtype_t = dtype.declaration_code("") 1.37 protocode.globalstate.use_utility_code(parse_typestring_repeat_code) 1.38 fields = dtype.scope.var_entries 1.39 @@ -580,18 +597,58 @@ 1.40 prevtype = None 1.41 for f in fields: 1.42 if n and f.type != prevtype: 1.43 - field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode))) 1.44 + field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype))) 1.45 n = 0 1.46 prevtype = f.type 1.47 n += 1 1.48 - field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode))) 1.49 + field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type))) 1.50 1.51 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name) 1.52 defcode.putln("static const char* %s(const char* ts) {" % name) 1.53 if simple: 1.54 + defcode.putln("int ok;") 1.55 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") 1.56 defcode.putln("if (*ts == '1') ++ts;") 1.57 - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker) 1.58 + if dtype.typestring is not None: 1.59 + assert len(dtype.typestring) == 1 1.60 + # Can use direct comparison 1.61 + defcode.putln("ok = (*ts == '%s');" % dtype.typestring) 1.62 + else: 1.63 + # Cannot trust declared size; but rely on int vs float and 1.64 + # signed/unsigned to be correctly declared. Use a switch statement 1.65 + # on all possible format codes to validate that the size is ok. 1.66 + # (Note that many codes may map to same size, e.g. 'i' and 'l' 1.67 + # may both be four bytes). 1.68 + ctype = dtype.declaration_code("") 1.69 + defcode.putln("switch (*ts) {") 1.70 + if dtype.is_int: 1.71 + types = [ 1.72 + ('b', 'char'), ('h', 'short'), ('i', 'int'), 1.73 + ('l', 'long'), ('q', 'long long') 1.74 + ] 1.75 + elif dtype.is_float: 1.76 + types = [('f', 'float'), ('d', 'double'), ('g', 'long double')] 1.77 + else: 1.78 + assert False 1.79 + if dtype.signed == 0: 1.80 + for char, against in types: 1.81 + defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" % 1.82 + (char.upper(), ctype, against, ctype)) 1.83 + else: 1.84 + for char, against in types: 1.85 + defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" % 1.86 + (char, ctype, against, ctype)) 1.87 + defcode.putln("default: ok = 0;") 1.88 + defcode.putln("}") 1.89 + defcode.putln("if (!ok) {") 1.90 + if dtype.typestring is not None: 1.91 + errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring 1.92 + else: 1.93 + errmsg = "Buffer datatype mismatch (rejecting on '%s')" 1.94 + defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg) 1.95 + defcode.putln("return NULL;"); 1.96 + defcode.putln("}") 1.97 + defcode.putln("++ts;") 1.98 elif complex_possible: 1.99 # Could be a struct representing a complex number, so allow 1.100 # for parsing a "Zf" spec. 1.101 @@ -623,15 +680,25 @@ 1.102 else: 1.103 defcode.putln("int n, count;") 1.104 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") 1.105 + 1.106 for n, type, checker in field_blocks: 1.107 if n == 1: 1.108 defcode.putln("if (*ts == '1') ++ts;") 1.109 - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) 1.110 else: 1.111 defcode.putln("n = %d;" % n); 1.112 defcode.putln("do {") 1.113 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;") 1.114 - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) 1.115 + 1.116 + simple = type.is_simple_buffer_dtype() 1.117 + if not simple: 1.118 + put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True)) 1.119 + defcode.putln("ts += 2;") 1.120 + defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) 1.121 + if not simple: 1.122 + put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True)) 1.123 + defcode.putln("++ts;") 1.124 + 1.125 + if n > 1: 1.126 defcode.putln("} while (n > 0);"); 1.127 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") 1.128 1.129 @@ -651,11 +718,7 @@ 1.130 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype) 1.131 if not code.globalstate.has_code(name): 1.132 code.globalstate.use_utility_code(acquire_utility_code) 1.133 - typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype) 1.134 - code.globalstate.use_code_from(create_typestringchecker, 1.135 - typestringchecker, 1.136 - dtype=dtype) 1.137 - 1.138 + typestringchecker = get_typestringchecker(code, dtype) 1.139 dtype_name = str(dtype) 1.140 dtype_cname = dtype.declaration_code("") 1.141 utilcode = [dedent("""
2.1 --- a/Cython/Compiler/PyrexTypes.py Fri Oct 10 02:42:55 2008 -0700 2.2 +++ b/Cython/Compiler/PyrexTypes.py Sat Oct 11 18:48:15 2008 +0200 2.3 @@ -140,6 +140,10 @@ 2.4 # a struct whose attributes are not defined, etc. 2.5 return 1 2.6 2.7 + def is_simple_buffer_dtype(self): 2.8 + return (self.is_int or self.is_float or self.is_pyobject or 2.9 + self.is_extension_type or self.is_ptr) 2.10 + 2.11 class CTypedefType(BaseType): 2.12 # 2.13 # Pseudo-type defined with a ctypedef statement in a
3.1 --- a/Cython/Includes/numpy.pxd Fri Oct 10 02:42:55 2008 -0700 3.2 +++ b/Cython/Includes/numpy.pxd Sat Oct 11 18:48:15 2008 +0200 3.3 @@ -1,4 +1,5 @@ 3.4 cimport python_buffer as pybuf 3.5 +cimport stdlib 3.6 3.7 cdef extern from "Python.h": 3.8 ctypedef int Py_intptr_t 3.9 @@ -26,6 +27,11 @@ 3.10 NPY_C_CONTIGUOUS, 3.11 NPY_F_CONTIGUOUS 3.12 3.13 + ctypedef class numpy.dtype [object PyArray_Descr]: 3.14 + cdef int type_num 3.15 + cdef object fields 3.16 + cdef object names 3.17 + 3.18 3.19 ctypedef class numpy.ndarray [object PyArrayObject]: 3.20 cdef __cythonbufferdefaults__ = {"mode": "strided"} 3.21 @@ -36,6 +42,7 @@ 3.22 npy_intp *shape "dimensions" 3.23 npy_intp *strides 3.24 int flags 3.25 + dtype descr 3.26 3.27 # Note: This syntax (function definition in pxd files) is an 3.28 # experimental exception made for __getbuffer__ and __releasebuffer__ 3.29 @@ -57,7 +64,6 @@ 3.30 raise ValueError("ndarray is not Fortran contiguous") 3.31 3.32 info.buf = PyArray_DATA(self) 3.33 - # info.obj = None # this is automatic 3.34 info.ndim = PyArray_NDIM(self) 3.35 info.strides = <Py_ssize_t*>PyArray_STRIDES(self) 3.36 info.shape = <Py_ssize_t*>PyArray_DIMS(self) 3.37 @@ -65,31 +71,104 @@ 3.38 info.itemsize = PyArray_ITEMSIZE(self) 3.39 info.readonly = not PyArray_ISWRITEABLE(self) 3.40 3.41 - # Formats that are not tested and working in Cython are not 3.42 - # made available from this pxd file yet. 3.43 - cdef int t = PyArray_TYPE(self) 3.44 - cdef char* f = NULL 3.45 - if t == NPY_BYTE: f = "b" 3.46 - elif t == NPY_UBYTE: f = "B" 3.47 - elif t == NPY_SHORT: f = "h" 3.48 - elif t == NPY_USHORT: f = "H" 3.49 - elif t == NPY_INT: f = "i" 3.50 - elif t == NPY_UINT: f = "I" 3.51 - elif t == NPY_LONG: f = "l" 3.52 - elif t == NPY_ULONG: f = "L" 3.53 - elif t == NPY_LONGLONG: f = "q" 3.54 - elif t == NPY_ULONGLONG: f = "Q" 3.55 - elif t == NPY_FLOAT: f = "f" 3.56 - elif t == NPY_DOUBLE: f = "d" 3.57 - elif t == NPY_LONGDOUBLE: f = "g" 3.58 - elif t == NPY_CFLOAT: f = "Zf" 3.59 - elif t == NPY_CDOUBLE: f = "Zd" 3.60 - elif t == NPY_CLONGDOUBLE: f = "Zg" 3.61 - elif t == NPY_OBJECT: f = "O" 3.62 + cdef int t 3.63 + cdef char* f = NULL 3.64 + cdef dtype descr = self.descr 3.65 + cdef list stack 3.66 3.67 - if f == NULL: 3.68 - raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t) 3.69 - info.format = f 3.70 + cdef bint hasfields = PyDataType_HASFIELDS(descr) 3.71 + 3.72 + # Ugly hack warning: 3.73 + # Cython currently will not support helper functions in 3.74 + # pxd files -- so we must keep our own, manual stack! 3.75 + # In addition, avoid allocation of the stack in the common 3.76 + # case that we are dealing with a single non-nested datatype... 3.77 + # (this would look much prettier if we could use utility 3.78 + # functions). 3.79 + 3.80 + 3.81 + if not hasfields: 3.82 + info.obj = None # do not call releasebuffer 3.83 + t = descr.type_num 3.84 + if t == NPY_BYTE: f = "b" 3.85 + elif t == NPY_UBYTE: f = "B" 3.86 + elif t == NPY_SHORT: f = "h" 3.87 + elif t == NPY_USHORT: f = "H" 3.88 + elif t == NPY_INT: f = "i" 3.89 + elif t == NPY_UINT: f = "I" 3.90 + elif t == NPY_LONG: f = "l" 3.91 + elif t == NPY_ULONG: f = "L" 3.92 + elif t == NPY_LONGLONG: f = "q" 3.93 + elif t == NPY_ULONGLONG: f = "Q" 3.94 + elif t == NPY_FLOAT: f = "f" 3.95 + elif t == NPY_DOUBLE: f = "d" 3.96 + elif t == NPY_LONGDOUBLE: f = "g" 3.97 + elif t == NPY_CFLOAT: f = "Zf" 3.98 + elif t == NPY_CDOUBLE: f = "Zd" 3.99 + elif t == NPY_CLONGDOUBLE: f = "Zg" 3.100 + elif t == NPY_OBJECT: f = "O" 3.101 + else: 3.102 + raise ValueError("unknown dtype code in numpy.pxd (%d)" % t) 3.103 + info.format = f 3.104 + return 3.105 + else: 3.106 + info.obj = self # need to call releasebuffer 3.107 + info.format = <char*>stdlib.malloc(255) # static size 3.108 + f = info.format 3.109 + stack = [iter(descr.fields.iteritems())] 3.110 + 3.111 + while True: 3.112 + iterator = stack[-1] 3.113 + descr = None 3.114 + while descr is None: 3.115 + try: 3.116 + descr = iterator.next()[1][0] 3.117 + except StopIteration: 3.118 + stack.pop() 3.119 + if len(stack) > 0: 3.120 + f[0] = "}" 3.121 + f += 1 3.122 + iterator = stack[-1] 3.123 + else: 3.124 + f[0] = 0 # Terminate string! 3.125 + return 3.126 + 3.127 + hasfields = PyDataType_HASFIELDS(descr) 3.128 + if not hasfields: 3.129 + t = descr.type_num 3.130 + if f - info.format > 240: # this should leave room for "T{" and "}" as well 3.131 + raise RuntimeError("Format string allocated too short.") 3.132 + 3.133 + if t == NPY_BYTE: f[0] = "b" 3.134 + elif t == NPY_UBYTE: f[0] = "B" 3.135 + elif t == NPY_SHORT: f[0] = "h" 3.136 + elif t == NPY_USHORT: f[0] = "H" 3.137 + elif t == NPY_INT: f[0] = "i" 3.138 + elif t == NPY_UINT: f[0] = "I" 3.139 + elif t == NPY_LONG: f[0] = "l" 3.140 + elif t == NPY_ULONG: f[0] = "L" 3.141 + elif t == NPY_LONGLONG: f[0] = "q" 3.142 + elif t == NPY_ULONGLONG: f[0] = "Q" 3.143 + elif t == NPY_FLOAT: f[0] = "f" 3.144 + elif t == NPY_DOUBLE: f[0] = "d" 3.145 + elif t == NPY_LONGDOUBLE: f[0] = "g" 3.146 + elif t == NPY_CFLOAT: f[0] = "Z"; f[1] = "f"; f += 1 3.147 + elif t == NPY_CDOUBLE: f[0] = "Z"; f[1] = "d"; f += 1 3.148 + elif t == NPY_CLONGDOUBLE: f[0] = "Z"; f[1] = "g"; f += 1 3.149 + elif t == NPY_OBJECT: f[0] = "O" 3.150 + else: 3.151 + raise ValueError("unknown dtype code in numpy.pxd (%d)" % t) 3.152 + f += 1 3.153 + else: 3.154 + f[0] = "T" 3.155 + f[1] = "{" 3.156 + f += 2 3.157 + stack.append(iter(descr.fields.iteritems())) 3.158 + 3.159 + def __releasebuffer__(ndarray self, Py_buffer* info): 3.160 + # This can not be called unless format needs to be freed (as 3.161 + # obj is set to NULL in those case) 3.162 + stdlib.free(info.format) 3.163 3.164 3.165 cdef void* PyArray_DATA(ndarray arr) 3.166 @@ -100,6 +179,9 @@ 3.167 cdef npy_intp PyArray_DIMS(ndarray arr) 3.168 cdef Py_ssize_t PyArray_ITEMSIZE(ndarray arr) 3.169 cdef int PyArray_CHKFLAGS(ndarray arr, int flags) 3.170 + cdef int PyArray_HASFIELDS(ndarray arr, int flags) 3.171 + 3.172 + cdef int PyDataType_HASFIELDS(dtype obj) 3.173 3.174 ctypedef signed int npy_byte 3.175 ctypedef signed int npy_short
4.1 --- a/tests/run/bufaccess.pyx Fri Oct 10 02:42:55 2008 -0700 4.2 +++ b/tests/run/bufaccess.pyx Sat Oct 11 18:48:15 2008 +0200 4.3 @@ -1292,6 +1292,15 @@ 4.4 int d 4.5 int e 4.6 4.7 +cdef struct SmallStruct: 4.8 + int a 4.9 + int b 4.10 + 4.11 +cdef struct NestedStruct: 4.12 + SmallStruct x 4.13 + SmallStruct y 4.14 + int z 4.15 + 4.16 cdef class MyStructMockBuffer(MockBuffer): 4.17 cdef int write(self, char* buf, object value) except -1: 4.18 cdef MyStruct* s 4.19 @@ -1302,6 +1311,16 @@ 4.20 cdef get_itemsize(self): return sizeof(MyStruct) 4.21 cdef get_default_format(self): return b"2bq2i" 4.22 4.23 +cdef class NestedStructMockBuffer(MockBuffer): 4.24 + cdef int write(self, char* buf, object value) except -1: 4.25 + cdef NestedStruct* s 4.26 + s = <NestedStruct*>buf; 4.27 + s.x.a, s.x.b, s.y.a, s.y.b, s.z = value 4.28 + return 0 4.29 + 4.30 + cdef get_itemsize(self): return sizeof(NestedStruct) 4.31 + cdef get_default_format(self): return b"2T{ii}i" 4.32 + 4.33 @testcase 4.34 def basic_struct(object[MyStruct] buf): 4.35 """ 4.36 @@ -1316,6 +1335,21 @@ 4.37 """ 4.38 print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e 4.39 4.40 +@testcase 4.41 +def nested_struct(object[NestedStruct] buf): 4.42 + """ 4.43 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)])) 4.44 + 1 2 3 4 5 4.45 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i")) 4.46 + 1 2 3 4 5 4.47 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii")) 4.48 + Traceback (most recent call last): 4.49 + ... 4.50 + ValueError: Expected start of SmallStruct, got 'iiiii' 4.51 + """ 4.52 + print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z 4.53 + 4.54 + 4.55 cdef struct LongComplex: 4.56 long double real 4.57 long double imag
5.1 --- a/tests/run/numpy_test.pyx Fri Oct 10 02:42:55 2008 -0700 5.2 +++ b/tests/run/numpy_test.pyx Sat Oct 11 18:48:15 2008 +0200 5.3 @@ -129,12 +129,22 @@ 5.4 >>> test_dtype(np.int32, inc1_int32_t) 5.5 >>> test_dtype(np.float64, inc1_float64_t) 5.6 5.7 - Unsupported types: 5.8 - >>> a = np.zeros((10,), dtype=np.dtype('i4,i4')) 5.9 - >>> inc1_byte(a) 5.10 + >>> test_recordarray() 5.11 + 5.12 + >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\ 5.13 + ('a', np.dtype('i,i')),\ 5.14 + ('b', np.dtype('i,i'))\ 5.15 + ]))) 5.16 + array([((0, 0), (0, 0)), ((1, 2), (1, 4)), ((1, 2), (1, 4))], 5.17 + dtype=[('a', [('f0', '<i4'), ('f1', '<i4')]), ('b', [('f0', '<i4'), ('f1', '<i4')])]) 5.18 + 5.19 + >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\ 5.20 + ('a', np.dtype('i,f')),\ 5.21 + ('b', np.dtype('i,i'))\ 5.22 + ]))) 5.23 Traceback (most recent call last): 5.24 - ... 5.25 - ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20) 5.26 + ... 5.27 + ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}') 5.28 5.29 >>> test_good_cast() 5.30 True 5.31 @@ -261,6 +271,49 @@ 5.32 inc1(a) 5.33 if a[1] != 11: print "failed!" 5.34 5.35 +cdef struct DoubleInt: 5.36 + int x, y 5.37 + 5.38 +def test_recordarray(): 5.39 + cdef object[DoubleInt] arr 5.40 + arr = np.array([(5,5), (4, 6)], dtype=np.dtype('i,i')) 5.41 + cdef DoubleInt rec 5.42 + rec = arr[0] 5.43 + if rec.x != 5: print "failed" 5.44 + if rec.y != 5: print "failed" 5.45 + rec.y += 5 5.46 + arr[1] = rec 5.47 + arr[0].x -= 2 5.48 + arr[0].y += 3 5.49 + if arr[0].x != 3: print "failed" 5.50 + if arr[0].y != 8: print "failed" 5.51 + if arr[1].x != 5: print "failed" 5.52 + if arr[1].y != 10: print "failed" 5.53 + 5.54 +cdef struct NestedStruct: 5.55 + DoubleInt a 5.56 + DoubleInt b 5.57 + 5.58 +cdef struct BadDoubleInt: 5.59 + float x 5.60 + int y 5.61 + 5.62 +cdef struct BadNestedStruct: 5.63 + DoubleInt a 5.64 + BadDoubleInt b 5.65 + 5.66 +def test_nested_dtypes(obj): 5.67 + cdef object[NestedStruct] arr = obj 5.68 + arr[1].a.x = 1 5.69 + arr[1].a.y = 2 5.70 + arr[1].b.x = arr[0].a.y + 1 5.71 + arr[1].b.y = 4 5.72 + arr[2] = arr[1] 5.73 + return arr 5.74 + 5.75 +def test_bad_nested_dtypes(): 5.76 + cdef object[BadNestedStruct] arr 5.77 + 5.78 def test_good_cast(): 5.79 # Check that a signed int can round-trip through casted unsigned int access 5.80 cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')