Improve first socket message.

This commit is contained in:
John Preston 2019-08-24 19:31:51 +03:00
parent 56c4d164f3
commit d66541989e
6 changed files with 326 additions and 111 deletions

View File

@ -99,6 +99,7 @@ tlsBlockRandom length:int = TlsBlock;
tlsBlockZero length:int = TlsBlock;
tlsBlockDomain = TlsBlock;
tlsBlockGrease seed:int = TlsBlock;
tlsBlockPublicKey = TlsBlock;
tlsBlockScope entries:Vector<TlsBlock> = TlsBlock;
---functions---

View File

@ -57,19 +57,38 @@ private:
class BigNum {
public:
BigNum() : _data(BN_new()) {
BigNum() = default;
BigNum(const BigNum &other)
: _data((other.failed() || other.isZero())
? nullptr
: BN_dup(other.raw()))
, _failed(other._failed) {
}
BigNum(const BigNum &other) : BigNum() {
*this = other;
BigNum(BigNum &&other)
: _data(std::exchange(other._data, nullptr))
, _failed(std::exchange(other._failed, false)) {
}
BigNum &operator=(const BigNum &other) {
if (other.failed() || !BN_copy(raw(), other.raw())) {
if (other.failed()) {
_failed = true;
} else if (other.isZero()) {
clear();
_failed = false;
} else if (!_data) {
_data = BN_dup(other.raw());
_failed = false;
} else {
_failed = !BN_copy(raw(), other.raw());
}
return *this;
}
BigNum &operator=(BigNum &&other) {
std::swap(_data, other._data);
std::swap(_failed, other._failed);
return *this;
}
~BigNum() {
BN_clear_free(raw());
clear();
}
explicit BigNum(unsigned int word) : BigNum() {
@ -79,64 +98,74 @@ public:
setBytes(bytes);
}
void setWord(unsigned int word) {
if (!BN_set_word(raw(), word)) {
_failed = true;
BigNum &setWord(unsigned int word) {
if (!word) {
clear();
_failed = false;
} else {
_failed = !BN_set_word(raw(), word);
}
return *this;
}
void setBytes(bytes::const_span bytes) {
if (!BN_bin2bn(
BigNum &setBytes(bytes::const_span bytes) {
if (bytes.empty()) {
clear();
_failed = false;
} else {
_failed = !BN_bin2bn(
reinterpret_cast<const unsigned char*>(bytes.data()),
bytes.size(),
raw())) {
_failed = true;
raw());
}
return *this;
}
void setAdd(const BigNum &a, const BigNum &b) {
BigNum &setAdd(const BigNum &a, const BigNum &b) {
if (a.failed() || b.failed()) {
_failed = true;
} else if (!BN_add(raw(), a.raw(), b.raw())) {
_failed = true;
} else {
_failed = !BN_add(raw(), a.raw(), b.raw());
}
return *this;
}
void setSub(const BigNum &a, const BigNum &b) {
BigNum &setSub(const BigNum &a, const BigNum &b) {
if (a.failed() || b.failed()) {
_failed = true;
} else if (!BN_sub(raw(), a.raw(), b.raw())) {
_failed = true;
} else {
_failed = !BN_sub(raw(), a.raw(), b.raw());
}
return *this;
}
void setSubWord(unsigned int word) {
if (failed()) {
return;
} else if (!BN_sub_word(raw(), word)) {
_failed = true;
}
}
void setMul(
BigNum &setMul(
const BigNum &a,
const BigNum &b,
const Context &context = Context()) {
if (a.failed() || b.failed()) {
_failed = true;
} else if (!BN_mul(raw(), a.raw(), b.raw(), context.raw())) {
_failed = true;
} else {
_failed = !BN_mul(raw(), a.raw(), b.raw(), context.raw());
}
return *this;
}
BN_ULONG setDivWord(BN_ULONG word) {
Expects(word != 0);
if (failed()) {
return (BN_ULONG)-1;
}
auto result = BN_div_word(raw(), word);
if (result == (BN_ULONG)-1) {
BigNum &setModAdd(
const BigNum &a,
const BigNum &b,
const BigNum &m,
const Context &context = Context()) {
if (a.failed() || b.failed() || m.failed()) {
_failed = true;
} else if (a.isNegative() || b.isNegative() || m.isNegative()) {
_failed = true;
} else if (!BN_mod_add(raw(), a.raw(), b.raw(), m.raw(), context.raw())) {
_failed = true;
} else if (isNegative()) {
_failed = true;
} else {
_failed = false;
}
return result;
return *this;
}
void setModSub(
BigNum &setModSub(
const BigNum &a,
const BigNum &b,
const BigNum &m,
@ -149,9 +178,12 @@ public:
_failed = true;
} else if (isNegative()) {
_failed = true;
} else {
_failed = false;
}
return *this;
}
void setModMul(
BigNum &setModMul(
const BigNum &a,
const BigNum &b,
const BigNum &m,
@ -164,9 +196,29 @@ public:
_failed = true;
} else if (isNegative()) {
_failed = true;
} else {
_failed = false;
}
return *this;
}
void setModExp(
BigNum &setModInverse(
const BigNum &a,
const BigNum &m,
const Context &context = Context()) {
if (a.failed() || m.failed()) {
_failed = true;
} else if (a.isNegative() || m.isNegative()) {
_failed = true;
} else if (!BN_mod_inverse(raw(), a.raw(), m.raw(), context.raw())) {
_failed = true;
} else if (isNegative()) {
_failed = true;
} else {
_failed = false;
}
return *this;
}
BigNum &setModExp(
const BigNum &base,
const BigNum &power,
const BigNum &m,
@ -179,23 +231,34 @@ public:
_failed = true;
} else if (isNegative()) {
_failed = true;
} else {
_failed = false;
}
return *this;
}
bool isNegative() const {
return failed() ? false : BN_is_negative(raw());
[[nodiscard]] bool isZero() const {
return !failed() && (!_data || BN_is_zero(raw()));
}
bool isPrime(const Context &context = Context()) const {
if (failed()) {
[[nodiscard]] bool isOne() const {
return !failed() && _data && BN_is_one(raw());
}
[[nodiscard]] bool isNegative() const {
return !failed() && _data && BN_is_negative(raw());
}
[[nodiscard]] bool isPrime(const Context &context = Context()) const {
if (failed() || !_data) {
return false;
}
constexpr auto kMillerRabinIterationCount = 30;
auto result = BN_is_prime_ex(
const auto result = BN_is_prime_ex(
raw(),
kMillerRabinIterationCount,
context.raw(),
NULL);
nullptr);
if (result == 1) {
return true;
} else if (result != 0) {
@ -204,27 +267,42 @@ public:
return false;
}
BN_ULONG modWord(BN_ULONG word) const {
Expects(word != 0);
BigNum &subWord(unsigned int word) {
if (failed()) {
return (BN_ULONG)-1;
return *this;
} else if (!BN_sub_word(raw(), word)) {
_failed = true;
}
return *this;
}
BigNum &divWord(BN_ULONG word, BN_ULONG *mod = nullptr) {
Expects(word != 0);
auto result = BN_mod_word(raw(), word);
const auto result = failed()
? (BN_ULONG)-1
: BN_div_word(raw(), word);
if (result == (BN_ULONG)-1) {
_failed = true;
}
return result;
if (mod) {
*mod = result;
}
return *this;
}
[[nodiscard]] BN_ULONG countModWord(BN_ULONG word) const {
Expects(word != 0);
return failed() ? (BN_ULONG)-1 : BN_mod_word(raw(), word);
}
int bitsSize() const {
[[nodiscard]] int bitsSize() const {
return failed() ? 0 : BN_num_bits(raw());
}
int bytesSize() const {
[[nodiscard]] int bytesSize() const {
return failed() ? 0 : BN_num_bytes(raw());
}
bytes::vector getBytes() const {
[[nodiscard]] bytes::vector getBytes() const {
if (failed()) {
return {};
}
@ -237,73 +315,84 @@ public:
return result;
}
BIGNUM *raw() {
[[nodiscard]] BIGNUM *raw() {
if (!_data) _data = BN_new();
return _data;
}
const BIGNUM *raw() const {
[[nodiscard]] const BIGNUM *raw() const {
if (!_data) _data = BN_new();
return _data;
}
BIGNUM *takeRaw() {
return base::take(_data);
[[nodiscard]] BIGNUM *takeRaw() {
return _failed
? nullptr
: _data
? std::exchange(_data, nullptr)
: BN_new();
}
bool failed() const {
[[nodiscard]] bool failed() const {
return _failed;
}
static BigNum Add(const BigNum &a, const BigNum &b) {
BigNum result;
result.setAdd(a, b);
return result;
[[nodiscard]] static BigNum Add(const BigNum &a, const BigNum &b) {
return BigNum().setAdd(a, b);
}
static BigNum Sub(const BigNum &a, const BigNum &b) {
BigNum result;
result.setSub(a, b);
return result;
[[nodiscard]] static BigNum Sub(const BigNum &a, const BigNum &b) {
return BigNum().setSub(a, b);
}
static BigNum Mul(
[[nodiscard]] static BigNum Mul(
const BigNum &a,
const BigNum &b,
const Context &context = Context()) {
BigNum result;
result.setMul(a, b, context);
return result;
return BigNum().setMul(a, b, context);
}
static BigNum ModSub(
[[nodiscard]] static BigNum ModAdd(
const BigNum &a,
const BigNum &b,
const BigNum &mod,
const Context &context = Context()) {
BigNum result;
result.setModSub(a, b, mod, context);
return result;
return BigNum().setModAdd(a, b, mod, context);
}
static BigNum ModMul(
[[nodiscard]] static BigNum ModSub(
const BigNum &a,
const BigNum &b,
const BigNum &mod,
const Context &context = Context()) {
BigNum result;
result.setModMul(a, b, mod, context);
return result;
return BigNum().setModSub(a, b, mod, context);
}
static BigNum ModExp(
[[nodiscard]] static BigNum ModMul(
const BigNum &a,
const BigNum &b,
const BigNum &mod,
const Context &context = Context()) {
return BigNum().setModMul(a, b, mod, context);
}
[[nodiscard]] static BigNum ModInverse(
const BigNum &a,
const BigNum &mod,
const Context &context = Context()) {
return BigNum().setModInverse(a, mod, context);
}
[[nodiscard]] static BigNum ModExp(
const BigNum &base,
const BigNum &power,
const BigNum &mod,
const Context &context = Context()) {
BigNum result;
result.setModExp(base, power, mod, context);
return result;
return BigNum().setModExp(base, power, mod, context);
}
static BigNum Failed() {
BigNum result;
[[nodiscard]] static BigNum Failed() {
auto result = BigNum();
result._failed = true;
return result;
}
private:
BIGNUM *_data = nullptr;
void clear() {
BN_clear_free(std::exchange(_data, nullptr));
}
mutable BIGNUM *_data = nullptr;
mutable bool _failed = false;
};

View File

@ -86,12 +86,16 @@ bool IsGoodModExpFirst(
bool IsPrimeAndGoodCheck(const openssl::BigNum &prime, int g) {
constexpr auto kGoodPrimeBitsCount = 2048;
if (prime.failed() || prime.isNegative() || prime.bitsSize() != kGoodPrimeBitsCount) {
LOG(("MTP Error: Bad prime bits count %1, expected %2.").arg(prime.bitsSize()).arg(kGoodPrimeBitsCount));
if (prime.failed()
|| prime.isNegative()
|| prime.bitsSize() != kGoodPrimeBitsCount) {
LOG(("MTP Error: Bad prime bits count %1, expected %2."
).arg(prime.bitsSize()
).arg(kGoodPrimeBitsCount));
return false;
}
openssl::Context context;
const auto context = openssl::Context();
if (!prime.isPrime(context)) {
LOG(("MTP Error: Bad prime."));
return false;
@ -99,14 +103,14 @@ bool IsPrimeAndGoodCheck(const openssl::BigNum &prime, int g) {
switch (g) {
case 2: {
auto mod8 = prime.modWord(8);
const auto mod8 = prime.countModWord(8);
if (mod8 != 7) {
LOG(("BigNum PT Error: bad g value: %1, mod8: %2").arg(g).arg(mod8));
return false;
}
} break;
case 3: {
auto mod3 = prime.modWord(3);
const auto mod3 = prime.countModWord(3);
if (mod3 != 2) {
LOG(("BigNum PT Error: bad g value: %1, mod3: %2").arg(g).arg(mod3));
return false;
@ -114,21 +118,21 @@ bool IsPrimeAndGoodCheck(const openssl::BigNum &prime, int g) {
} break;
case 4: break;
case 5: {
auto mod5 = prime.modWord(5);
const auto mod5 = prime.countModWord(5);
if (mod5 != 1 && mod5 != 4) {
LOG(("BigNum PT Error: bad g value: %1, mod5: %2").arg(g).arg(mod5));
return false;
}
} break;
case 6: {
auto mod24 = prime.modWord(24);
const auto mod24 = prime.countModWord(24);
if (mod24 != 19 && mod24 != 23) {
LOG(("BigNum PT Error: bad g value: %1, mod24: %2").arg(g).arg(mod24));
return false;
}
} break;
case 7: {
auto mod7 = prime.modWord(7);
const auto mod7 = prime.countModWord(7);
if (mod7 != 3 && mod7 != 5 && mod7 != 6) {
LOG(("BigNum PT Error: bad g value: %1, mod7: %2").arg(g).arg(mod7));
return false;
@ -140,10 +144,7 @@ bool IsPrimeAndGoodCheck(const openssl::BigNum &prime, int g) {
} break;
}
auto primeSubOneDivTwo = prime;
primeSubOneDivTwo.setSubWord(1);
primeSubOneDivTwo.setDivWord(2);
if (!primeSubOneDivTwo.isPrime(context)) {
if (!openssl::BigNum(prime).subWord(1).divWord(2).isPrime(context)) {
LOG(("MTP Error: Bad (prime - 1) / 2."));
return false;
}
@ -184,8 +185,9 @@ bytes::vector CreateAuthKey(
bytes::const_span randomBytes,
bytes::const_span primeBytes) {
using openssl::BigNum;
BigNum first(firstBytes);
BigNum prime(primeBytes);
const auto first = BigNum(firstBytes);
const auto prime = BigNum(primeBytes);
if (!IsGoodModExpFirst(first, prime)) {
LOG(("AuthKey Error: Bad first prime in CreateAuthKey()."));
return {};
@ -3304,15 +3306,23 @@ bool IsPrimeAndGood(bytes::const_span primeBytes, int g) {
return internal::IsPrimeAndGood(primeBytes, g);
}
bool IsGoodModExpFirst(const openssl::BigNum &modexp, const openssl::BigNum &prime) {
bool IsGoodModExpFirst(
const openssl::BigNum &modexp,
const openssl::BigNum &prime) {
return internal::IsGoodModExpFirst(modexp, prime);
}
ModExpFirst CreateModExp(int g, bytes::const_span primeBytes, bytes::const_span randomSeed) {
ModExpFirst CreateModExp(
int g,
bytes::const_span primeBytes,
bytes::const_span randomSeed) {
return internal::CreateModExp(g, primeBytes, randomSeed);
}
bytes::vector CreateAuthKey(bytes::const_span firstBytes, bytes::const_span randomBytes, bytes::const_span primeBytes) {
bytes::vector CreateAuthKey(
bytes::const_span firstBytes,
bytes::const_span randomBytes,
bytes::const_span primeBytes) {
return internal::CreateAuthKey(firstBytes, randomBytes, primeBytes);
}

View File

@ -20,16 +20,24 @@ constexpr auto kAckSendWaiting = crl::time(10000);
class Instance;
bool IsPrimeAndGood(bytes::const_span primeBytes, int g);
[[nodiscard]] bool IsPrimeAndGood(bytes::const_span primeBytes, int g);
struct ModExpFirst {
static constexpr auto kRandomPowerSize = 256;
bytes::vector modexp;
bytes::vector randomPower;
};
bool IsGoodModExpFirst(const openssl::BigNum &modexp, const openssl::BigNum &prime);
ModExpFirst CreateModExp(int g, bytes::const_span primeBytes, bytes::const_span randomSeed);
bytes::vector CreateAuthKey(bytes::const_span firstBytes, bytes::const_span randomBytes, bytes::const_span primeBytes);
[[nodiscard]] bool IsGoodModExpFirst(
const openssl::BigNum &modexp,
const openssl::BigNum &prime);
[[nodiscard]] ModExpFirst CreateModExp(
int g,
bytes::const_span primeBytes,
bytes::const_span randomSeed);
[[nodiscard]] bytes::vector CreateAuthKey(
bytes::const_span firstBytes,
bytes::const_span randomBytes,
bytes::const_span primeBytes);
namespace internal {

View File

@ -14,6 +14,7 @@ https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
#include "base/unixtime.h"
#include <QtCore/QtEndian>
#include <range/v3/algorithm/reverse.hpp>
namespace MTP {
namespace internal {
@ -31,6 +32,9 @@ constexpr auto kClientPartSize = 2878;
const auto kClientPrefix = qstr("\x14\x03\x03\x00\x01\x01");
const auto kClientHeader = qstr("\x17\x03\x03");
using BigNum = openssl::BigNum;
using BigNumContext = openssl::Context;
[[nodiscard]] MTPTlsClientHello PrepareClientHelloRules() {
auto stack = std::vector<QVector<MTPTlsBlock>>();
const auto pushToBack = [&](MTPTlsBlock &&block) {
@ -54,6 +58,9 @@ const auto kClientHeader = qstr("\x17\x03\x03");
const auto D = [&] {
pushToBack(MTP_tlsBlockDomain());
};
const auto K = [&] {
pushToBack(MTP_tlsBlockPublicKey());
};
const auto Open = [&] {
stack.emplace_back();
};
@ -102,7 +109,7 @@ const auto kClientHeader = qstr("\x17\x03\x03");
"\x01\x02\x01\x00\x12\x00\x00\x00\x33\x00\x2b\x00\x29"));
G(4);
S(qstr("\x00\x01\x00\x00\x1d\x00\x20"));
R(32);
K();
S(qstr("\x00\x2d\x00\x02\x01\x01\x00\x2b\x00\x0b\x0a"));
G(6);
S(qstr("\x03\x04\x03\x03\x03\x02\x03\x01\x00\x1b\x00\x03\x02\x00\x02"));
@ -127,6 +134,96 @@ const auto kClientHeader = qstr("\x17\x03\x03");
return result;
}
// Returns y^2 = x^3 + 486662 * x^2 + x.
[[nodiscard]] BigNum GenerateY2(
const BigNum &x,
const BigNum &mod,
const BigNumContext &context) {
auto coef = BigNum(486662);
auto y = BigNum::ModAdd(x, coef, mod, context);
y.setModMul(y, x, mod, context);
coef.setWord(1);
y.setModAdd(y, coef, mod, context);
return BigNum::ModMul(y, x, mod, context);
}
// Returns x_2 = (x^2 - 1)^2/(4*y^2).
[[nodiscard]] BigNum GenerateX2(
const BigNum &x,
const BigNum &mod,
const BigNumContext &context) {
auto denominator = GenerateY2(x, mod, context);
auto coef = BigNum(4);
denominator.setModMul(denominator, coef, mod, context);
auto numerator = BigNum::ModMul(x, x, mod, context);
coef.setWord(1);
numerator.setModSub(numerator, coef, mod, context);
numerator.setModMul(numerator, numerator, mod, context);
denominator.setModInverse(denominator, mod, context);
return BigNum::ModMul(numerator, denominator, mod, context);
}
[[nodiscard]] bytes::vector GeneratePublicKey() {
const auto context = BigNumContext();
const char modBytes[] = ""
"\x7f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xed";
const char powBytes[] = ""
"\x3f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xf6";
const auto mod = BigNum(bytes::make_span(modBytes).subspan(0, 32));
const auto pow = BigNum(bytes::make_span(powBytes).subspan(0, 32));
auto x = BigNum();
do {
while (true) {
auto random = bytes::vector(32);
bytes::set_random(random);
random[31] &= bytes::type(0x7FU);
x.setBytes(random);
x.setModMul(x, x, mod, context);
auto y = GenerateY2(x, mod, context);
if (BigNum::ModExp(y, pow, mod, context).isOne()) {
break;
}
}
for (auto i = 0; i != 3; ++i) {
x = GenerateX2(x, mod, context);
}
const auto xBytes = x.getBytes();
Assert(!xBytes.empty());
Assert(xBytes.size() <= 32);
} while (x.bytesSize() == 32);
const auto xBytes = x.getBytes();
auto result = bytes::vector(32, bytes::type());
bytes::copy(
bytes::make_span(result).subspan(32 - xBytes.size()),
xBytes);
ranges::reverse(result);
//auto string = QString();
//string.reserve(64);
//for (const auto byte : result) {
// const auto code = uchar(byte);
// const auto hex = [](uchar value) -> char {
// if (value >= 0 && value <= 9) {
// return '0' + value;
// } else if (value >= 10 && value <= 15) {
// return 'a' + (value - 10);
// }
// return '-';
// };
// string.append(hex(code / 16)).append(hex(code % 16));
//}
//LOG(("KEY: %1").arg(string));
return result;
}
struct ClientHello {
QByteArray data;
QByteArray digest;
@ -149,6 +246,7 @@ private:
void writeBlock(const MTPDtlsBlockGrease &data);
void writeBlock(const MTPDtlsBlockRandom &data);
void writeBlock(const MTPDtlsBlockDomain &data);
void writeBlock(const MTPDtlsBlockPublicKey &data);
void writeBlock(const MTPDtlsBlockScope &data);
void writePadding();
void writeDigest();
@ -264,6 +362,15 @@ void ClientHelloGenerator::writeBlock(const MTPDtlsBlockDomain &data) {
bytes::copy(storage, _domain);
}
void ClientHelloGenerator::writeBlock(const MTPDtlsBlockPublicKey &data) {
const auto key = GeneratePublicKey();
const auto storage = grow(key.size());
if (storage.empty()) {
return;
}
bytes::copy(storage, key);
}
void ClientHelloGenerator::writeBlock(const MTPDtlsBlockScope &data) {
const auto storage = grow(kLengthSize);
if (storage.empty()) {

View File

@ -101,9 +101,9 @@ public:
Private(bytes::const_span nBytes, bytes::const_span eBytes)
: _rsa(RSA_new()) {
if (_rsa) {
auto n = openssl::BigNum(nBytes).takeRaw();
auto e = openssl::BigNum(eBytes).takeRaw();
auto valid = (n != nullptr) && (e != nullptr);
const auto n = openssl::BigNum(nBytes).takeRaw();
const auto e = openssl::BigNum(eBytes).takeRaw();
const auto valid = (n != nullptr) && (e != nullptr);
// We still pass both values to RSA_set0_key() so that even
// if only one of them is valid RSA would take ownership of it.
if (!RSA_set0_key(_rsa, n, e, nullptr) || !valid) {