diff --git a/src/libs/sqlite/sqlitebasestatement.cpp b/src/libs/sqlite/sqlitebasestatement.cpp index 15e4a980436..01cd43facb2 100644 --- a/src/libs/sqlite/sqlitebasestatement.cpp +++ b/src/libs/sqlite/sqlitebasestatement.cpp @@ -54,8 +54,7 @@ BaseStatement::BaseStatement(Utils::SmallStringView sqlStatement, Database &data void BaseStatement::deleteCompiledStatement(sqlite3_stmt *compiledStatement) { - if (compiledStatement) - sqlite3_finalize(compiledStatement); + sqlite3_finalize(compiledStatement); } class UnlockNotification @@ -145,12 +144,6 @@ void BaseStatement::step() const next(); } -void BaseStatement::execute() const -{ - next(); - reset(); -} - int BaseStatement::columnCount() const { return m_columnCount; diff --git a/src/libs/sqlite/sqlitebasestatement.h b/src/libs/sqlite/sqlitebasestatement.h index 5e629e5ec9e..86d03fd8553 100644 --- a/src/libs/sqlite/sqlitebasestatement.h +++ b/src/libs/sqlite/sqlitebasestatement.h @@ -60,7 +60,6 @@ public: bool next() const; void step() const; - void execute() const; void reset() const; int fetchIntValue(int column) const; @@ -159,6 +158,13 @@ class StatementImplementation : public BaseStatement public: using BaseStatement::BaseStatement; + void execute() + { + Resetter resetter{*this}; + BaseStatement::next(); + resetter.reset(); + } + void bindValues() { } @@ -172,8 +178,10 @@ public: template void write(const ValueType&... values) { + Resetter resetter{*this}; bindValuesByIndex(1, values...); - BaseStatement::execute(); + BaseStatement::next(); + resetter.reset(); } template @@ -185,8 +193,10 @@ public: template void writeNamed(const ValueType&... values) { + Resetter resetter{*this}; bindValuesByName(values...); - BaseStatement::execute(); + BaseStatement::next(); + resetter.reset(); } template m_locker{m_interface}; + std::unique_lock m_locker{m_interface}; bool m_isAlreadyCommited = false; bool m_rollback = false; }; diff --git a/tests/unit/unittest/mocksqlitestatement.h b/tests/unit/unittest/mocksqlitestatement.h index ab919b9f4e2..26f5b0de6e0 100644 --- a/tests/unit/unittest/mocksqlitestatement.h +++ b/tests/unit/unittest/mocksqlitestatement.h @@ -35,7 +35,6 @@ class BaseMockSqliteStatement public: MOCK_METHOD0(next, bool ()); MOCK_METHOD0(step, void ()); - MOCK_METHOD0(execute, void ()); MOCK_METHOD0(reset, void ()); MOCK_CONST_METHOD1(fetchIntValue, int (int)); @@ -53,6 +52,7 @@ public: MOCK_METHOD2(bind, void (int, double)); MOCK_METHOD2(bind, void (int, Utils::SmallStringView)); MOCK_METHOD2(bind, void (int, long)); + MOCK_CONST_METHOD1(bindingIndexForName, int (Utils::SmallStringView name)); MOCK_METHOD1(prepare, void (Utils::SmallStringView sqlStatement)); }; diff --git a/tests/unit/unittest/sqlitestatement-test.cpp b/tests/unit/unittest/sqlitestatement-test.cpp index f033c07ed73..cb8206d8757 100644 --- a/tests/unit/unittest/sqlitestatement-test.cpp +++ b/tests/unit/unittest/sqlitestatement-test.cpp @@ -639,6 +639,40 @@ TEST_F(SqliteStatement, ThrowExceptionOnlyInReset) Sqlite::StatementHasError); } +TEST_F(SqliteStatement, ResetIfWriteIsThrowingException) +{ + MockSqliteStatement mockStatement; + + EXPECT_CALL(mockStatement, bind(1, TypedEq("bar"))) + .WillOnce(Throw(Sqlite::StatementIsBusy(""))); + EXPECT_CALL(mockStatement, reset()); + + ASSERT_ANY_THROW(mockStatement.write("bar")); +} + +TEST_F(SqliteStatement, ResetIfWriteNamedIsThrowingException) +{ + MockSqliteStatement mockStatement; + + EXPECT_CALL(mockStatement, bindingIndexForName(TypedEq("@foo"))) + .WillOnce(Return(1)); + EXPECT_CALL(mockStatement, bind(1, TypedEq("bar"))) + .WillOnce(Throw(Sqlite::StatementIsBusy(""))); + EXPECT_CALL(mockStatement, reset()); + + ASSERT_ANY_THROW(mockStatement.writeNamed("@foo", "bar")); +} + +TEST_F(SqliteStatement, ResetIfExecuteThrowsException) +{ + MockSqliteStatement mockStatement; + + EXPECT_CALL(mockStatement, next()).WillOnce(Throw(Sqlite::StatementIsBusy(""))); + EXPECT_CALL(mockStatement, reset()); + + ASSERT_ANY_THROW(mockStatement.execute()); +} + void SqliteStatement::SetUp() { database.execute("CREATE TABLE test(name TEXT UNIQUE, number NUMERIC, value NUMERIC)");