Moved nesting decrement logic to class NestingLimit

This commit is contained in:
Benoit Blanchon
2020-02-13 16:54:18 +01:00
parent 6e52f242b2
commit fbffadb2cf
4 changed files with 88 additions and 84 deletions

View File

@ -5,13 +5,25 @@
#pragma once #pragma once
#include <ArduinoJson/Namespace.hpp> #include <ArduinoJson/Namespace.hpp>
#include <ArduinoJson/Polyfills/assert.hpp>
namespace ARDUINOJSON_NAMESPACE { namespace ARDUINOJSON_NAMESPACE {
struct NestingLimit { class NestingLimit {
NestingLimit() : value(ARDUINOJSON_DEFAULT_NESTING_LIMIT) {} public:
explicit NestingLimit(uint8_t n) : value(n) {} NestingLimit() : _value(ARDUINOJSON_DEFAULT_NESTING_LIMIT) {}
explicit NestingLimit(uint8_t n) : _value(n) {}
uint8_t value; NestingLimit decrement() const {
ARDUINOJSON_ASSERT(_value > 0);
return NestingLimit(static_cast<uint8_t>(_value - 1));
}
bool reached() const {
return _value == 0;
}
private:
uint8_t _value;
}; };
} // namespace ARDUINOJSON_NAMESPACE } // namespace ARDUINOJSON_NAMESPACE

View File

@ -15,9 +15,9 @@ namespace ARDUINOJSON_NAMESPACE {
template <template <typename, typename> class TDeserializer, typename TReader, template <template <typename, typename> class TDeserializer, typename TReader,
typename TWriter> typename TWriter>
TDeserializer<TReader, TWriter> makeDeserializer(MemoryPool &pool, TDeserializer<TReader, TWriter> makeDeserializer(MemoryPool &pool,
TReader reader, TWriter writer, TReader reader,
uint8_t nestingLimit) { TWriter writer) {
return TDeserializer<TReader, TWriter>(pool, reader, writer, nestingLimit); return TDeserializer<TReader, TWriter>(pool, reader, writer);
} }
// deserialize(JsonDocument&, const std::string&, NestingLimit, Filter); // deserialize(JsonDocument&, const std::string&, NestingLimit, Filter);
@ -34,8 +34,8 @@ deserialize(JsonDocument &doc, const TString &input, NestingLimit nestingLimit,
doc.clear(); doc.clear();
return makeDeserializer<TDeserializer>( return makeDeserializer<TDeserializer>(
doc.memoryPool(), reader, doc.memoryPool(), reader,
makeStringStorage(doc.memoryPool(), input), nestingLimit.value) makeStringStorage(doc.memoryPool(), input))
.parse(doc.data(), filter); .parse(doc.data(), filter, nestingLimit);
} }
// //
// deserialize(JsonDocument&, char*, size_t, NestingLimit, Filter); // deserialize(JsonDocument&, char*, size_t, NestingLimit, Filter);
@ -50,8 +50,8 @@ DeserializationError deserialize(JsonDocument &doc, TChar *input,
doc.clear(); doc.clear();
return makeDeserializer<TDeserializer>( return makeDeserializer<TDeserializer>(
doc.memoryPool(), reader, doc.memoryPool(), reader,
makeStringStorage(doc.memoryPool(), input), nestingLimit.value) makeStringStorage(doc.memoryPool(), input))
.parse(doc.data(), filter); .parse(doc.data(), filter, nestingLimit);
} }
// //
// deserialize(JsonDocument&, std::istream&, NestingLimit, Filter); // deserialize(JsonDocument&, std::istream&, NestingLimit, Filter);
@ -64,8 +64,8 @@ DeserializationError deserialize(JsonDocument &doc, TStream &input,
doc.clear(); doc.clear();
return makeDeserializer<TDeserializer>( return makeDeserializer<TDeserializer>(
doc.memoryPool(), reader, doc.memoryPool(), reader,
makeStringStorage(doc.memoryPool(), input), nestingLimit.value) makeStringStorage(doc.memoryPool(), input))
.parse(doc.data(), filter); .parse(doc.data(), filter, nestingLimit);
} }
} // namespace ARDUINOJSON_NAMESPACE } // namespace ARDUINOJSON_NAMESPACE

View File

@ -23,15 +23,13 @@ class JsonDeserializer {
public: public:
JsonDeserializer(MemoryPool &pool, TReader reader, JsonDeserializer(MemoryPool &pool, TReader reader,
TStringStorage stringStorage, uint8_t nestingLimit) TStringStorage stringStorage)
: _pool(&pool), : _pool(&pool), _stringStorage(stringStorage), _latch(reader) {}
_stringStorage(stringStorage),
_nestingLimit(nestingLimit),
_latch(reader) {}
template <typename TFilter> template <typename TFilter>
DeserializationError parse(VariantData &variant, TFilter filter) { DeserializationError parse(VariantData &variant, TFilter filter,
DeserializationError err = parseVariant(variant, filter); NestingLimit nestingLimit) {
DeserializationError err = parseVariant(variant, filter, nestingLimit);
if (!err && _latch.last() != 0 && !variant.isEnclosed()) { if (!err && _latch.last() != 0 && !variant.isEnclosed()) {
// We don't detect trailing characters earlier, so we need to check now // We don't detect trailing characters earlier, so we need to check now
@ -59,22 +57,23 @@ class JsonDeserializer {
} }
template <typename TFilter> template <typename TFilter>
DeserializationError parseVariant(VariantData &variant, TFilter filter) { DeserializationError parseVariant(VariantData &variant, TFilter filter,
NestingLimit nestingLimit) {
DeserializationError err = skipSpacesAndComments(); DeserializationError err = skipSpacesAndComments();
if (err) return err; if (err) return err;
switch (current()) { switch (current()) {
case '[': case '[':
if (filter.allowArray()) if (filter.allowArray())
return parseArray(variant.toArray(), filter); return parseArray(variant.toArray(), filter, nestingLimit);
else else
return skipArray(); return skipArray(nestingLimit);
case '{': case '{':
if (filter.allowObject()) if (filter.allowObject())
return parseObject(variant.toObject(), filter); return parseObject(variant.toObject(), filter, nestingLimit);
else else
return skipObject(); return skipObject(nestingLimit);
case '\"': case '\"':
case '\'': case '\'':
@ -91,16 +90,16 @@ class JsonDeserializer {
} }
} }
DeserializationError skipVariant() { DeserializationError skipVariant(NestingLimit nestingLimit) {
DeserializationError err = skipSpacesAndComments(); DeserializationError err = skipSpacesAndComments();
if (err) return err; if (err) return err;
switch (current()) { switch (current()) {
case '[': case '[':
return skipArray(); return skipArray(nestingLimit);
case '{': case '{':
return skipObject(); return skipObject(nestingLimit);
case '\"': case '\"':
case '\'': case '\'':
@ -112,8 +111,9 @@ class JsonDeserializer {
} }
template <typename TFilter> template <typename TFilter>
DeserializationError parseArray(CollectionData &array, TFilter filter) { DeserializationError parseArray(CollectionData &array, TFilter filter,
if (_nestingLimit == 0) return DeserializationError::TooDeep; NestingLimit nestingLimit) {
if (nestingLimit.reached()) return DeserializationError::TooDeep;
// Check opening braket // Check opening braket
if (!eat('[')) return DeserializationError::InvalidInput; if (!eat('[')) return DeserializationError::InvalidInput;
@ -135,14 +135,10 @@ class JsonDeserializer {
if (!value) return DeserializationError::NoMemory; if (!value) return DeserializationError::NoMemory;
// 1 - Parse value // 1 - Parse value
_nestingLimit--; err = parseVariant(*value, memberFilter, nestingLimit.decrement());
err = parseVariant(*value, memberFilter);
_nestingLimit++;
if (err) return err; if (err) return err;
} else { } else {
_nestingLimit--; err = skipVariant(nestingLimit.decrement());
err = skipVariant();
_nestingLimit++;
if (err) return err; if (err) return err;
} }
@ -156,8 +152,8 @@ class JsonDeserializer {
} }
} }
DeserializationError skipArray() { DeserializationError skipArray(NestingLimit nestingLimit) {
if (_nestingLimit == 0) return DeserializationError::TooDeep; if (nestingLimit.reached()) return DeserializationError::TooDeep;
// Check opening braket // Check opening braket
if (!eat('[')) return DeserializationError::InvalidInput; if (!eat('[')) return DeserializationError::InvalidInput;
@ -165,9 +161,7 @@ class JsonDeserializer {
// Read each value // Read each value
for (;;) { for (;;) {
// 1 - Skip value // 1 - Skip value
_nestingLimit--; DeserializationError err = skipVariant(nestingLimit.decrement());
DeserializationError err = skipVariant();
_nestingLimit++;
if (err) return err; if (err) return err;
// 2 - Skip spaces // 2 - Skip spaces
@ -181,8 +175,9 @@ class JsonDeserializer {
} }
template <typename TFilter> template <typename TFilter>
DeserializationError parseObject(CollectionData &object, TFilter filter) { DeserializationError parseObject(CollectionData &object, TFilter filter,
if (_nestingLimit == 0) return DeserializationError::TooDeep; NestingLimit nestingLimit) {
if (nestingLimit.reached()) return DeserializationError::TooDeep;
// Check opening brace // Check opening brace
if (!eat('{')) return DeserializationError::InvalidInput; if (!eat('{')) return DeserializationError::InvalidInput;
@ -221,15 +216,11 @@ class JsonDeserializer {
} }
// Parse value // Parse value
_nestingLimit--; err = parseVariant(*variant, memberFilter, nestingLimit.decrement());
err = parseVariant(*variant, memberFilter);
_nestingLimit++;
if (err) return err; if (err) return err;
} else { } else {
_stringStorage.reclaim(key); _stringStorage.reclaim(key);
_nestingLimit--; err = skipVariant(nestingLimit.decrement());
err = skipVariant();
_nestingLimit++;
if (err) return err; if (err) return err;
} }
@ -247,8 +238,8 @@ class JsonDeserializer {
} }
} }
DeserializationError skipObject() { DeserializationError skipObject(NestingLimit nestingLimit) {
if (_nestingLimit == 0) return DeserializationError::TooDeep; if (nestingLimit.reached()) return DeserializationError::TooDeep;
// Check opening brace // Check opening brace
if (!eat('{')) return DeserializationError::InvalidInput; if (!eat('{')) return DeserializationError::InvalidInput;
@ -263,7 +254,7 @@ class JsonDeserializer {
// Read each key value pair // Read each key value pair
for (;;) { for (;;) {
// Skip key // Skip key
err = skipVariant(); err = skipVariant(nestingLimit.decrement());
if (err) return err; if (err) return err;
// Skip spaces // Skip spaces
@ -272,9 +263,7 @@ class JsonDeserializer {
if (!eat(':')) return DeserializationError::InvalidInput; if (!eat(':')) return DeserializationError::InvalidInput;
// Skip value // Skip value
_nestingLimit--; err = skipVariant(nestingLimit.decrement());
err = skipVariant();
_nestingLimit++;
if (err) return err; if (err) return err;
// Skip spaces // Skip spaces
@ -538,7 +527,6 @@ class JsonDeserializer {
MemoryPool *_pool; MemoryPool *_pool;
TStringStorage _stringStorage; TStringStorage _stringStorage;
uint8_t _nestingLimit;
Latch<TReader> _latch; Latch<TReader> _latch;
}; };

View File

@ -20,15 +20,16 @@ class MsgPackDeserializer {
public: public:
MsgPackDeserializer(MemoryPool &pool, TReader reader, MsgPackDeserializer(MemoryPool &pool, TReader reader,
TStringStorage stringStorage, uint8_t nestingLimit) TStringStorage stringStorage)
: _pool(&pool), : _pool(&pool), _reader(reader), _stringStorage(stringStorage) {}
_reader(reader),
_stringStorage(stringStorage),
_nestingLimit(nestingLimit) {}
// TODO: add support for filter // TODO: add support for filter
DeserializationError parse(VariantData &variant, DeserializationError parse(VariantData &variant, AllowAllFilter,
AllowAllFilter = AllowAllFilter()) { NestingLimit nestingLimit) {
return parse(variant, nestingLimit);
}
DeserializationError parse(VariantData &variant, NestingLimit nestingLimit) {
uint8_t code; uint8_t code;
if (!readByte(code)) return DeserializationError::IncompleteInput; if (!readByte(code)) return DeserializationError::IncompleteInput;
@ -48,11 +49,11 @@ class MsgPackDeserializer {
} }
if ((code & 0xf0) == 0x90) { if ((code & 0xf0) == 0x90) {
return readArray(variant.toArray(), code & 0x0F); return readArray(variant.toArray(), code & 0x0F, nestingLimit);
} }
if ((code & 0xf0) == 0x80) { if ((code & 0xf0) == 0x80) {
return readObject(variant.toObject(), code & 0x0F); return readObject(variant.toObject(), code & 0x0F, nestingLimit);
} }
switch (code) { switch (code) {
@ -116,16 +117,16 @@ class MsgPackDeserializer {
return readString<uint32_t>(variant); return readString<uint32_t>(variant);
case 0xdc: case 0xdc:
return readArray<uint16_t>(variant.toArray()); return readArray<uint16_t>(variant.toArray(), nestingLimit);
case 0xdd: case 0xdd:
return readArray<uint32_t>(variant.toArray()); return readArray<uint32_t>(variant.toArray(), nestingLimit);
case 0xde: case 0xde:
return readObject<uint16_t>(variant.toObject()); return readObject<uint16_t>(variant.toObject(), nestingLimit);
case 0xdf: case 0xdf:
return readObject<uint32_t>(variant.toObject()); return readObject<uint32_t>(variant.toObject(), nestingLimit);
default: default:
return DeserializationError::NotSupported; return DeserializationError::NotSupported;
@ -242,36 +243,40 @@ class MsgPackDeserializer {
} }
template <typename TSize> template <typename TSize>
DeserializationError readArray(CollectionData &array) { DeserializationError readArray(CollectionData &array,
NestingLimit nestingLimit) {
TSize size; TSize size;
if (!readInteger(size)) return DeserializationError::IncompleteInput; if (!readInteger(size)) return DeserializationError::IncompleteInput;
return readArray(array, size); return readArray(array, size, nestingLimit);
} }
DeserializationError readArray(CollectionData &array, size_t n) { DeserializationError readArray(CollectionData &array, size_t n,
if (_nestingLimit == 0) return DeserializationError::TooDeep; NestingLimit nestingLimit) {
--_nestingLimit; if (nestingLimit.reached()) return DeserializationError::TooDeep;
for (; n; --n) { for (; n; --n) {
VariantData *value = array.add(_pool); VariantData *value = array.add(_pool);
if (!value) return DeserializationError::NoMemory; if (!value) return DeserializationError::NoMemory;
DeserializationError err = parse(*value); DeserializationError err = parse(*value, nestingLimit.decrement());
if (err) return err; if (err) return err;
} }
++_nestingLimit;
return DeserializationError::Ok; return DeserializationError::Ok;
} }
template <typename TSize> template <typename TSize>
DeserializationError readObject(CollectionData &object) { DeserializationError readObject(CollectionData &object,
NestingLimit nestingLimit) {
TSize size; TSize size;
if (!readInteger(size)) return DeserializationError::IncompleteInput; if (!readInteger(size)) return DeserializationError::IncompleteInput;
return readObject(object, size); return readObject(object, size, nestingLimit);
} }
DeserializationError readObject(CollectionData &object, size_t n) { DeserializationError readObject(CollectionData &object, size_t n,
if (_nestingLimit == 0) return DeserializationError::TooDeep; NestingLimit nestingLimit) {
--_nestingLimit; if (nestingLimit.reached()) return DeserializationError::TooDeep;
for (; n; --n) { for (; n; --n) {
VariantSlot *slot = object.addSlot(_pool); VariantSlot *slot = object.addSlot(_pool);
if (!slot) return DeserializationError::NoMemory; if (!slot) return DeserializationError::NoMemory;
@ -281,10 +286,10 @@ class MsgPackDeserializer {
if (err) return err; if (err) return err;
slot->setOwnedKey(make_not_null(key)); slot->setOwnedKey(make_not_null(key));
err = parse(*slot->data()); err = parse(*slot->data(), nestingLimit.decrement());
if (err) return err; if (err) return err;
} }
++_nestingLimit;
return DeserializationError::Ok; return DeserializationError::Ok;
} }
@ -312,7 +317,6 @@ class MsgPackDeserializer {
MemoryPool *_pool; MemoryPool *_pool;
TReader _reader; TReader _reader;
TStringStorage _stringStorage; TStringStorage _stringStorage;
uint8_t _nestingLimit;
}; };
template <typename TInput> template <typename TInput>