Rewrite TCP socket reading using bytes::vector.

I hope this fixes a strange assertion violation.
This commit is contained in:
John Preston 2018-07-11 17:00:06 +03:00
parent 951634a717
commit 556f75ef6c
2 changed files with 105 additions and 95 deletions

View File

@ -19,8 +19,9 @@ namespace MTP {
namespace internal {
namespace {
constexpr auto kPacketSizeMax = 0x01000000 * sizeof(mtpPrime);
constexpr auto kPacketSizeMax = int(0x01000000 * sizeof(mtpPrime));
constexpr auto kFullConnectionTimeout = 8 * TimeMs(1000);
constexpr auto kSmallBufferSize = 256 * 1024;
using ErrorSignal = void(QTcpSocket::*)(QAbstractSocket::SocketError);
const auto QTcpSocket_error = ErrorSignal(&QAbstractSocket::error);
@ -38,9 +39,9 @@ public:
virtual void prepareKey(bytes::span key, bytes::const_span source) = 0;
virtual bytes::span finalizePacket(mtpBuffer &buffer) = 0;
static constexpr auto kUnknownSize = uint32(-1);
static constexpr auto kInvalidSize = uint32(-2);
virtual uint32 readPacketLength(bytes::const_span bytes) const = 0;
static constexpr auto kUnknownSize = -1;
static constexpr auto kInvalidSize = -2;
virtual int readPacketLength(bytes::const_span bytes) const = 0;
virtual bytes::const_span readPacket(bytes::const_span bytes) const = 0;
virtual ~Protocol() = default;
@ -61,7 +62,7 @@ public:
void prepareKey(bytes::span key, bytes::const_span source) override;
bytes::span finalizePacket(mtpBuffer &buffer) override;
uint32 readPacketLength(bytes::const_span bytes) const override;
int readPacketLength(bytes::const_span bytes) const override;
bytes::const_span readPacket(bytes::const_span bytes) const override;
};
@ -105,7 +106,7 @@ bytes::span TcpConnection::Protocol::Version0::finalizePacket(
return bytes::make_span(buffer).subspan(8 - added, added + bytesSize);
}
uint32 TcpConnection::Protocol::Version0::readPacketLength(
int TcpConnection::Protocol::Version0::readPacketLength(
bytes::const_span bytes) const {
if (bytes.empty()) {
return kUnknownSize;
@ -118,10 +119,10 @@ uint32 TcpConnection::Protocol::Version0::readPacketLength(
const auto ints = static_cast<uint32>(bytes[1])
| (static_cast<uint32>(bytes[2]) << 8)
| (static_cast<uint32>(bytes[3]) << 16);
return (ints >= 0x7F) ? ((ints << 2) + 4) : kInvalidSize;
return (ints >= 0x7F) ? (int(ints << 2) + 4) : kInvalidSize;
} else if (first > 0 && first < 0x7F) {
const auto ints = uint32(first);
return (ints << 2) + 1;
return int(ints << 2) + 1;
}
return kInvalidSize;
}
@ -172,7 +173,7 @@ public:
bytes::span finalizePacket(mtpBuffer &buffer) override;
uint32 readPacketLength(bytes::const_span bytes) const override;
int readPacketLength(bytes::const_span bytes) const override;
bytes::const_span readPacket(bytes::const_span bytes) const override;
};
@ -200,13 +201,15 @@ bytes::span TcpConnection::Protocol::VersionD::finalizePacket(
return bytes::make_span(buffer).subspan(4, 4 + bytesSize);
}
uint32 TcpConnection::Protocol::VersionD::readPacketLength(
int TcpConnection::Protocol::VersionD::readPacketLength(
bytes::const_span bytes) const {
if (bytes.size() < 4) {
return kUnknownSize;
}
const auto value = *reinterpret_cast<const uint32*>(bytes.data()) + 4;
return (value >= 8 && value < kPacketSizeMax) ? value : kInvalidSize;
return (value >= 8 && value < kPacketSizeMax)
? int(value)
: kInvalidSize;
}
bytes::const_span TcpConnection::Protocol::VersionD::readPacket(
@ -234,7 +237,6 @@ auto TcpConnection::Protocol::Create(bytes::vector &&secret)
TcpConnection::TcpConnection(QThread *thread, const ProxyData &proxy)
: AbstractConnection(thread, proxy)
, _currentPosition(reinterpret_cast<char*>(_shortBuffer))
, _checkNonce(rand_value<MTPint128>()) {
_socket.moveToThread(thread);
_socket.setProxy(ToNetworkProxy(proxy));
@ -265,6 +267,8 @@ ConnectionPointer TcpConnection::clone(const ProxyData &proxy) {
}
void TcpConnection::socketRead() {
Expects(_leftBytes > 0 || !_usingLargeBuffer);
if (_socket.state() != QAbstractSocket::ConnectedState) {
LOG(("MTP error: "
"socket not connected in socketRead(), state: %1"
@ -273,93 +277,101 @@ void TcpConnection::socketRead() {
return;
}
if (_smallBuffer.empty()) {
_smallBuffer.resize(kSmallBufferSize);
}
do {
uint32 toRead = _packetLeft
? _packetLeft
: (_readingToShort
? (kShortBufferSize * sizeof(mtpPrime) - _packetRead)
: 4);
if (_readingToShort) {
if (_currentPosition + toRead > ((char*)_shortBuffer) + kShortBufferSize * sizeof(mtpPrime)) {
_longBuffer.resize(((_packetRead + toRead) >> 2) + 1);
memcpy(&_longBuffer[0], _shortBuffer, _packetRead);
_currentPosition = ((char*)&_longBuffer[0]) + _packetRead;
_readingToShort = false;
}
} else {
if (_longBuffer.size() * sizeof(mtpPrime) < _packetRead + toRead) {
_longBuffer.resize(((_packetRead + toRead) >> 2) + 1);
_currentPosition = ((char*)&_longBuffer[0]) + _packetRead;
}
}
int32 bytes = (int32)_socket.read(_currentPosition, toRead);
if (bytes > 0) {
aesCtrEncrypt(
bytes::make_span(_currentPosition, bytes),
_receiveKey,
&_receiveState);
TCP_LOG(("TCP Info: read %1 bytes").arg(bytes));
const auto readLimit = (_leftBytes > 0)
? _leftBytes
: (kSmallBufferSize - _offsetBytes - _readBytes);
Assert(readLimit > 0);
_packetRead += bytes;
_currentPosition += bytes;
if (_packetLeft) {
_packetLeft -= bytes;
if (!_packetLeft) {
socketPacket(bytes::make_span(
_currentPosition - _packetRead,
_packetRead));
_currentPosition = (char*)_shortBuffer;
_packetRead = _packetLeft = 0;
_readingToShort = true;
_longBuffer.clear();
auto &buffer = _usingLargeBuffer ? _largeBuffer : _smallBuffer;
const auto full = bytes::make_span(buffer).subspan(_offsetBytes);
const auto free = full.subspan(_readBytes);
Assert(free.size() >= readLimit);
const auto readCount = _socket.read(
reinterpret_cast<char*>(free.data()),
readLimit);
if (readCount > 0) {
const auto read = free.subspan(0, readCount);
aesCtrEncrypt(read, _receiveKey, &_receiveState);
TCP_LOG(("TCP Info: read %1 bytes").arg(readCount));
_readBytes += readCount;
if (_leftBytes > 0) {
Assert(readCount <= _leftBytes);
_leftBytes -= readCount;
if (!_leftBytes) {
socketPacket(full.subspan(0, _readBytes));
_usingLargeBuffer = false;
_largeBuffer.clear();
_offsetBytes = _readBytes = 0;
} else {
TCP_LOG(("TCP Info: not enough %1 for packet! read %2"
).arg(_packetLeft
).arg(_packetRead));
).arg(_leftBytes
).arg(_readBytes));
emit receivedSome();
}
} else {
bool move = false;
while (_packetRead >= 4) {
auto available = full.subspan(0, _readBytes);
while (_readBytes > 0) {
const auto packetSize = _protocol->readPacketLength(
bytes::make_span(
_currentPosition - _packetRead,
_packetRead));
if (packetSize == Protocol::kUnknownSize
|| packetSize == Protocol::kInvalidSize) {
LOG(("TCP Error: packet size = %1").arg(packetSize));
available);
if (packetSize == Protocol::kUnknownSize) {
// Not enough bytes yet.
break;
} else if (packetSize <= 0) {
LOG(("TCP Error: bad packet size in 4 bytes: %1"
).arg(packetSize));
emit error(kErrorCodeOther);
return;
}
if (_packetRead >= packetSize) {
socketPacket(bytes::make_span(
_currentPosition - _packetRead,
packetSize));
_packetRead -= packetSize;
_packetLeft = 0;
move = true;
} else if (available.size() >= packetSize) {
socketPacket(available.subspan(0, packetSize));
available = available.subspan(packetSize);
_offsetBytes += packetSize;
_readBytes -= packetSize;
} else {
_packetLeft = packetSize - _packetRead;
TCP_LOG(("TCP Info: not enough %1 for packet! size %2 read %3").arg(_packetLeft).arg(packetSize).arg(_packetRead));
_leftBytes = packetSize - available.size();
// If the next packet won't fit in the buffer.
const auto full = bytes::make_span(buffer).subspan(
_offsetBytes);
if (full.size() < packetSize) {
const auto read = full.subspan(0, _readBytes);
if (packetSize <= _smallBuffer.size()) {
if (_usingLargeBuffer) {
bytes::copy(_smallBuffer, read);
_usingLargeBuffer = false;
_largeBuffer.clear();
} else {
bytes::move(_smallBuffer, read);
}
} else if (packetSize <= _largeBuffer.size()) {
Assert(_usingLargeBuffer);
bytes::move(_largeBuffer, read);
} else {
auto enough = bytes::vector(packetSize);
bytes::copy(enough, read);
_largeBuffer = std::move(enough);
_usingLargeBuffer = true;
}
_offsetBytes = 0;
}
TCP_LOG(("TCP Info: not enough %1 for packet! "
"full size %2 read %3"
).arg(_leftBytes
).arg(packetSize
).arg(available.size()));
emit receivedSome();
break;
}
}
if (move) {
if (!_packetRead) {
_currentPosition = (char*)_shortBuffer;
_readingToShort = true;
_longBuffer.clear();
} else if (!_readingToShort && _packetRead < kShortBufferSize * sizeof(mtpPrime)) {
memcpy(_shortBuffer, _currentPosition - _packetRead, _packetRead);
_currentPosition = (char*)_shortBuffer + _packetRead;
_readingToShort = true;
_longBuffer.clear();
}
}
}
} else if (bytes < 0) {
LOG(("TCP Error: socket read return -1"));
} else if (readCount < 0) {
LOG(("TCP Error: socket read return %1").arg(readCount));
emit error(kErrorCodeOther);
return;
} else {
@ -527,15 +539,14 @@ void TcpConnection::writeConnectionStart() {
}
void TcpConnection::sendBuffer(mtpBuffer &&buffer) {
if (!_packetIndex++) {
if (!_connectionStarted) {
writeConnectionStart();
_connectionStarted = true;
}
// buffer: 2 available int-s + data + available int.
const auto bytes = _protocol->finalizePacket(buffer);
TCP_LOG(("TCP Info: write %1 packet %2"
).arg(_packetIndex
).arg(bytes.size()));
TCP_LOG(("TCP Info: write packet %1 bytes").arg(bytes.size()));
aesCtrEncrypt(bytes, _sendKey, &_sendState);
_socket.write(
reinterpret_cast<const char*>(bytes.data()),

View File

@ -47,7 +47,6 @@ private:
Ready,
Finished,
};
static constexpr auto kShortBufferSize = 65535; // Of ints, 256 kb.
void socketRead();
void writeConnectionStart();
@ -68,14 +67,14 @@ private:
void sendBuffer(mtpBuffer &&buffer);
QTcpSocket _socket;
uint32 _packetIndex = 0; // sent packet number
bool _connectionStarted = false;
uint32 _packetRead = 0;
uint32 _packetLeft = 0; // reading from socket
bool _readingToShort = true;
mtpBuffer _longBuffer;
mtpPrime _shortBuffer[kShortBufferSize];
char *_currentPosition = nullptr;
int _offsetBytes = 0;
int _readBytes = 0;
int _leftBytes = 0;
bytes::vector _smallBuffer;
bytes::vector _largeBuffer;
bool _usingLargeBuffer = false;
uchar _sendKey[CTRState::KeySize];
CTRState _sendState;