From b11c059c0fecf2d030ffdc9a7a2584f28f497760 Mon Sep 17 00:00:00 2001 From: Lephenixnoir Date: Sun, 14 Jan 2024 19:28:36 +0100 Subject: [PATCH] stdio: start simplifying scanf limit tracking logic Basically removing it from the __scanf_input structure and specializing it at format sites. The reason is that pretending it's the end of the stream after the limit is reached does not work because we have to return EOF at end of stream but not when the limit is hit. So we have to handle it explicitly, and since we do, no need to have it in the structure too. --- src/stdio/scanf/scan.c | 38 ++++++++++++++-------------------- src/stdio/stdio_p.h | 12 +++++++---- src/stdlib/stdlib_p.h | 10 +++++++-- src/stdlib/strto_fp.c | 47 ++++++++++++++++++++++++++++-------------- src/stdlib/strto_int.c | 21 +++++++++---------- src/stdlib/strtod.c | 4 ++-- src/stdlib/strtof.c | 4 ++-- src/stdlib/strtol.c | 10 ++------- src/stdlib/strtold.c | 4 ++-- src/stdlib/strtoll.c | 4 ++-- src/stdlib/strtoul.c | 4 ++-- src/stdlib/strtoull.c | 4 ++-- 12 files changed, 88 insertions(+), 74 deletions(-) diff --git a/src/stdio/scanf/scan.c b/src/stdio/scanf/scan.c index 9bef20b..9643852 100644 --- a/src/stdio/scanf/scan.c +++ b/src/stdio/scanf/scan.c @@ -269,12 +269,11 @@ int __scanf( // we will have to manage a given format else if( format[pos] == '%' ) { - in->readmaxlength = -1; + in->readmaxlength = INT_MAX; // main loop loopagain: pos++; - in->currentlength = 0; switch( format[pos] ) { // we need to decrypt the corresponding scanf set of character @@ -286,7 +285,7 @@ int __scanf( // we need to assign the read char to the corresponding pointer if (!skip) { char *c = (char *) va_arg( *args, char* ); - if (in->readmaxlength==(unsigned int)-1) { + if (in->readmaxlength==INT_MAX) { for(;;) { temp = __scanf_peek( in ); if (temp==EOF) return EOF; @@ -332,7 +331,7 @@ int __scanf( else { - if (in->readmaxlength==(unsigned int)-1) { + if (in->readmaxlength==INT_MAX) { for(;;) { temp = __scanf_peek( in ); if (temp==EOF) return EOF; @@ -433,16 +432,7 @@ int __scanf( break; } - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': { + case '0' ... '9': { user_length = user_length * 10 + (int) ( format[pos] - '0' ); in->readmaxlength = user_length; goto loopagain; @@ -462,7 +452,8 @@ int __scanf( bool use_unsigned = (f == 'o' || f == 'x' || f == 'X'); long long int temp; - err = __strto_int(in, base, NULL, &temp, use_unsigned); + err = __strto_int(in, base, NULL, &temp, use_unsigned, + in->readmaxlength); if (err == EOF && validrets == 0) return EOF; if (err != 0) return validrets; if (skip) __scanf_store_i( temp, MODSKIP, args ); @@ -482,7 +473,8 @@ int __scanf( // read a double from the current input stream // and store in the corresponding arg as a char by reference long double temp; - err = __strto_fp( in, NULL, NULL, &temp ); + err = __strto_fp( in, NULL, NULL, &temp, + in->readmaxlength); if (err == EOF && validrets == 0) return EOF; if (err != 0) return validrets; if (skip) __scanf_store_d( temp, MODSKIP, args ); @@ -495,9 +487,11 @@ int __scanf( long int temp; if (!skip) { void *p = (void *) va_arg( *args, void** ); // get the adress of the target pointer (void**) - err = __strto_int( in, 0, p, NULL, true ); + err = __strto_int( in, 0, p, NULL, true, + in->readmaxlength); } - else err = __strto_int( in, 0, &temp, NULL, true ); + else err = __strto_int( in, 0, &temp, NULL, true, + in->readmaxlength); if (err == 0) validrets++; else return validrets; skip = false; @@ -508,7 +502,7 @@ int __scanf( int temp; if (!skip) { char *c = (char *) va_arg( *args, char* ); - if (in->readmaxlength==(unsigned int)-1) { + if (in->readmaxlength==INT_MAX) { temp = __scanf_peek( in ); if (temp==EOF) return EOF; else *c = __scanf_in( in ); @@ -522,7 +516,7 @@ int __scanf( } } else { - if (in->readmaxlength==(unsigned int)-1) { + if (in->readmaxlength==INT_MAX) { temp = __scanf_peek( in ); if (temp==EOF) return EOF; else { @@ -552,7 +546,7 @@ int __scanf( __purge_space( in ); if (!skip) { char *c = (char *) va_arg( *args, char* ); - if (in->readmaxlength==(unsigned int)-1) { + if (in->readmaxlength==INT_MAX) { loopstring: temp = __scanf_peek( in ); if (temp==EOF && curstrlength==0) return validrets; @@ -582,7 +576,7 @@ int __scanf( } } else { - if (in->readmaxlength==(unsigned int)-1) { + if (in->readmaxlength==INT_MAX) { loopstringskip: temp = __scanf_peek( in ); if (temp==EOF && curstrlength==0) return validrets; diff --git a/src/stdio/stdio_p.h b/src/stdio/stdio_p.h index 792ea7b..8dae131 100644 --- a/src/stdio/stdio_p.h +++ b/src/stdio/stdio_p.h @@ -20,7 +20,6 @@ struct __scanf_input { // max char to read from the input stream as per user length modifier unsigned int readmaxlength; - int currentlength; // total number of char read so far in the current call of a XYscanf() function (to return a %n when required) int readsofar; @@ -48,15 +47,20 @@ static inline int __scanf_in(struct __scanf_input *__in) int c = __in->buffer; __in->buffer = __scanf_fetch(__in); __in->readsofar++; - __in->currentlength++; return c; } +/* Read the next byte and also decrease a total count of available reads. */ +static inline int __scanf_in_limit(struct __scanf_input *__in, int *__N) +{ + (*__N)--; + return __scanf_in(__in); +} + /* Peek the next byte without advancing. */ static inline int __scanf_peek(struct __scanf_input *__in) { - return ((unsigned)__in->currentlength < __in->readmaxlength) - ? __in->buffer : EOF; + return __in->buffer; } /* Close the input by unsending the buffer once finished. */ diff --git a/src/stdlib/stdlib_p.h b/src/stdlib/stdlib_p.h index c368086..e2c0c2c 100644 --- a/src/stdlib/stdlib_p.h +++ b/src/stdlib/stdlib_p.h @@ -3,6 +3,7 @@ #include #include +#include #include "../stdio/stdio_p.h" /* @@ -22,13 +23,17 @@ ** On platforms where long is 32-bit, 64-bit operations are performed only if ** outll is non-NULL. This is because multiplications with overflow can be ** expensive. +** +** N is the bound on the number of characters to read. To disable the bound, +** specify INT_MAX. */ int __strto_int( struct __scanf_input *__input, int __base, long *__outl, long long *__outll, - bool __use_unsigned); + bool __use_unsigned, + int __N); /* ** Parse a floating-point value from a string. This is the base function for @@ -42,6 +47,7 @@ int __strto_fp( struct __scanf_input *__input, double *__out, float *__outf, - long double *__outl); + long double *__outl, + int __N); #endif /*__STDLIB_P_H__*/ diff --git a/src/stdlib/strto_fp.c b/src/stdlib/strto_fp.c index 72a4fd1..8f6e2cd 100644 --- a/src/stdlib/strto_fp.c +++ b/src/stdlib/strto_fp.c @@ -38,8 +38,8 @@ ** -> In hexadecimal notation, we read as many bits as the mantissa of a long ** double, then later multiply by a power of 2. There are no approximations. */ -static bool parse_digits(struct __scanf_input *input, - SIGNIFICAND_TYPE *digits, long *exponent, bool hexadecimal) +static int parse_digits(struct __scanf_input *input, + SIGNIFICAND_TYPE *digits, long *exponent, bool hexadecimal, int *N) { bool dot_found = false; int digits_found=0, c=0; @@ -53,12 +53,14 @@ static bool parse_digits(struct __scanf_input *input, int dot_character = '.'; int exp_character = (hexadecimal ? 'p' : 'e'); - for(int i = 0; true; i++) { + for(int i = 0; *N >= 0; i++) { c = __scanf_peek(input); + if(i == 0 && c == EOF) + return EOF; if(!(isdigit(c) || (hexadecimal && isxdigit(c)) || (c == dot_character && !dot_found))) break; - __scanf_in(input); + __scanf_in_limit(input, N); if(c == dot_character) { dot_found = true; @@ -102,9 +104,10 @@ static bool parse_digits(struct __scanf_input *input, set correctly */ struct __scanf_input backup = *input; - __scanf_in(input); + __scanf_in_limit(input, N); long e = 0; - if(__strto_int(input, 10, &e, NULL, false) == 0) + // TODO: strto_fp: Pass limit to __strto_int + if(__strto_int(input, 10, &e, NULL, false, *N) == 0) *exponent += e; else *input = backup; @@ -124,18 +127,21 @@ static bool expect(struct __scanf_input *input, char const *sequence) } int __strto_fp(struct __scanf_input *input, double *out, float *outf, - long double *outl) + long double *outl, int N) { - input->currentlength = 0; - /* Skip initial whitespace */ while(isspace(__scanf_peek(input))) __scanf_in(input); + // TODO: strto_fp() doesn't support size limits well, affecting %5f etc. + + if(N <= 0) + return EOF; + /* Read optional sign */ bool negative = false; int sign = __scanf_peek(input); if(sign == '-') negative = true; - if(sign == '-' || sign == '+') __scanf_in(input); + if(sign == '-' || sign == '+') __scanf_in_limit(input, &N); int errno_value = 0; bool valid = false; @@ -156,8 +162,10 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf, if(__scanf_peek(input) == '(') { while(i < 31) { - int c = __scanf_in(input); + int c = __scanf_in_limit(input, &N); if(c == ')') break; + if(c == EOF || N <= 0) + return EOF; arg[i++] = c; } arg[i] = 0; @@ -179,6 +187,9 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf, if(outl) *outl = __builtin_infl(); valid = true; } + else if(__scanf_peek(input) == EOF) { + return EOF; + } else { SIGNIFICAND_TYPE digits = 0; long e = 0; @@ -187,9 +198,9 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf, not 0x isn't a problem. */ bool hexa = false; if(__scanf_peek(input) == '0') { - __scanf_in(input); + __scanf_in_limit(input, &N); if(tolower(__scanf_peek(input)) == 'x') { - __scanf_in(input); + __scanf_in_limit(input, &N); hexa = true; } /* Count the 0 as a digit */ @@ -197,13 +208,19 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf, } if(hexa) { - valid |= parse_digits(input, &digits, &e, true); + int rc = parse_digits(input, &digits, &e, true, &N); + if(!valid && rc == EOF) + return EOF; + valid |= rc; if(out) *out = (double)digits * exp2(e); if(outf) *outf = (float)digits * exp2f(e); if(outl) *outl = (long double)digits * exp2l(e); } else { - valid |= parse_digits(input, &digits, &e, false); + int rc = parse_digits(input, &digits, &e, false, &N); + if(!valid && rc == EOF) + return EOF; + valid |= rc; if(out) *out = (double)digits * pow(10, e); if(outf) *outf = (float)digits * powf(10, e); if(outl) *outl = (long double)digits * powl(10, e); diff --git a/src/stdlib/strto_int.c b/src/stdlib/strto_int.c index b75e43a..2da6dc7 100644 --- a/src/stdlib/strto_int.c +++ b/src/stdlib/strto_int.c @@ -5,20 +5,19 @@ #include int __strto_int(struct __scanf_input *input, int base, long *outl, - long long *outll, bool use_unsigned) + long long *outll, bool use_unsigned, int N) { - input->currentlength = 0; - - /* Skip initial whitespace */ while(isspace(__scanf_peek(input))) __scanf_in(input); + if(N <= 0) + return EOF; + /* Accept a sign character */ bool negative = false; int sign = __scanf_peek(input); - if(sign == EOF) return EOF; if(sign == '-') negative = true; - if(sign == '-' || sign == '+') __scanf_in(input); + if(sign == '-' || sign == '+') __scanf_in_limit(input, &N); /* Use unsigned variables as only these have defined overflow */ unsigned long xl = 0; @@ -29,10 +28,10 @@ int __strto_int(struct __scanf_input *input, int base, long *outl, /* Read prefixes and determine base */ if(__scanf_peek(input) == '0') { - __scanf_in(input); + __scanf_in_limit(input, &N); if((base == 0 || base == 16) && tolower(__scanf_peek(input)) == 'x') { - __scanf_in(input); + __scanf_in_limit(input, &N); base = 16; } /* If we don't consume the x then count the 0 as a digit */ @@ -40,13 +39,13 @@ int __strto_int(struct __scanf_input *input, int base, long *outl, if(base == 0) base = 8; } - else if(__scanf_peek(input) == EOF) + if(!valid && (N <= 0 || __scanf_peek(input) == EOF)) return EOF; if(base == 0) base = 10; /* Read digits */ - while(1) { + while(N > 0) { int v = -1; int c = __scanf_peek(input); if(isdigit(c)) v = c - '0'; @@ -71,7 +70,7 @@ int __strto_int(struct __scanf_input *input, int base, long *outl, errno_value = ERANGE; } - __scanf_in(input); + __scanf_in_limit(input, &N); } /* Handle sign and range */ diff --git a/src/stdlib/strtod.c b/src/stdlib/strtod.c index c7c36ca..f58d08b 100644 --- a/src/stdlib/strtod.c +++ b/src/stdlib/strtod.c @@ -7,9 +7,9 @@ double strtod(char const * restrict ptr, char ** restrict endptr) if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_fp(&in, &d, NULL, NULL); + int err = __strto_fp(&in, &d, NULL, NULL, INT_MAX); __scanf_end(&in); if(err != 0) diff --git a/src/stdlib/strtof.c b/src/stdlib/strtof.c index 4459dbe..a1cf870 100644 --- a/src/stdlib/strtof.c +++ b/src/stdlib/strtof.c @@ -7,9 +7,9 @@ float strtof(char const * restrict ptr, char ** restrict endptr) if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_fp(&in, NULL, &f, NULL); + int err = __strto_fp(&in, NULL, &f, NULL, INT_MAX); __scanf_end(&in); if(err != 0) diff --git a/src/stdlib/strtol.c b/src/stdlib/strtol.c index e584d35..1169910 100644 --- a/src/stdlib/strtol.c +++ b/src/stdlib/strtol.c @@ -7,15 +7,9 @@ long int strtol(char const * restrict ptr, char ** restrict endptr, int base) if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { - .str = ptr, - .fp = NULL, - .readmaxlength = -1, - .currentlength = 0, - .readsofar = 0, - }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_int(&in, base, &n, NULL, false); + int err = __strto_int(&in, base, &n, NULL, false, INT_MAX); __scanf_end(&in); if(err != 0) diff --git a/src/stdlib/strtold.c b/src/stdlib/strtold.c index e4d8f6d..2fb748f 100644 --- a/src/stdlib/strtold.c +++ b/src/stdlib/strtold.c @@ -7,9 +7,9 @@ long double strtold(char const * restrict ptr, char ** restrict endptr) if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_fp(&in, NULL, NULL, &ld); + int err = __strto_fp(&in, NULL, NULL, &ld, INT_MAX); __scanf_end(&in); if(err != 0) diff --git a/src/stdlib/strtoll.c b/src/stdlib/strtoll.c index 6fd19b4..4c616df 100644 --- a/src/stdlib/strtoll.c +++ b/src/stdlib/strtoll.c @@ -8,9 +8,9 @@ long long int strtoll(char const * restrict ptr, char ** restrict endptr, if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_int(&in, base, NULL, &n, false); + int err = __strto_int(&in, base, NULL, &n, false, INT_MAX); __scanf_end(&in); if(err != 0) diff --git a/src/stdlib/strtoul.c b/src/stdlib/strtoul.c index b0c7148..b911386 100644 --- a/src/stdlib/strtoul.c +++ b/src/stdlib/strtoul.c @@ -8,9 +8,9 @@ unsigned long int strtoul(char const * restrict ptr, char ** restrict endptr, if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_int(&in, base, (long *)&n, NULL, true); + int err = __strto_int(&in, base, (long *)&n, NULL, true, INT_MAX); __scanf_end(&in); if(err != 0) diff --git a/src/stdlib/strtoull.c b/src/stdlib/strtoull.c index 4061286..3ab0d8e 100644 --- a/src/stdlib/strtoull.c +++ b/src/stdlib/strtoull.c @@ -8,9 +8,9 @@ unsigned long long int strtoull(char const * restrict ptr, if(endptr) *endptr = (char *)ptr; - struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; + struct __scanf_input in = { .str = ptr, .fp = NULL }; __scanf_start(&in); - int err = __strto_int(&in, base, NULL, (long long *)&n, true); + int err = __strto_int(&in, base, NULL, (long long *)&n, true, INT_MAX); __scanf_end(&in); if(err != 0)