题目要求

Given a singly linked list, return a random node's value from the linked list. Each node must have the same probability of being chosen.

Follow up:
What if the linked list is extremely large and its length is unknown to you? Could you solve this efficiently without using extra space?

Example:

// Init a singly linked list [1,2,3].
ListNode head = new ListNode(1);
head.next = new ListNode(2);
head.next.next = new ListNode(3);
Solution solution = new Solution(head);

// getRandom() should return either 1, 2, or 3 randomly. Each element should have equal probability of returning.
solution.getRandom();

要求从单链表中,随机返回一个节点的值,要求每个节点被选中的概率是相等的。

思路和代码

在等概率随机选择算法中,最经典的算法就是蓄水池算法。可以参考同类型题目398 random pick index。这里再次整理一下蓄水池算法的思路和简单证明。

假如一共有N个物品,需要从其中挑选出K个物品,要求确保N个物品中每个物品都能够被等概率选中。对于这种等概率问题,简答的做法是通过随机数获取选中物品的下标。但是蓄水池算法允许我们从数据流的角度来随机获得K个物品,即在并不知道总体的样本数有多少的情况下,随机抽取K个物品。

蓄水池算法的思路如下:

  1. 选中前K个物品放入蓄水池
  2. 对于第K+1个物品,其被选中并替换蓄水池中任意一个物品的概率为K/(K+1)
  3. 对于第K+i个物品,其被选中并替换蓄水池中任意一个物品的概率为K/(K+i)
  4. 重复这个步骤直到K+i=N

对于这个算法,我们可以采用归纳法进行简单证明。已知对于前K个物品,每个物品的被选中的概率为1,满足了K/K=1的概率。
对于K+i-1个物品,假设每个物品被选中的概率为K/(K+i-1)。证明对于前K+i个物品,每个物品被放入蓄水池中的概率为K/(K+i)

  1. 对于第K+i个物品,其被选中并替换蓄水池中任意一个物品的概率为K/(K+i)
  2. 对于之前在蓄水池中的物品,其仍在蓄水池中的概率为之前被选中在蓄水池中概率乘以这一次未被换出蓄水池的概率,即P = P(上一轮在蓄水池中) * P(这一轮没有被替换掉)。对此进行计算,P(上一轮在蓄水池中) * P(这一轮没有被替换掉) = P(上一轮在蓄水池中) * (1-P(这一轮被替换掉)) = (K / (K+i-1)) * (1 - (P * 1/K)),算出P = K/(K+i)
  3. 证明对于前K+i个物品,每个物品被放入蓄水池中的概率为K/(K+i),当K+i等于N时,每个物品被选中的概率为K/N

在本题中,使用蓄水池算法的N为单链表的长度,K为1。

代码如下:

    private ListNode head;
    private Random r;
    /** @param head The linked list's head.
    Note that the head is guaranteed to be not null, so it contains at least one node. */
    public Solution(ListNode head) {
        this.head = head;
        this.r = new Random();
    }

    /** Returns a random node's value. */
    public int getRandom() {
        ListNode tmp = this.head;
        int result = 0;
        int index = 1;
        do{
            if(r.nextInt(index) == 0) {
                result = tmp.val;
            }
            tmp = tmp.next;
            index++;
        }while(tmp != null);
        return result;
    }
03-05 15:23