All mtproto downloads using DownloadMtprotoTask.

This commit is contained in:
John Preston 2019-12-05 11:32:33 +03:00
parent 4611727ab9
commit ee94e78533
41 changed files with 1081 additions and 1010 deletions

View File

@ -55,7 +55,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "ui/emoji_config.h"
#include "support/support_helper.h"
#include "storage/localimageloader.h"
#include "storage/file_download_mtproto.h"
#include "storage/download_manager_mtproto.h"
#include "storage/file_upload.h"
#include "storage/storage_facade.h"
#include "storage/storage_shared_media.h"
@ -2973,12 +2973,12 @@ void ApiWrap::requestFileReference(
void ApiWrap::refreshFileReference(
Data::FileOrigin origin,
not_null<mtpFileLoader*> loader,
not_null<Storage::DownloadMtprotoTask*> task,
int requestId,
const QByteArray &current) {
return refreshFileReference(origin, crl::guard(loader, [=](
return refreshFileReference(origin, crl::guard(task, [=](
const UpdatedFileReferences &data) {
loader->refreshFileReferenceFrom(data, requestId, current);
task->refreshFileReferenceFrom(data, requestId, current);
}));
}

View File

@ -20,7 +20,6 @@ struct MessageGroupId;
struct SendingAlbum;
enum class SendMediaType;
struct FileLoadTo;
class mtpFileLoader;
namespace Main {
class Session;
@ -38,6 +37,7 @@ class Result;
namespace Storage {
enum class SharedMediaType : signed char;
struct PreparedList;
class DownloadMtprotoTask;
} // namespace Storage
namespace Dialogs {
@ -201,7 +201,7 @@ public:
FileReferencesHandler &&handler);
void refreshFileReference(
Data::FileOrigin origin,
not_null<mtpFileLoader*> loader,
not_null<Storage::DownloadMtprotoTask*> task,
int requestId,
const QByteArray &current);

View File

@ -30,6 +30,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_channel.h"
#include "data/data_chat.h"
#include "data/data_user.h"
#include "data/data_file_origin.h"
#include "base/unixtime.h"
#include "main/main_session.h"
#include "observer_peer.h"

View File

@ -14,6 +14,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "base/call_delayed.h"
#include "core/application.h"
#include "main/main_account.h"
#include "mtproto/facade.h"
#include "ui/widgets/checkbox.h"
#include "ui/widgets/buttons.h"
#include "ui/widgets/input_fields.h"

View File

@ -9,6 +9,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_file_origin.h"
#include "lang/lang_keys.h"
#include "chat_helpers/stickers.h"
#include "boxes/confirm_box.h"

View File

@ -10,6 +10,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_channel.h"
#include "data/data_file_origin.h"
#include "core/application.h"
#include "lang/lang_keys.h"
#include "mainwidget.h"

View File

@ -12,6 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_chat.h"
#include "data/data_user.h"
#include "data/data_peer_values.h"
#include "data/data_file_origin.h"
#include "mainwindow.h"
#include "apiwrap.h"
#include "storage/localstorage.h"

View File

@ -11,7 +11,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_user.h"
#include "styles/style_chat_helpers.h"
#include "data/data_file_origin.h"
#include "ui/widgets/buttons.h"
#include "ui/widgets/input_fields.h"
#include "ui/effects/ripple_animation.h"
@ -27,6 +27,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "history/view/history_view_cursor_state.h"
#include "facades.h"
#include "app.h"
#include "styles/style_chat_helpers.h"
#include <QtWidgets/QApplication>

View File

@ -9,6 +9,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_file_origin.h"
#include "boxes/stickers_box.h"
#include "boxes/confirm_box.h"
#include "lang/lang_keys.h"

View File

@ -10,6 +10,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_channel.h"
#include "data/data_file_origin.h"
#include "ui/widgets/buttons.h"
#include "ui/effects/animations.h"
#include "ui/effects/ripple_animation.h"

View File

@ -341,6 +341,7 @@ QString PlatformString() {
}
void StartCatching(not_null<Core::Launcher*> launcher) {
return; AssertIsDebug();
#ifndef DESKTOP_APP_DISABLE_CRASH_REPORTS
ProcessAnnotations["Binary"] = cExeName().toUtf8().constData();
ProcessAnnotations["ApiId"] = QString::number(ApiId).toUtf8().constData();

View File

@ -817,7 +817,6 @@ void DocumentData::destroyLoader() const {
if (cancelled()) {
loader->cancel();
}
loader->stop();
}
bool DocumentData::loading() const {

View File

@ -37,6 +37,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_web_page.h"
#include "data/data_poll.h"
#include "data/data_channel.h"
#include "data/data_file_origin.h"
#include "lang/lang_keys.h"
#include "layout.h"
#include "storage/file_upload.h"

View File

@ -40,6 +40,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_chat.h"
#include "data/data_user.h"
#include "data/data_scheduled_messages.h"
#include "data/data_file_origin.h"
#include "history/history.h"
#include "history/history_item.h"
#include "history/history_message.h"

View File

@ -21,6 +21,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_session.h"
#include "data/data_document.h"
#include "data/data_media_types.h"
#include "data/data_file_origin.h"
#include "app.h"
#include "styles/style_history.h"

View File

@ -13,13 +13,13 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_photo.h"
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_file_origin.h"
#include "history/history_item.h"
#include "history/history.h"
#include "history/view/history_view_cursor_state.h"
#include "window/themes/window_theme.h"
#include "window/window_session_controller.h"
#include "window/window_peer_menu.h"
#include "storage/file_download.h"
#include "ui/widgets/popup_menu.h"
#include "ui/ui_utility.h"
#include "ui/inactive_press.h"
@ -574,7 +574,7 @@ void ListWidget::start() {
}
}, lifetime());
ObservableViewer(
session().downloader().taskFinished()
session().downloaderTaskFinished()
) | rpl::start_with_next([this] { update(); }, lifetime());
session().data().itemLayoutChanged(
) | rpl::start_with_next([this](auto item) {

View File

@ -10,6 +10,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_photo.h"
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_file_origin.h"
#include "styles/style_overview.h"
#include "styles/style_history.h"
#include "styles/style_chat_helpers.h"

View File

@ -10,6 +10,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_photo.h"
#include "data/data_document.h"
#include "data/data_peer.h"
#include "data/data_file_origin.h"
#include "core/click_handler_types.h"
#include "inline_bots/inline_bot_result.h"
#include "inline_bots/inline_bot_layout_internal.h"

View File

@ -11,6 +11,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_photo.h"
#include "data/data_document.h"
#include "data/data_session.h"
#include "data/data_file_origin.h"
#include "inline_bots/inline_bot_layout_item.h"
#include "inline_bots/inline_bot_send_data.h"
#include "storage/file_download.h"

View File

@ -12,7 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_channel.h"
#include "data/data_user.h"
#include "data/data_session.h"
#include "styles/style_chat_helpers.h"
#include "data/data_file_origin.h"
#include "ui/widgets/buttons.h"
#include "ui/widgets/shadow.h"
#include "ui/effects/ripple_animation.h"
@ -34,6 +34,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "history/view/history_view_cursor_state.h"
#include "facades.h"
#include "app.h"
#include "styles/style_chat_helpers.h"
#include <QtWidgets/QApplication>

View File

@ -13,6 +13,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "main/main_account.h"
#include "chat_helpers/stickers_emoji_pack.h"
#include "storage/file_download.h"
#include "storage/download_manager_mtproto.h"
#include "storage/file_upload.h"
#include "storage/localstorage.h"
#include "storage/storage_facade.h"
@ -44,7 +45,7 @@ Session::Session(
, _autoLockTimer([=] { checkAutoLock(); })
, _api(std::make_unique<ApiWrap>(this))
, _calls(std::make_unique<Calls::Instance>(this))
, _downloader(std::make_unique<Storage::DownloadManager>(_api.get()))
, _downloader(std::make_unique<Storage::DownloadManagerMtproto>(_api.get()))
, _uploader(std::make_unique<Storage::Uploader>(_api.get()))
, _storage(std::make_unique<Storage::Facade>())
, _notifications(std::make_unique<Window::Notifications::System>(this))

View File

@ -29,7 +29,7 @@ class Session;
} // namespace Data
namespace Storage {
class DownloadManager;
class DownloadManagerMtproto;
class Uploader;
class Facade;
} // namespace Storage
@ -80,7 +80,7 @@ public:
}
bool validateSelf(const MTPUser &user);
[[nodiscard]] Storage::DownloadManager &downloader() {
[[nodiscard]] Storage::DownloadManagerMtproto &downloader() {
return *_downloader;
}
[[nodiscard]] Storage::Uploader &uploader() {
@ -145,7 +145,7 @@ private:
const std::unique_ptr<ApiWrap> _api;
const std::unique_ptr<Calls::Instance> _calls;
const std::unique_ptr<Storage::DownloadManager> _downloader;
const std::unique_ptr<Storage::DownloadManagerMtproto> _downloader;
const std::unique_ptr<Storage::Uploader> _uploader;
const std::unique_ptr<Storage::Facade> _storage;
const std::unique_ptr<Window::Notifications::System> _notifications;

View File

@ -23,6 +23,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_chat.h"
#include "data/data_user.h"
#include "data/data_scheduled_messages.h"
#include "data/data_file_origin.h"
#include "api/api_text_entities.h"
#include "ui/special_buttons.h"
#include "ui/widgets/buttons.h"

View File

@ -21,27 +21,17 @@ constexpr auto kMaxConcurrentRequests = 4;
} // namespace
LoaderMtproto::LoaderMtproto(
not_null<Storage::DownloadManager*> owner,
not_null<Storage::DownloadManagerMtproto*> owner,
const StorageFileLocation &location,
int size,
Data::FileOrigin origin)
: _owner(owner)
, _location(location)
, _dcId(location.dcId())
: DownloadMtprotoTask(owner, location, origin)
, _size(size)
, _origin(origin)
, _api(_owner->api().instance()) {
}
LoaderMtproto::~LoaderMtproto() {
for (const auto [index, amount] : _amountByDcIndex) {
changeRequestedAmount(index, -amount);
}
_owner->remove(this);
, _api(api().instance()) {
}
std::optional<Storage::Cache::Key> LoaderMtproto::baseCacheKey() const {
return _location.bigFileBaseCacheKey();
return location().data.get<StorageFileLocation>().bigFileBaseCacheKey();
}
int LoaderMtproto::size() const {
@ -58,22 +48,19 @@ void LoaderMtproto::load(int offset) {
return;
}
}
if (_requests.contains(offset)) {
if (haveSentRequestForOffset(offset)) {
return;
} else if (_requested.add(offset)) {
_owner->enqueue(this); // #TODO download priority
addToQueue(); // #TODO download priority
}
});
}
void LoaderMtproto::stop() {
crl::on_main(this, [=] {
ranges::for_each(
base::take(_requests),
_api.requestCanceller(),
&base::flat_map<int, mtpRequestId>::value_type::second);
cancelAllRequests();
_requested.clear();
_owner->remove(this);
removeFromQueue();
});
}
@ -84,9 +71,9 @@ void LoaderMtproto::cancel(int offset) {
}
void LoaderMtproto::cancelForOffset(int offset) {
if (const auto requestId = _requests.take(offset)) {
_api.request(*requestId).cancel();
_owner->enqueue(this);
if (haveSentRequestForOffset(offset)) {
cancelRequestForOffset(offset);
addToQueue(); // #TODO download priority
} else {
_requested.remove(offset);
}
@ -107,100 +94,26 @@ void LoaderMtproto::increasePriority() {
});
}
void LoaderMtproto::changeRequestedAmount(int index, int amount) {
_owner->requestedAmountIncrement(_dcId, index, amount);
_amountByDcIndex[index] += amount;
}
MTP::DcId LoaderMtproto::dcId() const {
return _dcId;
}
bool LoaderMtproto::readyToRequest() const {
return !_requested.empty();
}
void LoaderMtproto::loadPart(int dcIndex) {
const auto offset = _requested.take().value_or(-1);
if (offset < 0) {
return;
}
int LoaderMtproto::takeNextRequestOffset() {
const auto offset = _requested.take();
changeRequestedAmount(dcIndex, kPartSize);
const auto usedFileReference = _location.fileReference();
const auto id = _api.request(MTPupload_GetFile(
MTP_flags(0),
_location.tl(Auth().userId()),
MTP_int(offset),
MTP_int(kPartSize)
)).done([=](const MTPupload_File &result) {
changeRequestedAmount(dcIndex, -kPartSize);
requestDone(offset, result);
}).fail([=](const RPCError &error) {
changeRequestedAmount(dcIndex, -kPartSize);
requestFailed(offset, error, usedFileReference);
}).toDC(
MTP::downloadDcId(_dcId, dcIndex)
).send();
_requests.emplace(offset, id);
Ensures(offset.has_value());
return *offset;
}
void LoaderMtproto::requestDone(int offset, const MTPupload_File &result) {
result.match([&](const MTPDupload_file &data) {
_requests.erase(offset);
_owner->enqueue(this);
_parts.fire({ offset, data.vbytes().v });
}, [&](const MTPDupload_fileCdnRedirect &data) {
changeCdnParams(
offset,
data.vdc_id().v,
data.vfile_token().v,
data.vencryption_key().v,
data.vencryption_iv().v,
data.vfile_hashes().v);
});
bool LoaderMtproto::feedPart(int offset, const QByteArray &bytes) {
_parts.fire({ offset, bytes });
return true;
}
void LoaderMtproto::changeCdnParams(
int offset,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes) {
// #TODO streaming later cdn
void LoaderMtproto::cancelOnFail() {
_parts.fire({ LoadedPart::kFailedOffset });
}
void LoaderMtproto::requestFailed(
int offset,
const RPCError &error,
const QByteArray &usedFileReference) {
const auto &type = error.type();
const auto fail = [=] {
_parts.fire({ LoadedPart::kFailedOffset });
};
if (error.code() != 400 || !type.startsWith(qstr("FILE_REFERENCE_"))) {
return fail();
}
const auto callback = [=](const Data::UpdatedFileReferences &updated) {
_location.refreshFileReference(updated);
if (_location.fileReference() == usedFileReference) {
fail();
} else if (!_requests.take(offset)) {
// Request with such offset was already cancelled.
return;
} else {
_requested.add(offset);
_owner->enqueue(this);
}
};
_owner->api().refreshFileReference(
_origin,
crl::guard(this, callback));
}
rpl::producer<LoadedPart> LoaderMtproto::parts() const {
return _parts.events();
}

View File

@ -10,22 +10,18 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "media/streaming/media_streaming_loader.h"
#include "mtproto/sender.h"
#include "data/data_file_origin.h"
#include "storage/file_download.h"
#include "storage/download_manager_mtproto.h"
namespace Media {
namespace Streaming {
class LoaderMtproto
: public Loader
, public base::has_weak_ptr
, public Storage::Downloader {
class LoaderMtproto : public Loader, public Storage::DownloadMtprotoTask {
public:
LoaderMtproto(
not_null<Storage::DownloadManager*> owner,
not_null<Storage::DownloadManagerMtproto*> owner,
const StorageFileLocation &location,
int size,
Data::FileOrigin origin);
~LoaderMtproto();
[[nodiscard]] auto baseCacheKey() const
-> std::optional<Storage::Cache::Key> override;
@ -44,39 +40,18 @@ public:
void clearAttachedDownloader() override;
private:
MTP::DcId dcId() const override;
bool readyToRequest() const override;
void loadPart(int dcIndex) override;
int takeNextRequestOffset() override;
bool feedPart(int offset, const QByteArray &bytes) override;
void cancelOnFail() override;
void requestDone(int offset, const MTPupload_File &result);
void requestFailed(
int offset,
const RPCError &error,
const QByteArray &usedFileReference);
void changeCdnParams(
int offset,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes);
void cancelForOffset(int offset);
void changeRequestedAmount(int index, int amount);
const not_null<Storage::DownloadManager*> _owner;
// _location can be changed with an updated file_reference.
StorageFileLocation _location;
MTP::DcId _dcId = 0;
const int _size = 0;
const Data::FileOrigin _origin;
MTP::Sender _api;
PriorityQueue _requested;
base::flat_map<int, mtpRequestId> _requests;
base::flat_map<int, int> _amountByDcIndex;
rpl::event_stream<LoadedPart> _parts;
Storage::StreamedFileDownloader *_downloader = nullptr;

View File

@ -36,6 +36,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_channel.h"
#include "data/data_chat.h"
#include "data/data_user.h"
#include "data/data_file_origin.h"
#include "window/themes/window_theme_preview.h"
#include "window/window_peer_menu.h"
#include "window/window_session_controller.h"

View File

@ -12,6 +12,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "data/data_web_page.h"
#include "data/data_media_types.h"
#include "data/data_peer.h"
#include "data/data_file_origin.h"
#include "styles/style_overview.h"
#include "styles/style_history.h"
#include "core/file_utilities.h"

View File

@ -40,6 +40,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "core/application.h"
#include "data/data_session.h"
#include "data/data_cloud_themes.h"
#include "data/data_file_origin.h"
#include "chat_helpers/emoji_sets_manager.h"
#include "base/platform/base_platform_info.h"
#include "base/call_delayed.h"

View File

@ -0,0 +1,714 @@
/*
This file is part of Telegram Desktop,
the official desktop application for the Telegram messaging service.
For license and copyright information please follow this link:
https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/
#include "storage/download_manager_mtproto.h"
#include "mtproto/facade.h"
#include "mtproto/mtproto_auth_key.h"
#include "mtproto/mtproto_rpc_sender.h"
#include "main/main_session.h"
#include "apiwrap.h"
#include "base/openssl_help.h"
namespace Storage {
namespace {
// How much time without download causes additional session kill.
constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
// Max 16 file parts downloaded at the same time, 128 KB each.
constexpr auto kMaxFileQueries = 16;
constexpr auto kMaxWaitedInConnection = 512 * 1024;
// Max 8 http[s] files downloaded at the same time.
constexpr auto kMaxWebFileQueries = 8;
constexpr auto kStartSessionsCount = 1;
constexpr auto kMaxSessionsCount = 8;
constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
} // namespace
void DownloadManagerMtproto::Queue::enqueue(not_null<Task*> task) {
const auto i = ranges::find(_tasks, task);
if (i != end(_tasks)) {
return;
}
_tasks.push_back(task);
_previousGeneration.erase(
ranges::remove(_previousGeneration, task),
end(_previousGeneration));
}
void DownloadManagerMtproto::Queue::remove(not_null<Task*> task) {
_tasks.erase(ranges::remove(_tasks, task), end(_tasks));
_previousGeneration.erase(
ranges::remove(_previousGeneration, task),
end(_previousGeneration));
}
void DownloadManagerMtproto::Queue::resetGeneration() {
if (!_previousGeneration.empty()) {
_tasks.reserve(_tasks.size() + _previousGeneration.size());
std::copy(
begin(_previousGeneration),
end(_previousGeneration),
std::back_inserter(_tasks));
_previousGeneration.clear();
}
std::swap(_tasks, _previousGeneration);
}
bool DownloadManagerMtproto::Queue::empty() const {
return _tasks.empty() && _previousGeneration.empty();
}
auto DownloadManagerMtproto::Queue::nextTask() const -> Task* {
auto &&all = ranges::view::concat(_tasks, _previousGeneration);
const auto i = ranges::find(all, true, &Task::readyToRequest);
return (i != all.end()) ? i->get() : nullptr;
}
DownloadManagerMtproto::DownloadManagerMtproto(not_null<ApiWrap*> api)
: _api(api)
, _resetGenerationTimer([=] { resetGeneration(); })
, _killDownloadSessionsTimer([=] { killDownloadSessions(); }) {
}
DownloadManagerMtproto::~DownloadManagerMtproto() {
killDownloadSessions();
}
void DownloadManagerMtproto::enqueue(not_null<Task*> task) {
const auto dcId = task->dcId();
auto &queue = _queues[dcId];
queue.enqueue(task);
if (!_resetGenerationTimer.isActive()) {
_resetGenerationTimer.callOnce(kResetDownloadPrioritiesTimeout);
}
checkSendNext(dcId, queue);
}
void DownloadManagerMtproto::remove(not_null<Task*> task) {
const auto dcId = task->dcId();
auto &queue = _queues[dcId];
queue.remove(task);
checkSendNext(dcId, queue);
}
void DownloadManagerMtproto::resetGeneration() {
_resetGenerationTimer.cancel();
for (auto &[dcId, queue] : _queues) {
queue.resetGeneration();
}
}
void DownloadManagerMtproto::checkSendNext() {
for (auto &[dcId, queue] : _queues) {
if (queue.empty()) {
continue;
}
checkSendNext(dcId, queue);
}
}
void DownloadManagerMtproto::checkSendNext(MTP::DcId dcId, Queue &queue) {
const auto bestIndex = [&] {
const auto i = _requestedBytesAmount.find(dcId);
if (i == end(_requestedBytesAmount)) {
_requestedBytesAmount[dcId].resize(kStartSessionsCount);
return 0;
}
const auto j = ranges::min_element(i->second);
const auto already = *j;
return (already + kDownloadPartSize <= kMaxWaitedInConnection)
? (j - begin(i->second))
: -1;
}();
if (bestIndex >= 0) {
if (const auto task = queue.nextTask()) {
task->loadPart(bestIndex);
}
}
}
void DownloadManagerMtproto::requestedAmountIncrement(
MTP::DcId dcId,
int index,
int amount) {
Expects(dcId != 0);
using namespace rpl::mappers;
auto it = _requestedBytesAmount.find(dcId);
if (it == _requestedBytesAmount.end()) {
it = _requestedBytesAmount.emplace(
dcId,
std::vector<int>(dcId ? kStartSessionsCount : 1, 0)
).first;
}
it->second[index] += amount;
if (amount > 0) {
killDownloadSessionsStop(dcId);
} else {
crl::on_main(this, [=] { checkSendNext(); });
if (ranges::find_if(it->second, _1 > 0) == end(it->second)) {
killDownloadSessionsStart(dcId);
}
}
}
int DownloadManagerMtproto::chooseDcIndexForRequest(MTP::DcId dcId) {
const auto i = _requestedBytesAmount.find(dcId);
return (i != end(_requestedBytesAmount))
? (ranges::min_element(i->second) - begin(i->second))
: 0;
}
void DownloadManagerMtproto::killDownloadSessionsStart(MTP::DcId dcId) {
if (!_killDownloadSessionTimes.contains(dcId)) {
_killDownloadSessionTimes.emplace(
dcId,
crl::now() + kKillSessionTimeout);
}
if (!_killDownloadSessionsTimer.isActive()) {
_killDownloadSessionsTimer.callOnce(kKillSessionTimeout + 5);
}
}
void DownloadManagerMtproto::killDownloadSessionsStop(MTP::DcId dcId) {
_killDownloadSessionTimes.erase(dcId);
if (_killDownloadSessionTimes.empty()
&& _killDownloadSessionsTimer.isActive()) {
_killDownloadSessionsTimer.cancel();
}
}
void DownloadManagerMtproto::killDownloadSessions() {
const auto now = crl::now();
auto left = kKillSessionTimeout;
for (auto i = _killDownloadSessionTimes.begin(); i != _killDownloadSessionTimes.end(); ) {
if (i->second <= now) {
const auto j = _requestedBytesAmount.find(i->first);
if (j != end(_requestedBytesAmount)) {
for (auto index = 0; index != int(j->second.size()); ++index) {
MTP::stopSession(MTP::downloadDcId(i->first, index));
}
}
i = _killDownloadSessionTimes.erase(i);
} else {
if (i->second - now < left) {
left = i->second - now;
}
++i;
}
}
if (!_killDownloadSessionTimes.empty()) {
_killDownloadSessionsTimer.callOnce(left);
}
}
DownloadMtprotoTask::DownloadMtprotoTask(
not_null<DownloadManagerMtproto*> owner,
const StorageFileLocation &location,
Data::FileOrigin origin)
: _owner(owner)
, _dcId(location.dcId())
, _location({ location })
, _origin(origin) {
}
DownloadMtprotoTask::DownloadMtprotoTask(
not_null<DownloadManagerMtproto*> owner,
MTP::DcId dcId,
const Location &location)
: _owner(owner)
, _dcId(dcId)
, _location(location) {
}
DownloadMtprotoTask::~DownloadMtprotoTask() {
cancelAllRequests();
_owner->remove(this);
}
MTP::DcId DownloadMtprotoTask::dcId() const {
return _dcId;
}
Data::FileOrigin DownloadMtprotoTask::fileOrigin() const {
return _origin;
}
uint64 DownloadMtprotoTask::objectId() const {
if (const auto v = base::get_if<StorageFileLocation>(&_location.data)) {
return v->objectId();
}
return 0;
}
const DownloadMtprotoTask::Location &DownloadMtprotoTask::location() const {
return _location;
}
void DownloadMtprotoTask::refreshFileReferenceFrom(
const Data::UpdatedFileReferences &updates,
int requestId,
const QByteArray &current) {
if (const auto v = base::get_if<StorageFileLocation>(&_location.data)) {
v->refreshFileReference(updates);
if (v->fileReference() == current) {
cancelOnFail();
return;
}
} else {
cancelOnFail();
return;
}
makeRequest(finishSentRequest(requestId));
}
void DownloadMtprotoTask::loadPart(int dcIndex) {
makeRequest({ takeNextRequestOffset(), dcIndex });
}
mtpRequestId DownloadMtprotoTask::sendRequest(const RequestData &requestData) {
const auto offset = requestData.offset;
const auto limit = Storage::kDownloadPartSize;
const auto shiftedDcId = MTP::downloadDcId(
_cdnDcId ? _cdnDcId : dcId(),
requestData.dcIndex);
if (_cdnDcId) {
return api().request(MTPupload_GetCdnFile(
MTP_bytes(_cdnToken),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_CdnFile &result, mtpRequestId id) {
cdnPartLoaded(result, id);
}).fail([=](const RPCError &error, mtpRequestId id) {
cdnPartFailed(error, id);
}).toDC(shiftedDcId).send();
}
return _location.data.match([&](const WebFileLocation &location) {
return api().request(MTPupload_GetWebFile(
MTP_inputWebFileLocation(
MTP_bytes(location.url()),
MTP_long(location.accessHash())),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
webPartLoaded(result, id);
}).fail([=](const RPCError &error, mtpRequestId id) {
partFailed(error, id);
}).toDC(shiftedDcId).send();
}, [&](const GeoPointLocation &location) {
return api().request(MTPupload_GetWebFile(
MTP_inputWebFileGeoPointLocation(
MTP_inputGeoPoint(
MTP_double(location.lat),
MTP_double(location.lon)),
MTP_long(location.access),
MTP_int(location.width),
MTP_int(location.height),
MTP_int(location.zoom),
MTP_int(location.scale)),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_WebFile &result, mtpRequestId id) {
webPartLoaded(result, id);
}).fail([=](const RPCError &error, mtpRequestId id) {
partFailed(error, id);
}).toDC(shiftedDcId).send();
}, [&](const StorageFileLocation &location) {
const auto reference = location.fileReference();
return api().request(MTPupload_GetFile(
MTP_flags(0),
location.tl(api().session().userId()),
MTP_int(offset),
MTP_int(limit)
)).done([=](const MTPupload_File &result, mtpRequestId id) {
normalPartLoaded(result, id);
}).fail([=](const RPCError &error, mtpRequestId id) {
normalPartFailed(reference, error, id);
}).toDC(shiftedDcId).send();
});
}
bool DownloadMtprotoTask::setWebFileSizeHook(int size) {
return true;
}
void DownloadMtprotoTask::makeRequest(const RequestData &requestData) {
placeSentRequest(sendRequest(requestData), requestData);
}
void DownloadMtprotoTask::requestMoreCdnFileHashes() {
if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) {
return;
}
const auto requestData = _cdnUncheckedParts.cbegin()->first;
const auto shiftedDcId = MTP::downloadDcId(
dcId(),
requestData.dcIndex);
_cdnHashesRequestId = api().request(MTPupload_GetCdnFileHashes(
MTP_bytes(_cdnToken),
MTP_int(requestData.offset)
)).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
getCdnFileHashesDone(result, id);
}).fail([=](const RPCError &error, mtpRequestId id) {
cdnPartFailed(error, id);
}).toDC(shiftedDcId).send();
placeSentRequest(_cdnHashesRequestId, requestData);
}
void DownloadMtprotoTask::normalPartLoaded(
const MTPupload_File &result,
mtpRequestId requestId) {
const auto requestData = finishSentRequest(requestId);
result.match([&](const MTPDupload_fileCdnRedirect &data) {
switchToCDN(requestData, data);
}, [&](const MTPDupload_file &data) {
partLoaded(requestData.offset, data.vbytes().v);
});
}
void DownloadMtprotoTask::webPartLoaded(
const MTPupload_WebFile &result,
mtpRequestId requestId) {
result.match([&](const MTPDupload_webFile &data) {
const auto requestData = finishSentRequest(requestId);
if (setWebFileSizeHook(data.vsize().v)) {
partLoaded(requestData.offset, data.vbytes().v);
}
});
}
void DownloadMtprotoTask::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
const auto requestData = finishSentRequest(requestId);
result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
const auto shiftedDcId = MTP::downloadDcId(
dcId(),
requestData.dcIndex);
const auto requestId = api().request(MTPupload_ReuploadCdnFile(
MTP_bytes(_cdnToken),
data.vrequest_token()
)).done([=](const MTPVector<MTPFileHash> &result, mtpRequestId id) {
reuploadDone(result, id);
}).fail([=](const RPCError &error, mtpRequestId id) {
cdnPartFailed(error, id);
}).toDC(shiftedDcId).send();
placeSentRequest(requestId, requestData);
}, [&](const MTPDupload_cdnFile &data) {
auto key = bytes::make_span(_cdnEncryptionKey);
auto iv = bytes::make_span(_cdnEncryptionIV);
Expects(key.size() == MTP::CTRState::KeySize);
Expects(iv.size() == MTP::CTRState::IvecSize);
auto state = MTP::CTRState();
auto ivec = bytes::make_span(state.ivec);
std::copy(iv.begin(), iv.end(), ivec.begin());
auto counterOffset = static_cast<uint32>(requestData.offset) >> 4;
state.ivec[15] = static_cast<uchar>(counterOffset & 0xFF);
state.ivec[14] = static_cast<uchar>((counterOffset >> 8) & 0xFF);
state.ivec[13] = static_cast<uchar>((counterOffset >> 16) & 0xFF);
state.ivec[12] = static_cast<uchar>((counterOffset >> 24) & 0xFF);
auto decryptInPlace = data.vbytes().v;
auto buffer = bytes::make_detached_span(decryptInPlace);
MTP::aesCtrEncrypt(buffer, key.data(), &state);
switch (checkCdnFileHash(requestData.offset, buffer)) {
case CheckCdnHashResult::NoHash: {
_cdnUncheckedParts.emplace(requestData, decryptInPlace);
requestMoreCdnFileHashes();
} return;
case CheckCdnHashResult::Invalid: {
LOG(("API Error: Wrong cdnFileHash for offset %1."
).arg(requestData.offset));
cancelOnFail();
} return;
case CheckCdnHashResult::Good: {
partLoaded(requestData.offset, decryptInPlace);
} return;
}
Unexpected("Result of checkCdnFileHash()");
});
}
DownloadMtprotoTask::CheckCdnHashResult DownloadMtprotoTask::checkCdnFileHash(
int offset,
bytes::const_span buffer) {
const auto cdnFileHashIt = _cdnFileHashes.find(offset);
if (cdnFileHashIt == _cdnFileHashes.cend()) {
return CheckCdnHashResult::NoHash;
}
const auto realHash = openssl::Sha256(buffer);
const auto receivedHash = bytes::make_span(cdnFileHashIt->second.hash);
if (bytes::compare(realHash, receivedHash)) {
return CheckCdnHashResult::Invalid;
}
return CheckCdnHashResult::Good;
}
void DownloadMtprotoTask::reuploadDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId) {
const auto requestData = finishSentRequest(requestId);
addCdnHashes(result.v);
makeRequest(requestData);
}
void DownloadMtprotoTask::getCdnFileHashesDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId) {
Expects(_cdnHashesRequestId == requestId);
_cdnHashesRequestId = 0;
const auto requestData = finishSentRequest(requestId);
addCdnHashes(result.v);
auto someMoreChecked = false;
for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
const auto uncheckedData = i->first;
const auto uncheckedBytes = bytes::make_span(i->second);
switch (checkCdnFileHash(uncheckedData.offset, uncheckedBytes)) {
case CheckCdnHashResult::NoHash: {
++i;
} break;
case CheckCdnHashResult::Invalid: {
LOG(("API Error: Wrong cdnFileHash for offset %1."
).arg(uncheckedData.offset));
cancelOnFail();
return;
} break;
case CheckCdnHashResult::Good: {
someMoreChecked = true;
const auto goodOffset = uncheckedData.offset;
const auto goodBytes = std::move(i->second);
const auto weak = base::make_weak(this);
i = _cdnUncheckedParts.erase(i);
if (!feedPart(goodOffset, goodBytes) || !weak) {
return;
}
} break;
default: Unexpected("Result of checkCdnFileHash()");
}
}
if (!someMoreChecked) {
LOG(("API Error: "
"Could not find cdnFileHash for offset %1 "
"after getCdnFileHashes request."
).arg(requestData.offset));
cancelOnFail();
return;
}
requestMoreCdnFileHashes();
}
void DownloadMtprotoTask::placeSentRequest(
mtpRequestId requestId,
const RequestData &requestData) {
_owner->requestedAmountIncrement(
dcId(),
requestData.dcIndex,
Storage::kDownloadPartSize);
const auto [i, ok1] = _sentRequests.emplace(requestId, requestData);
const auto [j, ok2] = _requestByOffset.emplace(
requestData.offset,
requestId);
Ensures(ok1 && ok2);
}
auto DownloadMtprotoTask::finishSentRequest(mtpRequestId requestId)
-> RequestData {
auto it = _sentRequests.find(requestId);
Assert(it != _sentRequests.cend());
const auto result = it->second;
_owner->requestedAmountIncrement(
dcId(),
result.dcIndex,
-Storage::kDownloadPartSize);
_sentRequests.erase(it);
_requestByOffset.erase(result.offset);
return result;
}
bool DownloadMtprotoTask::haveSentRequests() const {
return !_sentRequests.empty() || !_cdnUncheckedParts.empty();
}
bool DownloadMtprotoTask::haveSentRequestForOffset(int offset) const {
return _requestByOffset.contains(offset)
|| _cdnUncheckedParts.contains({ offset, 0 });
}
void DownloadMtprotoTask::cancelAllRequests() {
while (!_sentRequests.empty()) {
cancelRequest(_sentRequests.begin()->first);
}
_cdnUncheckedParts.clear();
}
void DownloadMtprotoTask::cancelRequestForOffset(int offset) {
const auto i = _requestByOffset.find(offset);
if (i != end(_requestByOffset)) {
cancelRequest(i->second);
}
_cdnUncheckedParts.remove({ offset, 0 });
}
void DownloadMtprotoTask::cancelRequest(mtpRequestId requestId) {
api().request(requestId).cancel();
[[maybe_unused]] const auto data = finishSentRequest(requestId);
}
void DownloadMtprotoTask::addToQueue() {
_owner->enqueue(this);
}
void DownloadMtprotoTask::removeFromQueue() {
_owner->remove(this);
}
void DownloadMtprotoTask::partLoaded(
int offset,
const QByteArray &bytes) {
feedPart(offset, bytes);
}
bool DownloadMtprotoTask::normalPartFailed(
QByteArray fileReference,
const RPCError &error,
mtpRequestId requestId) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
if (error.code() == 400
&& error.type().startsWith(qstr("FILE_REFERENCE_"))) {
api().refreshFileReference(
_origin,
this,
requestId,
fileReference);
return true;
}
return partFailed(error, requestId);
}
bool DownloadMtprotoTask::partFailed(
const RPCError &error,
mtpRequestId requestId) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
cancelOnFail();
return true;
}
bool DownloadMtprotoTask::cdnPartFailed(
const RPCError &error,
mtpRequestId requestId) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
if (requestId == _cdnHashesRequestId) {
_cdnHashesRequestId = 0;
}
if (error.type() == qstr("FILE_TOKEN_INVALID")
|| error.type() == qstr("REQUEST_TOKEN_INVALID")) {
const auto requestData = finishSentRequest(requestId);
changeCDNParams(
requestData,
0,
QByteArray(),
QByteArray(),
QByteArray(),
QVector<MTPFileHash>());
return true;
}
return partFailed(error, requestId);
}
void DownloadMtprotoTask::switchToCDN(
const RequestData &requestData,
const MTPDupload_fileCdnRedirect &redirect) {
changeCDNParams(
requestData,
redirect.vdc_id().v,
redirect.vfile_token().v,
redirect.vencryption_key().v,
redirect.vencryption_iv().v,
redirect.vfile_hashes().v);
}
void DownloadMtprotoTask::addCdnHashes(
const QVector<MTPFileHash> &hashes) {
for (const auto &hash : hashes) {
hash.match([&](const MTPDfileHash &data) {
_cdnFileHashes.emplace(
data.voffset().v,
CdnFileHash{ data.vlimit().v, data.vhash().v });
});
}
}
void DownloadMtprotoTask::changeCDNParams(
const RequestData &requestData,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes) {
if (dcId != 0
&& (encryptionKey.size() != MTP::CTRState::KeySize
|| encryptionIV.size() != MTP::CTRState::IvecSize)) {
LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params"
).arg(encryptionKey.size()
).arg(encryptionIV.size()));
cancelOnFail();
return;
}
auto resendAllRequests = (_cdnDcId != dcId
|| _cdnToken != token
|| _cdnEncryptionKey != encryptionKey
|| _cdnEncryptionIV != encryptionIV);
_cdnDcId = dcId;
_cdnToken = token;
_cdnEncryptionKey = encryptionKey;
_cdnEncryptionIV = encryptionIV;
addCdnHashes(hashes);
if (resendAllRequests && !_sentRequests.empty()) {
auto resendRequests = std::vector<RequestData>();
resendRequests.reserve(_sentRequests.size());
while (!_sentRequests.empty()) {
const auto requestId = _sentRequests.begin()->first;
api().request(requestId).cancel();
resendRequests.push_back(finishSentRequest(requestId));
}
for (const auto &requestData : resendRequests) {
makeRequest(requestData);
}
}
makeRequest(requestData);
}
} // namespace Storage

View File

@ -0,0 +1,227 @@
/*
This file is part of Telegram Desktop,
the official desktop application for the Telegram messaging service.
For license and copyright information please follow this link:
https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
*/
#pragma once
#include "data/data_file_origin.h"
#include "base/timer.h"
#include "base/weak_ptr.h"
class ApiWrap;
class RPCError;
namespace Storage {
// Different part sizes are not supported for now :(
// Because we start downloading with some part size
// and then we get a CDN-redirect where we support only
// fixed part size download for hash checking.
constexpr auto kDownloadPartSize = 128 * 1024;
class DownloadMtprotoTask;
class DownloadManagerMtproto final : public base::has_weak_ptr {
public:
using Task = DownloadMtprotoTask;
explicit DownloadManagerMtproto(not_null<ApiWrap*> api);
~DownloadManagerMtproto();
[[nodiscard]] ApiWrap &api() const {
return *_api;
}
void enqueue(not_null<Task*> task);
void remove(not_null<Task*> task);
[[nodiscard]] base::Observable<void> &taskFinished() {
return _taskFinishedObservable;
}
void requestedAmountIncrement(MTP::DcId dcId, int index, int amount);
[[nodiscard]] int chooseDcIndexForRequest(MTP::DcId dcId);
private:
class Queue final {
public:
void enqueue(not_null<Task*> task);
void remove(not_null<Task*> task);
void resetGeneration();
[[nodiscard]] bool empty() const;
[[nodiscard]] Task *nextTask() const;
private:
std::vector<not_null<Task*>> _tasks;
std::vector<not_null<Task*>> _previousGeneration;
};
void checkSendNext();
void checkSendNext(MTP::DcId dcId, Queue &queue);
void killDownloadSessionsStart(MTP::DcId dcId);
void killDownloadSessionsStop(MTP::DcId dcId);
void killDownloadSessions();
void resetGeneration();
const not_null<ApiWrap*> _api;
base::Observable<void> _taskFinishedObservable;
base::flat_map<MTP::DcId, std::vector<int>> _requestedBytesAmount;
base::Timer _resetGenerationTimer;
base::flat_map<MTP::DcId, crl::time> _killDownloadSessionTimes;
base::Timer _killDownloadSessionsTimer;
base::flat_map<MTP::DcId, Queue> _queues;
};
class DownloadMtprotoTask : public base::has_weak_ptr {
public:
struct Location {
base::variant<
StorageFileLocation,
WebFileLocation,
GeoPointLocation> data;
};
DownloadMtprotoTask(
not_null<DownloadManagerMtproto*> owner,
const StorageFileLocation &location,
Data::FileOrigin origin);
DownloadMtprotoTask(
not_null<DownloadManagerMtproto*> owner,
MTP::DcId dcId,
const Location &location);
virtual ~DownloadMtprotoTask();
[[nodiscard]] MTP::DcId dcId() const;
[[nodiscard]] Data::FileOrigin fileOrigin() const;
[[nodiscard]] uint64 objectId() const;
[[nodiscard]] const Location &location() const;
[[nodiscard]] virtual bool readyToRequest() const = 0;
void loadPart(int dcIndex);
void refreshFileReferenceFrom(
const Data::UpdatedFileReferences &updates,
int requestId,
const QByteArray &current);
protected:
[[nodiscard]] bool haveSentRequests() const;
[[nodiscard]] bool haveSentRequestForOffset(int offset) const;
void cancelAllRequests();
void cancelRequestForOffset(int offset);
void addToQueue();
void removeFromQueue();
[[nodiscard]] ApiWrap &api() const {
return _owner->api();
}
private:
struct RequestData {
int offset = 0;
int dcIndex = 0;
inline bool operator<(const RequestData &other) const {
return offset < other.offset;
}
};
struct CdnFileHash {
CdnFileHash(int limit, QByteArray hash) : limit(limit), hash(hash) {
}
int limit = 0;
QByteArray hash;
};
enum class CheckCdnHashResult {
NoHash,
Invalid,
Good,
};
// Called only if readyToRequest() == true.
[[nodiscard]] virtual int takeNextRequestOffset() = 0;
virtual bool feedPart(int offset, const QByteArray &bytes) = 0;
virtual bool setWebFileSizeHook(int size);
virtual void cancelOnFail() = 0;
void cancelRequest(mtpRequestId requestId);
void makeRequest(const RequestData &requestData);
void normalPartLoaded(
const MTPupload_File &result,
mtpRequestId requestId);
void webPartLoaded(
const MTPupload_WebFile &result,
mtpRequestId requestId);
void cdnPartLoaded(
const MTPupload_CdnFile &result,
mtpRequestId requestId);
void reuploadDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId);
void requestMoreCdnFileHashes();
void getCdnFileHashesDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId);
void partLoaded(int offset, const QByteArray &bytes);
bool partFailed(const RPCError &error, mtpRequestId requestId);
bool normalPartFailed(
QByteArray fileReference,
const RPCError &error,
mtpRequestId requestId);
bool cdnPartFailed(const RPCError &error, mtpRequestId requestId);
[[nodiscard]] mtpRequestId sendRequest(const RequestData &requestData);
void placeSentRequest(
mtpRequestId requestId,
const RequestData &requestData);
[[nodiscard]] RequestData finishSentRequest(mtpRequestId requestId);
void switchToCDN(
const RequestData &requestData,
const MTPDupload_fileCdnRedirect &redirect);
void addCdnHashes(const QVector<MTPFileHash> &hashes);
void changeCDNParams(
const RequestData &requestData,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes);
[[nodiscard]] CheckCdnHashResult checkCdnFileHash(
int offset,
bytes::const_span buffer);
const not_null<DownloadManagerMtproto*> _owner;
const MTP::DcId _dcId = 0;
// _location can be changed with an updated file_reference.
Location _location;
const Data::FileOrigin _origin;
base::flat_map<mtpRequestId, RequestData> _sentRequests;
base::flat_map<int, mtpRequestId> _requestByOffset;
MTP::DcId _cdnDcId = 0;
QByteArray _cdnToken;
QByteArray _cdnEncryptionKey;
QByteArray _cdnEncryptionIV;
base::flat_map<int, CdnFileHash> _cdnFileHashes;
base::flat_map<RequestData, QByteArray> _cdnUncheckedParts;
mtpRequestId _cdnHashesRequestId = 0;
};
} // namespace Storage

View File

@ -23,210 +23,6 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "facades.h"
#include "app.h"
namespace Storage {
namespace {
// How much time without download causes additional session kill.
constexpr auto kKillSessionTimeout = 15 * crl::time(1000);
// Max 16 file parts downloaded at the same time, 128 KB each.
constexpr auto kMaxFileQueries = 16;
constexpr auto kMaxWaitedInConnection = 512 * 1024;
// Max 8 http[s] files downloaded at the same time.
constexpr auto kMaxWebFileQueries = 8;
constexpr auto kStartSessionsCount = 1;
constexpr auto kMaxSessionsCount = 8;
constexpr auto kResetDownloadPrioritiesTimeout = crl::time(200);
} // namespace
void DownloadManager::Queue::enqueue(not_null<Downloader*> loader) {
const auto i = ranges::find(_loaders, loader);
if (i != end(_loaders)) {
return;
}
_loaders.push_back(loader);
_previousGeneration.erase(
ranges::remove(_previousGeneration, loader),
end(_previousGeneration));
}
void DownloadManager::Queue::remove(not_null<Downloader*> loader) {
_loaders.erase(ranges::remove(_loaders, loader), end(_loaders));
_previousGeneration.erase(
ranges::remove(_previousGeneration, loader),
end(_previousGeneration));
}
void DownloadManager::Queue::resetGeneration() {
if (!_previousGeneration.empty()) {
_loaders.reserve(_loaders.size() + _previousGeneration.size());
std::copy(
begin(_previousGeneration),
end(_previousGeneration),
std::back_inserter(_loaders));
_previousGeneration.clear();
}
std::swap(_loaders, _previousGeneration);
}
bool DownloadManager::Queue::empty() const {
return _loaders.empty() && _previousGeneration.empty();
}
Downloader *DownloadManager::Queue::nextLoader() const {
auto &&all = ranges::view::concat(_loaders, _previousGeneration);
const auto i = ranges::find(all, true, &Downloader::readyToRequest);
return (i != all.end()) ? i->get() : nullptr;
}
DownloadManager::DownloadManager(not_null<ApiWrap*> api)
: _api(api)
, _resetGenerationTimer([=] { resetGeneration(); })
, _killDownloadSessionsTimer([=] { killDownloadSessions(); }) {
}
DownloadManager::~DownloadManager() {
killDownloadSessions();
}
void DownloadManager::enqueue(not_null<Downloader*> loader) {
const auto dcId = loader->dcId();
(dcId ? _mtprotoLoaders[dcId] : _webLoaders).enqueue(loader);
if (!_resetGenerationTimer.isActive()) {
_resetGenerationTimer.callOnce(kResetDownloadPrioritiesTimeout);
}
checkSendNext();
}
void DownloadManager::remove(not_null<Downloader*> loader) {
const auto dcId = loader->dcId();
(dcId ? _mtprotoLoaders[dcId] : _webLoaders).remove(loader);
crl::on_main(&_api->session(), [=] { checkSendNext(); });
}
void DownloadManager::resetGeneration() {
_resetGenerationTimer.cancel();
for (auto &[dcId, queue] : _mtprotoLoaders) {
queue.resetGeneration();
}
_webLoaders.resetGeneration();
}
void DownloadManager::checkSendNext() {
for (auto &[dcId, queue] : _mtprotoLoaders) {
if (queue.empty()) {
continue;
}
const auto bestIndex = [&] {
const auto i = _requestedBytesAmount.find(dcId);
if (i == end(_requestedBytesAmount)) {
_requestedBytesAmount[dcId].resize(kStartSessionsCount);
return 0;
}
const auto j = ranges::min_element(i->second);
const auto already = *j;
return (already + kDownloadPartSize <= kMaxWaitedInConnection)
? (j - begin(i->second))
: -1;
}();
if (bestIndex < 0) {
continue;
}
if (const auto loader = queue.nextLoader()) {
loader->loadPart(bestIndex);
}
}
if (_requestedBytesAmount[0].empty()) {
_requestedBytesAmount[0] = std::vector<int>(1, 0);
}
if (_requestedBytesAmount[0][0] < kMaxWebFileQueries) {
if (const auto loader = _webLoaders.nextLoader()) {
loader->loadPart(0);
}
}
}
void DownloadManager::requestedAmountIncrement(
MTP::DcId dcId,
int index,
int amount) {
using namespace rpl::mappers;
auto it = _requestedBytesAmount.find(dcId);
if (it == _requestedBytesAmount.end()) {
it = _requestedBytesAmount.emplace(
dcId,
std::vector<int>(dcId ? kStartSessionsCount : 1, 0)
).first;
}
it->second[index] += amount;
if (!dcId) {
return; // webLoaders.
}
if (amount > 0) {
killDownloadSessionsStop(dcId);
} else if (ranges::find_if(it->second, _1 > 0) == end(it->second)) {
killDownloadSessionsStart(dcId);
checkSendNext();
}
}
int DownloadManager::chooseDcIndexForRequest(MTP::DcId dcId) {
const auto i = _requestedBytesAmount.find(dcId);
return (i != end(_requestedBytesAmount))
? (ranges::min_element(i->second) - begin(i->second))
: 0;
}
void DownloadManager::killDownloadSessionsStart(MTP::DcId dcId) {
if (!_killDownloadSessionTimes.contains(dcId)) {
_killDownloadSessionTimes.emplace(
dcId,
crl::now() + kKillSessionTimeout);
}
if (!_killDownloadSessionsTimer.isActive()) {
_killDownloadSessionsTimer.callOnce(kKillSessionTimeout + 5);
}
}
void DownloadManager::killDownloadSessionsStop(MTP::DcId dcId) {
_killDownloadSessionTimes.erase(dcId);
if (_killDownloadSessionTimes.empty()
&& _killDownloadSessionsTimer.isActive()) {
_killDownloadSessionsTimer.cancel();
}
}
void DownloadManager::killDownloadSessions() {
const auto now = crl::now();
auto left = kKillSessionTimeout;
for (auto i = _killDownloadSessionTimes.begin(); i != _killDownloadSessionTimes.end(); ) {
if (i->second <= now) {
const auto j = _requestedBytesAmount.find(i->first);
if (j != end(_requestedBytesAmount)) {
for (auto index = 0; index != int(j->second.size()); ++index) {
MTP::stopSession(MTP::downloadDcId(i->first, index));
}
}
i = _killDownloadSessionTimes.erase(i);
} else {
if (i->second - now < left) {
left = i->second - now;
}
++i;
}
}
if (!_killDownloadSessionTimes.empty()) {
_killDownloadSessionsTimer.callOnce(left);
}
}
} // namespace Storage
FileLoader::FileLoader(
const QString &toFile,
int32 size,
@ -445,7 +241,9 @@ void FileLoader::cancel() {
void FileLoader::cancel(bool fail) {
const auto started = (currentOffset() > 0);
cancelRequests();
cancelHook();
_cancelled = true;
_finished = true;
if (_fileIsOpen) {

View File

@ -10,12 +10,12 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "base/observer.h"
#include "base/timer.h"
#include "base/binary_guard.h"
#include "data/data_file_origin.h"
#include "mtproto/facade.h"
#include <QtNetwork/QNetworkReply>
class ApiWrap;
namespace Data {
struct FileOrigin;
} // namespace Data
namespace Main {
class Session;
@ -26,88 +26,22 @@ namespace Cache {
struct Key;
} // namespace Cache
// 10 MB max file could be hold in memory
// This value is used in local cache database settings!
constexpr auto kMaxFileInMemory = 10 * 1024 * 1024; // 10 MB max file could be hold in memory
constexpr auto kMaxFileInMemory = 10 * 1024 * 1024;
constexpr auto kMaxVoiceInMemory = 2 * 1024 * 1024; // 2 MB audio is hold in memory and auto loaded
constexpr auto kMaxStickerInMemory = 2 * 1024 * 1024; // 2 MB stickers hold in memory, auto loaded and displayed inline
// 2 MB audio is hold in memory and auto loaded
constexpr auto kMaxVoiceInMemory = 2 * 1024 * 1024;
// 2 MB stickers hold in memory, auto loaded and displayed inline
constexpr auto kMaxStickerInMemory = 2 * 1024 * 1024;
// 10 MB GIF and mp4 animations held in memory while playing
constexpr auto kMaxWallPaperInMemory = kMaxFileInMemory;
constexpr auto kMaxAnimationInMemory = kMaxFileInMemory; // 10 MB gif and mp4 animations held in memory while playing
constexpr auto kMaxWallPaperDimension = 4096; // 4096x4096 is max area.
constexpr auto kMaxAnimationInMemory = kMaxFileInMemory;
// Different part sizes are not supported for now :(
// Because we start downloading with some part size
// and then we get a cdn-redirect where we support only
// fixed part size download for hash checking.
constexpr auto kDownloadPartSize = 128 * 1024;
class Downloader {
public:
virtual ~Downloader() = default;
[[nodiscard]] virtual MTP::DcId dcId() const = 0;
[[nodiscard]] virtual bool readyToRequest() const = 0;
virtual void loadPart(int dcIndex) = 0;
};
class DownloadManager final : public base::has_weak_ptr {
public:
explicit DownloadManager(not_null<ApiWrap*> api);
~DownloadManager();
[[nodiscard]] ApiWrap &api() const {
return *_api;
}
void enqueue(not_null<Downloader*> loader);
void remove(not_null<Downloader*> loader);
[[nodiscard]] base::Observable<void> &taskFinished() {
return _taskFinishedObservable;
}
// dcId == 0 is for web requests.
void requestedAmountIncrement(MTP::DcId dcId, int index, int amount);
[[nodiscard]] int chooseDcIndexForRequest(MTP::DcId dcId);
private:
class Queue final {
public:
void enqueue(not_null<Downloader*> loader);
void remove(not_null<Downloader*> loader);
void resetGeneration();
[[nodiscard]] bool empty() const;
[[nodiscard]] Downloader *nextLoader() const;
private:
std::vector<not_null<Downloader*>> _loaders;
std::vector<not_null<Downloader*>> _previousGeneration;
};
void checkSendNext();
void killDownloadSessionsStart(MTP::DcId dcId);
void killDownloadSessionsStop(MTP::DcId dcId);
void killDownloadSessions();
void resetGeneration();
const not_null<ApiWrap*> _api;
base::Observable<void> _taskFinishedObservable;
base::flat_map<MTP::DcId, std::vector<int>> _requestedBytesAmount;
base::Timer _resetGenerationTimer;
base::flat_map<MTP::DcId, crl::time> _killDownloadSessionTimes;
base::Timer _killDownloadSessionsTimer;
base::flat_map<MTP::DcId, Queue> _mtprotoLoaders;
Queue _webLoaders;
};
// 4096x4096 is max area.
constexpr auto kMaxWallPaperDimension = 4096;
} // namespace Storage
@ -132,8 +66,9 @@ public:
LoadFromCloudSetting fromCloud,
bool autoLoading,
uint8 cacheTag);
virtual ~FileLoader();
Main::Session &session() const;
[[nodiscard]] Main::Session &session() const;
bool finished() const {
return _finished;
@ -153,7 +88,8 @@ public:
QString fileName() const {
return _filename;
}
virtual Data::FileOrigin fileOrigin() const;
// Used in MainWidget::documentLoadFailed.
[[nodiscard]] virtual Data::FileOrigin fileOrigin() const;
float64 currentProgress() const;
virtual int currentOffset() const;
int fullSize() const;
@ -171,10 +107,6 @@ public:
return _autoLoading;
}
virtual void stop() {
}
virtual ~FileLoader();
void localLoaded(
const StorageImageSaved &result,
const QByteArray &imageFormat,
@ -198,7 +130,7 @@ protected:
void loadLocal(const Storage::Cache::Key &key);
virtual Storage::Cache::Key cacheKey() const = 0;
virtual std::optional<MediaKey> fileLocationKey() const = 0;
virtual void cancelRequests() = 0;
virtual void cancelHook() = 0;
virtual void startLoading() = 0;
void cancel(bool failed);

View File

@ -35,10 +35,7 @@ mtpFileLoader::mtpFileLoader(
fromCloud,
autoLoading,
cacheTag)
, _downloader(&session().downloader())
, _dcId(location.dcId())
, _location(location)
, _origin(origin) {
, DownloadMtprotoTask(&session().downloader(), location, origin) {
}
mtpFileLoader::mtpFileLoader(
@ -55,9 +52,10 @@ mtpFileLoader::mtpFileLoader(
fromCloud,
autoLoading,
cacheTag)
, _downloader(&session().downloader())
, _dcId(Global::WebFileDcId())
, _location(location) {
, DownloadMtprotoTask(
&session().downloader(),
Global::WebFileDcId(),
{ location }) {
}
mtpFileLoader::mtpFileLoader(
@ -74,506 +72,85 @@ mtpFileLoader::mtpFileLoader(
fromCloud,
autoLoading,
cacheTag)
, _downloader(&session().downloader())
, _dcId(Global::WebFileDcId())
, _location(location) {
}
mtpFileLoader::~mtpFileLoader() {
cancelRequests();
_downloader->remove(this);
, DownloadMtprotoTask(
&session().downloader(),
Global::WebFileDcId(),
{ location }) {
}
Data::FileOrigin mtpFileLoader::fileOrigin() const {
return _origin;
return DownloadMtprotoTask::fileOrigin();
}
uint64 mtpFileLoader::objId() const {
if (const auto storage = base::get_if<StorageFileLocation>(&_location)) {
return storage->objectId();
}
return 0;
}
void mtpFileLoader::refreshFileReferenceFrom(
const Data::UpdatedFileReferences &updates,
int requestId,
const QByteArray &current) {
if (const auto storage = base::get_if<StorageFileLocation>(&_location)) {
storage->refreshFileReference(updates);
if (storage->fileReference() == current) {
cancel(true);
return;
}
} else {
cancel(true);
return;
}
makeRequest(finishSentRequest(requestId));
}
MTP::DcId mtpFileLoader::dcId() const {
return _dcId;
return DownloadMtprotoTask::objectId();
}
bool mtpFileLoader::readyToRequest() const {
return !_finished
&& !_lastComplete
&& (_sentRequests.empty() || _size != 0)
&& (_size != 0 || !haveSentRequests())
&& (!_size || _nextRequestOffset < _size);
}
void mtpFileLoader::loadPart(int dcIndex) {
int mtpFileLoader::takeNextRequestOffset() {
Expects(readyToRequest());
makeRequest({ _nextRequestOffset, dcIndex });
const auto result = _nextRequestOffset;
_nextRequestOffset += Storage::kDownloadPartSize;
}
mtpRequestId mtpFileLoader::sendRequest(const RequestData &requestData) {
const auto offset = requestData.offset;
const auto limit = Storage::kDownloadPartSize;
const auto shiftedDcId = MTP::downloadDcId(
_cdnDcId ? _cdnDcId : dcId(),
requestData.dcIndex);
if (_cdnDcId) {
return MTP::send(
MTPupload_GetCdnFile(
MTP_bytes(_cdnToken),
MTP_int(offset),
MTP_int(limit)),
rpcDone(&mtpFileLoader::cdnPartLoaded),
rpcFail(&mtpFileLoader::cdnPartFailed),
shiftedDcId,
50);
}
return _location.match([&](const WebFileLocation &location) {
return MTP::send(
MTPupload_GetWebFile(
MTP_inputWebFileLocation(
MTP_bytes(location.url()),
MTP_long(location.accessHash())),
MTP_int(offset),
MTP_int(limit)),
rpcDone(&mtpFileLoader::webPartLoaded),
rpcFail(&mtpFileLoader::partFailed),
shiftedDcId,
50);
}, [&](const GeoPointLocation &location) {
return MTP::send(
MTPupload_GetWebFile(
MTP_inputWebFileGeoPointLocation(
MTP_inputGeoPoint(
MTP_double(location.lat),
MTP_double(location.lon)),
MTP_long(location.access),
MTP_int(location.width),
MTP_int(location.height),
MTP_int(location.zoom),
MTP_int(location.scale)),
MTP_int(offset),
MTP_int(limit)),
rpcDone(&mtpFileLoader::webPartLoaded),
rpcFail(&mtpFileLoader::partFailed),
shiftedDcId,
50);
}, [&](const StorageFileLocation &location) {
return MTP::send(
MTPupload_GetFile(
MTP_flags(0),
location.tl(session().userId()),
MTP_int(offset),
MTP_int(limit)),
rpcDone(&mtpFileLoader::normalPartLoaded),
rpcFail(
&mtpFileLoader::normalPartFailed,
location.fileReference()),
shiftedDcId,
50);
});
}
void mtpFileLoader::makeRequest(const RequestData &requestData) {
Expects(!_finished);
placeSentRequest(sendRequest(requestData), requestData);
}
void mtpFileLoader::requestMoreCdnFileHashes() {
Expects(!_finished);
if (_cdnHashesRequestId || _cdnUncheckedParts.empty()) {
return;
}
const auto requestData = _cdnUncheckedParts.cbegin()->first;
const auto shiftedDcId = MTP::downloadDcId(
dcId(),
requestData.dcIndex);
const auto requestId = _cdnHashesRequestId = MTP::send(
MTPupload_GetCdnFileHashes(
MTP_bytes(_cdnToken),
MTP_int(requestData.offset)),
rpcDone(&mtpFileLoader::getCdnFileHashesDone),
rpcFail(&mtpFileLoader::cdnPartFailed),
shiftedDcId);
placeSentRequest(requestId, requestData);
}
void mtpFileLoader::normalPartLoaded(
const MTPupload_File &result,
mtpRequestId requestId) {
Expects(!_finished);
const auto requestData = finishSentRequest(requestId);
result.match([&](const MTPDupload_fileCdnRedirect &data) {
switchToCDN(requestData, data);
}, [&](const MTPDupload_file &data) {
partLoaded(requestData.offset, bytes::make_span(data.vbytes().v));
});
}
void mtpFileLoader::webPartLoaded(
const MTPupload_WebFile &result,
mtpRequestId requestId) {
result.match([&](const MTPDupload_webFile &data) {
const auto requestData = finishSentRequest(requestId);
if (!_size) {
_size = data.vsize().v;
} else if (data.vsize().v != _size) {
LOG(("MTP Error: "
"Bad size provided by bot for webDocument: %1, real: %2"
).arg(_size
).arg(data.vsize().v));
cancel(true);
return;
}
partLoaded(requestData.offset, bytes::make_span(data.vbytes().v));
});
}
void mtpFileLoader::cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId) {
Expects(!_finished);
const auto requestData = finishSentRequest(requestId);
result.match([&](const MTPDupload_cdnFileReuploadNeeded &data) {
const auto shiftedDcId = MTP::downloadDcId(
dcId(),
requestData.dcIndex);
const auto requestId = MTP::send(
MTPupload_ReuploadCdnFile(
MTP_bytes(_cdnToken),
data.vrequest_token()),
rpcDone(&mtpFileLoader::reuploadDone),
rpcFail(&mtpFileLoader::cdnPartFailed),
shiftedDcId);
placeSentRequest(requestId, requestData);
}, [&](const MTPDupload_cdnFile &data) {
auto key = bytes::make_span(_cdnEncryptionKey);
auto iv = bytes::make_span(_cdnEncryptionIV);
Expects(key.size() == MTP::CTRState::KeySize);
Expects(iv.size() == MTP::CTRState::IvecSize);
auto state = MTP::CTRState();
auto ivec = bytes::make_span(state.ivec);
std::copy(iv.begin(), iv.end(), ivec.begin());
auto counterOffset = static_cast<uint32>(requestData.offset) >> 4;
state.ivec[15] = static_cast<uchar>(counterOffset & 0xFF);
state.ivec[14] = static_cast<uchar>((counterOffset >> 8) & 0xFF);
state.ivec[13] = static_cast<uchar>((counterOffset >> 16) & 0xFF);
state.ivec[12] = static_cast<uchar>((counterOffset >> 24) & 0xFF);
auto decryptInPlace = data.vbytes().v;
auto buffer = bytes::make_detached_span(decryptInPlace);
MTP::aesCtrEncrypt(buffer, key.data(), &state);
switch (checkCdnFileHash(requestData.offset, buffer)) {
case CheckCdnHashResult::NoHash: {
_cdnUncheckedParts.emplace(requestData, decryptInPlace);
requestMoreCdnFileHashes();
} return;
case CheckCdnHashResult::Invalid: {
LOG(("API Error: Wrong cdnFileHash for offset %1."
).arg(requestData.offset));
cancel(true);
} return;
case CheckCdnHashResult::Good: {
partLoaded(requestData.offset, buffer);
} return;
}
Unexpected("Result of checkCdnFileHash()");
});
}
mtpFileLoader::CheckCdnHashResult mtpFileLoader::checkCdnFileHash(
int offset,
bytes::const_span buffer) {
const auto cdnFileHashIt = _cdnFileHashes.find(offset);
if (cdnFileHashIt == _cdnFileHashes.cend()) {
return CheckCdnHashResult::NoHash;
}
const auto realHash = openssl::Sha256(buffer);
const auto receivedHash = bytes::make_span(cdnFileHashIt->second.hash);
if (bytes::compare(realHash, receivedHash)) {
return CheckCdnHashResult::Invalid;
}
return CheckCdnHashResult::Good;
}
void mtpFileLoader::reuploadDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId) {
const auto requestData = finishSentRequest(requestId);
addCdnHashes(result.v);
makeRequest(requestData);
}
void mtpFileLoader::getCdnFileHashesDone(
const MTPVector<MTPFileHash> &result,
mtpRequestId requestId) {
Expects(!_finished);
Expects(_cdnHashesRequestId == requestId);
_cdnHashesRequestId = 0;
const auto requestData = finishSentRequest(requestId);
addCdnHashes(result.v);
auto someMoreChecked = false;
for (auto i = _cdnUncheckedParts.begin(); i != _cdnUncheckedParts.cend();) {
const auto uncheckedData = i->first;
const auto uncheckedBytes = bytes::make_span(i->second);
switch (checkCdnFileHash(uncheckedData.offset, uncheckedBytes)) {
case CheckCdnHashResult::NoHash: {
++i;
} break;
case CheckCdnHashResult::Invalid: {
LOG(("API Error: Wrong cdnFileHash for offset %1."
).arg(uncheckedData.offset));
cancel(true);
return;
} break;
case CheckCdnHashResult::Good: {
someMoreChecked = true;
const auto goodOffset = uncheckedData.offset;
const auto goodBytes = std::move(i->second);
const auto weak = QPointer<mtpFileLoader>(this);
i = _cdnUncheckedParts.erase(i);
if (!feedPart(goodOffset, bytes::make_span(goodBytes))
|| !weak) {
return;
} else if (_finished) {
notifyAboutProgress();
return;
}
} break;
default: Unexpected("Result of checkCdnFileHash()");
}
}
if (someMoreChecked) {
const auto weak = QPointer<mtpFileLoader>(this);
notifyAboutProgress();
if (weak) {
requestMoreCdnFileHashes();
}
return;
}
LOG(("API Error: "
"Could not find cdnFileHash for offset %1 "
"after getCdnFileHashes request."
).arg(requestData.offset));
cancel(true);
}
void mtpFileLoader::placeSentRequest(
mtpRequestId requestId,
const RequestData &requestData) {
Expects(!_finished);
_downloader->requestedAmountIncrement(
dcId(),
requestData.dcIndex,
Storage::kDownloadPartSize);
_sentRequests.emplace(requestId, requestData);
}
auto mtpFileLoader::finishSentRequest(mtpRequestId requestId)
-> RequestData {
auto it = _sentRequests.find(requestId);
Assert(it != _sentRequests.cend());
const auto result = it->second;
_downloader->requestedAmountIncrement(
dcId(),
result.dcIndex,
-Storage::kDownloadPartSize);
_sentRequests.erase(it);
return result;
}
bool mtpFileLoader::feedPart(int offset, bytes::const_span buffer) {
bool mtpFileLoader::feedPart(int offset, const QByteArray &bytes) {
const auto buffer = bytes::make_span(bytes);
if (!writeResultPart(offset, buffer)) {
return false;
}
if (buffer.empty() || (buffer.size() % 1024)) { // bad next offset
_lastComplete = true;
}
const auto finished = _sentRequests.empty()
&& _cdnUncheckedParts.empty()
const auto weak = QPointer<mtpFileLoader>(this);
const auto finished = !haveSentRequests()
&& (_lastComplete || (_size && _nextRequestOffset >= _size));
if (finished) {
_downloader->remove(this);
removeFromQueue();
if (!finalizeResult()) {
return false;
}
}
return true;
}
void mtpFileLoader::partLoaded(int offset, bytes::const_span buffer) {
if (feedPart(offset, buffer)) {
if (weak) {
notifyAboutProgress();
}
}
bool mtpFileLoader::normalPartFailed(
QByteArray fileReference,
const RPCError &error,
mtpRequestId requestId) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
if (error.code() == 400
&& error.type().startsWith(qstr("FILE_REFERENCE_"))) {
session().api().refreshFileReference(
_origin,
this,
requestId,
fileReference);
return true;
}
return partFailed(error, requestId);
}
bool mtpFileLoader::partFailed(
const RPCError &error,
mtpRequestId requestId) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
cancel(true);
return true;
}
bool mtpFileLoader::cdnPartFailed(
const RPCError &error,
mtpRequestId requestId) {
if (MTP::isDefaultHandledError(error)) {
return false;
}
void mtpFileLoader::cancelOnFail() {
cancel(true);
}
if (requestId == _cdnHashesRequestId) {
_cdnHashesRequestId = 0;
}
if (error.type() == qstr("FILE_TOKEN_INVALID")
|| error.type() == qstr("REQUEST_TOKEN_INVALID")) {
const auto requestData = finishSentRequest(requestId);
changeCDNParams(
requestData,
0,
QByteArray(),
QByteArray(),
QByteArray(),
QVector<MTPFileHash>());
bool mtpFileLoader::setWebFileSizeHook(int size) {
if (!_size || _size == size) {
_size = size;
return true;
}
return partFailed(error, requestId);
LOG(("MTP Error: "
"Bad size provided by bot for webDocument: %1, real: %2"
).arg(_size
).arg(size));
cancel(true);
return false;
}
void mtpFileLoader::startLoading() {
_downloader->enqueue(this);
addToQueue();
}
void mtpFileLoader::cancelRequests() {
while (!_sentRequests.empty()) {
auto requestId = _sentRequests.begin()->first;
MTP::cancel(requestId);
[[maybe_unused]] const auto data = finishSentRequest(requestId);
}
}
void mtpFileLoader::switchToCDN(
const RequestData &requestData,
const MTPDupload_fileCdnRedirect &redirect) {
changeCDNParams(
requestData,
redirect.vdc_id().v,
redirect.vfile_token().v,
redirect.vencryption_key().v,
redirect.vencryption_iv().v,
redirect.vfile_hashes().v);
}
void mtpFileLoader::addCdnHashes(const QVector<MTPFileHash> &hashes) {
for (const auto &hash : hashes) {
hash.match([&](const MTPDfileHash &data) {
_cdnFileHashes.emplace(
data.voffset().v,
CdnFileHash{ data.vlimit().v, data.vhash().v });
});
}
}
void mtpFileLoader::changeCDNParams(
const RequestData &requestData,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes) {
if (dcId != 0
&& (encryptionKey.size() != MTP::CTRState::KeySize
|| encryptionIV.size() != MTP::CTRState::IvecSize)) {
LOG(("Message Error: Wrong key (%1) / iv (%2) size in CDN params").arg(encryptionKey.size()).arg(encryptionIV.size()));
cancel(true);
return;
}
auto resendAllRequests = (_cdnDcId != dcId
|| _cdnToken != token
|| _cdnEncryptionKey != encryptionKey
|| _cdnEncryptionIV != encryptionIV);
_cdnDcId = dcId;
_cdnToken = token;
_cdnEncryptionKey = encryptionKey;
_cdnEncryptionIV = encryptionIV;
addCdnHashes(hashes);
if (resendAllRequests && !_sentRequests.empty()) {
auto resendRequests = std::vector<RequestData>();
resendRequests.reserve(_sentRequests.size());
while (!_sentRequests.empty()) {
auto requestId = _sentRequests.begin()->first;
MTP::cancel(requestId);
resendRequests.push_back(finishSentRequest(requestId));
}
for (const auto &requestData : resendRequests) {
makeRequest(requestData);
}
}
makeRequest(requestData);
void mtpFileLoader::cancelHook() {
cancelAllRequests();
}
Storage::Cache::Key mtpFileLoader::cacheKey() const {
return _location.match([&](const WebFileLocation &location) {
return location().data.match([&](const WebFileLocation &location) {
return Data::WebDocumentCacheKey(location);
}, [&](const GeoPointLocation &location) {
return Data::GeoPointCacheKey(location);

View File

@ -8,13 +8,11 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#pragma once
#include "storage/file_download.h"
#include "storage/download_manager_mtproto.h"
class StorageImageLocation;
class WebFileLocation;
class mtpFileLoader final
: public FileLoader
, public RPCSender
, public Storage::Downloader {
, private Storage::DownloadMtprotoTask {
public:
mtpFileLoader(
const StorageFileLocation &location,
@ -40,101 +38,21 @@ public:
uint8 cacheTag);
Data::FileOrigin fileOrigin() const override;
uint64 objId() const override;
void stop() override {
rpcInvalidate();
}
void refreshFileReferenceFrom(
const Data::UpdatedFileReferences &updates,
int requestId,
const QByteArray &current);
~mtpFileLoader();
private:
struct RequestData {
int offset = 0;
int dcIndex = 0;
inline bool operator<(const RequestData &other) const {
return offset < other.offset;
}
};
struct CdnFileHash {
CdnFileHash(int limit, QByteArray hash) : limit(limit), hash(hash) {
}
int limit = 0;
QByteArray hash;
};
Storage::Cache::Key cacheKey() const override;
std::optional<MediaKey> fileLocationKey() const override;
void startLoading() override;
void cancelRequests() override;
void cancelHook() override;
void makeRequest(const RequestData &requestData);
MTP::DcId dcId() const override;
bool readyToRequest() const override;
void loadPart(int dcIndex) override;
void normalPartLoaded(const MTPupload_File &result, mtpRequestId requestId);
void webPartLoaded(const MTPupload_WebFile &result, mtpRequestId requestId);
void cdnPartLoaded(const MTPupload_CdnFile &result, mtpRequestId requestId);
void reuploadDone(const MTPVector<MTPFileHash> &result, mtpRequestId requestId);
void requestMoreCdnFileHashes();
void getCdnFileHashesDone(const MTPVector<MTPFileHash> &result, mtpRequestId requestId);
void partLoaded(int offset, bytes::const_span buffer);
bool feedPart(int offset, bytes::const_span buffer);
bool partFailed(const RPCError &error, mtpRequestId requestId);
bool normalPartFailed(QByteArray fileReference, const RPCError &error, mtpRequestId requestId);
bool cdnPartFailed(const RPCError &error, mtpRequestId requestId);
mtpRequestId sendRequest(const RequestData &requestData);
void placeSentRequest(
mtpRequestId requestId,
const RequestData &requestData);
[[nodiscard]] RequestData finishSentRequest(mtpRequestId requestId);
void switchToCDN(
const RequestData &requestData,
const MTPDupload_fileCdnRedirect &redirect);
void addCdnHashes(const QVector<MTPFileHash> &hashes);
void changeCDNParams(
const RequestData &requestData,
MTP::DcId dcId,
const QByteArray &token,
const QByteArray &encryptionKey,
const QByteArray &encryptionIV,
const QVector<MTPFileHash> &hashes);
enum class CheckCdnHashResult {
NoHash,
Invalid,
Good,
};
CheckCdnHashResult checkCdnFileHash(int offset, bytes::const_span buffer);
const not_null<Storage::DownloadManager*> _downloader;
const MTP::DcId _dcId = 0;
std::map<mtpRequestId, RequestData> _sentRequests;
int takeNextRequestOffset() override;
bool feedPart(int offset, const QByteArray &bytes) override;
void cancelOnFail() override;
bool setWebFileSizeHook(int size) override;
bool _lastComplete = false;
int32 _nextRequestOffset = 0;
base::variant<
StorageFileLocation,
WebFileLocation,
GeoPointLocation> _location;
Data::FileOrigin _origin;
MTP::DcId _cdnDcId = 0;
QByteArray _cdnToken;
QByteArray _cdnEncryptionKey;
QByteArray _cdnEncryptionIV;
base::flat_map<int, CdnFileHash> _cdnFileHashes;
base::flat_map<RequestData, QByteArray> _cdnUncheckedParts;
mtpRequestId _cdnHashesRequestId = 0;
};

View File

@ -455,7 +455,7 @@ webFileLoader::webFileLoader(
}
webFileLoader::~webFileLoader() {
cancelRequests();
cancelRequest();
}
QString webFileLoader::url() const {
@ -493,7 +493,7 @@ void webFileLoader::loadProgress(qint64 ready, qint64 total) {
}
void webFileLoader::loadFinished(const QByteArray &data) {
cancelRequests();
cancelRequest();
if (writeResultPart(0, bytes::make_span(data))) {
if (finalizeResult()) {
notifyAboutProgress();
@ -513,11 +513,11 @@ std::optional<MediaKey> webFileLoader::fileLocationKey() const {
return std::nullopt;
}
void webFileLoader::stop() {
cancelRequests();
void webFileLoader::cancelHook() {
cancelRequest();
}
void webFileLoader::cancelRequests() {
void webFileLoader::cancelRequest() {
if (!_manager) {
return;
}

View File

@ -24,10 +24,10 @@ public:
[[nodiscard]] QString url() const;
int currentOffset() const override;
void stop() override;
void cancelRequests() override;
private:
void cancelRequest();
void cancelHook() override;
void startLoading() override;
Storage::Cache::Key cacheKey() const override;
std::optional<MediaKey> fileLocationKey() const override;

View File

@ -63,7 +63,7 @@ StreamedFileDownloader::StreamedFileDownloader(
}
StreamedFileDownloader::~StreamedFileDownloader() {
stop();
cancelHook();
}
uint64 StreamedFileDownloader::objId() const {
@ -74,10 +74,6 @@ Data::FileOrigin StreamedFileDownloader::fileOrigin() const {
return _origin;
}
void StreamedFileDownloader::stop() {
cancelRequests();
}
void StreamedFileDownloader::requestParts() {
while (!_finished
&& _nextPartIndex < _partsCount
@ -121,7 +117,7 @@ std::optional<MediaKey> StreamedFileDownloader::fileLocationKey() const {
return _fileLocationKey;
}
void StreamedFileDownloader::cancelRequests() {
void StreamedFileDownloader::cancelHook() {
_partsRequested = 0;
_nextPartIndex = 0;

View File

@ -9,6 +9,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "storage/file_download.h"
#include "storage/cache/storage_cache_types.h"
#include "data/data_file_origin.h"
namespace Media {
namespace Streaming {
@ -41,7 +42,6 @@ public:
uint64 objId() const override;
Data::FileOrigin fileOrigin() const override;
void stop() override;
QByteArray readLoadedPart(int offset);
@ -49,7 +49,7 @@ private:
void startLoading() override;
Cache::Key cacheKey() const override;
std::optional<MediaKey> fileLocationKey() const override;
void cancelRequests() override;
void cancelHook() override;
void requestParts();
void requestPart();

View File

@ -306,7 +306,6 @@ void RemoteSource::destroyLoader() {
if (cancelled()) {
loader->cancel();
}
loader->stop();
}
void RemoteSource::loadLocal() {

View File

@ -65,7 +65,7 @@ std::unique_ptr<Manager> Create(System *system) {
Manager::Manager(System *system)
: Notifications::Manager(system)
, _inputCheckTimer([=] { checkLastInput(); }) {
subscribe(system->session().downloader().taskFinished(), [this] {
subscribe(system->session().downloaderTaskFinished(), [this] {
for (const auto &notification : _notifications) {
notification->updatePeerPhoto();
}

View File

@ -683,6 +683,8 @@
<(src_loc)/settings/settings_privacy_controllers.h
<(src_loc)/settings/settings_privacy_security.cpp
<(src_loc)/settings/settings_privacy_security.h
<(src_loc)/storage/download_manager_mtproto.cpp
<(src_loc)/storage/download_manager_mtproto.h
<(src_loc)/storage/file_download.cpp
<(src_loc)/storage/file_download.h
<(src_loc)/storage/file_download_mtproto.cpp