Cython has moved to github.
cython
view Cython/Compiler/Buffer.py @ 1587:cdf889c30e7a
Better error message for lack of pointer buffers support.
| author | Dag Sverre Seljebotn <dagss@student.matnat.uio.no> |
|---|---|
| date | Tue Dec 16 10:09:52 2008 +0100 (3 years ago) |
| parents | 4a96e1aff2d4 |
| children | 136bd9fa09a7 |
line source
1 from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
2 from Cython.Compiler.ModuleNode import ModuleNode
3 from Cython.Compiler.Nodes import *
4 from Cython.Compiler.ExprNodes import *
5 from Cython.Compiler.StringEncoding import EncodedString
6 from Cython.Compiler.Errors import CompileError
7 from Cython.Utils import UtilityCode
8 import Interpreter
9 import PyrexTypes
11 try:
12 set
13 except NameError:
14 from sets import Set as set
16 import textwrap
18 # Code cleanup ideas:
19 # - One could be more smart about casting in some places
20 # - Start using CCodeWriters to generate utility functions
21 # - Create a struct type per ndim rather than keeping loose local vars
24 def dedent(text, reindent=0):
25 text = textwrap.dedent(text)
26 if reindent > 0:
27 indent = " " * reindent
28 text = '\n'.join([indent + x for x in text.split('\n')])
29 return text
31 class IntroduceBufferAuxiliaryVars(CythonTransform):
33 #
34 # Entry point
35 #
37 buffers_exists = False
39 def __call__(self, node):
40 assert isinstance(node, ModuleNode)
41 self.max_ndim = 0
42 result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
43 if self.buffers_exists:
44 use_py2_buffer_functions(node.scope)
45 use_empty_bufstruct_code(node.scope, self.max_ndim)
46 return result
49 #
50 # Basic operations for transforms
51 #
52 def handle_scope(self, node, scope):
53 # For all buffers, insert extra variables in the scope.
54 # The variables are also accessible from the buffer_info
55 # on the buffer entry
56 bufvars = [entry for name, entry
57 in scope.entries.iteritems()
58 if entry.type.is_buffer]
59 if len(bufvars) > 0:
60 self.buffers_exists = True
63 if isinstance(node, ModuleNode) and len(bufvars) > 0:
64 # for now...note that pos is wrong
65 raise CompileError(node.pos, "Buffer vars not allowed in module scope")
66 for entry in bufvars:
67 if entry.type.dtype.is_ptr:
68 raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
70 name = entry.name
71 buftype = entry.type
72 if buftype.ndim > self.max_ndim:
73 self.max_ndim = buftype.ndim
75 # Declare auxiliary vars
76 cname = scope.mangle(Naming.bufstruct_prefix, name)
77 bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
78 type=PyrexTypes.c_py_buffer_type, pos=node.pos)
79 if entry.is_arg:
80 bufinfo.used = True # otherwise, NameNode will mark whether it is used
82 def var(prefix, idx, initval):
83 cname = scope.mangle(prefix, "%d_%s" % (idx, name))
84 result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
85 node.pos, cname=cname, is_cdef=True)
87 result.init = initval
88 if entry.is_arg:
89 result.used = True
90 return result
93 stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
94 shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
95 mode = entry.type.mode
96 if mode == 'full':
97 suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
98 else:
99 suboffsetvars = None
101 entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
103 scope.buffer_entries = bufvars
104 self.scope = scope
106 def visit_ModuleNode(self, node):
107 self.handle_scope(node, node.scope)
108 self.visitchildren(node)
109 return node
111 def visit_FuncDefNode(self, node):
112 self.handle_scope(node, node.local_scope)
113 self.visitchildren(node)
114 return node
116 #
117 # Analysis
118 #
119 buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
120 buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
121 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
123 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
124 ERR_BUF_TOO_MANY = 'Too many buffer options'
125 ERR_BUF_DUP = '"%s" buffer option already supplied'
126 ERR_BUF_MISSING = '"%s" missing'
127 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
128 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
129 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
130 ERR_BUF_BOOL = '"%s" must be a boolean'
132 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
133 """
134 Must be called during type analysis, as analyse is called
135 on the dtype argument.
137 posargs and dictargs should consist of a list and a dict
138 of tuples (value, pos). Defaults should be a dict of values.
140 Returns a dict containing all the options a buffer can have and
141 its value (with the positions stripped).
142 """
143 if defaults is None:
144 defaults = buffer_defaults
146 posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
148 if len(posargs) > buffer_positional_options_count:
149 raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
151 options = {}
152 for name, (value, pos) in dictargs.iteritems():
153 if not name in buffer_options:
154 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
155 options[name.encode("ASCII")] = value
157 for name, (value, pos) in zip(buffer_options, posargs):
158 if not name in buffer_options:
159 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
160 if name in options:
161 raise CompileError(pos, ERR_BUF_DUP % name)
162 options[name] = value
164 # Check that they are all there and copy defaults
165 for name in buffer_options:
166 if not name in options:
167 try:
168 options[name] = defaults[name]
169 except KeyError:
170 if need_complete:
171 raise CompileError(globalpos, ERR_BUF_MISSING % name)
173 dtype = options.get("dtype")
174 if dtype and dtype.is_extension_type:
175 raise CompileError(globalpos, ERR_BUF_DTYPE)
177 ndim = options.get("ndim")
178 if ndim and (not isinstance(ndim, int) or ndim < 0):
179 raise CompileError(globalpos, ERR_BUF_NDIM)
181 mode = options.get("mode")
182 if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
183 raise CompileError(globalpos, ERR_BUF_MODE)
185 def assert_bool(name):
186 x = options.get(name)
187 if not isinstance(x, bool):
188 raise CompileError(globalpos, ERR_BUF_BOOL % name)
190 assert_bool('negative_indices')
191 assert_bool('cast')
193 return options
196 #
197 # Code generation
198 #
201 def get_flags(buffer_aux, buffer_type):
202 flags = 'PyBUF_FORMAT'
203 mode = buffer_type.mode
204 if mode == 'full':
205 flags += '| PyBUF_INDIRECT'
206 elif mode == 'strided':
207 flags += '| PyBUF_STRIDES'
208 elif mode == 'c':
209 flags += '| PyBUF_C_CONTIGUOUS'
210 elif mode == 'fortran':
211 flags += '| PyBUF_F_CONTIGUOUS'
212 else:
213 assert False
214 if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
215 return flags
217 def used_buffer_aux_vars(entry):
218 buffer_aux = entry.buffer_aux
219 buffer_aux.buffer_info_var.used = True
220 for s in buffer_aux.shapevars: s.used = True
221 for s in buffer_aux.stridevars: s.used = True
222 if buffer_aux.suboffsetvars:
223 for s in buffer_aux.suboffsetvars: s.used = True
225 def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
226 # Generate code to copy the needed struct info into local
227 # variables.
228 bufstruct = buffer_aux.buffer_info_var.cname
230 varspec = [("strides", buffer_aux.stridevars),
231 ("shape", buffer_aux.shapevars)]
232 if mode == 'full':
233 varspec.append(("suboffsets", buffer_aux.suboffsetvars))
235 for field, vars in varspec:
236 code.putln(" ".join(["%s = %s.%s[%d];" %
237 (s.cname, bufstruct, field, idx)
238 for idx, s in enumerate(vars)]))
240 def put_acquire_arg_buffer(entry, code, pos):
241 code.globalstate.use_utility_code(acquire_utility_code)
242 buffer_aux = entry.buffer_aux
243 getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
245 # Acquire any new buffer
246 code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
247 getbuffer_cname,
248 entry.cname,
249 entry.buffer_aux.buffer_info_var.cname,
250 get_flags(buffer_aux, entry.type),
251 entry.type.ndim,
252 int(entry.type.cast)), pos))
253 # An exception raised in arg parsing cannot be catched, so no
254 # need to care about the buffer then.
255 put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
257 #def put_release_buffer_normal(entry, code):
258 # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
259 # entry.cname,
260 # entry.cname,
261 # entry.buffer_aux.buffer_info_var.cname))
263 def get_release_buffer_code(entry):
264 return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
266 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
267 is_initialized, pos, code):
268 """
269 Generate code for reassigning a buffer variables. This only deals with getting
270 the buffer auxiliary structure and variables set up correctly, the assignment
271 itself and refcounting is the responsibility of the caller.
273 However, the assignment operation may throw an exception so that the reassignment
274 never happens.
276 Depending on the circumstances there are two possible outcomes:
277 - Old buffer released, new acquired, rhs assigned to lhs
278 - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
279 (which may or may not succeed).
280 """
282 code.globalstate.use_utility_code(acquire_utility_code)
283 bufstruct = buffer_aux.buffer_info_var.cname
284 flags = get_flags(buffer_aux, buffer_type)
286 getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
287 # note: object is filled in later (%%s)
288 bufstruct,
289 flags,
290 buffer_type.ndim,
291 int(buffer_type.cast))
293 if is_initialized:
294 # Release any existing buffer
295 code.putln('__Pyx_SafeReleaseBuffer(&%s);' % bufstruct)
296 # Acquire
297 retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
298 code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
299 code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
300 # If acquisition failed, attempt to reacquire the old buffer
301 # before raising the exception. A failure of reacquisition
302 # will cause the reacquisition exception to be reported, one
303 # can consider working around this later.
304 code.begin_block()
305 type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
306 for i in range(3)]
307 code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
308 code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
309 code.begin_block()
310 code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
311 code.globalstate.use_utility_code(raise_buffer_fallback_code)
312 code.putln('__Pyx_RaiseBufferFallbackError();')
313 code.putln('} else {')
314 code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
315 for t in (type, value, tb):
316 code.funcstate.release_temp(t)
317 code.end_block()
318 # Unpack indices
319 code.end_block()
320 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
321 code.putln(code.error_goto_if_neg(retcode_cname, pos))
322 code.funcstate.release_temp(retcode_cname)
323 else:
324 # Our entry had no previous value, so set to None when acquisition fails.
325 # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
326 # so it suffices to set the buf field to NULL.
327 code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
328 code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
329 (lhs_cname,
330 PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
331 bufstruct))
332 code.putln(code.error_goto(pos))
333 code.put('} else {')
334 # Unpack indices
335 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
336 code.putln('}')
339 def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
340 """
341 Generates code to process indices and calculate an offset into
342 a buffer. Returns a C string which gives a pointer which can be
343 read from or written to at will (it is an expression so caller should
344 store it in a temporary if it is used more than once).
346 As the bounds checking can have any number of combinations of unsigned
347 arguments, smart optimizations etc. we insert it directly in the function
348 body. The lookup however is delegated to a inline function that is instantiated
349 once per ndim (lookup with suboffsets tend to get quite complicated).
351 """
352 bufaux = entry.buffer_aux
353 bufstruct = bufaux.buffer_info_var.cname
354 negative_indices = entry.type.negative_indices
356 if options['boundscheck']:
357 # Check bounds and fix negative indices.
358 # We allocate a temporary which is initialized to -1, meaning OK (!).
359 # If an error occurs, the temp is set to the dimension index the
360 # error is occuring at.
361 tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
362 code.putln("%s = -1;" % tmp_cname)
363 for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
364 bufaux.shapevars)):
365 if signed != 0:
366 # not unsigned, deal with negative index
367 code.putln("if (%s < 0) {" % cname)
368 if negative_indices:
369 code.putln("%s += %s;" % (cname, shape.cname))
370 code.putln("if (%s) %s = %d;" % (
371 code.unlikely("%s < 0" % cname), tmp_cname, dim))
372 else:
373 code.putln("%s = %d;" % (tmp_cname, dim))
374 code.put("} else ")
375 # check bounds in positive direction
376 code.putln("if (%s) %s = %d;" % (
377 code.unlikely("%s >= %s" % (cname, shape.cname)),
378 tmp_cname, dim))
379 code.globalstate.use_utility_code(raise_indexerror_code)
380 code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
381 code.begin_block()
382 code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
383 code.putln(code.error_goto(pos))
384 code.end_block()
385 code.funcstate.release_temp(tmp_cname)
386 elif negative_indices:
387 # Only fix negative indices.
388 for signed, cname, shape in zip(index_signeds, index_cnames,
389 bufaux.shapevars):
390 if signed != 0:
391 code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
393 # Create buffer lookup and return it
394 # This is done via utility macros/inline functions, which vary
395 # according to the access mode used.
396 params = []
397 nd = entry.type.ndim
398 mode = entry.type.mode
399 if mode == 'full':
400 for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
401 params.append(i)
402 params.append(s.cname)
403 params.append(o.cname)
404 funcname = "__Pyx_BufPtrFull%dd" % nd
405 funcgen = buf_lookup_full_code
406 else:
407 if mode == 'strided':
408 funcname = "__Pyx_BufPtrStrided%dd" % nd
409 funcgen = buf_lookup_strided_code
410 elif mode == 'c':
411 funcname = "__Pyx_BufPtrCContig%dd" % nd
412 funcgen = buf_lookup_c_code
413 elif mode == 'fortran':
414 funcname = "__Pyx_BufPtrFortranContig%dd" % nd
415 funcgen = buf_lookup_fortran_code
416 else:
417 assert False
418 for i, s in zip(index_cnames, bufaux.stridevars):
419 params.append(i)
420 params.append(s.cname)
422 # Make sure the utility code is available
423 code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
425 ptr_type = entry.type.buffer_ptr_type
426 ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
427 ptr_type.declaration_code(""),
428 bufstruct,
429 ", ".join(params))
430 return ptrcode
433 def use_empty_bufstruct_code(env, max_ndim):
434 code = dedent("""
435 Py_ssize_t __Pyx_zeros[] = {%s};
436 Py_ssize_t __Pyx_minusones[] = {%s};
437 """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
438 env.use_utility_code(UtilityCode(proto=code), "empty_bufstruct_code")
441 def buf_lookup_full_code(proto, defin, name, nd):
442 """
443 Generates a buffer lookup function for the right number
444 of dimensions. The function gives back a void* at the right location.
445 """
446 # _i_ndex, _s_tride, sub_o_ffset
447 macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
448 proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
450 funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
451 proto.putln("static INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
452 defin.putln(dedent("""
453 static INLINE void* %s_imp(void* buf, %s) {
454 char* ptr = (char*)buf;
455 """) % (name, funcargs) + "".join([dedent("""\
456 ptr += s%d * i%d;
457 if (o%d >= 0) ptr = *((char**)ptr) + o%d;
458 """) % (i, i, i, i) for i in range(nd)]
459 ) + "\nreturn ptr;\n}")
461 def buf_lookup_strided_code(proto, defin, name, nd):
462 """
463 Generates a buffer lookup function for the right number
464 of dimensions. The function gives back a void* at the right location.
465 """
466 # _i_ndex, _s_tride
467 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
468 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
469 proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
471 def buf_lookup_c_code(proto, defin, name, nd):
472 """
473 Similar to strided lookup, but can assume that the last dimension
474 doesn't need a multiplication as long as.
475 Still we keep the same signature for now.
476 """
477 if nd == 1:
478 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
479 else:
480 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
481 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
482 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
484 def buf_lookup_fortran_code(proto, defin, name, nd):
485 """
486 Like C lookup, but the first index is optimized instead.
487 """
488 if nd == 1:
489 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
490 else:
491 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
492 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
493 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
495 #
496 # Utils for creating type string checkers
497 #
498 def mangle_dtype_name(dtype):
499 # Use prefixes to seperate user defined types from builtins
500 # (consider "typedef float unsigned_int")
501 if dtype.is_pyobject:
502 return "object"
503 elif dtype.is_ptr:
504 return "ptr"
505 else:
506 if dtype.is_typedef or dtype.is_struct_or_union:
507 prefix = "nn_"
508 else:
509 prefix = ""
510 return prefix + dtype.declaration_code("").replace(" ", "_")
512 def get_typestringchecker(code, dtype):
513 """
514 Returns the name of a typestring checker with the given type; emitting
515 it to code if needed.
516 """
517 name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
518 code.globalstate.use_code_from(create_typestringchecker,
519 name,
520 dtype=dtype)
521 return name
523 def create_typestringchecker(protocode, defcode, name, dtype):
525 def put_assert(cond, msg):
526 defcode.putln("if (!(%s)) {" % cond)
527 defcode.putln('PyErr_Format(PyExc_ValueError, "Buffer dtype mismatch (%s)", __Pyx_DescribeTokenInFormatString(ts));' % msg)
528 defcode.putln("return NULL;")
529 defcode.putln("}")
531 if dtype.is_error: return
532 simple = dtype.is_simple_buffer_dtype()
533 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
534 # Cannot add utility code recursively...
535 if not simple:
536 dtype_t = dtype.declaration_code("")
537 protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
538 fields = dtype.scope.var_entries
540 # divide fields into blocks of equal type (for repeat count)
541 field_blocks = [] # of (n, type, checkerfunc)
542 n = 0
543 prevtype = None
544 for f in fields:
545 if n and f.type != prevtype:
546 field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
547 n = 0
548 prevtype = f.type
549 n += 1
550 field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
552 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
553 defcode.putln("static const char* %s(const char* ts) {" % name)
554 if simple:
555 defcode.putln("int ok;")
556 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
557 defcode.putln("if (*ts == '1') ++ts;")
558 if dtype.is_pyobject:
559 defcode.putln("ok = (*ts == 'O');")
560 else:
561 # Cannot trust declared size; but rely on int vs float and
562 # signed/unsigned to be correctly declared. Use a switch statement
563 # on all possible format codes to validate that the size is ok.
564 # (Note that many codes may map to same size, e.g. 'i' and 'l'
565 # may both be four bytes).
566 ctype = dtype.declaration_code("")
567 defcode.putln("switch (*ts) {")
568 if dtype.is_int:
569 types = [
570 ('b', 'char'), ('h', 'short'), ('i', 'int'),
571 ('l', 'long'), ('q', 'long long')
572 ]
573 elif dtype.is_float:
574 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
575 else:
576 assert False
577 if dtype.signed == 0:
578 for char, against in types:
579 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
580 (char.upper(), ctype, against, ctype))
581 else:
582 for char, against in types:
583 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
584 (char, ctype, against, ctype))
585 defcode.putln("default: ok = 0;")
586 defcode.putln("}")
587 put_assert("ok", "expected %s, got %%s" % dtype)
588 defcode.putln("++ts;")
589 elif complex_possible:
590 # Could be a struct representing a complex number, so allow
591 # for parsing a "Zf" spec.
592 real_t, imag_t = [x.type for x in fields]
593 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
594 defcode.putln("if (*ts == '1') ++ts;")
595 defcode.putln("if (*ts == 'Z') {")
596 if len(field_blocks) == 2:
597 # Different float type, sizeof check needed
598 defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
599 real_t.declaration_code(""),
600 imag_t.declaration_code("")))
601 defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
602 dtype, real_t, imag_t))
603 defcode.putln("return NULL;")
604 defcode.putln("}")
605 check_real, check_imag = [x[2] for x in field_blocks]
606 else:
607 assert len(field_blocks) == 1
608 check_real = check_imag = field_blocks[0][2]
609 defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
610 defcode.putln("} else {")
611 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
612 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
613 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
614 defcode.putln("}")
615 else:
616 defcode.putln("int n, count;")
617 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
619 next_types = [x[1] for x in field_blocks[1:]] + ["end"]
620 for (n, type, checker), next_type in zip(field_blocks, next_types):
621 if n == 1:
622 defcode.putln("if (*ts == '1') ++ts;")
623 else:
624 defcode.putln("n = %d;" % n);
625 defcode.putln("do {")
626 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
627 put_assert("n >= 0", "expected %s, got %%s" % next_type)
629 simple = type.is_simple_buffer_dtype()
630 if not simple:
631 put_assert("*ts == 'T' && *(ts+1) == '{'", "expected %s, got %%s" % type)
632 defcode.putln("ts += 2;")
633 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
634 if not simple:
635 put_assert("*ts == '}'", "expected end of %s struct, got %%s" % type)
636 defcode.putln("++ts;")
638 if n > 1:
639 defcode.putln("} while (n > 0);");
640 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
642 defcode.putln("return ts;")
643 defcode.putln("}")
645 def get_getbuffer_code(dtype, code):
646 """
647 Generate a utility function for getting a buffer for the given dtype.
648 The function will:
649 - Call PyObject_GetBuffer
650 - Check that ndim matched the expected value
651 - Check that the format string is right
652 - Set suboffsets to all -1 if it is returned as NULL.
653 """
655 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
656 if not code.globalstate.has_code(name):
657 code.globalstate.use_utility_code(acquire_utility_code)
658 typestringchecker = get_typestringchecker(code, dtype)
659 dtype_name = str(dtype)
660 dtype_cname = dtype.declaration_code("")
661 utilcode = UtilityCode(proto = dedent("""
662 static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
663 """) % name, impl = dedent("""
664 static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
665 const char* ts;
666 if (obj == Py_None) {
667 __Pyx_ZeroBuffer(buf);
668 return 0;
669 }
670 buf->buf = NULL;
671 if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
672 if (buf->ndim != nd) {
673 __Pyx_BufferNdimError(buf, nd);
674 goto fail;
675 }
676 if (!cast) {
677 ts = buf->format;
678 ts = __Pyx_ConsumeWhitespace(ts);
679 if (!ts) goto fail;
680 ts = %(typestringchecker)s(ts);
681 if (!ts) goto fail;
682 ts = __Pyx_ConsumeWhitespace(ts);
683 if (!ts) goto fail;
684 if (*ts != 0) {
685 PyErr_Format(PyExc_ValueError,
686 "Buffer dtype mismatch (expected end, got %%s)",
687 __Pyx_DescribeTokenInFormatString(ts));
688 goto fail;
689 }
690 } else {
691 if (buf->itemsize != sizeof(%(dtype_cname)s)) {
692 PyErr_SetString(PyExc_ValueError,
693 "Attempted cast of buffer to datatype of different size.");
694 goto fail;
695 }
696 }
697 if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
698 return 0;
699 fail:;
700 __Pyx_ZeroBuffer(buf);
701 return -1;
702 }""") % locals())
703 code.globalstate.use_utility_code(utilcode, name)
704 return name
706 def use_py2_buffer_functions(env):
707 # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
708 # For >= 2.6 we do double mode -- use the new buffer interface on objects
709 # which has the right tp_flags set, but emulation otherwise.
710 codename = "PyObject_GetBuffer" # just a representative unique key
712 # Search all types for __getbuffer__ overloads
713 types = []
714 visited_scopes = set()
715 def find_buffer_types(scope):
716 if scope in visited_scopes:
717 return
718 visited_scopes.add(scope)
719 for m in scope.cimported_modules:
720 find_buffer_types(m)
721 for e in scope.type_entries:
722 t = e.type
723 if t.is_extension_type:
724 release = get = None
725 for x in t.scope.pyfunc_entries:
726 if x.name == u"__getbuffer__": get = x.func_cname
727 elif x.name == u"__releasebuffer__": release = x.func_cname
728 if get:
729 types.append((t.typeptr_cname, get, release))
731 find_buffer_types(env)
733 code = dedent("""
734 #if PY_MAJOR_VERSION < 3
735 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
736 #if PY_VERSION_HEX >= 0x02060000
737 if (Py_TYPE(obj)->tp_flags & Py_TPFLAGS_HAVE_NEWBUFFER)
738 return PyObject_GetBuffer(obj, view, flags);
739 #endif
740 """)
741 if len(types) > 0:
742 clause = "if"
743 for t, get, release in types:
744 code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
745 clause = "else if"
746 code += " else {\n"
747 code += dedent("""\
748 PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
749 return -1;
750 """, 2)
751 if len(types) > 0: code += " }"
752 code += dedent("""
753 }
755 static void __Pyx_ReleaseBuffer(Py_buffer *view) {
756 PyObject* obj = view->obj;
757 if (obj) {
758 """)
759 if len(types) > 0:
760 clause = "if"
761 for t, get, release in types:
762 if release:
763 code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
764 clause = "else if"
765 code += dedent("""
766 Py_DECREF(obj);
767 view->obj = NULL;
768 }
769 }
771 #endif
772 """)
774 env.use_utility_code(UtilityCode(
775 proto = dedent("""\
776 #if PY_MAJOR_VERSION < 3
777 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
778 static void __Pyx_ReleaseBuffer(Py_buffer *view);
779 #else
780 #define __Pyx_GetBuffer PyObject_GetBuffer
781 #define __Pyx_ReleaseBuffer PyBuffer_Release
782 #endif
783 """), impl = code), codename)
785 #
786 # Static utility code
787 #
790 # Utility function to set the right exception
791 # The caller should immediately goto_error
792 raise_indexerror_code = UtilityCode(
793 proto = """\
794 static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
795 """,
796 impl = """\
797 static void __Pyx_RaiseBufferIndexError(int axis) {
798 PyErr_Format(PyExc_IndexError,
799 "Out of bounds on buffer access (axis %d)", axis);
800 }
802 """)
804 #
805 # Buffer type checking. Utility code for checking that acquired
806 # buffers match our assumptions. We only need to check ndim and
807 # the format string; the access mode/flags is checked by the
808 # exporter.
809 #
810 acquire_utility_code = UtilityCode(
811 proto = """\
812 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
813 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
814 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
815 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
816 static const char* __Pyx_DescribeTokenInFormatString(const char* ts); /*proto*/
817 """,
818 impl = """
819 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
820 if (info->buf == NULL) return;
821 if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
822 __Pyx_ReleaseBuffer(info);
823 }
825 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
826 buf->buf = NULL;
827 buf->obj = NULL;
828 buf->strides = __Pyx_zeros;
829 buf->shape = __Pyx_zeros;
830 buf->suboffsets = __Pyx_minusones;
831 }
833 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
834 while (1) {
835 switch (*ts) {
836 case '@':
837 case 10:
838 case 13:
839 case ' ':
840 ++ts;
841 break;
842 case '=':
843 case '<':
844 case '>':
845 case '!':
846 PyErr_SetString(PyExc_ValueError, "Buffer acquisition error: Only native byte order, size and alignment supported.");
847 return NULL;
848 default:
849 return ts;
850 }
851 }
852 }
854 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
855 PyErr_Format(PyExc_ValueError,
856 "Buffer has wrong number of dimensions (expected %d, got %d)",
857 expected_ndim, buffer->ndim);
858 }
860 static const char* __Pyx_DescribeTokenInFormatString(const char* ts) {
861 switch (*ts) {
862 case 'b': return "char";
863 case 'B': return "unsigned char";
864 case 'h': return "short";
865 case 'H': return "unsigned short";
866 case 'i': return "int";
867 case 'I': return "unsigned int";
868 case 'l': return "long";
869 case 'L': return "unsigned long";
870 case 'q': return "long long";
871 case 'Q': return "unsigned long long";
872 case 'f': return "float";
873 case 'd': return "double";
874 case 'g': return "long double";
875 case 'Z': switch (*(ts+1)) {
876 case 'f': return "complex float";
877 case 'd': return "complex double";
878 case 'g': return "complex long double";
879 default: return "unparseable format string";
880 }
881 case 'T': return "a struct";
882 case 'O': return "Python object";
883 case 'P': return "a pointer";
884 default: return "unparseable format string";
885 }
886 }
888 """)
891 parse_typestring_repeat_code = UtilityCode(
892 proto = """
893 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
894 """,
895 impl = """
896 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
897 int count;
898 if (*ts < '0' || *ts > '9') {
899 count = 1;
900 } else {
901 count = *ts++ - '0';
902 while (*ts >= '0' && *ts < '9') {
903 count *= 10;
904 count += *ts++ - '0';
905 }
906 }
907 *out_count = count;
908 return ts;
909 }
910 """)
912 raise_buffer_fallback_code = UtilityCode(
913 proto = """
914 static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
915 """,
916 impl = """
917 static void __Pyx_RaiseBufferFallbackError(void) {
918 PyErr_Format(PyExc_ValueError,
919 "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
920 }
922 """)
