Handle state / resend requests separately.

This commit is contained in:
John Preston 2019-12-02 14:34:14 +03:00
parent 3b703d7262
commit 718de09aa6
4 changed files with 67 additions and 87 deletions

View File

@ -226,9 +226,7 @@ void Connection::checkSentRequests() {
const auto now = crl::now(); const auto now = crl::now();
const auto checkAfter = kCheckSentRequestTimeout; const auto checkAfter = kCheckSentRequestTimeout;
for (const auto &[msgId, request] : haveSent) { for (const auto &[msgId, request] : haveSent) {
if (request.isStateRequest()) { if (request->lastSentTime + checkAfter < now) {
continue;
} else if (request->lastSentTime + checkAfter < now) {
// Need to check state. // Need to check state.
request->lastSentTime = now; request->lastSentTime = now;
if (_bindMsgId) { if (_bindMsgId) {
@ -563,8 +561,6 @@ void Connection::tryToSend() {
resendRequest = SerializedRequest::Serialize(MTPMsgResendReq( resendRequest = SerializedRequest::Serialize(MTPMsgResendReq(
MTP_msg_resend_req(MTP_vector<MTPlong>( MTP_msg_resend_req(MTP_vector<MTPlong>(
base::take(_resendRequestData))))); base::take(_resendRequestData)))));
// Add to haveSent / _ackedIds, but don't add to requestMap.
resendRequest->requestId = GetNextRequestId();
} }
if (!_stateRequestData.empty()) { if (!_stateRequestData.empty()) {
auto ids = QVector<MTPlong>(); auto ids = QVector<MTPlong>();
@ -574,8 +570,6 @@ void Connection::tryToSend() {
} }
stateRequest = SerializedRequest::Serialize(MTPMsgsStateReq( stateRequest = SerializedRequest::Serialize(MTPMsgsStateReq(
MTP_msgs_state_req(MTP_vector<MTPlong>(ids)))); MTP_msgs_state_req(MTP_vector<MTPlong>(ids))));
// Add to haveSent / _ackedIds, but don't add to requestMap.
stateRequest->requestId = GetNextRequestId();
} }
if (_connection->usingHttpWait()) { if (_connection->usingHttpWait()) {
httpWaitRequest = SerializedRequest::Serialize(MTPHttpWait( httpWaitRequest = SerializedRequest::Serialize(MTPHttpWait(
@ -694,7 +688,10 @@ void Connection::tryToSend() {
} else if (pingRequest) { } else if (pingRequest) {
_pingMsgId = msgId; _pingMsgId = msgId;
needAnyResponse = true; needAnyResponse = true;
} else if (resendRequest || stateRequest) { } else if (stateRequest || resendRequest) {
_stateAndResendRequests.emplace(
msgId,
stateRequest ? stateRequest : resendRequest);
needAnyResponse = true; needAnyResponse = true;
} }
@ -793,6 +790,7 @@ void Connection::tryToSend() {
pingRequest); pingRequest);
needAnyResponse = true; needAnyResponse = true;
} }
for (auto &[requestId, request] : toSend) { for (auto &[requestId, request] : toSend) {
const auto msgId = prepareToSend( const auto msgId = prepareToSend(
request, request,
@ -836,15 +834,15 @@ void Connection::tryToSend() {
memcpy(toSendRequest->data() + from, request->constData() + 4, len * sizeof(mtpPrime)); memcpy(toSendRequest->data() + from, request->constData() + 4, len * sizeof(mtpPrime));
} }
} }
toSend.clear();
if (stateRequest) { if (stateRequest) {
const auto msgId = placeToContainer( const auto msgId = placeToContainer(
toSendRequest, toSendRequest,
bigMsgId, bigMsgId,
forceNewMsgId, forceNewMsgId,
stateRequest); stateRequest);
Assert(!haveSent.contains(msgId)); _stateAndResendRequests.emplace(msgId, stateRequest);
haveSent.emplace(msgId, stateRequest);
sentIdsWrap.messages.push_back(msgId);
needAnyResponse = true; needAnyResponse = true;
} }
if (resendRequest) { if (resendRequest) {
@ -853,9 +851,7 @@ void Connection::tryToSend() {
bigMsgId, bigMsgId,
forceNewMsgId, forceNewMsgId,
resendRequest); resendRequest);
Assert(!haveSent.contains(msgId)); _stateAndResendRequests.emplace(msgId, resendRequest);
haveSent.emplace(msgId, resendRequest);
sentIdsWrap.messages.push_back(msgId);
needAnyResponse = true; needAnyResponse = true;
} }
if (ackRequest) { if (ackRequest) {
@ -872,7 +868,6 @@ void Connection::tryToSend() {
forceNewMsgId, forceNewMsgId,
httpWaitRequest); httpWaitRequest);
} }
toSend.clear();
const auto containerMsgId = prepareToSend( const auto containerMsgId = prepareToSend(
toSendRequest, toSendRequest,
@ -1585,52 +1580,44 @@ Connection::HandleResult Connection::handleOneReceived(
auto &states = data.vinfo().v; auto &states = data.vinfo().v;
DEBUG_LOG(("Message Info: msg state received, msgId %1, reqMsgId: %2, HEX states %3").arg(msgId).arg(reqMsgId).arg(Logs::mb(states.data(), states.length()).str())); DEBUG_LOG(("Message Info: msg state received, msgId %1, reqMsgId: %2, HEX states %3").arg(msgId).arg(reqMsgId).arg(Logs::mb(states.data(), states.length()).str()));
SerializedRequest requestBuffer; const auto i = _stateAndResendRequests.find(reqMsgId);
{ // find this request in session-shared sent requests map if (i == _stateAndResendRequests.end()) {
QReadLocker locker(_sessionData->haveSentMutex()); DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId));
const auto &haveSent = _sessionData->haveSentMap(); return (badTime ? HandleResult::Ignored : HandleResult::Success);
const auto replyTo = haveSent.find(reqMsgId);
if (replyTo == haveSent.end()) { // do not look in toResend, because we do not resend msgs_state_req requests
DEBUG_LOG(("Message Error: such message was not sent recently %1").arg(reqMsgId));
return (badTime ? HandleResult::Ignored : HandleResult::Success);
}
if (badTime) {
if (serverSalt) {
_sessionSalt = serverSalt; // requestsFixTimeSalt with no lookup
}
base::unixtime::update(serverTime, true);
DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(serverTime));
badTime = false;
}
requestBuffer = replyTo->second;
} }
QVector<MTPlong> toAckReq(1, MTP_long(reqMsgId)), toAck; if (badTime) {
requestsAcked(toAck, true); if (serverSalt) {
_sessionSalt = serverSalt; // requestsFixTimeSalt with no lookup
}
base::unixtime::update(serverTime, true);
if (requestBuffer->size() < 9) { DEBUG_LOG(("Message Info: unixtime updated from mtpc_msgs_state_info, now %1").arg(serverTime));
LOG(("Message Error: bad request %1 found in requestMap, size: %2").arg(reqMsgId).arg(requestBuffer->size()));
return HandleResult::RestartConnection; badTime = false;
} }
const mtpPrime *rFrom = requestBuffer->constData() + 8, *rEnd = requestBuffer->constData() + requestBuffer->size(); const auto originalRequest = i->second;
Assert(originalRequest->size() > 8);
requestsAcked(QVector<MTPlong>(1, MTP_long(reqMsgId)), true);
auto rFrom = originalRequest->constData() + 8;
const auto rEnd = originalRequest->constData() + originalRequest->size();
auto toAck = QVector<MTPlong>();
if (mtpTypeId(*rFrom) == mtpc_msgs_state_req) { if (mtpTypeId(*rFrom) == mtpc_msgs_state_req) {
MTPMsgsStateReq request; MTPMsgsStateReq request;
if (!request.read(rFrom, rEnd)) { if (!request.read(rFrom, rEnd)) {
LOG(("Message Error: could not parse sent msgs_state_req")); LOG(("Message Error: could not parse sent msgs_state_req"));
return HandleResult::ParseError; return HandleResult::ParseError;
} }
handleMsgsStates(request.c_msgs_state_req().vmsg_ids().v, states, toAck); handleMsgsStates(request.c_msgs_state_req().vmsg_ids().v, states);
} else { } else {
MTPMsgResendReq request; MTPMsgResendReq request;
if (!request.read(rFrom, rEnd)) { if (!request.read(rFrom, rEnd)) {
LOG(("Message Error: could not parse sent msgs_resend_req")); LOG(("Message Error: could not parse sent msgs_resend_req"));
return HandleResult::ParseError; return HandleResult::ParseError;
} }
handleMsgsStates(request.c_msg_resend_req().vmsg_ids().v, states, toAck); handleMsgsStates(request.c_msg_resend_req().vmsg_ids().v, states);
} }
requestsAcked(toAck);
} return HandleResult::Success; } return HandleResult::Success;
case mtpc_msgs_all_info: { case mtpc_msgs_all_info: {
@ -1647,12 +1634,8 @@ Connection::HandleResult Connection::handleOneReceived(
auto &ids = data.vmsg_ids().v; auto &ids = data.vmsg_ids().v;
auto &states = data.vinfo().v; auto &states = data.vinfo().v;
QVector<MTPlong> toAck;
DEBUG_LOG(("Message Info: msgs all info received, msgId %1, reqMsgIds: %2, states %3").arg(msgId).arg(LogIdsVector(ids)).arg(Logs::mb(states.data(), states.length()).str())); DEBUG_LOG(("Message Info: msgs all info received, msgId %1, reqMsgIds: %2, states %3").arg(msgId).arg(LogIdsVector(ids)).arg(Logs::mb(states.data(), states.length()).str()));
handleMsgsStates(ids, states, toAck); handleMsgsStates(ids, states);
requestsAcked(toAck);
} return HandleResult::Success; } return HandleResult::Success;
case mtpc_msg_detailed_info: { case mtpc_msg_detailed_info: {
@ -1984,6 +1967,10 @@ void Connection::requestsAcked(const QVector<MTPlong> &ids, bool byResponse) {
_sentContainers.erase(i); _sentContainers.erase(i);
continue; continue;
} }
if (const auto i = _stateAndResendRequests.find(msgId); i != end(_stateAndResendRequests)) {
_stateAndResendRequests.erase(i);
continue;
}
if (const auto i = haveSent.find(msgId); i != end(haveSent)) { if (const auto i = haveSent.find(msgId); i != end(haveSent)) {
const auto requestId = i->second->requestId; const auto requestId = i->second->requestId;
@ -2040,21 +2027,22 @@ void Connection::requestsAcked(const QVector<MTPlong> &ids, bool byResponse) {
} }
} }
void Connection::handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states, QVector<MTPlong> &acked) { void Connection::handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states) {
uint32 idsCount = ids.size(); const auto idsCount = ids.size();
if (!idsCount) { if (!idsCount) {
DEBUG_LOG(("Message Info: void ids vector in handleMsgsStates()")); DEBUG_LOG(("Message Info: void ids vector in handleMsgsStates()"));
return; return;
} }
if (states.size() < idsCount) { if (states.size() != idsCount) {
LOG(("Message Error: got less states than required ids count.")); LOG(("Message Error: got less states than required ids count."));
return; return;
} }
acked.reserve(acked.size() + idsCount); auto acked = QVector<MTPlong>();
for (uint32 i = 0, count = idsCount; i < count; ++i) { acked.reserve(idsCount);
char state = states[i]; for (auto i = 0; i != idsCount; ++i) {
uint64 requestMsgId = ids[i].v; const auto state = states[i];
const auto requestMsgId = ids[i].v;
{ {
QReadLocker locker(_sessionData->haveSentMutex()); QReadLocker locker(_sessionData->haveSentMutex());
if (!_sessionData->haveSentMap().contains(requestMsgId)) { if (!_sessionData->haveSentMap().contains(requestMsgId)) {
@ -2081,6 +2069,7 @@ void Connection::handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray
acked.push_back(MTP_long(requestMsgId)); acked.push_back(MTP_long(requestMsgId));
} }
} }
requestsAcked(acked);
} }
void Connection::clearSpecialMsgId(mtpMsgId msgId) { void Connection::clearSpecialMsgId(mtpMsgId msgId) {
@ -2123,33 +2112,32 @@ void Connection::resend(
haveSent.erase(i); haveSent.erase(i);
lock.unlock(); lock.unlock();
if (!request.isStateRequest()) { request->lastSentTime = crl::now();
request->lastSentTime = crl::now(); request->forceSendInContainer = forceContainer;
request->forceSendInContainer = forceContainer; _resendingIds.emplace(msgId, request->requestId);
_resendingIds.emplace(msgId, request->requestId); {
{ QWriteLocker locker(_sessionData->toSendMutex());
QWriteLocker locker(_sessionData->toSendMutex()); _sessionData->toSendMap().emplace(request->requestId, request);
_sessionData->toSendMap().emplace(request->requestId, request);
}
} }
} }
void Connection::resendAll() { void Connection::resendAll() {
auto toResend = std::vector<mtpMsgId>(); auto lock = QWriteLocker(_sessionData->haveSentMutex());
auto haveSent = base::take(_sessionData->haveSentMap());
auto lock = QReadLocker(_sessionData->haveSentMutex()); lock.unlock();
const auto &haveSent = _sessionData->haveSentMap(); {
toResend.reserve(haveSent.size()); auto lock = QWriteLocker(_sessionData->toSendMutex());
for (const auto &[msgId, request] : haveSent) { auto &toSend = _sessionData->toSendMap();
if (!request.isStateRequest()) { const auto now = crl::now();
toResend.push_back(msgId); for (auto &[msgId, request] : haveSent) {
const auto requestId = request->requestId;
request->lastSentTime = now;
request->forceSendInContainer = true;
_resendingIds.emplace(msgId, requestId);
toSend.emplace(requestId, std::move(request));
} }
} }
lock.unlock();
for (const auto msgId : toResend) {
resend(msgId, -1, true);
}
_sessionData->queueSendAnything(); _sessionData->queueSendAnything();
} }

View File

@ -128,7 +128,7 @@ private:
mtpMsgId requestMsgId, mtpMsgId requestMsgId,
const mtpBuffer &response); const mtpBuffer &response);
mtpBuffer ungzip(const mtpPrime *from, const mtpPrime *end) const; mtpBuffer ungzip(const mtpPrime *from, const mtpPrime *end) const;
void handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states, QVector<MTPlong> &acked); void handleMsgsStates(const QVector<MTPlong> &ids, const QByteArray &states);
// _sessionDataMutex must be locked for read. // _sessionDataMutex must be locked for read.
bool setState(int state, int ifState = kUpdateStateAlways); bool setState(int state, int ifState = kUpdateStateAlways);
@ -219,6 +219,7 @@ private:
details::ReceivedIdsManager _receivedMessageIds; details::ReceivedIdsManager _receivedMessageIds;
base::flat_map<mtpMsgId, mtpRequestId> _resendingIds; base::flat_map<mtpMsgId, mtpRequestId> _resendingIds;
base::flat_map<mtpMsgId, mtpRequestId> _ackedIds; base::flat_map<mtpMsgId, mtpRequestId> _ackedIds;
base::flat_map<mtpMsgId, details::SerializedRequest> _stateAndResendRequests;
base::flat_map<mtpMsgId, SentContainer> _sentContainers; base::flat_map<mtpMsgId, SentContainer> _sentContainers;
std::unique_ptr<details::BoundKeyCreator> _keyCreator; std::unique_ptr<details::BoundKeyCreator> _keyCreator;

View File

@ -124,14 +124,6 @@ uint32 SerializedRequest::messageSize() const {
return kMessageIdInts + kSeqNoInts + kMessageLengthInts + ints; return kMessageIdInts + kSeqNoInts + kMessageLengthInts + ints;
} }
bool SerializedRequest::isStateRequest() const {
Expects(_data != nullptr);
Expects(_data->size() > kMessageBodyPosition);
const auto type = mtpTypeId((*_data)[kMessageBodyPosition]);
return (type == mtpc_msgs_state_req);
}
bool SerializedRequest::needAck() const { bool SerializedRequest::needAck() const {
Expects(_data != nullptr); Expects(_data != nullptr);
Expects(_data->size() > kMessageBodyPosition); Expects(_data->size() > kMessageBodyPosition);

View File

@ -67,7 +67,6 @@ public:
void addPadding(bool extended, bool old); void addPadding(bool extended, bool old);
[[nodiscard]] uint32 messageSize() const; [[nodiscard]] uint32 messageSize() const;
[[nodiscard]] bool isStateRequest() const;
[[nodiscard]] bool needAck() const; [[nodiscard]] bool needAck() const;
using ResponseType = void; // don't know real response type =( using ResponseType = void; // don't know real response type =(