#include <iostream>
using namespace std;

/*函数作用:取待排序序列中low、mid、high三个位置上数据,选取他们中间的那个数据作为枢轴*/
int median(int arr[], int b[], int len1, int low, int high) {
	int mid = low + ((high - low) >> 1); //计算数组中间的元素的下标

	int &lowData = low >= len1 ? b[low - len1] : arr[low];
	int &midData = mid >= len1 ? b[mid - len1] : arr[mid];
	int &highData = high >= len1 ? b[high - len1] : arr[high];
	//使用三数取中法选择枢轴
	if (midData > highData) //目标: arr[mid] <= arr[high]
			{
		swap(midData, highData);
	}
	if (lowData > highData) //目标: arr[low] <= arr[high]
			{
		swap(lowData, highData);
	}
	if (midData > lowData) //目标: arr[low] >= arr[mid]
			{
		swap(midData, lowData);
	}
	//此时,arr[mid] <= arr[low] <= arr[high]
	return lowData;
	//low的位置上保存这三个位置中间的值
	//分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了
}

int kth_elem(int a[], int b[], int len1, int low, int high, int k) {
	int pivot = median(a, b, len1, low, high);

	//要么是选取数组中中位数作为枢纽元,保证最坏情况下,依然为线性O(N)的平均时间复杂度。
	int low_temp = low;
	int high_temp = high;
	while (low < high) {
		int tmp = high >= len1 ? b[high - len1] : a[high];
		while (low < high && tmp >= pivot) {
			--high;
			tmp = high >= len1 ? b[high - len1] : a[high];
		}
		if (low >= len1) {
			b[low - len1] = tmp;
		} else {
			a[low] = tmp;
		}

		int tmp1 = low >= len1 ? b[low - len1] : a[low];
		while (low < high && tmp1 < pivot) {
			++low;
			tmp1 = low >= len1 ? b[low - len1] : a[low];
		}
		if (high >= len1) {
			b[high - len1] = tmp1;
		} else {
			a[high] = tmp1;
		}
	}

	if (low >= len1) {
		b[low - len1] = pivot;
	} else {
		a[low] = pivot;
	}

	//以下就是主要思想中所述的内容
	if (low == k - 1) {
		if (low >= len1) {
			return b[low - len1];
		}
		return a[low];
	} else if (low > k - 1)
		return kth_elem(a, b, len1, low_temp, low - 1, k);
	else
		return kth_elem(a, b, len1, low + 1, high_temp, k);
}

void printArray(int* arr, int len) {
	if (!arr) {
		return;
	}
	for (int i = 0; i < len; ++i) {
		cout << arr[i] << " ";
	}
	cout << endl;
}

void print2SortedArray(int* a, int* b, int len1, int len2) {
	int* arr = new int[len1 + len2];
	for (int i = 0; i < len1; ++i) {
		arr[i] = a[i];
	}
	for (int i = len1, j = 0; j < len2; ++i, j++) {
		arr[i] = b[j];
	}
	sort(arr, arr + len1 + len2);
	printArray(arr, len1 + len2);
	delete arr;
}

int main() {
	int arr1[] = { 2, 12, 5, 10, 43, 24, 33, 4 };
	int arr2[] = { 10, 23, 41, 70, 84, 29, 6 };

	int len1 = sizeof(arr1) / sizeof(int);
	int len2 = sizeof(arr2) / sizeof(int);

	print2SortedArray(arr1, arr2, len1, len2);

	int mid1 = (len1 + len2) / 2 + 1;
	int mid2 = (len1 + len2) % 2 == 0 ? mid1 - 1 : mid1;

	int midData1 = kth_elem(arr1, arr2, len1, 0, len1 + len2 - 1, mid1);
	int midData2 = kth_elem(arr1, arr2, len1, 0, len1 + len2 - 1, mid2);

//	cout << midData1 << ',' << midData2 << endl;
	cout << "中位数: " << (midData1 + midData2) / 2 << endl;
	return 0;
}
04-26 19:16