Cython has moved to github.

cython

view Cython/Compiler/Buffer.py @ 1381:43d4e2b19134

fix for infinite loop in buffer code
author DagSverreSeljebotn
date Tue Nov 25 12:30:05 2008 -0800 (3 years ago)
parents 3421c9767918
children 76b485a17607 4a96e1aff2d4
line source
1 from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
2 from Cython.Compiler.ModuleNode import ModuleNode
3 from Cython.Compiler.Nodes import *
4 from Cython.Compiler.ExprNodes import *
5 from Cython.Compiler.StringEncoding import EncodedString
6 from Cython.Compiler.Errors import CompileError
7 from Cython.Utils import UtilityCode
8 import Interpreter
9 import PyrexTypes
11 try:
12 set
13 except NameError:
14 from sets import Set as set
16 import textwrap
18 # Code cleanup ideas:
19 # - One could be more smart about casting in some places
20 # - Start using CCodeWriters to generate utility functions
21 # - Create a struct type per ndim rather than keeping loose local vars
24 def dedent(text, reindent=0):
25 text = textwrap.dedent(text)
26 if reindent > 0:
27 indent = " " * reindent
28 text = '\n'.join([indent + x for x in text.split('\n')])
29 return text
31 class IntroduceBufferAuxiliaryVars(CythonTransform):
33 #
34 # Entry point
35 #
37 buffers_exists = False
39 def __call__(self, node):
40 assert isinstance(node, ModuleNode)
41 self.max_ndim = 0
42 result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
43 if self.buffers_exists:
44 use_py2_buffer_functions(node.scope)
45 use_empty_bufstruct_code(node.scope, self.max_ndim)
46 return result
49 #
50 # Basic operations for transforms
51 #
52 def handle_scope(self, node, scope):
53 # For all buffers, insert extra variables in the scope.
54 # The variables are also accessible from the buffer_info
55 # on the buffer entry
56 bufvars = [entry for name, entry
57 in scope.entries.iteritems()
58 if entry.type.is_buffer]
59 if len(bufvars) > 0:
60 self.buffers_exists = True
63 if isinstance(node, ModuleNode) and len(bufvars) > 0:
64 # for now...note that pos is wrong
65 raise CompileError(node.pos, "Buffer vars not allowed in module scope")
66 for entry in bufvars:
67 name = entry.name
68 buftype = entry.type
69 if buftype.ndim > self.max_ndim:
70 self.max_ndim = buftype.ndim
72 # Declare auxiliary vars
73 cname = scope.mangle(Naming.bufstruct_prefix, name)
74 bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
75 type=PyrexTypes.c_py_buffer_type, pos=node.pos)
76 if entry.is_arg:
77 bufinfo.used = True # otherwise, NameNode will mark whether it is used
79 def var(prefix, idx, initval):
80 cname = scope.mangle(prefix, "%d_%s" % (idx, name))
81 result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
82 node.pos, cname=cname, is_cdef=True)
84 result.init = initval
85 if entry.is_arg:
86 result.used = True
87 return result
90 stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
91 shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
92 mode = entry.type.mode
93 if mode == 'full':
94 suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
95 else:
96 suboffsetvars = None
98 entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
100 scope.buffer_entries = bufvars
101 self.scope = scope
103 def visit_ModuleNode(self, node):
104 self.handle_scope(node, node.scope)
105 self.visitchildren(node)
106 return node
108 def visit_FuncDefNode(self, node):
109 self.handle_scope(node, node.local_scope)
110 self.visitchildren(node)
111 return node
113 #
114 # Analysis
115 #
116 buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
117 buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
118 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
120 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
121 ERR_BUF_TOO_MANY = 'Too many buffer options'
122 ERR_BUF_DUP = '"%s" buffer option already supplied'
123 ERR_BUF_MISSING = '"%s" missing'
124 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
125 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
126 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
127 ERR_BUF_BOOL = '"%s" must be a boolean'
129 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
130 """
131 Must be called during type analysis, as analyse is called
132 on the dtype argument.
134 posargs and dictargs should consist of a list and a dict
135 of tuples (value, pos). Defaults should be a dict of values.
137 Returns a dict containing all the options a buffer can have and
138 its value (with the positions stripped).
139 """
140 if defaults is None:
141 defaults = buffer_defaults
143 posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
145 if len(posargs) > buffer_positional_options_count:
146 raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
148 options = {}
149 for name, (value, pos) in dictargs.iteritems():
150 if not name in buffer_options:
151 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
152 options[name.encode("ASCII")] = value
154 for name, (value, pos) in zip(buffer_options, posargs):
155 if not name in buffer_options:
156 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
157 if name in options:
158 raise CompileError(pos, ERR_BUF_DUP % name)
159 options[name] = value
161 # Check that they are all there and copy defaults
162 for name in buffer_options:
163 if not name in options:
164 try:
165 options[name] = defaults[name]
166 except KeyError:
167 if need_complete:
168 raise CompileError(globalpos, ERR_BUF_MISSING % name)
170 dtype = options.get("dtype")
171 if dtype and dtype.is_extension_type:
172 raise CompileError(globalpos, ERR_BUF_DTYPE)
174 ndim = options.get("ndim")
175 if ndim and (not isinstance(ndim, int) or ndim < 0):
176 raise CompileError(globalpos, ERR_BUF_NDIM)
178 mode = options.get("mode")
179 if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
180 raise CompileError(globalpos, ERR_BUF_MODE)
182 def assert_bool(name):
183 x = options.get(name)
184 if not isinstance(x, bool):
185 raise CompileError(globalpos, ERR_BUF_BOOL % name)
187 assert_bool('negative_indices')
188 assert_bool('cast')
190 return options
193 #
194 # Code generation
195 #
198 def get_flags(buffer_aux, buffer_type):
199 flags = 'PyBUF_FORMAT'
200 mode = buffer_type.mode
201 if mode == 'full':
202 flags += '| PyBUF_INDIRECT'
203 elif mode == 'strided':
204 flags += '| PyBUF_STRIDES'
205 elif mode == 'c':
206 flags += '| PyBUF_C_CONTIGUOUS'
207 elif mode == 'fortran':
208 flags += '| PyBUF_F_CONTIGUOUS'
209 else:
210 assert False
211 if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
212 return flags
214 def used_buffer_aux_vars(entry):
215 buffer_aux = entry.buffer_aux
216 buffer_aux.buffer_info_var.used = True
217 for s in buffer_aux.shapevars: s.used = True
218 for s in buffer_aux.stridevars: s.used = True
219 if buffer_aux.suboffsetvars:
220 for s in buffer_aux.suboffsetvars: s.used = True
222 def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
223 # Generate code to copy the needed struct info into local
224 # variables.
225 bufstruct = buffer_aux.buffer_info_var.cname
227 varspec = [("strides", buffer_aux.stridevars),
228 ("shape", buffer_aux.shapevars)]
229 if mode == 'full':
230 varspec.append(("suboffsets", buffer_aux.suboffsetvars))
232 for field, vars in varspec:
233 code.putln(" ".join(["%s = %s.%s[%d];" %
234 (s.cname, bufstruct, field, idx)
235 for idx, s in enumerate(vars)]))
237 def put_acquire_arg_buffer(entry, code, pos):
238 code.globalstate.use_utility_code(acquire_utility_code)
239 buffer_aux = entry.buffer_aux
240 getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
242 # Acquire any new buffer
243 code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
244 getbuffer_cname,
245 entry.cname,
246 entry.buffer_aux.buffer_info_var.cname,
247 get_flags(buffer_aux, entry.type),
248 entry.type.ndim,
249 int(entry.type.cast)), pos))
250 # An exception raised in arg parsing cannot be catched, so no
251 # need to care about the buffer then.
252 put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
254 #def put_release_buffer_normal(entry, code):
255 # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
256 # entry.cname,
257 # entry.cname,
258 # entry.buffer_aux.buffer_info_var.cname))
260 def get_release_buffer_code(entry):
261 return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
263 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
264 is_initialized, pos, code):
265 """
266 Generate code for reassigning a buffer variables. This only deals with getting
267 the buffer auxiliary structure and variables set up correctly, the assignment
268 itself and refcounting is the responsibility of the caller.
270 However, the assignment operation may throw an exception so that the reassignment
271 never happens.
273 Depending on the circumstances there are two possible outcomes:
274 - Old buffer released, new acquired, rhs assigned to lhs
275 - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
276 (which may or may not succeed).
277 """
279 code.globalstate.use_utility_code(acquire_utility_code)
280 bufstruct = buffer_aux.buffer_info_var.cname
281 flags = get_flags(buffer_aux, buffer_type)
283 getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
284 # note: object is filled in later (%%s)
285 bufstruct,
286 flags,
287 buffer_type.ndim,
288 int(buffer_type.cast))
290 if is_initialized:
291 # Release any existing buffer
292 code.putln('__Pyx_SafeReleaseBuffer(&%s);' % bufstruct)
293 # Acquire
294 retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
295 code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
296 code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
297 # If acquisition failed, attempt to reacquire the old buffer
298 # before raising the exception. A failure of reacquisition
299 # will cause the reacquisition exception to be reported, one
300 # can consider working around this later.
301 code.begin_block()
302 type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
303 for i in range(3)]
304 code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
305 code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
306 code.begin_block()
307 code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
308 code.globalstate.use_utility_code(raise_buffer_fallback_code)
309 code.putln('__Pyx_RaiseBufferFallbackError();')
310 code.putln('} else {')
311 code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
312 for t in (type, value, tb):
313 code.funcstate.release_temp(t)
314 code.end_block()
315 # Unpack indices
316 code.end_block()
317 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
318 code.putln(code.error_goto_if_neg(retcode_cname, pos))
319 code.funcstate.release_temp(retcode_cname)
320 else:
321 # Our entry had no previous value, so set to None when acquisition fails.
322 # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
323 # so it suffices to set the buf field to NULL.
324 code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
325 code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
326 (lhs_cname,
327 PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
328 bufstruct))
329 code.putln(code.error_goto(pos))
330 code.put('} else {')
331 # Unpack indices
332 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
333 code.putln('}')
336 def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
337 """
338 Generates code to process indices and calculate an offset into
339 a buffer. Returns a C string which gives a pointer which can be
340 read from or written to at will (it is an expression so caller should
341 store it in a temporary if it is used more than once).
343 As the bounds checking can have any number of combinations of unsigned
344 arguments, smart optimizations etc. we insert it directly in the function
345 body. The lookup however is delegated to a inline function that is instantiated
346 once per ndim (lookup with suboffsets tend to get quite complicated).
348 """
349 bufaux = entry.buffer_aux
350 bufstruct = bufaux.buffer_info_var.cname
351 negative_indices = entry.type.negative_indices
353 if options['boundscheck']:
354 # Check bounds and fix negative indices.
355 # We allocate a temporary which is initialized to -1, meaning OK (!).
356 # If an error occurs, the temp is set to the dimension index the
357 # error is occuring at.
358 tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
359 code.putln("%s = -1;" % tmp_cname)
360 for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
361 bufaux.shapevars)):
362 if signed != 0:
363 # not unsigned, deal with negative index
364 code.putln("if (%s < 0) {" % cname)
365 if negative_indices:
366 code.putln("%s += %s;" % (cname, shape.cname))
367 code.putln("if (%s) %s = %d;" % (
368 code.unlikely("%s < 0" % cname), tmp_cname, dim))
369 else:
370 code.putln("%s = %d;" % (tmp_cname, dim))
371 code.put("} else ")
372 # check bounds in positive direction
373 code.putln("if (%s) %s = %d;" % (
374 code.unlikely("%s >= %s" % (cname, shape.cname)),
375 tmp_cname, dim))
376 code.globalstate.use_utility_code(raise_indexerror_code)
377 code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
378 code.begin_block()
379 code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
380 code.putln(code.error_goto(pos))
381 code.end_block()
382 code.funcstate.release_temp(tmp_cname)
383 elif negative_indices:
384 # Only fix negative indices.
385 for signed, cname, shape in zip(index_signeds, index_cnames,
386 bufaux.shapevars):
387 if signed != 0:
388 code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
390 # Create buffer lookup and return it
391 # This is done via utility macros/inline functions, which vary
392 # according to the access mode used.
393 params = []
394 nd = entry.type.ndim
395 mode = entry.type.mode
396 if mode == 'full':
397 for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
398 params.append(i)
399 params.append(s.cname)
400 params.append(o.cname)
401 funcname = "__Pyx_BufPtrFull%dd" % nd
402 funcgen = buf_lookup_full_code
403 else:
404 if mode == 'strided':
405 funcname = "__Pyx_BufPtrStrided%dd" % nd
406 funcgen = buf_lookup_strided_code
407 elif mode == 'c':
408 funcname = "__Pyx_BufPtrCContig%dd" % nd
409 funcgen = buf_lookup_c_code
410 elif mode == 'fortran':
411 funcname = "__Pyx_BufPtrFortranContig%dd" % nd
412 funcgen = buf_lookup_fortran_code
413 else:
414 assert False
415 for i, s in zip(index_cnames, bufaux.stridevars):
416 params.append(i)
417 params.append(s.cname)
419 # Make sure the utility code is available
420 code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
422 ptr_type = entry.type.buffer_ptr_type
423 ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
424 ptr_type.declaration_code(""),
425 bufstruct,
426 ", ".join(params))
427 return ptrcode
430 def use_empty_bufstruct_code(env, max_ndim):
431 code = dedent("""
432 Py_ssize_t __Pyx_zeros[] = {%s};
433 Py_ssize_t __Pyx_minusones[] = {%s};
434 """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
435 env.use_utility_code(UtilityCode(proto=code), "empty_bufstruct_code")
438 def buf_lookup_full_code(proto, defin, name, nd):
439 """
440 Generates a buffer lookup function for the right number
441 of dimensions. The function gives back a void* at the right location.
442 """
443 # _i_ndex, _s_tride, sub_o_ffset
444 macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
445 proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
447 funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
448 proto.putln("static INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
449 defin.putln(dedent("""
450 static INLINE void* %s_imp(void* buf, %s) {
451 char* ptr = (char*)buf;
452 """) % (name, funcargs) + "".join([dedent("""\
453 ptr += s%d * i%d;
454 if (o%d >= 0) ptr = *((char**)ptr) + o%d;
455 """) % (i, i, i, i) for i in range(nd)]
456 ) + "\nreturn ptr;\n}")
458 def buf_lookup_strided_code(proto, defin, name, nd):
459 """
460 Generates a buffer lookup function for the right number
461 of dimensions. The function gives back a void* at the right location.
462 """
463 # _i_ndex, _s_tride
464 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
465 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
466 proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
468 def buf_lookup_c_code(proto, defin, name, nd):
469 """
470 Similar to strided lookup, but can assume that the last dimension
471 doesn't need a multiplication as long as.
472 Still we keep the same signature for now.
473 """
474 if nd == 1:
475 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
476 else:
477 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
478 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
479 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
481 def buf_lookup_fortran_code(proto, defin, name, nd):
482 """
483 Like C lookup, but the first index is optimized instead.
484 """
485 if nd == 1:
486 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
487 else:
488 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
489 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
490 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
492 #
493 # Utils for creating type string checkers
494 #
495 def mangle_dtype_name(dtype):
496 # Use prefixes to seperate user defined types from builtins
497 # (consider "typedef float unsigned_int")
498 if dtype.is_pyobject:
499 return "object"
500 elif dtype.is_ptr:
501 return "ptr"
502 else:
503 if dtype.typestring is None:
504 prefix = "nn_"
505 else:
506 prefix = ""
507 return prefix + dtype.declaration_code("").replace(" ", "_")
509 def get_typestringchecker(code, dtype):
510 """
511 Returns the name of a typestring checker with the given type; emitting
512 it to code if needed.
513 """
514 name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
515 code.globalstate.use_code_from(create_typestringchecker,
516 name,
517 dtype=dtype)
518 return name
520 def create_typestringchecker(protocode, defcode, name, dtype):
522 def put_assert(cond, msg):
523 defcode.putln("if (!(%s)) {" % cond)
524 defcode.putln('PyErr_Format(PyExc_ValueError, "Buffer dtype mismatch (%s)", __Pyx_DescribeTokenInFormatString(ts));' % msg)
525 defcode.putln("return NULL;")
526 defcode.putln("}")
528 if dtype.is_error: return
529 simple = dtype.is_simple_buffer_dtype()
530 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
531 # Cannot add utility code recursively...
532 if not simple:
533 dtype_t = dtype.declaration_code("")
534 protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
535 fields = dtype.scope.var_entries
537 # divide fields into blocks of equal type (for repeat count)
538 field_blocks = [] # of (n, type, checkerfunc)
539 n = 0
540 prevtype = None
541 for f in fields:
542 if n and f.type != prevtype:
543 field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
544 n = 0
545 prevtype = f.type
546 n += 1
547 field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
549 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
550 defcode.putln("static const char* %s(const char* ts) {" % name)
551 if simple:
552 defcode.putln("int ok;")
553 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
554 defcode.putln("if (*ts == '1') ++ts;")
555 if dtype.typestring is not None:
556 assert len(dtype.typestring) == 1
557 # Can use direct comparison
558 defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
559 else:
560 # Cannot trust declared size; but rely on int vs float and
561 # signed/unsigned to be correctly declared. Use a switch statement
562 # on all possible format codes to validate that the size is ok.
563 # (Note that many codes may map to same size, e.g. 'i' and 'l'
564 # may both be four bytes).
565 ctype = dtype.declaration_code("")
566 defcode.putln("switch (*ts) {")
567 if dtype.is_int:
568 types = [
569 ('b', 'char'), ('h', 'short'), ('i', 'int'),
570 ('l', 'long'), ('q', 'long long')
571 ]
572 elif dtype.is_float:
573 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
574 else:
575 assert False
576 if dtype.signed == 0:
577 for char, against in types:
578 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
579 (char.upper(), ctype, against, ctype))
580 else:
581 for char, against in types:
582 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
583 (char, ctype, against, ctype))
584 defcode.putln("default: ok = 0;")
585 defcode.putln("}")
586 put_assert("ok", "expected %s, got %%s" % dtype)
587 defcode.putln("++ts;")
588 elif complex_possible:
589 # Could be a struct representing a complex number, so allow
590 # for parsing a "Zf" spec.
591 real_t, imag_t = [x.type for x in fields]
592 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
593 defcode.putln("if (*ts == '1') ++ts;")
594 defcode.putln("if (*ts == 'Z') {")
595 if len(field_blocks) == 2:
596 # Different float type, sizeof check needed
597 defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
598 real_t.declaration_code(""),
599 imag_t.declaration_code("")))
600 defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
601 dtype, real_t, imag_t))
602 defcode.putln("return NULL;")
603 defcode.putln("}")
604 check_real, check_imag = [x[2] for x in field_blocks]
605 else:
606 assert len(field_blocks) == 1
607 check_real = check_imag = field_blocks[0][2]
608 defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
609 defcode.putln("} else {")
610 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
611 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
612 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
613 defcode.putln("}")
614 else:
615 defcode.putln("int n, count;")
616 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
618 next_types = [x[1] for x in field_blocks[1:]] + ["end"]
619 for (n, type, checker), next_type in zip(field_blocks, next_types):
620 if n == 1:
621 defcode.putln("if (*ts == '1') ++ts;")
622 else:
623 defcode.putln("n = %d;" % n);
624 defcode.putln("do {")
625 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
626 put_assert("n >= 0", "expected %s, got %%s" % next_type)
628 simple = type.is_simple_buffer_dtype()
629 if not simple:
630 put_assert("*ts == 'T' && *(ts+1) == '{'", "expected %s, got %%s" % type)
631 defcode.putln("ts += 2;")
632 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
633 if not simple:
634 put_assert("*ts == '}'", "expected end of %s struct, got %%s" % type)
635 defcode.putln("++ts;")
637 if n > 1:
638 defcode.putln("} while (n > 0);");
639 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
641 defcode.putln("return ts;")
642 defcode.putln("}")
644 def get_getbuffer_code(dtype, code):
645 """
646 Generate a utility function for getting a buffer for the given dtype.
647 The function will:
648 - Call PyObject_GetBuffer
649 - Check that ndim matched the expected value
650 - Check that the format string is right
651 - Set suboffsets to all -1 if it is returned as NULL.
652 """
654 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
655 if not code.globalstate.has_code(name):
656 code.globalstate.use_utility_code(acquire_utility_code)
657 typestringchecker = get_typestringchecker(code, dtype)
658 dtype_name = str(dtype)
659 dtype_cname = dtype.declaration_code("")
660 utilcode = UtilityCode(proto = dedent("""
661 static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
662 """) % name, impl = dedent("""
663 static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
664 const char* ts;
665 if (obj == Py_None) {
666 __Pyx_ZeroBuffer(buf);
667 return 0;
668 }
669 buf->buf = NULL;
670 if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
671 if (buf->ndim != nd) {
672 __Pyx_BufferNdimError(buf, nd);
673 goto fail;
674 }
675 if (!cast) {
676 ts = buf->format;
677 ts = __Pyx_ConsumeWhitespace(ts);
678 if (!ts) goto fail;
679 ts = %(typestringchecker)s(ts);
680 if (!ts) goto fail;
681 ts = __Pyx_ConsumeWhitespace(ts);
682 if (!ts) goto fail;
683 if (*ts != 0) {
684 PyErr_Format(PyExc_ValueError,
685 "Buffer dtype mismatch (expected end, got %%s)",
686 __Pyx_DescribeTokenInFormatString(ts));
687 goto fail;
688 }
689 } else {
690 if (buf->itemsize != sizeof(%(dtype_cname)s)) {
691 PyErr_SetString(PyExc_ValueError,
692 "Attempted cast of buffer to datatype of different size.");
693 goto fail;
694 }
695 }
696 if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
697 return 0;
698 fail:;
699 __Pyx_ZeroBuffer(buf);
700 return -1;
701 }""") % locals())
702 code.globalstate.use_utility_code(utilcode, name)
703 return name
705 def use_py2_buffer_functions(env):
706 # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
707 # For >= 2.6 we do double mode -- use the new buffer interface on objects
708 # which has the right tp_flags set, but emulation otherwise.
709 codename = "PyObject_GetBuffer" # just a representative unique key
711 # Search all types for __getbuffer__ overloads
712 types = []
713 visited_scopes = set()
714 def find_buffer_types(scope):
715 if scope in visited_scopes:
716 return
717 visited_scopes.add(scope)
718 for m in scope.cimported_modules:
719 find_buffer_types(m)
720 for e in scope.type_entries:
721 t = e.type
722 if t.is_extension_type:
723 release = get = None
724 for x in t.scope.pyfunc_entries:
725 if x.name == u"__getbuffer__": get = x.func_cname
726 elif x.name == u"__releasebuffer__": release = x.func_cname
727 if get:
728 types.append((t.typeptr_cname, get, release))
730 find_buffer_types(env)
732 code = dedent("""
733 #if PY_MAJOR_VERSION < 3
734 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
735 #if PY_VERSION_HEX >= 0x02060000
736 if (Py_TYPE(obj)->tp_flags & Py_TPFLAGS_HAVE_NEWBUFFER)
737 return PyObject_GetBuffer(obj, view, flags);
738 #endif
739 """)
740 if len(types) > 0:
741 clause = "if"
742 for t, get, release in types:
743 code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
744 clause = "else if"
745 code += " else {\n"
746 code += dedent("""\
747 PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
748 return -1;
749 """, 2)
750 if len(types) > 0: code += " }"
751 code += dedent("""
752 }
754 static void __Pyx_ReleaseBuffer(Py_buffer *view) {
755 PyObject* obj = view->obj;
756 if (obj) {
757 """)
758 if len(types) > 0:
759 clause = "if"
760 for t, get, release in types:
761 if release:
762 code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
763 clause = "else if"
764 code += dedent("""
765 Py_DECREF(obj);
766 view->obj = NULL;
767 }
768 }
770 #endif
771 """)
773 env.use_utility_code(UtilityCode(
774 proto = dedent("""\
775 #if PY_MAJOR_VERSION < 3
776 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
777 static void __Pyx_ReleaseBuffer(Py_buffer *view);
778 #else
779 #define __Pyx_GetBuffer PyObject_GetBuffer
780 #define __Pyx_ReleaseBuffer PyBuffer_Release
781 #endif
782 """), impl = code), codename)
784 #
785 # Static utility code
786 #
789 # Utility function to set the right exception
790 # The caller should immediately goto_error
791 raise_indexerror_code = UtilityCode(
792 proto = """\
793 static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
794 """,
795 impl = """\
796 static void __Pyx_RaiseBufferIndexError(int axis) {
797 PyErr_Format(PyExc_IndexError,
798 "Out of bounds on buffer access (axis %d)", axis);
799 }
801 """)
803 #
804 # Buffer type checking. Utility code for checking that acquired
805 # buffers match our assumptions. We only need to check ndim and
806 # the format string; the access mode/flags is checked by the
807 # exporter.
808 #
809 acquire_utility_code = UtilityCode(
810 proto = """\
811 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
812 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
813 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
814 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
815 static const char* __Pyx_DescribeTokenInFormatString(const char* ts); /*proto*/
816 """,
817 impl = """
818 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
819 if (info->buf == NULL) return;
820 if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
821 __Pyx_ReleaseBuffer(info);
822 }
824 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
825 buf->buf = NULL;
826 buf->obj = NULL;
827 buf->strides = __Pyx_zeros;
828 buf->shape = __Pyx_zeros;
829 buf->suboffsets = __Pyx_minusones;
830 }
832 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
833 while (1) {
834 switch (*ts) {
835 case '@':
836 case 10:
837 case 13:
838 case ' ':
839 ++ts;
840 break;
841 case '=':
842 case '<':
843 case '>':
844 case '!':
845 PyErr_SetString(PyExc_ValueError, "Buffer acquisition error: Only native byte order, size and alignment supported.");
846 return NULL;
847 default:
848 return ts;
849 }
850 }
851 }
853 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
854 PyErr_Format(PyExc_ValueError,
855 "Buffer has wrong number of dimensions (expected %d, got %d)",
856 expected_ndim, buffer->ndim);
857 }
859 static const char* __Pyx_DescribeTokenInFormatString(const char* ts) {
860 switch (*ts) {
861 case 'b': return "char";
862 case 'B': return "unsigned char";
863 case 'h': return "short";
864 case 'H': return "unsigned short";
865 case 'i': return "int";
866 case 'I': return "unsigned int";
867 case 'l': return "long";
868 case 'L': return "unsigned long";
869 case 'q': return "long long";
870 case 'Q': return "unsigned long long";
871 case 'f': return "float";
872 case 'd': return "double";
873 case 'g': return "long double";
874 case 'Z': switch (*(ts+1)) {
875 case 'f': return "complex float";
876 case 'd': return "complex double";
877 case 'g': return "complex long double";
878 default: return "unparseable format string";
879 }
880 case 'T': return "a struct";
881 case 'O': return "Python object";
882 case 'P': return "a pointer";
883 default: return "unparseable format string";
884 }
885 }
887 """)
890 parse_typestring_repeat_code = UtilityCode(
891 proto = """
892 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
893 """,
894 impl = """
895 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
896 int count;
897 if (*ts < '0' || *ts > '9') {
898 count = 1;
899 } else {
900 count = *ts++ - '0';
901 while (*ts >= '0' && *ts < '9') {
902 count *= 10;
903 count += *ts++ - '0';
904 }
905 }
906 *out_count = count;
907 return ts;
908 }
909 """)
911 raise_buffer_fallback_code = UtilityCode(
912 proto = """
913 static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
914 """,
915 impl = """
916 static void __Pyx_RaiseBufferFallbackError(void) {
917 PyErr_Format(PyExc_ValueError,
918 "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
919 }
921 """)