diff --git a/src/app/MessageDef/InvokeResponseMessage.cpp b/src/app/MessageDef/InvokeResponseMessage.cpp index cc9b89d3b15f35..f9887530c15e15 100644 --- a/src/app/MessageDef/InvokeResponseMessage.cpp +++ b/src/app/MessageDef/InvokeResponseMessage.cpp @@ -147,14 +147,32 @@ InvokeResponseIBs::Builder & InvokeResponseMessage::Builder::CreateInvokeRespons InvokeResponseMessage::Builder & InvokeResponseMessage::Builder::MoreChunkedMessages(const bool aMoreChunkedMessages) { + // If any changes are made to how we encoded MoreChunkedMessage that involves how many + // bytes are needed, a corresponding change to GetSizeForMoreChunkResponses indicating + // the new size that will be required. + // skip if error has already been set - if (mError == CHIP_NO_ERROR) + SuccessOrExit(mError); + + if (mIsMoreChunkMessageBufferReserved) { - mError = mpWriter->PutBoolean(TLV::ContextTag(Tag::kMoreChunkedMessages), aMoreChunkedMessages); + mError = GetWriter()->UnreserveBuffer(GetSizeForMoreChunkResponses()); + SuccessOrExit(mError); + mIsMoreChunkMessageBufferReserved = false; } + + mError = mpWriter->PutBoolean(TLV::ContextTag(Tag::kMoreChunkedMessages), aMoreChunkedMessages); +exit: return *this; } +CHIP_ERROR InvokeResponseMessage::Builder::ReserveSpaceForMoreChunkedMessages() +{ + ReturnErrorOnFailure(GetWriter()->ReserveBuffer(GetSizeForMoreChunkResponses())); + mIsMoreChunkMessageBufferReserved = true; + return CHIP_NO_ERROR; +} + CHIP_ERROR InvokeResponseMessage::Builder::EndOfInvokeResponseMessage() { // If any changes are made to how we end the invoke response message that involves how many @@ -177,6 +195,15 @@ CHIP_ERROR InvokeResponseMessage::Builder::EndOfInvokeResponseMessage() return GetError(); } +uint32_t InvokeResponseMessage::Builder::GetSizeForMoreChunkResponses() +{ + // MoreChunkedMessages() encodes a uint8_t with context tag 0x02. This means 1 control byte, + // 1 byte for the tag. For booleans the value is encoded in control byte. + uint32_t kEncodeMoreChunkedMessages = 1 + 1; + + return kEncodeMoreChunkedMessages; +} + uint32_t InvokeResponseMessage::Builder::GetSizeToEndInvokeResponseMessage() { // EncodeInteractionModelRevision() encodes a uint8_t with context tag 0xFF. This means 1 control byte, diff --git a/src/app/MessageDef/InvokeResponseMessage.h b/src/app/MessageDef/InvokeResponseMessage.h index ff08ab1780cc02..020a049a8eae9e 100644 --- a/src/app/MessageDef/InvokeResponseMessage.h +++ b/src/app/MessageDef/InvokeResponseMessage.h @@ -110,6 +110,13 @@ class Builder : public MessageBuilder */ InvokeResponseMessage::Builder & MoreChunkedMessages(const bool aMoreChunkedMessages); + /** + * @brief Reserved space in TLVWriter for MoreChunkedMessages + * @return CHIP_NO_ERROR upon successfully reserving space for MoreChunkedMessages + * @return other CHIP error see TLVWriter::ReserveBuffer for more details. + */ + CHIP_ERROR ReserveSpaceForMoreChunkedMessages(); + /** * @brief Mark the end of this InvokeResponseMessage * @@ -117,6 +124,13 @@ class Builder : public MessageBuilder */ CHIP_ERROR EndOfInvokeResponseMessage(); + /** + * @brief Get number of bytes required in the buffer by MoreChunkedMessages + * + * @return Expected number of bytes required in the buffer by MoreChunkedMessages() + */ + uint32_t GetSizeForMoreChunkResponses(); + /** * @brief Get number of bytes required in the buffer by EndOfInvokeResponseMessage() * @@ -126,7 +140,8 @@ class Builder : public MessageBuilder private: InvokeResponseIBs::Builder mInvokeResponses; - bool mIsEndBufferReserved = false; + bool mIsEndBufferReserved = false; + bool mIsMoreChunkMessageBufferReserved = false; }; } // namespace InvokeResponseMessage } // namespace app diff --git a/src/app/tests/TestMessageDef.cpp b/src/app/tests/TestMessageDef.cpp index 4d53ba04cc5446..10417fa2025522 100644 --- a/src/app/tests/TestMessageDef.cpp +++ b/src/app/tests/TestMessageDef.cpp @@ -2136,6 +2136,29 @@ void InvokeResponseMessageEndOfMessageReservationTest(nlTestSuite * apSuite, voi NL_TEST_ASSERT(apSuite, remainingLengthAfterInitWithReservation == remainingLengthAfterEndingInvokeResponseMessage); } +void InvokeResponseMessageReservationForEndandMoreChunkTest(nlTestSuite * apSuite, void * apContext) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + chip::System::PacketBufferTLVWriter writer; + InvokeResponseMessage::Builder invokeResponseMessageBuilder; + const uint32_t kSmallBufferSize = 100; + writer.Init(chip::System::PacketBufferHandle::New(kSmallBufferSize, /* aReservedSize = */ 0), /* useChainedBuffers = */ false); + err = invokeResponseMessageBuilder.InitWithEndBufferReserved(&writer); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + err = invokeResponseMessageBuilder.ReserveSpaceForMoreChunkedMessages(); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + uint32_t remainingLengthAllReservations = writer.GetRemainingFreeLength(); + + invokeResponseMessageBuilder.MoreChunkedMessages(/* aMoreChunkedMessages = */ true); + NL_TEST_ASSERT(apSuite, invokeResponseMessageBuilder.GetError() == CHIP_NO_ERROR); + err = invokeResponseMessageBuilder.EndOfInvokeResponseMessage(); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + uint32_t remainingLengthAfterEndingInvokeResponseMessage = writer.GetRemainingFreeLength(); + NL_TEST_ASSERT(apSuite, remainingLengthAllReservations == remainingLengthAfterEndingInvokeResponseMessage); +} + void InvokeResponsesEndOfResponseReservationTest(nlTestSuite * apSuite, void * apContext) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -2373,6 +2396,7 @@ const nlTest sTests[] = NL_TEST_DEF("InvokeRequestsEndOfRequestReservationTest", InvokeRequestsEndOfRequestReservationTest), NL_TEST_DEF("InvokeInvokeResponseMessageTest", InvokeInvokeResponseMessageTest), NL_TEST_DEF("InvokeResponseMessageEndOfMessageReservationTest", InvokeResponseMessageEndOfMessageReservationTest), + NL_TEST_DEF("InvokeResponseMessageReservationForEndandMoreChunkTest", InvokeResponseMessageReservationForEndandMoreChunkTest), NL_TEST_DEF("InvokeResponsesEndOfResponseReservationTest", InvokeResponsesEndOfResponseReservationTest), NL_TEST_DEF("ReportDataMessageTest", ReportDataMessageTest), NL_TEST_DEF("ReadRequestMessageTest", ReadRequestMessageTest),