/* This file is part of Telegram Desktop, the official desktop application for the Telegram messaging service. For license and copyright information please follow this link: https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL */ #include "mtproto/concurrent_sender.h" #include "mtproto/mtp_instance.h" #include "mtproto/rpc_sender.h" namespace MTP { class ConcurrentSender::RPCDoneHandler : public RPCAbstractDoneHandler { public: RPCDoneHandler( not_null sender, Fn)> runner); bool operator()( mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override; private: base::weak_ptr _weak; Fn)> _runner; }; class ConcurrentSender::RPCFailHandler : public RPCAbstractFailHandler { public: RPCFailHandler( not_null sender, Fn)> runner, FailSkipPolicy skipPolicy); bool operator()( mtpRequestId requestId, const RPCError &error) override; private: base::weak_ptr _weak; Fn)> _runner; FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple; }; ConcurrentSender::RPCDoneHandler::RPCDoneHandler( not_null sender, Fn)> runner) : _weak(sender) , _runner(std::move(runner)) { } bool ConcurrentSender::RPCDoneHandler::operator()( mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) { auto response = gsl::make_span( from, end - from); _runner([=, weak = _weak, moved = bytes::make_vector(response)]() mutable { if (const auto strong = weak.get()) { strong->senderRequestDone(requestId, std::move(moved)); } }); return true; } ConcurrentSender::RPCFailHandler::RPCFailHandler( not_null sender, Fn)> runner, FailSkipPolicy skipPolicy) : _weak(sender) , _runner(std::move(runner)) , _skipPolicy(skipPolicy) { } bool ConcurrentSender::RPCFailHandler::operator()( mtpRequestId requestId, const RPCError &error) { if (_skipPolicy == FailSkipPolicy::Simple) { if (MTP::isDefaultHandledError(error)) { return false; } } else if (_skipPolicy == FailSkipPolicy::HandleFlood) { if (MTP::isDefaultHandledError(error) && !MTP::isFloodError(error)) { return false; } } _runner([=, weak = _weak, error = error]() mutable { if (const auto strong = weak.get()) { strong->senderRequestFail(requestId, std::move(error)); } }); return true; } template auto ConcurrentSender::with_instance(Method &&method) -> std::enable_if_t>> { crl::on_main([method = std::forward(method)]() mutable { if (const auto instance = MainInstance()) { std::move(method)(instance); } }); } ConcurrentSender::RequestBuilder::RequestBuilder( not_null sender, SecureRequest &&serialized) noexcept : _sender(sender) , _serialized(std::move(serialized)) { } void ConcurrentSender::RequestBuilder::setToDC(ShiftedDcId dcId) noexcept { _dcId = dcId; } void ConcurrentSender::RequestBuilder::setCanWait(crl::time ms) noexcept { _canWait = ms; } void ConcurrentSender::RequestBuilder::setFailSkipPolicy( FailSkipPolicy policy) noexcept { _failSkipPolicy = policy; } void ConcurrentSender::RequestBuilder::setAfter( mtpRequestId requestId) noexcept { _afterRequestId = requestId; } mtpRequestId ConcurrentSender::RequestBuilder::send() { const auto requestId = GetNextRequestId(); const auto dcId = _dcId; const auto msCanWait = _canWait; const auto afterRequestId = _afterRequestId; _sender->senderRequestRegister(requestId, std::move(_handlers)); _sender->with_instance([ =, request = std::move(_serialized), done = std::make_shared(_sender, _sender->_runner), fail = std::make_shared( _sender, _sender->_runner, _failSkipPolicy) ](not_null instance) mutable { instance->sendSerialized( requestId, std::move(request), RPCResponseHandler(std::move(done), std::move(fail)), dcId, msCanWait, afterRequestId); }); return requestId; } ConcurrentSender::ConcurrentSender(Fn)> runner) : _runner(runner) { } ConcurrentSender::~ConcurrentSender() { senderRequestCancelAll(); } void ConcurrentSender::senderRequestRegister( mtpRequestId requestId, Handlers &&handlers) { _requests.emplace(requestId, std::move(handlers)); } void ConcurrentSender::senderRequestDone( mtpRequestId requestId, bytes::const_span result) { if (auto handlers = _requests.take(requestId)) { if (!handlers->done(requestId, result)) { handlers->fail( requestId, RPCError::Local( "RESPONSE_PARSE_FAILED", "ConcurrentSender::senderRequestDone")); } } } void ConcurrentSender::senderRequestFail( mtpRequestId requestId, RPCError &&error) { if (auto handlers = _requests.take(requestId)) { handlers->fail(requestId, std::move(error)); } } void ConcurrentSender::senderRequestCancel(mtpRequestId requestId) { senderRequestDetach(requestId); with_instance([=](not_null instance) { instance->cancel(requestId); }); } void ConcurrentSender::senderRequestCancelAll() { auto list = std::vector(_requests.size()); for (const auto &pair : base::take(_requests)) { list.push_back(pair.first); } with_instance([list = std::move(list)](not_null instance) { for (const auto requestId : list) { instance->cancel(requestId); } }); } void ConcurrentSender::senderRequestDetach(mtpRequestId requestId) { _requests.erase(requestId); } } // namespace MTP