Skip to content

Commit b1f3098

Browse files
Merge branch 'main' into http-invocation-example
2 parents 664f70f + aafb900 commit b1f3098

File tree

13 files changed

+1001
-9
lines changed

13 files changed

+1001
-9
lines changed

dapr/actor/runtime/mock_actor.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from datetime import timedelta
17+
from typing import Any, Optional, TypeVar
18+
19+
from dapr.actor.id import ActorId
20+
from dapr.actor.runtime._reminder_data import ActorReminderData
21+
from dapr.actor.runtime._timer_data import TIMER_CALLBACK, ActorTimerData
22+
from dapr.actor.runtime.actor import Actor
23+
from dapr.actor.runtime.mock_state_manager import MockStateManager
24+
25+
26+
class MockActor(Actor):
27+
"""A mock actor class to be used to override certain Actor methods for unit testing.
28+
To be used only via the create_mock_actor function, which takes in a class and returns a
29+
mock actor object for that class.
30+
31+
Examples:
32+
class SomeActorInterface(ActorInterface):
33+
@actor_method(name="method")
34+
async def set_state(self, data: dict) -> None:
35+
36+
class SomeActor(Actor, SomeActorInterface):
37+
async def set_state(self, data: dict) -> None:
38+
await self._state_manager.set_state('state', data)
39+
await self._state_manager.save_state()
40+
41+
mock_actor = create_mock_actor(SomeActor, "actor_1")
42+
assert mock_actor._state_manager._mock_state == {}
43+
await mock_actor.set_state({"test":10})
44+
assert mock_actor._state_manager._mock_state == {"test":10}
45+
"""
46+
47+
def __init__(self, actor_id: str, initstate: Optional[dict]):
48+
self.id = ActorId(actor_id)
49+
self._runtime_ctx = None # type: ignore
50+
self._state_manager = MockStateManager(self, initstate)
51+
52+
async def register_timer(
53+
self,
54+
name: Optional[str],
55+
callback: TIMER_CALLBACK,
56+
state: Any,
57+
due_time: timedelta,
58+
period: timedelta,
59+
ttl: Optional[timedelta] = None,
60+
) -> None:
61+
"""Adds actor timer to self._state_manager._mock_timers.
62+
Args:
63+
name (str): the name of the timer to register.
64+
callback (Callable): An awaitable callable which will be called when the timer fires.
65+
state (Any): An object which will pass to the callback method, or None.
66+
due_time (datetime.timedelta): the amount of time to delay before the awaitable
67+
callback is first invoked.
68+
period (datetime.timedelta): the time interval between invocations
69+
of the awaitable callback.
70+
ttl (Optional[datetime.timedelta]): the time interval before the timer stops firing
71+
"""
72+
name = name or self.__get_new_timer_name()
73+
timer = ActorTimerData(name, callback, state, due_time, period, ttl)
74+
self._state_manager._mock_timers[name] = timer # type: ignore
75+
76+
async def unregister_timer(self, name: str) -> None:
77+
"""Unregisters actor timer from self._state_manager._mock_timers.
78+
79+
Args:
80+
name (str): the name of the timer to unregister.
81+
"""
82+
self._state_manager._mock_timers.pop(name, None) # type: ignore
83+
84+
async def register_reminder(
85+
self,
86+
name: str,
87+
state: bytes,
88+
due_time: timedelta,
89+
period: timedelta,
90+
ttl: Optional[timedelta] = None,
91+
) -> None:
92+
"""Adds actor reminder to self._state_manager._mock_reminders.
93+
94+
Args:
95+
name (str): the name of the reminder to register. the name must be unique per actor.
96+
state (bytes): the user state passed to the reminder invocation.
97+
due_time (datetime.timedelta): the amount of time to delay before invoking the reminder
98+
for the first time.
99+
period (datetime.timedelta): the time interval between reminder invocations after
100+
the first invocation.
101+
ttl (datetime.timedelta): the time interval before the reminder stops firing
102+
"""
103+
reminder = ActorReminderData(name, state, due_time, period, ttl)
104+
self._state_manager._mock_reminders[name] = reminder # type: ignore
105+
106+
async def unregister_reminder(self, name: str) -> None:
107+
"""Unregisters actor reminder from self._state_manager._mock_reminders..
108+
109+
Args:
110+
name (str): the name of the reminder to unregister.
111+
"""
112+
self._state_manager._mock_reminders.pop(name, None) # type: ignore
113+
114+
115+
T = TypeVar('T', bound=Actor)
116+
117+
118+
def create_mock_actor(cls1: type[T], actor_id: str, initstate: Optional[dict] = None) -> T:
119+
class MockSuperClass(MockActor, cls1): # type: ignore
120+
pass
121+
122+
return MockSuperClass(actor_id, initstate) # type: ignore
+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
import asyncio
15+
from contextvars import ContextVar
16+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar
17+
18+
from dapr.actor.runtime._reminder_data import ActorReminderData
19+
from dapr.actor.runtime._timer_data import ActorTimerData
20+
from dapr.actor.runtime.state_change import ActorStateChange, StateChangeKind
21+
from dapr.actor.runtime.state_manager import ActorStateManager, StateMetadata
22+
23+
if TYPE_CHECKING:
24+
from dapr.actor.runtime.mock_actor import MockActor
25+
26+
T = TypeVar('T')
27+
CONTEXT: ContextVar[Optional[Dict[str, Any]]] = ContextVar('state_tracker_context')
28+
29+
30+
class MockStateManager(ActorStateManager):
31+
def __init__(self, actor: 'MockActor', initstate: Optional[dict]):
32+
self._actor = actor
33+
self._default_state_change_tracker: Dict[str, StateMetadata] = {}
34+
self._mock_state: Dict[str, Any] = {}
35+
self._mock_timers: Dict[str, ActorTimerData] = {}
36+
self._mock_reminders: Dict[str, ActorReminderData] = {}
37+
if initstate:
38+
self._mock_state = initstate
39+
40+
async def add_state(self, state_name: str, value: T) -> None:
41+
if not await self.try_add_state(state_name, value):
42+
raise ValueError(f'The actor state name {state_name} already exist.')
43+
44+
async def try_add_state(self, state_name: str, value: T) -> bool:
45+
if state_name in self._default_state_change_tracker:
46+
state_metadata = self._default_state_change_tracker[state_name]
47+
if state_metadata.change_kind == StateChangeKind.remove:
48+
self._default_state_change_tracker[state_name] = StateMetadata(
49+
value, StateChangeKind.update
50+
)
51+
return True
52+
return False
53+
existed = state_name in self._mock_state
54+
if existed:
55+
return False
56+
self._default_state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
57+
self._mock_state[state_name] = value
58+
return True
59+
60+
async def get_state(self, state_name: str) -> Optional[T]:
61+
has_value, val = await self.try_get_state(state_name)
62+
if has_value:
63+
return val
64+
else:
65+
raise KeyError(f'Actor State with name {state_name} was not found.')
66+
67+
async def try_get_state(self, state_name: str) -> Tuple[bool, Optional[T]]:
68+
if state_name in self._default_state_change_tracker:
69+
state_metadata = self._default_state_change_tracker[state_name]
70+
if state_metadata.change_kind == StateChangeKind.remove:
71+
return False, None
72+
return True, state_metadata.value
73+
has_value = state_name in self._mock_state
74+
val = self._mock_state.get(state_name)
75+
if has_value:
76+
self._default_state_change_tracker[state_name] = StateMetadata(
77+
val, StateChangeKind.none
78+
)
79+
return has_value, val
80+
81+
async def set_state(self, state_name: str, value: T) -> None:
82+
await self.set_state_ttl(state_name, value, None)
83+
84+
async def set_state_ttl(self, state_name: str, value: T, ttl_in_seconds: Optional[int]) -> None:
85+
if ttl_in_seconds is not None and ttl_in_seconds < 0:
86+
return
87+
88+
if state_name in self._default_state_change_tracker:
89+
state_metadata = self._default_state_change_tracker[state_name]
90+
state_metadata.value = value
91+
state_metadata.ttl_in_seconds = ttl_in_seconds
92+
93+
if (
94+
state_metadata.change_kind == StateChangeKind.none
95+
or state_metadata.change_kind == StateChangeKind.remove
96+
):
97+
state_metadata.change_kind = StateChangeKind.update
98+
self._default_state_change_tracker[state_name] = state_metadata
99+
self._mock_state[state_name] = value
100+
return
101+
102+
existed = state_name in self._mock_state
103+
if existed:
104+
self._default_state_change_tracker[state_name] = StateMetadata(
105+
value, StateChangeKind.update, ttl_in_seconds
106+
)
107+
else:
108+
self._default_state_change_tracker[state_name] = StateMetadata(
109+
value, StateChangeKind.add, ttl_in_seconds
110+
)
111+
self._mock_state[state_name] = value
112+
113+
async def remove_state(self, state_name: str) -> None:
114+
if not await self.try_remove_state(state_name):
115+
raise KeyError(f'Actor State with name {state_name} was not found.')
116+
117+
async def try_remove_state(self, state_name: str) -> bool:
118+
if state_name in self._default_state_change_tracker:
119+
state_metadata = self._default_state_change_tracker[state_name]
120+
if state_metadata.change_kind == StateChangeKind.remove:
121+
return False
122+
elif state_metadata.change_kind == StateChangeKind.add:
123+
self._default_state_change_tracker.pop(state_name, None)
124+
self._mock_state.pop(state_name, None)
125+
return True
126+
self._mock_state.pop(state_name, None)
127+
state_metadata.change_kind = StateChangeKind.remove
128+
return True
129+
130+
existed = state_name in self._mock_state
131+
if existed:
132+
self._default_state_change_tracker[state_name] = StateMetadata(
133+
None, StateChangeKind.remove
134+
)
135+
self._mock_state.pop(state_name, None)
136+
return True
137+
return False
138+
139+
async def contains_state(self, state_name: str) -> bool:
140+
if state_name in self._default_state_change_tracker:
141+
state_metadata = self._default_state_change_tracker[state_name]
142+
return state_metadata.change_kind != StateChangeKind.remove
143+
return state_name in self._mock_state
144+
145+
async def get_or_add_state(self, state_name: str, value: T) -> Optional[T]:
146+
has_value, val = await self.try_get_state(state_name)
147+
if has_value:
148+
return val
149+
change_kind = (
150+
StateChangeKind.update
151+
if self.is_state_marked_for_remove(state_name)
152+
else StateChangeKind.add
153+
)
154+
self._mock_state[state_name] = value
155+
self._default_state_change_tracker[state_name] = StateMetadata(value, change_kind)
156+
return value
157+
158+
async def add_or_update_state(
159+
self, state_name: str, value: T, update_value_factory: Callable[[str, T], T]
160+
) -> T:
161+
if not callable(update_value_factory):
162+
raise AttributeError('update_value_factory is not callable')
163+
164+
if state_name in self._default_state_change_tracker:
165+
state_metadata = self._default_state_change_tracker[state_name]
166+
if state_metadata.change_kind == StateChangeKind.remove:
167+
self._default_state_change_tracker[state_name] = StateMetadata(
168+
value, StateChangeKind.update
169+
)
170+
self._mock_state[state_name] = value
171+
return value
172+
new_value = update_value_factory(state_name, state_metadata.value)
173+
state_metadata.value = new_value
174+
if state_metadata.change_kind == StateChangeKind.none:
175+
state_metadata.change_kind = StateChangeKind.update
176+
self._default_state_change_tracker[state_name] = state_metadata
177+
self._mock_state[state_name] = new_value
178+
return new_value
179+
180+
has_value = state_name in self._mock_state
181+
val: Any = self._mock_state.get(state_name)
182+
if has_value:
183+
new_value = update_value_factory(state_name, val)
184+
self._default_state_change_tracker[state_name] = StateMetadata(
185+
new_value, StateChangeKind.update
186+
)
187+
self._mock_state[state_name] = new_value
188+
return new_value
189+
self._default_state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
190+
self._mock_state[state_name] = value
191+
return value
192+
193+
async def get_state_names(self) -> List[str]:
194+
# TODO: Get all state names from Dapr once implemented.
195+
def append_names_sync():
196+
state_names = []
197+
for key, value in self._default_state_change_tracker.items():
198+
if value.change_kind == StateChangeKind.add:
199+
state_names.append(key)
200+
elif value.change_kind == StateChangeKind.remove:
201+
state_names.append(key)
202+
return state_names
203+
204+
default_loop = asyncio.get_running_loop()
205+
return await default_loop.run_in_executor(None, append_names_sync)
206+
207+
async def clear_cache(self) -> None:
208+
self._default_state_change_tracker.clear()
209+
210+
async def save_state(self) -> None:
211+
if len(self._default_state_change_tracker) == 0:
212+
return
213+
214+
state_changes = []
215+
states_to_remove = []
216+
for state_name, state_metadata in self._default_state_change_tracker.items():
217+
if state_metadata.change_kind == StateChangeKind.none:
218+
continue
219+
state_changes.append(
220+
ActorStateChange(
221+
state_name,
222+
state_metadata.value,
223+
state_metadata.change_kind,
224+
state_metadata.ttl_in_seconds,
225+
)
226+
)
227+
if state_metadata.change_kind == StateChangeKind.remove:
228+
states_to_remove.append(state_name)
229+
# Mark the states as unmodified so that tracking for next invocation is done correctly.
230+
state_metadata.change_kind = StateChangeKind.none
231+
for state_name in states_to_remove:
232+
self._default_state_change_tracker.pop(state_name, None)
233+
234+
def is_state_marked_for_remove(self, state_name: str) -> bool:
235+
return (
236+
state_name in self._default_state_change_tracker
237+
and self._default_state_change_tracker[state_name].change_kind == StateChangeKind.remove
238+
)

dapr/actor/runtime/state_manager.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515

1616
import asyncio
1717
from contextvars import ContextVar
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar
1819

19-
from dapr.actor.runtime.state_change import StateChangeKind, ActorStateChange
2020
from dapr.actor.runtime.reentrancy_context import reentrancy_ctx
21-
22-
from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar, Optional, TYPE_CHECKING
21+
from dapr.actor.runtime.state_change import ActorStateChange, StateChangeKind
2322

2423
if TYPE_CHECKING:
2524
from dapr.actor.runtime.actor import Actor

0 commit comments

Comments
 (0)