Cython has moved to github.
cython-devel
view Cython/Compiler/Buffer.py @ 1238:d0c4b0a150ca
Buffers: Better error messages + bugfix
| author | Dag Sverre Seljebotn <dagss@student.matnat.uio.no> |
|---|---|
| date | Mon Oct 13 22:11:27 2008 +0200 (3 years ago) |
| parents | c9fc106f9412 |
| children | d53e468baad4 |
line source
1 from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
2 from Cython.Compiler.ModuleNode import ModuleNode
3 from Cython.Compiler.Nodes import *
4 from Cython.Compiler.ExprNodes import *
5 from Cython.Compiler.StringEncoding import EncodedString
6 from Cython.Compiler.Errors import CompileError
7 import Interpreter
8 import PyrexTypes
10 try:
11 set
12 except NameError:
13 from sets import Set as set
15 import textwrap
17 # Code cleanup ideas:
18 # - One could be more smart about casting in some places
19 # - Start using CCodeWriters to generate utility functions
20 # - Create a struct type per ndim rather than keeping loose local vars
23 def dedent(text, reindent=0):
24 text = textwrap.dedent(text)
25 if reindent > 0:
26 indent = " " * reindent
27 text = '\n'.join([indent + x for x in text.split('\n')])
28 return text
30 class IntroduceBufferAuxiliaryVars(CythonTransform):
32 #
33 # Entry point
34 #
36 buffers_exists = False
38 def __call__(self, node):
39 assert isinstance(node, ModuleNode)
40 self.max_ndim = 0
41 result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
42 if self.buffers_exists:
43 use_py2_buffer_functions(node.scope)
44 use_empty_bufstruct_code(node.scope, self.max_ndim)
45 return result
48 #
49 # Basic operations for transforms
50 #
51 def handle_scope(self, node, scope):
52 # For all buffers, insert extra variables in the scope.
53 # The variables are also accessible from the buffer_info
54 # on the buffer entry
55 bufvars = [entry for name, entry
56 in scope.entries.iteritems()
57 if entry.type.is_buffer]
58 if len(bufvars) > 0:
59 self.buffers_exists = True
62 if isinstance(node, ModuleNode) and len(bufvars) > 0:
63 # for now...note that pos is wrong
64 raise CompileError(node.pos, "Buffer vars not allowed in module scope")
65 for entry in bufvars:
66 name = entry.name
67 buftype = entry.type
68 if buftype.ndim > self.max_ndim:
69 self.max_ndim = buftype.ndim
71 # Declare auxiliary vars
72 cname = scope.mangle(Naming.bufstruct_prefix, name)
73 bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
74 type=PyrexTypes.c_py_buffer_type, pos=node.pos)
75 if entry.is_arg:
76 bufinfo.used = True # otherwise, NameNode will mark whether it is used
78 def var(prefix, idx, initval):
79 cname = scope.mangle(prefix, "%d_%s" % (idx, name))
80 result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
81 node.pos, cname=cname, is_cdef=True)
83 result.init = initval
84 if entry.is_arg:
85 result.used = True
86 return result
89 stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
90 shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
91 mode = entry.type.mode
92 if mode == 'full':
93 suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
94 else:
95 suboffsetvars = None
97 entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
99 scope.buffer_entries = bufvars
100 self.scope = scope
102 def visit_ModuleNode(self, node):
103 self.handle_scope(node, node.scope)
104 self.visitchildren(node)
105 return node
107 def visit_FuncDefNode(self, node):
108 self.handle_scope(node, node.local_scope)
109 self.visitchildren(node)
110 return node
112 #
113 # Analysis
114 #
115 buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
116 buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
117 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
119 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
120 ERR_BUF_TOO_MANY = 'Too many buffer options'
121 ERR_BUF_DUP = '"%s" buffer option already supplied'
122 ERR_BUF_MISSING = '"%s" missing'
123 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
124 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
125 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
126 ERR_BUF_BOOL = '"%s" must be a boolean'
128 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
129 """
130 Must be called during type analysis, as analyse is called
131 on the dtype argument.
133 posargs and dictargs should consist of a list and a dict
134 of tuples (value, pos). Defaults should be a dict of values.
136 Returns a dict containing all the options a buffer can have and
137 its value (with the positions stripped).
138 """
139 if defaults is None:
140 defaults = buffer_defaults
142 posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
144 if len(posargs) > buffer_positional_options_count:
145 raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
147 options = {}
148 for name, (value, pos) in dictargs.iteritems():
149 if not name in buffer_options:
150 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
151 options[name.encode("ASCII")] = value
153 for name, (value, pos) in zip(buffer_options, posargs):
154 if not name in buffer_options:
155 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
156 if name in options:
157 raise CompileError(pos, ERR_BUF_DUP % name)
158 options[name] = value
160 # Check that they are all there and copy defaults
161 for name in buffer_options:
162 if not name in options:
163 try:
164 options[name] = defaults[name]
165 except KeyError:
166 if need_complete:
167 raise CompileError(globalpos, ERR_BUF_MISSING % name)
169 dtype = options.get("dtype")
170 if dtype and dtype.is_extension_type:
171 raise CompileError(globalpos, ERR_BUF_DTYPE)
173 ndim = options.get("ndim")
174 if ndim and (not isinstance(ndim, int) or ndim < 0):
175 raise CompileError(globalpos, ERR_BUF_NDIM)
177 mode = options.get("mode")
178 if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
179 raise CompileError(globalpos, ERR_BUF_MODE)
181 def assert_bool(name):
182 x = options.get(name)
183 if not isinstance(x, bool):
184 raise CompileError(globalpos, ERR_BUF_BOOL % name)
186 assert_bool('negative_indices')
187 assert_bool('cast')
189 return options
192 #
193 # Code generation
194 #
197 def get_flags(buffer_aux, buffer_type):
198 flags = 'PyBUF_FORMAT'
199 mode = buffer_type.mode
200 if mode == 'full':
201 flags += '| PyBUF_INDIRECT'
202 elif mode == 'strided':
203 flags += '| PyBUF_STRIDES'
204 elif mode == 'c':
205 flags += '| PyBUF_C_CONTIGUOUS'
206 elif mode == 'fortran':
207 flags += '| PyBUF_F_CONTIGUOUS'
208 else:
209 assert False
210 if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
211 return flags
213 def used_buffer_aux_vars(entry):
214 buffer_aux = entry.buffer_aux
215 buffer_aux.buffer_info_var.used = True
216 for s in buffer_aux.shapevars: s.used = True
217 for s in buffer_aux.stridevars: s.used = True
218 if buffer_aux.suboffsetvars:
219 for s in buffer_aux.suboffsetvars: s.used = True
221 def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
222 # Generate code to copy the needed struct info into local
223 # variables.
224 bufstruct = buffer_aux.buffer_info_var.cname
226 varspec = [("strides", buffer_aux.stridevars),
227 ("shape", buffer_aux.shapevars)]
228 if mode == 'full':
229 varspec.append(("suboffsets", buffer_aux.suboffsetvars))
231 for field, vars in varspec:
232 code.putln(" ".join(["%s = %s.%s[%d];" %
233 (s.cname, bufstruct, field, idx)
234 for idx, s in enumerate(vars)]))
236 def put_acquire_arg_buffer(entry, code, pos):
237 code.globalstate.use_utility_code(acquire_utility_code)
238 buffer_aux = entry.buffer_aux
239 getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
241 # Acquire any new buffer
242 code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
243 getbuffer_cname,
244 entry.cname,
245 entry.buffer_aux.buffer_info_var.cname,
246 get_flags(buffer_aux, entry.type),
247 entry.type.ndim,
248 int(entry.type.cast)), pos))
249 # An exception raised in arg parsing cannot be catched, so no
250 # need to care about the buffer then.
251 put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
253 #def put_release_buffer_normal(entry, code):
254 # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
255 # entry.cname,
256 # entry.cname,
257 # entry.buffer_aux.buffer_info_var.cname))
259 def get_release_buffer_code(entry):
260 return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
262 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
263 is_initialized, pos, code):
264 """
265 Generate code for reassigning a buffer variables. This only deals with getting
266 the buffer auxiliary structure and variables set up correctly, the assignment
267 itself and refcounting is the responsibility of the caller.
269 However, the assignment operation may throw an exception so that the reassignment
270 never happens.
272 Depending on the circumstances there are two possible outcomes:
273 - Old buffer released, new acquired, rhs assigned to lhs
274 - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
275 (which may or may not succeed).
276 """
278 code.globalstate.use_utility_code(acquire_utility_code)
279 bufstruct = buffer_aux.buffer_info_var.cname
280 flags = get_flags(buffer_aux, buffer_type)
282 getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
283 # note: object is filled in later (%%s)
284 bufstruct,
285 flags,
286 buffer_type.ndim,
287 int(buffer_type.cast))
289 if is_initialized:
290 # Release any existing buffer
291 code.putln('__Pyx_SafeReleaseBuffer(&%s);' % bufstruct)
292 # Acquire
293 retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
294 code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
295 code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
296 # If acquisition failed, attempt to reacquire the old buffer
297 # before raising the exception. A failure of reacquisition
298 # will cause the reacquisition exception to be reported, one
299 # can consider working around this later.
300 code.begin_block()
301 type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
302 for i in range(3)]
303 code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
304 code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
305 code.begin_block()
306 code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
307 code.globalstate.use_utility_code(raise_buffer_fallback_code)
308 code.putln('__Pyx_RaiseBufferFallbackError();')
309 code.putln('} else {')
310 code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
311 for t in (type, value, tb):
312 code.funcstate.release_temp(t)
313 code.end_block()
314 # Unpack indices
315 code.end_block()
316 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
317 code.putln(code.error_goto_if_neg(retcode_cname, pos))
318 code.funcstate.release_temp(retcode_cname)
319 else:
320 # Our entry had no previous value, so set to None when acquisition fails.
321 # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
322 # so it suffices to set the buf field to NULL.
323 code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
324 code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
325 (lhs_cname,
326 PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
327 bufstruct))
328 code.putln(code.error_goto(pos))
329 code.put('} else {')
330 # Unpack indices
331 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
332 code.putln('}')
335 def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
336 """
337 Generates code to process indices and calculate an offset into
338 a buffer. Returns a C string which gives a pointer which can be
339 read from or written to at will (it is an expression so caller should
340 store it in a temporary if it is used more than once).
342 As the bounds checking can have any number of combinations of unsigned
343 arguments, smart optimizations etc. we insert it directly in the function
344 body. The lookup however is delegated to a inline function that is instantiated
345 once per ndim (lookup with suboffsets tend to get quite complicated).
347 """
348 bufaux = entry.buffer_aux
349 bufstruct = bufaux.buffer_info_var.cname
350 negative_indices = entry.type.negative_indices
352 if options['boundscheck']:
353 # Check bounds and fix negative indices.
354 # We allocate a temporary which is initialized to -1, meaning OK (!).
355 # If an error occurs, the temp is set to the dimension index the
356 # error is occuring at.
357 tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
358 code.putln("%s = -1;" % tmp_cname)
359 for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
360 bufaux.shapevars)):
361 if signed != 0:
362 # not unsigned, deal with negative index
363 code.putln("if (%s < 0) {" % cname)
364 if negative_indices:
365 code.putln("%s += %s;" % (cname, shape.cname))
366 code.putln("if (%s) %s = %d;" % (
367 code.unlikely("%s < 0" % cname), tmp_cname, dim))
368 else:
369 code.putln("%s = %d;" % (tmp_cname, dim))
370 code.put("} else ")
371 # check bounds in positive direction
372 code.putln("if (%s) %s = %d;" % (
373 code.unlikely("%s >= %s" % (cname, shape.cname)),
374 tmp_cname, dim))
375 code.globalstate.use_utility_code(raise_indexerror_code)
376 code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
377 code.begin_block()
378 code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
379 code.putln(code.error_goto(pos))
380 code.end_block()
381 code.funcstate.release_temp(tmp_cname)
382 elif negative_indices:
383 # Only fix negative indices.
384 for signed, cname, shape in zip(index_signeds, index_cnames,
385 bufaux.shapevars):
386 if signed != 0:
387 code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
389 # Create buffer lookup and return it
390 # This is done via utility macros/inline functions, which vary
391 # according to the access mode used.
392 params = []
393 nd = entry.type.ndim
394 mode = entry.type.mode
395 if mode == 'full':
396 for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
397 params.append(i)
398 params.append(s.cname)
399 params.append(o.cname)
400 funcname = "__Pyx_BufPtrFull%dd" % nd
401 funcgen = buf_lookup_full_code
402 else:
403 if mode == 'strided':
404 funcname = "__Pyx_BufPtrStrided%dd" % nd
405 funcgen = buf_lookup_strided_code
406 elif mode == 'c':
407 funcname = "__Pyx_BufPtrCContig%dd" % nd
408 funcgen = buf_lookup_c_code
409 elif mode == 'fortran':
410 funcname = "__Pyx_BufPtrFortranContig%dd" % nd
411 funcgen = buf_lookup_fortran_code
412 else:
413 assert False
414 for i, s in zip(index_cnames, bufaux.stridevars):
415 params.append(i)
416 params.append(s.cname)
418 # Make sure the utility code is available
419 code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
421 ptr_type = entry.type.buffer_ptr_type
422 ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
423 ptr_type.declaration_code(""),
424 bufstruct,
425 ", ".join(params))
426 return ptrcode
429 def use_empty_bufstruct_code(env, max_ndim):
430 code = dedent("""
431 Py_ssize_t __Pyx_zeros[] = {%s};
432 Py_ssize_t __Pyx_minusones[] = {%s};
433 """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
434 env.use_utility_code([code, ""], "empty_bufstruct_code")
437 def buf_lookup_full_code(proto, defin, name, nd):
438 """
439 Generates a buffer lookup function for the right number
440 of dimensions. The function gives back a void* at the right location.
441 """
442 # _i_ndex, _s_tride, sub_o_ffset
443 macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
444 proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
446 funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
447 proto.putln("static INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
448 defin.putln(dedent("""
449 static INLINE void* %s_imp(void* buf, %s) {
450 char* ptr = (char*)buf;
451 """) % (name, funcargs) + "".join([dedent("""\
452 ptr += s%d * i%d;
453 if (o%d >= 0) ptr = *((char**)ptr) + o%d;
454 """) % (i, i, i, i) for i in range(nd)]
455 ) + "\nreturn ptr;\n}")
457 def buf_lookup_strided_code(proto, defin, name, nd):
458 """
459 Generates a buffer lookup function for the right number
460 of dimensions. The function gives back a void* at the right location.
461 """
462 # _i_ndex, _s_tride
463 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
464 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
465 proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
467 def buf_lookup_c_code(proto, defin, name, nd):
468 """
469 Similar to strided lookup, but can assume that the last dimension
470 doesn't need a multiplication as long as.
471 Still we keep the same signature for now.
472 """
473 if nd == 1:
474 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
475 else:
476 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
477 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
478 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
480 def buf_lookup_fortran_code(proto, defin, name, nd):
481 """
482 Like C lookup, but the first index is optimized instead.
483 """
484 if nd == 1:
485 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
486 else:
487 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
488 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
489 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
491 #
492 # Utils for creating type string checkers
493 #
494 def mangle_dtype_name(dtype):
495 # Use prefixes to seperate user defined types from builtins
496 # (consider "typedef float unsigned_int")
497 if dtype.is_pyobject:
498 return "object"
499 elif dtype.is_ptr:
500 return "ptr"
501 else:
502 if dtype.typestring is None:
503 prefix = "nn_"
504 else:
505 prefix = ""
506 return prefix + dtype.declaration_code("").replace(" ", "_")
508 def get_typestringchecker(code, dtype):
509 """
510 Returns the name of a typestring checker with the given type; emitting
511 it to code if needed.
512 """
513 name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
514 code.globalstate.use_code_from(create_typestringchecker,
515 name,
516 dtype=dtype)
517 return name
519 def create_typestringchecker(protocode, defcode, name, dtype):
521 def put_assert(cond, msg):
522 defcode.putln("if (!(%s)) {" % cond)
523 defcode.putln('PyErr_Format(PyExc_ValueError, "Buffer dtype mismatch (%s)", __Pyx_DescribeTokenInFormatString(ts));' % msg)
524 defcode.putln("return NULL;")
525 defcode.putln("}")
527 if dtype.is_error: return
528 simple = dtype.is_simple_buffer_dtype()
529 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
530 # Cannot add utility code recursively...
531 if not simple:
532 dtype_t = dtype.declaration_code("")
533 protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
534 fields = dtype.scope.var_entries
536 # divide fields into blocks of equal type (for repeat count)
537 field_blocks = [] # of (n, type, checkerfunc)
538 n = 0
539 prevtype = None
540 for f in fields:
541 if n and f.type != prevtype:
542 field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
543 n = 0
544 prevtype = f.type
545 n += 1
546 field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
548 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
549 defcode.putln("static const char* %s(const char* ts) {" % name)
550 if simple:
551 defcode.putln("int ok;")
552 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
553 defcode.putln("if (*ts == '1') ++ts;")
554 if dtype.typestring is not None:
555 assert len(dtype.typestring) == 1
556 # Can use direct comparison
557 defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
558 else:
559 # Cannot trust declared size; but rely on int vs float and
560 # signed/unsigned to be correctly declared. Use a switch statement
561 # on all possible format codes to validate that the size is ok.
562 # (Note that many codes may map to same size, e.g. 'i' and 'l'
563 # may both be four bytes).
564 ctype = dtype.declaration_code("")
565 defcode.putln("switch (*ts) {")
566 if dtype.is_int:
567 types = [
568 ('b', 'char'), ('h', 'short'), ('i', 'int'),
569 ('l', 'long'), ('q', 'long long')
570 ]
571 elif dtype.is_float:
572 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
573 else:
574 assert False
575 if dtype.signed == 0:
576 for char, against in types:
577 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
578 (char.upper(), ctype, against, ctype))
579 else:
580 for char, against in types:
581 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
582 (char, ctype, against, ctype))
583 defcode.putln("default: ok = 0;")
584 defcode.putln("}")
585 put_assert("ok", "expected %s, got %%s" % dtype)
586 defcode.putln("++ts;")
587 elif complex_possible:
588 # Could be a struct representing a complex number, so allow
589 # for parsing a "Zf" spec.
590 real_t, imag_t = [x.type for x in fields]
591 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
592 defcode.putln("if (*ts == '1') ++ts;")
593 defcode.putln("if (*ts == 'Z') {")
594 if len(field_blocks) == 2:
595 # Different float type, sizeof check needed
596 defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
597 real_t.declaration_code(""),
598 imag_t.declaration_code("")))
599 defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
600 dtype, real_t, imag_t))
601 defcode.putln("return NULL;")
602 defcode.putln("}")
603 check_real, check_imag = [x[2] for x in field_blocks]
604 else:
605 assert len(field_blocks) == 1
606 check_real = check_imag = field_blocks[0][2]
607 defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
608 defcode.putln("} else {")
609 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
610 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
611 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
612 defcode.putln("}")
613 else:
614 defcode.putln("int n, count;")
615 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
617 next_types = [x[1] for x in field_blocks[1:]] + ["end"]
618 for (n, type, checker), next_type in zip(field_blocks, next_types):
619 if n == 1:
620 defcode.putln("if (*ts == '1') ++ts;")
621 else:
622 defcode.putln("n = %d;" % n);
623 defcode.putln("do {")
624 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
625 put_assert("n >= 0", "expected %s, got %%s" % next_type)
627 simple = type.is_simple_buffer_dtype()
628 if not simple:
629 put_assert("*ts == 'T' && *(ts+1) == '{'", "expected %s, got %%s" % type)
630 defcode.putln("ts += 2;")
631 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
632 if not simple:
633 put_assert("*ts == '}'", "expected end of %s struct, got %%s" % type)
634 defcode.putln("++ts;")
636 if n > 1:
637 defcode.putln("} while (n > 0);");
638 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
640 defcode.putln("return ts;")
641 defcode.putln("}")
643 def get_getbuffer_code(dtype, code):
644 """
645 Generate a utility function for getting a buffer for the given dtype.
646 The function will:
647 - Call PyObject_GetBuffer
648 - Check that ndim matched the expected value
649 - Check that the format string is right
650 - Set suboffsets to all -1 if it is returned as NULL.
651 """
653 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
654 if not code.globalstate.has_code(name):
655 code.globalstate.use_utility_code(acquire_utility_code)
656 typestringchecker = get_typestringchecker(code, dtype)
657 dtype_name = str(dtype)
658 dtype_cname = dtype.declaration_code("")
659 utilcode = [dedent("""
660 static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
661 """) % name, dedent("""
662 static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
663 const char* ts;
664 if (obj == Py_None) {
665 __Pyx_ZeroBuffer(buf);
666 return 0;
667 }
668 buf->buf = NULL;
669 if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
670 if (buf->ndim != nd) {
671 __Pyx_BufferNdimError(buf, nd);
672 goto fail;
673 }
674 if (!cast) {
675 ts = buf->format;
676 ts = __Pyx_ConsumeWhitespace(ts);
677 if (!ts) goto fail;
678 ts = %(typestringchecker)s(ts);
679 if (!ts) goto fail;
680 ts = __Pyx_ConsumeWhitespace(ts);
681 if (!ts) goto fail;
682 if (*ts != 0) {
683 PyErr_Format(PyExc_ValueError,
684 "Buffer dtype mismatch (expected end, got %%s)",
685 __Pyx_DescribeTokenInFormatString(ts));
686 goto fail;
687 }
688 } else {
689 if (buf->itemsize != sizeof(%(dtype_cname)s)) {
690 PyErr_SetString(PyExc_ValueError,
691 "Attempted cast of buffer to datatype of different size.");
692 goto fail;
693 }
694 }
695 if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
696 return 0;
697 fail:;
698 __Pyx_ZeroBuffer(buf);
699 return -1;
700 }""") % locals()]
701 code.globalstate.use_utility_code(utilcode, name)
702 return name
704 def buffer_type_checker(dtype, code):
705 # Creates a type checker function for the given type.
706 if dtype.is_struct_or_union:
707 assert False
708 elif dtype.is_int or dtype.is_float:
709 # This includes simple typedef-ed types
710 funcname = get_getbuffer_code(dtype, code)
711 else:
712 assert False
713 return funcname
715 def use_py2_buffer_functions(env):
716 # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
717 # For >= 2.6 we do double mode -- use the new buffer interface on objects
718 # which has the right tp_flags set, but emulation otherwise.
719 codename = "PyObject_GetBuffer" # just a representative unique key
721 # Search all types for __getbuffer__ overloads
722 types = []
723 def find_buffer_types(scope):
724 for m in scope.cimported_modules:
725 find_buffer_types(m)
726 for e in scope.type_entries:
727 t = e.type
728 if t.is_extension_type:
729 release = get = None
730 for x in t.scope.pyfunc_entries:
731 if x.name == u"__getbuffer__": get = x.func_cname
732 elif x.name == u"__releasebuffer__": release = x.func_cname
733 if get:
734 types.append((t.typeptr_cname, get, release))
736 find_buffer_types(env)
738 code = dedent("""
739 #if PY_MAJOR_VERSION < 3
740 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
741 #if PY_VERSION_HEX >= 0x02060000
742 if (Py_TYPE(obj)->tp_flags & Py_TPFLAGS_HAVE_NEWBUFFER)
743 return PyObject_GetBuffer(obj, view, flags);
744 #endif
745 """)
746 if len(types) > 0:
747 clause = "if"
748 for t, get, release in types:
749 code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
750 clause = "else if"
751 code += " else {\n"
752 code += dedent("""\
753 PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
754 return -1;
755 """, 2)
756 if len(types) > 0: code += " }"
757 code += dedent("""
758 }
760 static void __Pyx_ReleaseBuffer(Py_buffer *view) {
761 PyObject* obj = view->obj;
762 if (obj) {
763 """)
764 if len(types) > 0:
765 clause = "if"
766 for t, get, release in types:
767 if release:
768 code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
769 clause = "else if"
770 code += dedent("""
771 Py_DECREF(obj);
772 view->obj = NULL;
773 }
774 }
776 #endif
777 """)
779 env.use_utility_code([dedent("""\
780 #if PY_MAJOR_VERSION < 3
781 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
782 static void __Pyx_ReleaseBuffer(Py_buffer *view);
783 #else
784 #define __Pyx_GetBuffer PyObject_GetBuffer
785 #define __Pyx_ReleaseBuffer PyBuffer_Release
786 #endif
787 """), code], codename)
789 #
790 # Static utility code
791 #
794 # Utility function to set the right exception
795 # The caller should immediately goto_error
796 raise_indexerror_code = [
797 """\
798 static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
799 ""","""\
800 static void __Pyx_RaiseBufferIndexError(int axis) {
801 PyErr_Format(PyExc_IndexError,
802 "Out of bounds on buffer access (axis %d)", axis);
803 }
805 """]
807 #
808 # Buffer type checking. Utility code for checking that acquired
809 # buffers match our assumptions. We only need to check ndim and
810 # the format string; the access mode/flags is checked by the
811 # exporter.
812 #
813 acquire_utility_code = ["""\
814 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
815 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
816 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
817 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
818 static const char* __Pyx_DescribeTokenInFormatString(const char* ts); /*proto*/
819 """, """
820 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
821 if (info->buf == NULL) return;
822 if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
823 __Pyx_ReleaseBuffer(info);
824 }
826 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
827 buf->buf = NULL;
828 buf->obj = NULL;
829 buf->strides = __Pyx_zeros;
830 buf->shape = __Pyx_zeros;
831 buf->suboffsets = __Pyx_minusones;
832 }
834 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
835 while (1) {
836 switch (*ts) {
837 case '@':
838 case 10:
839 case 13:
840 case ' ':
841 ++ts;
842 break;
843 case '=':
844 case '<':
845 case '>':
846 case '!':
847 PyErr_SetString(PyExc_ValueError, "Buffer acquisition error: Only native byte order, size and alignment supported.");
848 return NULL;
849 default:
850 return ts;
851 }
852 }
853 }
855 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
856 PyErr_Format(PyExc_ValueError,
857 "Buffer has wrong number of dimensions (expected %d, got %d)",
858 expected_ndim, buffer->ndim);
859 }
861 static const char* __Pyx_DescribeTokenInFormatString(const char* ts) {
862 switch (*ts) {
863 case 'b': return "char";
864 case 'B': return "unsigned char";
865 case 'h': return "short";
866 case 'H': return "unsigned short";
867 case 'i': return "int";
868 case 'I': return "unsigned int";
869 case 'l': return "long";
870 case 'L': return "unsigned long";
871 case 'q': return "long long";
872 case 'Q': return "unsigned long long";
873 case 'f': return "float";
874 case 'd': return "double";
875 case 'g': return "long double";
876 case 'Z': switch (*(ts+1)) {
877 case 'f': return "complex float";
878 case 'd': return "complex double";
879 case 'g': return "complex long double";
880 default: return "unparseable format string";
881 }
882 case 'T': return "a struct";
883 case 'O': return "Python object";
884 case 'P': return "a pointer";
885 default: return "unparseable format string";
886 }
887 }
889 """]
892 parse_typestring_repeat_code = ["""
893 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
894 ""","""
895 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
896 int count;
897 if (*ts < '0' || *ts > '9') {
898 count = 1;
899 } else {
900 count = *ts++ - '0';
901 while (*ts >= '0' && *ts < '9') {
902 count *= 10;
903 count += *ts++ - '0';
904 }
905 }
906 *out_count = count;
907 return ts;
908 }
909 """]
911 raise_buffer_fallback_code = ["""
912 static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
913 ""","""
914 static void __Pyx_RaiseBufferFallbackError(void) {
915 PyErr_Format(PyExc_ValueError,
916 "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
917 }
919 """]
