diff --git a/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp b/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp index 4e652248e7..1fc31d04bb 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Transformers.cpp @@ -195,30 +195,45 @@ updateGateInfoAccordingValue(Message &gate, } auto gateCiphertext = gateTypeInfo.getLweCiphertext(); auto gateCompression = gateCiphertext.getCompression(); - if (gateCompression == concreteprotocol::Compression::SEED) { - auto valueReader = value.asReader(); - auto valueTypeInfo = valueReader.getTypeInfo(); - if (!valueTypeInfo.hasLweCiphertext()) { - return gate; - } - auto valueCiphertext = valueTypeInfo.getLweCiphertext(); - auto valueCompression = valueCiphertext.getCompression(); - if (valueCompression == concreteprotocol::Compression::NONE) { - // If the compression of transportValue is none and the gateInfo have - // compression we update the gateInfo to allow uncompressed transportValue - auto gateBuilder = gate.asBuilder(); - gateBuilder.getTypeInfo().getLweCiphertext().setCompression( - concreteprotocol::Compression::NONE); - auto gateDimensions = gateBuilder.getRawInfo().getShape().getDimensions(); - auto lweSize = gateCiphertext.getEncryption().getLweDimension() + 1; - gateDimensions.set(gateDimensions.size() - 1, lweSize); - auto concreteShapeDimensions = gateBuilder.getTypeInfo() - .getLweCiphertext() - .getConcreteShape() - .getDimensions(); - concreteShapeDimensions.set(concreteShapeDimensions.size() - 1, lweSize); - return gateBuilder.asReader(); - } + auto valueReader = value.asReader(); + auto valueTypeInfo = valueReader.getTypeInfo(); + if (!valueTypeInfo.hasLweCiphertext()) { + return gate; + } + auto valueCiphertext = valueTypeInfo.getLweCiphertext(); + auto valueCompression = valueCiphertext.getCompression(); + if (gateCompression == concreteprotocol::Compression::SEED && + valueCompression == concreteprotocol::Compression::NONE) { + // If the compression of transportValue is none and the gateInfo have + // compression we update the gateInfo to allow uncompressed transportValue + auto gateBuilder = gate.asBuilder(); + gateBuilder.getTypeInfo().getLweCiphertext().setCompression( + concreteprotocol::Compression::NONE); + auto gateDimensions = gateBuilder.getRawInfo().getShape().getDimensions(); + auto lweSize = gateCiphertext.getEncryption().getLweDimension() + 1; + gateDimensions.set(gateDimensions.size() - 1, lweSize); + auto concreteShapeDimensions = gateBuilder.getTypeInfo() + .getLweCiphertext() + .getConcreteShape() + .getDimensions(); + concreteShapeDimensions.set(concreteShapeDimensions.size() - 1, lweSize); + return gateBuilder.asReader(); + } + if (gateCompression == concreteprotocol::Compression::NONE && + valueCompression == concreteprotocol::Compression::SEED) { + // If the compression of transportValue is none and the gateInfo have + // compression we update the gateInfo to allow uncompressed transportValue + auto gateBuilder = gate.asBuilder(); + gateBuilder.getTypeInfo().getLweCiphertext().setCompression( + concreteprotocol::Compression::SEED); + auto gateDimensions = gateBuilder.getRawInfo().getShape().getDimensions(); + gateDimensions.set(gateDimensions.size() - 1, 3); + auto concreteShapeDimensions = gateBuilder.getTypeInfo() + .getLweCiphertext() + .getConcreteShape() + .getDimensions(); + concreteShapeDimensions.set(concreteShapeDimensions.size() - 1, 3); + return gateBuilder.asReader(); } return gate; } @@ -902,6 +917,23 @@ Result getSeededLweCiphertextDecompressionTransformer( }; } +Result getDecompressionTransformer( + const Message &info) { + + return [=](TransportValue transportVal) -> Result { + auto value = Value::fromRawTransportValue(transportVal); + auto compression = transportVal.asReader() + .getTypeInfo() + .getLweCiphertext() + .getCompression(); + if (compression == concreteprotocol::Compression::SEED) { + OUTCOME_TRY(auto d, getSeededLweCiphertextDecompressionTransformer(info)); + return d(value); + } + return value; + }; +} + Result TransformerFactory::getLweCiphertextArgTransformer( Message gateInfo, bool useSimulation) { if (!gateInfo.asReader().getTypeInfo().hasLweCiphertext()) { @@ -910,20 +942,9 @@ Result TransformerFactory::getLweCiphertextArgTransformer( } /// Generating the decompression transformer. - Transformer decompressionTransformer; auto lweCiphertextInfo = gateInfo.asReader().getTypeInfo().getLweCiphertext(); - auto compression = lweCiphertextInfo.getCompression(); - if (compression == concreteprotocol::Compression::NONE || useSimulation) { - OUTCOME_TRY(decompressionTransformer, getNoneDecompressionTransformer()); - } else if (compression == concreteprotocol::Compression::SEED) { - OUTCOME_TRY(decompressionTransformer, - getSeededLweCiphertextDecompressionTransformer( - lweCiphertextInfo.getEncryption())); - } else { - return StringError( - "Only none compression is currently supported for lwe ciphertext " - "currently."); - } + OUTCOME_TRY(auto decompressionTransformer, + getDecompressionTransformer(lweCiphertextInfo.getEncryption())); // Generating the verifier. TransportValueVerifier verify; @@ -935,13 +956,7 @@ Result TransformerFactory::getLweCiphertextArgTransformer( return [=](TransportValue transportVal) -> Result { OUTCOME_TRYV(verify(transportVal)); - auto value = Value::fromRawTransportValue(transportVal); - if (transportVal.asReader() - .getTypeInfo() - .getLweCiphertext() - .getCompression() == concreteprotocol::Compression::NONE) - return value; - return decompressionTransformer(value); + return decompressionTransformer(transportVal); }; } @@ -1001,15 +1016,10 @@ Result TransformerFactory::getLweCiphertextOutputTransformer( } /// Generating the decompression transformer. - Transformer decompressionTransformer; - if (gateInfo.asReader().getTypeInfo().getLweCiphertext().getCompression() == - concreteprotocol::Compression::NONE) { - OUTCOME_TRY(decompressionTransformer, getNoneDecompressionTransformer()); - } else { - return StringError( - "Only none compression is currently supported for lwe ciphertext " - "currently."); - } + auto encryptionInfo = + gateInfo.asReader().getTypeInfo().getLweCiphertext().getEncryption(); + OUTCOME_TRY(auto decompressionTransformer, + getDecompressionTransformer(encryptionInfo)); /// Generating the decryption transformer. Transformer decryptionTransformer; @@ -1056,8 +1066,8 @@ Result TransformerFactory::getLweCiphertextOutputTransformer( return [=](TransportValue transportVal) -> Result { OUTCOME_TRYV(verify(transportVal)); - return decodingTransformer(decryptionTransformer( - decompressionTransformer(Value::fromRawTransportValue(transportVal)))); + OUTCOME_TRY(auto value, decompressionTransformer(transportVal)); + return decodingTransformer(decryptionTransformer(value)); }; }