diff --git a/Telegram/SourceFiles/mtproto/mtp_instance.cpp b/Telegram/SourceFiles/mtproto/mtp_instance.cpp index 782f51a3ca..042c0b248a 100644 --- a/Telegram/SourceFiles/mtproto/mtp_instance.cpp +++ b/Telegram/SourceFiles/mtproto/mtp_instance.cpp @@ -128,8 +128,10 @@ private: void configLoadDone(const MTPConfig &result); bool configLoadFail(const RPCError &error); - void cdnConfigLoadDone(const MTPCdnConfig &result); - bool cdnConfigLoadFail(const RPCError &error); + base::optional queryRequestByDc( + mtpRequestId requestId) const; + base::optional changeRequestByDc( + mtpRequestId requestId, DcId newdc); // RPCError::NoError means do not toggle onError callback. void clearCallbacks( @@ -164,7 +166,7 @@ private: // holds dcWithShift for request to this dc or -dc for request to main dc std::map _requestsByDc; - QMutex _requestByDcLock; + mutable QMutex _requestByDcLock; // holds target dcWithShift for auth export request std::map _authExportRequests; @@ -249,9 +251,7 @@ void Instance::Private::start(Config &&config) { _checkDelayedTimer.setCallback([this] { checkDelayedRequests(); }); Assert((_mainDcId == Config::kNoneMainDc) == isKeysDestroyer()); - if (!isKeysDestroyer()) { - requestConfig(); - } + requestConfig(); } void Instance::Private::suggestMainDcId(DcId mainDcId) { @@ -280,7 +280,7 @@ DcId Instance::Private::mainDcId() const { } void Instance::Private::requestConfig() { - if (_configLoader) { + if (_configLoader || isKeysDestroyer()) { return; } _configLoader = std::make_unique(_instance, rpcDone([this](const MTPConfig &result) { @@ -301,7 +301,9 @@ void Instance::Private::requestCDNConfig() { if (_cdnConfigLoadRequestId || _mainDcId == Config::kNoneMainDc) { return; } - _cdnConfigLoadRequestId = request(MTPhelp_GetCdnConfig()).done([this](const MTPCdnConfig &result) { + _cdnConfigLoadRequestId = request( + MTPhelp_GetCdnConfig() + ).done([this](const MTPCdnConfig &result) { _cdnConfigLoadRequestId = 0; Expects(result.type() == mtpc_cdnConfig); @@ -374,8 +376,8 @@ void Instance::Private::ping() { void Instance::Private::cancel(mtpRequestId requestId) { if (!requestId) return; - mtpMsgId msgId = 0; - _requestsDelays.erase(requestId); + const auto shiftedDcId = queryRequestByDc(requestId); + auto msgId = mtpMsgId(0); { QWriteLocker locker(&_requestMapLock); auto it = _requestMap.find(requestId); @@ -384,14 +386,10 @@ void Instance::Private::cancel(mtpRequestId requestId) { _requestMap.erase(it); } } - { - QMutexLocker locker(&_requestByDcLock); - auto it = _requestsByDc.find(requestId); - if (it != _requestsByDc.end()) { - if (auto session = getSession(qAbs(it->second))) { - session->cancel(requestId, msgId); - } - _requestsByDc.erase(it); + unregisterRequest(requestId); + if (shiftedDcId) { + if (const auto session = getSession(qAbs(*shiftedDcId))) { + session->cancel(requestId, msgId); } } clearCallbacks(requestId); @@ -399,10 +397,8 @@ void Instance::Private::cancel(mtpRequestId requestId) { int32 Instance::Private::state(mtpRequestId requestId) { // < 0 means waiting for such count of ms if (requestId > 0) { - QMutexLocker locker(&_requestByDcLock); - auto i = _requestsByDc.find(requestId); - if (i != _requestsByDc.end()) { - if (auto session = getSession(qAbs(i->second))) { + if (const auto shiftedDcId = queryRequestByDc(requestId)) { + if (auto session = getSession(qAbs(*shiftedDcId))) { return session->requestState(requestId); } return MTP::RequestConnecting; @@ -647,6 +643,32 @@ bool Instance::Private::configLoadFail(const RPCError &error) { return false; } +base::optional Instance::Private::queryRequestByDc( + mtpRequestId requestId) const { + QMutexLocker locker(&_requestByDcLock); + auto it = _requestsByDc.find(requestId); + if (it != _requestsByDc.cend()) { + return it->second; + } + return base::none; +} + +base::optional Instance::Private::changeRequestByDc( + mtpRequestId requestId, + DcId newdc) { + QMutexLocker locker(&_requestByDcLock); + auto it = _requestsByDc.find(requestId); + if (it != _requestsByDc.cend()) { + if (it->second < 0) { + it->second = -newdc; + } else { + it->second = shiftDcId(newdc, getDcIdShift(it->second)); + } + return it->second; + } + return base::none; +} + void Instance::Private::checkDelayedRequests() { auto now = getms(true); while (!_delayedRequests.empty() && now >= _delayedRequests.front().second) { @@ -654,15 +676,11 @@ void Instance::Private::checkDelayedRequests() { _delayedRequests.pop_front(); auto dcWithShift = ShiftedDcId(0); - { - QMutexLocker locker(&_requestByDcLock); - auto it = _requestsByDc.find(requestId); - if (it != _requestsByDc.cend()) { - dcWithShift = it->second; - } else { - LOG(("MTP Error: could not find request dc for delayed resend, requestId %1").arg(requestId)); - continue; - } + if (const auto shiftedDcId = queryRequestByDc(requestId)) { + dcWithShift = *shiftedDcId; + } else { + LOG(("MTP Error: could not find request dc for delayed resend, requestId %1").arg(requestId)); + continue; } auto request = mtpRequest(); @@ -880,10 +898,8 @@ bool Instance::Private::hasAuthorization() { } void Instance::Private::importDone(const MTPauth_Authorization &result, mtpRequestId requestId) { - QMutexLocker locker1(&_requestByDcLock); - - auto it = _requestsByDc.find(requestId); - if (it == _requestsByDc.end()) { + const auto shiftedDcId = queryRequestByDc(requestId); + if (!shiftedDcId) { LOG(("MTP Error: auth import request not found in requestsByDC, requestId: %1").arg(requestId)); RPCError error(internal::rpcClientError("AUTH_IMPORT_FAIL", QString("did not find import request in requestsByDC, request %1").arg(requestId))); if (_globalHandler.onFail && hasAuthorization()) { @@ -891,7 +907,7 @@ void Instance::Private::importDone(const MTPauth_Authorization &result, mtpReque } return; } - auto newdc = bareDcId(it->second); + auto newdc = bareDcId(*shiftedDcId); DEBUG_LOG(("MTP Info: auth import to dc %1 succeeded").arg(newdc)); @@ -904,23 +920,15 @@ void Instance::Private::importDone(const MTPauth_Authorization &result, mtpReque LOG(("MTP Error: could not find request %1 for resending").arg(waitedRequestId)); continue; } - auto dcWithShift = ShiftedDcId(newdc); - { - auto k = _requestsByDc.find(waitedRequestId); - if (k == _requestsByDc.cend()) { - LOG(("MTP Error: could not find request %1 by dc for resending").arg(waitedRequestId)); - continue; - } - if (k->second < 0) { - _instance->setMainDcId(newdc); - k->second = -newdc; - } else { - dcWithShift = shiftDcId(newdc, getDcIdShift(k->second)); - k->second = dcWithShift; - } - DEBUG_LOG(("MTP Info: resending request %1 to dc %2 after import auth").arg(waitedRequestId).arg(k->second)); + const auto shiftedDcId = changeRequestByDc(waitedRequestId, newdc); + if (!shiftedDcId) { + LOG(("MTP Error: could not find request %1 by dc for resending").arg(waitedRequestId)); + continue; + } else if (*shiftedDcId < 0) { + _instance->setMainDcId(newdc); } - if (auto session = getSession(dcWithShift)) { + DEBUG_LOG(("MTP Info: resending request %1 to dc %2 after import auth").arg(waitedRequestId).arg(*shiftedDcId)); + if (auto session = getSession(*shiftedDcId)) { session->sendPrepared(it->second); } } @@ -981,15 +989,12 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e if ((m = QRegularExpression("^(FILE|PHONE|NETWORK|USER)_MIGRATE_(\\d+)$").match(err)).hasMatch()) { if (!requestId) return false; - ShiftedDcId dcWithShift = 0, newdcWithShift = m.captured(2).toInt(); - { - QMutexLocker locker(&_requestByDcLock); - auto it = _requestsByDc.find(requestId); - if (it == _requestsByDc.end()) { - LOG(("MTP Error: could not find request %1 for migrating to %2").arg(requestId).arg(newdcWithShift)); - } else { - dcWithShift = it->second; - } + auto dcWithShift = ShiftedDcId(0); + auto newdcWithShift = ShiftedDcId(m.captured(2).toInt()); + if (const auto shiftedDcId = queryRequestByDc(requestId)) { + dcWithShift = *shiftedDcId; + } else { + LOG(("MTP Error: could not find request %1 for migrating to %2").arg(requestId).arg(newdcWithShift)); } if (!dcWithShift || !newdcWithShift) return false; @@ -1064,14 +1069,10 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e return true; } else if (code == 401 || (badGuestDc && _badGuestDcRequests.find(requestId) == _badGuestDcRequests.cend())) { auto dcWithShift = ShiftedDcId(0); - { - QMutexLocker locker(&_requestByDcLock); - auto it = _requestsByDc.find(requestId); - if (it != _requestsByDc.end()) { - dcWithShift = it->second; - } else { - LOG(("MTP Error: unauthorized request without dc info, requestId %1").arg(requestId)); - } + if (const auto shiftedDcId = queryRequestByDc(requestId)) { + dcWithShift = *shiftedDcId; + } else { + LOG(("MTP Error: unauthorized request without dc info, requestId %1").arg(requestId)); } auto newdc = bareDcId(qAbs(dcWithShift)); if (!newdc || newdc == mainDcId() || !hasAuthorization()) { @@ -1106,14 +1107,10 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e request = it->second; } auto dcWithShift = ShiftedDcId(0); - { - QMutexLocker locker(&_requestByDcLock); - auto it = _requestsByDc.find(requestId); - if (it == _requestsByDc.end()) { - LOG(("MTP Error: could not find request %1 for resending with init connection").arg(requestId)); - } else { - dcWithShift = it->second; - } + if (const auto shiftedDcId = queryRequestByDc(requestId)) { + dcWithShift = *shiftedDcId; + } else { + LOG(("MTP Error: could not find request %1 for resending with init connection").arg(requestId)); } if (!dcWithShift) return false; @@ -1140,20 +1137,17 @@ bool Instance::Private::onErrorDefault(mtpRequestId requestId, const RPCError &e return false; } auto dcWithShift = ShiftedDcId(0); - { - QMutexLocker locker(&_requestByDcLock); - auto it = _requestsByDc.find(requestId); - auto afterIt = _requestsByDc.find(request->after->requestId); - if (it == _requestsByDc.end()) { - LOG(("MTP Error: could not find request %1 by dc").arg(requestId)); - } else if (afterIt == _requestsByDc.end()) { - LOG(("MTP Error: could not find dependent request %1 by dc").arg(request->after->requestId)); - } else { - dcWithShift = it->second; - if (it->second != afterIt->second) { + if (const auto shiftedDcId = queryRequestByDc(requestId)) { + if (const auto afterDcId = queryRequestByDc(request->after->requestId)) { + dcWithShift = *shiftedDcId; + if (*shiftedDcId != *afterDcId) { request->after = mtpRequest(); } + } else { + LOG(("MTP Error: could not find dependent request %1 by dc").arg(request->after->requestId)); } + } else { + LOG(("MTP Error: could not find request %1 by dc").arg(requestId)); } if (!dcWithShift) return false;