cython-devel
changeset 1235:5de00fce9b73
Buffers: NumPy record array support, format string parsing improvements
| author | Dag Sverre Seljebotn <dagss@student.matnat.uio.no> |
|---|---|
| date | Sat Oct 11 18:48:15 2008 +0200 (3 years ago) |
| parents | 81b6cc6209c1 |
| children | c9fc106f9412 |
| files | Cython/Compiler/Buffer.py Cython/Compiler/PyrexTypes.py Cython/Includes/numpy.pxd tests/run/bufaccess.pyx tests/run/numpy_test.pyx |
line diff
1.1 --- a/Cython/Compiler/Buffer.py Fri Oct 10 02:42:55 2008 -0700
1.2 +++ b/Cython/Compiler/Buffer.py Sat Oct 11 18:48:15 2008 +0200
1.3 @@ -562,14 +562,31 @@
1.4
1.5 return name
1.6
1.7 +def get_typestringchecker(code, dtype):
1.8 + """
1.9 + Returns the name of a typestring checker with the given type; emitting
1.10 + it to code if needed.
1.11 + """
1.12 + name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
1.13 + code.globalstate.use_code_from(create_typestringchecker,
1.14 + name,
1.15 + dtype=dtype)
1.16 + return name
1.17 +
1.18 def create_typestringchecker(protocode, defcode, name, dtype):
1.19 +
1.20 + def put_assert(cond, msg):
1.21 + defcode.putln("if (!(%s)) {" % cond)
1.22 + msg += ", got '%s'"
1.23 + defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg)
1.24 + defcode.putln("return NULL;")
1.25 + defcode.putln("}")
1.26 +
1.27 if dtype.is_error: return
1.28 - simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
1.29 + simple = dtype.is_simple_buffer_dtype()
1.30 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
1.31 # Cannot add utility code recursively...
1.32 - if simple:
1.33 - itemchecker = get_ts_check_item(dtype, protocode)
1.34 - else:
1.35 + if not simple:
1.36 dtype_t = dtype.declaration_code("")
1.37 protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
1.38 fields = dtype.scope.var_entries
1.39 @@ -580,18 +597,58 @@
1.40 prevtype = None
1.41 for f in fields:
1.42 if n and f.type != prevtype:
1.43 - field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode)))
1.44 + field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
1.45 n = 0
1.46 prevtype = f.type
1.47 n += 1
1.48 - field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode)))
1.49 + field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
1.50
1.51 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
1.52 defcode.putln("static const char* %s(const char* ts) {" % name)
1.53 if simple:
1.54 + defcode.putln("int ok;")
1.55 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
1.56 defcode.putln("if (*ts == '1') ++ts;")
1.57 - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
1.58 + if dtype.typestring is not None:
1.59 + assert len(dtype.typestring) == 1
1.60 + # Can use direct comparison
1.61 + defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
1.62 + else:
1.63 + # Cannot trust declared size; but rely on int vs float and
1.64 + # signed/unsigned to be correctly declared. Use a switch statement
1.65 + # on all possible format codes to validate that the size is ok.
1.66 + # (Note that many codes may map to same size, e.g. 'i' and 'l'
1.67 + # may both be four bytes).
1.68 + ctype = dtype.declaration_code("")
1.69 + defcode.putln("switch (*ts) {")
1.70 + if dtype.is_int:
1.71 + types = [
1.72 + ('b', 'char'), ('h', 'short'), ('i', 'int'),
1.73 + ('l', 'long'), ('q', 'long long')
1.74 + ]
1.75 + elif dtype.is_float:
1.76 + types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
1.77 + else:
1.78 + assert False
1.79 + if dtype.signed == 0:
1.80 + for char, against in types:
1.81 + defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
1.82 + (char.upper(), ctype, against, ctype))
1.83 + else:
1.84 + for char, against in types:
1.85 + defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
1.86 + (char, ctype, against, ctype))
1.87 + defcode.putln("default: ok = 0;")
1.88 + defcode.putln("}")
1.89 + defcode.putln("if (!ok) {")
1.90 + if dtype.typestring is not None:
1.91 + errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring
1.92 + else:
1.93 + errmsg = "Buffer datatype mismatch (rejecting on '%s')"
1.94 + defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg)
1.95 + defcode.putln("return NULL;");
1.96 + defcode.putln("}")
1.97 + defcode.putln("++ts;")
1.98 elif complex_possible:
1.99 # Could be a struct representing a complex number, so allow
1.100 # for parsing a "Zf" spec.
1.101 @@ -623,15 +680,25 @@
1.102 else:
1.103 defcode.putln("int n, count;")
1.104 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
1.105 +
1.106 for n, type, checker in field_blocks:
1.107 if n == 1:
1.108 defcode.putln("if (*ts == '1') ++ts;")
1.109 - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
1.110 else:
1.111 defcode.putln("n = %d;" % n);
1.112 defcode.putln("do {")
1.113 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
1.114 - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
1.115 +
1.116 + simple = type.is_simple_buffer_dtype()
1.117 + if not simple:
1.118 + put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True))
1.119 + defcode.putln("ts += 2;")
1.120 + defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
1.121 + if not simple:
1.122 + put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True))
1.123 + defcode.putln("++ts;")
1.124 +
1.125 + if n > 1:
1.126 defcode.putln("} while (n > 0);");
1.127 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
1.128
1.129 @@ -651,11 +718,7 @@
1.130 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
1.131 if not code.globalstate.has_code(name):
1.132 code.globalstate.use_utility_code(acquire_utility_code)
1.133 - typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
1.134 - code.globalstate.use_code_from(create_typestringchecker,
1.135 - typestringchecker,
1.136 - dtype=dtype)
1.137 -
1.138 + typestringchecker = get_typestringchecker(code, dtype)
1.139 dtype_name = str(dtype)
1.140 dtype_cname = dtype.declaration_code("")
1.141 utilcode = [dedent("""
2.1 --- a/Cython/Compiler/PyrexTypes.py Fri Oct 10 02:42:55 2008 -0700
2.2 +++ b/Cython/Compiler/PyrexTypes.py Sat Oct 11 18:48:15 2008 +0200
2.3 @@ -140,6 +140,10 @@
2.4 # a struct whose attributes are not defined, etc.
2.5 return 1
2.6
2.7 + def is_simple_buffer_dtype(self):
2.8 + return (self.is_int or self.is_float or self.is_pyobject or
2.9 + self.is_extension_type or self.is_ptr)
2.10 +
2.11 class CTypedefType(BaseType):
2.12 #
2.13 # Pseudo-type defined with a ctypedef statement in a
3.1 --- a/Cython/Includes/numpy.pxd Fri Oct 10 02:42:55 2008 -0700
3.2 +++ b/Cython/Includes/numpy.pxd Sat Oct 11 18:48:15 2008 +0200
3.3 @@ -1,4 +1,5 @@
3.4 cimport python_buffer as pybuf
3.5 +cimport stdlib
3.6
3.7 cdef extern from "Python.h":
3.8 ctypedef int Py_intptr_t
3.9 @@ -26,6 +27,11 @@
3.10 NPY_C_CONTIGUOUS,
3.11 NPY_F_CONTIGUOUS
3.12
3.13 + ctypedef class numpy.dtype [object PyArray_Descr]:
3.14 + cdef int type_num
3.15 + cdef object fields
3.16 + cdef object names
3.17 +
3.18
3.19 ctypedef class numpy.ndarray [object PyArrayObject]:
3.20 cdef __cythonbufferdefaults__ = {"mode": "strided"}
3.21 @@ -36,6 +42,7 @@
3.22 npy_intp *shape "dimensions"
3.23 npy_intp *strides
3.24 int flags
3.25 + dtype descr
3.26
3.27 # Note: This syntax (function definition in pxd files) is an
3.28 # experimental exception made for __getbuffer__ and __releasebuffer__
3.29 @@ -57,7 +64,6 @@
3.30 raise ValueError("ndarray is not Fortran contiguous")
3.31
3.32 info.buf = PyArray_DATA(self)
3.33 - # info.obj = None # this is automatic
3.34 info.ndim = PyArray_NDIM(self)
3.35 info.strides = <Py_ssize_t*>PyArray_STRIDES(self)
3.36 info.shape = <Py_ssize_t*>PyArray_DIMS(self)
3.37 @@ -65,31 +71,104 @@
3.38 info.itemsize = PyArray_ITEMSIZE(self)
3.39 info.readonly = not PyArray_ISWRITEABLE(self)
3.40
3.41 - # Formats that are not tested and working in Cython are not
3.42 - # made available from this pxd file yet.
3.43 - cdef int t = PyArray_TYPE(self)
3.44 - cdef char* f = NULL
3.45 - if t == NPY_BYTE: f = "b"
3.46 - elif t == NPY_UBYTE: f = "B"
3.47 - elif t == NPY_SHORT: f = "h"
3.48 - elif t == NPY_USHORT: f = "H"
3.49 - elif t == NPY_INT: f = "i"
3.50 - elif t == NPY_UINT: f = "I"
3.51 - elif t == NPY_LONG: f = "l"
3.52 - elif t == NPY_ULONG: f = "L"
3.53 - elif t == NPY_LONGLONG: f = "q"
3.54 - elif t == NPY_ULONGLONG: f = "Q"
3.55 - elif t == NPY_FLOAT: f = "f"
3.56 - elif t == NPY_DOUBLE: f = "d"
3.57 - elif t == NPY_LONGDOUBLE: f = "g"
3.58 - elif t == NPY_CFLOAT: f = "Zf"
3.59 - elif t == NPY_CDOUBLE: f = "Zd"
3.60 - elif t == NPY_CLONGDOUBLE: f = "Zg"
3.61 - elif t == NPY_OBJECT: f = "O"
3.62 + cdef int t
3.63 + cdef char* f = NULL
3.64 + cdef dtype descr = self.descr
3.65 + cdef list stack
3.66
3.67 - if f == NULL:
3.68 - raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
3.69 - info.format = f
3.70 + cdef bint hasfields = PyDataType_HASFIELDS(descr)
3.71 +
3.72 + # Ugly hack warning:
3.73 + # Cython currently will not support helper functions in
3.74 + # pxd files -- so we must keep our own, manual stack!
3.75 + # In addition, avoid allocation of the stack in the common
3.76 + # case that we are dealing with a single non-nested datatype...
3.77 + # (this would look much prettier if we could use utility
3.78 + # functions).
3.79 +
3.80 +
3.81 + if not hasfields:
3.82 + info.obj = None # do not call releasebuffer
3.83 + t = descr.type_num
3.84 + if t == NPY_BYTE: f = "b"
3.85 + elif t == NPY_UBYTE: f = "B"
3.86 + elif t == NPY_SHORT: f = "h"
3.87 + elif t == NPY_USHORT: f = "H"
3.88 + elif t == NPY_INT: f = "i"
3.89 + elif t == NPY_UINT: f = "I"
3.90 + elif t == NPY_LONG: f = "l"
3.91 + elif t == NPY_ULONG: f = "L"
3.92 + elif t == NPY_LONGLONG: f = "q"
3.93 + elif t == NPY_ULONGLONG: f = "Q"
3.94 + elif t == NPY_FLOAT: f = "f"
3.95 + elif t == NPY_DOUBLE: f = "d"
3.96 + elif t == NPY_LONGDOUBLE: f = "g"
3.97 + elif t == NPY_CFLOAT: f = "Zf"
3.98 + elif t == NPY_CDOUBLE: f = "Zd"
3.99 + elif t == NPY_CLONGDOUBLE: f = "Zg"
3.100 + elif t == NPY_OBJECT: f = "O"
3.101 + else:
3.102 + raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
3.103 + info.format = f
3.104 + return
3.105 + else:
3.106 + info.obj = self # need to call releasebuffer
3.107 + info.format = <char*>stdlib.malloc(255) # static size
3.108 + f = info.format
3.109 + stack = [iter(descr.fields.iteritems())]
3.110 +
3.111 + while True:
3.112 + iterator = stack[-1]
3.113 + descr = None
3.114 + while descr is None:
3.115 + try:
3.116 + descr = iterator.next()[1][0]
3.117 + except StopIteration:
3.118 + stack.pop()
3.119 + if len(stack) > 0:
3.120 + f[0] = "}"
3.121 + f += 1
3.122 + iterator = stack[-1]
3.123 + else:
3.124 + f[0] = 0 # Terminate string!
3.125 + return
3.126 +
3.127 + hasfields = PyDataType_HASFIELDS(descr)
3.128 + if not hasfields:
3.129 + t = descr.type_num
3.130 + if f - info.format > 240: # this should leave room for "T{" and "}" as well
3.131 + raise RuntimeError("Format string allocated too short.")
3.132 +
3.133 + if t == NPY_BYTE: f[0] = "b"
3.134 + elif t == NPY_UBYTE: f[0] = "B"
3.135 + elif t == NPY_SHORT: f[0] = "h"
3.136 + elif t == NPY_USHORT: f[0] = "H"
3.137 + elif t == NPY_INT: f[0] = "i"
3.138 + elif t == NPY_UINT: f[0] = "I"
3.139 + elif t == NPY_LONG: f[0] = "l"
3.140 + elif t == NPY_ULONG: f[0] = "L"
3.141 + elif t == NPY_LONGLONG: f[0] = "q"
3.142 + elif t == NPY_ULONGLONG: f[0] = "Q"
3.143 + elif t == NPY_FLOAT: f[0] = "f"
3.144 + elif t == NPY_DOUBLE: f[0] = "d"
3.145 + elif t == NPY_LONGDOUBLE: f[0] = "g"
3.146 + elif t == NPY_CFLOAT: f[0] = "Z"; f[1] = "f"; f += 1
3.147 + elif t == NPY_CDOUBLE: f[0] = "Z"; f[1] = "d"; f += 1
3.148 + elif t == NPY_CLONGDOUBLE: f[0] = "Z"; f[1] = "g"; f += 1
3.149 + elif t == NPY_OBJECT: f[0] = "O"
3.150 + else:
3.151 + raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
3.152 + f += 1
3.153 + else:
3.154 + f[0] = "T"
3.155 + f[1] = "{"
3.156 + f += 2
3.157 + stack.append(iter(descr.fields.iteritems()))
3.158 +
3.159 + def __releasebuffer__(ndarray self, Py_buffer* info):
3.160 + # This can not be called unless format needs to be freed (as
3.161 + # obj is set to NULL in those case)
3.162 + stdlib.free(info.format)
3.163
3.164
3.165 cdef void* PyArray_DATA(ndarray arr)
3.166 @@ -100,6 +179,9 @@
3.167 cdef npy_intp PyArray_DIMS(ndarray arr)
3.168 cdef Py_ssize_t PyArray_ITEMSIZE(ndarray arr)
3.169 cdef int PyArray_CHKFLAGS(ndarray arr, int flags)
3.170 + cdef int PyArray_HASFIELDS(ndarray arr, int flags)
3.171 +
3.172 + cdef int PyDataType_HASFIELDS(dtype obj)
3.173
3.174 ctypedef signed int npy_byte
3.175 ctypedef signed int npy_short
4.1 --- a/tests/run/bufaccess.pyx Fri Oct 10 02:42:55 2008 -0700
4.2 +++ b/tests/run/bufaccess.pyx Sat Oct 11 18:48:15 2008 +0200
4.3 @@ -1292,6 +1292,15 @@
4.4 int d
4.5 int e
4.6
4.7 +cdef struct SmallStruct:
4.8 + int a
4.9 + int b
4.10 +
4.11 +cdef struct NestedStruct:
4.12 + SmallStruct x
4.13 + SmallStruct y
4.14 + int z
4.15 +
4.16 cdef class MyStructMockBuffer(MockBuffer):
4.17 cdef int write(self, char* buf, object value) except -1:
4.18 cdef MyStruct* s
4.19 @@ -1302,6 +1311,16 @@
4.20 cdef get_itemsize(self): return sizeof(MyStruct)
4.21 cdef get_default_format(self): return b"2bq2i"
4.22
4.23 +cdef class NestedStructMockBuffer(MockBuffer):
4.24 + cdef int write(self, char* buf, object value) except -1:
4.25 + cdef NestedStruct* s
4.26 + s = <NestedStruct*>buf;
4.27 + s.x.a, s.x.b, s.y.a, s.y.b, s.z = value
4.28 + return 0
4.29 +
4.30 + cdef get_itemsize(self): return sizeof(NestedStruct)
4.31 + cdef get_default_format(self): return b"2T{ii}i"
4.32 +
4.33 @testcase
4.34 def basic_struct(object[MyStruct] buf):
4.35 """
4.36 @@ -1316,6 +1335,21 @@
4.37 """
4.38 print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
4.39
4.40 +@testcase
4.41 +def nested_struct(object[NestedStruct] buf):
4.42 + """
4.43 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
4.44 + 1 2 3 4 5
4.45 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i"))
4.46 + 1 2 3 4 5
4.47 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii"))
4.48 + Traceback (most recent call last):
4.49 + ...
4.50 + ValueError: Expected start of SmallStruct, got 'iiiii'
4.51 + """
4.52 + print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z
4.53 +
4.54 +
4.55 cdef struct LongComplex:
4.56 long double real
4.57 long double imag
5.1 --- a/tests/run/numpy_test.pyx Fri Oct 10 02:42:55 2008 -0700
5.2 +++ b/tests/run/numpy_test.pyx Sat Oct 11 18:48:15 2008 +0200
5.3 @@ -129,12 +129,22 @@
5.4 >>> test_dtype(np.int32, inc1_int32_t)
5.5 >>> test_dtype(np.float64, inc1_float64_t)
5.6
5.7 - Unsupported types:
5.8 - >>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
5.9 - >>> inc1_byte(a)
5.10 + >>> test_recordarray()
5.11 +
5.12 + >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
5.13 + ('a', np.dtype('i,i')),\
5.14 + ('b', np.dtype('i,i'))\
5.15 + ])))
5.16 + array([((0, 0), (0, 0)), ((1, 2), (1, 4)), ((1, 2), (1, 4))],
5.17 + dtype=[('a', [('f0', '<i4'), ('f1', '<i4')]), ('b', [('f0', '<i4'), ('f1', '<i4')])])
5.18 +
5.19 + >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
5.20 + ('a', np.dtype('i,f')),\
5.21 + ('b', np.dtype('i,i'))\
5.22 + ])))
5.23 Traceback (most recent call last):
5.24 - ...
5.25 - ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20)
5.26 + ...
5.27 + ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}')
5.28
5.29 >>> test_good_cast()
5.30 True
5.31 @@ -261,6 +271,49 @@
5.32 inc1(a)
5.33 if a[1] != 11: print "failed!"
5.34
5.35 +cdef struct DoubleInt:
5.36 + int x, y
5.37 +
5.38 +def test_recordarray():
5.39 + cdef object[DoubleInt] arr
5.40 + arr = np.array([(5,5), (4, 6)], dtype=np.dtype('i,i'))
5.41 + cdef DoubleInt rec
5.42 + rec = arr[0]
5.43 + if rec.x != 5: print "failed"
5.44 + if rec.y != 5: print "failed"
5.45 + rec.y += 5
5.46 + arr[1] = rec
5.47 + arr[0].x -= 2
5.48 + arr[0].y += 3
5.49 + if arr[0].x != 3: print "failed"
5.50 + if arr[0].y != 8: print "failed"
5.51 + if arr[1].x != 5: print "failed"
5.52 + if arr[1].y != 10: print "failed"
5.53 +
5.54 +cdef struct NestedStruct:
5.55 + DoubleInt a
5.56 + DoubleInt b
5.57 +
5.58 +cdef struct BadDoubleInt:
5.59 + float x
5.60 + int y
5.61 +
5.62 +cdef struct BadNestedStruct:
5.63 + DoubleInt a
5.64 + BadDoubleInt b
5.65 +
5.66 +def test_nested_dtypes(obj):
5.67 + cdef object[NestedStruct] arr = obj
5.68 + arr[1].a.x = 1
5.69 + arr[1].a.y = 2
5.70 + arr[1].b.x = arr[0].a.y + 1
5.71 + arr[1].b.y = 4
5.72 + arr[2] = arr[1]
5.73 + return arr
5.74 +
5.75 +def test_bad_nested_dtypes():
5.76 + cdef object[BadNestedStruct] arr
5.77 +
5.78 def test_good_cast():
5.79 # Check that a signed int can round-trip through casted unsigned int access
5.80 cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')
