Skip to content

Commit

Permalink
Merge pull request lf-lang#2375 from lf-lang/custom-serializer
Browse files Browse the repository at this point in the history
Custom Serialization in Python Target
  • Loading branch information
lhstrh authored Jul 31, 2024
2 parents d07547c + 496b345 commit a088c1a
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ local.properties
# PyDev specific (Python IDE for Eclipse)
*.pydevproject

## Python specific
*.pyc
**.egg-info
__pycache__

# CDT-specific (C/C++ Development Tooling)
.cproject

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.lflang.federated.generator.FederateInstance;
import org.lflang.federated.generator.FederationFileConfig;
import org.lflang.federated.launcher.RtiConfig;
import org.lflang.federated.serialization.FedCustomPythonSerialization;
import org.lflang.federated.serialization.FedNativePythonSerialization;
import org.lflang.federated.serialization.FedSerialization;
import org.lflang.federated.serialization.SupportedSerializers;
Expand Down Expand Up @@ -65,6 +66,12 @@ protected String generateSerializationIncludes(
FedNativePythonSerialization pickler = new FedNativePythonSerialization();
code.pr(pickler.generatePreambleForSupport().toString());
}
case CUSTOM:
{
FedCustomPythonSerialization serializer =
new FedCustomPythonSerialization(serialization.getSerializer());
code.pr(serializer.generatePreambleForSupport().toString());
}
case PROTO:
{
// Nothing needs to be done
Expand Down Expand Up @@ -144,6 +151,21 @@ protected void deserialize(
result.pr("lf_set_destructor(" + receiveRef + ", python_count_decrement);\n");
result.pr("lf_set_token(" + receiveRef + ", token);\n");
}
case CUSTOM -> {
value = action.getName();
FedCustomPythonSerialization serializer =
new FedCustomPythonSerialization(connection.getSerializer().getSerializer());
result.pr(serializer.generateNetworkDeserializerCode(value, null));
// Use token to set ports and destructor
result.pr(
"lf_token_t* token = lf_new_token((void*)"
+ receiveRef
+ ", "
+ FedSerialization.deserializedVarName
+ ", 1);\n");
result.pr("lf_set_destructor(" + receiveRef + ", python_count_decrement);\n");
result.pr("lf_set_token(" + receiveRef + ", token);\n");
}
case PROTO ->
throw new UnsupportedOperationException("Protobuf serialization is not supported yet.");
case ROS2 ->
Expand Down Expand Up @@ -174,6 +196,18 @@ protected void serializeAndSend(
// Decrease the reference count for serialized_pyobject
result.pr("Py_XDECREF(serialized_pyobject);\n");
}
case CUSTOM -> {
var variableToSerialize = sendRef + "->value";
FedCustomPythonSerialization serializer =
new FedCustomPythonSerialization(connection.getSerializer().getSerializer());
lengthExpression = serializer.serializedBufferLength();
pointerExpression = serializer.serializedBufferVar();
result.pr(serializer.generateNetworkSerializerCode(variableToSerialize, null));
result.pr("size_t _lf_message_length = " + lengthExpression + ";");
result.pr(sendingFunction + "(" + commonArgs + ", " + pointerExpression + ");\n");
// Decrease the reference count for serialized_pyobject
result.pr("Py_XDECREF(serialized_pyobject);\n");
}
case PROTO ->
throw new UnsupportedOperationException("Protobuf serialization is not supported yet.");
case ROS2 ->
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/java/org/lflang/federated/generator/FedASTUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@
import org.lflang.lf.ParameterReference;
import org.lflang.lf.Reaction;
import org.lflang.lf.Reactor;
import org.lflang.lf.StateVar;
import org.lflang.lf.Type;
import org.lflang.lf.VarRef;
import org.lflang.lf.Variable;
import org.lflang.target.Target;
import org.lflang.target.property.type.CoordinationModeType.CoordinationMode;

/**
Expand Down Expand Up @@ -258,6 +260,12 @@ private static void addNetworkReceiverReactor(
receiver.getReactions().add(networkReceiverReaction);
receiver.getOutputs().add(out);

if (connection.dstFederate.targetConfig.target == Target.Python) {
StateVar serializer = factory.createStateVar();
serializer.setName("custom_serializer");
receiver.getStateVars().add(serializer);
}

addLevelAttribute(
networkInstance,
connection.getDestinationPortInstance(),
Expand Down Expand Up @@ -682,6 +690,12 @@ private static Reactor getNetworkSenderReactor(
in.setWidthSpec(widthSpec);
inRef.setVariable(in);

if (connection.getSrcFederate().targetConfig.target == Target.Python) {
StateVar serializer = factory.createStateVar();
serializer.setName("custom_serializer");
sender.getStateVars().add(serializer);
}

destRef.setContainer(connection.getDestinationPortInstance().getParent().getDefinition());
destRef.setVariable(connection.getDestinationPortInstance().getDefinition());

Expand Down
13 changes: 12 additions & 1 deletion core/src/main/java/org/lflang/federated/generator/FedUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@ public static SupportedSerializers getSerializer(
// Get the serializer
SupportedSerializers serializer = SupportedSerializers.NATIVE;
if (connection.getSerializer() != null) {
serializer = SupportedSerializers.valueOf(connection.getSerializer().getType().toUpperCase());
boolean isCustomSerializer = true;
for (SupportedSerializers method : SupportedSerializers.values()) {
if (method.name().equalsIgnoreCase(connection.getSerializer().getType())) {
serializer =
SupportedSerializers.valueOf(connection.getSerializer().getType().toUpperCase());
isCustomSerializer = false;
break;
}
}
if (isCustomSerializer) {
serializer = SupportedSerializers.fromCustomString(connection.getSerializer().getType());
}
}
// Add it to the list of enabled serializers for the source and destination federates
srcFederate.enabledSerializers.add(serializer);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*************
* Copyright (c) 2024, The University of California at Berkeley.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
***************/

package org.lflang.federated.serialization;

import org.lflang.generator.GeneratorBase;
import org.lflang.target.Target;

/**
* Enables support for custom serialization.
*
* @author Shulu Li
*/
public class FedCustomPythonSerialization implements FedSerialization {

String customSerializerPackage;

public FedCustomPythonSerialization(String customSerializerPackage) {
this.customSerializerPackage = customSerializerPackage;
}

@Override
public boolean isCompatible(GeneratorBase generator) {
if (generator.getTarget() != Target.Python) {
throw new UnsupportedOperationException(
"The FedCustomPythonSerialization class only supports the Python target.");
}
return true;
}

@Override
public String serializedBufferLength() {
return serializedVarName + ".len";
}

@Override
public String serializedBufferVar() {
return serializedVarName + ".buf";
}

private String initializeCustomSerializer() {
return "if (self->custom_serializer == NULL) \n"
+ "self->custom_serializer = load_serializer(\"%s\");\n".formatted(customSerializerPackage);
}

@Override
public StringBuilder generateNetworkSerializerCode(String varName, String originalType) {
StringBuilder serializerCode = new StringBuilder();
// Initialize self->custom_serializer if null
serializerCode.append(this.initializeCustomSerializer());
// Serialize PyObject to bytes using custom serializer
serializerCode.append(
"""
PyObject *serialized_pyobject = custom_serialize(msg[0]->value, self->custom_serializer);
Py_buffer %s;
int returnValue = PyBytes_AsStringAndSize(serialized_pyobject, (char**)&%s.buf, &%s.len);
if (returnValue == -1) {
if (PyErr_Occurred()) PyErr_Print();
lf_print_error_and_exit("Could not serialize %s.");
}
"""
.formatted(serializedVarName, serializedVarName, serializedVarName, serializedVarName));
return serializerCode;
}

@Override
public StringBuilder generateNetworkDeserializerCode(String varName, String targetType) {
StringBuilder deserializerCode = new StringBuilder();
// Initialize self->custom_serializer if null
deserializerCode.append(this.initializeCustomSerializer());
// Deserialize network message to a PyObject using custom serializer
deserializerCode.append(
"""
PyObject *message_byte_array = PyBytes_FromStringAndSize((char*) %s->token->value, %s->token->length);
PyObject *%s = custom_deserialize(message_byte_array, self->custom_serializer);
if ( %s == NULL ) {
if (PyErr_Occurred()) PyErr_Print();
lf_print_error_and_exit("Could not serialize %s.");
}
"""
.formatted(
varName, varName, deserializedVarName, deserializedVarName, deserializedVarName));
return deserializerCode;
}

@Override
public StringBuilder generatePreambleForSupport() {
return new StringBuilder();
}

@Override
public StringBuilder generateCompilerExtensionForSupport() {
return new StringBuilder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
public enum SupportedSerializers {
NATIVE("native"), // Dangerous: just copies the memory layout of the sender
ROS2("ros2"),
PROTO("proto");
PROTO("proto"),
CUSTOM("");

private String serializer;

Expand All @@ -23,4 +24,9 @@ public String getSerializer() {
public void setSerializer(String serializer) {
this.serializer = serializer;
}

public static SupportedSerializers fromCustomString(String serializer) {
CUSTOM.setSerializer(serializer);
return CUSTOM;
}
}
4 changes: 4 additions & 0 deletions core/src/main/java/org/lflang/validation/LFValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,10 @@ public void checkReactor(Reactor reactor) throws IOException {
@Check(CheckType.FAST)
public void checkSerializer(Serializer serializer) {
boolean isValidSerializer = false;
if (this.target == Target.Python) {
// Allow any serializer package name in python
isValidSerializer = true;
}
for (SupportedSerializers method : SupportedSerializers.values()) {
if (method.name().equalsIgnoreCase(serializer.getType())) {
isValidSerializer = true;
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/resources/lib/c/reactor-c
70 changes: 70 additions & 0 deletions test/Python/src/serialization/CustomSerializer.lf
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# To run this test, the `pickle_serializer` package must be installed in the Python environment.
# Run `pip3 install -e ./test/Python/src/serialization/pickle_serializer` in the project root directory to install the pickle_serializer.
target Python {
coordination: decentralized
}

preamble {=
os.system("pip install ./src/serialization/pickle_serializer/ --user")
=}

reactor Client {
input server_message
output client_message
state count

reaction(startup) {=
self.count = 0
print("Client Startup!")
=}

reaction(server_message) -> client_message {=
val = server_message.value
if val != self.count:
print("client: out of order", val, self.count)
exit(1)
self.count+=2
val += 1
print("client:", val)
if val==23:
print("client done")
request_stop()
if val<23:
client_message.set(val)
=}
}

reactor Server {
output server_message
input client_message
state count

reaction(startup) -> server_message {=
self.count = 1
print("Server Startup!")
server_message.set(0)
=}

reaction(client_message) -> server_message {=
val = client_message.value
if val != self.count:
print("server: out of order", val, self.count)
exit(1)
self.count+=2
val += 1
print("server:", val)
if val==22:
print("server done")
server_message.set(val)
request_stop()
if val<22:
server_message.set(val)
=}
}

federated reactor {
client = new Client()
server = new Server()
server.server_message -> client.server_message after 100 ms serializer "pickle_serializer"
client.client_message -> server.client_message serializer "pickle_serializer"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .serializer import Serializer
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pickle

class Serializer():
def serialize(self, obj)->bytes:
return pickle.dumps(obj)
def deserialize(self, message:bytes):
return pickle.loads(message)
8 changes: 8 additions & 0 deletions test/Python/src/serialization/pickle_serializer/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from setuptools import setup, find_packages

setup(
name='pickle_serializer',
version='0.1',
packages=find_packages(),
install_requires=[],
)

0 comments on commit a088c1a

Please sign in to comment.