tdesktop/Telegram/SourceFiles/mtproto/concurrent_sender.cpp

225 lines
5.5 KiB
C++

/*
This file is part of Telegram Desktop,
the official desktop application for the Telegram messaging service.
For license and copyright information please follow this link:
https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/
#include "mtproto/concurrent_sender.h"
#include "mtproto/mtp_instance.h"
#include "mtproto/rpc_sender.h"
namespace MTP {
class ConcurrentSender::RPCDoneHandler : public RPCAbstractDoneHandler {
public:
RPCDoneHandler(
not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner);
void operator()(
mtpRequestId requestId,
const mtpPrime *from,
const mtpPrime *end) override;
private:
base::weak_ptr<ConcurrentSender> _weak;
Fn<void(FnMut<void()>)> _runner;
};
class ConcurrentSender::RPCFailHandler : public RPCAbstractFailHandler {
public:
RPCFailHandler(
not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner,
FailSkipPolicy skipPolicy);
bool operator()(
mtpRequestId requestId,
const RPCError &error) override;
private:
base::weak_ptr<ConcurrentSender> _weak;
Fn<void(FnMut<void()>)> _runner;
FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple;
};
ConcurrentSender::RPCDoneHandler::RPCDoneHandler(
not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner)
: _weak(sender)
, _runner(std::move(runner)) {
}
void ConcurrentSender::RPCDoneHandler::operator()(
mtpRequestId requestId,
const mtpPrime *from,
const mtpPrime *end) {
auto response = gsl::make_span(
from,
end - from);
_runner([=, weak = _weak, moved = bytes::make_vector(response)]() mutable {
if (const auto strong = weak.get()) {
strong->senderRequestDone(requestId, std::move(moved));
}
});
}
ConcurrentSender::RPCFailHandler::RPCFailHandler(
not_null<ConcurrentSender*> sender,
Fn<void(FnMut<void()>)> runner,
FailSkipPolicy skipPolicy)
: _weak(sender)
, _runner(std::move(runner))
, _skipPolicy(skipPolicy) {
}
bool ConcurrentSender::RPCFailHandler::operator()(
mtpRequestId requestId,
const RPCError &error) {
if (_skipPolicy == FailSkipPolicy::Simple) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
} else if (_skipPolicy == FailSkipPolicy::HandleFlood) {
if (MTP::isDefaultHandledError(error) && !MTP::isFloodError(error)) {
return false;
}
}
_runner([=, weak = _weak, error = error]() mutable {
if (const auto strong = weak.get()) {
strong->senderRequestFail(requestId, std::move(error));
}
});
return true;
}
template <typename Method>
auto ConcurrentSender::with_instance(Method &&method)
-> std::enable_if_t<is_callable_v<Method, not_null<Instance*>>> {
crl::on_main([method = std::forward<Method>(method)]() mutable {
if (const auto instance = MainInstance()) {
std::move(method)(instance);
}
});
}
ConcurrentSender::RequestBuilder::RequestBuilder(
not_null<ConcurrentSender*> sender,
SecureRequest &&serialized) noexcept
: _sender(sender)
, _serialized(std::move(serialized)) {
}
void ConcurrentSender::RequestBuilder::setToDC(ShiftedDcId dcId) noexcept {
_dcId = dcId;
}
void ConcurrentSender::RequestBuilder::setCanWait(crl::time ms) noexcept {
_canWait = ms;
}
void ConcurrentSender::RequestBuilder::setFailSkipPolicy(
FailSkipPolicy policy) noexcept {
_failSkipPolicy = policy;
}
void ConcurrentSender::RequestBuilder::setAfter(
mtpRequestId requestId) noexcept {
_afterRequestId = requestId;
}
mtpRequestId ConcurrentSender::RequestBuilder::send() {
const auto requestId = GetNextRequestId();
const auto dcId = _dcId;
const auto msCanWait = _canWait;
const auto afterRequestId = _afterRequestId;
_sender->senderRequestRegister(requestId, std::move(_handlers));
_sender->with_instance([
=,
request = std::move(_serialized),
done = std::make_shared<RPCDoneHandler>(_sender, _sender->_runner),
fail = std::make_shared<RPCFailHandler>(
_sender,
_sender->_runner,
_failSkipPolicy)
](not_null<Instance*> instance) mutable {
instance->sendSerialized(
requestId,
std::move(request),
RPCResponseHandler(std::move(done), std::move(fail)),
dcId,
msCanWait,
afterRequestId);
});
return requestId;
}
ConcurrentSender::ConcurrentSender(Fn<void(FnMut<void()>)> runner)
: _runner(runner) {
}
ConcurrentSender::~ConcurrentSender() {
senderRequestCancelAll();
}
void ConcurrentSender::senderRequestRegister(
mtpRequestId requestId,
Handlers &&handlers) {
_requests.emplace(requestId, std::move(handlers));
}
void ConcurrentSender::senderRequestDone(
mtpRequestId requestId,
bytes::const_span result) {
if (auto handlers = _requests.take(requestId)) {
try {
handlers->done(requestId, result);
} catch (Exception &e) {
handlers->fail(
requestId,
RPCError::Local(
"RESPONSE_PARSE_FAILED",
QString("exception text: ") + e.what()));
}
}
}
void ConcurrentSender::senderRequestFail(
mtpRequestId requestId,
RPCError &&error) {
if (auto handlers = _requests.take(requestId)) {
handlers->fail(requestId, std::move(error));
}
}
void ConcurrentSender::senderRequestCancel(mtpRequestId requestId) {
senderRequestDetach(requestId);
with_instance([=](not_null<Instance*> instance) {
instance->cancel(requestId);
});
}
void ConcurrentSender::senderRequestCancelAll() {
auto list = std::vector<mtpRequestId>(_requests.size());
for (const auto &pair : base::take(_requests)) {
list.push_back(pair.first);
}
with_instance([list = std::move(list)](not_null<Instance*> instance) {
for (const auto requestId : list) {
instance->cancel(requestId);
}
});
}
void ConcurrentSender::senderRequestDetach(mtpRequestId requestId) {
_requests.erase(requestId);
}
} // namespace MTP