/* This file is part of Telegram Desktop, the official desktop application for the Telegram messaging service. For license and copyright information please follow this link: https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL */ #pragma once #include "base/bytes.h" #include "base/algorithm.h" #include "base/basic_types.h" extern "C" { #include #include #include #include #include #include #include } // extern "C" #ifdef small #undef small #endif // small namespace openssl { class Context { public: Context() : _data(BN_CTX_new()) { } Context(const Context &other) = delete; Context(Context &&other) : _data(base::take(other._data)) { } Context &operator=(const Context &other) = delete; Context &operator=(Context &&other) { _data = base::take(other._data); return *this; } ~Context() { if (_data) { BN_CTX_free(_data); } } BN_CTX *raw() const { return _data; } private: BN_CTX *_data = nullptr; }; class BigNum { public: BigNum() : _data(BN_new()) { } BigNum(const BigNum &other) : BigNum() { *this = other; } BigNum &operator=(const BigNum &other) { if (other.failed() || !BN_copy(raw(), other.raw())) { _failed = true; } return *this; } ~BigNum() { BN_clear_free(raw()); } explicit BigNum(unsigned int word) : BigNum() { setWord(word); } explicit BigNum(bytes::const_span bytes) : BigNum() { setBytes(bytes); } void setWord(unsigned int word) { if (!BN_set_word(raw(), word)) { _failed = true; } } void setBytes(bytes::const_span bytes) { if (!BN_bin2bn( reinterpret_cast(bytes.data()), bytes.size(), raw())) { _failed = true; } } void setAdd(const BigNum &a, const BigNum &b) { if (a.failed() || b.failed()) { _failed = true; } else if (!BN_add(raw(), a.raw(), b.raw())) { _failed = true; } } void setSub(const BigNum &a, const BigNum &b) { if (a.failed() || b.failed()) { _failed = true; } else if (!BN_sub(raw(), a.raw(), b.raw())) { _failed = true; } } void setSubWord(unsigned int word) { if (failed()) { return; } else if (!BN_sub_word(raw(), word)) { _failed = true; } } void setMul( const BigNum &a, const BigNum &b, const Context &context = Context()) { if (a.failed() || b.failed()) { _failed = true; } else if (!BN_mul(raw(), a.raw(), b.raw(), context.raw())) { _failed = true; } } BN_ULONG setDivWord(BN_ULONG word) { Expects(word != 0); if (failed()) { return (BN_ULONG)-1; } auto result = BN_div_word(raw(), word); if (result == (BN_ULONG)-1) { _failed = true; } return result; } void setModSub( const BigNum &a, const BigNum &b, const BigNum &m, const Context &context = Context()) { if (a.failed() || b.failed() || m.failed()) { _failed = true; } else if (a.isNegative() || b.isNegative() || m.isNegative()) { _failed = true; } else if (!BN_mod_sub(raw(), a.raw(), b.raw(), m.raw(), context.raw())) { _failed = true; } else if (isNegative()) { _failed = true; } } void setModMul( const BigNum &a, const BigNum &b, const BigNum &m, const Context &context = Context()) { if (a.failed() || b.failed() || m.failed()) { _failed = true; } else if (a.isNegative() || b.isNegative() || m.isNegative()) { _failed = true; } else if (!BN_mod_mul(raw(), a.raw(), b.raw(), m.raw(), context.raw())) { _failed = true; } else if (isNegative()) { _failed = true; } } void setModExp( const BigNum &base, const BigNum &power, const BigNum &m, const Context &context = Context()) { if (base.failed() || power.failed() || m.failed()) { _failed = true; } else if (base.isNegative() || power.isNegative() || m.isNegative()) { _failed = true; } else if (!BN_mod_exp(raw(), base.raw(), power.raw(), m.raw(), context.raw())) { _failed = true; } else if (isNegative()) { _failed = true; } } bool isNegative() const { return failed() ? false : BN_is_negative(raw()); } bool isPrime(const Context &context = Context()) const { if (failed()) { return false; } constexpr auto kMillerRabinIterationCount = 30; auto result = BN_is_prime_ex( raw(), kMillerRabinIterationCount, context.raw(), NULL); if (result == 1) { return true; } else if (result != 0) { _failed = true; } return false; } BN_ULONG modWord(BN_ULONG word) const { Expects(word != 0); if (failed()) { return (BN_ULONG)-1; } auto result = BN_mod_word(raw(), word); if (result == (BN_ULONG)-1) { _failed = true; } return result; } int bitsSize() const { return failed() ? 0 : BN_num_bits(raw()); } int bytesSize() const { return failed() ? 0 : BN_num_bytes(raw()); } bytes::vector getBytes() const { if (failed()) { return {}; } auto length = BN_num_bytes(raw()); auto result = bytes::vector(length); auto resultSize = BN_bn2bin( raw(), reinterpret_cast(result.data())); Assert(resultSize == length); return result; } BIGNUM *raw() { return _data; } const BIGNUM *raw() const { return _data; } BIGNUM *takeRaw() { return base::take(_data); } bool failed() const { return _failed; } static BigNum Add(const BigNum &a, const BigNum &b) { BigNum result; result.setAdd(a, b); return result; } static BigNum Sub(const BigNum &a, const BigNum &b) { BigNum result; result.setSub(a, b); return result; } static BigNum Mul( const BigNum &a, const BigNum &b, const Context &context = Context()) { BigNum result; result.setMul(a, b, context); return result; } static BigNum ModSub( const BigNum &a, const BigNum &b, const BigNum &mod, const Context &context = Context()) { BigNum result; result.setModSub(a, b, mod, context); return result; } static BigNum ModMul( const BigNum &a, const BigNum &b, const BigNum &mod, const Context &context = Context()) { BigNum result; result.setModMul(a, b, mod, context); return result; } static BigNum ModExp( const BigNum &base, const BigNum &power, const BigNum &mod, const Context &context = Context()) { BigNum result; result.setModExp(base, power, mod, context); return result; } static BigNum Failed() { BigNum result; result._failed = true; return result; } private: BIGNUM *_data = nullptr; mutable bool _failed = false; }; namespace details { template inline void ShaUpdate(Context context, Method method, Arg &&arg) { const auto span = bytes::make_span(arg); method(context, span.data(), span.size()); } template inline void ShaUpdate(Context context, Method method, Arg &&arg, Args &&...args) { const auto span = bytes::make_span(arg); method(context, span.data(), span.size()); ShaUpdate(context, method, args...); } template inline bytes::vector Sha(Method method, bytes::const_span data) { auto result = bytes::vector(Size); method( reinterpret_cast(data.data()), data.size(), reinterpret_cast(result.data())); return result; } template < size_type Size, typename Context, typename Init, typename Update, typename Finalize, typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 1)>> bytes::vector Sha( Context context, Init init, Update update, Finalize finalize, Args &&...args) { auto result = bytes::vector(Size); init(&context); ShaUpdate(&context, update, args...); finalize(reinterpret_cast(result.data()), &context); return result; } template < size_type Size, typename Evp> bytes::vector Pbkdf2( bytes::const_span password, bytes::const_span salt, int iterations, Evp evp) { auto result = bytes::vector(Size); PKCS5_PBKDF2_HMAC( reinterpret_cast(password.data()), password.size(), reinterpret_cast(salt.data()), salt.size(), iterations, evp, result.size(), reinterpret_cast(result.data())); return result; } } // namespace details constexpr auto kSha1Size = size_type(SHA_DIGEST_LENGTH); constexpr auto kSha256Size = size_type(SHA256_DIGEST_LENGTH); constexpr auto kSha512Size = size_type(SHA512_DIGEST_LENGTH); inline bytes::vector Sha1(bytes::const_span data) { return details::Sha(SHA1, data); } template < typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 1)>> inline bytes::vector Sha1(Args &&...args) { return details::Sha( SHA_CTX(), SHA1_Init, SHA1_Update, SHA1_Final, args...); } inline bytes::vector Sha256(bytes::const_span data) { return details::Sha(SHA256, data); } template < typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 1)>> inline bytes::vector Sha256(Args &&...args) { return details::Sha( SHA256_CTX(), SHA256_Init, SHA256_Update, SHA256_Final, args...); } inline bytes::vector Sha512(bytes::const_span data) { return details::Sha(SHA512, data); } template < typename ...Args, typename = std::enable_if_t<(sizeof...(Args) > 1)>> inline bytes::vector Sha512(Args &&...args) { return details::Sha( SHA512_CTX(), SHA512_Init, SHA512_Update, SHA512_Final, args...); } inline void AddRandomSeed(bytes::const_span data) { RAND_seed(data.data(), data.size()); } inline bytes::vector Pbkdf2Sha512( bytes::const_span password, bytes::const_span salt, int iterations) { return details::Pbkdf2( password, salt, iterations, EVP_sha512()); } } // namespace openssl namespace bytes { inline void set_random(span destination) { RAND_bytes( reinterpret_cast(destination.data()), destination.size()); } } // namespace bytes