Skip to content

Commit 8f4efa5

Browse files
authoredFeb 11, 2025
Python: Allow for factory callbacks in the process framework (microsoft#10451)
### Motivation and Context In the current Python process framework, in runtimes like Dapr, there is no easy way to pass complex (unserializable) dependencies to a step -- this includes things like a ChatCompletion service or an agent of type ChatCompletionAgent. Similar to how the kernel dependency was propagated to the step_actor or process_actor, we're introducing the ability to specify a factory callback that will be called as the step is instantiated. The factory is created, if specified via the optional kwarg when adding a step to the process builder like: ```python myBStep = process.add_step(step_type=BStep, factory_function=bstep_factory) ``` The `bstep_factory` looks like (along with its corresponding step) ```python async def bstep_factory(): """Creates a BStep instance with ephemeral references like ChatCompletionAgent.""" kernel = Kernel() kernel.add_service(AzureChatCompletion()) agent = ChatCompletionAgent(kernel=kernel, name="echo", instructions="repeat the input back") step_instance = BStep() step_instance.agent = agent return step_instance class BStep(KernelProcessStep): """A sample BStep that optionally holds a ChatCompletionAgent. By design, the agent is ephemeral (not stored in state). """ # Ephemeral references won't be persisted to Dapr # because we do not place them in a step state model. # We'll set this in the factory function: agent: ChatCompletionAgent | None = None @kernel_function(name="do_it") async def do_it(self, context: KernelProcessStepContext): print("##### BStep ran (do_it).") await asyncio.sleep(2) if self.agent: history = ChatHistory() history.add_user_message("Hello from BStep!") async for msg in self.agent.invoke(history): print(f"BStep got agent response: {msg.content}") await context.emit_event(process_event="BStepDone", data="I did B") ``` Although this isn't explicitly necessary with the local runtime, the factory callback will also work, if desired. <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> ### Description Adds the ability to specify a factory callback for a step in the process framework. - Adjusts the Dapr FastAPI demo sample to show how one can include a dependency like a `ChatCompletionAgent` and use the factory callback for `BStep`. Although the output from the agent isn't needed, it demonstrates the capability to handle these types of dependencies while running Dapr. - Adds unit tests - Closes microsoft#10409 <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone 😄
1 parent b8795b7 commit 8f4efa5

File tree

15 files changed

+283
-54
lines changed

15 files changed

+283
-54
lines changed
 

‎python/samples/demos/process_with_dapr/fastapi_app.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from samples.demos.process_with_dapr.process.process import get_process
1212
from samples.demos.process_with_dapr.process.steps import CommonEvents
1313
from semantic_kernel import Kernel
14-
from semantic_kernel.processes.dapr_runtime import (
15-
register_fastapi_dapr_actors,
16-
start,
17-
)
14+
from semantic_kernel.processes.dapr_runtime import register_fastapi_dapr_actors, start
1815

1916
logging.basicConfig(level=logging.ERROR)
2017

@@ -34,12 +31,16 @@
3431
# and returns the actor instance with the kernel injected. #
3532
#########################################################################
3633

34+
# Get the process which means we have the `KernelProcess` object
35+
# along with any defined step factories
36+
process = get_process()
37+
3738

3839
# Define a lifespan method that registers the actors with the Dapr runtime
3940
@asynccontextmanager
4041
async def lifespan(app: FastAPI):
4142
print("## actor startup ##")
42-
await register_fastapi_dapr_actors(actor, kernel)
43+
await register_fastapi_dapr_actors(actor, kernel, process.factories)
4344
yield
4445

4546

@@ -56,8 +57,6 @@ async def healthcheck():
5657
@app.get("/processes/{process_id}")
5758
async def start_process(process_id: str):
5859
try:
59-
process = get_process()
60-
6160
_ = await start(
6261
process=process,
6362
kernel=kernel,

‎python/samples/demos/process_with_dapr/process/process.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22

33
from typing import TYPE_CHECKING
44

5-
from samples.demos.process_with_dapr.process.steps import AStep, BStep, CommonEvents, CStep, CStepState, KickOffStep
5+
from samples.demos.process_with_dapr.process.steps import (
6+
AStep,
7+
BStep,
8+
CommonEvents,
9+
CStep,
10+
CStepState,
11+
KickOffStep,
12+
bstep_factory,
13+
)
614
from semantic_kernel.processes import ProcessBuilder
715

816
if TYPE_CHECKING:
@@ -16,7 +24,7 @@ def get_process() -> "KernelProcess":
1624
# Add the step types to the builder
1725
kickoff_step = process.add_step(step_type=KickOffStep)
1826
myAStep = process.add_step(step_type=AStep)
19-
myBStep = process.add_step(step_type=BStep)
27+
myBStep = process.add_step(step_type=BStep, factory_function=bstep_factory)
2028

2129
# Initialize the CStep with an initial state and the state's current cycle set to 1
2230
myCStep = process.add_step(step_type=CStep, initial_state=CStepState(current_cycle=1))

‎python/samples/demos/process_with_dapr/process/steps.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
from pydantic import Field
88

9+
from semantic_kernel.agents.chat_completion.chat_completion_agent import ChatCompletionAgent
10+
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
11+
from semantic_kernel.contents.chat_history import ChatHistory
912
from semantic_kernel.functions import kernel_function
13+
from semantic_kernel.kernel import Kernel
1014
from semantic_kernel.kernel_pydantic import KernelBaseModel
1115
from semantic_kernel.processes.kernel_process import (
1216
KernelProcessStep,
@@ -52,14 +56,43 @@ async def do_it(self, context: KernelProcessStepContext):
5256
await context.emit_event(process_event=CommonEvents.AStepDone, data="I did A")
5357

5458

55-
# Define a sample `BStep` step that will emit an event after 2 seconds.
56-
# The event will be sent to the `CStep` step with the data `I did B`.
59+
# Define a simple factory for the BStep that can create the dependency that the BStep requires
60+
# As an example, this factory creates a kernel and adds an `AzureChatCompletion` service to it.
61+
async def bstep_factory():
62+
"""Creates a BStep instance with ephemeral references like ChatCompletionAgent."""
63+
kernel = Kernel()
64+
kernel.add_service(AzureChatCompletion())
65+
66+
agent = ChatCompletionAgent(kernel=kernel, name="echo", instructions="repeat the input back")
67+
step_instance = BStep()
68+
step_instance.agent = agent
69+
70+
return step_instance
71+
72+
5773
class BStep(KernelProcessStep):
58-
@kernel_function()
74+
"""A sample BStep that optionally holds a ChatCompletionAgent.
75+
76+
By design, the agent is ephemeral (not stored in state).
77+
"""
78+
79+
# Ephemeral references won't be persisted to Dapr
80+
# because we do not place them in a step state model.
81+
# We'll set this in the factory function:
82+
agent: ChatCompletionAgent | None = None
83+
84+
@kernel_function(name="do_it")
5985
async def do_it(self, context: KernelProcessStepContext):
60-
print("##### BStep ran.")
86+
print("##### BStep ran (do_it).")
6187
await asyncio.sleep(2)
62-
await context.emit_event(process_event=CommonEvents.BStepDone, data="I did B")
88+
89+
if self.agent:
90+
history = ChatHistory()
91+
history.add_user_message("Hello from BStep!")
92+
async for msg in self.agent.invoke(history):
93+
print(f"BStep got agent response: {msg.content}")
94+
95+
await context.emit_event(process_event="BStepDone", data="I did B")
6396

6497

6598
# Define a sample `CStepState` that will keep track of the current cycle.

‎python/semantic_kernel/processes/dapr_runtime/actors/process_actor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77
import uuid
8+
from collections.abc import Callable
89
from queue import Queue
910
from typing import Any
1011

@@ -46,16 +47,18 @@
4647
class ProcessActor(StepActor, ProcessInterface):
4748
"""A local process that contains a collection of steps."""
4849

49-
def __init__(self, ctx: ActorRuntimeContext, actor_id: ActorId, kernel: Kernel):
50+
def __init__(self, ctx: ActorRuntimeContext, actor_id: ActorId, kernel: Kernel, factories: dict[str, Callable]):
5051
"""Initializes a new instance of ProcessActor.
5152
5253
Args:
5354
ctx: The actor runtime context.
5455
actor_id: The unique ID for the actor.
5556
kernel: The Kernel dependency to be injected.
57+
factories: The factory dictionary that contains step types to factory methods.
5658
"""
57-
super().__init__(ctx, actor_id, kernel)
59+
super().__init__(ctx, actor_id, kernel, factories)
5860
self.kernel = kernel
61+
self.factories = factories
5962
self.steps: list[StepInterface] = []
6063
self.step_infos: list[DaprStepInfo] = []
6164
self.initialize_task: bool | None = False

‎python/semantic_kernel/processes/dapr_runtime/actors/step_actor.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import importlib
55
import json
66
import logging
7+
from collections.abc import Callable
8+
from inspect import isawaitable
79
from queue import Queue
810
from typing import Any
911

@@ -46,16 +48,18 @@
4648
class StepActor(Actor, StepInterface, KernelProcessMessageChannel):
4749
"""Represents a step actor that follows the Step abstract class."""
4850

49-
def __init__(self, ctx: ActorRuntimeContext, actor_id: ActorId, kernel: Kernel):
51+
def __init__(self, ctx: ActorRuntimeContext, actor_id: ActorId, kernel: Kernel, factories: dict[str, Callable]):
5052
"""Initializes a new instance of StepActor.
5153
5254
Args:
5355
ctx: The actor runtime context.
5456
actor_id: The unique ID for the actor.
5557
kernel: The Kernel dependency to be injected.
58+
factories: The factory dictionary to use for creating the step.
5659
"""
5760
super().__init__(ctx, actor_id)
5861
self.kernel = kernel
62+
self.factories: dict[str, Callable] = factories
5963
self.parent_process_id: str | None = None
6064
self.step_info: DaprStepInfo | None = None
6165
self.initialize_task: bool | None = False
@@ -172,31 +176,38 @@ def _get_class_from_string(self, full_class_name: str):
172176

173177
async def activate_step(self):
174178
"""Initializes the step."""
175-
# Instantiate an instance of the inner step object
176-
step_cls = self._get_class_from_string(self.inner_step_type)
177-
178-
step_instance: KernelProcessStep = step_cls() # type: ignore
179+
# Instantiate an instance of the inner step object and retrieve its class reference.
180+
if self.factories and self.inner_step_type in self.factories:
181+
step_object = self.factories[self.inner_step_type]()
182+
if isawaitable(step_object):
183+
step_object = await step_object
184+
step_cls = step_object.__class__
185+
step_instance: KernelProcessStep = step_object # type: ignore
186+
else:
187+
step_cls = self._get_class_from_string(self.inner_step_type)
188+
step_instance: KernelProcessStep = step_cls() # type: ignore
179189

180190
kernel_plugin = self.kernel.add_plugin(
181-
step_instance, self.step_info.state.name if self.step_info.state else "default_name"
191+
step_instance,
192+
self.step_info.state.name if self.step_info.state else "default_name",
182193
)
183194

184-
# Load the kernel functions
195+
# Load the kernel functions.
185196
for name, f in kernel_plugin.functions.items():
186197
self.functions[name] = f
187198

188-
# Initialize the input channels
199+
# Initialize the input channels.
189200
self.initial_inputs = find_input_channels(channel=self, functions=self.functions)
190201
self.inputs = {k: {kk: vv for kk, vv in v.items()} if v else {} for k, v in self.initial_inputs.items()}
191202

192-
# Use the existing state or create a new one if not provided
203+
# Use the existing state or create a new one if not provided.
193204
state_object = self.step_info.state
194205

195-
# Extract TState from inner_step_type
206+
# Extract TState from inner_step_type using the class reference.
196207
t_state = get_generic_state_type(step_cls)
197208

198209
if t_state is not None:
199-
# Create state_type as KernelProcessStepState[TState]
210+
# Create state_type as KernelProcessStepState[TState].
200211
state_type = KernelProcessStepState[t_state]
201212

202213
if state_object is None:
@@ -206,7 +217,7 @@ async def activate_step(self):
206217
state=None,
207218
)
208219
else:
209-
# Make sure state_object is an instance of state_type
220+
# Ensure that state_object is an instance of the expected type.
210221
if not isinstance(state_object, KernelProcessStepState):
211222
error_message = "State object is not of the expected type."
212223
raise KernelException(error_message)
@@ -215,25 +226,22 @@ async def activate_step(self):
215226
ActorStateKeys.StepStateType.value,
216227
get_fully_qualified_name(t_state),
217228
)
218-
219229
await self._state_manager.try_add_state(
220230
ActorStateKeys.StepStateJson.value,
221231
json.dumps(state_object.model_dump()),
222232
)
223-
224233
await self._state_manager.save_state()
225234

226-
# Make sure that state_object.state is not None
235+
# Initialize state_object.state if it is not already set.
227236
if state_object.state is None:
228237
try:
229238
state_object.state = t_state()
230239
except Exception as e:
231240
error_message = f"Cannot instantiate state of type {t_state}: {e}"
232241
raise KernelException(error_message)
233242
else:
234-
# The step has no user-defined state; use the base KernelProcessStepState
243+
# The step has no user-defined state; use the base KernelProcessStepState.
235244
state_type = KernelProcessStepState
236-
237245
if state_object is None:
238246
state_object = state_type(
239247
name=step_cls.__name__,
@@ -245,7 +253,7 @@ async def activate_step(self):
245253
error_message = "The state object for the KernelProcessStep could not be created."
246254
raise KernelException(error_message)
247255

248-
# Set the step state and activate the step with the state object
256+
# Set the step state and activate the step with the state object.
249257
self.step_state = state_object
250258
await step_instance.activate(state_object)
251259

‎python/semantic_kernel/processes/dapr_runtime/dapr_actor_registration.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3+
from collections.abc import Callable
34
from typing import TYPE_CHECKING
45

56
from dapr.actor import ActorId
@@ -19,22 +20,32 @@
1920
from semantic_kernel.kernel import Kernel
2021

2122

22-
def create_actor_factories(kernel: "Kernel") -> tuple:
23+
def create_actor_factories(kernel: "Kernel", factories: dict[str, Callable] | None = None) -> tuple:
2324
"""Creates actor factories for ProcessActor and StepActor."""
25+
if factories is None:
26+
factories = {}
2427

25-
def process_actor_factory(ctx: ActorRuntimeContext, actor_id: ActorId) -> ProcessActor:
26-
return ProcessActor(ctx, actor_id, kernel)
28+
def process_actor_factory(
29+
ctx: ActorRuntimeContext,
30+
actor_id: ActorId,
31+
) -> ProcessActor:
32+
return ProcessActor(ctx, actor_id, kernel=kernel, factories=factories)
2733

28-
def step_actor_factory(ctx: ActorRuntimeContext, actor_id: ActorId) -> StepActor:
29-
return StepActor(ctx, actor_id, kernel=kernel)
34+
def step_actor_factory(
35+
ctx: ActorRuntimeContext,
36+
actor_id: ActorId,
37+
) -> StepActor:
38+
return StepActor(ctx, actor_id, kernel=kernel, factories=factories)
3039

3140
return process_actor_factory, step_actor_factory
3241

3342

3443
# Asynchronous registration for FastAPI
35-
async def register_fastapi_dapr_actors(actor: FastAPIDaprActor, kernel: "Kernel") -> None:
44+
async def register_fastapi_dapr_actors(
45+
actor: FastAPIDaprActor, kernel: "Kernel", factories: dict[str, Callable] | None = None
46+
) -> None:
3647
"""Registers the actors with the Dapr runtime for use with a FastAPI app."""
37-
process_actor_factory, step_actor_factory = create_actor_factories(kernel)
48+
process_actor_factory, step_actor_factory = create_actor_factories(kernel, factories)
3849
await actor.register_actor(ProcessActor, actor_factory=process_actor_factory)
3950
await actor.register_actor(StepActor, actor_factory=step_actor_factory)
4051
await actor.register_actor(EventBufferActor)
@@ -43,9 +54,11 @@ async def register_fastapi_dapr_actors(actor: FastAPIDaprActor, kernel: "Kernel"
4354

4455

4556
# Synchronous registration for Flask
46-
def register_flask_dapr_actors(actor: FlaskDaprActor, kernel: "Kernel") -> None:
57+
def register_flask_dapr_actors(
58+
actor: FlaskDaprActor, kernel: "Kernel", factory: dict[str, Callable] | None = None
59+
) -> None:
4760
"""Registers the actors with the Dapr runtime for use with a Flask app."""
48-
process_actor_factory, step_actor_factory = create_actor_factories(kernel)
61+
process_actor_factory, step_actor_factory = create_actor_factories(kernel, factory)
4962
actor.register_actor(ProcessActor, actor_factory=process_actor_factory)
5063
actor.register_actor(StepActor, actor_factory=step_actor_factory)
5164
actor.register_actor(EventBufferActor)

‎python/semantic_kernel/processes/kernel_process/kernel_process.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3+
from collections.abc import Callable
34
from typing import TYPE_CHECKING, Any
45

56
from pydantic import Field
@@ -18,14 +19,24 @@ class KernelProcess(KernelProcessStepInfo):
1819
"""A kernel process."""
1920

2021
steps: list[KernelProcessStepInfo] = Field(default_factory=list)
22+
factories: dict[str, Callable] = Field(default_factory=dict)
2123

2224
def __init__(
2325
self,
2426
state: KernelProcessState,
2527
steps: list[KernelProcessStepInfo],
2628
edges: dict[str, list["KernelProcessEdge"]] | None = None,
29+
factories: dict[str, Callable] | None = None,
2730
):
28-
"""Initialize the kernel process."""
31+
"""Initialize the kernel process.
32+
33+
Args:
34+
state: The state of the process.
35+
steps: The steps of the process.
36+
edges: The edges of the process. Defaults to None.
37+
factories: The factories of the process. This allows for the creation of
38+
steps that require complex dependencies that cannot be JSON serialized or deserialized.
39+
"""
2940
if not state:
3041
raise ValueError("state cannot be None")
3142
if not steps:
@@ -43,4 +54,7 @@ def __init__(
4354
"output_edges": edges or {},
4455
}
4556

57+
if factories:
58+
args["factories"] = factories
59+
4660
super().__init__(**args)

‎python/semantic_kernel/processes/local_runtime/local_kernel_process_context.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, process: "KernelProcess", kernel: "Kernel"):
3333
process=process,
3434
kernel=kernel,
3535
parent_process_id=None,
36+
factories=process.factories,
3637
)
3738

3839
super().__init__(local_process=local_process)

‎python/semantic_kernel/processes/local_runtime/local_process.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import contextlib
55
import logging
66
import uuid
7+
from collections.abc import Callable
78
from queue import Queue
89
from typing import TYPE_CHECKING, Any
910

@@ -42,8 +43,15 @@ class LocalProcess(LocalStep):
4243
initialize_task: bool | None = False
4344
external_event_queue: Queue = Field(default_factory=Queue)
4445
process_task: asyncio.Task | None = None
45-
46-
def __init__(self, process: "KernelProcess", kernel: Kernel, parent_process_id: str | None = None):
46+
factories: dict[str, Callable] = Field(default_factory=dict)
47+
48+
def __init__(
49+
self,
50+
process: "KernelProcess",
51+
kernel: Kernel,
52+
factories: dict[str, Callable] | None = None,
53+
parent_process_id: str | None = None,
54+
):
4755
"""Initializes the local process."""
4856
args: dict[str, Any] = {
4957
"step_info": process,
@@ -54,6 +62,9 @@ def __init__(self, process: "KernelProcess", kernel: Kernel, parent_process_id:
5462
"initialize_task": False,
5563
}
5664

65+
if factories:
66+
args["factories"] = factories
67+
5768
super().__init__(**args)
5869

5970
def ensure_initialized(self):
@@ -124,6 +135,7 @@ def initialize_process(self):
124135
process = LocalProcess(
125136
process=step,
126137
kernel=self.kernel,
138+
factorie=self.factories,
127139
parent_process_id=self.id,
128140
)
129141

@@ -136,6 +148,7 @@ def initialize_process(self):
136148
local_step = LocalStep(
137149
step_info=step,
138150
kernel=self.kernel,
151+
factories=self.factories,
139152
parent_process_id=self.id,
140153
)
141154

‎python/semantic_kernel/processes/local_runtime/local_step.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import asyncio
44
import logging
55
import uuid
6+
from collections.abc import Callable
7+
from inspect import isawaitable
68
from queue import Queue
79
from typing import Any
810

@@ -19,13 +21,12 @@
1921
from semantic_kernel.processes.kernel_process.kernel_process_edge import KernelProcessEdge
2022
from semantic_kernel.processes.kernel_process.kernel_process_event import KernelProcessEvent
2123
from semantic_kernel.processes.kernel_process.kernel_process_message_channel import KernelProcessMessageChannel
22-
from semantic_kernel.processes.kernel_process.kernel_process_step import KernelProcessStep
2324
from semantic_kernel.processes.kernel_process.kernel_process_step_info import KernelProcessStepInfo
2425
from semantic_kernel.processes.kernel_process.kernel_process_step_state import KernelProcessStepState
2526
from semantic_kernel.processes.local_runtime.local_event import LocalEvent
2627
from semantic_kernel.processes.local_runtime.local_message import LocalMessage
2728
from semantic_kernel.processes.process_types import get_generic_state_type
28-
from semantic_kernel.processes.step_utils import find_input_channels
29+
from semantic_kernel.processes.step_utils import find_input_channels, get_fully_qualified_name
2930
from semantic_kernel.utils.experimental_decorator import experimental_class
3031

3132
logger: logging.Logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class LocalStep(KernelProcessMessageChannel, KernelBaseModel):
4748
output_edges: dict[str, list[KernelProcessEdge]] = Field(default_factory=dict)
4849
parent_process_id: str | None = None
4950
init_lock: asyncio.Lock = Field(default_factory=asyncio.Lock, exclude=True)
51+
factories: dict[str, Callable]
5052

5153
@model_validator(mode="before")
5254
@classmethod
@@ -185,8 +187,16 @@ async def initialize_step(self):
185187
"""Initializes the step."""
186188
# Instantiate an instance of the inner step object
187189
step_cls = self.step_info.inner_step_type
188-
189-
step_instance: KernelProcessStep = step_cls() # type: ignore
190+
factory = (
191+
self.factories.get(get_fully_qualified_name(self.step_info.inner_step_type)) if self.factories else None
192+
)
193+
if factory:
194+
step_instance = factory()
195+
if isawaitable(step_instance):
196+
step_instance = await step_instance
197+
step_cls = type(step_instance)
198+
else:
199+
step_instance = step_cls() # type: ignore
190200

191201
kernel_plugin = self.kernel.add_plugin(
192202
step_instance, self.step_info.state.name if self.step_info.state else "default_name"

‎python/semantic_kernel/processes/process_builder.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import inspect
5+
from collections.abc import Callable
56
from copy import copy
67
from enum import Enum
78
from typing import TYPE_CHECKING
@@ -17,6 +18,7 @@
1718
from semantic_kernel.processes.process_step_builder import ProcessStepBuilder
1819
from semantic_kernel.processes.process_step_edge_builder import ProcessStepEdgeBuilder
1920
from semantic_kernel.processes.process_types import TState, TStep
21+
from semantic_kernel.processes.step_utils import get_fully_qualified_name
2022
from semantic_kernel.utils.experimental_decorator import experimental_class
2123

2224
if TYPE_CHECKING:
@@ -32,20 +34,38 @@ class ProcessBuilder(ProcessStepBuilder):
3234
has_parent_process: bool = False
3335

3436
steps: list["ProcessStepBuilder"] = Field(default_factory=list)
37+
factories: dict[str, Callable] = Field(default_factory=dict)
3538

3639
def add_step(
3740
self,
3841
step_type: type[TStep],
3942
name: str | None = None,
4043
initial_state: TState | None = None,
44+
factory_function: Callable | None = None,
4145
**kwargs,
4246
) -> ProcessStepBuilder[TState, TStep]:
43-
"""Register a step type with optional constructor arguments."""
47+
"""Register a step type with optional constructor arguments.
48+
49+
Args:
50+
step_type: The step type.
51+
name: The name of the step. Defaults to None.
52+
initial_state: The initial state of the step. Defaults to None.
53+
factory_function: The factory function. Allows for a callable that is used to create the step instance
54+
that may have complex dependencies that cannot be JSON serialized or deserialized. Defaults to None.
55+
kwargs: Additional keyword arguments.
56+
57+
Returns:
58+
The process step builder.
59+
"""
4460
if not inspect.isclass(step_type):
4561
raise ProcessInvalidConfigurationException(
4662
f"Expected a class type, but got an instance of {type(step_type).__name__}"
4763
)
4864

65+
if factory_function:
66+
fq_name = get_fully_qualified_name(step_type)
67+
self.factories[fq_name] = factory_function
68+
4969
name = name or step_type.__name__
5070
process_step_builder = ProcessStepBuilder(type=step_type, name=name, initial_state=initial_state, **kwargs)
5171
self.steps.append(process_step_builder)
@@ -117,4 +137,4 @@ def build(self) -> "KernelProcess":
117137
built_edges = {key: [edge.build() for edge in edges] for key, edges in self.edges.items()}
118138
built_steps = [step.build_step() for step in self.steps]
119139
process_state = KernelProcessState(name=self.name, id=self.id if self.has_parent_process else None)
120-
return KernelProcess(state=process_state, steps=built_steps, edges=built_edges)
140+
return KernelProcess(state=process_state, steps=built_steps, edges=built_edges, factories=self.factories)

‎python/tests/unit/processes/dapr_runtime/test_process_actor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def actor_context():
2727
actor_client=MagicMock(),
2828
)
2929
kernel_mock = MagicMock()
30-
actor = ProcessActor(runtime_context, actor_id, kernel=kernel_mock)
30+
actor = ProcessActor(runtime_context, actor_id, kernel=kernel_mock, factories={})
3131

3232
actor._state_manager = AsyncMock()
3333
actor._state_manager.try_add_state = AsyncMock(return_value=True)

‎python/tests/unit/processes/dapr_runtime/test_step_actor.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,28 @@
66
import pytest
77
from dapr.actor import ActorId
88

9+
from semantic_kernel.processes.dapr_runtime.actors.actor_state_key import ActorStateKeys
910
from semantic_kernel.processes.dapr_runtime.actors.step_actor import StepActor
1011
from semantic_kernel.processes.dapr_runtime.dapr_step_info import DaprStepInfo
1112
from semantic_kernel.processes.kernel_process.kernel_process_step_state import KernelProcessStepState
1213
from semantic_kernel.processes.process_message import ProcessMessage
1314

1415

16+
class FakeStep:
17+
async def activate(self, state):
18+
self.activated_state = state
19+
20+
21+
class FakeState:
22+
pass
23+
24+
1525
@pytest.fixture
1626
def actor_context():
1727
ctx = MagicMock()
1828
actor_id = ActorId("test_actor")
1929
kernel = MagicMock()
20-
return StepActor(ctx, actor_id, kernel)
30+
return StepActor(ctx, actor_id, kernel, factories={})
2131

2232

2333
async def test_initialize_step(actor_context):
@@ -97,3 +107,98 @@ async def test_process_incoming_messages(actor_context):
97107
expected_messages = []
98108
expected_messages = [json.dumps(msg.model_dump()) for msg in list(actor_context.incoming_messages.queue)]
99109
mock_try_add_state.assert_any_call("incomingMessagesState", expected_messages)
110+
111+
112+
async def test_activate_step_with_factory_creates_state(actor_context):
113+
fake_step_instance = FakeStep()
114+
fake_step_instance.activate = AsyncMock(side_effect=fake_step_instance.activate)
115+
116+
fake_plugin = MagicMock()
117+
fake_plugin.functions = {"test_function": lambda x: x}
118+
119+
with (
120+
patch(
121+
"semantic_kernel.processes.dapr_runtime.actors.step_actor.get_generic_state_type",
122+
return_value=FakeState,
123+
),
124+
patch(
125+
"semantic_kernel.processes.dapr_runtime.actors.step_actor.get_fully_qualified_name",
126+
return_value="FakeStateFullyQualified",
127+
),
128+
patch(
129+
"semantic_kernel.processes.dapr_runtime.actors.step_actor.find_input_channels",
130+
return_value={"channel": {"input": "value"}},
131+
),
132+
):
133+
actor_context.factories = {"FakeStep": lambda: fake_step_instance}
134+
actor_context.inner_step_type = "FakeStep"
135+
actor_context.step_info = DaprStepInfo(
136+
state=KernelProcessStepState(name="default_name", id="step_123"),
137+
inner_step_python_type="FakeStep",
138+
edges={},
139+
)
140+
actor_context.kernel.add_plugin = MagicMock(return_value=fake_plugin)
141+
actor_context._state_manager.try_add_state = AsyncMock()
142+
actor_context._state_manager.save_state = AsyncMock()
143+
144+
await actor_context.activate_step()
145+
146+
actor_context.kernel.add_plugin.assert_called_once_with(fake_step_instance, "default_name")
147+
assert actor_context.functions == fake_plugin.functions
148+
assert actor_context.initial_inputs == {"channel": {"input": "value"}}
149+
assert actor_context.inputs == {"channel": {"input": "value"}}
150+
assert actor_context.step_state is not None
151+
assert isinstance(actor_context.step_state.state, FakeState)
152+
fake_step_instance.activate.assert_awaited_once_with(actor_context.step_state)
153+
154+
155+
async def test_activate_step_with_factory_uses_existing_state(actor_context):
156+
fake_step_instance = FakeStep()
157+
fake_step_instance.activate = AsyncMock(side_effect=fake_step_instance.activate)
158+
159+
fake_plugin = MagicMock()
160+
fake_plugin.functions = {"test_function": lambda x: x}
161+
162+
pre_existing_state = KernelProcessStepState(name="ExistingState", id="ExistingState", state=None)
163+
164+
with (
165+
patch.object(
166+
KernelProcessStepState,
167+
"model_dump",
168+
return_value={"name": "ExistingState", "id": "ExistingState", "state": None},
169+
),
170+
patch(
171+
"semantic_kernel.processes.dapr_runtime.actors.step_actor.get_generic_state_type",
172+
return_value=FakeState,
173+
),
174+
patch(
175+
"semantic_kernel.processes.dapr_runtime.actors.step_actor.get_fully_qualified_name",
176+
return_value="FakeStateFullyQualified",
177+
),
178+
patch(
179+
"semantic_kernel.processes.dapr_runtime.actors.step_actor.find_input_channels",
180+
return_value={"channel": {"input": "value"}},
181+
),
182+
):
183+
actor_context.factories = {"FakeStep": lambda: fake_step_instance}
184+
actor_context.inner_step_type = "FakeStep"
185+
actor_context.step_info = DaprStepInfo(state=pre_existing_state, inner_step_python_type="FakeStep", edges={})
186+
actor_context.kernel.add_plugin = MagicMock(return_value=fake_plugin)
187+
actor_context._state_manager.try_add_state = AsyncMock()
188+
actor_context._state_manager.save_state = AsyncMock()
189+
190+
await actor_context.activate_step()
191+
192+
actor_context.kernel.add_plugin.assert_called_once_with(fake_step_instance, pre_existing_state.name)
193+
assert actor_context.functions == fake_plugin.functions
194+
assert actor_context.initial_inputs == {"channel": {"input": "value"}}
195+
assert actor_context.inputs == {"channel": {"input": "value"}}
196+
actor_context._state_manager.try_add_state.assert_any_await(
197+
ActorStateKeys.StepStateType.value, "FakeStateFullyQualified"
198+
)
199+
actor_context._state_manager.try_add_state.assert_any_await(
200+
ActorStateKeys.StepStateJson.value, json.dumps(pre_existing_state.model_dump())
201+
)
202+
actor_context._state_manager.save_state.assert_awaited_once()
203+
assert isinstance(actor_context.step_state.state, FakeState)
204+
fake_step_instance.activate.assert_awaited_once_with(actor_context.step_state)

‎python/tests/unit/processes/local_runtime/test_local_kernel_process_context.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def mock_process():
2626
process = MagicMock(spec=KernelProcess)
2727
process.state = state
2828
process.steps = [step_info]
29+
process.factories = {}
2930
return process
3031

3132

‎python/tests/unit/processes/local_runtime/test_local_process.py

+1
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def test_initialize_process(mock_process, mock_kernel, build_model):
271271
mock_local_step_init.assert_called_with(
272272
step_info=step_info,
273273
kernel=mock_kernel,
274+
factories={},
274275
parent_process_id=local_process.id,
275276
)
276277

0 commit comments

Comments
 (0)
Please sign in to comment.