diff options
Diffstat (limited to 'lib/mpi/mpicoder.c')
-rw-r--r-- | lib/mpi/mpicoder.c | 249 |
1 files changed, 81 insertions, 168 deletions
diff --git a/lib/mpi/mpicoder.c b/lib/mpi/mpicoder.c index 747606f..c6272ae 100644 --- a/lib/mpi/mpicoder.c +++ b/lib/mpi/mpicoder.c @@ -21,6 +21,7 @@ #include <linux/bitops.h> #include <linux/count_zeros.h> #include <linux/byteorder/generic.h> +#include <linux/scatterlist.h> #include <linux/string.h> #include "mpi-internal.h" @@ -50,9 +51,7 @@ MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes) return NULL; } if (nbytes > 0) - nbits -= count_leading_zeros(buffer[0]); - else - nbits = 0; + nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8); nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); val = mpi_alloc(nlimbs); @@ -82,50 +81,30 @@ EXPORT_SYMBOL_GPL(mpi_read_raw_data); MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread) { const uint8_t *buffer = xbuffer; - int i, j; - unsigned nbits, nbytes, nlimbs, nread = 0; - mpi_limb_t a; - MPI val = NULL; + unsigned int nbits, nbytes; + MPI val; if (*ret_nread < 2) - goto leave; + return ERR_PTR(-EINVAL); nbits = buffer[0] << 8 | buffer[1]; if (nbits > MAX_EXTERN_MPI_BITS) { pr_info("MPI: mpi too large (%u bits)\n", nbits); - goto leave; + return ERR_PTR(-EINVAL); } - buffer += 2; - nread = 2; nbytes = DIV_ROUND_UP(nbits, 8); - nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); - val = mpi_alloc(nlimbs); - if (!val) - return NULL; - i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB; - i %= BYTES_PER_MPI_LIMB; - val->nbits = nbits; - j = val->nlimbs = nlimbs; - val->sign = 0; - for (; j > 0; j--) { - a = 0; - for (; i < BYTES_PER_MPI_LIMB; i++) { - if (++nread > *ret_nread) { - printk - ("MPI: mpi larger than buffer nread=%d ret_nread=%d\n", - nread, *ret_nread); - goto leave; - } - a <<= 8; - a |= *buffer++; - } - i = 0; - val->d[j - 1] = a; + if (nbytes + 2 > *ret_nread) { + pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n", + nbytes, *ret_nread); + return ERR_PTR(-EINVAL); } -leave: - *ret_nread = nread; + val = mpi_read_raw_data(buffer + 2, nbytes); + if (!val) + return ERR_PTR(-ENOMEM); + + *ret_nread = nbytes + 2; return val; } EXPORT_SYMBOL_GPL(mpi_read_from_buffer); @@ -250,82 +229,6 @@ void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign) } EXPORT_SYMBOL_GPL(mpi_get_buffer); -/**************** - * Use BUFFER to update MPI. - */ -int mpi_set_buffer(MPI a, const void *xbuffer, unsigned nbytes, int sign) -{ - const uint8_t *buffer = xbuffer, *p; - mpi_limb_t alimb; - int nlimbs; - int i; - - nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); - if (RESIZE_IF_NEEDED(a, nlimbs) < 0) - return -ENOMEM; - a->sign = sign; - - for (i = 0, p = buffer + nbytes - 1; p >= buffer + BYTES_PER_MPI_LIMB;) { -#if BYTES_PER_MPI_LIMB == 4 - alimb = (mpi_limb_t) *p--; - alimb |= (mpi_limb_t) *p-- << 8; - alimb |= (mpi_limb_t) *p-- << 16; - alimb |= (mpi_limb_t) *p-- << 24; -#elif BYTES_PER_MPI_LIMB == 8 - alimb = (mpi_limb_t) *p--; - alimb |= (mpi_limb_t) *p-- << 8; - alimb |= (mpi_limb_t) *p-- << 16; - alimb |= (mpi_limb_t) *p-- << 24; - alimb |= (mpi_limb_t) *p-- << 32; - alimb |= (mpi_limb_t) *p-- << 40; - alimb |= (mpi_limb_t) *p-- << 48; - alimb |= (mpi_limb_t) *p-- << 56; -#else -#error please implement for this limb size. -#endif - a->d[i++] = alimb; - } - if (p >= buffer) { -#if BYTES_PER_MPI_LIMB == 4 - alimb = *p--; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 8; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 16; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 24; -#elif BYTES_PER_MPI_LIMB == 8 - alimb = (mpi_limb_t) *p--; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 8; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 16; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 24; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 32; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 40; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 48; - if (p >= buffer) - alimb |= (mpi_limb_t) *p-- << 56; -#else -#error please implement for this limb size. -#endif - a->d[i++] = alimb; - } - a->nlimbs = i; - - if (i != nlimbs) { - pr_emerg("MPI: mpi_set_buffer: Assertion failed (%d != %d)", i, - nlimbs); - BUG(); - } - return 0; -} -EXPORT_SYMBOL_GPL(mpi_set_buffer); - /** * mpi_write_to_sgl() - Funnction exports MPI to an sgl (msb first) * @@ -335,16 +238,13 @@ EXPORT_SYMBOL_GPL(mpi_set_buffer); * @a: a multi precision integer * @sgl: scatterlist to write to. Needs to be at least * mpi_get_size(a) long. - * @nbytes: in/out param - it has the be set to the maximum number of - * bytes that can be written to sgl. This has to be at least - * the size of the integer a. On return it receives the actual - * length of the data written on success or the data that would - * be written if buffer was too small. + * @nbytes: the number of bytes to write. Leading bytes will be + * filled with zero. * @sign: if not NULL, it will be set to the sign of a. * * Return: 0 on success or error code in case of error */ -int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes, +int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes, int *sign) { u8 *p, *p2; @@ -356,55 +256,60 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes, #error please implement for this limb size. #endif unsigned int n = mpi_get_size(a); - int i, x, y = 0, lzeros, buf_len; - - if (!nbytes) - return -EINVAL; + struct sg_mapping_iter miter; + int i, x, buf_len; + int nents; if (sign) *sign = a->sign; - lzeros = count_lzeros(a); - - if (*nbytes < n - lzeros) { - *nbytes = n - lzeros; + if (nbytes < n) return -EOVERFLOW; - } - *nbytes = n - lzeros; - buf_len = sgl->length; - p2 = sg_virt(sgl); + nents = sg_nents_for_len(sgl, nbytes); + if (nents < 0) + return -EINVAL; - for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB, - lzeros %= BYTES_PER_MPI_LIMB; - i >= 0; i--) { + sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG); + sg_miter_next(&miter); + buf_len = miter.length; + p2 = miter.addr; + + while (nbytes > n) { + i = min_t(unsigned, nbytes - n, buf_len); + memset(p2, 0, i); + p2 += i; + nbytes -= i; + + buf_len -= i; + if (!buf_len) { + sg_miter_next(&miter); + buf_len = miter.length; + p2 = miter.addr; + } + } + + for (i = a->nlimbs - 1; i >= 0; i--) { #if BYTES_PER_MPI_LIMB == 4 - alimb = cpu_to_be32(a->d[i]); + alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0; #elif BYTES_PER_MPI_LIMB == 8 - alimb = cpu_to_be64(a->d[i]); + alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0; #else #error please implement for this limb size. #endif - if (lzeros) { - y = lzeros; - lzeros = 0; - } + p = (u8 *)&alimb; - p = (u8 *)&alimb + y; - - for (x = 0; x < sizeof(alimb) - y; x++) { - if (!buf_len) { - sgl = sg_next(sgl); - if (!sgl) - return -EINVAL; - buf_len = sgl->length; - p2 = sg_virt(sgl); - } + for (x = 0; x < sizeof(alimb); x++) { *p2++ = *p++; - buf_len--; + if (!--buf_len) { + sg_miter_next(&miter); + buf_len = miter.length; + p2 = miter.addr; + } } - y = 0; } + + sg_miter_stop(&miter); return 0; } EXPORT_SYMBOL_GPL(mpi_write_to_sgl); @@ -424,19 +329,23 @@ EXPORT_SYMBOL_GPL(mpi_write_to_sgl); */ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes) { - struct scatterlist *sg; - int x, i, j, z, lzeros, ents; + struct sg_mapping_iter miter; unsigned int nbits, nlimbs; + int x, j, z, lzeros, ents; + unsigned int len; + const u8 *buff; mpi_limb_t a; MPI val = NULL; - lzeros = 0; - ents = sg_nents(sgl); + ents = sg_nents_for_len(sgl, nbytes); + if (ents < 0) + return NULL; - for_each_sg(sgl, sg, ents, i) { - const u8 *buff = sg_virt(sg); - int len = sg->length; + sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG); + lzeros = 0; + len = 0; + while (nbytes > 0) { while (len && !*buff) { lzeros++; len--; @@ -446,12 +355,14 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes) if (len && *buff) break; - ents--; + sg_miter_next(&miter); + buff = miter.addr; + len = miter.length; + nbytes -= lzeros; lzeros = 0; } - sgl = sg; nbytes -= lzeros; nbits = nbytes * 8; if (nbits > MAX_EXTERN_MPI_BITS) { @@ -460,8 +371,7 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes) } if (nbytes > 0) - nbits -= count_leading_zeros(*(u8 *)(sg_virt(sgl) + lzeros)) - - (BITS_PER_LONG - 8); + nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8); nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); val = mpi_alloc(nlimbs); @@ -480,21 +390,24 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes) z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB; z %= BYTES_PER_MPI_LIMB; - for_each_sg(sgl, sg, ents, i) { - const u8 *buffer = sg_virt(sg) + lzeros; - int len = sg->length - lzeros; - + for (;;) { for (x = 0; x < len; x++) { a <<= 8; - a |= *buffer++; + a |= *buff++; if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) { val->d[j--] = a; a = 0; } } z += x; - lzeros = 0; + + if (!sg_miter_next(&miter)) + break; + + buff = miter.addr; + len = miter.length; } + return val; } EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl); |