implement rules

This commit is contained in:
Chris Cromer 2023-02-19 01:03:35 -03:00
parent 2f54f13b54
commit ba8788af56
Signed by: cromer
GPG Key ID: FA91071797BEEEC2
7 changed files with 517 additions and 27 deletions

View File

@ -174,6 +174,25 @@ void obelisk::KnowledgeBase::addSuggestActions(std::vector<obelisk::SuggestActio
}
}
void obelisk::KnowledgeBase::addRules(std::vector<obelisk::Rule>& rules)
{
for (auto& rule : rules)
{
try
{
rule.insert(dbConnection_);
}
catch (obelisk::DatabaseConstraintException& exception)
{
// ignore unique constraint error
if (std::strcmp(exception.what(), "UNIQUE constraint failed: rule.fact, rule.reason") != 0)
{
throw;
}
}
}
}
void obelisk::KnowledgeBase::getEntity(obelisk::Entity& entity)
{
entity.selectByName(dbConnection_);
@ -199,6 +218,11 @@ void obelisk::KnowledgeBase::getSuggestAction(obelisk::SuggestAction& suggestAct
suggestAction.selectById(dbConnection_);
}
void obelisk::KnowledgeBase::getRule(obelisk::Rule& rule)
{
rule.selectById(dbConnection_);
}
void obelisk::KnowledgeBase::getFloat(float& result1, float& result2, double var)
{
result1 = (float) var;

View File

@ -127,6 +127,14 @@ namespace obelisk
*/
void addSuggestActions(std::vector<obelisk::SuggestAction>& suggestActions);
/**
* @brief Add rules to the KnowledgeBase.
*
* @param[in,out] rules The rules to add. If the insert is successful it
* will have a row ID, if not the ID will be 0.
*/
void addRules(std::vector<obelisk::Rule>& rules);
/**
* @brief Get an Entity object based on the ID it contains.
*
@ -167,6 +175,14 @@ namespace obelisk
*/
void getSuggestAction(obelisk::SuggestAction& suggestAction);
/**
* @brief Get a Rule based on the ID it contains.
*
* @param[in,out] rule The Rule object should contain just the ID
* and the rest will be filled in.
*/
void getRule(obelisk::Rule& rule);
/**
* @brief Take a float and divide it into 2 floats.
*

View File

@ -1,3 +1,4 @@
#include "models/error.h"
#include "models/rule.h"
const char* obelisk::Rule::createTable()
@ -15,6 +16,173 @@ const char* obelisk::Rule::createTable()
)";
}
void obelisk::Rule::selectById(sqlite3* dbConnection)
{
if (dbConnection == nullptr)
{
throw obelisk::DatabaseException("database isn't open");
}
sqlite3_stmt* ppStmt = nullptr;
auto result = sqlite3_prepare_v2(dbConnection,
"SELECT id, fact, reason FROM rule WHERE (fact=? AND reason=?)",
-1,
&ppStmt,
nullptr);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
result = sqlite3_bind_int(ppStmt, 1, getFact().getId());
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_bind_int(ppStmt, 2, getReason().getId());
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_step(ppStmt);
switch (result)
{
case SQLITE_DONE :
// no rows in the database
break;
case SQLITE_ROW :
setId(sqlite3_column_int(ppStmt, 0));
getFact().setId(sqlite3_column_int(ppStmt, 1));
getReason().setId(sqlite3_column_int(ppStmt, 2));
break;
case SQLITE_BUSY :
throw obelisk::DatabaseBusyException();
break;
case SQLITE_MISUSE :
throw obelisk::DatabaseMisuseException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_finalize(ppStmt);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
}
void obelisk::Rule::insert(sqlite3* dbConnection)
{
if (dbConnection == nullptr)
{
throw obelisk::DatabaseException("database isn't open");
}
sqlite3_stmt* ppStmt = nullptr;
auto result
= sqlite3_prepare_v2(dbConnection, "INSERT INTO rule (fact, reason) VALUES (?, ?)", -1, &ppStmt, nullptr);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
result = sqlite3_bind_int(ppStmt, 1, getFact().getId());
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_bind_int(ppStmt, 2, getReason().getId());
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_step(ppStmt);
switch (result)
{
case SQLITE_DONE :
setId((int) sqlite3_last_insert_rowid(dbConnection));
sqlite3_set_last_insert_rowid(dbConnection, 0);
break;
case SQLITE_CONSTRAINT :
throw obelisk::DatabaseConstraintException(sqlite3_errmsg(dbConnection));
case SQLITE_BUSY :
throw obelisk::DatabaseBusyException();
break;
case SQLITE_MISUSE :
throw obelisk::DatabaseMisuseException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_finalize(ppStmt);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
}
int& obelisk::Rule::getId()
{
return id_;

View File

@ -131,6 +131,21 @@ namespace obelisk
* @param[in] reason The reason Fact.
*/
void setReason(obelisk::Fact reason);
/**
* @brief Select the Rule from the KnowledgeBase by IDs of the
* sub-objects.
*
* @param[in] dbConnection The database connection to use.
*/
void selectById(sqlite3* dbConnection);
/**
* @brief Insert the Rule into the KnowledgeBase.
*
* @param[in] dbConnection The database connection to use.
*/
void insert(sqlite3* dbConnection);
};
} // namespace obelisk

View File

@ -84,13 +84,37 @@ int obelisk::mainLoop(const std::vector<std::string>& sourceFiles, const std::st
}
break;
case obelisk::Lexer::kTokenFact :
parser->handleFact(kb);
try
{
parser->handleFact(kb);
}
catch (obelisk::ParserException& exception)
{
std::cout << "Error: " << exception.what() << std::endl;
return EXIT_FAILURE;
}
break;
case obelisk::Lexer::kTokenRule :
// parser->handleRule();
try
{
parser->handleRule(kb);
}
catch (obelisk::ParserException& exception)
{
std::cout << "Error: " << exception.what() << std::endl;
return EXIT_FAILURE;
}
break;
case obelisk::Lexer::kTokenAction :
parser->handleAction(kb);
try
{
parser->handleAction(kb);
}
catch (obelisk::ParserException& exception)
{
std::cout << "Error: " << exception.what() << std::endl;
return EXIT_FAILURE;
}
break;
default :
parser->getNextToken();

View File

@ -257,14 +257,13 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction)
syntax.pop();
if (verb == "")
{
leftEntity = entityName;
leftEntity = std::move(entityName);
}
else
{
rightEntity = entityName;
rightEntity = std::move(entityName);
}
entityName = "";
getEntity = false;
getEntity = false;
getNextToken();
continue;
}
@ -290,23 +289,22 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction)
if (getCurrentToken() == '"')
{
throw obelisk::ParserException("unexpected '\"'");
break;
}
if (getLexer()->getIdentifier() == "and")
{
getNextToken();
getEntity = true;
continue;
}
else if (getLexer()->getIdentifier() == "then")
if (getLexer()->getIdentifier() == "then")
{
break;
}
else
{
verb = getLexer()->getIdentifier();
// TODO: make sure verb is alphabetic
verb = getLexer()->getIdentifier();
for (const auto& letter : verb)
{
if (!isalpha(letter))
{
throw new obelisk::ParserException("non alphabetic symbol in verb");
}
}
getEntity = true;
continue;
}
@ -333,14 +331,13 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction)
syntax.pop();
if (trueAction == "")
{
trueAction = entityName;
trueAction = std::move(entityName);
}
else
{
falseAction = entityName;
falseAction = std::move(entityName);
}
entityName = "";
getAction = false;
getAction = false;
getNextToken();
continue;
}
@ -370,6 +367,21 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction)
syntax.pop();
}
if (leftEntity == "")
{
throw obelisk::ParserException("missing left entity");
}
if (rightEntity == "")
{
throw obelisk::ParserException("missing left entity");
}
if (verb == "")
{
throw obelisk::ParserException("missing verb");
}
if (trueAction == "")
{
throw obelisk::ParserException("missing true action");
@ -380,6 +392,12 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction)
throw obelisk::ParserException("missing false action");
}
getNextToken();
if (getCurrentToken() != ';')
{
throw obelisk::ParserException("missing ';'");
}
break;
}
@ -409,8 +427,185 @@ void obelisk::Parser::parseAction(obelisk::SuggestAction& suggestAction)
suggestAction.setFalseAction(obelisk::Action(falseAction));
}
void obelisk::Parser::parseRule(std::vector<obelisk::Rule>& rules)
void obelisk::Parser::parseRule(obelisk::Rule& rule)
{
std::stack<char> syntax;
getNextToken();
if (getCurrentToken() != '(')
{
throw obelisk::ParserException("expected '(' but got '" + std::to_string(getCurrentToken()) + "'");
}
syntax.push('(');
bool getEntity {true};
bool getReason {false};
std::string leftEntity {""};
std::string rightEntity {""};
std::string verb {""};
std::string leftReasonEntity {""};
std::string rightReasonEntity {""};
std::string reasonVerb {""};
std::string entityName {""};
getNextToken();
// get the entity side of statement
while (true)
{
if (getEntity)
{
if (getCurrentToken() == '"')
{
if (syntax.top() != '"')
{
// open a double quote
syntax.push('"');
getNextToken();
}
else if (syntax.top() == '"')
{
// close a double quote
syntax.pop();
if (!getReason)
{
if (verb == "")
{
leftEntity = std::move(entityName);
}
else
{
rightEntity = std::move(entityName);
}
}
else
{
if (reasonVerb == "")
{
leftReasonEntity = std::move(entityName);
}
else
{
rightReasonEntity = std::move(entityName);
}
}
getEntity = false;
getNextToken();
continue;
}
}
if (syntax.top() == '"')
{
if (entityName != "")
{
entityName += " ";
}
entityName += getLexer()->getIdentifier();
}
getNextToken();
}
else
{
if (getCurrentToken() == ')')
{
// closing parenthesis found, make sure we have everything needed
if (syntax.top() != '(')
{
throw obelisk::ParserException("unexpected ')'");
}
else
{
syntax.pop();
}
if (leftEntity == "")
{
throw obelisk::ParserException("missing left entity");
}
if (rightEntity == "")
{
throw obelisk::ParserException("missing left entity");
}
if (verb == "")
{
throw obelisk::ParserException("missing verb");
}
if (leftReasonEntity == "")
{
throw obelisk::ParserException("missing left reason entity");
}
if (rightReasonEntity == "")
{
throw obelisk::ParserException("missing right reason entity");
}
if (reasonVerb == "")
{
throw obelisk::ParserException("missing reason verb");
}
getNextToken();
if (getCurrentToken() != ';')
{
throw obelisk::ParserException("missing ';'");
}
break;
}
if (getCurrentToken() == '"')
{
throw obelisk::ParserException("unexpected '\"'");
}
if (getLexer()->getIdentifier() == "if")
{
getReason = true;
getEntity = true;
getNextToken();
continue;
}
else
{
if (!getReason)
{
verb = getLexer()->getIdentifier();
for (const auto& letter : verb)
{
if (!isalpha(letter))
{
throw new obelisk::ParserException("non alphabetic symbol in verb");
}
}
getEntity = true;
continue;
}
else
{
reasonVerb = getLexer()->getIdentifier();
for (const auto& letter : reasonVerb)
{
if (!isalpha(letter))
{
throw new obelisk::ParserException("non alphabetic symbol in verb");
}
}
getEntity = true;
continue;
}
}
}
}
rule.setFact(obelisk::Fact(obelisk::Entity(leftEntity), obelisk::Entity(rightEntity), obelisk::Verb(verb)));
rule.setReason(obelisk::Fact(obelisk::Entity(leftReasonEntity),
obelisk::Entity(rightReasonEntity),
obelisk::Verb(reasonVerb)));
}
void obelisk::Parser::parseFact(std::vector<obelisk::Fact>& facts)
@ -501,13 +696,18 @@ void obelisk::Parser::parseFact(std::vector<obelisk::Fact>& facts)
throw obelisk::ParserException("missing right side entities");
}
getNextToken();
if (getCurrentToken() != ';')
{
throw obelisk::ParserException("missing ';'");
}
break;
}
if (getCurrentToken() == '"')
{
throw obelisk::ParserException("unexpected '\"'");
break;
}
if (getLexer()->getIdentifier() == "and")
@ -518,8 +718,14 @@ void obelisk::Parser::parseFact(std::vector<obelisk::Fact>& facts)
}
else
{
verb = getLexer()->getIdentifier();
// TODO: make sure verb is alphabetic
verb = getLexer()->getIdentifier();
for (const auto& letter : verb)
{
if (!isalpha(letter))
{
throw new obelisk::ParserException("non alphabetic symbol in verb");
}
}
getEntity = true;
continue;
}
@ -539,10 +745,10 @@ void obelisk::Parser::parseFact(std::vector<obelisk::Fact>& facts)
void obelisk::Parser::handleAction(std::unique_ptr<obelisk::KnowledgeBase>& kb)
{
obelisk::SuggestAction suggestAction;
parseAction(suggestAction);
try
{
parseAction(suggestAction);
insertEntity(kb, suggestAction.getFact().getLeftEntity());
insertEntity(kb, suggestAction.getFact().getRightEntity());
insertVerb(kb, suggestAction.getFact().getVerb());
@ -559,6 +765,25 @@ void obelisk::Parser::handleAction(std::unique_ptr<obelisk::KnowledgeBase>& kb)
void obelisk::Parser::handleRule(std::unique_ptr<obelisk::KnowledgeBase>& kb)
{
obelisk::Rule rule;
try
{
parseRule(rule);
insertEntity(kb, rule.getFact().getLeftEntity());
insertEntity(kb, rule.getFact().getRightEntity());
insertVerb(kb, rule.getFact().getVerb());
insertFact(kb, rule.getFact());
insertEntity(kb, rule.getReason().getLeftEntity());
insertEntity(kb, rule.getReason().getRightEntity());
insertVerb(kb, rule.getReason().getVerb());
insertFact(kb, rule.getReason());
insertRule(kb, rule);
}
catch (obelisk::ParserException& exception)
{
throw;
}
}
void obelisk::Parser::handleFact(std::unique_ptr<obelisk::KnowledgeBase>& kb)
@ -699,3 +924,20 @@ void obelisk::Parser::insertSuggestAction(std::unique_ptr<obelisk::KnowledgeBase
}
}
}
void obelisk::Parser::insertRule(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Rule& rule)
{
std::vector<obelisk::Rule> rules {rule};
kb->addRules(rules);
rule = std::move(rules.front());
// the id was not inserted, so check if it exists in the database
if (rule.getId() == 0)
{
kb->getRule(rule);
if (rule.getId() == 0)
{
throw obelisk::ParserException("rule could not be inserted into the database");
}
}
}

View File

@ -38,7 +38,7 @@ namespace obelisk
std::unique_ptr<obelisk::FunctionAST> parseTopLevelExpression();
std::unique_ptr<obelisk::PrototypeAST> parseExtern();
void parseAction(obelisk::SuggestAction& suggestAction);
void parseRule(std::vector<obelisk::Rule>& rules);
void parseRule(obelisk::Rule& rule);
void parseFact(std::vector<obelisk::Fact>& facts);
public:
@ -67,6 +67,7 @@ namespace obelisk
void insertFact(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Fact& fact);
void insertSuggestAction(std::unique_ptr<obelisk::KnowledgeBase>& kb,
obelisk::SuggestAction& suggestAction);
void insertRule(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Rule& rule);
};
class ParserException : public std::exception