diff --git a/include/state-machine_pool.hpp b/include/state-machine_pool.hpp old mode 100644 new mode 100755 index 7082208..4432294 --- a/include/state-machine_pool.hpp +++ b/include/state-machine_pool.hpp @@ -1,7 +1,12 @@ #ifndef STATEMACHINE_POOL_HPP #define STATEMACHINE_POOL_HPP +#include "state.hpp" #include "state_machine.hpp" +#include +#include +#include +#include template class StateMachinePool { @@ -15,16 +20,46 @@ class StateMachinePool { machines.erase(id); } + void AddTransition(Transition &transition) { + std::string from = transition.from; + std::string to = transition.target_state; + std::function condition; + + transitions[from].push_back(transition); + } + void UpdateAll() { - for (const auto &[id, machine] : machines) + for (const auto &[id, machine] : machines) { + auto &owner = machine->GetOwner(); + const std::string ¤t = machine->GetCurrentStateId(); + + auto it = transitions.find(current); + if (it != transitions.end() && auto_transition == true) + for (const auto &transition : it->second) + if (transition.condition(owner)) { + machine->ChangeState(transition.target_state); + break; + } machine->Update(); + } } void ChangeState(const std::string &id, const std::string &stateId) { auto it = machines.find(id); + const std::string current = it->second->GetCurrentStateId(); + auto &owner = it->second->GetOwner(); - if (it != machines.end()) + bool allowed = true; + auto transitions_it = transitions.find(current); + if (strict_transition && (transitions_it != transitions.end())) { + for (const auto &transition : transitions_it->second) + if (!transition.condition(owner)) + allowed = false; + } + + if (it != machines.end() && allowed) { it->second->ChangeState(stateId); + } } void RevertState(const std::string &id) { @@ -88,17 +123,19 @@ class StateMachinePool { } void SetStrictTransitions(bool enabled) { - strict_transition = enabled; + this->strict_transition = enabled; + } - for (auto &[_, machine] : machines) - machine->SetStrictTransitions(enabled); + bool GetStrictTransitions() const { + return strict_transition; } void SetAutoTransitions(bool enabled) { - auto_transition = enabled; + this->auto_transition = enabled; + } - for (const auto &[_, machine] : machines) - machine->SetAutoTransitions(enabled); + bool GetAutoTransitions() const { + return auto_transition; } private: @@ -106,6 +143,7 @@ class StateMachinePool { std::unordered_map>> machines; bool auto_transition = false; bool strict_transition = false; + std::unordered_map>> transitions; }; #endif // STATEMACHINE_POOL_HPP diff --git a/include/state.hpp b/include/state.hpp index 7852363..c428366 100644 --- a/include/state.hpp +++ b/include/state.hpp @@ -7,6 +7,7 @@ template struct Transition { + const std::string from; const std::string target_state; const std::function condition; }; @@ -20,10 +21,6 @@ class State { virtual void OnUpdate(OwnerType &owner) = 0; virtual void OnExit(OwnerType &owner) = 0; - virtual void AddTransition(const std::string &to, - std::function condition) = 0; - virtual std::vector> GetTransitions() = 0; - virtual std::string Id() const = 0; }; @@ -32,20 +29,11 @@ class AState : public State { public: AState(const std::string &id) : entity_id(id) {}; - void AddTransition(const std::string &to, - std::function condition) override { - transitions.push_back({ to, condition }); - } - std::vector> GetTransitions() override { - return transitions; - } - std::string Id() const override { return entity_id; } private: - std::vector> transitions; std::string entity_id; }; diff --git a/include/state_machine.hpp b/include/state_machine.hpp index 85901a7..a0d0f1b 100644 --- a/include/state_machine.hpp +++ b/include/state_machine.hpp @@ -3,10 +3,12 @@ #include #include +#include #include #include +#include "state.hpp" #include "state_factory.hpp" using json = nlohmann::json; @@ -21,19 +23,6 @@ class StateMachine { if (!new_state) return; - if (strict_transition && current_state) { - bool allowed = false; - for (const auto &transition : current_state->GetTransitions()) { - if (transition.target_state == id && - (!transition.condition || transition.condition(owner))) { - allowed = true; - break; - } - } - if (!allowed) - return; - } - if (current_state) { current_state->OnExit(owner); history.push_back(current_state->Id()); @@ -67,12 +56,6 @@ class StateMachine { void Update() { if (!current_state) return; - - for (const auto &transition : current_state->GetTransitions()) - if (auto_transition && transition.condition(owner)) { - ChangeState(transition.target_state); - return; - } current_state->OnUpdate(owner); } @@ -91,32 +74,18 @@ class StateMachine { } const std::string GetCurrentStateId() const { - static const std::string nullState = "null"; + static const std::string nullState = "nullstate"; return current_state ? current_state->Id() : nullState; } - void SetStrictTransitions(bool enabled) { - this->strict_transition = enabled; - } - - bool GetStrictTransitions() const { - return strict_transition; - } - - void SetAutoTransitions(bool enabled) { - this->auto_transition = enabled; - } - - bool GetAutoTransitions() const { - return auto_transition; + T &GetOwner() { + return owner; } private: T &owner; const StateFactory &factory; std::shared_ptr> current_state; - bool strict_transition = false; - bool auto_transition = false; std::vector history; }; diff --git a/include/state_system.hpp b/include/state_system.hpp index d9f7b41..a9b1441 100644 --- a/include/state_system.hpp +++ b/include/state_system.hpp @@ -1,6 +1,7 @@ #ifndef STATE_SYSTEM_HPP #define STATE_SYSTEM_HPP +#include "state.hpp" #include "state_machine.hpp" #include "state-machine_pool.hpp" @@ -16,11 +17,6 @@ class StateSystem { GetPool().UnregisterEntity(id); } - template - void UpdateAll() { - GetPool().UpdateAll(); - } - template void ChangeState(const std::string &id, const std::string &stateId) { GetPool().ChangeState(id, stateId); @@ -37,6 +33,16 @@ class StateSystem { GetPool().RegisterState(id, std::move(state)); } + template + void AddTransition(Transition &transition) { + GetPool().AddTransition(transition); + } + + template + void UpdateAll() { + GetPool().UpdateAll(); + } + template json Serialize() const { StateMachinePool &pool = GetPool();