Cython has moved to github.

cython

view Cython/Compiler/Buffer.py @ 1586:4a96e1aff2d4

Fix error in buffer typestring checking
author Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
date Tue Dec 16 10:02:51 2008 +0100 (3 years ago)
parents 43d4e2b19134
children cdf889c30e7a
line source
1 from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
2 from Cython.Compiler.ModuleNode import ModuleNode
3 from Cython.Compiler.Nodes import *
4 from Cython.Compiler.ExprNodes import *
5 from Cython.Compiler.StringEncoding import EncodedString
6 from Cython.Compiler.Errors import CompileError
7 from Cython.Utils import UtilityCode
8 import Interpreter
9 import PyrexTypes
11 try:
12 set
13 except NameError:
14 from sets import Set as set
16 import textwrap
18 # Code cleanup ideas:
19 # - One could be more smart about casting in some places
20 # - Start using CCodeWriters to generate utility functions
21 # - Create a struct type per ndim rather than keeping loose local vars
24 def dedent(text, reindent=0):
25 text = textwrap.dedent(text)
26 if reindent > 0:
27 indent = " " * reindent
28 text = '\n'.join([indent + x for x in text.split('\n')])
29 return text
31 class IntroduceBufferAuxiliaryVars(CythonTransform):
33 #
34 # Entry point
35 #
37 buffers_exists = False
39 def __call__(self, node):
40 assert isinstance(node, ModuleNode)
41 self.max_ndim = 0
42 result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
43 if self.buffers_exists:
44 use_py2_buffer_functions(node.scope)
45 use_empty_bufstruct_code(node.scope, self.max_ndim)
46 return result
49 #
50 # Basic operations for transforms
51 #
52 def handle_scope(self, node, scope):
53 # For all buffers, insert extra variables in the scope.
54 # The variables are also accessible from the buffer_info
55 # on the buffer entry
56 bufvars = [entry for name, entry
57 in scope.entries.iteritems()
58 if entry.type.is_buffer]
59 if len(bufvars) > 0:
60 self.buffers_exists = True
63 if isinstance(node, ModuleNode) and len(bufvars) > 0:
64 # for now...note that pos is wrong
65 raise CompileError(node.pos, "Buffer vars not allowed in module scope")
66 for entry in bufvars:
67 name = entry.name
68 buftype = entry.type
69 if buftype.ndim > self.max_ndim:
70 self.max_ndim = buftype.ndim
72 # Declare auxiliary vars
73 cname = scope.mangle(Naming.bufstruct_prefix, name)
74 bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
75 type=PyrexTypes.c_py_buffer_type, pos=node.pos)
76 if entry.is_arg:
77 bufinfo.used = True # otherwise, NameNode will mark whether it is used
79 def var(prefix, idx, initval):
80 cname = scope.mangle(prefix, "%d_%s" % (idx, name))
81 result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
82 node.pos, cname=cname, is_cdef=True)
84 result.init = initval
85 if entry.is_arg:
86 result.used = True
87 return result
90 stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
91 shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
92 mode = entry.type.mode
93 if mode == 'full':
94 suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
95 else:
96 suboffsetvars = None
98 entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
100 scope.buffer_entries = bufvars
101 self.scope = scope
103 def visit_ModuleNode(self, node):
104 self.handle_scope(node, node.scope)
105 self.visitchildren(node)
106 return node
108 def visit_FuncDefNode(self, node):
109 self.handle_scope(node, node.local_scope)
110 self.visitchildren(node)
111 return node
113 #
114 # Analysis
115 #
116 buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
117 buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
118 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
120 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
121 ERR_BUF_TOO_MANY = 'Too many buffer options'
122 ERR_BUF_DUP = '"%s" buffer option already supplied'
123 ERR_BUF_MISSING = '"%s" missing'
124 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
125 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
126 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
127 ERR_BUF_BOOL = '"%s" must be a boolean'
129 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
130 """
131 Must be called during type analysis, as analyse is called
132 on the dtype argument.
134 posargs and dictargs should consist of a list and a dict
135 of tuples (value, pos). Defaults should be a dict of values.
137 Returns a dict containing all the options a buffer can have and
138 its value (with the positions stripped).
139 """
140 if defaults is None:
141 defaults = buffer_defaults
143 posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
145 if len(posargs) > buffer_positional_options_count:
146 raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
148 options = {}
149 for name, (value, pos) in dictargs.iteritems():
150 if not name in buffer_options:
151 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
152 options[name.encode("ASCII")] = value
154 for name, (value, pos) in zip(buffer_options, posargs):
155 if not name in buffer_options:
156 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
157 if name in options:
158 raise CompileError(pos, ERR_BUF_DUP % name)
159 options[name] = value
161 # Check that they are all there and copy defaults
162 for name in buffer_options:
163 if not name in options:
164 try:
165 options[name] = defaults[name]
166 except KeyError:
167 if need_complete:
168 raise CompileError(globalpos, ERR_BUF_MISSING % name)
170 dtype = options.get("dtype")
171 if dtype and dtype.is_extension_type:
172 raise CompileError(globalpos, ERR_BUF_DTYPE)
174 ndim = options.get("ndim")
175 if ndim and (not isinstance(ndim, int) or ndim < 0):
176 raise CompileError(globalpos, ERR_BUF_NDIM)
178 mode = options.get("mode")
179 if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
180 raise CompileError(globalpos, ERR_BUF_MODE)
182 def assert_bool(name):
183 x = options.get(name)
184 if not isinstance(x, bool):
185 raise CompileError(globalpos, ERR_BUF_BOOL % name)
187 assert_bool('negative_indices')
188 assert_bool('cast')
190 return options
193 #
194 # Code generation
195 #
198 def get_flags(buffer_aux, buffer_type):
199 flags = 'PyBUF_FORMAT'
200 mode = buffer_type.mode
201 if mode == 'full':
202 flags += '| PyBUF_INDIRECT'
203 elif mode == 'strided':
204 flags += '| PyBUF_STRIDES'
205 elif mode == 'c':
206 flags += '| PyBUF_C_CONTIGUOUS'
207 elif mode == 'fortran':
208 flags += '| PyBUF_F_CONTIGUOUS'
209 else:
210 assert False
211 if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
212 return flags
214 def used_buffer_aux_vars(entry):
215 buffer_aux = entry.buffer_aux
216 buffer_aux.buffer_info_var.used = True
217 for s in buffer_aux.shapevars: s.used = True
218 for s in buffer_aux.stridevars: s.used = True
219 if buffer_aux.suboffsetvars:
220 for s in buffer_aux.suboffsetvars: s.used = True
222 def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
223 # Generate code to copy the needed struct info into local
224 # variables.
225 bufstruct = buffer_aux.buffer_info_var.cname
227 varspec = [("strides", buffer_aux.stridevars),
228 ("shape", buffer_aux.shapevars)]
229 if mode == 'full':
230 varspec.append(("suboffsets", buffer_aux.suboffsetvars))
232 for field, vars in varspec:
233 code.putln(" ".join(["%s = %s.%s[%d];" %
234 (s.cname, bufstruct, field, idx)
235 for idx, s in enumerate(vars)]))
237 def put_acquire_arg_buffer(entry, code, pos):
238 code.globalstate.use_utility_code(acquire_utility_code)
239 buffer_aux = entry.buffer_aux
240 getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
242 # Acquire any new buffer
243 code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
244 getbuffer_cname,
245 entry.cname,
246 entry.buffer_aux.buffer_info_var.cname,
247 get_flags(buffer_aux, entry.type),
248 entry.type.ndim,
249 int(entry.type.cast)), pos))
250 # An exception raised in arg parsing cannot be catched, so no
251 # need to care about the buffer then.
252 put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
254 #def put_release_buffer_normal(entry, code):
255 # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
256 # entry.cname,
257 # entry.cname,
258 # entry.buffer_aux.buffer_info_var.cname))
260 def get_release_buffer_code(entry):
261 return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
263 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
264 is_initialized, pos, code):
265 """
266 Generate code for reassigning a buffer variables. This only deals with getting
267 the buffer auxiliary structure and variables set up correctly, the assignment
268 itself and refcounting is the responsibility of the caller.
270 However, the assignment operation may throw an exception so that the reassignment
271 never happens.
273 Depending on the circumstances there are two possible outcomes:
274 - Old buffer released, new acquired, rhs assigned to lhs
275 - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
276 (which may or may not succeed).
277 """
279 code.globalstate.use_utility_code(acquire_utility_code)
280 bufstruct = buffer_aux.buffer_info_var.cname
281 flags = get_flags(buffer_aux, buffer_type)
283 getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
284 # note: object is filled in later (%%s)
285 bufstruct,
286 flags,
287 buffer_type.ndim,
288 int(buffer_type.cast))
290 if is_initialized:
291 # Release any existing buffer
292 code.putln('__Pyx_SafeReleaseBuffer(&%s);' % bufstruct)
293 # Acquire
294 retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
295 code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
296 code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
297 # If acquisition failed, attempt to reacquire the old buffer
298 # before raising the exception. A failure of reacquisition
299 # will cause the reacquisition exception to be reported, one
300 # can consider working around this later.
301 code.begin_block()
302 type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
303 for i in range(3)]
304 code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
305 code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
306 code.begin_block()
307 code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
308 code.globalstate.use_utility_code(raise_buffer_fallback_code)
309 code.putln('__Pyx_RaiseBufferFallbackError();')
310 code.putln('} else {')
311 code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
312 for t in (type, value, tb):
313 code.funcstate.release_temp(t)
314 code.end_block()
315 # Unpack indices
316 code.end_block()
317 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
318 code.putln(code.error_goto_if_neg(retcode_cname, pos))
319 code.funcstate.release_temp(retcode_cname)
320 else:
321 # Our entry had no previous value, so set to None when acquisition fails.
322 # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
323 # so it suffices to set the buf field to NULL.
324 code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
325 code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
326 (lhs_cname,
327 PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
328 bufstruct))
329 code.putln(code.error_goto(pos))
330 code.put('} else {')
331 # Unpack indices
332 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
333 code.putln('}')
336 def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
337 """
338 Generates code to process indices and calculate an offset into
339 a buffer. Returns a C string which gives a pointer which can be
340 read from or written to at will (it is an expression so caller should
341 store it in a temporary if it is used more than once).
343 As the bounds checking can have any number of combinations of unsigned
344 arguments, smart optimizations etc. we insert it directly in the function
345 body. The lookup however is delegated to a inline function that is instantiated
346 once per ndim (lookup with suboffsets tend to get quite complicated).
348 """
349 bufaux = entry.buffer_aux
350 bufstruct = bufaux.buffer_info_var.cname
351 negative_indices = entry.type.negative_indices
353 if options['boundscheck']:
354 # Check bounds and fix negative indices.
355 # We allocate a temporary which is initialized to -1, meaning OK (!).
356 # If an error occurs, the temp is set to the dimension index the
357 # error is occuring at.
358 tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
359 code.putln("%s = -1;" % tmp_cname)
360 for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
361 bufaux.shapevars)):
362 if signed != 0:
363 # not unsigned, deal with negative index
364 code.putln("if (%s < 0) {" % cname)
365 if negative_indices:
366 code.putln("%s += %s;" % (cname, shape.cname))
367 code.putln("if (%s) %s = %d;" % (
368 code.unlikely("%s < 0" % cname), tmp_cname, dim))
369 else:
370 code.putln("%s = %d;" % (tmp_cname, dim))
371 code.put("} else ")
372 # check bounds in positive direction
373 code.putln("if (%s) %s = %d;" % (
374 code.unlikely("%s >= %s" % (cname, shape.cname)),
375 tmp_cname, dim))
376 code.globalstate.use_utility_code(raise_indexerror_code)
377 code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
378 code.begin_block()
379 code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
380 code.putln(code.error_goto(pos))
381 code.end_block()
382 code.funcstate.release_temp(tmp_cname)
383 elif negative_indices:
384 # Only fix negative indices.
385 for signed, cname, shape in zip(index_signeds, index_cnames,
386 bufaux.shapevars):
387 if signed != 0:
388 code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
390 # Create buffer lookup and return it
391 # This is done via utility macros/inline functions, which vary
392 # according to the access mode used.
393 params = []
394 nd = entry.type.ndim
395 mode = entry.type.mode
396 if mode == 'full':
397 for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
398 params.append(i)
399 params.append(s.cname)
400 params.append(o.cname)
401 funcname = "__Pyx_BufPtrFull%dd" % nd
402 funcgen = buf_lookup_full_code
403 else:
404 if mode == 'strided':
405 funcname = "__Pyx_BufPtrStrided%dd" % nd
406 funcgen = buf_lookup_strided_code
407 elif mode == 'c':
408 funcname = "__Pyx_BufPtrCContig%dd" % nd
409 funcgen = buf_lookup_c_code
410 elif mode == 'fortran':
411 funcname = "__Pyx_BufPtrFortranContig%dd" % nd
412 funcgen = buf_lookup_fortran_code
413 else:
414 assert False
415 for i, s in zip(index_cnames, bufaux.stridevars):
416 params.append(i)
417 params.append(s.cname)
419 # Make sure the utility code is available
420 code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
422 ptr_type = entry.type.buffer_ptr_type
423 ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
424 ptr_type.declaration_code(""),
425 bufstruct,
426 ", ".join(params))
427 return ptrcode
430 def use_empty_bufstruct_code(env, max_ndim):
431 code = dedent("""
432 Py_ssize_t __Pyx_zeros[] = {%s};
433 Py_ssize_t __Pyx_minusones[] = {%s};
434 """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
435 env.use_utility_code(UtilityCode(proto=code), "empty_bufstruct_code")
438 def buf_lookup_full_code(proto, defin, name, nd):
439 """
440 Generates a buffer lookup function for the right number
441 of dimensions. The function gives back a void* at the right location.
442 """
443 # _i_ndex, _s_tride, sub_o_ffset
444 macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
445 proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
447 funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
448 proto.putln("static INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
449 defin.putln(dedent("""
450 static INLINE void* %s_imp(void* buf, %s) {
451 char* ptr = (char*)buf;
452 """) % (name, funcargs) + "".join([dedent("""\
453 ptr += s%d * i%d;
454 if (o%d >= 0) ptr = *((char**)ptr) + o%d;
455 """) % (i, i, i, i) for i in range(nd)]
456 ) + "\nreturn ptr;\n}")
458 def buf_lookup_strided_code(proto, defin, name, nd):
459 """
460 Generates a buffer lookup function for the right number
461 of dimensions. The function gives back a void* at the right location.
462 """
463 # _i_ndex, _s_tride
464 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
465 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
466 proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
468 def buf_lookup_c_code(proto, defin, name, nd):
469 """
470 Similar to strided lookup, but can assume that the last dimension
471 doesn't need a multiplication as long as.
472 Still we keep the same signature for now.
473 """
474 if nd == 1:
475 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
476 else:
477 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
478 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
479 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
481 def buf_lookup_fortran_code(proto, defin, name, nd):
482 """
483 Like C lookup, but the first index is optimized instead.
484 """
485 if nd == 1:
486 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
487 else:
488 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
489 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
490 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
492 #
493 # Utils for creating type string checkers
494 #
495 def mangle_dtype_name(dtype):
496 # Use prefixes to seperate user defined types from builtins
497 # (consider "typedef float unsigned_int")
498 if dtype.is_pyobject:
499 return "object"
500 elif dtype.is_ptr:
501 return "ptr"
502 else:
503 if dtype.is_typedef or dtype.is_struct_or_union:
504 prefix = "nn_"
505 else:
506 prefix = ""
507 return prefix + dtype.declaration_code("").replace(" ", "_")
509 def get_typestringchecker(code, dtype):
510 """
511 Returns the name of a typestring checker with the given type; emitting
512 it to code if needed.
513 """
514 name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
515 code.globalstate.use_code_from(create_typestringchecker,
516 name,
517 dtype=dtype)
518 return name
520 def create_typestringchecker(protocode, defcode, name, dtype):
522 def put_assert(cond, msg):
523 defcode.putln("if (!(%s)) {" % cond)
524 defcode.putln('PyErr_Format(PyExc_ValueError, "Buffer dtype mismatch (%s)", __Pyx_DescribeTokenInFormatString(ts));' % msg)
525 defcode.putln("return NULL;")
526 defcode.putln("}")
528 if dtype.is_error: return
529 simple = dtype.is_simple_buffer_dtype()
530 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
531 # Cannot add utility code recursively...
532 if not simple:
533 dtype_t = dtype.declaration_code("")
534 protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
535 fields = dtype.scope.var_entries
537 # divide fields into blocks of equal type (for repeat count)
538 field_blocks = [] # of (n, type, checkerfunc)
539 n = 0
540 prevtype = None
541 for f in fields:
542 if n and f.type != prevtype:
543 field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
544 n = 0
545 prevtype = f.type
546 n += 1
547 field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
549 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
550 defcode.putln("static const char* %s(const char* ts) {" % name)
551 if simple:
552 defcode.putln("int ok;")
553 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
554 defcode.putln("if (*ts == '1') ++ts;")
555 if dtype.is_pyobject:
556 defcode.putln("ok = (*ts == 'O');")
557 else:
558 # Cannot trust declared size; but rely on int vs float and
559 # signed/unsigned to be correctly declared. Use a switch statement
560 # on all possible format codes to validate that the size is ok.
561 # (Note that many codes may map to same size, e.g. 'i' and 'l'
562 # may both be four bytes).
563 ctype = dtype.declaration_code("")
564 defcode.putln("switch (*ts) {")
565 if dtype.is_int:
566 types = [
567 ('b', 'char'), ('h', 'short'), ('i', 'int'),
568 ('l', 'long'), ('q', 'long long')
569 ]
570 elif dtype.is_float:
571 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
572 else:
573 assert False
574 if dtype.signed == 0:
575 for char, against in types:
576 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
577 (char.upper(), ctype, against, ctype))
578 else:
579 for char, against in types:
580 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
581 (char, ctype, against, ctype))
582 defcode.putln("default: ok = 0;")
583 defcode.putln("}")
584 put_assert("ok", "expected %s, got %%s" % dtype)
585 defcode.putln("++ts;")
586 elif complex_possible:
587 # Could be a struct representing a complex number, so allow
588 # for parsing a "Zf" spec.
589 real_t, imag_t = [x.type for x in fields]
590 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
591 defcode.putln("if (*ts == '1') ++ts;")
592 defcode.putln("if (*ts == 'Z') {")
593 if len(field_blocks) == 2:
594 # Different float type, sizeof check needed
595 defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
596 real_t.declaration_code(""),
597 imag_t.declaration_code("")))
598 defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
599 dtype, real_t, imag_t))
600 defcode.putln("return NULL;")
601 defcode.putln("}")
602 check_real, check_imag = [x[2] for x in field_blocks]
603 else:
604 assert len(field_blocks) == 1
605 check_real = check_imag = field_blocks[0][2]
606 defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
607 defcode.putln("} else {")
608 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
609 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
610 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
611 defcode.putln("}")
612 else:
613 defcode.putln("int n, count;")
614 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
616 next_types = [x[1] for x in field_blocks[1:]] + ["end"]
617 for (n, type, checker), next_type in zip(field_blocks, next_types):
618 if n == 1:
619 defcode.putln("if (*ts == '1') ++ts;")
620 else:
621 defcode.putln("n = %d;" % n);
622 defcode.putln("do {")
623 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
624 put_assert("n >= 0", "expected %s, got %%s" % next_type)
626 simple = type.is_simple_buffer_dtype()
627 if not simple:
628 put_assert("*ts == 'T' && *(ts+1) == '{'", "expected %s, got %%s" % type)
629 defcode.putln("ts += 2;")
630 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
631 if not simple:
632 put_assert("*ts == '}'", "expected end of %s struct, got %%s" % type)
633 defcode.putln("++ts;")
635 if n > 1:
636 defcode.putln("} while (n > 0);");
637 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
639 defcode.putln("return ts;")
640 defcode.putln("}")
642 def get_getbuffer_code(dtype, code):
643 """
644 Generate a utility function for getting a buffer for the given dtype.
645 The function will:
646 - Call PyObject_GetBuffer
647 - Check that ndim matched the expected value
648 - Check that the format string is right
649 - Set suboffsets to all -1 if it is returned as NULL.
650 """
652 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
653 if not code.globalstate.has_code(name):
654 code.globalstate.use_utility_code(acquire_utility_code)
655 typestringchecker = get_typestringchecker(code, dtype)
656 dtype_name = str(dtype)
657 dtype_cname = dtype.declaration_code("")
658 utilcode = UtilityCode(proto = dedent("""
659 static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
660 """) % name, impl = dedent("""
661 static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
662 const char* ts;
663 if (obj == Py_None) {
664 __Pyx_ZeroBuffer(buf);
665 return 0;
666 }
667 buf->buf = NULL;
668 if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
669 if (buf->ndim != nd) {
670 __Pyx_BufferNdimError(buf, nd);
671 goto fail;
672 }
673 if (!cast) {
674 ts = buf->format;
675 ts = __Pyx_ConsumeWhitespace(ts);
676 if (!ts) goto fail;
677 ts = %(typestringchecker)s(ts);
678 if (!ts) goto fail;
679 ts = __Pyx_ConsumeWhitespace(ts);
680 if (!ts) goto fail;
681 if (*ts != 0) {
682 PyErr_Format(PyExc_ValueError,
683 "Buffer dtype mismatch (expected end, got %%s)",
684 __Pyx_DescribeTokenInFormatString(ts));
685 goto fail;
686 }
687 } else {
688 if (buf->itemsize != sizeof(%(dtype_cname)s)) {
689 PyErr_SetString(PyExc_ValueError,
690 "Attempted cast of buffer to datatype of different size.");
691 goto fail;
692 }
693 }
694 if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
695 return 0;
696 fail:;
697 __Pyx_ZeroBuffer(buf);
698 return -1;
699 }""") % locals())
700 code.globalstate.use_utility_code(utilcode, name)
701 return name
703 def use_py2_buffer_functions(env):
704 # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
705 # For >= 2.6 we do double mode -- use the new buffer interface on objects
706 # which has the right tp_flags set, but emulation otherwise.
707 codename = "PyObject_GetBuffer" # just a representative unique key
709 # Search all types for __getbuffer__ overloads
710 types = []
711 visited_scopes = set()
712 def find_buffer_types(scope):
713 if scope in visited_scopes:
714 return
715 visited_scopes.add(scope)
716 for m in scope.cimported_modules:
717 find_buffer_types(m)
718 for e in scope.type_entries:
719 t = e.type
720 if t.is_extension_type:
721 release = get = None
722 for x in t.scope.pyfunc_entries:
723 if x.name == u"__getbuffer__": get = x.func_cname
724 elif x.name == u"__releasebuffer__": release = x.func_cname
725 if get:
726 types.append((t.typeptr_cname, get, release))
728 find_buffer_types(env)
730 code = dedent("""
731 #if PY_MAJOR_VERSION < 3
732 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
733 #if PY_VERSION_HEX >= 0x02060000
734 if (Py_TYPE(obj)->tp_flags & Py_TPFLAGS_HAVE_NEWBUFFER)
735 return PyObject_GetBuffer(obj, view, flags);
736 #endif
737 """)
738 if len(types) > 0:
739 clause = "if"
740 for t, get, release in types:
741 code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
742 clause = "else if"
743 code += " else {\n"
744 code += dedent("""\
745 PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
746 return -1;
747 """, 2)
748 if len(types) > 0: code += " }"
749 code += dedent("""
750 }
752 static void __Pyx_ReleaseBuffer(Py_buffer *view) {
753 PyObject* obj = view->obj;
754 if (obj) {
755 """)
756 if len(types) > 0:
757 clause = "if"
758 for t, get, release in types:
759 if release:
760 code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
761 clause = "else if"
762 code += dedent("""
763 Py_DECREF(obj);
764 view->obj = NULL;
765 }
766 }
768 #endif
769 """)
771 env.use_utility_code(UtilityCode(
772 proto = dedent("""\
773 #if PY_MAJOR_VERSION < 3
774 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
775 static void __Pyx_ReleaseBuffer(Py_buffer *view);
776 #else
777 #define __Pyx_GetBuffer PyObject_GetBuffer
778 #define __Pyx_ReleaseBuffer PyBuffer_Release
779 #endif
780 """), impl = code), codename)
782 #
783 # Static utility code
784 #
787 # Utility function to set the right exception
788 # The caller should immediately goto_error
789 raise_indexerror_code = UtilityCode(
790 proto = """\
791 static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
792 """,
793 impl = """\
794 static void __Pyx_RaiseBufferIndexError(int axis) {
795 PyErr_Format(PyExc_IndexError,
796 "Out of bounds on buffer access (axis %d)", axis);
797 }
799 """)
801 #
802 # Buffer type checking. Utility code for checking that acquired
803 # buffers match our assumptions. We only need to check ndim and
804 # the format string; the access mode/flags is checked by the
805 # exporter.
806 #
807 acquire_utility_code = UtilityCode(
808 proto = """\
809 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
810 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
811 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
812 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
813 static const char* __Pyx_DescribeTokenInFormatString(const char* ts); /*proto*/
814 """,
815 impl = """
816 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
817 if (info->buf == NULL) return;
818 if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
819 __Pyx_ReleaseBuffer(info);
820 }
822 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
823 buf->buf = NULL;
824 buf->obj = NULL;
825 buf->strides = __Pyx_zeros;
826 buf->shape = __Pyx_zeros;
827 buf->suboffsets = __Pyx_minusones;
828 }
830 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
831 while (1) {
832 switch (*ts) {
833 case '@':
834 case 10:
835 case 13:
836 case ' ':
837 ++ts;
838 break;
839 case '=':
840 case '<':
841 case '>':
842 case '!':
843 PyErr_SetString(PyExc_ValueError, "Buffer acquisition error: Only native byte order, size and alignment supported.");
844 return NULL;
845 default:
846 return ts;
847 }
848 }
849 }
851 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
852 PyErr_Format(PyExc_ValueError,
853 "Buffer has wrong number of dimensions (expected %d, got %d)",
854 expected_ndim, buffer->ndim);
855 }
857 static const char* __Pyx_DescribeTokenInFormatString(const char* ts) {
858 switch (*ts) {
859 case 'b': return "char";
860 case 'B': return "unsigned char";
861 case 'h': return "short";
862 case 'H': return "unsigned short";
863 case 'i': return "int";
864 case 'I': return "unsigned int";
865 case 'l': return "long";
866 case 'L': return "unsigned long";
867 case 'q': return "long long";
868 case 'Q': return "unsigned long long";
869 case 'f': return "float";
870 case 'd': return "double";
871 case 'g': return "long double";
872 case 'Z': switch (*(ts+1)) {
873 case 'f': return "complex float";
874 case 'd': return "complex double";
875 case 'g': return "complex long double";
876 default: return "unparseable format string";
877 }
878 case 'T': return "a struct";
879 case 'O': return "Python object";
880 case 'P': return "a pointer";
881 default: return "unparseable format string";
882 }
883 }
885 """)
888 parse_typestring_repeat_code = UtilityCode(
889 proto = """
890 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
891 """,
892 impl = """
893 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
894 int count;
895 if (*ts < '0' || *ts > '9') {
896 count = 1;
897 } else {
898 count = *ts++ - '0';
899 while (*ts >= '0' && *ts < '9') {
900 count *= 10;
901 count += *ts++ - '0';
902 }
903 }
904 *out_count = count;
905 return ts;
906 }
907 """)
909 raise_buffer_fallback_code = UtilityCode(
910 proto = """
911 static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
912 """,
913 impl = """
914 static void __Pyx_RaiseBufferFallbackError(void) {
915 PyErr_Format(PyExc_ValueError,
916 "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
917 }
919 """)