/* This file is part of Telegram Desktop, the official desktop application for the Telegram messaging service. For license and copyright information please follow this link: https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL */ #include "mtproto/connection_tcp.h" #include "base/bytes.h" #include "base/openssl_help.h" #include "base/qthelp_url.h" #include namespace MTP { namespace internal { namespace { constexpr auto kMinReceiveTimeout = TimeMs(2000); constexpr auto kMaxReceiveTimeout = TimeMs(8000); uint32 tcpPacketSize(const char *packet) { // must have at least 4 bytes readable uint32 result = (packet[0] > 0) ? packet[0] : 0; if (result == 0x7f) { const uchar *bytes = reinterpret_cast(packet); result = (((uint32(bytes[3]) << 8) | uint32(bytes[2])) << 8) | uint32(bytes[1]); return (result << 2) + 4; } return (result << 2) + 1; } } // namespace AbstractTCPConnection::AbstractTCPConnection( QThread *thread) : AbstractConnection(thread) , currentPos((char*)shortBuffer) { } AbstractTCPConnection::~AbstractTCPConnection() { } void AbstractTCPConnection::socketRead() { if (sock.state() != QAbstractSocket::ConnectedState) { LOG(("MTP error: socket not connected in socketRead(), state: %1").arg(sock.state())); emit error(kErrorCodeOther); return; } do { uint32 toRead = packetLeft ? packetLeft : (readingToShort ? (MTPShortBufferSize * sizeof(mtpPrime) - packetRead) : 4); if (readingToShort) { if (currentPos + toRead > ((char*)shortBuffer) + MTPShortBufferSize * sizeof(mtpPrime)) { longBuffer.resize(((packetRead + toRead) >> 2) + 1); memcpy(&longBuffer[0], shortBuffer, packetRead); currentPos = ((char*)&longBuffer[0]) + packetRead; readingToShort = false; } } else { if (longBuffer.size() * sizeof(mtpPrime) < packetRead + toRead) { longBuffer.resize(((packetRead + toRead) >> 2) + 1); currentPos = ((char*)&longBuffer[0]) + packetRead; } } int32 bytes = (int32)sock.read(currentPos, toRead); if (bytes > 0) { aesCtrEncrypt(currentPos, bytes, _receiveKey, &_receiveState); TCP_LOG(("TCP Info: read %1 bytes").arg(bytes)); packetRead += bytes; currentPos += bytes; if (packetLeft) { packetLeft -= bytes; if (!packetLeft) { socketPacket(currentPos - packetRead, packetRead); currentPos = (char*)shortBuffer; packetRead = packetLeft = 0; readingToShort = true; longBuffer.clear(); } else { TCP_LOG(("TCP Info: not enough %1 for packet! read %2").arg(packetLeft).arg(packetRead)); emit receivedSome(); } } else { bool move = false; while (packetRead >= 4) { uint32 packetSize = tcpPacketSize(currentPos - packetRead); if (packetSize < 5 || packetSize > MTPPacketSizeMax) { LOG(("TCP Error: packet size = %1").arg(packetSize)); emit error(kErrorCodeOther); return; } if (packetRead >= packetSize) { socketPacket(currentPos - packetRead, packetSize); packetRead -= packetSize; packetLeft = 0; move = true; } else { packetLeft = packetSize - packetRead; TCP_LOG(("TCP Info: not enough %1 for packet! size %2 read %3").arg(packetLeft).arg(packetSize).arg(packetRead)); emit receivedSome(); break; } } if (move) { if (!packetRead) { currentPos = (char*)shortBuffer; readingToShort = true; longBuffer.clear(); } else if (!readingToShort && packetRead < MTPShortBufferSize * sizeof(mtpPrime)) { memcpy(shortBuffer, currentPos - packetRead, packetRead); currentPos = (char*)shortBuffer + packetRead; readingToShort = true; longBuffer.clear(); } } } } else if (bytes < 0) { LOG(("TCP Error: socket read return -1")); emit error(kErrorCodeOther); return; } else { TCP_LOG(("TCP Info: no bytes read, but bytes available was true...")); break; } } while (sock.state() == QAbstractSocket::ConnectedState && sock.bytesAvailable()); } mtpBuffer AbstractTCPConnection::handleResponse(const char *packet, uint32 length) { if (length < 5 || length > MTPPacketSizeMax) { LOG(("TCP Error: bad packet size %1").arg(length)); return mtpBuffer(1, -500); } int32 size = packet[0], len = length - 1; if (size == 0x7f) { const uchar *bytes = reinterpret_cast(packet); size = (((uint32(bytes[3]) << 8) | uint32(bytes[2])) << 8) | uint32(bytes[1]); len -= 3; } if (size * int32(sizeof(mtpPrime)) != len) { LOG(("TCP Error: bad packet header")); TCP_LOG(("TCP Error: bad packet header, packet: %1").arg(Logs::mb(packet, length).str())); return mtpBuffer(1, -500); } const mtpPrime *packetdata = reinterpret_cast(packet + (length - len)); TCP_LOG(("TCP Info: packet received, size = %1").arg(size * sizeof(mtpPrime))); if (size == 1) { LOG(("TCP Error: error packet received, code = %1").arg(*packetdata)); return mtpBuffer(1, *packetdata); } mtpBuffer data(size); memcpy(data.data(), packetdata, size * sizeof(mtpPrime)); return data; } void AbstractTCPConnection::handleError(QAbstractSocket::SocketError e, QTcpSocket &sock) { switch (e) { case QAbstractSocket::ConnectionRefusedError: LOG(("TCP Error: socket connection refused - %1").arg(sock.errorString())); break; case QAbstractSocket::RemoteHostClosedError: TCP_LOG(("TCP Info: remote host closed socket connection - %1").arg(sock.errorString())); break; case QAbstractSocket::HostNotFoundError: LOG(("TCP Error: host not found - %1").arg(sock.errorString())); break; case QAbstractSocket::SocketTimeoutError: LOG(("TCP Error: socket timeout - %1").arg(sock.errorString())); break; case QAbstractSocket::NetworkError: LOG(("TCP Error: network - %1").arg(sock.errorString())); break; case QAbstractSocket::ProxyAuthenticationRequiredError: case QAbstractSocket::ProxyConnectionRefusedError: case QAbstractSocket::ProxyConnectionClosedError: case QAbstractSocket::ProxyConnectionTimeoutError: case QAbstractSocket::ProxyNotFoundError: case QAbstractSocket::ProxyProtocolError: LOG(("TCP Error: proxy (%1) - %2").arg(e).arg(sock.errorString())); break; default: LOG(("TCP Error: other (%1) - %2").arg(e).arg(sock.errorString())); break; } TCP_LOG(("TCP Error %1, restarting! - %2").arg(e).arg(sock.errorString())); } TCPConnection::TCPConnection(QThread *thread) : AbstractTCPConnection(thread) , status(WaitingTcp) , tcpNonce(rand_value()) , _tcpTimeout(kMinReceiveTimeout) { tcpTimeoutTimer.moveToThread(thread); tcpTimeoutTimer.setSingleShot(true); connect(&tcpTimeoutTimer, SIGNAL(timeout()), this, SLOT(onTcpTimeoutTimer())); sock.moveToThread(thread); connect(&sock, SIGNAL(error(QAbstractSocket::SocketError)), this, SLOT(socketError(QAbstractSocket::SocketError))); connect(&sock, SIGNAL(connected()), this, SLOT(onSocketConnected())); connect(&sock, SIGNAL(disconnected()), this, SLOT(onSocketDisconnected())); } void TCPConnection::onSocketConnected() { if (status == WaitingTcp) { mtpBuffer buffer(preparePQFake(tcpNonce)); DEBUG_LOG(("Connection Info: sending fake req_pq through TCP transport to %1").arg(_address)); if (_tcpTimeout < 0) _tcpTimeout = -_tcpTimeout; tcpTimeoutTimer.start(_tcpTimeout); sendData(buffer); } } void TCPConnection::onTcpTimeoutTimer() { if (status == WaitingTcp) { if (_tcpTimeout < kMaxReceiveTimeout) { _tcpTimeout *= 2; } _tcpTimeout = -_tcpTimeout; QAbstractSocket::SocketState state = sock.state(); if (state == QAbstractSocket::ConnectedState || state == QAbstractSocket::ConnectingState || state == QAbstractSocket::HostLookupState) { sock.disconnectFromHost(); } else if (state != QAbstractSocket::ClosingState) { sock.connectToHost(QHostAddress(_address), _port); } } } void TCPConnection::onSocketDisconnected() { if (_tcpTimeout < 0) { _tcpTimeout = -_tcpTimeout; if (status == WaitingTcp) { sock.connectToHost(QHostAddress(_address), _port); return; } } if (status == WaitingTcp || status == UsingTcp) { emit disconnected(); } } void TCPConnection::sendData(mtpBuffer &buffer) { if (status == FinishedWork) return; if (buffer.size() < 3) { LOG(("TCP Error: writing bad packet, len = %1").arg(buffer.size() * sizeof(mtpPrime))); TCP_LOG(("TCP Error: bad packet %1").arg(Logs::mb(&buffer[0], buffer.size() * sizeof(mtpPrime)).str())); emit error(kErrorCodeOther); return; } tcpSend(buffer); } void AbstractTCPConnection::writeConnectionStart() { // prepare random part auto nonceBytes = bytes::vector(64); const auto nonce = bytes::make_span(nonceBytes); const auto zero = reinterpret_cast(nonce.data()); const auto first = reinterpret_cast(nonce.data()); const auto second = first + 1; const auto reserved01 = 0x000000EFU; const auto reserved11 = 0x44414548U; const auto reserved12 = 0x54534F50U; const auto reserved13 = 0x20544547U; const auto reserved14 = 0x20544547U; const auto reserved15 = 0xEEEEEEEEU; const auto reserved21 = 0x00000000U; do { bytes::set_random(nonce); } while (*zero == reserved01 || *first == reserved11 || *first == reserved12 || *first == reserved13 || *first == reserved14 || *first == reserved15 || *second == reserved21); const auto prepareKey = [&](bytes::span key, bytes::const_span from) { if (_protocolSecret.size() == 16) { const auto payload = bytes::concatenate(from, _protocolSecret); bytes::copy(key, openssl::Sha256(payload)); } else if (_protocolSecret.empty()) { bytes::copy(key, from); } else { bytes::set_with_const(key, gsl::byte{}); } }; // prepare encryption key/iv prepareKey( bytes::make_span(_sendKey), nonce.subspan(8, CTRState::KeySize)); bytes::copy( bytes::make_span(_sendState.ivec), nonce.subspan(8 + CTRState::KeySize, CTRState::IvecSize)); // prepare decryption key/iv auto reversedBytes = bytes::vector(48); const auto reversed = bytes::make_span(reversedBytes); bytes::copy(reversed, nonce.subspan(8, reversed.size())); std::reverse(reversed.begin(), reversed.end()); prepareKey( bytes::make_span(_receiveKey), reversed.subspan(0, CTRState::KeySize)); bytes::copy( bytes::make_span(_receiveState.ivec), reversed.subspan(CTRState::KeySize, CTRState::IvecSize)); // write protocol and dc ids const auto protocol = reinterpret_cast(nonce.data() + 56); *protocol = 0xEFEFEFEFU; const auto dcId = reinterpret_cast(nonce.data() + 60); *dcId = _protocolDcId; sock.write(reinterpret_cast(nonce.data()), 56); aesCtrEncrypt(nonce.data(), 64, _sendKey, &_sendState); sock.write(reinterpret_cast(nonce.subspan(56).data()), 8); } void AbstractTCPConnection::tcpSend(mtpBuffer &buffer) { if (!packetNum) { writeConnectionStart(); } ++packetNum; uint32 size = buffer.size() - 3, len = size * 4; char *data = reinterpret_cast(&buffer[0]); if (size < 0x7f) { data[7] = char(size); TCP_LOG(("TCP Info: write %1 packet %2").arg(packetNum).arg(len + 1)); aesCtrEncrypt(data + 7, len + 1, _sendKey, &_sendState); sock.write(data + 7, len + 1); } else { data[4] = 0x7f; reinterpret_cast(data)[5] = uchar(size & 0xFF); reinterpret_cast(data)[6] = uchar((size >> 8) & 0xFF); reinterpret_cast(data)[7] = uchar((size >> 16) & 0xFF); TCP_LOG(("TCP Info: write %1 packet %2").arg(packetNum).arg(len + 4)); aesCtrEncrypt(data + 4, len + 4, _sendKey, &_sendState); sock.write(data + 4, len + 4); } } void TCPConnection::disconnectFromServer() { if (status == FinishedWork) return; status = FinishedWork; disconnect(&sock, SIGNAL(readyRead()), 0, 0); sock.close(); } void TCPConnection::connectToServer( const QString &ip, int port, const bytes::vector &protocolSecret, int16 protocolDcId) { _address = ip; _port = port; _protocolSecret = protocolSecret; _protocolDcId = protocolDcId; connect(&sock, SIGNAL(readyRead()), this, SLOT(socketRead())); sock.connectToHost(QHostAddress(_address), _port); } void TCPConnection::socketPacket(const char *packet, uint32 length) { if (status == FinishedWork) return; mtpBuffer data = handleResponse(packet, length); if (data.size() == 1) { emit error(data[0]); } else if (status == UsingTcp) { _receivedQueue.push_back(data); emit receivedData(); } else if (status == WaitingTcp) { tcpTimeoutTimer.stop(); try { auto res_pq = readPQFakeReply(data); const auto &res_pq_data(res_pq.c_resPQ()); if (res_pq_data.vnonce == tcpNonce) { DEBUG_LOG(("Connection Info: TCP-transport to %1 chosen by pq-response").arg(_address)); status = UsingTcp; emit connected(); } } catch (Exception &e) { DEBUG_LOG(("Connection Error: exception in parsing TCP fake pq-responce, %1").arg(e.what())); emit error(kErrorCodeOther); } } } bool TCPConnection::isConnected() const { return (status == UsingTcp); } int32 TCPConnection::debugState() const { return sock.state(); } QString TCPConnection::transport() const { if (!isConnected()) { return QString(); } auto result = qsl("TCP"); if (qthelp::is_ipv6(_address)) { result += qsl("/IPv6"); } return result; } QString TCPConnection::tag() const { auto result = qsl("TCP"); if (qthelp::is_ipv6(_address)) { result += qsl("/IPv6"); } else { result += qsl("/IPv4"); } return result; } void TCPConnection::socketError(QAbstractSocket::SocketError e) { if (status == FinishedWork) return; handleError(e, sock); emit error(kErrorCodeOther); } } // namespace internal } // namespace MTP