diff --git a/src/lib/knowledge_base.cpp b/src/lib/knowledge_base.cpp index 0e163de..f5f4170 100644 --- a/src/lib/knowledge_base.cpp +++ b/src/lib/knowledge_base.cpp @@ -223,6 +223,28 @@ void obelisk::KnowledgeBase::getRule(obelisk::Rule& rule) rule.selectById(dbConnection_); } +void obelisk::KnowledgeBase::checkRule(obelisk::Fact& fact) +{ + std::vector rules; + obelisk::Rule::selectByReason(dbConnection_, fact.getId(), rules); + for (auto& rule : rules) + { + auto reason = rule.getReason(); + getFact(reason); + if (reason.getIsTrue()) + { + auto updateFact = rule.getFact(); + updateFact.setIsTrue(true); + updateFact.updateIsTrue(dbConnection_); + } + } +} + +void obelisk::KnowledgeBase::updateIsTrue(obelisk::Fact& fact) +{ + fact.updateIsTrue(dbConnection_); +} + void obelisk::KnowledgeBase::getFloat(float& result1, float& result2, double var) { result1 = (float) var; diff --git a/src/lib/knowledge_base.h b/src/lib/knowledge_base.h index c1790a3..ac463d2 100644 --- a/src/lib/knowledge_base.h +++ b/src/lib/knowledge_base.h @@ -183,6 +183,21 @@ namespace obelisk */ void getRule(obelisk::Rule& rule); + /** + * @brief Check if a rule looks for this Fact, if so update its + * truth. + * + * @param[in,out] fact The Fact to check for existing rules. + */ + void checkRule(obelisk::Fact& fact); + + /** + * @brief Update the is true field in the KnowledgeBase. + * + * @param[in,out] fact The fact to update. + */ + void updateIsTrue(obelisk::Fact& fact); + /** * @brief Take a float and divide it into 2 floats. * diff --git a/src/lib/models/fact.cpp b/src/lib/models/fact.cpp index c6092ea..eafe955 100644 --- a/src/lib/models/fact.cpp +++ b/src/lib/models/fact.cpp @@ -28,71 +28,101 @@ void obelisk::Fact::selectById(sqlite3* dbConnection) sqlite3_stmt* ppStmt = nullptr; - auto result = sqlite3_prepare_v2(dbConnection, - "SELECT id, left_entity, right_entity, verb, is_true FROM fact WHERE (left_entity=? AND right_entity=? AND verb=?)", - -1, - &ppStmt, - nullptr); + const char* query; + if (getId() == 0) + { + query + = "SELECT id, left_entity, right_entity, verb, is_true FROM fact WHERE (left_entity=? AND right_entity=? AND verb=?)"; + } + else + { + query = "SELECT id, left_entity, right_entity, verb, is_true FROM fact WHERE (id=?)"; + } + auto result = sqlite3_prepare_v2(dbConnection, query, -1, &ppStmt, nullptr); if (result != SQLITE_OK) { throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); } - result = sqlite3_bind_int(ppStmt, 1, getLeftEntity().getId()); - switch (result) + if (getId() == 0) { - case SQLITE_OK : - break; - case SQLITE_TOOBIG : - throw obelisk::DatabaseSizeException(); - break; - case SQLITE_RANGE : - throw obelisk::DatabaseRangeException(); - break; - case SQLITE_NOMEM : - throw obelisk::DatabaseMemoryException(); - break; - default : - throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); - break; - } + result = sqlite3_bind_int(ppStmt, 1, getLeftEntity().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } - result = sqlite3_bind_int(ppStmt, 2, getRightEntity().getId()); - switch (result) - { - case SQLITE_OK : - break; - case SQLITE_TOOBIG : - throw obelisk::DatabaseSizeException(); - break; - case SQLITE_RANGE : - throw obelisk::DatabaseRangeException(); - break; - case SQLITE_NOMEM : - throw obelisk::DatabaseMemoryException(); - break; - default : - throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); - break; - } + result = sqlite3_bind_int(ppStmt, 2, getRightEntity().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } - result = sqlite3_bind_int(ppStmt, 3, getVerb().getId()); - switch (result) + result = sqlite3_bind_int(ppStmt, 3, getVerb().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + } + else { - case SQLITE_OK : - break; - case SQLITE_TOOBIG : - throw obelisk::DatabaseSizeException(); - break; - case SQLITE_RANGE : - throw obelisk::DatabaseRangeException(); - break; - case SQLITE_NOMEM : - throw obelisk::DatabaseMemoryException(); - break; - default : - throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); - break; + result = sqlite3_bind_int(ppStmt, 1, getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } } result = sqlite3_step(ppStmt); @@ -248,6 +278,85 @@ void obelisk::Fact::insert(sqlite3* dbConnection) } } +void obelisk::Fact::updateIsTrue(sqlite3* dbConnection) +{ + if (dbConnection == nullptr) + { + throw obelisk::DatabaseException("database isn't open"); + } + + sqlite3_stmt* ppStmt = nullptr; + + auto result = sqlite3_prepare_v2(dbConnection, "UPDATE fact SET is_true=? WHERE id=?", -1, &ppStmt, nullptr); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } + + result = sqlite3_bind_int(ppStmt, 1, getIsTrue()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_bind_int(ppStmt, 2, getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_step(ppStmt); + switch (result) + { + case SQLITE_DONE : + // Row updated + break; + case SQLITE_CONSTRAINT : + throw obelisk::DatabaseConstraintException(sqlite3_errmsg(dbConnection)); + case SQLITE_BUSY : + throw obelisk::DatabaseBusyException(); + break; + case SQLITE_MISUSE : + throw obelisk::DatabaseMisuseException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_finalize(ppStmt); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } +} + int& obelisk::Fact::getId() { return id_; diff --git a/src/lib/models/fact.h b/src/lib/models/fact.h index 4e6b821..cc16b06 100644 --- a/src/lib/models/fact.h +++ b/src/lib/models/fact.h @@ -211,6 +211,14 @@ namespace obelisk * @param[in] dbConnection The database connection to use. */ void insert(sqlite3* dbConnection); + + /** + * @brief Update whether or not the fact is true in the + * KnowledgeBase. + * + * @param[in] dbConnection The database connection. + */ + void updateIsTrue(sqlite3* dbConnection); }; } // namespace obelisk diff --git a/src/lib/models/rule.cpp b/src/lib/models/rule.cpp index 26f9fd9..9184b6c 100644 --- a/src/lib/models/rule.cpp +++ b/src/lib/models/rule.cpp @@ -183,6 +183,69 @@ void obelisk::Rule::insert(sqlite3* dbConnection) } } +void obelisk::Rule::selectByReason(sqlite3* dbConnection, int reasonId, std::vector& rules) +{ + if (dbConnection == nullptr) + { + throw obelisk::DatabaseException("database isn't open"); + } + + sqlite3_stmt* ppStmt = nullptr; + + auto result + = sqlite3_prepare_v2(dbConnection, "SELECT id, fact, reason FROM rule WHERE (reason=?)", -1, &ppStmt, nullptr); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } + + result = sqlite3_bind_int(ppStmt, 1, reasonId); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + while ((result = sqlite3_step(ppStmt)) != SQLITE_DONE) + { + switch (result) + { + case SQLITE_ROW : + rules.push_back(obelisk::Rule(sqlite3_column_int(ppStmt, 0), + obelisk::Fact(sqlite3_column_int(ppStmt, 1)), + obelisk::Fact(sqlite3_column_int(ppStmt, 2)))); + break; + case SQLITE_BUSY : + throw obelisk::DatabaseBusyException(); + break; + case SQLITE_MISUSE : + throw obelisk::DatabaseMisuseException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + } + + result = sqlite3_finalize(ppStmt); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } +} + int& obelisk::Rule::getId() { return id_; diff --git a/src/lib/models/rule.h b/src/lib/models/rule.h index e48189d..02f06fd 100644 --- a/src/lib/models/rule.h +++ b/src/lib/models/rule.h @@ -4,6 +4,7 @@ #include "models/fact.h" #include +#include namespace obelisk { @@ -140,6 +141,14 @@ namespace obelisk */ void selectById(sqlite3* dbConnection); + /** + * @brief Get the rules that match the reason. + * + * @param[in] dbConnection The database connection to use. + * @param[out] rules The rules to fill in from the database. + */ + static void selectByReason(sqlite3* dbConnection, int reasonId, std::vector& rules); + /** * @brief Insert the Rule into the KnowledgeBase. * diff --git a/src/parser.cpp b/src/parser.cpp index de93c3c..b3207c9 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -770,14 +770,23 @@ void obelisk::Parser::handleRule(std::unique_ptr& kb) try { parseRule(rule); - insertEntity(kb, rule.getFact().getLeftEntity()); - insertEntity(kb, rule.getFact().getRightEntity()); - insertVerb(kb, rule.getFact().getVerb()); - insertFact(kb, rule.getFact()); + insertEntity(kb, rule.getReason().getLeftEntity()); insertEntity(kb, rule.getReason().getRightEntity()); insertVerb(kb, rule.getReason().getVerb()); insertFact(kb, rule.getReason()); + + // The rule is true, so the fact must be true to. + if (rule.getReason().getIsTrue()) + { + rule.getFact().setIsTrue(true); + } + + insertEntity(kb, rule.getFact().getLeftEntity()); + insertEntity(kb, rule.getFact().getRightEntity()); + insertVerb(kb, rule.getFact().getVerb()); + insertFact(kb, rule.getFact()); + insertRule(kb, rule); } catch (obelisk::ParserException& exception) @@ -830,12 +839,14 @@ void obelisk::Parser::handleFact(std::unique_ptr& kb) try { - insertFact(kb, fact); + insertFact(kb, fact, true); } catch (obelisk::ParserException& exception) { throw; } + + kb->checkRule(fact); } } @@ -890,7 +901,7 @@ void obelisk::Parser::insertAction(std::unique_ptr& kb, } } -void obelisk::Parser::insertFact(std::unique_ptr& kb, obelisk::Fact& fact) +void obelisk::Parser::insertFact(std::unique_ptr& kb, obelisk::Fact& fact, bool updateIsTrue) { std::vector facts {fact}; kb->addFacts(facts); @@ -904,6 +915,14 @@ void obelisk::Parser::insertFact(std::unique_ptr& kb, ob { throw obelisk::ParserException("fact could not be inserted into the database"); } + else + { + if (updateIsTrue) + { + fact.setIsTrue(true); + kb->updateIsTrue(fact); + } + } } } diff --git a/src/parser.h b/src/parser.h index 13ad7ea..573fbdc 100644 --- a/src/parser.h +++ b/src/parser.h @@ -17,12 +17,32 @@ namespace obelisk { + /** + * @brief The Parser is responsible for analyzing the language's key words + * and taking action based on its analysis. + * + */ class Parser { private: + /** + * @brief The Lexer object that the Parser is using to Parse a + * specific source file. + * + */ std::shared_ptr lexer_; + + /** + * @brief The current token that the lexer has retrieved. + * + */ int currentToken_ = 0; + /** + * @brief Set the current token. + * + * @param[in] currentToken The token should be ASCII character. + */ void setCurrentToken(int currentToken); std::unique_ptr logError(const char* str); @@ -64,7 +84,9 @@ namespace obelisk void insertEntity(std::unique_ptr& kb, obelisk::Entity& entity); void insertVerb(std::unique_ptr& kb, obelisk::Verb& verb); void insertAction(std::unique_ptr& kb, obelisk::Action& action); - void insertFact(std::unique_ptr& kb, obelisk::Fact& fact); + void insertFact(std::unique_ptr& kb, + obelisk::Fact& fact, + bool updateIsTrue = false); void insertSuggestAction(std::unique_ptr& kb, obelisk::SuggestAction& suggestAction); void insertRule(std::unique_ptr& kb, obelisk::Rule& rule);