From fdd89d65ca7df9e811e99eca9faa8f278d20519a Mon Sep 17 00:00:00 2001 From: John Preston Date: Thu, 5 Oct 2017 16:35:31 +0100 Subject: [PATCH] Allow using custom comparators in flat_[map|set]. --- Telegram/SourceFiles/base/flat_map.h | 202 +++++++++++++------ Telegram/SourceFiles/base/flat_set.h | 181 +++++++++++------ Telegram/SourceFiles/base/flat_set_tests.cpp | 2 + 3 files changed, 258 insertions(+), 127 deletions(-) diff --git a/Telegram/SourceFiles/base/flat_map.h b/Telegram/SourceFiles/base/flat_map.h index 86553d61a5..fe076adb50 100644 --- a/Telegram/SourceFiles/base/flat_map.h +++ b/Telegram/SourceFiles/base/flat_map.h @@ -26,15 +26,22 @@ Copyright (c) 2014-2017 John Preston, https://desktop.telegram.org namespace base { -template +template < + typename Key, + typename Type, + typename Compare = std::less<>> class flat_map; -template +template < + typename Key, + typename Type, + typename Compare = std::less<>> class flat_multi_map; template < typename Key, typename Type, + typename Compare, typename iterator_impl, typename pointer_impl, typename reference_impl> @@ -43,6 +50,7 @@ class flat_multi_map_iterator_base_impl; template < typename Key, typename Type, + typename Compare, typename iterator_impl, typename pointer_impl, typename reference_impl> @@ -50,12 +58,12 @@ class flat_multi_map_iterator_base_impl { public: using iterator_category = typename iterator_impl::iterator_category; - using value_type = typename flat_multi_map::value_type; + using value_type = typename flat_multi_map::value_type; using difference_type = typename iterator_impl::difference_type; using pointer = pointer_impl; - using const_pointer = typename flat_multi_map::const_pointer; + using const_pointer = typename flat_multi_map::const_pointer; using reference = reference_impl; - using const_reference = typename flat_multi_map::const_reference; + using const_reference = typename flat_multi_map::const_reference; flat_multi_map_iterator_base_impl(iterator_impl impl = iterator_impl()) : _impl(impl) { @@ -123,64 +131,129 @@ public: private: iterator_impl _impl; - friend class flat_multi_map; + friend class flat_multi_map; }; -template +template class flat_multi_map { - using self = flat_multi_map; class key_const_wrap { public: - key_const_wrap(const Key &value) : _value(value) { + constexpr key_const_wrap(const Key &value) : _value(value) { } - key_const_wrap(Key &&value) : _value(std::move(value)) { + constexpr key_const_wrap(Key &&value) : _value(std::move(value)) { } - inline operator const Key&() const { + inline constexpr operator const Key&() const { return _value; } - friend inline bool operator<(const Key &a, const key_const_wrap &b) { - return a < ((const Key&)b); - } - friend inline bool operator<(const key_const_wrap &a, const Key &b) { - return ((const Key&)a) < b; - } - friend inline bool operator<( - const key_const_wrap &a, - const key_const_wrap &b) { - return ((const Key&)a) < ((const Key&)b); - } - private: Key _value; }; - using pair_type = std::pair; + + class compare { + public: + template < + typename OtherType1, + typename OtherType2, + typename = std::enable_if_t< + !std::is_same_v, key_const_wrap> && + !std::is_same_v, pair_type> && + !std::is_same_v, key_const_wrap> && + !std::is_same_v, pair_type>>> + inline constexpr auto operator()( + OtherType1 &&a, + OtherType2 &b) const { + return Compare()( + std::forward(a), + std::forward(b)); + } + inline constexpr auto operator()( + const key_const_wrap &a, + const key_const_wrap &b) const { + return operator()( + static_cast(a), + static_cast(b)); + } + template < + typename OtherType, + typename = std::enable_if_t< + !std::is_same_v, key_const_wrap> && + !std::is_same_v, pair_type>>> + inline constexpr auto operator()( + const key_const_wrap &a, + OtherType &&b) const { + return operator()( + static_cast(a), + std::forward(b)); + } + template < + typename OtherType, + typename = std::enable_if_t< + !std::is_same_v, key_const_wrap> && + !std::is_same_v, pair_type>>> + inline constexpr auto operator()( + OtherType &&a, + const key_const_wrap &b) const { + return operator()( + std::forward(a), + static_cast(b)); + } + inline constexpr auto operator()( + const pair_type &a, + const pair_type &b) const { + return operator()(a.first, b.first); + } + template < + typename OtherType, + typename = std::enable_if_t< + !std::is_same_v, pair_type>>> + inline constexpr auto operator()( + const pair_type &a, + OtherType &&b) const { + return operator()(a.first, std::forward(b)); + } + template < + typename OtherType, + typename = std::enable_if_t< + !std::is_same_v, pair_type>>> + inline constexpr auto operator()( + OtherType &&a, + const pair_type &b) const { + return operator()(std::forward(a), b.first); + } + + }; + using impl = std::deque; using iterator_base = flat_multi_map_iterator_base_impl< Key, Type, + Compare, typename impl::iterator, pair_type*, pair_type&>; using const_iterator_base = flat_multi_map_iterator_base_impl< Key, Type, + Compare, typename impl::const_iterator, const pair_type*, const pair_type&>; using reverse_iterator_base = flat_multi_map_iterator_base_impl< Key, Type, + Compare, typename impl::reverse_iterator, pair_type*, pair_type&>; using const_reverse_iterator_base = flat_multi_map_iterator_base_impl< Key, Type, + Compare, typename impl::const_reverse_iterator, const pair_type*, const pair_type&>; @@ -292,10 +365,10 @@ public: } iterator insert(const value_type &value) { - if (empty() || (value.first < front().first)) { + if (empty() || compare()(value.first, front().first)) { _impl.push_front(value); return begin(); - } else if (!(value.first < back().first)) { + } else if (!compare()(value.first, back().first)) { _impl.push_back(value); return (end() - 1); } @@ -303,10 +376,10 @@ public: return _impl.insert(where, value); } iterator insert(value_type &&value) { - if (empty() || (value.first < front().first)) { + if (empty() || compare()(value.first, front().first)) { _impl.push_front(std::move(value)); return begin(); - } else if (!(value.first < back().first)) { + } else if (!compare()(value.first, back().first)) { _impl.push_back(std::move(value)); return (end() - 1); } @@ -319,18 +392,22 @@ public: } bool removeOne(const Key &key) { - if (empty() || (key < front().first) || (back().first < key)) { + if (empty() + || compare()(key, front().first) + || compare()(back().first, key)) { return false; } auto where = getLowerBound(key); - if (key < where->first) { + if (compare()(key, where->first)) { return false; } _impl.erase(where); return true; } int removeAll(const Key &key) { - if (empty() || (key < front().first) || (back().first < key)) { + if (empty() + || compare()(key, front().first) + || compare()(back().first, key)) { return 0; } auto range = getEqualRange(key); @@ -349,26 +426,32 @@ public: } iterator findFirst(const Key &key) { - if (empty() || (key < front().first) || (back().first < key)) { + if (empty() + || compare()(key, front().first) + || compare()(back().first, key)) { return end(); } auto where = getLowerBound(key); - return (key < where->first) ? _impl.end() : where; + return compare()(key, where->first) ? _impl.end() : where; } const_iterator findFirst(const Key &key) const { - if (empty() || (key < front().first) || (back().first < key)) { + if (empty() + || compare()(key, front().first) + || compare()(back().first, key)) { return end(); } auto where = getLowerBound(key); - return (key < where->first) ? _impl.end() : where; + return compare()(key, where->first) ? _impl.end() : where; } bool contains(const Key &key) const { return findFirst(key) != end(); } int count(const Key &key) const { - if (empty() || (key < front().first) || (back().first < key)) { + if (empty() + || compare()(key, front().first) + || compare()(back().first, key)) { return 0; } auto range = getEqualRange(key); @@ -377,46 +460,39 @@ public: private: impl _impl; - friend class flat_map; + friend class flat_map; - struct Comparator { - inline bool operator()(const pair_type &a, const Key &b) { - return a.first < b; - } - inline bool operator()(const Key &a, const pair_type &b) { - return a < b.first; - } - }; typename impl::iterator getLowerBound(const Key &key) { - return base::lower_bound(_impl, key, Comparator()); + return base::lower_bound(_impl, key, compare()); } typename impl::const_iterator getLowerBound(const Key &key) const { - return base::lower_bound(_impl, key, Comparator()); + return base::lower_bound(_impl, key, compare()); } typename impl::iterator getUpperBound(const Key &key) { - return base::upper_bound(_impl, key, Comparator()); + return base::upper_bound(_impl, key, compare()); } typename impl::const_iterator getUpperBound(const Key &key) const { - return base::upper_bound(_impl, key, Comparator()); + return base::upper_bound(_impl, key, compare()); } std::pair< typename impl::iterator, typename impl::iterator > getEqualRange(const Key &key) { - return base::equal_range(_impl, key, Comparator()); + return base::equal_range(_impl, key, compare()); } std::pair< typename impl::const_iterator, typename impl::const_iterator > getEqualRange(const Key &key) const { - return base::equal_range(_impl, key, Comparator()); + return base::equal_range(_impl, key, compare()); } }; -template -class flat_map : private flat_multi_map { - using parent = flat_multi_map; +template +class flat_map : private flat_multi_map { + using parent = flat_multi_map; + using compare = typename parent::compare; using pair_type = typename parent::pair_type; public: @@ -450,29 +526,29 @@ public: using parent::contains; iterator insert(const value_type &value) { - if (this->empty() || (value.first < this->front().first)) { + if (this->empty() || compare()(value.first, this->front().first)) { this->_impl.push_front(value); return this->begin(); - } else if (this->back().first < value.first) { + } else if (compare()(this->back().first, value.first)) { this->_impl.push_back(value); return (this->end() - 1); } auto where = this->getLowerBound(value.first); - if (value.first < where->first) { + if (compare()(value.first, where->first)) { return this->_impl.insert(where, value); } return this->end(); } iterator insert(value_type &&value) { - if (this->empty() || (value.first < this->front().first)) { + if (this->empty() || compare()(value.first, this->front().first)) { this->_impl.push_front(std::move(value)); return this->begin(); - } else if (this->back().first < value.first) { + } else if (compare()(this->back().first, value.first)) { this->_impl.push_back(std::move(value)); return (this->end() - 1); } auto where = this->getLowerBound(value.first); - if (value.first < where->first) { + if (compare()(value.first, where->first)) { return this->_impl.insert(where, std::move(value)); } return this->end(); @@ -494,15 +570,15 @@ public: } Type &operator[](const Key &key) { - if (this->empty() || (key < this->front().first)) { + if (this->empty() || compare()(key, this->front().first)) { this->_impl.push_front({ key, Type() }); return this->front().second; - } else if (this->back().first < key) { + } else if (compare()(this->back().first, key)) { this->_impl.push_back({ key, Type() }); return this->back().second; } auto where = this->getLowerBound(key); - if (key < where->first) { + if (compare()(key, where->first)) { return this->_impl.insert(where, { key, Type() })->second; } return where->second; diff --git a/Telegram/SourceFiles/base/flat_set.h b/Telegram/SourceFiles/base/flat_set.h index 8dd2b1012a..184a0e0f1e 100644 --- a/Telegram/SourceFiles/base/flat_set.h +++ b/Telegram/SourceFiles/base/flat_set.h @@ -25,24 +25,24 @@ Copyright (c) 2014-2017 John Preston, https://desktop.telegram.org namespace base { -template +template > class flat_set; -template +template > class flat_multi_set; -template +template class flat_multi_set_iterator_base_impl; -template +template class flat_multi_set_iterator_base_impl { public: using iterator_category = typename iterator_impl::iterator_category; - using value_type = typename flat_multi_set::value_type; + using value_type = typename flat_multi_set::value_type; using difference_type = typename iterator_impl::difference_type; - using pointer = typename flat_multi_set::pointer; - using reference = typename flat_multi_set::reference; + using pointer = typename flat_multi_set::pointer; + using reference = typename flat_multi_set::reference; flat_multi_set_iterator_base_impl(iterator_impl impl = iterator_impl()) : _impl(impl) { @@ -101,8 +101,8 @@ public: private: iterator_impl _impl; - friend class flat_multi_set; - friend class flat_set; + friend class flat_multi_set; + friend class flat_set; Type &wrapped() { return _impl->wrapped(); @@ -110,50 +110,92 @@ private: }; -template +template class flat_multi_set { - using self = flat_multi_set; class const_wrap { public: - const_wrap(const Type &value) : _value(value) { + constexpr const_wrap(const Type &value) + : _value(value) { } - const_wrap(Type &&value) : _value(std::move(value)) { + constexpr const_wrap(Type &&value) + : _value(std::move(value)) { } - inline operator const Type&() const { + inline constexpr operator const Type&() const { return _value; } - Type &wrapped() { + constexpr Type &wrapped() { return _value; } - friend inline bool operator<(const Type &a, const const_wrap &b) { - return a < ((const Type&)b); - } - friend inline bool operator<(const const_wrap &a, const Type &b) { - return ((const Type&)a) < b; - } - friend inline bool operator<(const const_wrap &a, const const_wrap &b) { - return ((const Type&)a) < ((const Type&)b); - } - private: Type _value; }; + class compare { + public: + template < + typename OtherType1, + typename OtherType2, + typename = std::enable_if_t< + !std::is_same_v, const_wrap> && + !std::is_same_v, const_wrap>>> + inline constexpr auto operator()( + OtherType1 &&a, + OtherType2 &b) const { + return Compare()( + std::forward(a), + std::forward(b)); + } + inline constexpr auto operator()( + const const_wrap &a, + const const_wrap &b) const { + return operator()( + static_cast(a), + static_cast(b)); + } + template < + typename OtherType, + typename = std::enable_if_t< + !std::is_same_v, const_wrap>>> + inline constexpr auto operator()( + const const_wrap &a, + OtherType &&b) const { + return operator()( + static_cast(a), + std::forward(b)); + } + template < + typename OtherType, + typename = std::enable_if_t< + !std::is_same_v, const_wrap>>> + inline constexpr auto operator()( + OtherType &&a, + const const_wrap &b) const { + return operator()( + std::forward(a), + static_cast(b)); + } + + }; + using impl = std::deque; using iterator_base = flat_multi_set_iterator_base_impl< Type, + Compare, typename impl::iterator>; using const_iterator_base = flat_multi_set_iterator_base_impl< Type, + Compare, typename impl::const_iterator>; using reverse_iterator_base = flat_multi_set_iterator_base_impl< Type, + Compare, typename impl::reverse_iterator>; using const_reverse_iterator_base = flat_multi_set_iterator_base_impl< Type, + Compare, typename impl::const_reverse_iterator>; public: @@ -209,7 +251,7 @@ public: typename Iterator, typename = typename std::iterator_traits::iterator_category> flat_multi_set(Iterator first, Iterator last) : _impl(first, last) { - base::sort(_impl); + base::sort(_impl, compare()); } flat_multi_set(std::initializer_list iter) @@ -271,10 +313,10 @@ public: } iterator insert(const Type &value) { - if (empty() || (value < front())) { + if (empty() || compare()(value, front())) { _impl.push_front(value); return begin(); - } else if (!(value < back())) { + } else if (!compare()(value, back())) { _impl.push_back(value); return (end() - 1); } @@ -282,10 +324,10 @@ public: return _impl.insert(where, value); } iterator insert(Type &&value) { - if (empty() || (value < front())) { + if (empty() || compare()(value, front())) { _impl.push_front(std::move(value)); return begin(); - } else if (!(value < back())) { + } else if (!compare()(value, back())) { _impl.push_back(std::move(value)); return (end() - 1); } @@ -298,18 +340,22 @@ public: } bool removeOne(const Type &value) { - if (empty() || (value < front()) || (back() < value)) { + if (empty() + || compare()(value, front()) + || compare()(back(), value)) { return false; } auto where = getLowerBound(value); - if (value < *where) { + if (compare()(value, *where)) { return false; } _impl.erase(where); return true; } int removeAll(const Type &value) { - if (empty() || (value < front()) || (back() < value)) { + if (empty() + || compare()(value, front()) + || compare()(back(), value)) { return 0; } auto range = getEqualRange(value); @@ -328,26 +374,32 @@ public: } iterator findFirst(const Type &value) { - if (empty() || (value < front()) || (back() < value)) { + if (empty() + || compare()(value, front()) + || compare()(back(), value)) { return end(); } auto where = getLowerBound(value); - return (value < *where) ? _impl.end() : where; + return compare()(value, *where) ? _impl.end() : where; } const_iterator findFirst(const Type &value) const { - if (empty() || (value < front()) || (back() < value)) { + if (empty() + || compare()(value, front()) + || compare()(back(), value)) { return end(); } auto where = getLowerBound(value); - return (value < *where) ? _impl.end() : where; + return compare()(value, *where) ? _impl.end() : where; } bool contains(const Type &value) const { return findFirst(value) != end(); } int count(const Type &value) const { - if (empty() || (value < front()) || (back() < value)) { + if (empty() + || compare()(value, front()) + || compare()(back(), value)) { return 0; } auto range = getEqualRange(value); @@ -358,7 +410,7 @@ public: auto modify(iterator which, Action action) { auto result = action(which.wrapped()); for (auto i = which + 1, e = end(); i != e; ++i) { - if (*i < *which) { + if (compare()(*i, *which)) { std::swap(i.wrapped(), which.wrapped()); } else { break; @@ -366,7 +418,7 @@ public: } for (auto i = which, b = begin(); i != b;) { --i; - if (*which < *i) { + if (compare()(*which, *i)) { std::swap(i.wrapped(), which.wrapped()); } else { break; @@ -380,10 +432,10 @@ public: typename = typename std::iterator_traits::iterator_category> void merge(Iterator first, Iterator last) { _impl.insert(_impl.end(), first, last); - base::sort(_impl); + base::sort(_impl, compare()); } - void merge(const flat_multi_set &other) { + void merge(const flat_multi_set &other) { merge(other.begin(), other.end()); } @@ -393,38 +445,39 @@ public: private: impl _impl; - friend class flat_set; + friend class flat_set; typename impl::iterator getLowerBound(const Type &value) { - return base::lower_bound(_impl, value); + return base::lower_bound(_impl, value, compare()); } typename impl::const_iterator getLowerBound(const Type &value) const { - return base::lower_bound(_impl, value); + return base::lower_bound(_impl, value, compare()); } typename impl::iterator getUpperBound(const Type &value) { - return base::upper_bound(_impl, value); + return base::upper_bound(_impl, value, compare()); } typename impl::const_iterator getUpperBound(const Type &value) const { - return base::upper_bound(_impl, value); + return base::upper_bound(_impl, value, compare()); } std::pair< typename impl::iterator, typename impl::iterator > getEqualRange(const Type &value) { - return base::equal_range(_impl, value); + return base::equal_range(_impl, value, compare()); } std::pair< typename impl::const_iterator, typename impl::const_iterator > getEqualRange(const Type &value) const { - return base::equal_range(_impl, value); + return base::equal_range(_impl, value, compare()); } }; -template -class flat_set : private flat_multi_set { - using parent = flat_multi_set; +template +class flat_set : private flat_multi_set { + using parent = flat_multi_set; + using compare = typename parent::compare; public: using iterator = typename parent::iterator; @@ -469,29 +522,29 @@ public: using parent::erase; iterator insert(const Type &value) { - if (this->empty() || (value < this->front())) { + if (this->empty() || compare()(value, this->front())) { this->_impl.push_front(value); return this->begin(); - } else if (this->back() < value) { + } else if (compare()(this->back(), value)) { this->_impl.push_back(value); return (this->end() - 1); } auto where = this->getLowerBound(value); - if (value < *where) { + if (compare()(value, *where)) { return this->_impl.insert(where, value); } return this->end(); } iterator insert(Type &&value) { - if (this->empty() || (value < this->front())) { + if (this->empty() || compare()(value, this->front())) { this->_impl.push_front(std::move(value)); return this->begin(); - } else if (this->back() < value) { + } else if (compare()(this->back(), value)) { this->_impl.push_back(std::move(value)); return (this->end() - 1); } auto where = this->getLowerBound(value); - if (value < *where) { + if (compare()(value, *where)) { return this->_impl.insert(where, std::move(value)); } return this->end(); @@ -516,9 +569,9 @@ public: void modify(iterator which, Action action) { action(which.wrapped()); for (auto i = iterator(which + 1), e = end(); i != e; ++i) { - if (*i < *which) { + if (compare()(*i, *which)) { std::swap(i.wrapped(), which.wrapped()); - } else if (!(*which < *i)) { + } else if (!compare()(*which, *i)) { erase(which); return; } else{ @@ -527,9 +580,9 @@ public: } for (auto i = which, b = begin(); i != b;) { --i; - if (*which < *i) { + if (compare()(*which, *i)) { std::swap(i.wrapped(), which.wrapped()); - } else if (!(*i < *which)) { + } else if (!compare()(*i, *which)) { erase(which); return; } else { @@ -546,7 +599,7 @@ public: finalize(); } - void merge(const flat_multi_set &other) { + void merge(const flat_multi_set &other) { merge(other.begin(), other.end()); } @@ -560,7 +613,7 @@ private: std::unique( this->_impl.begin(), this->_impl.end(), - [](auto &&a, auto &&b) { return !(a < b); }), + [](auto &&a, auto &&b) { return !compare()(a, b); }), this->_impl.end()); } diff --git a/Telegram/SourceFiles/base/flat_set_tests.cpp b/Telegram/SourceFiles/base/flat_set_tests.cpp index 675b0bdafe..c428ab5c0e 100644 --- a/Telegram/SourceFiles/base/flat_set_tests.cpp +++ b/Telegram/SourceFiles/base/flat_set_tests.cpp @@ -29,6 +29,8 @@ TEST_CASE("flat_sets should keep items sorted", "[flat_set]") { v.insert(4); v.insert(2); + REQUIRE(v.contains(4)); + auto checkSorted = [&] { auto prev = v.begin(); REQUIRE(prev != v.end());