tdesktop/Telegram/SourceFiles/mtproto/sender.h

403 lines
11 KiB
C++

/*
This file is part of Telegram Desktop,
the official desktop application for the Telegram messaging service.
For license and copyright information please follow this link:
https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/
#pragma once
#include "base/variant.h"
namespace MTP {
class Instance;
Instance *MainInstance();
class Sender {
class RequestBuilder {
public:
RequestBuilder(const RequestBuilder &other) = delete;
RequestBuilder &operator=(const RequestBuilder &other) = delete;
RequestBuilder &operator=(RequestBuilder &&other) = delete;
protected:
using FailPlainHandler = FnMut<void(const RPCError &error)>;
using FailRequestIdHandler = FnMut<void(const RPCError &error, mtpRequestId requestId)>;
enum class FailSkipPolicy {
Simple,
HandleFlood,
HandleAll,
};
template <typename Response>
struct DonePlainPolicy {
using Callback = FnMut<void(const Response &result)>;
static void handle(Callback &&handler, mtpRequestId requestId, Response &&result) {
handler(result);
}
};
template <typename Response>
struct DoneRequestIdPolicy {
using Callback = FnMut<void(const Response &result, mtpRequestId requestId)>;
static void handle(Callback &&handler, mtpRequestId requestId, Response &&result) {
handler(result, requestId);
}
};
template <typename Response, template <typename> typename PolicyTemplate>
class DoneHandler : public RPCAbstractDoneHandler {
using Policy = PolicyTemplate<Response>;
using Callback = typename Policy::Callback;
public:
DoneHandler(not_null<Sender*> sender, Callback handler) : _sender(sender), _handler(std::move(handler)) {
}
void operator()(mtpRequestId requestId, const mtpPrime *from, const mtpPrime *end) override {
auto handler = std::move(_handler);
_sender->senderRequestHandled(requestId);
if (handler) {
auto result = Response();
result.read(from, end);
Policy::handle(std::move(handler), requestId, std::move(result));
}
}
private:
not_null<Sender*> _sender;
Callback _handler;
};
struct FailPlainPolicy {
using Callback = FnMut<void(const RPCError &error)>;
static void handle(Callback &&handler, mtpRequestId requestId, const RPCError &error) {
handler(error);
}
};
struct FailRequestIdPolicy {
using Callback = FnMut<void(const RPCError &error, mtpRequestId requestId)>;
static void handle(Callback &&handler, mtpRequestId requestId, const RPCError &error) {
handler(error, requestId);
}
};
template <typename Policy>
class FailHandler : public RPCAbstractFailHandler {
using Callback = typename Policy::Callback;
public:
FailHandler(not_null<Sender*> sender, Callback handler, FailSkipPolicy skipPolicy)
: _sender(sender)
, _handler(std::move(handler))
, _skipPolicy(skipPolicy) {
}
bool operator()(mtpRequestId requestId, const RPCError &error) override {
if (_skipPolicy == FailSkipPolicy::Simple) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
} else if (_skipPolicy == FailSkipPolicy::HandleFlood) {
if (MTP::isDefaultHandledError(error) && !MTP::isFloodError(error)) {
return false;
}
}
auto handler = std::move(_handler);
_sender->senderRequestHandled(requestId);
if (handler) {
Policy::handle(std::move(handler), requestId, error);
}
return true;
}
private:
not_null<Sender*> _sender;
Callback _handler;
FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple;
};
explicit RequestBuilder(not_null<Sender*> sender) noexcept : _sender(sender) {
}
RequestBuilder(RequestBuilder &&other) = default;
void setToDC(ShiftedDcId dcId) noexcept {
_dcId = dcId;
}
void setCanWait(TimeMs ms) noexcept {
_canWait = ms;
}
void setDoneHandler(RPCDoneHandlerPtr &&handler) noexcept {
_done = std::move(handler);
}
void setFailHandler(FailPlainHandler &&handler) noexcept {
_fail = std::move(handler);
}
void setFailHandler(FailRequestIdHandler &&handler) noexcept {
_fail = std::move(handler);
}
void setFailSkipPolicy(FailSkipPolicy policy) noexcept {
_failSkipPolicy = policy;
}
void setAfter(mtpRequestId requestId) noexcept {
_afterRequestId = requestId;
}
ShiftedDcId takeDcId() const noexcept {
return _dcId;
}
TimeMs takeCanWait() const noexcept {
return _canWait;
}
RPCDoneHandlerPtr takeOnDone() noexcept {
return std::move(_done);
}
RPCFailHandlerPtr takeOnFail() {
if (auto handler = base::get_if<FailPlainHandler>(&_fail)) {
return std::make_shared<FailHandler<FailPlainPolicy>>(_sender, std::move(*handler), _failSkipPolicy);
} else if (auto handler = base::get_if<FailRequestIdHandler>(&_fail)) {
return std::make_shared<FailHandler<FailRequestIdPolicy>>(_sender, std::move(*handler), _failSkipPolicy);
}
return RPCFailHandlerPtr();
}
mtpRequestId takeAfter() const noexcept {
return _afterRequestId;
}
not_null<Sender*> sender() const noexcept {
return _sender;
}
void registerRequest(mtpRequestId requestId) {
_sender->senderRequestRegister(requestId);
}
private:
not_null<Sender*> _sender;
ShiftedDcId _dcId = 0;
TimeMs _canWait = 0;
RPCDoneHandlerPtr _done;
base::variant<FailPlainHandler, FailRequestIdHandler> _fail;
FailSkipPolicy _failSkipPolicy = FailSkipPolicy::Simple;
mtpRequestId _afterRequestId = 0;
};
public:
Sender() noexcept {
}
template <typename Request>
class SpecificRequestBuilder : public RequestBuilder {
private:
friend class Sender;
SpecificRequestBuilder(not_null<Sender*> sender, Request &&request) noexcept : RequestBuilder(sender), _request(std::move(request)) {
}
SpecificRequestBuilder(SpecificRequestBuilder &&other) = default;
public:
[[nodiscard]] SpecificRequestBuilder &toDC(ShiftedDcId dcId) noexcept {
setToDC(dcId);
return *this;
}
[[nodiscard]] SpecificRequestBuilder &afterDelay(TimeMs ms) noexcept {
setCanWait(ms);
return *this;
}
[[nodiscard]] SpecificRequestBuilder &done(FnMut<void(const typename Request::ResponseType &result)> callback) {
setDoneHandler(std::make_shared<DoneHandler<typename Request::ResponseType, DonePlainPolicy>>(sender(), std::move(callback)));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &done(FnMut<void(const typename Request::ResponseType &result, mtpRequestId requestId)> callback) {
setDoneHandler(std::make_shared<DoneHandler<typename Request::ResponseType, DoneRequestIdPolicy>>(sender(), std::move(callback)));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &fail(FnMut<void(const RPCError &error)> callback) noexcept {
setFailHandler(std::move(callback));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &fail(FnMut<void(const RPCError &error, mtpRequestId requestId)> callback) noexcept {
setFailHandler(std::move(callback));
return *this;
}
[[nodiscard]] SpecificRequestBuilder &handleFloodErrors() noexcept {
setFailSkipPolicy(FailSkipPolicy::HandleFlood);
return *this;
}
[[nodiscard]] SpecificRequestBuilder &handleAllErrors() noexcept {
setFailSkipPolicy(FailSkipPolicy::HandleAll);
return *this;
}
[[nodiscard]] SpecificRequestBuilder &afterRequest(mtpRequestId requestId) noexcept {
setAfter(requestId);
return *this;
}
mtpRequestId send() {
const auto id = MainInstance()->send(
_request,
takeOnDone(),
takeOnFail(),
takeDcId(),
takeCanWait(),
takeAfter());
registerRequest(id);
return id;
}
private:
Request _request;
};
class SentRequestWrap {
private:
friend class Sender;
SentRequestWrap(not_null<Sender*> sender, mtpRequestId requestId) : _sender(sender), _requestId(requestId) {
}
public:
void cancel() {
_sender->senderRequestCancel(_requestId);
}
private:
not_null<Sender*> _sender;
mtpRequestId _requestId = 0;
};
template <
typename Request,
typename = std::enable_if_t<!std::is_reference_v<Request>>,
typename = typename Request::Unboxed>
[[nodiscard]] SpecificRequestBuilder<Request> request(Request &&request) noexcept;
[[nodiscard]] SentRequestWrap request(mtpRequestId requestId) noexcept;
[[nodiscard]] auto requestCanceller() noexcept {
return [this](mtpRequestId requestId) {
request(requestId).cancel();
};
}
void requestSendDelayed() {
MainInstance()->sendAnything();
}
void requestCancellingDiscard() {
for (auto &request : _requests) {
request.handled();
}
}
not_null<Instance*> requestMTP() const {
return MainInstance();
}
private:
class RequestWrap {
public:
RequestWrap(
Instance *instance,
mtpRequestId requestId) noexcept
: _id(requestId) {
}
RequestWrap(const RequestWrap &other) = delete;
RequestWrap &operator=(const RequestWrap &other) = delete;
RequestWrap(RequestWrap &&other) : _id(base::take(other._id)) {
}
RequestWrap &operator=(RequestWrap &&other) {
if (_id != other._id) {
cancelRequest();
_id = base::take(other._id);
}
return *this;
}
mtpRequestId id() const noexcept {
return _id;
}
void handled() const noexcept {
}
~RequestWrap() {
cancelRequest();
}
private:
void cancelRequest() {
if (_id) {
if (auto instance = MainInstance()) {
instance->cancel(_id);
}
}
}
mtpRequestId _id = 0;
};
struct RequestWrapComparator {
using is_transparent = std::true_type;
struct helper {
mtpRequestId requestId = 0;
helper() = default;
helper(const helper &other) = default;
helper(mtpRequestId requestId) noexcept : requestId(requestId) {
}
helper(const RequestWrap &request) noexcept : requestId(request.id()) {
}
bool operator<(helper other) const {
return requestId < other.requestId;
}
};
bool operator()(const helper &&lhs, const helper &&rhs) const {
return lhs < rhs;
}
};
template <typename Request>
friend class SpecialRequestBuilder;
friend class RequestBuilder;
friend class RequestWrap;
friend class SentRequestWrap;
void senderRequestRegister(mtpRequestId requestId) {
_requests.emplace(MainInstance(), requestId);
}
void senderRequestHandled(mtpRequestId requestId) {
auto it = _requests.find(requestId);
if (it != _requests.cend()) {
it->handled();
_requests.erase(it);
}
}
void senderRequestCancel(mtpRequestId requestId) {
auto it = _requests.find(requestId);
if (it != _requests.cend()) {
_requests.erase(it);
}
}
base::flat_set<RequestWrap, RequestWrapComparator> _requests;
};
template <typename Request, typename, typename>
Sender::SpecificRequestBuilder<Request> Sender::request(Request &&request) noexcept {
return SpecificRequestBuilder<Request>(this, std::move(request));
}
inline Sender::SentRequestWrap Sender::request(mtpRequestId requestId) noexcept {
return SentRequestWrap(this, requestId);
}
} // namespace MTP