diff --git a/test/scan-test.cc b/test/scan-test.cc index 84603cf1..993a4b3f 100644 --- a/test/scan-test.cc +++ b/test/scan-test.cc @@ -28,6 +28,8 @@ TEST(scan_test, read_int) { EXPECT_EQ(n, 42); fmt::scan("-42", "{}", n); EXPECT_EQ(n, -42); + EXPECT_THROW_MSG(fmt::scan(std::to_string(INT_MAX + 1u), "{}", n), + fmt::format_error, "number is too big"); } TEST(scan_test, read_longlong) { diff --git a/test/scan.h b/test/scan.h index db949388..d5831528 100644 --- a/test/scan.h +++ b/test/scan.h @@ -136,51 +136,90 @@ class string_scan_buffer : public scan_buffer { : scan_buffer(s.begin(), s.end(), true) {} }; -class file_scan_buffer : public scan_buffer { - private: - FILE* file_; - char next_; - bool filled_ = false; +// A FILE wrapper. F is FILE defined as a template parameter to make +// system-specific API detection work. +template class file_base { + protected: + F* file_; + + public: + file_base(F* file) : file_(file) {} + operator F*() const { return file_; } + + // Reads a code unit from the stream. + auto get() -> int { + int result = getc(file_); + if (result == EOF && ferror(file_) != 0) + FMT_THROW(system_error(errno, FMT_STRING("getc failed"))); + return result; + } + + // Puts the code unit back into the stream buffer. + void unget(char c) { + if (ungetc(c, file_) == EOF) + FMT_THROW(system_error(errno, FMT_STRING("ungetc failed"))); + } +}; + +// A FILE wrapper for Apple's libc. +template class apple_file : public file_base { + public: + using file_base::file_base; // Returns the file's read buffer as a string_view. - template - auto get_buffer(F* file) -> string_view { // Apple libc - char* ptr = reinterpret_cast(file->_p); - return {ptr, to_unsigned(file->_r)}; + auto buffer() const -> string_view { + return {reinterpret_cast(this->file_->_p), + to_unsigned(this->file_->_r)}; } - auto get_buffer(...) -> string_view { - return {&next_, (filled_ ? 1u : 0u)}; +}; + +// A fallback FILE wrapper. +template class fallback_file : public file_base { + private: + char next_; // The next unconsumed character in the buffer. + bool has_next_ = false; + + public: + using file_base::file_base; + + auto buffer() const -> string_view { return {&next_, has_next_ ? 1u : 0u}; } + + auto get() -> int { + has_next_ = false; + return file_base::get(); } + void unget(char c) { + file_base::unget(c); + next_ = c; + has_next_ = true; + } +}; + +template +auto get_file(F* file, int) -> apple_file { + return file; +} +auto get_file(FILE* file, ...) -> fallback_file { return file; } + +class file_scan_buffer : public scan_buffer { + private: + decltype(get_file(static_cast(nullptr), 0)) file_; + void do_fill() { - string_view buf = get_buffer(file_); + string_view buf = file_.buffer(); if (buf.size() == 0) { - int result = getc(file_); - if (result != EOF) { - // Put the character back since we are only filling the buffer. - if (ungetc(result, file_) == EOF) - FMT_THROW(system_error(errno, FMT_STRING("I/O error"))); - next_ = static_cast(result); - filled_ = true; - } else { - if (ferror(file_) != 0) - FMT_THROW(system_error(errno, FMT_STRING("I/O error"))); - filled_ = false; - } - buf = get_buffer(file_); + int c = file_.get(); + // Put the character back since we are only filling the buffer. + if (c != EOF) file_.unget(static_cast(c)); + buf = file_.buffer(); } this->set(buf.begin(), buf.end()); } void consume() override { // Consume the current buffer content. - string_view buf = get_buffer(file_); - for (size_t i = 0, n = buf.size(); i != n; ++i) { - int result = getc(file_); - if (result == EOF && ferror(file_) != 0) - FMT_THROW(system_error(errno, FMT_STRING("I/O error"))); - } - filled_ = false; + for (size_t i = 0, n = file_.buffer().size(); i != n; ++i) file_.get(); do_fill(); } @@ -317,18 +356,30 @@ struct scan_handler : error_handler { scan_arg arg_; template auto read_uint() -> T { - T value = 0; auto it = scan_ctx_.begin(), end = scan_ctx_.end(); - char c = it != end ? *it : '\0'; + char c = it != end ? *it : '\0', prev_digit; if (c < '0' || c > '9') on_error("invalid input"); + + int num_digits = 0; + T value = 0, prev = 0; do { + prev = value; value = value * 10 + static_cast(c - '0'); + prev_digit = c; c = *++it; + ++num_digits; if (c < '0' || c > '9') break; - // TODO: check overflow } while (it != end); scan_ctx_.advance_to(it); - return value; + + // Check overflow. + if (num_digits <= std::numeric_limits::digits10) return value; + const unsigned max = to_unsigned((std::numeric_limits::max)()); + if (num_digits == std::numeric_limits::digits10 + 1 && + prev * 10ull + unsigned(prev_digit - '0') <= max) { + return value; + } + throw format_error("number is too big"); } template auto read_int() -> T {