Fix N-width sorting bug

This commit is contained in:
Viktor Lofgren 2023-05-28 11:52:00 +02:00
parent a57ab427b3
commit 6814c90625
4 changed files with 213 additions and 49 deletions

View File

@ -17,11 +17,12 @@ class SortAlgoInsertionSort {
for (int i = 1; i < span / sz; i++) {
long key = array.get(start + (long) i * sz);
int j;
for (j = i - 1; j >= 0 && array.get(start + (long) j * sz) > key; j--) {
array.swap(start + (long) j * sz, start + (long) (j + 1) * sz);
long j;
for (j = i - 1; j >= 0 && array.get(start + j* sz) > key; j--) {
array.swapn(sz, start + j *sz, start + (j + 1)*sz);
}
array.set(start + (long) (j + 1) * sz, key);
array.set(start + (j + 1) * sz, key);
}
}

View File

@ -79,7 +79,7 @@ class SortAlgoQuickSort {
static long _quickSortPartition(LongArraySort array, long low, long high) {
long pivotPoint = ((low + high) / (2L));
long pivotPoint = low + ((high - low) / 2L);
long pivot = array.get(pivotPoint);
long i = low - 1;
@ -102,9 +102,15 @@ class SortAlgoQuickSort {
static long _quickSortPartitionN(LongArraySort array, int wordSize, long low, long high) {
long pivotPoint = ((low + high) / (2L*wordSize)) * wordSize;
long delta = (high - low) / (2L);
long pivotPoint = low + (delta / wordSize) * wordSize;
long pivot = array.get(pivotPoint);
assert (pivotPoint - low) >= 0;
assert (pivotPoint - low) % wordSize == 0;
long i = low - wordSize;
long j = high + wordSize;

View File

@ -0,0 +1,200 @@
package nu.marginalia.array.algo;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import nu.marginalia.array.LongArray;
import nu.marginalia.array.page.LongArrayPage;
import nu.marginalia.array.page.PagingLongArray;
import nu.marginalia.array.scheme.PowerOf2PartitioningScheme;
import nu.marginalia.util.test.TestUtil;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag("slow")
class LongArraySortNTest {
LongArray basic;
LongArray paged;
LongArray shifted;
Long2ObjectOpenHashMap<LongOpenHashSet> basicPairs;
Long2ObjectOpenHashMap<LongOpenHashSet> pagedPairs;
Long2ObjectOpenHashMap<LongOpenHashSet> shiftedPairs;
final int size = 1026;
@BeforeEach
public void setUp() {
basic = LongArrayPage.onHeap(size);
paged = PagingLongArray.newOnHeap(new PowerOf2PartitioningScheme(32), size);
shifted = LongArrayPage.onHeap(size + 30).shifted(30);
var random = new Random();
long[] values = new long[size];
for (int i = 0; i < size; i++) {
values[i] = random.nextInt(0, 1000);
}
for (int i = 1; i < size; i+=2) {
values[i] = -values[i];
}
basic.transformEach(0, size, (i, old) -> values[(int) i]);
paged.transformEach(0, size, (i, old) -> values[(int) i]);
shifted.transformEach(0, size, (i, old) -> values[(int) i]);
basicPairs = asPairs(basic);
pagedPairs = asPairs(paged);
shiftedPairs = asPairs(shifted);
}
interface SortOperation {
void sort(LongArray array, long start, long end) throws IOException;
}
@Test
public void quickSortStressTest() throws IOException {
LongArray array = LongArray.allocate(65536);
sortAlgorithmTester(array, LongArraySort::quickSort);
}
@Test
public void insertionSortStressTest() throws IOException {
LongArray array = LongArray.allocate(8192);
sortAlgorithmTester(array, LongArraySort::insertionSort);
}
@Test
public void mergeSortStressTest() throws IOException {
LongArray array = LongArray.allocate(65536);
Path tempDir = Files.createTempDirectory(getClass().getSimpleName());
sortAlgorithmTester(array, (a, s, e) -> a.mergeSort(s, e, tempDir));
TestUtil.clearTempDir(tempDir);
}
void sortAlgorithmTester(LongArray array, SortOperation operation) throws IOException {
long[] values = new long[(int) array.size()];
for (int i = 0; i < values.length; i++) {
values[i] = i;
}
ArrayUtils.shuffle(values);
long sentinelA = 0xFEEDBEEFL;
long sentinelB = 0xB000B000L;
int start = 6;
for (int end = start + 1; end < values.length - 1; end+=97) {
// Use sentinel values to catch if the sort algorithm overwrites end values
array.set(start - 1, sentinelA);
array.set(end, sentinelB);
long orderInvariantChecksum = 0;
for (long i = 0; i < end - start; i++) {
array.set(start + i, values[start + (int)i]);
// Try to checksum the contents to catch bugs where the result is sorted
// but a value has been duplicated, overwriting another
orderInvariantChecksum ^= values[start + (int)i];
}
operation.sort(array, start, end);
assertTrue(array.isSorted(start, end), "Array wasn't sorted");
assertEquals(sentinelA, array.get(start - 1), "Start position sentinel overwritten");
assertEquals(sentinelB, array.get(end), "End position sentinel overwritten");
long actualChecksum = 0;
for (long i = start; i < end; i++) {
actualChecksum ^= array.get(i);
}
assertEquals(orderInvariantChecksum, actualChecksum, "Checksum validation failed");
}
}
private void compare(LongArray sorted, Long2ObjectOpenHashMap<LongOpenHashSet> expectedPairs) {
var actual = asPairs(sorted);
assertEquals(expectedPairs, actual);
}
@Test
void insertionSortN() {
basic.insertionSortN(2, 0, size);
assertTrue(basic.isSortedN(2, 0, size));
paged.insertionSortN(2, 0, size);
assertTrue(paged.isSortedN(2, 0, size));
shifted.insertionSortN(2, 0, size);
assertTrue(shifted.isSortedN(2, 0, size));
compare(basic, basicPairs);
compare(paged, pagedPairs);
compare(shifted, shiftedPairs);
}
@Test
void quickSortN() {
basic.quickSortN(2, 0, size);
assertTrue(basic.isSortedN(2, 0, size));
paged.quickSortN(2, 0, size);
assertTrue(paged.isSortedN(2, 0, size));
shifted.quickSortN(2, 0, size);
assertTrue(shifted.isSortedN(2, 0, size));
compare(basic, basicPairs);
compare(paged, pagedPairs);
compare(shifted, shiftedPairs);
}
@Test
void mergeSortN() throws IOException {
basic.mergeSortN(2, 0, size, Path.of("/tmp"));
assertTrue(basic.isSortedN(2, 0, size));
paged.mergeSortN(2, 0, size, Path.of("/tmp"));
assertTrue(paged.isSortedN(2, 0, size));
shifted.mergeSortN(2, 0, size, Path.of("/tmp"));
assertTrue(shifted.isSortedN(2, 0, size));
compare(basic, basicPairs);
compare(paged, pagedPairs);
compare(shifted, shiftedPairs);
}
private Long2ObjectOpenHashMap<LongOpenHashSet> asPairs(LongArray array) {
Long2ObjectOpenHashMap<LongOpenHashSet> map = new Long2ObjectOpenHashMap<>();
for (long i = 0; i < array.size(); i+=2) {
long key = array.get(i);
long val = array.get(i+1);
if (null == map.get(key)) {
var set = new LongOpenHashSet();
map.put(key, set);
}
map.get(key).add(val);
}
return map;
}
}

View File

@ -126,18 +126,6 @@ class LongArraySortTest {
assertTrue(shifted.isSorted(0, 128));
}
@Test
void insertionSortN() {
basic.insertionSortN(2, 0, size);
assertTrue(basic.isSortedN(2, 0, size));
paged.insertionSortN(2, 0, size);
assertTrue(paged.isSortedN(2, 0, size));
shifted.insertionSortN(2, 0, size);
assertTrue(shifted.isSortedN(2, 0, size));
}
@Test
void quickSort() {
basic.quickSort(0, size);
@ -150,36 +138,6 @@ class LongArraySortTest {
assertTrue(shifted.isSorted(0, size));
}
@Test
void quickSortN() {
basic.quickSortN(2, 0, size);
if (!basic.isSortedN(2, 0, size)) {
for (int i = 0; i < size; i+=2) {
System.out.println(basic.get(i));
}
}
assertTrue(basic.isSortedN(2, 0, size));
paged.quickSortN(2, 0, size);
assertTrue(paged.isSortedN(2, 0, size));
shifted.quickSortN(2, 0, size);
assertTrue(shifted.isSortedN(2, 0, size));
}
@Test
void mergeSortN() throws IOException {
basic.mergeSortN(2, 0, size, Path.of("/tmp"));
assertTrue(basic.isSortedN(2, 0, size));
paged.mergeSortN(2, 0, size, Path.of("/tmp"));
assertTrue(paged.isSortedN(2, 0, size));
shifted.mergeSortN(2, 0, size, Path.of("/tmp"));
assertTrue(shifted.isSortedN(2, 0, size));
}
@Test
void mergeSort() throws IOException {
basic.mergeSort(0, size, Path.of("/tmp"));
@ -192,7 +150,6 @@ class LongArraySortTest {
assertTrue(shifted.isSorted(0, size));
}
@Test
void keepUniqueFuzz() {
var array = LongArray.allocate(1000);