diff options
-rw-r--r-- | changes/bug7191 | 3 | ||||
-rw-r--r-- | src/common/container.c | 112 | ||||
-rw-r--r-- | src/test/test_containers.c | 36 |
3 files changed, 134 insertions, 17 deletions
diff --git a/changes/bug7191 b/changes/bug7191 new file mode 100644 index 000000000..f50b16e3a --- /dev/null +++ b/changes/bug7191 @@ -0,0 +1,3 @@ + o Bugfixes + - The smartlist_bsearch_idx() function was broken for lists of length zero + or one; fix it. This fixes bug7191. diff --git a/src/common/container.c b/src/common/container.c index d941048b0..d4a2f89c9 100644 --- a/src/common/container.c +++ b/src/common/container.c @@ -571,31 +571,109 @@ smartlist_bsearch_idx(const smartlist_t *sl, const void *key, int (*compare)(const void *key, const void **member), int *found_out) { - int hi = smartlist_len(sl) - 1, lo = 0, cmp, mid; + int hi, lo, cmp, mid, len, diff; + + tor_assert(sl); + tor_assert(compare); + tor_assert(found_out); + + len = smartlist_len(sl); + + /* Check for the trivial case of a zero-length list */ + if (len == 0) { + *found_out = 0; + /* We already know smartlist_len(sl) is 0 in this case */ + return 0; + } + + /* Okay, we have a real search to do */ + tor_assert(len > 0); + lo = 0; + hi = len - 1; + + /* + * These invariants are always true: + * + * For all i such that 0 <= i < lo, sl[i] < key + * For all i such that hi < i <= len, sl[i] > key + */ while (lo <= hi) { - mid = (lo + hi) / 2; + diff = hi - lo; + /* + * We want mid = (lo + hi) / 2, but that could lead to overflow, so + * instead diff = hi - lo (non-negative because of loop condition), and + * then hi = lo + diff, mid = (lo + lo + diff) / 2 = lo + (diff / 2). + */ + mid = lo + (diff / 2); cmp = compare(key, (const void**) &(sl->list[mid])); - if (cmp>0) { /* key > sl[mid] */ - lo = mid+1; - } else if (cmp<0) { /* key < sl[mid] */ - hi = mid-1; - } else { /* key == sl[mid] */ + if (cmp == 0) { + /* sl[mid] == key; we found it */ *found_out = 1; return mid; - } - } - /* lo > hi. */ - { - tor_assert(lo >= 0); - if (lo < smartlist_len(sl)) { - cmp = compare(key, (const void**) &(sl->list[lo])); + } else if (cmp > 0) { + /* + * key > sl[mid] and an index i such that sl[i] == key must + * have i > mid if it exists. + */ + + /* + * Since lo <= mid <= hi, hi can only decrease on each iteration (by + * being set to mid - 1) and hi is initially len - 1, mid < len should + * always hold, and this is not symmetric with the left end of list + * mid > 0 test below. A key greater than the right end of the list + * should eventually lead to lo == hi == mid == len - 1, and then + * we set lo to len below and fall out to the same exit we hit for + * a key in the middle of the list but not matching. Thus, we just + * assert for consistency here rather than handle a mid == len case. + */ + tor_assert(mid < len); + /* Move lo to the element immediately after sl[mid] */ + lo = mid + 1; + } else { + /* This should always be true in this case */ tor_assert(cmp < 0); - } else if (smartlist_len(sl)) { - cmp = compare(key, (const void**) &(sl->list[smartlist_len(sl)-1])); - tor_assert(cmp > 0); + + /* + * key < sl[mid] and an index i such that sl[i] == key must + * have i < mid if it exists. + */ + + if (mid > 0) { + /* Normal case, move hi to the element immediately before sl[mid] */ + hi = mid - 1; + } else { + /* These should always be true in this case */ + tor_assert(mid == lo); + tor_assert(mid == 0); + /* + * We were at the beginning of the list and concluded that every + * element e compares e > key. + */ + *found_out = 0; + return 0; + } } } + + /* + * lo > hi; we have no element matching key but we have elements falling + * on both sides of it. The lo index points to the first element > key. + */ + tor_assert(lo == hi + 1); /* All other cases should have been handled */ + tor_assert(lo >= 0); + tor_assert(lo <= len); + tor_assert(hi >= 0); + tor_assert(hi <= len); + + if (lo < len) { + cmp = compare(key, (const void **) &(sl->list[lo])); + tor_assert(cmp < 0); + } else { + cmp = compare(key, (const void **) &(sl->list[len-1])); + tor_assert(cmp > 0); + } + *found_out = 0; return lo; } diff --git a/src/test/test_containers.c b/src/test/test_containers.c index 10146c5f6..399ef8e90 100644 --- a/src/test/test_containers.c +++ b/src/test/test_containers.c @@ -16,6 +16,15 @@ compare_strs_(const void **a, const void **b) return strcmp(s1, s2); } +/** Helper: return a tristate based on comparing the strings in <b>a</b> and + * *<b>b</b>. */ +static int +compare_strs_for_bsearch_(const void *a, const void **b) +{ + const char *s1 = a, *s2 = *b; + return strcmp(s1, s2); +} + /** Helper: return a tristate based on comparing the strings in *<b>a</b> and * *<b>b</b>, excluding a's first character, and ignoring case. */ static int @@ -204,6 +213,8 @@ test_container_smartlist_strings(void) /* Test bsearch_idx */ { int f; + smartlist_t *tmp = NULL; + test_eq(0, smartlist_bsearch_idx(sl," aaa",compare_without_first_ch_,&f)); test_eq(f, 0); test_eq(0, smartlist_bsearch_idx(sl," and",compare_without_first_ch_,&f)); @@ -216,6 +227,31 @@ test_container_smartlist_strings(void) test_eq(f, 0); test_eq(7, smartlist_bsearch_idx(sl," zzzz",compare_without_first_ch_,&f)); test_eq(f, 0); + + /* Test trivial cases for list of length 0 or 1 */ + tmp = smartlist_new(); + test_eq(0, smartlist_bsearch_idx(tmp, "foo", + compare_strs_for_bsearch_, &f)); + test_eq(f, 0); + smartlist_insert(tmp, 0, (void *)("bar")); + test_eq(1, smartlist_bsearch_idx(tmp, "foo", + compare_strs_for_bsearch_, &f)); + test_eq(f, 0); + test_eq(0, smartlist_bsearch_idx(tmp, "aaa", + compare_strs_for_bsearch_, &f)); + test_eq(f, 0); + test_eq(0, smartlist_bsearch_idx(tmp, "bar", + compare_strs_for_bsearch_, &f)); + test_eq(f, 1); + /* ... and one for length 2 */ + smartlist_insert(tmp, 1, (void *)("foo")); + test_eq(1, smartlist_bsearch_idx(tmp, "foo", + compare_strs_for_bsearch_, &f)); + test_eq(f, 1); + test_eq(2, smartlist_bsearch_idx(tmp, "goo", + compare_strs_for_bsearch_, &f)); + test_eq(f, 0); + smartlist_free(tmp); } /* Test reverse() and pop_last() */ |