diff --git a/.gitignore b/.gitignore index 35ccb1f2f9..da8902f971 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,11 @@ local.properties # PyDev specific (Python IDE for Eclipse) *.pydevproject +## Python specific +*.pyc +**.egg-info +__pycache__ + # CDT-specific (C/C++ Development Tooling) .cproject diff --git a/core/src/main/java/org/lflang/federated/extensions/PythonExtension.java b/core/src/main/java/org/lflang/federated/extensions/PythonExtension.java index b387d4966c..13c6a46fa2 100644 --- a/core/src/main/java/org/lflang/federated/extensions/PythonExtension.java +++ b/core/src/main/java/org/lflang/federated/extensions/PythonExtension.java @@ -34,6 +34,7 @@ import org.lflang.federated.generator.FederateInstance; import org.lflang.federated.generator.FederationFileConfig; import org.lflang.federated.launcher.RtiConfig; +import org.lflang.federated.serialization.FedCustomPythonSerialization; import org.lflang.federated.serialization.FedNativePythonSerialization; import org.lflang.federated.serialization.FedSerialization; import org.lflang.federated.serialization.SupportedSerializers; @@ -65,6 +66,12 @@ protected String generateSerializationIncludes( FedNativePythonSerialization pickler = new FedNativePythonSerialization(); code.pr(pickler.generatePreambleForSupport().toString()); } + case CUSTOM: + { + FedCustomPythonSerialization serializer = + new FedCustomPythonSerialization(serialization.getSerializer()); + code.pr(serializer.generatePreambleForSupport().toString()); + } case PROTO: { // Nothing needs to be done @@ -144,6 +151,21 @@ protected void deserialize( result.pr("lf_set_destructor(" + receiveRef + ", python_count_decrement);\n"); result.pr("lf_set_token(" + receiveRef + ", token);\n"); } + case CUSTOM -> { + value = action.getName(); + FedCustomPythonSerialization serializer = + new FedCustomPythonSerialization(connection.getSerializer().getSerializer()); + result.pr(serializer.generateNetworkDeserializerCode(value, null)); + // Use token to set ports and destructor + result.pr( + "lf_token_t* token = lf_new_token((void*)" + + receiveRef + + ", " + + FedSerialization.deserializedVarName + + ", 1);\n"); + result.pr("lf_set_destructor(" + receiveRef + ", python_count_decrement);\n"); + result.pr("lf_set_token(" + receiveRef + ", token);\n"); + } case PROTO -> throw new UnsupportedOperationException("Protobuf serialization is not supported yet."); case ROS2 -> @@ -174,6 +196,18 @@ protected void serializeAndSend( // Decrease the reference count for serialized_pyobject result.pr("Py_XDECREF(serialized_pyobject);\n"); } + case CUSTOM -> { + var variableToSerialize = sendRef + "->value"; + FedCustomPythonSerialization serializer = + new FedCustomPythonSerialization(connection.getSerializer().getSerializer()); + lengthExpression = serializer.serializedBufferLength(); + pointerExpression = serializer.serializedBufferVar(); + result.pr(serializer.generateNetworkSerializerCode(variableToSerialize, null)); + result.pr("size_t _lf_message_length = " + lengthExpression + ";"); + result.pr(sendingFunction + "(" + commonArgs + ", " + pointerExpression + ");\n"); + // Decrease the reference count for serialized_pyobject + result.pr("Py_XDECREF(serialized_pyobject);\n"); + } case PROTO -> throw new UnsupportedOperationException("Protobuf serialization is not supported yet."); case ROS2 -> diff --git a/core/src/main/java/org/lflang/federated/generator/FedASTUtils.java b/core/src/main/java/org/lflang/federated/generator/FedASTUtils.java index 5315e8562a..5f74e5c9ec 100644 --- a/core/src/main/java/org/lflang/federated/generator/FedASTUtils.java +++ b/core/src/main/java/org/lflang/federated/generator/FedASTUtils.java @@ -67,9 +67,11 @@ import org.lflang.lf.ParameterReference; import org.lflang.lf.Reaction; import org.lflang.lf.Reactor; +import org.lflang.lf.StateVar; import org.lflang.lf.Type; import org.lflang.lf.VarRef; import org.lflang.lf.Variable; +import org.lflang.target.Target; import org.lflang.target.property.type.CoordinationModeType.CoordinationMode; /** @@ -258,6 +260,12 @@ private static void addNetworkReceiverReactor( receiver.getReactions().add(networkReceiverReaction); receiver.getOutputs().add(out); + if (connection.dstFederate.targetConfig.target == Target.Python) { + StateVar serializer = factory.createStateVar(); + serializer.setName("custom_serializer"); + receiver.getStateVars().add(serializer); + } + addLevelAttribute( networkInstance, connection.getDestinationPortInstance(), @@ -682,6 +690,12 @@ private static Reactor getNetworkSenderReactor( in.setWidthSpec(widthSpec); inRef.setVariable(in); + if (connection.getSrcFederate().targetConfig.target == Target.Python) { + StateVar serializer = factory.createStateVar(); + serializer.setName("custom_serializer"); + sender.getStateVars().add(serializer); + } + destRef.setContainer(connection.getDestinationPortInstance().getParent().getDefinition()); destRef.setVariable(connection.getDestinationPortInstance().getDefinition()); diff --git a/core/src/main/java/org/lflang/federated/generator/FedUtils.java b/core/src/main/java/org/lflang/federated/generator/FedUtils.java index 2970be7c2a..fe755ce99d 100644 --- a/core/src/main/java/org/lflang/federated/generator/FedUtils.java +++ b/core/src/main/java/org/lflang/federated/generator/FedUtils.java @@ -14,7 +14,18 @@ public static SupportedSerializers getSerializer( // Get the serializer SupportedSerializers serializer = SupportedSerializers.NATIVE; if (connection.getSerializer() != null) { - serializer = SupportedSerializers.valueOf(connection.getSerializer().getType().toUpperCase()); + boolean isCustomSerializer = true; + for (SupportedSerializers method : SupportedSerializers.values()) { + if (method.name().equalsIgnoreCase(connection.getSerializer().getType())) { + serializer = + SupportedSerializers.valueOf(connection.getSerializer().getType().toUpperCase()); + isCustomSerializer = false; + break; + } + } + if (isCustomSerializer) { + serializer = SupportedSerializers.fromCustomString(connection.getSerializer().getType()); + } } // Add it to the list of enabled serializers for the source and destination federates srcFederate.enabledSerializers.add(serializer); diff --git a/core/src/main/java/org/lflang/federated/serialization/FedCustomPythonSerialization.java b/core/src/main/java/org/lflang/federated/serialization/FedCustomPythonSerialization.java new file mode 100644 index 0000000000..42a9de68eb --- /dev/null +++ b/core/src/main/java/org/lflang/federated/serialization/FedCustomPythonSerialization.java @@ -0,0 +1,118 @@ +/************* + * Copyright (c) 2024, The University of California at Berkeley. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + ***************/ + +package org.lflang.federated.serialization; + +import org.lflang.generator.GeneratorBase; +import org.lflang.target.Target; + +/** + * Enables support for custom serialization. + * + * @author Shulu Li + */ +public class FedCustomPythonSerialization implements FedSerialization { + + String customSerializerPackage; + + public FedCustomPythonSerialization(String customSerializerPackage) { + this.customSerializerPackage = customSerializerPackage; + } + + @Override + public boolean isCompatible(GeneratorBase generator) { + if (generator.getTarget() != Target.Python) { + throw new UnsupportedOperationException( + "The FedCustomPythonSerialization class only supports the Python target."); + } + return true; + } + + @Override + public String serializedBufferLength() { + return serializedVarName + ".len"; + } + + @Override + public String serializedBufferVar() { + return serializedVarName + ".buf"; + } + + private String initializeCustomSerializer() { + return "if (self->custom_serializer == NULL) \n" + + "self->custom_serializer = load_serializer(\"%s\");\n".formatted(customSerializerPackage); + } + + @Override + public StringBuilder generateNetworkSerializerCode(String varName, String originalType) { + StringBuilder serializerCode = new StringBuilder(); + // Initialize self->custom_serializer if null + serializerCode.append(this.initializeCustomSerializer()); + // Serialize PyObject to bytes using custom serializer + serializerCode.append( + """ + PyObject *serialized_pyobject = custom_serialize(msg[0]->value, self->custom_serializer); + Py_buffer %s; + int returnValue = PyBytes_AsStringAndSize(serialized_pyobject, (char**)&%s.buf, &%s.len); + if (returnValue == -1) { + if (PyErr_Occurred()) PyErr_Print(); + lf_print_error_and_exit("Could not serialize %s."); + } + """ + .formatted(serializedVarName, serializedVarName, serializedVarName, serializedVarName)); + return serializerCode; + } + + @Override + public StringBuilder generateNetworkDeserializerCode(String varName, String targetType) { + StringBuilder deserializerCode = new StringBuilder(); + // Initialize self->custom_serializer if null + deserializerCode.append(this.initializeCustomSerializer()); + // Deserialize network message to a PyObject using custom serializer + deserializerCode.append( + """ +PyObject *message_byte_array = PyBytes_FromStringAndSize((char*) %s->token->value, %s->token->length); +PyObject *%s = custom_deserialize(message_byte_array, self->custom_serializer); +if ( %s == NULL ) { + if (PyErr_Occurred()) PyErr_Print(); + lf_print_error_and_exit("Could not serialize %s."); +} +""" + .formatted( + varName, varName, deserializedVarName, deserializedVarName, deserializedVarName)); + return deserializerCode; + } + + @Override + public StringBuilder generatePreambleForSupport() { + return new StringBuilder(); + } + + @Override + public StringBuilder generateCompilerExtensionForSupport() { + return new StringBuilder(); + } +} diff --git a/core/src/main/java/org/lflang/federated/serialization/SupportedSerializers.java b/core/src/main/java/org/lflang/federated/serialization/SupportedSerializers.java index 4767f3f4d7..094a6b9c69 100644 --- a/core/src/main/java/org/lflang/federated/serialization/SupportedSerializers.java +++ b/core/src/main/java/org/lflang/federated/serialization/SupportedSerializers.java @@ -8,7 +8,8 @@ public enum SupportedSerializers { NATIVE("native"), // Dangerous: just copies the memory layout of the sender ROS2("ros2"), - PROTO("proto"); + PROTO("proto"), + CUSTOM(""); private String serializer; @@ -23,4 +24,9 @@ public String getSerializer() { public void setSerializer(String serializer) { this.serializer = serializer; } + + public static SupportedSerializers fromCustomString(String serializer) { + CUSTOM.setSerializer(serializer); + return CUSTOM; + } } diff --git a/core/src/main/java/org/lflang/validation/LFValidator.java b/core/src/main/java/org/lflang/validation/LFValidator.java index 7f08232c88..0e245be2db 100644 --- a/core/src/main/java/org/lflang/validation/LFValidator.java +++ b/core/src/main/java/org/lflang/validation/LFValidator.java @@ -1006,6 +1006,10 @@ public void checkReactor(Reactor reactor) throws IOException { @Check(CheckType.FAST) public void checkSerializer(Serializer serializer) { boolean isValidSerializer = false; + if (this.target == Target.Python) { + // Allow any serializer package name in python + isValidSerializer = true; + } for (SupportedSerializers method : SupportedSerializers.values()) { if (method.name().equalsIgnoreCase(serializer.getType())) { isValidSerializer = true; diff --git a/core/src/main/resources/lib/c/reactor-c b/core/src/main/resources/lib/c/reactor-c index 227c86d9d2..76ec485017 160000 --- a/core/src/main/resources/lib/c/reactor-c +++ b/core/src/main/resources/lib/c/reactor-c @@ -1 +1 @@ -Subproject commit 227c86d9d249e971d10a4dab1ff9b37e26cad2b5 +Subproject commit 76ec485017ae682a7f00aa9bc5609410c638c703 diff --git a/test/Python/src/serialization/CustomSerializer.lf b/test/Python/src/serialization/CustomSerializer.lf new file mode 100644 index 0000000000..adcec6a350 --- /dev/null +++ b/test/Python/src/serialization/CustomSerializer.lf @@ -0,0 +1,70 @@ +# To run this test, the `pickle_serializer` package must be installed in the Python environment. +# Run `pip3 install -e ./test/Python/src/serialization/pickle_serializer` in the project root directory to install the pickle_serializer. +target Python { + coordination: decentralized +} + +preamble {= + os.system("pip install ./src/serialization/pickle_serializer/ --user") +=} + +reactor Client { + input server_message + output client_message + state count + + reaction(startup) {= + self.count = 0 + print("Client Startup!") + =} + + reaction(server_message) -> client_message {= + val = server_message.value + if val != self.count: + print("client: out of order", val, self.count) + exit(1) + self.count+=2 + val += 1 + print("client:", val) + if val==23: + print("client done") + request_stop() + if val<23: + client_message.set(val) + =} +} + +reactor Server { + output server_message + input client_message + state count + + reaction(startup) -> server_message {= + self.count = 1 + print("Server Startup!") + server_message.set(0) + =} + + reaction(client_message) -> server_message {= + val = client_message.value + if val != self.count: + print("server: out of order", val, self.count) + exit(1) + self.count+=2 + val += 1 + print("server:", val) + if val==22: + print("server done") + server_message.set(val) + request_stop() + if val<22: + server_message.set(val) + =} +} + +federated reactor { + client = new Client() + server = new Server() + server.server_message -> client.server_message after 100 ms serializer "pickle_serializer" + client.client_message -> server.client_message serializer "pickle_serializer" +} diff --git a/test/Python/src/serialization/pickle_serializer/pickle_serializer/__init__.py b/test/Python/src/serialization/pickle_serializer/pickle_serializer/__init__.py new file mode 100644 index 0000000000..0bfb262658 --- /dev/null +++ b/test/Python/src/serialization/pickle_serializer/pickle_serializer/__init__.py @@ -0,0 +1 @@ +from .serializer import Serializer \ No newline at end of file diff --git a/test/Python/src/serialization/pickle_serializer/pickle_serializer/serializer.py b/test/Python/src/serialization/pickle_serializer/pickle_serializer/serializer.py new file mode 100644 index 0000000000..8a9ce82ae1 --- /dev/null +++ b/test/Python/src/serialization/pickle_serializer/pickle_serializer/serializer.py @@ -0,0 +1,7 @@ +import pickle + +class Serializer(): + def serialize(self, obj)->bytes: + return pickle.dumps(obj) + def deserialize(self, message:bytes): + return pickle.loads(message) \ No newline at end of file diff --git a/test/Python/src/serialization/pickle_serializer/setup.py b/test/Python/src/serialization/pickle_serializer/setup.py new file mode 100644 index 0000000000..f08a0ca8c1 --- /dev/null +++ b/test/Python/src/serialization/pickle_serializer/setup.py @@ -0,0 +1,8 @@ +from setuptools import setup, find_packages + +setup( + name='pickle_serializer', + version='0.1', + packages=find_packages(), + install_requires=[], +)