diff --git a/src/lib/knowledge_base.cpp b/src/lib/knowledge_base.cpp index 37ba0a8..0e163de 100644 --- a/src/lib/knowledge_base.cpp +++ b/src/lib/knowledge_base.cpp @@ -174,6 +174,25 @@ void obelisk::KnowledgeBase::addSuggestActions(std::vector& rules) +{ + for (auto& rule : rules) + { + try + { + rule.insert(dbConnection_); + } + catch (obelisk::DatabaseConstraintException& exception) + { + // ignore unique constraint error + if (std::strcmp(exception.what(), "UNIQUE constraint failed: rule.fact, rule.reason") != 0) + { + throw; + } + } + } +} + void obelisk::KnowledgeBase::getEntity(obelisk::Entity& entity) { entity.selectByName(dbConnection_); @@ -199,6 +218,11 @@ void obelisk::KnowledgeBase::getSuggestAction(obelisk::SuggestAction& suggestAct suggestAction.selectById(dbConnection_); } +void obelisk::KnowledgeBase::getRule(obelisk::Rule& rule) +{ + rule.selectById(dbConnection_); +} + void obelisk::KnowledgeBase::getFloat(float& result1, float& result2, double var) { result1 = (float) var; diff --git a/src/lib/knowledge_base.h b/src/lib/knowledge_base.h index ddb1165..c1790a3 100644 --- a/src/lib/knowledge_base.h +++ b/src/lib/knowledge_base.h @@ -127,6 +127,14 @@ namespace obelisk */ void addSuggestActions(std::vector& suggestActions); + /** + * @brief Add rules to the KnowledgeBase. + * + * @param[in,out] rules The rules to add. If the insert is successful it + * will have a row ID, if not the ID will be 0. + */ + void addRules(std::vector& rules); + /** * @brief Get an Entity object based on the ID it contains. * @@ -167,6 +175,14 @@ namespace obelisk */ void getSuggestAction(obelisk::SuggestAction& suggestAction); + /** + * @brief Get a Rule based on the ID it contains. + * + * @param[in,out] rule The Rule object should contain just the ID + * and the rest will be filled in. + */ + void getRule(obelisk::Rule& rule); + /** * @brief Take a float and divide it into 2 floats. * diff --git a/src/lib/models/rule.cpp b/src/lib/models/rule.cpp index 7dfc4f5..26f9fd9 100644 --- a/src/lib/models/rule.cpp +++ b/src/lib/models/rule.cpp @@ -1,3 +1,4 @@ +#include "models/error.h" #include "models/rule.h" const char* obelisk::Rule::createTable() @@ -15,6 +16,173 @@ const char* obelisk::Rule::createTable() )"; } +void obelisk::Rule::selectById(sqlite3* dbConnection) +{ + if (dbConnection == nullptr) + { + throw obelisk::DatabaseException("database isn't open"); + } + + sqlite3_stmt* ppStmt = nullptr; + + auto result = sqlite3_prepare_v2(dbConnection, + "SELECT id, fact, reason FROM rule WHERE (fact=? AND reason=?)", + -1, + &ppStmt, + nullptr); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } + + result = sqlite3_bind_int(ppStmt, 1, getFact().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_bind_int(ppStmt, 2, getReason().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_step(ppStmt); + switch (result) + { + case SQLITE_DONE : + // no rows in the database + break; + case SQLITE_ROW : + setId(sqlite3_column_int(ppStmt, 0)); + getFact().setId(sqlite3_column_int(ppStmt, 1)); + getReason().setId(sqlite3_column_int(ppStmt, 2)); + break; + case SQLITE_BUSY : + throw obelisk::DatabaseBusyException(); + break; + case SQLITE_MISUSE : + throw obelisk::DatabaseMisuseException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_finalize(ppStmt); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } +} + +void obelisk::Rule::insert(sqlite3* dbConnection) +{ + if (dbConnection == nullptr) + { + throw obelisk::DatabaseException("database isn't open"); + } + + sqlite3_stmt* ppStmt = nullptr; + + auto result + = sqlite3_prepare_v2(dbConnection, "INSERT INTO rule (fact, reason) VALUES (?, ?)", -1, &ppStmt, nullptr); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } + + result = sqlite3_bind_int(ppStmt, 1, getFact().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_bind_int(ppStmt, 2, getReason().getId()); + switch (result) + { + case SQLITE_OK : + break; + case SQLITE_TOOBIG : + throw obelisk::DatabaseSizeException(); + break; + case SQLITE_RANGE : + throw obelisk::DatabaseRangeException(); + break; + case SQLITE_NOMEM : + throw obelisk::DatabaseMemoryException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_step(ppStmt); + switch (result) + { + case SQLITE_DONE : + setId((int) sqlite3_last_insert_rowid(dbConnection)); + sqlite3_set_last_insert_rowid(dbConnection, 0); + break; + case SQLITE_CONSTRAINT : + throw obelisk::DatabaseConstraintException(sqlite3_errmsg(dbConnection)); + case SQLITE_BUSY : + throw obelisk::DatabaseBusyException(); + break; + case SQLITE_MISUSE : + throw obelisk::DatabaseMisuseException(); + break; + default : + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + break; + } + + result = sqlite3_finalize(ppStmt); + if (result != SQLITE_OK) + { + throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); + } +} + int& obelisk::Rule::getId() { return id_; diff --git a/src/lib/models/rule.h b/src/lib/models/rule.h index 5d3b23e..e48189d 100644 --- a/src/lib/models/rule.h +++ b/src/lib/models/rule.h @@ -131,6 +131,21 @@ namespace obelisk * @param[in] reason The reason Fact. */ void setReason(obelisk::Fact reason); + + /** + * @brief Select the Rule from the KnowledgeBase by IDs of the + * sub-objects. + * + * @param[in] dbConnection The database connection to use. + */ + void selectById(sqlite3* dbConnection); + + /** + * @brief Insert the Rule into the KnowledgeBase. + * + * @param[in] dbConnection The database connection to use. + */ + void insert(sqlite3* dbConnection); }; } // namespace obelisk diff --git a/src/main.cpp b/src/main.cpp index edd4cc0..e7f5e02 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -84,13 +84,37 @@ int obelisk::mainLoop(const std::vector& sourceFiles, const std::st } break; case obelisk::Lexer::kTokenFact : - parser->handleFact(kb); + try + { + parser->handleFact(kb); + } + catch (obelisk::ParserException& exception) + { + std::cout << "Error: " << exception.what() << std::endl; + return EXIT_FAILURE; + } break; case obelisk::Lexer::kTokenRule : - // parser->handleRule(); + try + { + parser->handleRule(kb); + } + catch (obelisk::ParserException& exception) + { + std::cout << "Error: " << exception.what() << std::endl; + return EXIT_FAILURE; + } break; case obelisk::Lexer::kTokenAction : - parser->handleAction(kb); + try + { + parser->handleAction(kb); + } + catch (obelisk::ParserException& exception) + { + std::cout << "Error: " << exception.what() << std::endl; + return EXIT_FAILURE; + } break; default : parser->getNextToken(); diff --git a/src/parser.cpp b/src/parser.cpp index e9f0b2e..de93c3c 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -257,14 +257,13 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction) syntax.pop(); if (verb == "") { - leftEntity = entityName; + leftEntity = std::move(entityName); } else { - rightEntity = entityName; + rightEntity = std::move(entityName); } - entityName = ""; - getEntity = false; + getEntity = false; getNextToken(); continue; } @@ -290,23 +289,22 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction) if (getCurrentToken() == '"') { throw obelisk::ParserException("unexpected '\"'"); - break; } - if (getLexer()->getIdentifier() == "and") - { - getNextToken(); - getEntity = true; - continue; - } - else if (getLexer()->getIdentifier() == "then") + if (getLexer()->getIdentifier() == "then") { break; } else { - verb = getLexer()->getIdentifier(); - // TODO: make sure verb is alphabetic + verb = getLexer()->getIdentifier(); + for (const auto& letter : verb) + { + if (!isalpha(letter)) + { + throw new obelisk::ParserException("non alphabetic symbol in verb"); + } + } getEntity = true; continue; } @@ -333,14 +331,13 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction) syntax.pop(); if (trueAction == "") { - trueAction = entityName; + trueAction = std::move(entityName); } else { - falseAction = entityName; + falseAction = std::move(entityName); } - entityName = ""; - getAction = false; + getAction = false; getNextToken(); continue; } @@ -370,6 +367,21 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction) syntax.pop(); } + if (leftEntity == "") + { + throw obelisk::ParserException("missing left entity"); + } + + if (rightEntity == "") + { + throw obelisk::ParserException("missing left entity"); + } + + if (verb == "") + { + throw obelisk::ParserException("missing verb"); + } + if (trueAction == "") { throw obelisk::ParserException("missing true action"); @@ -380,6 +392,12 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction) throw obelisk::ParserException("missing false action"); } + getNextToken(); + if (getCurrentToken() != ';') + { + throw obelisk::ParserException("missing ';'"); + } + break; } @@ -409,8 +427,185 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction) suggestAction.setFalseAction(obelisk::Action(falseAction)); } -void obelisk::Parser::parseRule(std::vector& rules) +void obelisk::Parser::parseRule(obelisk::Rule& rule) { + std::stack syntax; + + getNextToken(); + if (getCurrentToken() != '(') + { + throw obelisk::ParserException("expected '(' but got '" + std::to_string(getCurrentToken()) + "'"); + } + + syntax.push('('); + + bool getEntity {true}; + bool getReason {false}; + std::string leftEntity {""}; + std::string rightEntity {""}; + std::string verb {""}; + std::string leftReasonEntity {""}; + std::string rightReasonEntity {""}; + std::string reasonVerb {""}; + std::string entityName {""}; + getNextToken(); + + // get the entity side of statement + while (true) + { + if (getEntity) + { + if (getCurrentToken() == '"') + { + if (syntax.top() != '"') + { + // open a double quote + syntax.push('"'); + getNextToken(); + } + else if (syntax.top() == '"') + { + // close a double quote + syntax.pop(); + if (!getReason) + { + if (verb == "") + { + leftEntity = std::move(entityName); + } + else + { + rightEntity = std::move(entityName); + } + } + else + { + if (reasonVerb == "") + { + leftReasonEntity = std::move(entityName); + } + else + { + rightReasonEntity = std::move(entityName); + } + } + getEntity = false; + getNextToken(); + continue; + } + } + + if (syntax.top() == '"') + { + if (entityName != "") + { + entityName += " "; + } + entityName += getLexer()->getIdentifier(); + } + getNextToken(); + } + else + { + if (getCurrentToken() == ')') + { + // closing parenthesis found, make sure we have everything needed + if (syntax.top() != '(') + { + throw obelisk::ParserException("unexpected ')'"); + } + else + { + syntax.pop(); + } + + if (leftEntity == "") + { + throw obelisk::ParserException("missing left entity"); + } + + if (rightEntity == "") + { + throw obelisk::ParserException("missing left entity"); + } + + if (verb == "") + { + throw obelisk::ParserException("missing verb"); + } + + if (leftReasonEntity == "") + { + throw obelisk::ParserException("missing left reason entity"); + } + + if (rightReasonEntity == "") + { + throw obelisk::ParserException("missing right reason entity"); + } + + if (reasonVerb == "") + { + throw obelisk::ParserException("missing reason verb"); + } + + getNextToken(); + if (getCurrentToken() != ';') + { + throw obelisk::ParserException("missing ';'"); + } + + break; + } + + if (getCurrentToken() == '"') + { + throw obelisk::ParserException("unexpected '\"'"); + } + + if (getLexer()->getIdentifier() == "if") + { + getReason = true; + getEntity = true; + getNextToken(); + continue; + } + else + { + if (!getReason) + { + verb = getLexer()->getIdentifier(); + for (const auto& letter : verb) + { + if (!isalpha(letter)) + { + throw new obelisk::ParserException("non alphabetic symbol in verb"); + } + } + getEntity = true; + continue; + } + else + { + reasonVerb = getLexer()->getIdentifier(); + for (const auto& letter : reasonVerb) + { + if (!isalpha(letter)) + { + throw new obelisk::ParserException("non alphabetic symbol in verb"); + } + } + getEntity = true; + continue; + } + } + } + } + + rule.setFact(obelisk::Fact(obelisk::Entity(leftEntity), obelisk::Entity(rightEntity), obelisk::Verb(verb))); + rule.setReason(obelisk::Fact(obelisk::Entity(leftReasonEntity), + obelisk::Entity(rightReasonEntity), + obelisk::Verb(reasonVerb))); } void obelisk::Parser::parseFact(std::vector& facts) @@ -501,13 +696,18 @@ void obelisk::Parser::parseFact(std::vector& facts) throw obelisk::ParserException("missing right side entities"); } + getNextToken(); + if (getCurrentToken() != ';') + { + throw obelisk::ParserException("missing ';'"); + } + break; } if (getCurrentToken() == '"') { throw obelisk::ParserException("unexpected '\"'"); - break; } if (getLexer()->getIdentifier() == "and") @@ -518,8 +718,14 @@ void obelisk::Parser::parseFact(std::vector& facts) } else { - verb = getLexer()->getIdentifier(); - // TODO: make sure verb is alphabetic + verb = getLexer()->getIdentifier(); + for (const auto& letter : verb) + { + if (!isalpha(letter)) + { + throw new obelisk::ParserException("non alphabetic symbol in verb"); + } + } getEntity = true; continue; } @@ -539,10 +745,10 @@ void obelisk::Parser::parseFact(std::vector& facts) void obelisk::Parser::handleAction(std::unique_ptr& kb) { obelisk::SuggestAction suggestAction; - parseAction(suggestAction); try { + parseAction(suggestAction); insertEntity(kb, suggestAction.getFact().getLeftEntity()); insertEntity(kb, suggestAction.getFact().getRightEntity()); insertVerb(kb, suggestAction.getFact().getVerb()); @@ -559,6 +765,25 @@ void obelisk::Parser::handleAction(std::unique_ptr& kb) void obelisk::Parser::handleRule(std::unique_ptr& kb) { + obelisk::Rule rule; + + try + { + parseRule(rule); + insertEntity(kb, rule.getFact().getLeftEntity()); + insertEntity(kb, rule.getFact().getRightEntity()); + insertVerb(kb, rule.getFact().getVerb()); + insertFact(kb, rule.getFact()); + insertEntity(kb, rule.getReason().getLeftEntity()); + insertEntity(kb, rule.getReason().getRightEntity()); + insertVerb(kb, rule.getReason().getVerb()); + insertFact(kb, rule.getReason()); + insertRule(kb, rule); + } + catch (obelisk::ParserException& exception) + { + throw; + } } void obelisk::Parser::handleFact(std::unique_ptr& kb) @@ -699,3 +924,20 @@ void obelisk::Parser::insertSuggestAction(std::unique_ptr& kb, obelisk::Rule& rule) +{ + std::vector rules {rule}; + kb->addRules(rules); + rule = std::move(rules.front()); + + // the id was not inserted, so check if it exists in the database + if (rule.getId() == 0) + { + kb->getRule(rule); + if (rule.getId() == 0) + { + throw obelisk::ParserException("rule could not be inserted into the database"); + } + } +} diff --git a/src/parser.h b/src/parser.h index 7131a9e..13ad7ea 100644 --- a/src/parser.h +++ b/src/parser.h @@ -38,7 +38,7 @@ namespace obelisk std::unique_ptr parseTopLevelExpression(); std::unique_ptr parseExtern(); void parseAction(obelisk::SuggestAction& suggestAction); - void parseRule(std::vector& rules); + void parseRule(obelisk::Rule& rule); void parseFact(std::vector& facts); public: @@ -67,6 +67,7 @@ namespace obelisk void insertFact(std::unique_ptr& kb, obelisk::Fact& fact); void insertSuggestAction(std::unique_ptr& kb, obelisk::SuggestAction& suggestAction); + void insertRule(std::unique_ptr& kb, obelisk::Rule& rule); }; class ParserException : public std::exception