Median of Two Sorted Arrays

Median of Two Sorted Arrays

描述

给定两个长度为 m与 n 的有序数组, 输出两个数组的中位数, 如果为偶数, 输出上下中位数的平均数, 要求复杂度\(O(log(m+n))\)

示例:
输入: [1, 3], [2]
输出: 2.0

输入: [1, 2], [3, 4]
输出: 2.5

分析

中位数不是一个很简洁的概念, 有数组l, 长度为 m, 则中位数为:

mid = l[len(l) / 2] if len(l) % 2 != 0 else (l[len(l) / 2] + l[len(l) / 2 - 1]) / 2

因为存在这种情况, 计算中位数时边界条件往往会导致代码不好看, 所以, 对于这道题目, 排序后直接输出会使得代码比较好看, 两个有序数组 merge, 复杂度为\(O(m+n)\)


在 python 中, 可以有以下解法:

class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        l = sorted(nums1 + nums2)
        if len(l) == 0:
            return 0
        if len(l) % 2 != 0:
            return l[len(l) / 2]
        return (l[len(l) / 2] + l[len(l) / 2 - 1]) / 2.0

挺简洁的, 而且, 从时间来看, 速度快于50%的提交…

然而, 复杂度是是不符合要求的, 继续想一下, 对数复杂度的容易想到二分

二分思路并不复杂, 基本思想是, 一个数组中有大约一半的数字比中位数大, 另一半比中位数小, 从这两部分中去掉相同数量的数字, 中位数保持不变

如果两个数组长度相等, 最终每个数组都会只剩下两到三个数字, sort 并取中位数即可, 如果两个数组长度不相等, 每次去除较短中位数的一半, 会有一个数组先到达下限值, 之后做一次常数级别的查找即可

由于在数组长度为偶数时, 不能判定数组的下中位数是不是两个数组总和的中位数, 因此, 在每次去除时, 左右各少去除一个, 防止漏掉

在这种思想的指导下, 一份能 ac 但是十分丑陋的代码出现了, 如下所示:

class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        mstart = 0
        mend = len(nums1) - 1
        nstart = 0
        nend = len(nums2) - 1
        return self._findMedianSortedArrays(nums1, nums2, mstart, mend, nstart, nend)

    def _findMedianSortedArrays(self, m, n, mstart, mend, nstart, nend):
        if mend - mstart <= 3 or nend - nstart <= 3:
            if mend - mstart + nend - nstart <= 16:
                t = sorted(m[mstart:mend + 1] + n[nstart:nend + 1])
                if (len(m) + len(n)) % 2 == 0:
                    return (t[len(t) / 2] + t[len(t) / 2 - 1]) / 2.0
                return t[len(t) / 2]
            l = m[mstart:mend + 1]
            s = n[nstart:nend + 1]
            if len(s) > len(l):
                s, l = l, s
            k1 = (len(l) + len(s)) / 2
            if (len(m) + len(n)) % 2 == 0:
                k2 = k1 - 1
                return (self._findKth(l, s, k1) + self._findKth(l, s, k2)) / 2.0
            return self._findKth(l, s, k1)
        minlen = mend - mstart + 1
        if nend - nstart + 1 < minlen:
            minlen = nend - nstart
        mmid = (mstart + mend) / 2
        nmid = (nstart + nend) / 2
        if m[mmid] <= n[nmid]:
            return self._findMedianSortedArrays(m, n, mstart + minlen / 2 - 1, mend, nstart, nend - minlen / 2 + 1)
        else:
            return self._findMedianSortedArrays(n, m, nstart + minlen / 2 - 1, nend, mstart, mend - minlen / 2 + 1)

    def _findKth(self, m, n, k):
        o = sorted(m + n)
        k1 = len(m) / 2
        res = m[k1]
        realK = k1
        for i in n:
            if i < m[k1]:
                realK += 1
        if realK == k:
            return res
        if realK < k:
            delta = k - realK
            t = sorted(n + m[k1 + 1:k1 + k -realK + 1])
            for i in t:
                if i >= res:
                    delta -= 1
                    if delta == 0:
                        return i
        if realK > k:
            delta = realK - k
            t = sorted(n + m[k1 - delta:k1])
            for i in t[::-1]:
                if i <= res:
                    delta -= 1
                    if delta == 0:
                        return i

又臭又长, 导致这种情况的原因是中位数的不确定性, 如果能确定到底是要求第几位, 代码会好看很多, 而我们可以把这种不确定放在一开始, 在递归内保持逻辑的简洁, 即最终的答案

答案

class Solution(object):
    def findMedianSortedArrays(self, nums1, nums2):
        """
        :type nums1: List[int]
        :type nums2: List[int]
        :rtype: float
        """
        total = len(nums1) + len(nums2)
        if total % 2 != 0:
            return self.findKth(nums1, nums2, total / 2 + 1)
        else:
            return (self.findKth(nums1, nums2, total / 2) + self.findKth(nums1, nums2, total / 2 + 1)) / 2.0

    def findKth(self, m, n, k):
        ml = len(m)
        nl = len(n)
        if ml > nl:
            return self.findKth(n, m,k)
        if ml == 0:
            if nl == 0:
                return 0
            else:
                return n[k-1]
        if k == 1:
            return m[0] if m[0] < n[0] else n[0]

        pa = k / 2 if k / 2 < ml else ml
        pb = k - pa
        if m[pa - 1] > n[pb - 1]:
            return self.findKth(m, n[pb::], k - pb)
        if m[pa - 1] < n[pb - 1]:
            return self.findKth(m[pa::], n, k - pa)
        return m[pa - 1]
打赏