tdesktop/Telegram/SourceFiles/storage/download_manager_mtproto.cpp

1046 lines
29 KiB
C++

/*
This file is part of Telegram Desktop,
the official desktop application for the Telegram messaging service.
For license and copyright information please follow this link:
https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/
#include "storage/download_manager_mtproto.h"
#include "mtproto/facade.h"
#include "mtproto/mtproto_auth_key.h"
#include "mtproto/mtproto_response.h"
#include "main/main_session.h"
#include "data/data_session.h"
#include "data/data_document.h"
#include "apiwrap.h"
#include "base/openssl_help.h"
namespace Storage {
namespace {
constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
constexpr auto kStartWaitedInSession = 4 * kDownloadPartSize;
constexpr auto kMaxWaitedInSession = 16 * kDownloadPartSize;
constexpr auto kStartSessionsCount = 1;
constexpr auto kMaxSessionsCount = 8;
constexpr auto kMaxTrackedSessionRemoves = 64;
constexpr auto kRetryAddSessionTimeout = 8 * crl::time(1000);
constexpr auto kRetryAddSessionSuccesses = 3;
constexpr auto kMaxTrackedSuccesses = kRetryAddSessionSuccesses
* kMaxTrackedSessionRemoves;
constexpr auto kRemoveSessionAfterTimeouts = 4;
constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
constexpr auto kBadRequestDurationThreshold = 8 * crl::time(1000);
// Each (session remove by timeouts) we wait for time:
// kRetryAddSessionTimeout * max(removesCount, kMaxTrackedSessionRemoves)
// and for successes in all remaining sessions:
// kRetryAddSessionSuccesses * max(removesCount, kMaxTrackedSessionRemoves)
} // namespace
void DownloadManagerMtproto::Queue::enqueue(
not_null<Task*> task,
int priority) {
const auto position = ranges::find_if(_tasks, [&](const Enqueued &task) {
return task.priority <= priority;
}) - begin(_tasks);
const auto now = ranges::find(_tasks, task, &Enqueued::task);
const auto i = [&] {
if (now != end(_tasks)) {
(now->priority = priority);
return now;
}
_tasks.push_back({ task, priority });
return end(_tasks) - 1;
}();
const auto j = begin(_tasks) + position;
if (j < i) {
std::rotate(j, i, i + 1);
} else if (j > i + 1) {
std::rotate(i, i + 1, j);
}
}
void DownloadManagerMtproto::Queue::remove(not_null<Task*> task) {
_tasks.erase(ranges::remove(_tasks, task, &Enqueued::task), end(_tasks));
}
void DownloadManagerMtproto::Queue::resetGeneration() {
const auto from = ranges::find(_tasks, 0, &Enqueued::priority);
for (auto &task : ranges::make_subrange(from, end(_tasks))) {
if (task.priority) {
Assert(task.priority == -1);
break;
}
task.priority = -1;
}
}
bool DownloadManagerMtproto::Queue::empty() const {
return _tasks.empty();
}
auto DownloadManagerMtproto::Queue::nextTask(bool onlyHighestPriority) const
-> Task* {
if (_tasks.empty()) {
return nullptr;
}
const auto highestPriority = _tasks.front().priority;
const auto notHighestPriority = [&](const Enqueued &enqueued) {
return (enqueued.priority != highestPriority);
};
const auto till = (onlyHighestPriority && highestPriority > 0)
? ranges::find_if(_tasks, notHighestPriority)
: end(_tasks);
const auto readyToRequest = [&](const Enqueued &enqueued) {
return enqueued.task->readyToRequest();
};
const auto first = ranges::find_if(
ranges::make_subrange(begin(_tasks), till),
readyToRequest);
return (first != till) ? first->task.get() : nullptr;
}
void DownloadManagerMtproto::Queue::removeSession(int index) {
for (const auto &enqueued : _tasks) {
enqueued.task->removeSession(index);
}
}
DownloadManagerMtproto::DcSessionBalanceData::DcSessionBalanceData()
: maxWaitedAmount(kStartWaitedInSession) {
}
DownloadManagerMtproto::DcBalanceData::DcBalanceData()
: sessions(kStartSessionsCount) {
}
DownloadManagerMtproto::DownloadManagerMtproto(not_null<ApiWrap*> api)
: _api(api)
, _resetGenerationTimer([=] { resetGeneration(); })
, _killSessionsTimer([=] { killSessions(); }) {
_api->instance().restartsByTimeout(
) | rpl::filter([](MTP::ShiftedDcId shiftedDcId) {
return MTP::isDownloadDcId(shiftedDcId);
}) | rpl::start_with_next([=](MTP::ShiftedDcId shiftedDcId) {
sessionTimedOut(
MTP::BareDcId(shiftedDcId),
MTP::GetDcIdShift(shiftedDcId));
}, _lifetime);
}
DownloadManagerMtproto::~DownloadManagerMtproto() {
killSessions();
}
void DownloadManagerMtproto::enqueue(not_null<Task*> task, int priority) {
const auto dcId = task->dcId();
auto &queue = _queues[dcId];
queue.enqueue(task, priority);
if (!_resetGenerationTimer.isActive()) {
_resetGenerationTimer.callOnce(kResetDownloadPrioritiesTimeout);
}
checkSendNext(dcId, queue);
}
void DownloadManagerMtproto::remove(not_null<Task*> task) {
const auto dcId = task->dcId();
auto &queue = _queues[dcId];
queue.remove(task);
checkSendNext(dcId, queue);
}
void DownloadManagerMtproto::resetGeneration() {
_resetGenerationTimer.cancel();
for (auto &[dcId, queue] : _queues) {
queue.resetGeneration();
}
}
void DownloadManagerMtproto::checkSendNext() {
for (auto &[dcId, queue] : _queues) {
if (queue.empty()) {
continue;
}
checkSendNext(dcId, queue);
}
}
void DownloadManagerMtproto::checkSendNext(MTP::DcId dcId, Queue &queue) {
while (trySendNextPart(dcId, queue)) {
}
}
void DownloadManagerMtproto::checkSendNextAfterSuccess(MTP::DcId dcId) {
checkSendNext(dcId, _queues[dcId]);
}
bool DownloadManagerMtproto::trySendNextPart(MTP::DcId dcId, Queue &queue) {
auto &balanceData = _balanceData[dcId];
const auto &sessions = balanceData.sessions;
const auto bestIndex = [&] {
const auto proj = [](const DcSessionBalanceData &data) {
return (data.requested < data.maxWaitedAmount)
? data.requested
: kMaxWaitedInSession;
};
const auto j = ranges::min_element(sessions, ranges::less(), proj);
return (j->requested + kDownloadPartSize <= j->maxWaitedAmount)
? (j - begin(sessions))
: -1;
}();
if (bestIndex < 0) {
return false;
}
const auto onlyHighestPriority = (balanceData.totalRequested > 0);
if (const auto task = queue.nextTask(onlyHighestPriority)) {
task->loadPart(bestIndex);
return true;
}
return false;
}
int DownloadManagerMtproto::changeRequestedAmount(
MTP::DcId dcId,
int index,
int delta) {
const auto i = _balanceData.find(dcId);
Assert(i != _balanceData.end());
Assert(index < i->second.sessions.size());
const auto result = (i->second.sessions[index].requested += delta);
i->second.totalRequested += delta;
const auto findNonEmptySession = [](const DcBalanceData &data) {
using namespace rpl::mappers;
return ranges::find_if(
data.sessions,
_1 > 0,
&DcSessionBalanceData::requested);
};
if (delta > 0) {
killSessionsCancel(dcId);
} else if (findNonEmptySession(i->second) == end(i->second.sessions)) {
killSessionsSchedule(dcId);
}
return result;
}
void DownloadManagerMtproto::requestSucceeded(
MTP::DcId dcId,
int index,
int amountAtRequestStart,
crl::time timeAtRequestStart) {
using namespace rpl::mappers;
const auto i = _balanceData.find(dcId);
Assert(i != end(_balanceData));
auto &dc = i->second;
Assert(index < dc.sessions.size());
auto &data = dc.sessions[index];
const auto overloaded = (timeAtRequestStart <= dc.lastSessionRemove)
|| (amountAtRequestStart > data.maxWaitedAmount);
const auto parts = amountAtRequestStart / kDownloadPartSize;
const auto duration = (crl::now() - timeAtRequestStart);
DEBUG_LOG(("Download (%1,%2) request done, duration: %3, parts: %4%5"
).arg(dcId
).arg(index
).arg(duration
).arg(parts
).arg(overloaded ? " (overloaded)" : ""));
if (overloaded) {
return;
}
if (duration >= kBadRequestDurationThreshold) {
DEBUG_LOG(("Duration too large, signaling time out."));
crl::on_main(this, [=] {
sessionTimedOut(dcId, index);
});
return;
}
if (amountAtRequestStart == data.maxWaitedAmount
&& data.maxWaitedAmount < kMaxWaitedInSession) {
data.maxWaitedAmount = std::min(
data.maxWaitedAmount + kDownloadPartSize,
kMaxWaitedInSession);
DEBUG_LOG(("Download (%1,%2) increased max waited amount %3."
).arg(dcId
).arg(index
).arg(data.maxWaitedAmount));
}
data.successes = std::min(data.successes + 1, kMaxTrackedSuccesses);
const auto notEnough = ranges::any_of(
dc.sessions,
_1 < (dc.sessionRemoveTimes + 1) * kRetryAddSessionSuccesses,
&DcSessionBalanceData::successes);
if (notEnough) {
return;
}
for (auto &session : dc.sessions) {
session.successes = 0;
}
if (dc.timeouts > 0) {
--dc.timeouts;
return;
} else if (dc.sessions.size() == kMaxSessionsCount) {
return;
}
const auto now = crl::now();
const auto delay = (dc.sessionRemoveTimes + 1) * kRetryAddSessionTimeout;
if (dc.lastSessionRemove && now < dc.lastSessionRemove + delay) {
return;
}
dc.sessions.emplace_back();
DEBUG_LOG(("Download (%1,%2) adding, now sessions: %3"
).arg(dcId
).arg(dc.sessions.size() - 1
).arg(dc.sessions.size()));
}
int DownloadManagerMtproto::chooseSessionIndex(MTP::DcId dcId) const {
const auto i = _balanceData.find(dcId);
Assert(i != end(_balanceData));
const auto &sessions = i->second.sessions;
const auto j = ranges::min_element(
sessions,
ranges::less(),
&DcSessionBalanceData::requested);
return (j - begin(sessions));
}
void DownloadManagerMtproto::sessionTimedOut(MTP::DcId dcId, int index) {
const auto i = _balanceData.find(dcId);
if (i == end(_balanceData)) {
return;
}
auto &dc = i->second;
if (index >= dc.sessions.size()) {
return;
}
DEBUG_LOG(("Download (%1,%2) session timed-out.").arg(dcId).arg(index));
for (auto &session : dc.sessions) {
session.successes = 0;
}
if (dc.sessions.size() == kStartSessionsCount
|| ++dc.timeouts < kRemoveSessionAfterTimeouts) {
return;
}
dc.timeouts = 0;
removeSession(dcId);
}
void DownloadManagerMtproto::removeSession(MTP::DcId dcId) {
auto &dc = _balanceData[dcId];
Assert(dc.sessions.size() > kStartSessionsCount);
const auto index = int(dc.sessions.size() - 1);
DEBUG_LOG(("Download (%1,%2) removing, now sessions: %3"
).arg(dcId
).arg(index
).arg(index));
auto &queue = _queues[dcId];
if (dc.sessionRemoveIndex == index) {
dc.sessionRemoveTimes = std::min(
dc.sessionRemoveTimes + 1,
kMaxTrackedSessionRemoves);
} else {
dc.sessionRemoveIndex = index;
dc.sessionRemoveTimes = 1;
}
auto &session = dc.sessions.back();
// Make sure we don't send anything to that session while redirecting.
session.requested += kMaxWaitedInSession * kMaxSessionsCount;
queue.removeSession(index);
Assert(session.requested == kMaxWaitedInSession * kMaxSessionsCount);
dc.sessions.pop_back();
api().instance().killSession(MTP::downloadDcId(dcId, index));
dc.lastSessionRemove = crl::now();
}
void DownloadManagerMtproto::killSessionsSchedule(MTP::DcId dcId) {
if (!_killSessionsWhen.contains(dcId)) {
_killSessionsWhen.emplace(dcId, crl::now() + kKillSessionTimeout);
}
if (!_killSessionsTimer.isActive()) {
_killSessionsTimer.callOnce(kKillSessionTimeout + 5);
}
}
void DownloadManagerMtproto::killSessionsCancel(MTP::DcId dcId) {
_killSessionsWhen.erase(dcId);
if (_killSessionsWhen.empty()) {
_killSessionsTimer.cancel();
}
}
void DownloadManagerMtproto::killSessions() {
const auto now = crl::now();
auto left = kKillSessionTimeout;
for (auto i = begin(_killSessionsWhen); i != end(_killSessionsWhen); ) {
if (i->second <= now) {
killSessions(i->first);
i = _killSessionsWhen.erase(i);
} else {
if (i->second - now < left) {
left = i->second - now;
}
++i;
}
}
if (!_killSessionsWhen.empty()) {
_killSessionsTimer.callOnce(left);
}
}
void DownloadManagerMtproto::killSessions(MTP::DcId dcId) {
const auto i = _balanceData.find(dcId);
if (i != end(_balanceData)) {
auto &dc = i->second;
Assert(dc.totalRequested == 0);
auto sessions = base::take(dc.sessions);
dc = DcBalanceData();
for (auto j = 0; j != int(sessions.size()); ++j) {
Assert(sessions[j].requested == 0);
sessions[j] = DcSessionBalanceData();
api().instance().stopSession(MTP::downloadDcId(dcId, j));
}
dc.sessions = base::take(sessions);
}
}
DownloadMtprotoTask::DownloadMtprotoTask(
not_null<DownloadManagerMtproto*> owner,
const StorageFileLocation &location,
Data::FileOrigin origin)
: _owner(owner)
, _dcId(location.dcId())
, _location({ location })
, _origin(origin) {
}
DownloadMtprotoTask::DownloadMtprotoTask(
not_null<DownloadManagerMtproto*> owner,
MTP::DcId dcId,
const Location &location)
: _owner(owner)
, _dcId(dcId)
, _location(location) {
}
DownloadMtprotoTask::~DownloadMtprotoTask() {
cancelAllRequests();
_owner->remove(this);
}
MTP::DcId DownloadMtprotoTask::dcId() const {
return _dcId;
}
Data::FileOrigin DownloadMtprotoTask::fileOrigin() const {
return _origin;
}
uint64 DownloadMtprotoTask::objectId() const {
if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
return v->objectId();
}
return 0;
}
const DownloadMtprotoTask::Location &DownloadMtprotoTask::location() const {
return _location;
}
void DownloadMtprotoTask::refreshFileReferenceFrom(
const Data::UpdatedFileReferences &updates,
int requestId,
const QByteArray &current) {
if (const auto v = std::get_if<StorageFileLocation>(&_location.data)) {
v->refreshFileReference(updates);
if (v->fileReference() == current) {
cancelOnFail();
return;
}
} else {
cancelOnFail();
return;
}
if (_sentRequests.contains(requestId)) {
makeRequest(finishSentRequest(
requestId,
FinishRequestReason::Redirect));
}
}
void DownloadMtprotoTask::loadPart(int sessionIndex) {
makeRequest({ takeNextRequestOffset(), sessionIndex });
}
void DownloadMtprotoTask::removeSession(int sessionIndex) {
struct Redirect {
mtpRequestId requestId = 0;
int64 offset = 0;
};
auto redirect = std::vector<Redirect>();
for (const auto &[requestId, requestData] : _sentRequests) {
if (requestData.sessionIndex == sessionIndex) {
redirect.reserve(_sentRequests.size());
redirect.push_back({ requestId, requestData.offset });
}
}
for (auto &[requestData, bytes] : _cdnUncheckedParts) {
if (requestData.sessionIndex == sessionIndex) {
const auto newIndex = _owner->chooseSessionIndex(dcId());
Assert(newIndex < sessionIndex);
requestData.sessionIndex = newIndex;
}
}
for (const auto &[requestId, offset] : redirect) {
const auto needMakeRequest = (requestId != _cdnHashesRequestId);
cancelRequest(requestId);
if (needMakeRequest) {
const auto newIndex = _owner->chooseSessionIndex(dcId());
Assert(newIndex < sessionIndex);
makeRequest({ offset, newIndex });
}
}
}
mtpRequestId DownloadMtprotoTask::sendRequest(
const RequestData &requestData) {
const auto offset = requestData.offset;
const auto limit = Storage::kDownloadPartSize;
const auto shiftedDcId = MTP::downloadDcId(
_cdnDcId ? _cdnDcId : dcId(),
requestData.sessionIndex);
if (_cdnDcId) {
return api().request(MTPupload_GetCdnFile(
MTP_bytes(_cdnToken),
MTP_long(offset),
MTP_int(limit)
)).done([=](const MTPupload_CdnFile &result, mtpRequestId id) {
cdnPartLoaded(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
cdnPartFailed(error, id);
}).toDC(shiftedDcId).send();
}
return v::match(_location.data, [&](const WebFileLocation &location) {
return api().request(MTPupload_GetWebFile(
MTP_inputWebFileLocation(
MTP_bytes(location.url()),
MTP_long(location.accessHash())),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
webPartLoaded(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
partFailed(error, id);
}).toDC(shiftedDcId).send();
}, [&](const GeoPointLocation &location) {
return api().request(MTPupload_GetWebFile(
MTP_inputWebFileGeoPointLocation(
MTP_inputGeoPoint(
MTP_flags(0),
MTP_double(location.lat),
MTP_double(location.lon),
MTP_int(0)), // accuracy_radius
MTP_long(location.access),
MTP_int(location.width),
MTP_int(location.height),
MTP_int(location.zoom),
MTP_int(location.scale)),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
webPartLoaded(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
partFailed(error, id);
}).toDC(shiftedDcId).send();
}, [&](const AudioAlbumThumbLocation &location) {
using Flag = MTPDinputWebFileAudioAlbumThumbLocation::Flag;
const auto owner = &api().session().data();
return api().request(MTPupload_GetWebFile(
MTP_inputWebFileAudioAlbumThumbLocation(
MTP_flags(Flag::f_document | Flag::f_small),
owner->document(location.documentId)->mtpInput(),
MTPstring(),
MTPstring()),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
webPartLoaded(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
partFailed(error, id);
}).toDC(shiftedDcId).send();
}, [&](const StorageFileLocation &location) {
const auto reference = location.fileReference();
return api().request(MTPupload_GetFile(
MTP_flags(MTPupload_GetFile::Flag::f_cdn_supported),
location.tl(api().session().userId()),
MTP_long(offset),
MTP_int(limit)
)).done([=](const MTPupload_File &result, mtpRequestId id) {
normalPartLoaded(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
normalPartFailed(reference, error, id);
}).toDC(shiftedDcId).send();
});
}
bool DownloadMtprotoTask::setWebFileSizeHook(int64 size) {
return true;
}
void DownloadMtprotoTask::makeRequest(const RequestData &requestData) {
placeSentRequest(sendRequest(requestData), requestData);
}
void DownloadMtprotoTask::requestMoreCdnFileHashes() {
if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) {
return;
}
const auto requestData = _cdnUncheckedParts.cbegin()->first;
const auto shiftedDcId = MTP::downloadDcId(
dcId(),
requestData.sessionIndex);
_cdnHashesRequestId = api().request(MTPupload_GetCdnFileHashes(
MTP_bytes(_cdnToken),
MTP_long(requestData.offset)
)).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
getCdnFileHashesDone(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
cdnPartFailed(error, id);
}).toDC(shiftedDcId).send();
placeSentRequest(_cdnHashesRequestId, requestData);
}
void DownloadMtprotoTask::normalPartLoaded(
const MTPupload_File &result,
mtpRequestId requestId) {
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Success);
const auto owner = _owner;
const auto dcId = this->dcId();
result.match([&](const MTPDupload_fileCdnRedirect &data) {
switchToCDN(requestData, data);
}, [&](const MTPDupload_file &data) {
partLoaded(requestData.offset, data.vbytes().v);
});
// 'this' may be deleted at this point.
owner->checkSendNextAfterSuccess(dcId);
}
void DownloadMtprotoTask::webPartLoaded(
const MTPupload_WebFile &result,
mtpRequestId requestId) {
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Success);
const auto owner = _owner;
const auto dcId = this->dcId();
result.match([&](const MTPDupload_webFile &data) {
if (setWebFileSizeHook(data.vsize().v)) {
partLoaded(requestData.offset, data.vbytes().v);
}
});
// 'this' may be deleted at this point.
owner->checkSendNextAfterSuccess(dcId);
}
void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Redirect);
const auto shiftedDcId = MTP::downloadDcId(
dcId(),
requestData.sessionIndex);
const auto requestId = api().request(MTPupload_ReuploadCdnFile(
MTP_bytes(_cdnToken),
data.vrequest_token()
)).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
reuploadDone(result, id);
}).fail([=](const MTP::Error &error, mtpRequestId id) {
cdnPartFailed(error, id);
}).toDC(shiftedDcId).send();
placeSentRequest(requestId, requestData);
}, [&](const MTPDupload_cdnFile &data) {
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Success);
const auto owner = _owner;
const auto dcId = this->dcId();
const auto guard = gsl::finally([=] {
// 'this' may be deleted at this point.
owner->checkSendNextAfterSuccess(dcId);
});
auto key = bytes::make_span(_cdnEncryptionKey);
auto iv = bytes::make_span(_cdnEncryptionIV);
Expects(key.size() == MTP::CTRState::KeySize);
Expects(iv.size() == MTP::CTRState::IvecSize);
auto state = MTP::CTRState();
auto ivec = bytes::make_span(state.ivec);
std::copy(iv.begin(), iv.end(), ivec.begin());
auto counterOffset = static_cast<uint32>(requestData.offset >> 4);
state.ivec[15] = static_cast<uchar>(counterOffset & 0xFF);
state.ivec[14] = static_cast<uchar>((counterOffset >> 8) & 0xFF);
state.ivec[13] = static_cast<uchar>((counterOffset >> 16) & 0xFF);
state.ivec[12] = static_cast<uchar>((counterOffset >> 24) & 0xFF);
auto decryptInPlace = data.vbytes().v;
auto buffer = bytes::make_detached_span(decryptInPlace);
MTP::aesCtrEncrypt(buffer, key.data(), &state);
switch (checkCdnFileHash(requestData.offset, buffer)) {
case CheckCdnHashResult::NoHash: {
_cdnUncheckedParts.emplace(requestData, decryptInPlace);
requestMoreCdnFileHashes();
} return;
case CheckCdnHashResult::Invalid: {
LOG(("API Error: Wrong cdnFileHash for offset %1."
).arg(requestData.offset));
cancelOnFail();
} return;
case CheckCdnHashResult::Good: {
partLoaded(requestData.offset, decryptInPlace);
} return;
}
Unexpected("Result of checkCdnFileHash()");
});
}
DownloadMtprotoTask::CheckCdnHashResult DownloadMtprotoTask::checkCdnFileHash(
int64 offset,
bytes::const_span buffer) {
const auto cdnFileHashIt = _cdnFileHashes.find(offset);
if (cdnFileHashIt == _cdnFileHashes.cend()) {
return CheckCdnHashResult::NoHash;
}
const auto realHash = openssl::Sha256(buffer);
const auto receivedHash = bytes::make_span(cdnFileHashIt->second.hash);
if (bytes::compare(realHash, receivedHash)) {
return CheckCdnHashResult::Invalid;
}
return CheckCdnHashResult::Good;
}
void DownloadMtprotoTask::reuploadDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId) {
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Redirect);
addCdnHashes(result.v);
makeRequest(requestData);
}
void DownloadMtprotoTask::getCdnFileHashesDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId) {
Expects(_cdnHashesRequestId == requestId);
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Redirect);
addCdnHashes(result.v);
auto someMoreChecked = false;
for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
const auto uncheckedData = i->first;
const auto uncheckedBytes = bytes::make_span(i->second);
switch (checkCdnFileHash(uncheckedData.offset, uncheckedBytes)) {
case CheckCdnHashResult::NoHash: {
++i;
} break;
case CheckCdnHashResult::Invalid: {
LOG(("API Error: Wrong cdnFileHash for offset %1."
).arg(uncheckedData.offset));
cancelOnFail();
return;
} break;
case CheckCdnHashResult::Good: {
someMoreChecked = true;
const auto goodOffset = uncheckedData.offset;
const auto goodBytes = std::move(i->second);
const auto weak = base::make_weak(this);
i = _cdnUncheckedParts.erase(i);
if (!feedPart(goodOffset, goodBytes) || !weak) {
return;
}
} break;
default: Unexpected("Result of checkCdnFileHash()");
}
}
if (!someMoreChecked) {
LOG(("API Error: "
"Could not find cdnFileHash for offset %1 "
"after getCdnFileHashes request."
).arg(requestData.offset));
cancelOnFail();
return;
}
requestMoreCdnFileHashes();
}
void DownloadMtprotoTask::placeSentRequest(
mtpRequestId requestId,
const RequestData &requestData) {
if (_sentRequests.empty()) {
subscribeToNonPremiumLimit();
}
const auto amount = _owner->changeRequestedAmount(
dcId(),
requestData.sessionIndex,
Storage::kDownloadPartSize);
const auto &[i, ok1] = _sentRequests.emplace(requestId, requestData);
const auto &[j, ok2] = _requestByOffset.emplace(
requestData.offset,
requestId);
i->second.requestedInSession = amount;
i->second.sent = crl::now();
Ensures(ok1 && ok2);
}
void DownloadMtprotoTask::subscribeToNonPremiumLimit() {
if (_nonPremiumLimitSubscription) {
return;
}
_owner->api().instance().nonPremiumDelayedRequests(
) | rpl::start_with_next([=](mtpRequestId id) {
if (_sentRequests.contains(id)) {
if (const auto documentId = objectId()) {
const auto type = v::get<StorageFileLocation>(
_location.data).type();
if (type == StorageFileLocation::Type::Document) {
_owner->notifyNonPremiumDelay(documentId);
}
}
}
}, _nonPremiumLimitSubscription);
}
auto DownloadMtprotoTask::finishSentRequest(
mtpRequestId requestId,
FinishRequestReason reason)
-> RequestData {
auto it = _sentRequests.find(requestId);
Assert(it != _sentRequests.cend());
if (_cdnHashesRequestId == requestId) {
_cdnHashesRequestId = 0;
}
const auto result = it->second;
_owner->changeRequestedAmount(
dcId(),
result.sessionIndex,
-Storage::kDownloadPartSize);
_sentRequests.erase(it);
const auto ok = _requestByOffset.remove(result.offset);
if (_sentRequests.empty()) {
_nonPremiumLimitSubscription.destroy();
}
if (reason == FinishRequestReason::Success) {
_owner->requestSucceeded(
dcId(),
result.sessionIndex,
result.requestedInSession,
result.sent);
}
Ensures(ok);
return result;
}
bool DownloadMtprotoTask::haveSentRequests() const {
return !_sentRequests.empty() || !_cdnUncheckedParts.empty();
}
bool DownloadMtprotoTask::haveSentRequestForOffset(int64 offset) const {
return _requestByOffset.contains(offset)
|| _cdnUncheckedParts.contains({ offset, 0 });
}
void DownloadMtprotoTask::cancelAllRequests() {
while (!_sentRequests.empty()) {
cancelRequest(_sentRequests.begin()->first);
}
_cdnUncheckedParts.clear();
}
void DownloadMtprotoTask::cancelRequestForOffset(int64 offset) {
const auto i = _requestByOffset.find(offset);
if (i != end(_requestByOffset)) {
cancelRequest(i->second);
}
_cdnUncheckedParts.remove({ offset, 0 });
}
void DownloadMtprotoTask::cancelRequest(mtpRequestId requestId) {
const auto hashes = (_cdnHashesRequestId == requestId);
api().request(requestId).cancel();
[[maybe_unused]] const auto data = finishSentRequest(
requestId,
FinishRequestReason::Cancel);
if (hashes && !_cdnUncheckedParts.empty()) {
crl::on_main(this, [=] {
requestMoreCdnFileHashes();
});
}
}
void DownloadMtprotoTask::addToQueue(int priority) {
_owner->enqueue(this, priority);
}
void DownloadMtprotoTask::removeFromQueue() {
_owner->remove(this);
}
void DownloadMtprotoTask::partLoaded(
int64 offset,
const QByteArray &bytes) {
feedPart(offset, bytes);
}
bool DownloadMtprotoTask::normalPartFailed(
QByteArray fileReference,
const MTP::Error &error,
mtpRequestId requestId) {
if (MTP::IsDefaultHandledError(error)) {
return false;
}
if (error.code() == 400
&& error.type().startsWith(u"FILE_REFERENCE_"_q)) {
api().refreshFileReference(
_origin,
this,
requestId,
fileReference);
return true;
}
return partFailed(error, requestId);
}
bool DownloadMtprotoTask::partFailed(
const MTP::Error &error,
mtpRequestId requestId) {
if (MTP::IsDefaultHandledError(error)) {
return false;
}
cancelOnFail();
return true;
}
bool DownloadMtprotoTask::cdnPartFailed(
const MTP::Error &error,
mtpRequestId requestId) {
if (MTP::IsDefaultHandledError(error)) {
return false;
}
if (error.type() == u"FILE_TOKEN_INVALID"_q
|| error.type() == u"REQUEST_TOKEN_INVALID"_q) {
const auto requestData = finishSentRequest(
requestId,
FinishRequestReason::Redirect);
changeCDNParams(
requestData,
0,
QByteArray(),
QByteArray(),
QByteArray(),
QVector<MTPFileHash>());
return true;
}
return partFailed(error, requestId);
}
void DownloadMtprotoTask::switchToCDN(
const RequestData &requestData,
const MTPDupload_fileCdnRedirect &redirect) {
changeCDNParams(
requestData,
redirect.vdc_id().v,
redirect.vfile_token().v,
redirect.vencryption_key().v,
redirect.vencryption_iv().v,
redirect.vfile_hashes().v);
}
void DownloadMtprotoTask::addCdnHashes(
const QVector<MTPFileHash> &hashes) {
for (const auto &hash : hashes) {
hash.match([&](const MTPDfileHash &data) {
_cdnFileHashes.emplace(
data.voffset().v,
CdnFileHash{ data.vlimit().v, data.vhash().v });
});
}
}
void DownloadMtprotoTask::changeCDNParams(
const RequestData &requestData,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes) {
if (dcId != 0
&& (encryptionKey.size() != MTP::CTRState::KeySize
|| encryptionIV.size() != MTP::CTRState::IvecSize)) {
LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params"
).arg(encryptionKey.size()
).arg(encryptionIV.size()));
cancelOnFail();
return;
}
auto resendAllRequests = (_cdnDcId != dcId
|| _cdnToken != token
|| _cdnEncryptionKey != encryptionKey
|| _cdnEncryptionIV != encryptionIV);
_cdnDcId = dcId;
_cdnToken = token;
_cdnEncryptionKey = encryptionKey;
_cdnEncryptionIV = encryptionIV;
addCdnHashes(hashes);
if (resendAllRequests && !_sentRequests.empty()) {
auto resendRequests = std::vector<RequestData>();
resendRequests.reserve(_sentRequests.size());
while (!_sentRequests.empty()) {
const auto requestId = _sentRequests.begin()->first;
api().request(requestId).cancel();
resendRequests.push_back(finishSentRequest(
requestId,
FinishRequestReason::Redirect));
}
for (const auto &requestData : resendRequests) {
makeRequest(requestData);
}
}
makeRequest(requestData);
}
} // namespace Storage