stdio: start simplifying scanf limit tracking logic

Basically removing it from the __scanf_input structure and specializing
it at format sites. The reason is that pretending it's the end of the
stream after the limit is reached does not work because we have to
return EOF at end of stream but not when the limit is hit. So we have to
handle it explicitly, and since we do, no need to have it in the
structure too.
This commit is contained in:
Lephenixnoir 2024-01-14 19:28:36 +01:00
parent 2215b3c267
commit b11c059c0f
Signed by: Lephenixnoir
GPG Key ID: 1BBA026E13FC0495
12 changed files with 88 additions and 74 deletions

View File

@ -269,12 +269,11 @@ int __scanf(
// we will have to manage a given format
else if( format[pos] == '%' ) {
in->readmaxlength = -1;
in->readmaxlength = INT_MAX;
// main loop
loopagain:
pos++;
in->currentlength = 0;
switch( format[pos] ) {
// we need to decrypt the corresponding scanf set of character
@ -286,7 +285,7 @@ int __scanf(
// we need to assign the read char to the corresponding pointer
if (!skip) {
char *c = (char *) va_arg( *args, char* );
if (in->readmaxlength==(unsigned int)-1) {
if (in->readmaxlength==INT_MAX) {
for(;;) {
temp = __scanf_peek( in );
if (temp==EOF) return EOF;
@ -332,7 +331,7 @@ int __scanf(
else
{
if (in->readmaxlength==(unsigned int)-1) {
if (in->readmaxlength==INT_MAX) {
for(;;) {
temp = __scanf_peek( in );
if (temp==EOF) return EOF;
@ -433,16 +432,7 @@ int __scanf(
break;
}
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9': {
case '0' ... '9': {
user_length = user_length * 10 + (int) ( format[pos] - '0' );
in->readmaxlength = user_length;
goto loopagain;
@ -462,7 +452,8 @@ int __scanf(
bool use_unsigned = (f == 'o' || f == 'x' || f == 'X');
long long int temp;
err = __strto_int(in, base, NULL, &temp, use_unsigned);
err = __strto_int(in, base, NULL, &temp, use_unsigned,
in->readmaxlength);
if (err == EOF && validrets == 0) return EOF;
if (err != 0) return validrets;
if (skip) __scanf_store_i( temp, MODSKIP, args );
@ -482,7 +473,8 @@ int __scanf(
// read a double from the current input stream
// and store in the corresponding arg as a char by reference
long double temp;
err = __strto_fp( in, NULL, NULL, &temp );
err = __strto_fp( in, NULL, NULL, &temp,
in->readmaxlength);
if (err == EOF && validrets == 0) return EOF;
if (err != 0) return validrets;
if (skip) __scanf_store_d( temp, MODSKIP, args );
@ -495,9 +487,11 @@ int __scanf(
long int temp;
if (!skip) {
void *p = (void *) va_arg( *args, void** ); // get the adress of the target pointer (void**)
err = __strto_int( in, 0, p, NULL, true );
err = __strto_int( in, 0, p, NULL, true,
in->readmaxlength);
}
else err = __strto_int( in, 0, &temp, NULL, true );
else err = __strto_int( in, 0, &temp, NULL, true,
in->readmaxlength);
if (err == 0) validrets++;
else return validrets;
skip = false;
@ -508,7 +502,7 @@ int __scanf(
int temp;
if (!skip) {
char *c = (char *) va_arg( *args, char* );
if (in->readmaxlength==(unsigned int)-1) {
if (in->readmaxlength==INT_MAX) {
temp = __scanf_peek( in );
if (temp==EOF) return EOF;
else *c = __scanf_in( in );
@ -522,7 +516,7 @@ int __scanf(
}
}
else {
if (in->readmaxlength==(unsigned int)-1) {
if (in->readmaxlength==INT_MAX) {
temp = __scanf_peek( in );
if (temp==EOF) return EOF;
else {
@ -552,7 +546,7 @@ int __scanf(
__purge_space( in );
if (!skip) {
char *c = (char *) va_arg( *args, char* );
if (in->readmaxlength==(unsigned int)-1) {
if (in->readmaxlength==INT_MAX) {
loopstring:
temp = __scanf_peek( in );
if (temp==EOF && curstrlength==0) return validrets;
@ -582,7 +576,7 @@ int __scanf(
}
}
else {
if (in->readmaxlength==(unsigned int)-1) {
if (in->readmaxlength==INT_MAX) {
loopstringskip:
temp = __scanf_peek( in );
if (temp==EOF && curstrlength==0) return validrets;

View File

@ -20,7 +20,6 @@ struct __scanf_input {
// max char to read from the input stream as per user length modifier
unsigned int readmaxlength;
int currentlength;
// total number of char read so far in the current call of a XYscanf() function (to return a %n when required)
int readsofar;
@ -48,15 +47,20 @@ static inline int __scanf_in(struct __scanf_input *__in)
int c = __in->buffer;
__in->buffer = __scanf_fetch(__in);
__in->readsofar++;
__in->currentlength++;
return c;
}
/* Read the next byte and also decrease a total count of available reads. */
static inline int __scanf_in_limit(struct __scanf_input *__in, int *__N)
{
(*__N)--;
return __scanf_in(__in);
}
/* Peek the next byte without advancing. */
static inline int __scanf_peek(struct __scanf_input *__in)
{
return ((unsigned)__in->currentlength < __in->readmaxlength)
? __in->buffer : EOF;
return __in->buffer;
}
/* Close the input by unsending the buffer once finished. */

View File

@ -3,6 +3,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <limits.h>
#include "../stdio/stdio_p.h"
/*
@ -22,13 +23,17 @@
** On platforms where long is 32-bit, 64-bit operations are performed only if
** outll is non-NULL. This is because multiplications with overflow can be
** expensive.
**
** N is the bound on the number of characters to read. To disable the bound,
** specify INT_MAX.
*/
int __strto_int(
struct __scanf_input *__input,
int __base,
long *__outl,
long long *__outll,
bool __use_unsigned);
bool __use_unsigned,
int __N);
/*
** Parse a floating-point value from a string. This is the base function for
@ -42,6 +47,7 @@ int __strto_fp(
struct __scanf_input *__input,
double *__out,
float *__outf,
long double *__outl);
long double *__outl,
int __N);
#endif /*__STDLIB_P_H__*/

View File

@ -38,8 +38,8 @@
** -> In hexadecimal notation, we read as many bits as the mantissa of a long
** double, then later multiply by a power of 2. There are no approximations.
*/
static bool parse_digits(struct __scanf_input *input,
SIGNIFICAND_TYPE *digits, long *exponent, bool hexadecimal)
static int parse_digits(struct __scanf_input *input,
SIGNIFICAND_TYPE *digits, long *exponent, bool hexadecimal, int *N)
{
bool dot_found = false;
int digits_found=0, c=0;
@ -53,12 +53,14 @@ static bool parse_digits(struct __scanf_input *input,
int dot_character = '.';
int exp_character = (hexadecimal ? 'p' : 'e');
for(int i = 0; true; i++) {
for(int i = 0; *N >= 0; i++) {
c = __scanf_peek(input);
if(i == 0 && c == EOF)
return EOF;
if(!(isdigit(c) ||
(hexadecimal && isxdigit(c)) ||
(c == dot_character && !dot_found))) break;
__scanf_in(input);
__scanf_in_limit(input, N);
if(c == dot_character) {
dot_found = true;
@ -102,9 +104,10 @@ static bool parse_digits(struct __scanf_input *input,
set correctly */
struct __scanf_input backup = *input;
__scanf_in(input);
__scanf_in_limit(input, N);
long e = 0;
if(__strto_int(input, 10, &e, NULL, false) == 0)
// TODO: strto_fp: Pass limit to __strto_int
if(__strto_int(input, 10, &e, NULL, false, *N) == 0)
*exponent += e;
else
*input = backup;
@ -124,18 +127,21 @@ static bool expect(struct __scanf_input *input, char const *sequence)
}
int __strto_fp(struct __scanf_input *input, double *out, float *outf,
long double *outl)
long double *outl, int N)
{
input->currentlength = 0;
/* Skip initial whitespace */
while(isspace(__scanf_peek(input))) __scanf_in(input);
// TODO: strto_fp() doesn't support size limits well, affecting %5f etc.
if(N <= 0)
return EOF;
/* Read optional sign */
bool negative = false;
int sign = __scanf_peek(input);
if(sign == '-') negative = true;
if(sign == '-' || sign == '+') __scanf_in(input);
if(sign == '-' || sign == '+') __scanf_in_limit(input, &N);
int errno_value = 0;
bool valid = false;
@ -156,8 +162,10 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
if(__scanf_peek(input) == '(') {
while(i < 31) {
int c = __scanf_in(input);
int c = __scanf_in_limit(input, &N);
if(c == ')') break;
if(c == EOF || N <= 0)
return EOF;
arg[i++] = c;
}
arg[i] = 0;
@ -179,6 +187,9 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
if(outl) *outl = __builtin_infl();
valid = true;
}
else if(__scanf_peek(input) == EOF) {
return EOF;
}
else {
SIGNIFICAND_TYPE digits = 0;
long e = 0;
@ -187,9 +198,9 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
not 0x isn't a problem. */
bool hexa = false;
if(__scanf_peek(input) == '0') {
__scanf_in(input);
__scanf_in_limit(input, &N);
if(tolower(__scanf_peek(input)) == 'x') {
__scanf_in(input);
__scanf_in_limit(input, &N);
hexa = true;
}
/* Count the 0 as a digit */
@ -197,13 +208,19 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
}
if(hexa) {
valid |= parse_digits(input, &digits, &e, true);
int rc = parse_digits(input, &digits, &e, true, &N);
if(!valid && rc == EOF)
return EOF;
valid |= rc;
if(out) *out = (double)digits * exp2(e);
if(outf) *outf = (float)digits * exp2f(e);
if(outl) *outl = (long double)digits * exp2l(e);
}
else {
valid |= parse_digits(input, &digits, &e, false);
int rc = parse_digits(input, &digits, &e, false, &N);
if(!valid && rc == EOF)
return EOF;
valid |= rc;
if(out) *out = (double)digits * pow(10, e);
if(outf) *outf = (float)digits * powf(10, e);
if(outl) *outl = (long double)digits * powl(10, e);

View File

@ -5,20 +5,19 @@
#include <limits.h>
int __strto_int(struct __scanf_input *input, int base, long *outl,
long long *outll, bool use_unsigned)
long long *outll, bool use_unsigned, int N)
{
input->currentlength = 0;
/* Skip initial whitespace */
while(isspace(__scanf_peek(input))) __scanf_in(input);
if(N <= 0)
return EOF;
/* Accept a sign character */
bool negative = false;
int sign = __scanf_peek(input);
if(sign == EOF) return EOF;
if(sign == '-') negative = true;
if(sign == '-' || sign == '+') __scanf_in(input);
if(sign == '-' || sign == '+') __scanf_in_limit(input, &N);
/* Use unsigned variables as only these have defined overflow */
unsigned long xl = 0;
@ -29,10 +28,10 @@ int __strto_int(struct __scanf_input *input, int base, long *outl,
/* Read prefixes and determine base */
if(__scanf_peek(input) == '0') {
__scanf_in(input);
__scanf_in_limit(input, &N);
if((base == 0 || base == 16) &&
tolower(__scanf_peek(input)) == 'x') {
__scanf_in(input);
__scanf_in_limit(input, &N);
base = 16;
}
/* If we don't consume the x then count the 0 as a digit */
@ -40,13 +39,13 @@ int __strto_int(struct __scanf_input *input, int base, long *outl,
if(base == 0)
base = 8;
}
else if(__scanf_peek(input) == EOF)
if(!valid && (N <= 0 || __scanf_peek(input) == EOF))
return EOF;
if(base == 0)
base = 10;
/* Read digits */
while(1) {
while(N > 0) {
int v = -1;
int c = __scanf_peek(input);
if(isdigit(c)) v = c - '0';
@ -71,7 +70,7 @@ int __strto_int(struct __scanf_input *input, int base, long *outl,
errno_value = ERANGE;
}
__scanf_in(input);
__scanf_in_limit(input, &N);
}
/* Handle sign and range */

View File

@ -7,9 +7,9 @@ double strtod(char const * restrict ptr, char ** restrict endptr)
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 };
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_fp(&in, &d, NULL, NULL);
int err = __strto_fp(&in, &d, NULL, NULL, INT_MAX);
__scanf_end(&in);
if(err != 0)

View File

@ -7,9 +7,9 @@ float strtof(char const * restrict ptr, char ** restrict endptr)
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 };
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_fp(&in, NULL, &f, NULL);
int err = __strto_fp(&in, NULL, &f, NULL, INT_MAX);
__scanf_end(&in);
if(err != 0)

View File

@ -7,15 +7,9 @@ long int strtol(char const * restrict ptr, char ** restrict endptr, int base)
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = {
.str = ptr,
.fp = NULL,
.readmaxlength = -1,
.currentlength = 0,
.readsofar = 0,
};
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_int(&in, base, &n, NULL, false);
int err = __strto_int(&in, base, &n, NULL, false, INT_MAX);
__scanf_end(&in);
if(err != 0)

View File

@ -7,9 +7,9 @@ long double strtold(char const * restrict ptr, char ** restrict endptr)
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 };
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_fp(&in, NULL, NULL, &ld);
int err = __strto_fp(&in, NULL, NULL, &ld, INT_MAX);
__scanf_end(&in);
if(err != 0)

View File

@ -8,9 +8,9 @@ long long int strtoll(char const * restrict ptr, char ** restrict endptr,
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 };
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_int(&in, base, NULL, &n, false);
int err = __strto_int(&in, base, NULL, &n, false, INT_MAX);
__scanf_end(&in);
if(err != 0)

View File

@ -8,9 +8,9 @@ unsigned long int strtoul(char const * restrict ptr, char ** restrict endptr,
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 };
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_int(&in, base, (long *)&n, NULL, true);
int err = __strto_int(&in, base, (long *)&n, NULL, true, INT_MAX);
__scanf_end(&in);
if(err != 0)

View File

@ -8,9 +8,9 @@ unsigned long long int strtoull(char const * restrict ptr,
if(endptr)
*endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 };
struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in);
int err = __strto_int(&in, base, NULL, (long long *)&n, true);
int err = __strto_int(&in, base, NULL, (long long *)&n, true, INT_MAX);
__scanf_end(&in);
if(err != 0)