Cython has moved to github.

cython-devel

view Cython/Compiler/Buffer.py @ 1235:5de00fce9b73

Buffers: NumPy record array support, format string parsing improvements
author Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
date Sat Oct 11 18:48:15 2008 +0200 (3 years ago)
parents da30eeb06679
children c9fc106f9412
line source
1 from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
2 from Cython.Compiler.ModuleNode import ModuleNode
3 from Cython.Compiler.Nodes import *
4 from Cython.Compiler.ExprNodes import *
5 from Cython.Compiler.StringEncoding import EncodedString
6 from Cython.Compiler.Errors import CompileError
7 import Interpreter
8 import PyrexTypes
10 try:
11 set
12 except NameError:
13 from sets import Set as set
15 import textwrap
17 # Code cleanup ideas:
18 # - One could be more smart about casting in some places
19 # - Start using CCodeWriters to generate utility functions
20 # - Create a struct type per ndim rather than keeping loose local vars
23 def dedent(text, reindent=0):
24 text = textwrap.dedent(text)
25 if reindent > 0:
26 indent = " " * reindent
27 text = '\n'.join([indent + x for x in text.split('\n')])
28 return text
30 class IntroduceBufferAuxiliaryVars(CythonTransform):
32 #
33 # Entry point
34 #
36 buffers_exists = False
38 def __call__(self, node):
39 assert isinstance(node, ModuleNode)
40 self.max_ndim = 0
41 result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
42 if self.buffers_exists:
43 use_py2_buffer_functions(node.scope)
44 use_empty_bufstruct_code(node.scope, self.max_ndim)
45 return result
48 #
49 # Basic operations for transforms
50 #
51 def handle_scope(self, node, scope):
52 # For all buffers, insert extra variables in the scope.
53 # The variables are also accessible from the buffer_info
54 # on the buffer entry
55 bufvars = [entry for name, entry
56 in scope.entries.iteritems()
57 if entry.type.is_buffer]
58 if len(bufvars) > 0:
59 self.buffers_exists = True
62 if isinstance(node, ModuleNode) and len(bufvars) > 0:
63 # for now...note that pos is wrong
64 raise CompileError(node.pos, "Buffer vars not allowed in module scope")
65 for entry in bufvars:
66 name = entry.name
67 buftype = entry.type
68 if buftype.ndim > self.max_ndim:
69 self.max_ndim = buftype.ndim
71 # Declare auxiliary vars
72 cname = scope.mangle(Naming.bufstruct_prefix, name)
73 bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
74 type=PyrexTypes.c_py_buffer_type, pos=node.pos)
75 if entry.is_arg:
76 bufinfo.used = True # otherwise, NameNode will mark whether it is used
78 def var(prefix, idx, initval):
79 cname = scope.mangle(prefix, "%d_%s" % (idx, name))
80 result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
81 node.pos, cname=cname, is_cdef=True)
83 result.init = initval
84 if entry.is_arg:
85 result.used = True
86 return result
89 stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
90 shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
91 mode = entry.type.mode
92 if mode == 'full':
93 suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
94 else:
95 suboffsetvars = None
97 entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
99 scope.buffer_entries = bufvars
100 self.scope = scope
102 def visit_ModuleNode(self, node):
103 self.handle_scope(node, node.scope)
104 self.visitchildren(node)
105 return node
107 def visit_FuncDefNode(self, node):
108 self.handle_scope(node, node.local_scope)
109 self.visitchildren(node)
110 return node
112 #
113 # Analysis
114 #
115 buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
116 buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
117 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
119 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
120 ERR_BUF_TOO_MANY = 'Too many buffer options'
121 ERR_BUF_DUP = '"%s" buffer option already supplied'
122 ERR_BUF_MISSING = '"%s" missing'
123 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
124 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
125 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
126 ERR_BUF_BOOL = '"%s" must be a boolean'
128 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
129 """
130 Must be called during type analysis, as analyse is called
131 on the dtype argument.
133 posargs and dictargs should consist of a list and a dict
134 of tuples (value, pos). Defaults should be a dict of values.
136 Returns a dict containing all the options a buffer can have and
137 its value (with the positions stripped).
138 """
139 if defaults is None:
140 defaults = buffer_defaults
142 posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
144 if len(posargs) > buffer_positional_options_count:
145 raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
147 options = {}
148 for name, (value, pos) in dictargs.iteritems():
149 if not name in buffer_options:
150 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
151 options[name.encode("ASCII")] = value
153 for name, (value, pos) in zip(buffer_options, posargs):
154 if not name in buffer_options:
155 raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
156 if name in options:
157 raise CompileError(pos, ERR_BUF_DUP % name)
158 options[name] = value
160 # Check that they are all there and copy defaults
161 for name in buffer_options:
162 if not name in options:
163 try:
164 options[name] = defaults[name]
165 except KeyError:
166 if need_complete:
167 raise CompileError(globalpos, ERR_BUF_MISSING % name)
169 dtype = options.get("dtype")
170 if dtype and dtype.is_extension_type:
171 raise CompileError(globalpos, ERR_BUF_DTYPE)
173 ndim = options.get("ndim")
174 if ndim and (not isinstance(ndim, int) or ndim < 0):
175 raise CompileError(globalpos, ERR_BUF_NDIM)
177 mode = options.get("mode")
178 if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
179 raise CompileError(globalpos, ERR_BUF_MODE)
181 def assert_bool(name):
182 x = options.get(name)
183 if not isinstance(x, bool):
184 raise CompileError(globalpos, ERR_BUF_BOOL % name)
186 assert_bool('negative_indices')
187 assert_bool('cast')
189 return options
192 #
193 # Code generation
194 #
197 def get_flags(buffer_aux, buffer_type):
198 flags = 'PyBUF_FORMAT'
199 mode = buffer_type.mode
200 if mode == 'full':
201 flags += '| PyBUF_INDIRECT'
202 elif mode == 'strided':
203 flags += '| PyBUF_STRIDES'
204 elif mode == 'c':
205 flags += '| PyBUF_C_CONTIGUOUS'
206 elif mode == 'fortran':
207 flags += '| PyBUF_F_CONTIGUOUS'
208 else:
209 assert False
210 if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
211 return flags
213 def used_buffer_aux_vars(entry):
214 buffer_aux = entry.buffer_aux
215 buffer_aux.buffer_info_var.used = True
216 for s in buffer_aux.shapevars: s.used = True
217 for s in buffer_aux.stridevars: s.used = True
218 if buffer_aux.suboffsetvars:
219 for s in buffer_aux.suboffsetvars: s.used = True
221 def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
222 # Generate code to copy the needed struct info into local
223 # variables.
224 bufstruct = buffer_aux.buffer_info_var.cname
226 varspec = [("strides", buffer_aux.stridevars),
227 ("shape", buffer_aux.shapevars)]
228 if mode == 'full':
229 varspec.append(("suboffsets", buffer_aux.suboffsetvars))
231 for field, vars in varspec:
232 code.putln(" ".join(["%s = %s.%s[%d];" %
233 (s.cname, bufstruct, field, idx)
234 for idx, s in enumerate(vars)]))
236 def put_acquire_arg_buffer(entry, code, pos):
237 code.globalstate.use_utility_code(acquire_utility_code)
238 buffer_aux = entry.buffer_aux
239 getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
241 # Acquire any new buffer
242 code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
243 getbuffer_cname,
244 entry.cname,
245 entry.buffer_aux.buffer_info_var.cname,
246 get_flags(buffer_aux, entry.type),
247 entry.type.ndim,
248 int(entry.type.cast)), pos))
249 # An exception raised in arg parsing cannot be catched, so no
250 # need to care about the buffer then.
251 put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
253 #def put_release_buffer_normal(entry, code):
254 # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
255 # entry.cname,
256 # entry.cname,
257 # entry.buffer_aux.buffer_info_var.cname))
259 def get_release_buffer_code(entry):
260 return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
262 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
263 is_initialized, pos, code):
264 """
265 Generate code for reassigning a buffer variables. This only deals with getting
266 the buffer auxiliary structure and variables set up correctly, the assignment
267 itself and refcounting is the responsibility of the caller.
269 However, the assignment operation may throw an exception so that the reassignment
270 never happens.
272 Depending on the circumstances there are two possible outcomes:
273 - Old buffer released, new acquired, rhs assigned to lhs
274 - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
275 (which may or may not succeed).
276 """
278 code.globalstate.use_utility_code(acquire_utility_code)
279 bufstruct = buffer_aux.buffer_info_var.cname
280 flags = get_flags(buffer_aux, buffer_type)
282 getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
283 # note: object is filled in later (%%s)
284 bufstruct,
285 flags,
286 buffer_type.ndim,
287 int(buffer_type.cast))
289 if is_initialized:
290 # Release any existing buffer
291 code.putln('__Pyx_SafeReleaseBuffer(&%s);' % bufstruct)
292 # Acquire
293 retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
294 code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
295 code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
296 # If acquisition failed, attempt to reacquire the old buffer
297 # before raising the exception. A failure of reacquisition
298 # will cause the reacquisition exception to be reported, one
299 # can consider working around this later.
300 code.begin_block()
301 type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
302 for i in range(3)]
303 code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
304 code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
305 code.begin_block()
306 code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
307 code.globalstate.use_utility_code(raise_buffer_fallback_code)
308 code.putln('__Pyx_RaiseBufferFallbackError();')
309 code.putln('} else {')
310 code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
311 for t in (type, value, tb):
312 code.funcstate.release_temp(t)
313 code.end_block()
314 # Unpack indices
315 code.end_block()
316 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
317 code.putln(code.error_goto_if_neg(retcode_cname, pos))
318 code.funcstate.release_temp(retcode_cname)
319 else:
320 # Our entry had no previous value, so set to None when acquisition fails.
321 # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
322 # so it suffices to set the buf field to NULL.
323 code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
324 code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
325 (lhs_cname,
326 PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
327 bufstruct))
328 code.putln(code.error_goto(pos))
329 code.put('} else {')
330 # Unpack indices
331 put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
332 code.putln('}')
335 def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
336 """
337 Generates code to process indices and calculate an offset into
338 a buffer. Returns a C string which gives a pointer which can be
339 read from or written to at will (it is an expression so caller should
340 store it in a temporary if it is used more than once).
342 As the bounds checking can have any number of combinations of unsigned
343 arguments, smart optimizations etc. we insert it directly in the function
344 body. The lookup however is delegated to a inline function that is instantiated
345 once per ndim (lookup with suboffsets tend to get quite complicated).
347 """
348 bufaux = entry.buffer_aux
349 bufstruct = bufaux.buffer_info_var.cname
350 negative_indices = entry.type.negative_indices
352 if options['boundscheck']:
353 # Check bounds and fix negative indices.
354 # We allocate a temporary which is initialized to -1, meaning OK (!).
355 # If an error occurs, the temp is set to the dimension index the
356 # error is occuring at.
357 tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
358 code.putln("%s = -1;" % tmp_cname)
359 for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
360 bufaux.shapevars)):
361 if signed != 0:
362 # not unsigned, deal with negative index
363 code.putln("if (%s < 0) {" % cname)
364 if negative_indices:
365 code.putln("%s += %s;" % (cname, shape.cname))
366 code.putln("if (%s) %s = %d;" % (
367 code.unlikely("%s < 0" % cname), tmp_cname, dim))
368 else:
369 code.putln("%s = %d;" % (tmp_cname, dim))
370 code.put("} else ")
371 # check bounds in positive direction
372 code.putln("if (%s) %s = %d;" % (
373 code.unlikely("%s >= %s" % (cname, shape.cname)),
374 tmp_cname, dim))
375 code.globalstate.use_utility_code(raise_indexerror_code)
376 code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
377 code.begin_block()
378 code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
379 code.putln(code.error_goto(pos))
380 code.end_block()
381 code.funcstate.release_temp(tmp_cname)
382 elif negative_indices:
383 # Only fix negative indices.
384 for signed, cname, shape in zip(index_signeds, index_cnames,
385 bufaux.shapevars):
386 if signed != 0:
387 code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
389 # Create buffer lookup and return it
390 # This is done via utility macros/inline functions, which vary
391 # according to the access mode used.
392 params = []
393 nd = entry.type.ndim
394 mode = entry.type.mode
395 if mode == 'full':
396 for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
397 params.append(i)
398 params.append(s.cname)
399 params.append(o.cname)
400 funcname = "__Pyx_BufPtrFull%dd" % nd
401 funcgen = buf_lookup_full_code
402 else:
403 if mode == 'strided':
404 funcname = "__Pyx_BufPtrStrided%dd" % nd
405 funcgen = buf_lookup_strided_code
406 elif mode == 'c':
407 funcname = "__Pyx_BufPtrCContig%dd" % nd
408 funcgen = buf_lookup_c_code
409 elif mode == 'fortran':
410 funcname = "__Pyx_BufPtrFortranContig%dd" % nd
411 funcgen = buf_lookup_fortran_code
412 else:
413 assert False
414 for i, s in zip(index_cnames, bufaux.stridevars):
415 params.append(i)
416 params.append(s.cname)
418 # Make sure the utility code is available
419 code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
421 ptr_type = entry.type.buffer_ptr_type
422 ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
423 ptr_type.declaration_code(""),
424 bufstruct,
425 ", ".join(params))
426 return ptrcode
429 def use_empty_bufstruct_code(env, max_ndim):
430 code = dedent("""
431 Py_ssize_t __Pyx_zeros[] = {%s};
432 Py_ssize_t __Pyx_minusones[] = {%s};
433 """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
434 env.use_utility_code([code, ""], "empty_bufstruct_code")
437 def buf_lookup_full_code(proto, defin, name, nd):
438 """
439 Generates a buffer lookup function for the right number
440 of dimensions. The function gives back a void* at the right location.
441 """
442 # _i_ndex, _s_tride, sub_o_ffset
443 macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
444 proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))
446 funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
447 proto.putln("static INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
448 defin.putln(dedent("""
449 static INLINE void* %s_imp(void* buf, %s) {
450 char* ptr = (char*)buf;
451 """) % (name, funcargs) + "".join([dedent("""\
452 ptr += s%d * i%d;
453 if (o%d >= 0) ptr = *((char**)ptr) + o%d;
454 """) % (i, i, i, i) for i in range(nd)]
455 ) + "\nreturn ptr;\n}")
457 def buf_lookup_strided_code(proto, defin, name, nd):
458 """
459 Generates a buffer lookup function for the right number
460 of dimensions. The function gives back a void* at the right location.
461 """
462 # _i_ndex, _s_tride
463 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
464 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
465 proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))
467 def buf_lookup_c_code(proto, defin, name, nd):
468 """
469 Similar to strided lookup, but can assume that the last dimension
470 doesn't need a multiplication as long as.
471 Still we keep the same signature for now.
472 """
473 if nd == 1:
474 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
475 else:
476 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
477 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
478 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))
480 def buf_lookup_fortran_code(proto, defin, name, nd):
481 """
482 Like C lookup, but the first index is optimized instead.
483 """
484 if nd == 1:
485 proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
486 else:
487 args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
488 offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
489 proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
491 #
492 # Utils for creating type string checkers
493 #
494 def mangle_dtype_name(dtype):
495 # Use prefixes to seperate user defined types from builtins
496 # (consider "typedef float unsigned_int")
497 if dtype.is_pyobject:
498 return "object"
499 elif dtype.is_ptr:
500 return "ptr"
501 else:
502 if dtype.typestring is None:
503 prefix = "nn_"
504 else:
505 prefix = ""
506 return prefix + dtype.declaration_code("").replace(" ", "_")
508 def get_ts_check_item(dtype, writer):
509 # See if we can consume one (unnamed) dtype as next item
510 # Put native and custom types in seperate namespaces (as one could create a type named unsigned_int...)
511 name = "__Pyx_CheckTypestringItem_%s" % mangle_dtype_name(dtype)
512 if not writer.globalstate.has_code(name):
513 char = dtype.typestring
514 if char is not None:
515 assert len(char) == 1
516 # Can use direct comparison
517 code = dedent("""\
518 if (*ts != '%s') {
519 PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts);
520 return NULL;
521 } else return ts + 1;
522 """, 2) % (char, char)
523 else:
524 # Cannot trust declared size; but rely on int vs float and
525 # signed/unsigned to be correctly declared
526 ctype = dtype.declaration_code("")
527 code = dedent("""\
528 int ok;
529 switch (*ts) {""", 2)
530 if dtype.is_int:
531 types = [
532 ('b', 'char'), ('h', 'short'), ('i', 'int'),
533 ('l', 'long'), ('q', 'long long')
534 ]
535 elif dtype.is_float:
536 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
537 else:
538 assert False
539 if dtype.signed == 0:
540 code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
541 (char.upper(), ctype, against, ctype) for char, against in types])
542 else:
543 code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
544 (char, ctype, against, ctype) for char, against in types])
545 code += dedent("""\
546 default: ok = 0;
547 }
548 if (!ok) {
549 PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%s')", ts);
550 return NULL;
551 } else return ts + 1;
552 """, 2)
555 writer.globalstate.use_utility_code([dedent("""\
556 static const char* %s(const char* ts); /*proto*/
557 """) % name, dedent("""
558 static const char* %s(const char* ts) {
559 %s
560 }
561 """) % (name, code)], name=name)
563 return name
565 def get_typestringchecker(code, dtype):
566 """
567 Returns the name of a typestring checker with the given type; emitting
568 it to code if needed.
569 """
570 name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
571 code.globalstate.use_code_from(create_typestringchecker,
572 name,
573 dtype=dtype)
574 return name
576 def create_typestringchecker(protocode, defcode, name, dtype):
578 def put_assert(cond, msg):
579 defcode.putln("if (!(%s)) {" % cond)
580 msg += ", got '%s'"
581 defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg)
582 defcode.putln("return NULL;")
583 defcode.putln("}")
585 if dtype.is_error: return
586 simple = dtype.is_simple_buffer_dtype()
587 complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
588 # Cannot add utility code recursively...
589 if not simple:
590 dtype_t = dtype.declaration_code("")
591 protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
592 fields = dtype.scope.var_entries
594 # divide fields into blocks of equal type (for repeat count)
595 field_blocks = [] # of (n, type, checkerfunc)
596 n = 0
597 prevtype = None
598 for f in fields:
599 if n and f.type != prevtype:
600 field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
601 n = 0
602 prevtype = f.type
603 n += 1
604 field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
606 protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
607 defcode.putln("static const char* %s(const char* ts) {" % name)
608 if simple:
609 defcode.putln("int ok;")
610 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
611 defcode.putln("if (*ts == '1') ++ts;")
612 if dtype.typestring is not None:
613 assert len(dtype.typestring) == 1
614 # Can use direct comparison
615 defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
616 else:
617 # Cannot trust declared size; but rely on int vs float and
618 # signed/unsigned to be correctly declared. Use a switch statement
619 # on all possible format codes to validate that the size is ok.
620 # (Note that many codes may map to same size, e.g. 'i' and 'l'
621 # may both be four bytes).
622 ctype = dtype.declaration_code("")
623 defcode.putln("switch (*ts) {")
624 if dtype.is_int:
625 types = [
626 ('b', 'char'), ('h', 'short'), ('i', 'int'),
627 ('l', 'long'), ('q', 'long long')
628 ]
629 elif dtype.is_float:
630 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
631 else:
632 assert False
633 if dtype.signed == 0:
634 for char, against in types:
635 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
636 (char.upper(), ctype, against, ctype))
637 else:
638 for char, against in types:
639 defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
640 (char, ctype, against, ctype))
641 defcode.putln("default: ok = 0;")
642 defcode.putln("}")
643 defcode.putln("if (!ok) {")
644 if dtype.typestring is not None:
645 errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring
646 else:
647 errmsg = "Buffer datatype mismatch (rejecting on '%s')"
648 defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg)
649 defcode.putln("return NULL;");
650 defcode.putln("}")
651 defcode.putln("++ts;")
652 elif complex_possible:
653 # Could be a struct representing a complex number, so allow
654 # for parsing a "Zf" spec.
655 real_t, imag_t = [x.type for x in fields]
656 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
657 defcode.putln("if (*ts == '1') ++ts;")
658 defcode.putln("if (*ts == 'Z') {")
659 if len(field_blocks) == 2:
660 # Different float type, sizeof check needed
661 defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
662 real_t.declaration_code(""),
663 imag_t.declaration_code("")))
664 defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
665 dtype.declaration_code("", for_display=True),
666 real_t.declaration_code("", for_display=True),
667 imag_t.declaration_code("", for_display=True)))
668 defcode.putln("return NULL;")
669 defcode.putln("}")
670 check_real, check_imag = [x[2] for x in field_blocks]
671 else:
672 assert len(field_blocks) == 1
673 check_real = check_imag = field_blocks[0][2]
674 defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
675 defcode.putln("} else {")
676 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
677 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
678 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
679 defcode.putln("}")
680 else:
681 defcode.putln("int n, count;")
682 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
684 for n, type, checker in field_blocks:
685 if n == 1:
686 defcode.putln("if (*ts == '1') ++ts;")
687 else:
688 defcode.putln("n = %d;" % n);
689 defcode.putln("do {")
690 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
692 simple = type.is_simple_buffer_dtype()
693 if not simple:
694 put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True))
695 defcode.putln("ts += 2;")
696 defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
697 if not simple:
698 put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True))
699 defcode.putln("++ts;")
701 if n > 1:
702 defcode.putln("} while (n > 0);");
703 defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
705 defcode.putln("return ts;")
706 defcode.putln("}")
708 def get_getbuffer_code(dtype, code):
709 """
710 Generate a utility function for getting a buffer for the given dtype.
711 The function will:
712 - Call PyObject_GetBuffer
713 - Check that ndim matched the expected value
714 - Check that the format string is right
715 - Set suboffsets to all -1 if it is returned as NULL.
716 """
718 name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
719 if not code.globalstate.has_code(name):
720 code.globalstate.use_utility_code(acquire_utility_code)
721 typestringchecker = get_typestringchecker(code, dtype)
722 dtype_name = str(dtype)
723 dtype_cname = dtype.declaration_code("")
724 utilcode = [dedent("""
725 static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
726 """) % name, dedent("""
727 static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
728 const char* ts;
729 if (obj == Py_None) {
730 __Pyx_ZeroBuffer(buf);
731 return 0;
732 }
733 buf->buf = NULL;
734 if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
735 if (buf->ndim != nd) {
736 __Pyx_BufferNdimError(buf, nd);
737 goto fail;
738 }
739 if (!cast) {
740 ts = buf->format;
741 ts = __Pyx_ConsumeWhitespace(ts);
742 if (!ts) goto fail;
743 ts = %(typestringchecker)s(ts);
744 if (!ts) goto fail;
745 ts = __Pyx_ConsumeWhitespace(ts);
746 if (!ts) goto fail;
747 if (*ts != 0) {
748 PyErr_Format(PyExc_ValueError,
749 "Buffer format string specifies more data than '%(dtype_name)s' can hold (expected end, got '%%s')", ts);
750 goto fail;
751 }
752 } else {
753 if (buf->itemsize != sizeof(%(dtype_cname)s)) {
754 PyErr_SetString(PyExc_ValueError,
755 "Attempted cast of buffer to datatype of different size.");
756 goto fail;
757 }
758 }
759 if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
760 return 0;
761 fail:;
762 __Pyx_ZeroBuffer(buf);
763 return -1;
764 }""") % locals()]
765 code.globalstate.use_utility_code(utilcode, name)
766 return name
768 def buffer_type_checker(dtype, code):
769 # Creates a type checker function for the given type.
770 if dtype.is_struct_or_union:
771 assert False
772 elif dtype.is_int or dtype.is_float:
773 # This includes simple typedef-ed types
774 funcname = get_getbuffer_code(dtype, code)
775 else:
776 assert False
777 return funcname
779 def use_py2_buffer_functions(env):
780 # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
781 # For >= 2.6 we do double mode -- use the new buffer interface on objects
782 # which has the right tp_flags set, but emulation otherwise.
783 codename = "PyObject_GetBuffer" # just a representative unique key
785 # Search all types for __getbuffer__ overloads
786 types = []
787 def find_buffer_types(scope):
788 for m in scope.cimported_modules:
789 find_buffer_types(m)
790 for e in scope.type_entries:
791 t = e.type
792 if t.is_extension_type:
793 release = get = None
794 for x in t.scope.pyfunc_entries:
795 if x.name == u"__getbuffer__": get = x.func_cname
796 elif x.name == u"__releasebuffer__": release = x.func_cname
797 if get:
798 types.append((t.typeptr_cname, get, release))
800 find_buffer_types(env)
802 code = dedent("""
803 #if PY_MAJOR_VERSION < 3
804 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
805 #if PY_VERSION_HEX >= 0x02060000
806 if (Py_TYPE(obj)->tp_flags & Py_TPFLAGS_HAVE_NEWBUFFER)
807 return PyObject_GetBuffer(obj, view, flags);
808 #endif
809 """)
810 if len(types) > 0:
811 clause = "if"
812 for t, get, release in types:
813 code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
814 clause = "else if"
815 code += " else {\n"
816 code += dedent("""\
817 PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
818 return -1;
819 """, 2)
820 if len(types) > 0: code += " }"
821 code += dedent("""
822 }
824 static void __Pyx_ReleaseBuffer(Py_buffer *view) {
825 PyObject* obj = view->obj;
826 if (obj) {
827 """)
828 if len(types) > 0:
829 clause = "if"
830 for t, get, release in types:
831 if release:
832 code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
833 clause = "else if"
834 code += dedent("""
835 Py_DECREF(obj);
836 view->obj = NULL;
837 }
838 }
840 #endif
841 """)
843 env.use_utility_code([dedent("""\
844 #if PY_MAJOR_VERSION < 3
845 static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
846 static void __Pyx_ReleaseBuffer(Py_buffer *view);
847 #else
848 #define __Pyx_GetBuffer PyObject_GetBuffer
849 #define __Pyx_ReleaseBuffer PyBuffer_Release
850 #endif
851 """), code], codename)
853 #
854 # Static utility code
855 #
858 # Utility function to set the right exception
859 # The caller should immediately goto_error
860 raise_indexerror_code = [
861 """\
862 static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
863 ""","""\
864 static void __Pyx_RaiseBufferIndexError(int axis) {
865 PyErr_Format(PyExc_IndexError,
866 "Out of bounds on buffer access (axis %d)", axis);
867 }
869 """]
871 #
872 # Buffer type checking. Utility code for checking that acquired
873 # buffers match our assumptions. We only need to check ndim and
874 # the format string; the access mode/flags is checked by the
875 # exporter.
876 #
877 acquire_utility_code = ["""\
878 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
879 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
880 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
881 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
882 """, """
883 static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
884 if (info->buf == NULL) return;
885 if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
886 __Pyx_ReleaseBuffer(info);
887 }
889 static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
890 buf->buf = NULL;
891 buf->obj = NULL;
892 buf->strides = __Pyx_zeros;
893 buf->shape = __Pyx_zeros;
894 buf->suboffsets = __Pyx_minusones;
895 }
897 static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
898 while (1) {
899 switch (*ts) {
900 case '@':
901 case 10:
902 case 13:
903 case ' ':
904 ++ts;
905 break;
906 case '=':
907 case '<':
908 case '>':
909 case '!':
910 PyErr_SetString(PyExc_ValueError, "Buffer acquisition error: Only native byte order, size and alignment supported.");
911 return NULL;
912 default:
913 return ts;
914 }
915 }
916 }
918 static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
919 PyErr_Format(PyExc_ValueError,
920 "Buffer has wrong number of dimensions (expected %d, got %d)",
921 expected_ndim, buffer->ndim);
922 }
924 """]
927 parse_typestring_repeat_code = ["""
928 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
929 ""","""
930 static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
931 int count;
932 if (*ts < '0' || *ts > '9') {
933 count = 1;
934 } else {
935 count = *ts++ - '0';
936 while (*ts >= '0' && *ts < '9') {
937 count *= 10;
938 count += *ts++ - '0';
939 }
940 }
941 *out_count = count;
942 return ts;
943 }
944 """]
946 raise_buffer_fallback_code = ["""
947 static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
948 ""","""
949 static void __Pyx_RaiseBufferFallbackError(void) {
950 PyErr_Format(PyExc_ValueError,
951 "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
952 }
954 """]