T* lower_bound(vector<T>& arr, int val) { // [l,r] : possible result int l = 0, r = arr.size() - 1; while (l < r) { // About whether the mid should be closer to l or r // this is mattered when l + 1 == r int mid = l + (r - l) / 2; // note since Less is operator<, we should always identify the less relation // this would be LessT(arr[mid], val) for std::upper_bound if (LessT(arr[mid], val)) l = mid + 1; else r = mid; // take a look at the range shrinking process // l = mid +1 // or // r = mid // since l = mid + 1 is always ok for range reduce // so to keep r = mid always do so, the mid must be closer to l } if (!LessT(arr[l], val)) return arr.data() + l; else return arr.data() + arr.size(); }
T* lower_bound(vector<T>& arr, int val) { // [l,r] : possible retval int l = 0, r = arr.size(); while (l < r) { int mid = l + (r - l) / 2; if (LessT(arr[mid - 1], val)) l = mid + 1; else r = mid; } return arr.data() + l; }
usingnamespace std; using T = int; boolLessT(T a, T b) { return a < b; }
structBS { virtual T* lower_bound(vector<T>& arr, int val)= 0; }; structBS_std : public BS { T* lower_bound(vector<T>& arr, int val) { return std::lower_bound(arr.data(), arr.data() + arr.size(), val, LessT); } };
structBS_V1 : public BS { T* lower_bound(vector<T>& arr, int val) { // [l,r] : possible result int l = 0, r = arr.size() - 1; while (l < r) { // About whether the mid should be closer to l or r // this is mattered when l + 1 == r int mid = l + (r - l) / 2; // note since Less is operator<, we should always identify the less relation // this would be LessT(arr[mid], val) for std::upper_bound if (LessT(arr[mid], val)) l = mid + 1; else r = mid; // take a look at the range shrinking process // l = mid +1 // or // r = mid // since l = mid + 1 is always ok for range reduce // so to keep r = mid always do so, the mid be closer to l would be nice. } if (!LessT(arr[l], val)) return arr.data() + l; else return arr.data() + arr.size(); } };
structBS_V2 : public BS { T* lower_bound(vector<T>& arr, int val) { // [l,r] : possible retval int l = 0, r = arr.size(); while (l < r) { int mid = l + (r - l) / 2; if (LessT(arr[mid - 1], val)) l = mid + 1; else r = mid; } return arr.data() + l; } };
intmain() { int testRunCount = 1000; int arrSize = 1000; srand(time(0));
BS* bss[] = { newBS_V1(), newBS_V2(), }; for (int r = 0; r < testRunCount; r++) { vector<T> data(arrSize); for (int i = 0; i < arrSize; i++) data[i] = rand() % (testRunCount * 2); sort(data.begin(), data.end(), LessT);
for (int t = 0; t < 3; ++t) { int tval = rand() % (testRunCount * 2); T* std_result = BS_std().lower_bound(data, tval);
for (auto bs : bss) { T* result = BS_V1().lower_bound(data, tval); assert(std_result == result); } } } }