finish knowledge base insertion

This commit is contained in:
Chris Cromer 2023-02-20 09:11:10 -03:00
parent fc0984904c
commit 88f17011ef
Signed by: cromer
GPG Key ID: FA91071797BEEEC2
8 changed files with 330 additions and 63 deletions

View File

@ -223,6 +223,28 @@ void obelisk::KnowledgeBase::getRule(obelisk::Rule& rule)
rule.selectById(dbConnection_); rule.selectById(dbConnection_);
} }
void obelisk::KnowledgeBase::checkRule(obelisk::Fact& fact)
{
std::vector<obelisk::Rule> rules;
obelisk::Rule::selectByReason(dbConnection_, fact.getId(), rules);
for (auto& rule : rules)
{
auto reason = rule.getReason();
getFact(reason);
if (reason.getIsTrue())
{
auto updateFact = rule.getFact();
updateFact.setIsTrue(true);
updateFact.updateIsTrue(dbConnection_);
}
}
}
void obelisk::KnowledgeBase::updateIsTrue(obelisk::Fact& fact)
{
fact.updateIsTrue(dbConnection_);
}
void obelisk::KnowledgeBase::getFloat(float& result1, float& result2, double var) void obelisk::KnowledgeBase::getFloat(float& result1, float& result2, double var)
{ {
result1 = (float) var; result1 = (float) var;

View File

@ -183,6 +183,21 @@ namespace obelisk
*/ */
void getRule(obelisk::Rule& rule); void getRule(obelisk::Rule& rule);
/**
* @brief Check if a rule looks for this Fact, if so update its
* truth.
*
* @param[in,out] fact The Fact to check for existing rules.
*/
void checkRule(obelisk::Fact& fact);
/**
* @brief Update the is true field in the KnowledgeBase.
*
* @param[in,out] fact The fact to update.
*/
void updateIsTrue(obelisk::Fact& fact);
/** /**
* @brief Take a float and divide it into 2 floats. * @brief Take a float and divide it into 2 floats.
* *

View File

@ -28,71 +28,101 @@ void obelisk::Fact::selectById(sqlite3* dbConnection)
sqlite3_stmt* ppStmt = nullptr; sqlite3_stmt* ppStmt = nullptr;
auto result = sqlite3_prepare_v2(dbConnection, const char* query;
"SELECT id, left_entity, right_entity, verb, is_true FROM fact WHERE (left_entity=? AND right_entity=? AND verb=?)", if (getId() == 0)
-1, {
&ppStmt, query
nullptr); = "SELECT id, left_entity, right_entity, verb, is_true FROM fact WHERE (left_entity=? AND right_entity=? AND verb=?)";
}
else
{
query = "SELECT id, left_entity, right_entity, verb, is_true FROM fact WHERE (id=?)";
}
auto result = sqlite3_prepare_v2(dbConnection, query, -1, &ppStmt, nullptr);
if (result != SQLITE_OK) if (result != SQLITE_OK)
{ {
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
} }
result = sqlite3_bind_int(ppStmt, 1, getLeftEntity().getId()); if (getId() == 0)
switch (result)
{ {
case SQLITE_OK : result = sqlite3_bind_int(ppStmt, 1, getLeftEntity().getId());
break; switch (result)
case SQLITE_TOOBIG : {
throw obelisk::DatabaseSizeException(); case SQLITE_OK :
break; break;
case SQLITE_RANGE : case SQLITE_TOOBIG :
throw obelisk::DatabaseRangeException(); throw obelisk::DatabaseSizeException();
break; break;
case SQLITE_NOMEM : case SQLITE_RANGE :
throw obelisk::DatabaseMemoryException(); throw obelisk::DatabaseRangeException();
break; break;
default : case SQLITE_NOMEM :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); throw obelisk::DatabaseMemoryException();
break; break;
} default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_bind_int(ppStmt, 2, getRightEntity().getId()); result = sqlite3_bind_int(ppStmt, 2, getRightEntity().getId());
switch (result) switch (result)
{ {
case SQLITE_OK : case SQLITE_OK :
break; break;
case SQLITE_TOOBIG : case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException(); throw obelisk::DatabaseSizeException();
break; break;
case SQLITE_RANGE : case SQLITE_RANGE :
throw obelisk::DatabaseRangeException(); throw obelisk::DatabaseRangeException();
break; break;
case SQLITE_NOMEM : case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException(); throw obelisk::DatabaseMemoryException();
break; break;
default : default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break; break;
} }
result = sqlite3_bind_int(ppStmt, 3, getVerb().getId()); result = sqlite3_bind_int(ppStmt, 3, getVerb().getId());
switch (result) switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
}
else
{ {
case SQLITE_OK : result = sqlite3_bind_int(ppStmt, 1, getId());
break; switch (result)
case SQLITE_TOOBIG : {
throw obelisk::DatabaseSizeException(); case SQLITE_OK :
break; break;
case SQLITE_RANGE : case SQLITE_TOOBIG :
throw obelisk::DatabaseRangeException(); throw obelisk::DatabaseSizeException();
break; break;
case SQLITE_NOMEM : case SQLITE_RANGE :
throw obelisk::DatabaseMemoryException(); throw obelisk::DatabaseRangeException();
break; break;
default : case SQLITE_NOMEM :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection)); throw obelisk::DatabaseMemoryException();
break; break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
} }
result = sqlite3_step(ppStmt); result = sqlite3_step(ppStmt);
@ -248,6 +278,85 @@ void obelisk::Fact::insert(sqlite3* dbConnection)
} }
} }
void obelisk::Fact::updateIsTrue(sqlite3* dbConnection)
{
if (dbConnection == nullptr)
{
throw obelisk::DatabaseException("database isn't open");
}
sqlite3_stmt* ppStmt = nullptr;
auto result = sqlite3_prepare_v2(dbConnection, "UPDATE fact SET is_true=? WHERE id=?", -1, &ppStmt, nullptr);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
result = sqlite3_bind_int(ppStmt, 1, getIsTrue());
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_bind_int(ppStmt, 2, getId());
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_step(ppStmt);
switch (result)
{
case SQLITE_DONE :
// Row updated
break;
case SQLITE_CONSTRAINT :
throw obelisk::DatabaseConstraintException(sqlite3_errmsg(dbConnection));
case SQLITE_BUSY :
throw obelisk::DatabaseBusyException();
break;
case SQLITE_MISUSE :
throw obelisk::DatabaseMisuseException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
result = sqlite3_finalize(ppStmt);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
}
int& obelisk::Fact::getId() int& obelisk::Fact::getId()
{ {
return id_; return id_;

View File

@ -211,6 +211,14 @@ namespace obelisk
* @param[in] dbConnection The database connection to use. * @param[in] dbConnection The database connection to use.
*/ */
void insert(sqlite3* dbConnection); void insert(sqlite3* dbConnection);
/**
* @brief Update whether or not the fact is true in the
* KnowledgeBase.
*
* @param[in] dbConnection The database connection.
*/
void updateIsTrue(sqlite3* dbConnection);
}; };
} // namespace obelisk } // namespace obelisk

View File

@ -183,6 +183,69 @@ void obelisk::Rule::insert(sqlite3* dbConnection)
} }
} }
void obelisk::Rule::selectByReason(sqlite3* dbConnection, int reasonId, std::vector<obelisk::Rule>& rules)
{
if (dbConnection == nullptr)
{
throw obelisk::DatabaseException("database isn't open");
}
sqlite3_stmt* ppStmt = nullptr;
auto result
= sqlite3_prepare_v2(dbConnection, "SELECT id, fact, reason FROM rule WHERE (reason=?)", -1, &ppStmt, nullptr);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
result = sqlite3_bind_int(ppStmt, 1, reasonId);
switch (result)
{
case SQLITE_OK :
break;
case SQLITE_TOOBIG :
throw obelisk::DatabaseSizeException();
break;
case SQLITE_RANGE :
throw obelisk::DatabaseRangeException();
break;
case SQLITE_NOMEM :
throw obelisk::DatabaseMemoryException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
while ((result = sqlite3_step(ppStmt)) != SQLITE_DONE)
{
switch (result)
{
case SQLITE_ROW :
rules.push_back(obelisk::Rule(sqlite3_column_int(ppStmt, 0),
obelisk::Fact(sqlite3_column_int(ppStmt, 1)),
obelisk::Fact(sqlite3_column_int(ppStmt, 2))));
break;
case SQLITE_BUSY :
throw obelisk::DatabaseBusyException();
break;
case SQLITE_MISUSE :
throw obelisk::DatabaseMisuseException();
break;
default :
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
break;
}
}
result = sqlite3_finalize(ppStmt);
if (result != SQLITE_OK)
{
throw obelisk::DatabaseException(sqlite3_errmsg(dbConnection));
}
}
int& obelisk::Rule::getId() int& obelisk::Rule::getId()
{ {
return id_; return id_;

View File

@ -4,6 +4,7 @@
#include "models/fact.h" #include "models/fact.h"
#include <string> #include <string>
#include <vector>
namespace obelisk namespace obelisk
{ {
@ -140,6 +141,14 @@ namespace obelisk
*/ */
void selectById(sqlite3* dbConnection); void selectById(sqlite3* dbConnection);
/**
* @brief Get the rules that match the reason.
*
* @param[in] dbConnection The database connection to use.
* @param[out] rules The rules to fill in from the database.
*/
static void selectByReason(sqlite3* dbConnection, int reasonId, std::vector<obelisk::Rule>& rules);
/** /**
* @brief Insert the Rule into the KnowledgeBase. * @brief Insert the Rule into the KnowledgeBase.
* *

View File

@ -770,14 +770,23 @@ void obelisk::Parser::handleRule(std::unique_ptr<obelisk::KnowledgeBase>& kb)
try try
{ {
parseRule(rule); parseRule(rule);
insertEntity(kb, rule.getFact().getLeftEntity());
insertEntity(kb, rule.getFact().getRightEntity());
insertVerb(kb, rule.getFact().getVerb());
insertFact(kb, rule.getFact());
insertEntity(kb, rule.getReason().getLeftEntity()); insertEntity(kb, rule.getReason().getLeftEntity());
insertEntity(kb, rule.getReason().getRightEntity()); insertEntity(kb, rule.getReason().getRightEntity());
insertVerb(kb, rule.getReason().getVerb()); insertVerb(kb, rule.getReason().getVerb());
insertFact(kb, rule.getReason()); insertFact(kb, rule.getReason());
// The rule is true, so the fact must be true to.
if (rule.getReason().getIsTrue())
{
rule.getFact().setIsTrue(true);
}
insertEntity(kb, rule.getFact().getLeftEntity());
insertEntity(kb, rule.getFact().getRightEntity());
insertVerb(kb, rule.getFact().getVerb());
insertFact(kb, rule.getFact());
insertRule(kb, rule); insertRule(kb, rule);
} }
catch (obelisk::ParserException& exception) catch (obelisk::ParserException& exception)
@ -830,12 +839,14 @@ void obelisk::Parser::handleFact(std::unique_ptr<obelisk::KnowledgeBase>& kb)
try try
{ {
insertFact(kb, fact); insertFact(kb, fact, true);
} }
catch (obelisk::ParserException& exception) catch (obelisk::ParserException& exception)
{ {
throw; throw;
} }
kb->checkRule(fact);
} }
} }
@ -890,7 +901,7 @@ void obelisk::Parser::insertAction(std::unique_ptr<obelisk::KnowledgeBase>& kb,
} }
} }
void obelisk::Parser::insertFact(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Fact& fact) void obelisk::Parser::insertFact(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Fact& fact, bool updateIsTrue)
{ {
std::vector<obelisk::Fact> facts {fact}; std::vector<obelisk::Fact> facts {fact};
kb->addFacts(facts); kb->addFacts(facts);
@ -904,6 +915,14 @@ void obelisk::Parser::insertFact(std::unique_ptr<obelisk::KnowledgeBase>& kb, ob
{ {
throw obelisk::ParserException("fact could not be inserted into the database"); throw obelisk::ParserException("fact could not be inserted into the database");
} }
else
{
if (updateIsTrue)
{
fact.setIsTrue(true);
kb->updateIsTrue(fact);
}
}
} }
} }

View File

@ -17,12 +17,32 @@
namespace obelisk namespace obelisk
{ {
/**
* @brief The Parser is responsible for analyzing the language's key words
* and taking action based on its analysis.
*
*/
class Parser class Parser
{ {
private: private:
/**
* @brief The Lexer object that the Parser is using to Parse a
* specific source file.
*
*/
std::shared_ptr<obelisk::Lexer> lexer_; std::shared_ptr<obelisk::Lexer> lexer_;
/**
* @brief The current token that the lexer has retrieved.
*
*/
int currentToken_ = 0; int currentToken_ = 0;
/**
* @brief Set the current token.
*
* @param[in] currentToken The token should be ASCII character.
*/
void setCurrentToken(int currentToken); void setCurrentToken(int currentToken);
std::unique_ptr<obelisk::ExpressionAST> logError(const char* str); std::unique_ptr<obelisk::ExpressionAST> logError(const char* str);
@ -64,7 +84,9 @@ namespace obelisk
void insertEntity(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Entity& entity); void insertEntity(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Entity& entity);
void insertVerb(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Verb& verb); void insertVerb(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Verb& verb);
void insertAction(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Action& action); void insertAction(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Action& action);
void insertFact(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Fact& fact); void insertFact(std::unique_ptr<obelisk::KnowledgeBase>& kb,
obelisk::Fact& fact,
bool updateIsTrue = false);
void insertSuggestAction(std::unique_ptr<obelisk::KnowledgeBase>& kb, void insertSuggestAction(std::unique_ptr<obelisk::KnowledgeBase>& kb,
obelisk::SuggestAction& suggestAction); obelisk::SuggestAction& suggestAction);
void insertRule(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Rule& rule); void insertRule(std::unique_ptr<obelisk::KnowledgeBase>& kb, obelisk::Rule& rule);