From 718de09aa6b2059c7042f16f3feff3ec0645d554 Mon Sep 17 00:00:00 2001 From: John Preston Date: Mon, 2 Dec 2019 14:34:14 +0300 Subject: [PATCH] Handle state / resend requests separately. --- Telegram/SourceFiles/mtproto/connection.cpp | 142 ++++++++---------- Telegram/SourceFiles/mtproto/connection.h | 3 +- .../details/mtproto_serialized_request.cpp | 8 - .../details/mtproto_serialized_request.h | 1 - 4 files changed, 67 insertions(+), 87 deletions(-) diff --git a/Telegram/SourceFiles/mtproto/connection.cpp b/Telegram/SourceFiles/mtproto/connection.cpp index d180ebba66..f8b4a9bb57 100644 --- a/Telegram/SourceFiles/mtproto/connection.cpp +++ b/Telegram/SourceFiles/mtproto/connection.cpp @@ -226,9 +226,7 @@ void Connection::checkSentRequests() { const auto now = crl::now(); const auto checkAfter = kCheckSentRequestTimeout; for (const auto &[msgId, request] : haveSent) { - if (request.isStateRequest()) { - continue; - } else if (request->lastSentTime + checkAfter < now) { + if (request->lastSentTime + checkAfter < now) { // Need to check state. request->lastSentTime = now; if (_bindMsgId) { @@ -563,8 +561,6 @@ void Connection::tryToSend() { resendRequest = SerializedRequest::Serialize(MTPMsgResendReq( MTP_msg_resend_req(MTP_vector( base::take(_resendRequestData))))); - // Add to haveSent / _ackedIds, but don't add to requestMap. - resendRequest->requestId = GetNextRequestId(); } if (!_stateRequestData.empty()) { auto ids = QVector(); @@ -574,8 +570,6 @@ void Connection::tryToSend() { } stateRequest = SerializedRequest::Serialize(MTPMsgsStateReq( MTP_msgs_state_req(MTP_vector(ids)))); - // Add to haveSent / _ackedIds, but don't add to requestMap. - stateRequest->requestId = GetNextRequestId(); } if (_connection->usingHttpWait()) { httpWaitRequest = SerializedRequest::Serialize(MTPHttpWait( @@ -694,7 +688,10 @@ void Connection::tryToSend() { } else if (pingRequest) { _pingMsgId = msgId; needAnyResponse = true; - } else if (resendRequest || stateRequest) { + } else if (stateRequest || resendRequest) { + _stateAndResendRequests.emplace( + msgId, + stateRequest ? stateRequest : resendRequest); needAnyResponse = true; } @@ -793,6 +790,7 @@ void Connection::tryToSend() { pingRequest); needAnyResponse = true; } + for (auto &[requestId, request] : toSend) { const auto msgId = prepareToSend( request, @@ -836,15 +834,15 @@ void Connection::tryToSend() { memcpy(toSendRequest->data() + from, request->constData() + 4, len * sizeof(mtpPrime)); } } + toSend.clear(); + if (stateRequest) { const auto msgId = placeToContainer( toSendRequest, bigMsgId, forceNewMsgId, stateRequest); - Assert(!haveSent.contains(msgId)); - haveSent.emplace(msgId, stateRequest); - sentIdsWrap.messages.push_back(msgId); + _stateAndResendRequests.emplace(msgId, stateRequest); needAnyResponse = true; } if (resendRequest) { @@ -853,9 +851,7 @@ void Connection::tryToSend() { bigMsgId, forceNewMsgId, resendRequest); - Assert(!haveSent.contains(msgId)); - haveSent.emplace(msgId, resendRequest); - sentIdsWrap.messages.push_back(msgId); + _stateAndResendRequests.emplace(msgId, resendRequest); needAnyResponse = true; } if (ackRequest) { @@ -872,7 +868,6 @@ void Connection::tryToSend() { forceNewMsgId, httpWaitRequest); } - toSend.clear(); const auto containerMsgId = prepareToSend( toSendRequest, @@ -1585,52 +1580,44 @@ Connection::HandleResult Connection::handleOneReceived( auto &states = data.vinfo().v; DEBUG_LOG(("Message Info: msg state received, msgId %1, reqMsgId: %2, HEX states %3").arg(msgId).arg(reqMsgId).arg(Logs::mb(states.data(), states.length()).str())); - SerializedRequest requestBuffer; - { // find this request in session-shared sent requests map - QReadLocker locker(_sessionData->haveSentMutex()); - const auto &haveSent = _sessionData->haveSentMap(); - const auto replyTo = haveSent.find(reqMsgId); - if (replyTo == haveSent.end()) { // do not look in toResend, because we do not resend msgs_state_req requests - DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId)); - return (badTime ? HandleResult::Ignored : HandleResult::Success); - } - if (badTime) { - if (serverSalt) { - _sessionSalt = serverSalt; // requestsFixTimeSalt with no lookup - } - base::unixtime::update(serverTime, true); - - DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(serverTime)); - - badTime = false; - } - requestBuffer = replyTo->second; + const auto i = _stateAndResendRequests.find(reqMsgId); + if (i == _stateAndResendRequests.end()) { + DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId)); + return (badTime ? HandleResult::Ignored : HandleResult::Success); } - QVector toAckReq(1, MTP_long(reqMsgId)), toAck; - requestsAcked(toAck, true); + if (badTime) { + if (serverSalt) { + _sessionSalt = serverSalt; // requestsFixTimeSalt with no lookup + } + base::unixtime::update(serverTime, true); - if (requestBuffer->size() < 9) { - LOG(("Message Error: bad request %1 found in requestMap, size: %2").arg(reqMsgId).arg(requestBuffer->size())); - return HandleResult::RestartConnection; + DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(serverTime)); + + badTime = false; } - const mtpPrime *rFrom = requestBuffer->constData() + 8, *rEnd = requestBuffer->constData() + requestBuffer->size(); + const auto originalRequest = i->second; + Assert(originalRequest->size() > 8); + + requestsAcked(QVector(1, MTP_long(reqMsgId)), true); + + auto rFrom = originalRequest->constData() + 8; + const auto rEnd = originalRequest->constData() + originalRequest->size(); + auto toAck = QVector(); if (mtpTypeId(*rFrom) == mtpc_msgs_state_req) { MTPMsgsStateReq request; if (!request.read(rFrom, rEnd)) { LOG(("Message Error: could not parse sent msgs_state_req")); return HandleResult::ParseError; } - handleMsgsStates(request.c_msgs_state_req().vmsg_ids().v, states, toAck); + handleMsgsStates(request.c_msgs_state_req().vmsg_ids().v, states); } else { MTPMsgResendReq request; if (!request.read(rFrom, rEnd)) { LOG(("Message Error: could not parse sent msgs_resend_req")); return HandleResult::ParseError; } - handleMsgsStates(request.c_msg_resend_req().vmsg_ids().v, states, toAck); + handleMsgsStates(request.c_msg_resend_req().vmsg_ids().v, states); } - - requestsAcked(toAck); } return HandleResult::Success; case mtpc_msgs_all_info: { @@ -1647,12 +1634,8 @@ Connection::HandleResult Connection::handleOneReceived( auto &ids = data.vmsg_ids().v; auto &states = data.vinfo().v; - QVector toAck; - DEBUG_LOG(("Message Info: msgs all info received, msgId %1, reqMsgIds: %2, states %3").arg(msgId).arg(LogIdsVector(ids)).arg(Logs::mb(states.data(), states.length()).str())); - handleMsgsStates(ids, states, toAck); - - requestsAcked(toAck); + handleMsgsStates(ids, states); } return HandleResult::Success; case mtpc_msg_detailed_info: { @@ -1984,6 +1967,10 @@ void Connection::requestsAcked(const QVector &ids, bool byResponse) { _sentContainers.erase(i); continue; } + if (const auto i = _stateAndResendRequests.find(msgId); i != end(_stateAndResendRequests)) { + _stateAndResendRequests.erase(i); + continue; + } if (const auto i = haveSent.find(msgId); i != end(haveSent)) { const auto requestId = i->second->requestId; @@ -2040,21 +2027,22 @@ void Connection::requestsAcked(const QVector &ids, bool byResponse) { } } -void Connection::handleMsgsStates(const QVector &ids, const QByteArray &states, QVector &acked) { - uint32 idsCount = ids.size(); +void Connection::handleMsgsStates(const QVector &ids, const QByteArray &states) { + const auto idsCount = ids.size(); if (!idsCount) { DEBUG_LOG(("Message Info: void ids vector in handleMsgsStates()")); return; } - if (states.size() < idsCount) { + if (states.size() != idsCount) { LOG(("Message Error: got less states than required ids count.")); return; } - acked.reserve(acked.size() + idsCount); - for (uint32 i = 0, count = idsCount; i < count; ++i) { - char state = states[i]; - uint64 requestMsgId = ids[i].v; + auto acked = QVector(); + acked.reserve(idsCount); + for (auto i = 0; i != idsCount; ++i) { + const auto state = states[i]; + const auto requestMsgId = ids[i].v; { QReadLocker locker(_sessionData->haveSentMutex()); if (!_sessionData->haveSentMap().contains(requestMsgId)) { @@ -2081,6 +2069,7 @@ void Connection::handleMsgsStates(const QVector &ids, const QByteArray acked.push_back(MTP_long(requestMsgId)); } } + requestsAcked(acked); } void Connection::clearSpecialMsgId(mtpMsgId msgId) { @@ -2123,33 +2112,32 @@ void Connection::resend( haveSent.erase(i); lock.unlock(); - if (!request.isStateRequest()) { - request->lastSentTime = crl::now(); - request->forceSendInContainer = forceContainer; - _resendingIds.emplace(msgId, request->requestId); - { - QWriteLocker locker(_sessionData->toSendMutex()); - _sessionData->toSendMap().emplace(request->requestId, request); - } + request->lastSentTime = crl::now(); + request->forceSendInContainer = forceContainer; + _resendingIds.emplace(msgId, request->requestId); + { + QWriteLocker locker(_sessionData->toSendMutex()); + _sessionData->toSendMap().emplace(request->requestId, request); } } void Connection::resendAll() { - auto toResend = std::vector(); - - auto lock = QReadLocker(_sessionData->haveSentMutex()); - const auto &haveSent = _sessionData->haveSentMap(); - toResend.reserve(haveSent.size()); - for (const auto &[msgId, request] : haveSent) { - if (!request.isStateRequest()) { - toResend.push_back(msgId); + auto lock = QWriteLocker(_sessionData->haveSentMutex()); + auto haveSent = base::take(_sessionData->haveSentMap()); + lock.unlock(); + { + auto lock = QWriteLocker(_sessionData->toSendMutex()); + auto &toSend = _sessionData->toSendMap(); + const auto now = crl::now(); + for (auto &[msgId, request] : haveSent) { + const auto requestId = request->requestId; + request->lastSentTime = now; + request->forceSendInContainer = true; + _resendingIds.emplace(msgId, requestId); + toSend.emplace(requestId, std::move(request)); } } - lock.unlock(); - for (const auto msgId : toResend) { - resend(msgId, -1, true); - } _sessionData->queueSendAnything(); } diff --git a/Telegram/SourceFiles/mtproto/connection.h b/Telegram/SourceFiles/mtproto/connection.h index 706919b4ba..67b898f483 100644 --- a/Telegram/SourceFiles/mtproto/connection.h +++ b/Telegram/SourceFiles/mtproto/connection.h @@ -128,7 +128,7 @@ private: mtpMsgId requestMsgId, const mtpBuffer &response); mtpBuffer ungzip(const mtpPrime *from, const mtpPrime *end) const; - void handleMsgsStates(const QVector &ids, const QByteArray &states, QVector &acked); + void handleMsgsStates(const QVector &ids, const QByteArray &states); // _sessionDataMutex must be locked for read. bool setState(int state, int ifState = kUpdateStateAlways); @@ -219,6 +219,7 @@ private: details::ReceivedIdsManager _receivedMessageIds; base::flat_map _resendingIds; base::flat_map _ackedIds; + base::flat_map _stateAndResendRequests; base::flat_map _sentContainers; std::unique_ptr _keyCreator; diff --git a/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.cpp b/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.cpp index 437b79c9ec..5b04acafbb 100644 --- a/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.cpp +++ b/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.cpp @@ -124,14 +124,6 @@ uint32 SerializedRequest::messageSize() const { return kMessageIdInts + kSeqNoInts + kMessageLengthInts + ints; } -bool SerializedRequest::isStateRequest() const { - Expects(_data != nullptr); - Expects(_data->size() > kMessageBodyPosition); - - const auto type = mtpTypeId((*_data)[kMessageBodyPosition]); - return (type == mtpc_msgs_state_req); -} - bool SerializedRequest::needAck() const { Expects(_data != nullptr); Expects(_data->size() > kMessageBodyPosition); diff --git a/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.h b/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.h index 5da84f4b30..61dd8e6c70 100644 --- a/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.h +++ b/Telegram/SourceFiles/mtproto/details/mtproto_serialized_request.h @@ -67,7 +67,6 @@ public: void addPadding(bool extended, bool old); [[nodiscard]] uint32 messageSize() const; - [[nodiscard]] bool isStateRequest() const; [[nodiscard]] bool needAck() const; using ResponseType = void; // don't know real response type =(