#ifndef __crush_BUCKET_H #define __crush_BUCKET_H #include "BinaryTree.h" #include "Hash.h" #include #include #include #include using namespace std; #include #include "include/buffer.h" namespace crush { const int CRUSH_BUCKET_UNIFORM = 1; const int CRUSH_BUCKET_TREE = 2; const int CRUSH_BUCKET_LIST = 3; const int CRUSH_BUCKET_STRAW = 4; /** abstract bucket **/ class Bucket { protected: int id; int parent; int type; float weight; public: Bucket(int _type, float _weight) : id(0), parent(0), type(_type), weight(_weight) { } Bucket(bufferlist& bl, int& off) { bl.copy(off, sizeof(id), (char*)&id); off += sizeof(id); bl.copy(off, sizeof(parent), (char*)&parent); off += sizeof(parent); bl.copy(off, sizeof(type), (char*)&type); off += sizeof(type); bl.copy(off, sizeof(weight), (char*)&weight); off += sizeof(weight); } virtual ~Bucket() { } virtual const char *get_bucket_type() const = 0; virtual bool is_uniform() const = 0; int get_id() const { return id; } int get_type() const { return type; } float get_weight() const { return weight; } int get_parent() const { return parent; } virtual int get_size() const = 0; void set_id(int i) { id = i; } void set_parent(int p) { parent = p; } void set_weight(float w) { weight = w; } virtual void get_items(vector& i) const = 0; virtual float get_item_weight(int item) const = 0; virtual void add_item(int item, float w, bool back=false) = 0; virtual void adjust_item_weight(int item, float w) = 0; virtual void set_item_weight(int item, float w) { adjust_item_weight(item, w - get_item_weight(item)); } virtual int choose_r(int x, int r, Hash& h) const = 0; virtual void _encode(bufferlist& bl) = 0; }; /** uniform bucket **/ class UniformBucket : public Bucket { protected: public: vector items; int item_type; float item_weight; // primes vector primes; int get_prime(int j) const { return primes[ j % primes.size() ]; } void make_primes() { if (items.empty()) return; //cout << "make_primes " << get_id() << " " << items.size() << endl; Hash h(123+get_id()); primes.clear(); // start with odd number > num_items unsigned x = items.size() + 1; // this is the minimum! x += h(items.size()) % (3*items.size()); // bump it up some x |= 1; // make it odd while (primes.size() < items.size()) { unsigned j; for (j=2; j*j<=x; j++) if (x % j == 0) break; if (j*j > x) { primes.push_back(x); //cout << "prime " << x << endl; } x += 2; } } public: UniformBucket(int _type, int _item_type) : Bucket(_type, 0), item_type(_item_type) { } UniformBucket(int _type, int _item_type, float _item_weight, vector& _items) : Bucket(_type, _item_weight*_items.size()), item_type(_item_type), item_weight(_item_weight) { items = _items; make_primes(); } UniformBucket(bufferlist& bl, int& off) : Bucket(bl, off) { bl.copy(off, sizeof(item_type), (char*)&item_type); off += sizeof(item_type); bl.copy(off, sizeof(item_weight), (char*)&item_weight); off += sizeof(item_weight); ::_decode(items, bl, off); make_primes(); } void _encode(bufferlist& bl) { char t = CRUSH_BUCKET_UNIFORM; bl.append((char*)&t, sizeof(t)); bl.append((char*)&id, sizeof(id)); bl.append((char*)&parent, sizeof(parent)); bl.append((char*)&type, sizeof(type)); bl.append((char*)&weight, sizeof(weight)); bl.append((char*)&item_type, sizeof(item_type)); bl.append((char*)&item_weight, sizeof(item_weight)); ::_encode(items, bl); } const char *get_bucket_type() const { return "uniform"; } bool is_uniform() const { return true; } int get_size() const { return items.size(); } // items void get_items(vector& i) const { i = items; } int get_item_type() const { return item_type; } float get_item_weight(int item) const { return item_weight; } void add_item(int item, float w, bool back=false) { if (items.empty()) item_weight = w; items.push_back(item); weight += item_weight; make_primes(); } void adjust_item_weight(int item, float w) { assert(0); } int choose_r(int x, int r, Hash& hash) const { //cout << "uniformbucket.choose_r(" << x << ", " << r << ")" << endl; //if (r >= get_size()) cout << "warning: r " << r << " >= " << get_size() << " uniformbucket.size" << endl; unsigned v = hash(x, get_id());// % get_size(); unsigned p = get_prime( hash(get_id(), x) ); // choose a prime based on hash(x, get_id(), 2) unsigned s = (x + v + (r+1)*p) % get_size(); return items[s]; } }; // list bucket.. RUSH_P sorta class ListBucket : public Bucket { protected: list items; list item_weight; list sum_weight; public: ListBucket(int _type) : Bucket(_type, 0) { } ListBucket(bufferlist& bl, int& off) : Bucket(bl, off) { ::_decode(items, bl, off); ::_decode(item_weight, bl, off); ::_decode(sum_weight, bl, off); } void _encode(bufferlist& bl) { char t = CRUSH_BUCKET_LIST; bl.append((char*)&t, sizeof(t)); bl.append((char*)&id, sizeof(id)); bl.append((char*)&parent, sizeof(parent)); bl.append((char*)&type, sizeof(type)); bl.append((char*)&weight, sizeof(weight)); ::_encode(items, bl); ::_encode(item_weight, bl); ::_encode(sum_weight, bl); } const char *get_bucket_type() const { return "list"; } bool is_uniform() const { return false; } int get_size() const { return items.size(); } void get_items(vector& i) const { for (list::const_iterator it = items.begin(); it != items.end(); it++) i.push_back(*it); } float get_item_weight(int item) const { list::const_iterator i = items.begin(); list::const_iterator w = item_weight.begin(); while (i != items.end()) { if (*i == item) return *w; i++; w++; } assert(0); return 0; } void add_item(int item, float w, bool back=false) { if (back) { items.push_back(item); item_weight.push_back(w); sum_weight.clear(); float s = 0.0; for (list::reverse_iterator i = item_weight.rbegin(); i != item_weight.rend(); i++) { s += *i; sum_weight.push_front(s); } weight += w; assert(weight == s); } else { items.push_front(item); item_weight.push_front(w); weight += w; sum_weight.push_front(weight); } } void adjust_item_weight(int item, float dw) { // find it list::iterator p = items.begin(); list::iterator pw = item_weight.begin(); list::iterator ps = sum_weight.begin(); while (*p != item) { *ps += dw; p++; pw++; ps++; // next! assert(p != items.end()); } assert(*p == item); *pw += dw; *ps += dw; } int choose_r(int x, int r, Hash& h) const { //cout << "linearbucket.choose_r(" << x << ", " << r << ")" << endl; list::const_iterator p = items.begin(); list::const_iterator pw = item_weight.begin(); list::const_iterator ps = sum_weight.begin(); while (p != items.end()) { const int item = *p; const float iw = *pw; const float tw = *ps; const float f = (float)(h(x, item, r, get_id()) % 10000) * tw / 10000.0; //cout << "item " << item << " iw = " << iw << " tw = " << tw << " f = " << f << endl; if (f < iw) { //cout << "linearbucket.choose_r(" << x << ", " << r << ") = " << item << endl; return item; } p++; pw++; ps++; // next! } assert(0); return 0; } }; // mixed bucket, based on RUSH_T type binary tree class TreeBucket : public Bucket { protected: //vector item_weight; // public: BinaryTree tree; map node_item; // node id -> item vector node_item_vec; // fast version of above map item_node; // item -> node id map item_weight; public: TreeBucket(int _type) : Bucket(_type, 0) { } TreeBucket(bufferlist& bl, int& off) : Bucket(bl, off) { tree._decode(bl, off); ::_decode(node_item, bl, off); ::_decode(node_item_vec, bl, off); ::_decode(item_node, bl, off); ::_decode(item_weight, bl, off); } void _encode(bufferlist& bl) { char t = CRUSH_BUCKET_TREE; bl.append((char*)&t, sizeof(t)); bl.append((char*)&id, sizeof(id)); bl.append((char*)&parent, sizeof(parent)); bl.append((char*)&type, sizeof(type)); bl.append((char*)&weight, sizeof(weight)); tree._encode(bl); ::_encode(node_item, bl); ::_encode(node_item_vec, bl); ::_encode(item_node, bl); ::_encode(item_weight, bl); } const char *get_bucket_type() const { return "tree"; } bool is_uniform() const { return false; } int get_size() const { return node_item.size(); } // items void get_items(vector& i) const { for (map::const_iterator it = node_item.begin(); it != node_item.end(); it++) i.push_back(it->second); } float get_item_weight(int i) const { assert(item_weight.count(i)); return ((map)item_weight)[i]; } void add_item(int item, float w, bool back=false) { item_weight[item] = w; weight += w; unsigned n = tree.add_node(w); node_item[n] = item; item_node[item] = n; while (node_item_vec.size() <= n) node_item_vec.push_back(0); node_item_vec[n] = item; } void adjust_item_weight(int item, float dw) { // adjust my weight weight += dw; item_weight[item] += dw; // adjust tree weights tree.adjust_node_weight(item_node[item], dw); } int choose_r(int x, int r, Hash& h) const { //cout << "mixedbucket.choose_r(" << x << ", " << r << ")" << endl; int n = tree.root(); while (!tree.terminal(n)) { // pick a point in [0,w) float w = tree.weight(n); float f = (float)(h(x, n, r, get_id()) % 10000) * w / 10000.0; // left or right? int l = tree.left(n); if (tree.exists(l) && f < tree.weight(l)) n = l; else n = tree.right(n); } //assert(node_item.count(n)); //return ((map)node_item)[n]; return node_item_vec[n]; } }; // straw bucket.. new thing! class StrawBucket : public Bucket { protected: map item_weight; map item_straw; list _items; list _straws; public: StrawBucket(int _type) : Bucket(_type, 0) { } StrawBucket(bufferlist& bl, int& off) : Bucket(bl, off) { ::_decode(item_weight, bl, off); calc_straws(); } void _encode(bufferlist& bl) { char t = CRUSH_BUCKET_TREE; bl.append((char*)&t, sizeof(t)); bl.append((char*)&id, sizeof(id)); bl.append((char*)&parent, sizeof(parent)); bl.append((char*)&type, sizeof(type)); bl.append((char*)&weight, sizeof(weight)); ::_encode(item_weight, bl); } const char *get_bucket_type() const { return "straw"; } bool is_uniform() const { return false; } int get_size() const { return item_weight.size(); } // items void get_items(vector& i) const { for (map::const_iterator it = item_weight.begin(); it != item_weight.end(); it++) i.push_back(it->first); } float get_item_weight(int item) const { assert(item_weight.count(item)); return ((map)item_weight)[item]; } void add_item(int item, float w, bool back=false) { item_weight[item] = w; weight += w; calc_straws(); } void adjust_item_weight(int item, float dw) { //cout << "adjust " << item << " " << dw << endl; weight += dw; item_weight[item] += dw; calc_straws(); } /* calculate straw lengths. this is kind of ugly. not sure if there's a closed form way to calculate this or not! */ void calc_straws() { //cout << get_id() << ": calc_straws ============" << endl; item_straw.clear(); _items.clear(); _straws.clear(); // reverse sort by weight; skip zero weight items map > reverse; for (map::iterator p = item_weight.begin(); p != item_weight.end(); p++) { //cout << get_id() << ":" << p->first << " " << p->second << endl; if (p->second > 0) { //p->second /= minw; reverse[p->second].insert(p->first); } } /* 1:2:7 item_straw[0] = 1.0; item_straw[1] = item_straw[0]*sqrt(1.0/.6); item_straw[2] = item_straw[1]*2.0; */ // work from low to high weights float straw = 1.0; float numleft = item_weight.size(); float wbelow = 0.0; float lastw = 0.0; map >::iterator next = reverse.begin(); //while (next != reverse.end()) { while (1) { //cout << "hi " << next->first << endl; map >::iterator cur = next; // set straw length for this set for (set::iterator s = cur->second.begin(); s != cur->second.end(); s++) { item_straw[*s] = straw; //cout << "straw " << *s << " w " << item_weight[*s] << " -> " << straw << endl; _items.push_back(*s); _straws.push_back(straw); } next++; if (next == reverse.end()) break; wbelow += (cur->first-lastw) * numleft; //cout << "wbelow " << wbelow << endl; numleft -= 1.0 * (float)cur->second.size(); //cout << "numleft now " << numleft << endl; float wnext = numleft * (next->first - cur->first); //cout << "wnext " << wnext << endl; float pbelow = wbelow / (wbelow+wnext); //cout << "pbelow " << pbelow << endl; straw *= pow((double)(1.0/pbelow), (double)1.0/numleft); lastw = cur->first; } //cout << "============" << endl; } int choose_r(int x, int r, Hash& h) const { //cout << "strawbucket.choose_r(" << x << ", " << r << ")" << endl; float high_draw = -1; int high = 0; list::const_iterator pi = _items.begin(); list::const_iterator ps = _straws.begin(); while (pi != _items.end()) { const int item = *pi; const float rnd = (float)(h(x, item, r) % 1000000) / 1000000.0; const float straw = *ps * rnd; if (high_draw < 0 || straw > high_draw) { high = *pi; high_draw = straw; } pi++; ps++; } return high; } }; inline Bucket* decode_bucket(bufferlist& bl, int& off) { char t; bl.copy(off, sizeof(t), (char*)&t); off += sizeof(t); switch (t) { case CRUSH_BUCKET_UNIFORM: return new UniformBucket(bl, off); case CRUSH_BUCKET_LIST: return new ListBucket(bl, off); case CRUSH_BUCKET_TREE: return new TreeBucket(bl, off); case CRUSH_BUCKET_STRAW: return new StrawBucket(bl, off); default: assert(0); } return 0; } } #endif