/*- * Copyright (c) 2015--2018 Taylor R. Campbell * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF * SUCH DAMAGE. */ #if defined(__NetBSD__) && defined(_KERNEL) #include #include #include #else #include #include #include #include #include #include #endif #include "pb.h" #include "pb_decode.h" #define PB_ABI_VERSION 0x00000000 struct decode { pb_decoder_callback_t *callback; void *arg; }; static int pb_decode_buf_partial(struct decode *, void *, size_t *); static int pb_decode_buf(struct decode *, void *, size_t); static int pb_skip(struct decode *, size_t); static int pb_decode_1_eof(struct decode *, uint8_t *, bool *); static int pb_decode_1(struct decode *, uint8_t *); static int pb_decode_by_hdr(struct decode *, struct pb_msg_hdr *); static const struct pb_field * pb_find_field(const struct pb_msgdesc *, uint32_t); static int pb_decode_check_required(const struct pb_msgdesc *, uint32_t[PB_MAX_REQUIRED_FIELDS], unsigned int); static int pb_skip_field(struct decode *, enum pb_wiretype); static int pb_skip_varint(struct decode *); static int pb_skip_length_delimited(struct decode *); static int pb_decode_field(struct decode *, unsigned char *, const struct pb_field *, enum pb_wiretype, uint32_t[PB_MAX_REQUIRED_FIELDS], unsigned int *); static int pb_decode_field_value(struct decode *, const struct pb_field *, enum pb_wiretype, unsigned char *); static const struct pb_enumerand * pb_enumerand_by_number(const struct pb_enumeration *, int32_t); static int pb_decompose_tag(uint64_t, uint32_t *, enum pb_wiretype *); static int pb_decode_tag(struct decode *, uint32_t *, enum pb_wiretype *); static int pb_decode_varint_eof(struct decode *, uint64_t *, bool *); static int pb_decode_varint(struct decode *, uint64_t *); static int pb_decode_varint_u(struct decode *, uint64_t *); static int pb_decode_fixed32(struct decode *, uint32_t *); static int pb_decode_fixed64(struct decode *, uint64_t *); static int pb_decode_varint_s(struct decode *, int64_t *); static int pb_decode_zigzag(struct decode *, int64_t *); static int pb_decode_sfixed32(struct decode *, int32_t *); static int pb_decode_sfixed64(struct decode *, int64_t *); static int pb_decode_ieee32(struct decode *, float *); static int pb_decode_ieee64(struct decode *, double *); static int pb_decode_length(struct decode *, size_t *); static int pb_decode_submsg(struct decode *, const struct pb_msgdesc *, struct pb_msg_hdr *, size_t); static void sort32(uint32_t *, size_t); struct decode_memory { const uint8_t *ptr; size_t nleft; }; static pb_decoder_callback_t decode_memory_callback; static int decode_memory_callback(void *cookie, void *buf, size_t *size) { struct decode_memory *const M = cookie; size_t n = *size < M->nleft? *size : M->nleft; (void)memcpy(buf, M->ptr, n); M->ptr += n; M->nleft -= n; *size = n; return 0; } int pb_decode_from_memory(struct pb_msg msg, const void *buf, size_t len) { struct decode_memory M = { .ptr = buf, .nleft = len }; return pb_decode(msg, &decode_memory_callback, &M); } int pb_decode(struct pb_msg msg, pb_decoder_callback_t *callback, void *arg) { struct pb_msg_hdr *const msg_hdr = (struct pb_msg_hdr *)msg.pbm_ptr; struct decode D = { .callback = callback, .arg = arg, }; if (!pb_abi_compatible(msg.pbm_msgdesc->pbmd_abi_version, PB_ABI_VERSION)) /* XXX pb_bug */ return EINVAL; if (msg_hdr->pbmh_msgdesc != msg.pbm_msgdesc) /* XXX pb_bug */ return EINVAL; return pb_decode_by_hdr(&D, msg_hdr); } /* * XXX It would be nice if we could have a pb_decode_ptr(D, &p, size) * which would store in p a pointer to a buffer of the requested size, * rather than copying it to another buffer, if possible. */ static int pb_decode_buf_partial(struct decode *D, void *buf, size_t *size) { return (*D->callback)(D->arg, buf, size); } static int pb_decode_buf(struct decode *D, void *buf, size_t size) { size_t rsize = size; int error; error = pb_decode_buf_partial(D, buf, &rsize); if (error) return error; if (rsize != size) return EIO; /* XXX What error code? */ return 0; } static int pb_skip(struct decode *D, size_t size) { return pb_decode_buf(D, NULL, size); } static int pb_decode_1_eof(struct decode *D, uint8_t *p, bool *eofp) { size_t n = 1; int error; error = pb_decode_buf_partial(D, p, &n); if (error) return error; pb_assert(n <= 1); *eofp = (n == 0); return 0; } static int pb_decode_1(struct decode *D, uint8_t *p) { return pb_decode_buf(D, p, 1); } static int pb_decode_by_hdr(struct decode *D, struct pb_msg_hdr *msg_hdr) { unsigned char *const addr = (void *)msg_hdr; const struct pb_msgdesc *const msgdesc = msg_hdr->pbmh_msgdesc; uint32_t tag; enum pb_wiretype wiretype; const struct pb_field *field; uint32_t req_fields[PB_MAX_REQUIRED_FIELDS]; unsigned int nreq_fields = 0; int error; while ((error = pb_decode_tag(D, &tag, &wiretype)) == 0) { /* Allow for zero-delimited or externally framed messages. */ if (tag == 0) break; field = pb_find_field(msgdesc, tag); if (field == NULL) error = pb_skip_field(D, wiretype); else error = pb_decode_field(D, addr, field, wiretype, req_fields, &nreq_fields); if (error) break; } if (error) return error; error = pb_decode_check_required(msgdesc, req_fields, nreq_fields); if (error) return error; return 0; } static const struct pb_field * pb_find_field(const struct pb_msgdesc *msgdesc, uint32_t tag) { size_t start = 0, end = msgdesc->pbmd_nfields; while (start < end) { const size_t i = (start + ((end - start) / 2)); if (tag < msgdesc->pbmd_fields[i].pbf_tag) end = i; else if (tag > msgdesc->pbmd_fields[i].pbf_tag) start = (i + 1); else return &msgdesc->pbmd_fields[i]; } return NULL; } static int pb_decode_check_required(const struct pb_msgdesc *msgdesc, uint32_t req_fields[PB_MAX_REQUIRED_FIELDS], unsigned int nreq_fields) { unsigned int i, j; sort32(req_fields, nreq_fields); for (i = 0, j = 0; i < nreq_fields; i++, j++) { while (pb_assert(j < msgdesc->pbmd_nfields), msgdesc->pbmd_fields[j].pbf_quant != PBQ_REQUIRED) j++; if (req_fields[i] != msgdesc->pbmd_fields[j].pbf_tag) return EIO; /* XXX What error code? */ } return 0; } static int pb_skip_field(struct decode *D, enum pb_wiretype wiretype) { switch (wiretype) { case PB_WIRETYPE_VARINT: return pb_skip_varint(D); case PB_WIRETYPE_32BIT: return pb_skip(D, 4); case PB_WIRETYPE_64BIT: return pb_skip(D, 8); case PB_WIRETYPE_LENGTH_DELIMITED: return pb_skip_length_delimited(D); default: return EIO; /* XXX What error code? */ } } static int pb_skip_varint(struct decode *D) { uint8_t o; int error; do { error = pb_decode_1(D, &o); if (error) return error; } while ((o & 0x80) != 0); return 0; } static int pb_skip_length_delimited(struct decode *D) { size_t n; int error; error = pb_decode_length(D, &n); if (error) return error; return pb_skip(D, n); } static int pb_decode_field(struct decode *D, unsigned char *addr, const struct pb_field *field, enum pb_wiretype wiretype, uint32_t req_fields[PB_MAX_REQUIRED_FIELDS], unsigned int *nreq_fields) { int error; switch (field->pbf_quant) { case PBQ_REQUIRED: error = pb_decode_field_value(D, field, wiretype, (addr + field->pbf_qu.required.offset)); if (error) return error; pb_assert(*nreq_fields < (PB_MAX_REQUIRED_FIELDS + 1)); req_fields[*nreq_fields++] = field->pbf_tag; return 0; case PBQ_OPTIONAL: *(bool *)(addr + field->pbf_qu.optional.present_offset) = true; return pb_decode_field_value(D, field, wiretype, (addr + field->pbf_qu.optional.value_offset)); case PBQ_REPEATED: { struct pb_repeated *const repeated = (struct pb_repeated *)(addr + field->pbf_qu.repeated.hdr_offset); unsigned char *ptr; const size_t elemsize = pb_type_size(&field->pbf_type); size_t i; if ((0 < field->pbf_qu.repeated.maximum) && (field->pbf_qu.repeated.maximum <= pb_repeated_count(repeated))) return pb_skip_field(D, wiretype); error = pb_repeated_add(repeated, &i); if (error) return error; ptr = *(void *const *)(addr + field->pbf_qu.repeated.ptr_offset); return pb_decode_field_value(D, field, wiretype, (ptr + (i * elemsize))); } default: return EIO; /* XXX */ } } static int pb_decode_field_value(struct decode *D, const struct pb_field *field, enum pb_wiretype wiretype, unsigned char *value) { int error; switch (field->pbf_type.pbt_type) { uint32_t u32; uint64_t u64; int32_t s32; int64_t s64; float f; double d; #define DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, CHECK) do \ { \ if (wiretype != (WIRETYPE)) \ return EIO; /* XXX What error code? */ \ error = DECODER(D, &(VAR)); \ if (error) \ return error; \ CHECK; \ *(TYPE *)value = (VAR); \ return 0; \ } while (0) #define DECODE(TYPE, WIRETYPE, DECODER, VAR) \ DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, do {} while (0)) #define DECODE_MAX(TYPE, WIRETYPE, DECODER, VAR, MAXIMUM) \ DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, do { \ if ((MAXIMUM) < (VAR)) \ return ERANGE; /* XXX What error code? */ \ } while (0)) #define DECODE_MINMAX(TYPE, WIRETYPE, DECODER, VAR, MINIMUM, MAXIMUM) \ DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, do { \ if ((MAXIMUM) < (VAR)) \ return ERANGE; /* XXX What error code? */ \ } while (0)) case PB_TYPE_BOOL: DECODE_MAX(bool, PB_WIRETYPE_VARINT, pb_decode_varint_u, u64, 1); case PB_TYPE_UINT32: DECODE_MAX(uint32_t, PB_WIRETYPE_VARINT, pb_decode_varint_u, u64, UINT32_MAX); case PB_TYPE_UINT64: DECODE(uint64_t, PB_WIRETYPE_VARINT, pb_decode_varint_u, u64); case PB_TYPE_FIXED32: DECODE(uint32_t, PB_WIRETYPE_32BIT, pb_decode_fixed32, u32); case PB_TYPE_FIXED64: DECODE(uint64_t, PB_WIRETYPE_64BIT, pb_decode_fixed64, u64); case PB_TYPE_INT32: DECODE_MINMAX(int32_t, PB_WIRETYPE_VARINT, pb_decode_varint_s, s64, INT32_MIN, INT32_MAX); case PB_TYPE_INT64: DECODE(int64_t, PB_WIRETYPE_VARINT, pb_decode_varint_s, s64); case PB_TYPE_SINT32: DECODE_MINMAX(int32_t, PB_WIRETYPE_VARINT, pb_decode_zigzag, s64, INT32_MIN, INT32_MAX); case PB_TYPE_SINT64: DECODE(int64_t, PB_WIRETYPE_VARINT, pb_decode_zigzag, s64); case PB_TYPE_SFIXED32: DECODE(int32_t, PB_WIRETYPE_32BIT, pb_decode_sfixed32, s32); case PB_TYPE_SFIXED64: DECODE(int64_t, PB_WIRETYPE_64BIT, pb_decode_sfixed64, s64); case PB_TYPE_ENUM: DECODE_CHECK(int32_t, PB_WIRETYPE_VARINT, pb_decode_varint_s, s64, do { const struct pb_type *const type = &field->pbf_type; const struct pb_enumeration *const enumeration = type->pbt_u.enumerated.enumeration; if ((s64 < INT32_MIN) || (INT32_MAX < s64)) return ERANGE; /* XXX What error code? */ if (pb_enumerand_by_number(enumeration, s64) == NULL) return EIO; /* XXX What error code? */ } while (0)); case PB_TYPE_FLOAT: DECODE(float, PB_WIRETYPE_32BIT, pb_decode_ieee32, f); case PB_TYPE_DOUBLE: DECODE(double, PB_WIRETYPE_64BIT, pb_decode_ieee64, d); case PB_TYPE_BYTES: { struct pb_bytes *const bytes = (struct pb_bytes *)value; size_t size, tsize pb_attr_diagused; uint8_t *ptr; if (wiretype != PB_WIRETYPE_LENGTH_DELIMITED) return EIO; /* XXX What error code? */ error = pb_decode_length(D, &size); if (error) return error; error = pb_bytes_alloc(bytes, size); if (error) return error; ptr = pb_bytes_ptr_mutable(bytes, &tsize); pb_assert(tsize == size); return pb_decode_buf(D, ptr, size); } case PB_TYPE_STRING: { struct pb_string *const string = (struct pb_string *)value; size_t len; char *ptr; if (wiretype != PB_WIRETYPE_LENGTH_DELIMITED) return EIO; /* XXX What error code? */ error = pb_decode_length(D, &len); if (error) return error; error = pb_string_alloc(string, len); if (error) return error; pb_assert(pb_string_len(string) == len); ptr = pb_string_ptr_mutable(string); pb_assert(ptr[len] == '\0'); error = pb_decode_buf(D, ptr, len); if (error) return error; error = pb_utf8_validate(ptr, len); if (error) { (void)memset(ptr, 0, len); /* paranoia */ pb_string_set_ptr(string, "", 0); return error; } return 0; } case PB_TYPE_MSG: { struct pb_msg_hdr *msg_hdr = (struct pb_msg_hdr *)value; size_t size; if (wiretype != PB_WIRETYPE_LENGTH_DELIMITED) return EIO; /* XXX What error code? */ error = pb_decode_length(D, &size); if (error) return error; return pb_decode_submsg(D, field->pbf_type.pbt_u.msg.msgdesc, msg_hdr, size); } default: return EIO; /* XXX What error code? */ } } static const struct pb_enumerand * pb_enumerand_by_number(const struct pb_enumeration *enumeration, int32_t number) { size_t start = 0, end = enumeration->pben_nenumerands; while (start < end) { const size_t i = (start + ((end - start) / 2)); if (number < enumeration->pben_enumerands[i].pbed_number) end = i; else if (number > enumeration->pben_enumerands[i].pbed_number) start = (i + 1); else return &enumeration->pben_enumerands[i]; } return NULL; } /* * Tags and varints */ static int pb_decompose_tag(uint64_t wtag, uint32_t *tag, enum pb_wiretype *wiretype) { if ((wtag >> 3) & ~(uint64_t)0xffffffff) return EIO; /* XXX What error code? */ *tag = ((wtag >> 3) & 0xffffffff); *wiretype = (wtag & 7); return 0; } static int pb_decode_tag(struct decode *D, uint32_t *tag, enum pb_wiretype *wiretype) { uint64_t wtag; bool eof; int error; error = pb_decode_varint_eof(D, &wtag, &eof); if (error) return error; if (eof) { *tag = 0; return 0; } else { return pb_decompose_tag(wtag, tag, wiretype); } } static int pb_decode_varint_eof(struct decode *D, uint64_t *value, bool *eofp) { uint8_t o; uint64_t v; unsigned int s = 0; int error; error = pb_decode_1_eof(D, &o, eofp); if (error) return error; if (*eofp) return 0; if ((o & 0x80) == 0) { *value = o; return 0; } v = (o & 0x7f); do { s += 7; if (s >= 32) return ERANGE; /* XXX What error code? */ error = pb_decode_1(D, &o); if (error) return error; v |= (uint64_t)(o & 0x7f) << s; } while ((o & 0x80) != 0); *value = v; return 0; } static int pb_decode_varint(struct decode *D, uint64_t *value) { bool eof; int error; error = pb_decode_varint_eof(D, value, &eof); if (error) return error; if (eof) return EIO; /* XXX What error code? */ return 0; } /* * Unsigned integer formats */ static int pb_decode_varint_u(struct decode *D, uint64_t *p) { return pb_decode_varint(D, p); } static int pb_decode_fixed32(struct decode *D, uint32_t *p) { uint8_t buf[4]; int error; error = pb_decode_buf(D, buf, sizeof buf); if (error) return error; *p = buf[0] | ((uint32_t)buf[1] << 8) | ((uint32_t)buf[2] << 16) | ((uint32_t)buf[3] << 24); return 0; } static int pb_decode_fixed64(struct decode *D, uint64_t *p) { uint8_t buf[8]; int error; error = pb_decode_buf(D, buf, sizeof buf); if (error) return error; *p = buf[0] | ((uint64_t)buf[1] << 8) | ((uint64_t)buf[2] << 16) | ((uint64_t)buf[3] << 24) | ((uint64_t)buf[4] << 32) | ((uint64_t)buf[5] << 40) | ((uint64_t)buf[6] << 48) | ((uint64_t)buf[7] << 56); return 0; } /* * Signed integer formats * * XXX These assume two's-complement arithmetic. */ static int pb_decode_varint_s(struct decode *D, int64_t *p) { uint64_t u; int error; error = pb_decode_varint_u(D, &u); if (error) return error; *p = (int64_t)u; return 0; } static int pb_decode_zigzag(struct decode *D, int64_t *p) { uint64_t u; int error; error = pb_decode_varint_u(D, &u); if (error) return error; *p = (int64_t)(((u & 1) << 63) | (u >> 1)); return 0; } static int pb_decode_sfixed32(struct decode *D, int32_t *p) { uint32_t u; int error; error = pb_decode_fixed32(D, &u); if (error) return error; *p = (int32_t)u; return 0; } static int pb_decode_sfixed64(struct decode *D, int64_t *p) { uint64_t u; int error; error = pb_decode_fixed64(D, &u); if (error) return error; *p = (int64_t)u; return 0; } static int pb_decode_ieee32(struct decode *D, float *p) { union { float f; uint32_t i; } u; int error; error = pb_decode_fixed32(D, &u.i); if (error) return error; *p = u.f; return 0; } static int pb_decode_ieee64(struct decode *D, double *p) { union { double f; uint64_t i; } u; int error; error = pb_decode_fixed64(D, &u.i); if (error) return error; *p = u.f; return 0; } /* * Length-delimited fields */ static int pb_decode_length(struct decode *D, size_t *p) { uint64_t u; int error; error = pb_decode_varint(D, &u); if (error) return error; if (SIZE_MAX < u) return ERANGE; /* XXX What error code? */ *p = (size_t)u; return 0; } /* * Submessages */ struct submsg { size_t sm_size; struct decode *sm_D; }; static pb_decoder_callback_t decode_submsg_callback; static int pb_decode_submsg(struct decode *D, const struct pb_msgdesc *msgdesc, struct pb_msg_hdr *msg_hdr, size_t size) { struct submsg submsg = { .sm_size = size, .sm_D = D, }; struct decode Dsub = { .callback = &decode_submsg_callback, .arg = &submsg, }; if (msg_hdr->pbmh_msgdesc != msgdesc) return EINVAL; return pb_decode_by_hdr(&Dsub, msg_hdr); } #define MIN(A,B) ((A) < (B)? (A) : (B)) static int decode_submsg_callback(void *arg, void *buf, size_t *n) { struct submsg *const submsg = arg; size_t nreq, n0; int error; nreq = *n; n0 = MIN(nreq, submsg->sm_size); error = pb_decode_buf_partial(submsg->sm_D, buf, &n0); if (error) return error; assert(n0 <= MIN(nreq, submsg->sm_size)); *n = n0; submsg->sm_size -= n0; return 0; } /* * Trivial heap sort */ static size_t parent(size_t i) { return ((i - 1)/2); } static size_t left(size_t i) { return ((2*i) + 1); } static size_t right(size_t i) { return ((2*i) + 2); } static void swap32(uint32_t *a, uint32_t *b) { uint32_t t; t = *a; *a = *b; *b = t; } static void heapify32(uint32_t *a, size_t node, size_t end) { /* * XXX Arithmetic overflow is not an issue here because the * array size has a small bound, but it would be an issue if * you made copypasta of this code elsewhere. */ while (left(node) <= end) { size_t largest = node; if ((left(node) <= end) && (a[largest] < a[left(node)])) largest = left(node); if ((right(node) <= end) && (a[largest] < a[right(node)])) largest = right(node); if (largest == node) break; swap32(&a[node], &a[largest]); node = largest; } } static void sort32(uint32_t *a, size_t n) { size_t start, end; if (n < 2) return; end = (n - 1); start = parent(end); do heapify32(a, start, end); while (0 < start--); while (0 < end) { swap32(&a[0], &a[end--]); heapify32(a, 0, end); } }