diff --git a/Assets/BetterStateMachine/Runtime/IStateMachine.cs b/Assets/BetterStateMachine/Runtime/IStateMachine.cs index a718a7d..1d7ddb0 100644 --- a/Assets/BetterStateMachine/Runtime/IStateMachine.cs +++ b/Assets/BetterStateMachine/Runtime/IStateMachine.cs @@ -14,20 +14,16 @@ public interface IStateMachine where TState : BaseState TState CurrentState { get; } bool InTransition { get; } Task TransitionTask { get; } - + void Run(); - Task ChangeStateAsync(TState newState, CancellationToken cancellationToken = default); bool InState() where T : TState; + Task ChangeStateAsync(TState newState, CancellationToken cancellationToken = default); void Stop(); - - public void AddModule(Module module); + + public bool AddModule(Module module); + public bool HasModule(Module module); public bool HasModule(Type type); - public bool HasModule() where TModule : Module; public bool TryGetModule(Type type, out Module module); - public bool TryGetModule(out TModule module) where TModule : Module; - public Module GetModule(Type type); - public TModule GetModule() where TModule : Module; - public bool RemoveModule(Type type); - public bool RemoveModule() where TModule : Module; + public bool RemoveModule(Module module); } } \ No newline at end of file diff --git a/Assets/BetterStateMachine/Runtime/Modules/Module.cs b/Assets/BetterStateMachine/Runtime/Modules/Module.cs index 5e8107e..55e9721 100644 --- a/Assets/BetterStateMachine/Runtime/Modules/Module.cs +++ b/Assets/BetterStateMachine/Runtime/Modules/Module.cs @@ -6,77 +6,76 @@ namespace Better.StateMachine.Runtime.Modules { public abstract class Module where TState : BaseState { - protected IStateMachine StateMachine { get; private set; } + public int LinksCount { get; private set; } + public bool IsLinked => LinksCount > 0; - internal void Link(IStateMachine stateMachine) + public virtual bool AllowLinkTo(IStateMachine stateMachine) { - if (StateMachine != null) - { - var message = $"Already linked to {nameof(StateMachine)}"; - DebugUtility.LogException(message); - return; - } + return true; + } - StateMachine = stateMachine; + internal void Link(IStateMachine stateMachine) + { + LinksCount++; OnLinked(stateMachine); } protected abstract void OnLinked(IStateMachine stateMachine); - internal void Unlink() + internal void Unlink(IStateMachine stateMachine) { - if (StateMachine == null) - { - var message = "Already unlinked"; - DebugUtility.LogException(message); - return; - } - - StateMachine = null; - OnUnlinked(); + LinksCount--; + OnUnlinked(stateMachine); } - protected abstract void OnUnlinked(); - - public virtual bool AllowRunMachine() + protected abstract void OnUnlinked(IStateMachine stateMachine); + + public virtual bool AllowRunMachine(IStateMachine stateMachine) { return true; } - public virtual void OnMachineRunned() + public virtual void OnMachineRunned(IStateMachine stateMachine) { } - public virtual bool AllowChangeState(TState state) + public virtual bool AllowChangeState(IStateMachine stateMachine, TState state) { return true; } - public virtual void OnStatePreChanged(TState state) + public virtual void OnStatePreChanged(IStateMachine stateMachine, TState state) { } - public virtual void OnStateChanged(TState state) + public virtual void OnStateChanged(IStateMachine stateMachine, TState state) { } - public virtual bool AllowStopMachine() + public virtual bool AllowStopMachine(IStateMachine stateMachine) { return true; } - public virtual void OnMachineStopped() + public virtual void OnMachineStopped(IStateMachine stateMachine) { } - protected bool ValidateMachineRunning(bool targetState, bool logException = true) + protected bool ValidateMachineRunning(IStateMachine stateMachine, bool targetState, bool logException = true) { - var isRunning = StateMachine?.IsRunning ?? false; + if (stateMachine == null) + { + var message = $"Is not valid, {nameof(stateMachine)} is null"; + DebugUtility.LogException(message); + return false; + } + + var isRunning = stateMachine.IsRunning; var isValid = isRunning == targetState; if (!isValid && logException) { var reason = targetState ? "not running" : "is running"; - var message = $"Is not valid, {nameof(StateMachine)} {reason}"; + var message = $"Is not valid, {nameof(stateMachine)} {reason}"; DebugUtility.LogException(message); } @@ -88,4 +87,8 @@ public override string ToString() return GetType().Name; } } + + public abstract class Module : Module + { + } } \ No newline at end of file diff --git a/Assets/BetterStateMachine/Runtime/Modules/SingleModule.cs b/Assets/BetterStateMachine/Runtime/Modules/SingleModule.cs new file mode 100644 index 0000000..04a2a44 --- /dev/null +++ b/Assets/BetterStateMachine/Runtime/Modules/SingleModule.cs @@ -0,0 +1,53 @@ +using System; +using Better.Commons.Runtime.Utility; +using Better.StateMachine.Runtime.States; + +namespace Better.StateMachine.Runtime.Modules +{ + public abstract class SingleModule : Module + where TState : BaseState + { + protected IStateMachine StateMachine { get; private set; } + + public override bool AllowLinkTo(IStateMachine stateMachine) + { + return base.AllowLinkTo(stateMachine) && !IsLinked; + } + + protected override void OnLinked(IStateMachine stateMachine) + { + if (IsLinked) + { + var message = "Already linked"; + DebugUtility.LogException(message); + + stateMachine.RemoveModule(this); + return; + } + + StateMachine = stateMachine; + } + + protected override void OnUnlinked(IStateMachine stateMachine) + { + if (!IsLinked) + { + var message = "Already unlinked"; + DebugUtility.LogException(message); + + return; + } + + StateMachine = null; + } + + protected bool ValidateMachineRunning(bool targetState, bool logException = true) + { + return ValidateMachineRunning(StateMachine, targetState, logException); + } + } + + public abstract class SingleModule : SingleModule + { + } +} \ No newline at end of file diff --git a/Assets/BetterStateMachine/Runtime/Modules/SingleModule.cs.meta b/Assets/BetterStateMachine/Runtime/Modules/SingleModule.cs.meta new file mode 100644 index 0000000..9e4962d --- /dev/null +++ b/Assets/BetterStateMachine/Runtime/Modules/SingleModule.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: e555dace4e234a95a9cd56ea8af7f1d6 +timeCreated: 1713278364 \ No newline at end of file diff --git a/Assets/BetterStateMachine/Runtime/Modules/StatesCacheModule.cs b/Assets/BetterStateMachine/Runtime/Modules/StatesCacheModule.cs index 19302bf..46f0894 100644 --- a/Assets/BetterStateMachine/Runtime/Modules/StatesCacheModule.cs +++ b/Assets/BetterStateMachine/Runtime/Modules/StatesCacheModule.cs @@ -10,20 +10,31 @@ public class StatesCacheModule : Module { public event Action Cached; + private readonly bool _autoCache; + private readonly bool _autoClear; private readonly Dictionary _typeInstanceMap; - public StatesCacheModule() + public StatesCacheModule(bool autoCache, bool autoClear) { _typeInstanceMap = new(); + autoCache = autoCache; + _autoClear = autoClear; + } + + public StatesCacheModule() : this(true, true) + { } protected override void OnLinked(IStateMachine stateMachine) { } - protected override void OnUnlinked() + protected override void OnUnlinked(IStateMachine stateMachine) { - ClearCache(); + if (_autoClear && !IsLinked) + { + ClearCache(); + } } public void Cache(TState state) @@ -39,11 +50,6 @@ public void Cache(TState state) OnCached(state); } - protected virtual void OnCached(TState state) - { - Cached?.Invoke(state); - } - public T Cache() where T : TState, new() { var state = new T(); @@ -52,11 +58,34 @@ protected virtual void OnCached(TState state) return state; } + protected virtual void OnCached(TState state) + { + Cached?.Invoke(state); + } + public bool Contains(Type type) { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + return false; + } + return _typeInstanceMap.ContainsKey(type); } + public bool Contains(TState state) + { + if (state == null) + { + DebugUtility.LogException(nameof(state)); + return false; + } + + var type = state.GetType(); + return Remove(type); + } + public bool Contains() where T : TState { @@ -66,6 +95,14 @@ public bool Contains() public bool TryGet(Type type, out TState module) { + if (type == null) + { + DebugUtility.LogException(nameof(module)); + + module = default; + return false; + } + return _typeInstanceMap.TryGetValue(type, out module); } @@ -85,6 +122,12 @@ public bool TryGet(out T state) where T : TState public TState Get(Type type) { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + return default; + } + if (TryGet(type, out var module)) { return module; @@ -120,21 +163,43 @@ public T Get() where T : TState public bool Remove(Type type) { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + return false; + } + return _typeInstanceMap.Remove(type); } + public bool Remove(TState state) + { + if (state == null) + { + DebugUtility.LogException(nameof(state)); + return false; + } + + var type = state.GetType(); + return Remove(type); + } + public bool Remove() where T : TState { var type = typeof(T); return Remove(type); } - public override void OnStatePreChanged(TState state) + public override void OnStatePreChanged(IStateMachine stateMachine, TState state) { - base.OnStatePreChanged(state); - Cache(state); + base.OnStatePreChanged(stateMachine, state); + + if (_autoCache) + { + Cache(state); + } } - + public void ClearCache() { _typeInstanceMap.Clear(); diff --git a/Assets/BetterStateMachine/Runtime/Modules/SynchronousModule.cs b/Assets/BetterStateMachine/Runtime/Modules/SynchronousModule.cs index 32cf540..b9045c7 100644 --- a/Assets/BetterStateMachine/Runtime/Modules/SynchronousModule.cs +++ b/Assets/BetterStateMachine/Runtime/Modules/SynchronousModule.cs @@ -4,7 +4,7 @@ namespace Better.StateMachine.Runtime.Modules { - public class SynchronousModule : Module + public class SynchronousModule : SingleModule where TState : BaseState { private readonly bool _onlySyncState; @@ -12,39 +12,35 @@ public class SynchronousModule : Module private float _cachedFrame; - public SynchronousModule(bool onlySyncState = true, bool allowLogs = true) + public SynchronousModule(bool onlySyncState, bool allowLogs) { _onlySyncState = onlySyncState; _allowLogs = allowLogs; } - protected override void OnLinked(IStateMachine stateMachine) + public SynchronousModule() : this(true, true) { } - protected override void OnUnlinked() + public override bool AllowRunMachine(IStateMachine stateMachine) { + return base.AllowRunMachine(stateMachine) && ValidateStateType(_allowLogs); } - public override bool AllowRunMachine() + public override bool AllowChangeState(IStateMachine stateMachine, TState state) { - return base.AllowRunMachine() && ValidateStateType(_allowLogs); + return base.AllowChangeState(stateMachine, state) && ValidateStateType(state, _allowLogs); } - public override bool AllowChangeState(TState state) + public override void OnStatePreChanged(IStateMachine stateMachine, TState state) { - return base.AllowChangeState(state) && ValidateStateType(state, _allowLogs); - } - - public override void OnStatePreChanged(TState state) - { - base.OnStatePreChanged(state); + base.OnStatePreChanged(stateMachine, state); _cachedFrame = Time.frameCount; } - public override void OnStateChanged(TState state) + public override void OnStateChanged(IStateMachine stateMachine, TState state) { - base.OnStateChanged(state); + base.OnStateChanged(stateMachine, state); if (_cachedFrame < Time.frameCount && _allowLogs) { diff --git a/Assets/BetterStateMachine/Runtime/Modules/Transitions/AutoTransitionsModule.cs b/Assets/BetterStateMachine/Runtime/Modules/Transitions/AutoTransitionsModule.cs index 6e79852..118527c 100644 --- a/Assets/BetterStateMachine/Runtime/Modules/Transitions/AutoTransitionsModule.cs +++ b/Assets/BetterStateMachine/Runtime/Modules/Transitions/AutoTransitionsModule.cs @@ -15,14 +15,18 @@ public class AutoTransitionsModule : TransitionsModule private CancellationTokenSource _tokenSource; private float _tickTimestep; - public AutoTransitionsModule(float tickTimestep = DefaultTickTimestep) : base() + public AutoTransitionsModule(float tickTimestep) { _tickTimestep = Mathf.Max(tickTimestep, 0f); } - public override void OnMachineRunned() + public AutoTransitionsModule() : this(DefaultTickTimestep) { - base.OnMachineRunned(); + } + + public override void OnMachineRunned(IStateMachine stateMachine) + { + base.OnMachineRunned(stateMachine); _tokenSource?.Cancel(); _tokenSource = new CancellationTokenSource(); @@ -43,16 +47,16 @@ protected async Task TickAsync(CancellationToken cancellationToken) } while (!cancellationToken.IsCancellationRequested); } - public override void OnMachineStopped() + public override void OnMachineStopped(IStateMachine stateMachine) { - base.OnMachineStopped(); + base.OnMachineStopped(stateMachine); _tokenSource?.Cancel(); } - protected override void OnUnlinked() + protected override void OnUnlinked(IStateMachine stateMachine) { - base.OnUnlinked(); + base.OnUnlinked(stateMachine); _tokenSource?.Cancel(); } diff --git a/Assets/BetterStateMachine/Runtime/Modules/Transitions/TransitionsModule.cs b/Assets/BetterStateMachine/Runtime/Modules/Transitions/TransitionsModule.cs index 15dd315..ab8a295 100644 --- a/Assets/BetterStateMachine/Runtime/Modules/Transitions/TransitionsModule.cs +++ b/Assets/BetterStateMachine/Runtime/Modules/Transitions/TransitionsModule.cs @@ -6,7 +6,7 @@ namespace Better.StateMachine.Runtime.Modules.Transitions { - public abstract class TransitionsModule : Module + public abstract class TransitionsModule : SingleModule where TState : BaseState { protected readonly Dictionary> _outfromingBundles; @@ -22,24 +22,22 @@ public TransitionsModule() protected override void OnLinked(IStateMachine stateMachine) { + base.OnLinked(stateMachine); + _currentBundles.Clear(); _currentBundles.Add(_anyToBundles); } - protected override void OnUnlinked() - { - } - - public override void OnMachineRunned() + public override void OnMachineRunned(IStateMachine stateMachine) { - base.OnMachineRunned(); + base.OnMachineRunned(stateMachine); ReconditionTransitions(); } - public override void OnStateChanged(TState state) + public override void OnStateChanged(IStateMachine stateMachine, TState state) { - base.OnStateChanged(state); + base.OnStateChanged(stateMachine, state); UpdateTransitions(state); } diff --git a/Assets/BetterStateMachine/Runtime/StateMachine.cs b/Assets/BetterStateMachine/Runtime/StateMachine.cs index 5d561c7..8ed969c 100644 --- a/Assets/BetterStateMachine/Runtime/StateMachine.cs +++ b/Assets/BetterStateMachine/Runtime/StateMachine.cs @@ -52,7 +52,7 @@ public virtual void Run() foreach (var module in _typeModuleMap.Values) { - if (!module.AllowRunMachine()) + if (!module.AllowRunMachine(this)) { var message = $"{module} not allow machine run"; Debug.LogWarning(message); @@ -66,7 +66,7 @@ public virtual void Run() foreach (var module in _typeModuleMap.Values) { - module.OnMachineRunned(); + module.OnMachineRunned(this); } } @@ -79,7 +79,7 @@ public virtual void Stop() foreach (var module in _typeModuleMap.Values) { - if (!module.AllowStopMachine()) + if (!module.AllowStopMachine(this)) { var message = $"{module} not allow machine stop"; Debug.LogWarning(message); @@ -93,7 +93,7 @@ public virtual void Stop() foreach (var module in _typeModuleMap.Values) { - module.OnMachineStopped(); + module.OnMachineStopped(this); } } @@ -117,7 +117,7 @@ public async Task ChangeStateAsync(TState newState, CancellationToken cancellati foreach (var module in _typeModuleMap.Values) { - if (!module.AllowChangeState(newState)) + if (!module.AllowChangeState(this, newState)) { var message = $"{module} not allow change state to {newState}"; Debug.LogWarning(message); @@ -159,7 +159,7 @@ protected virtual void OnStatePreChanged(TState state) { foreach (var module in _typeModuleMap.Values) { - module.OnStatePreChanged(state); + module.OnStatePreChanged(this, state); } } @@ -167,7 +167,7 @@ protected virtual void OnStateChanged(TState state) { foreach (var module in _typeModuleMap.Values) { - module.OnStateChanged(state); + module.OnStateChanged(this, state); } StateChanged?.Invoke(state); @@ -182,17 +182,17 @@ public bool InState() where T : TState #region Modules - public void AddModule(Module module) + public bool AddModule(Module module) { if (module == null) { DebugUtility.LogException(nameof(module)); - return; + return false; } if (!ValidateRunning(false)) { - return; + return false; } var type = module.GetType(); @@ -200,11 +200,19 @@ public void AddModule(Module module) { var message = $"{nameof(module)} of {nameof(type)}({type}) already added"; Debug.LogWarning(message); - return; + return false; + } + + if (!module.AllowLinkTo(this)) + { + var message = $"{nameof(module)} of {nameof(type)}({type}) not allowed linked"; + Debug.LogWarning(message); + return false; } _typeModuleMap.Add(type, module); module.Link(this); + return true; } public TModule AddModule() @@ -218,6 +226,12 @@ public TModule AddModule() public bool HasModule(Type type) { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + return false; + } + return _typeModuleMap.ContainsKey(type); } @@ -228,8 +242,27 @@ public bool HasModule() return HasModule(type); } + public bool HasModule(Module module) + { + if (module == null) + { + DebugUtility.LogException(nameof(module)); + return false; + } + + return _typeModuleMap.ContainsValue(module); + } + public bool TryGetModule(Type type, out Module module) { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + + module = default; + return false; + } + return _typeModuleMap.TryGetValue(type, out module); } @@ -250,6 +283,12 @@ public bool TryGetModule(out TModule module) public Module GetModule(Type type) { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + return null; + } + if (TryGetModule(type, out var module)) { return module; @@ -285,28 +324,44 @@ public TModule GetOrAddModule() return AddModule(); } - public bool RemoveModule(Type type) + public bool RemoveModule(Module module) { - if (_typeModuleMap.TryGetValue(type, out var module) - && _typeModuleMap.Remove(type)) + if (module == null) + { + DebugUtility.LogException(nameof(module)); + return false; + } + + var type = module.GetType(); + if (_typeModuleMap.Remove(type)) { - module.Unlink(); + module.Unlink(this); return true; } return false; } + public bool RemoveModule(Type type) + { + if (type == null) + { + DebugUtility.LogException(nameof(type)); + return false; + } + + return TryGetModule(type, out var module) && RemoveModule(module); + } + public bool RemoveModule() where TModule : Module { - var type = typeof(TModule); - return RemoveModule(type); + return TryGetModule(out var module) && RemoveModule(module); } #endregion - private bool ValidateRunning(bool targetState, bool logException = true) + protected bool ValidateRunning(bool targetState, bool logException = true) { var isValid = IsRunning == targetState; if (!isValid && logException)