vlambda博客
学习文章列表

一道关于二叉树的字节面试题的思考

技术人的精神,就是追根究底,把一个事情彻底弄清楚吧!

题目

众所周知,字节在一二面的末尾,会随机抽一道算法题,当场写代码。我抽到的题目如下:

二叉树根节点到叶子节点的所有路径和。给定一个仅包含数字 0−9 的二叉树,每一条从根节点到叶子节点的路径都可以用一个数字表示。例如根节点到叶子节点的一条路径是 1→2→3,那么这条路径就用 123 来代替。找出根节点到叶子节点的所有路径表示的数字之和。

例如:这棵二叉树一共有两条路径,根节点到左叶子节点的路径 12 代替,根节点到右叶子节点的路径用 13 代替。所以答案为12+13=25 。

递归解法

看到这个题目,首先想到的,其实就是找到这个二叉树的从根节点到叶子节点的所有路径。而要找到所有路径,第一想到的肯定是递归。通过左子树的递归拿到的路径、右子树的递归拿到的路径,以及根节点,得出最终的所有路径。

算法如下:

STEP1:如果已经是叶子节点,那么构造一条路径列表,该路径只有一个元素即叶子节点的值,然后返回【退出条件】。

STEP2:  递归找到左子树到叶子节点的所有路径列表。对于每条路径,将根节点加入,从而得到新的结果路径,并加入;

STEP3:递归找到右子树到叶子节点的所有路径列表。对于每条路径,将根节点加入,从而得到新的结果路径,并加入;

STEP4:  将左右子树的所有路径合并成最终的路径列表【组合子问题的解】。

有两点说明下:

  • 由于节点的值只有 0-9,因此,可以直接用字符串来表示路径。如果用 List[Integer] ,更灵活,不过会变成列表的列表,处理起来会有点绕。
  • 构建路径时,使用的是 StringBuilder 的 append 方法,而不是 insert 方法,因此构造的路径是逆序的。主要考虑到 insert 方法会导致数组频繁移动,效率低。具体可以看 StringBuilder 实现。

递归代码如下:

public List<Path> findAllPaths(TreeNode root) {
        List<Path> le = new ArrayList<>();
        List<Path> ri = new ArrayList<>();
        if (root != null) {

            if (root.left == null && root.right == null) {
                List<Path> single = new ArrayList<>();
                single.add(new Path(root.val));
                return single;
            }

            if (root.left != null) {
                le = findAllPaths(root.left);
                for (Path p: le) {
                    p.append(root.val);
                }
            }
            if (root.right != null) {
                ri = findAllPaths(root.right);
                for (Path p: ri) {
                    p.append(root.val);
                }
            }
        }
        List<Path> paths = new ArrayList<>();
        paths.addAll(le);
        paths.addAll(ri);
        return paths;
    }


class Path {
    StringBuilder s = new StringBuilder();
    public Path() { }

    public Path(Integer i) {
        s.append(i);
    }

    public Path(List list) {
        list.forEach( e-> {
            s.append(e);
        });
    }

    public Path(String str) this.s = new StringBuilder(str); }

    public Long getValue() {
        return Long.parseLong(s.reverse().toString());
    }

    public StringBuilder append(Integer i) {
        return s.append(i);
    }

    public String toString() {
        return s.reverse().toString();
    }
}


class TreeNode {

    int val;
    TreeNode left;
    TreeNode right;
    TreeNode(int x) { val = x; }

    public int height() {
        if (left == null && right == null) {
            return 1;
        }
        int leftHeight = 0;
        int rightHeight = 0;
        if (left != null) {
            leftHeight = left.height();
        }
        if (right != null) {
            rightHeight = right.height();
        }
        return 1 + max(leftHeight, rightHeight);
    }
}

「关键点」

实际上,我在面试当场没有做出来,但在面试后的十分钟,我就把代码写出来了。可能在面试的时候有点紧张,有个地方一直卡住了。

类似二叉树、动态规划的问题,由于有多条分支,从思维上来说,不像处理数组、链表那样是一种线性思维,而是需要一种非线性思维,因此,多做类似的题目,对思维的锻炼是很有益的,—— 能够帮助人摆脱固有的线性思维。

一般来说,算法问题,通常可以分为两步:1.  划分子问题;2.  将子问题的解组合成原问题的解。划分子问题,相对容易一点,但如果划分不合理,就难以想清楚如何去组合解。我一开始就想到了要用子树的解与根节点来组合,但是一直纠结在对求出单条路径的思考上,而不是把所有路径作为子问题的解。这样,我就难以想到如何去组合得到最终解。但面试结束之后,我脑子里闪过左子树的所有路径列表,顿时就明白如何组合了。因此,有时,把“所有”作为子问题的解,再跟上层节点组合,反而能容易地得到原问题的解。此外,递归要特别注意退出条件。

推荐可以多做二叉树、动态规划的题目,能够很好地锻炼划分子问题、组合子问题的解来求解的技能。

非递归算法

实现递归解法,只是一个开始。递归算法很简洁,但执行效率很低,而且容易栈溢出。如果一个足够大的二叉树,就能让递归代码无法执行下去。因此,需要寻求非递归实现。

非递归实现,往往需要借助于栈。我们需要模拟一下如何用栈来访问二叉树。如下图所示:

可以先找找规则,往往规则就是代码的路径。

  • 每次走到一个节点,先将节点值入栈;
  • 走到叶子节点时,说明已经走到路径的尽头,可以记录下这条路径。

第一版非递归实现如下。用一个栈来存储二叉树的访问节点。如果是叶子节点,就记录路径,然后将叶子节点出栈,继续访问。

public List<Path> findAllPathsNonRecDeadLoop(TreeNode root) {

        List<Path> allPaths = new ArrayList<>();
        Stack<Integer> s = new DyStack<Integer>();

        TreeNode p = root;
        while(p != null) {
            s.push(p.val);
            if (p.left == null && p.right == null) {
                allPaths.add(new Path(s.unmodifiedList()));
                s.pop();
                if (s.isEmpty()) {
                    break;
                }
            }
            if (p.left != null) {
                p = p.left;
            }
            else if (p.right != null) {
                p = p.right;
            }
        }
        return allPaths;
    }

不过,这个代码实现会陷入死循环。为什么呢?因为它会无止境重复进入左子树,而且回溯的时候,也没法找到父节点。

「回溯」

为了解决死循环的问题,我们需要加一些支持:进入某个节点时,必须记下该节点的父节点,以及该父节点是否访问过左右子树。这个信息用 TraceNode 来表示。由于始终需要回溯,因此,TraceNode 必须放在栈中,在适当的时候弹出,就像保存现场一样。当遍历的时候,需要记录已经访问的节点,不重复访问,也需要避免将中间节点重复压栈。

重新理一下。对于当前节点,有四种情形需要考虑:

  • 当前节点是叶子节点。记录路径、出栈 treeData, 出栈 traceNode ,回溯到父节点;
  • 当前节点不是叶子节点,有左子树,则需要记录该节点指针及左子树已访问,并进入左子树;
  • 当前节点不是叶子节点,有右子树,则需要记录该节点指针及右子树已访问,并进入右子树;
  • 当前节点不是叶子节点,有左右子树且均已访问,出栈 treeData, 出栈 traceNode ,回溯到父节点。

第二版的递归实现如下:

public List<Path> findAllPathsNonRec(TreeNode root) {

        List<Path> allPaths = new ArrayList<>();
        Stack<Integer> treeData = new DyStack<>();
        Stack<TraceNode> trace = new DyStack<>();

        TreeNode p = root;
        TraceNode traceNode = TraceNode.getNoAccessedNode(p);
        while(p != null) {
            if (p.left == null && p.right == null) {
                // 叶子节点的情形,需要记录路径,并回溯到父节点
                treeData.push(p.val);
                allPaths.add(new ListPath(treeData.unmodifiedList()));
                treeData.pop();
                if (treeData.isEmpty()) {
                    break;
                }
                traceNode = trace.pop();
                p = traceNode.getParent();
                continue;
            }
            else if (traceNode.needAccessLeft()) {
                // 需要访问左子树的情形
                treeData.push(p.val);
                trace.push(TraceNode.getLeftAccessedNode(p));
                p = p.left;
            }
            else if (traceNode.needAccessRight()) {
                // 需要访问右子树的情形
                if (traceNode.hasNoLeft()) {
                    treeData.push(p.val);
                }
                if (!traceNode.hasAccessedLeft()) {
                    // 访问左节点时已经入栈过,这里不重复入栈
                    treeData.push(p.val);
                }
                trace.push(TraceNode.getRightAccessedNode(p));
                p = p.right;
                if (p.left != null) {
                    traceNode = TraceNode.getNoAccessedNode(p);
                }
                else if (p.right != null) {
                    traceNode = TraceNode.getLeftAccessedNode(p);
                }
            }
            else if (traceNode.hasAllAccessed()) {
                // 左右子树都已经访问了,需要回溯到父节点
                if (trace.isEmpty()) {
                    break;
                }
                treeData.pop();
                traceNode = trace.pop();
                p = traceNode.getParent();
            }
        }
        return allPaths;
    }

class TraceNode {

    private TreeNode parent;
    private int accessed;  // 0 均未访问 1 已访问左 2 已访问右

    public TraceNode(TreeNode parent, int child) {
        this.parent = parent;
        this.accessed = child;
    }

    public static TraceNode getNoAccessedNode(TreeNode parent) {
        return new TraceNode(parent, 0);
    }

    public static TraceNode getLeftAccessedNode(TreeNode parent) {
        return new TraceNode(parent, 1);
    }

    public static TraceNode getRightAccessedNode(TreeNode parent) {
        return new TraceNode(parent, 2);
    }

    public boolean needAccessLeft() {
        return parent.left != null && accessed == 0;
    }

    public boolean needAccessRight() {
        return parent.right != null && accessed < 2;
    }

    public boolean hasAccessedLeft() {
        return parent.left == null || (parent.left != null && accessed == 1);
    }

    public boolean hasNoLeft() {
        return parent.left == null;
    }

    public boolean hasAllAccessed() {
        if (parent.left != null && parent.right == null && accessed == 1) {
            return true;
        }
        if (parent.right != null && accessed == 2) {
            return true;
        }
        return false;
    }

    public TreeNode getParent() {
        return parent;
    }

    public int getAccessed() {
        return accessed;
    }
}

关于是否已访问左右子树的判断都隐藏在 TraceNode 里,findAllPathsNonRec 方法不感知这个。后续如果觉得用 int 来表示 accessed 空间效率不高,可以内部重构,对 findAllPathsNonRec 无影响。这就是封装的益处。

测试

递归代码和非递归代码都是容易有 BUG 的,需要仔细测试下。测试用例通常至少要包括:

  • C1: 单个根节点树;
  • C2: 单个根节点 + 左节点;
  • C3: 单个根节点 + 右节点;
  • C4: 单个根节点 + 左右节点;
  • C5: 普通的二叉树,左右随机;
  • 复杂的二叉树,非常大。

如何构造复杂的二叉树呢?可以采用构造法。基于简单的 C2,C3,C4,将一棵树的根节点连接到另一棵树的左叶子节点或右叶子节点上。复杂结构总是由简单结构来组合而成。

测试代码如下。用 TreeBuilder 注解来表示构造的二叉树,从而能够批量拿到这些方法构造的树,进行测试。

   public static void main(String[] args) {
        TreePathSum treePathSum = new TreePathSum();
        Method[] methods = treePathSum.getClass().getDeclaredMethods();
        for (Method m: methods) {
            if (m.isAnnotationPresent(TreeBuilder.class)) {
                try {
                    TreeNode t = (TreeNode) m.invoke(treePathSum, null);
                    System.out.println("height: " + t.height());
                    treePathSum.test2(t);
                } catch (Exception ex) {
                    System.err.println(ex.getMessage());
                }

            }
        }
    }

    public void test(TreeNode root) {

        System.out.println("Rec Implementation");

        List<Path> paths = findAllPaths(root);
        Long sum = paths.stream().collect(Collectors.summarizingLong(Path::getValue)).getSum();
        System.out.println(paths);
        System.out.println(sum);

        System.out.println("Non Rec Implementation");

        List<Path> paths2 = findAllPathsNonRec(root);
        Long sum2 = paths2.stream().collect(Collectors.summarizingLong(Path::getValue)).getSum();
        System.out.println(paths2);
        System.out.println(sum2);

        assert sum == sum2;
    }

    public void test2(TreeNode root) {
        System.out.println("Rec Implementation");
        List<Path> paths = findAllPaths(root);
        System.out.println(paths);

        System.out.println("Non Rec Implementation");
        List<Path> paths2 = findAllPathsNonRec(root);
        System.out.println(paths2);

        assert paths.size() == paths2.size();
        for (int i=0; i < paths.size(); i++) {
            assert paths.get(i).toString().equals(paths2.get(i).toString());
        }

    }

    @TreeBuilder
    public TreeNode buildTreeOnlyRoot() {
        TreeNode tree = new TreeNode(9);
        return tree;
    }

    @TreeBuilder
    public TreeNode buildTreeWithL() {
        return buildTreeWithL(51);
    }

    public TreeNode buildTreeWithL(int rootVal, int leftVal) {
        TreeNode tree = new TreeNode(rootVal);
        TreeNode left = new TreeNode(leftVal);
        tree.left = left;
        return tree;
    }

    @TreeBuilder
    public TreeNode buildTreeWithR() {
        return buildTreeWithR(5,2);
    }

    public TreeNode buildTreeWithR(int rootVal, int rightVal) {
        TreeNode tree = new TreeNode(rootVal);
        TreeNode right = new TreeNode(rightVal);
        tree.right = right;
        return tree;
    }

    @TreeBuilder
    public TreeNode buildTreeWithLR() {
        return buildTreeWithLR(5,1,2);
    }

    public TreeNode buildTreeWithLR(int rootVal, int leftVal, int rightVal) {
        TreeNode tree = new TreeNode(rootVal);
        TreeNode left = new TreeNode(leftVal);
        TreeNode right = new TreeNode(rightVal);
        tree.right = right;
        tree.left = left;
        return tree;
    }

    Random rand = new Random(System.currentTimeMillis());

    @TreeBuilder
    public TreeNode buildTreeWithMore() {
        TreeNode tree = new TreeNode(5);
        TreeNode left = new TreeNode(1);
        TreeNode right = new TreeNode(2);
        TreeNode left2 = new TreeNode(3);
        TreeNode right2 = new TreeNode(4);
        tree.right = right;
        tree.left = left;
        left.left = left2;
        left.right = right2;
        return tree;
    }

    @TreeBuilder
    public TreeNode buildTreeWithMore2() {
        TreeNode tree = new TreeNode(5);
        TreeNode left = new TreeNode(1);
        TreeNode right = new TreeNode(2);
        TreeNode left2 = new TreeNode(3);
        TreeNode right2 = new TreeNode(4);
        tree.right = right;
        tree.left = left;
        right.left = left2;
        right.right = right2;
        return tree;
    }

    public TreeNode treeWithRandom() {
        int c = rand.nextInt(3);
        switch (c) {
            case 0return buildTreeWithL(rand.nextInt(9), rand.nextInt(9));
            case 1return buildTreeWithR(rand.nextInt(9), rand.nextInt(9));
            case 2return buildTreeWithLR(rand.nextInt(9), rand.nextInt(9), rand.nextInt(9));
            defaultreturn buildTreeOnlyRoot();
        }
    }

    public TreeNode linkRandom(TreeNode t1, TreeNode t2) {
        if (t2.left == null) {
            t2.left = t1;
        }
        else if (t2.right == null) {
            t2.right = t1;
        }
        else {
            int c = rand.nextInt(4);
            switch (c) {
                case 0: t2.left.left = t1;
                case 1: t2.left.right = t1;
                case 2: t2.right.left = t1;
                case 3: t2.right.right = t1;
                default: t2.left.left = t1;
            }
        }
        return t2;
    }

    @TreeBuilder
    public TreeNode buildTreeWithRandom() {
        TreeNode root = treeWithRandom();
        int i = 12;
        while (i > 0) {
            TreeNode t = treeWithRandom();
            root = linkRandom(root, t);
            i--;
        }
        return root;
    }

经测试,发现第二版非递归程序在某种情况下,还是有 BUG 。这说明某些基本情形还是没覆盖到。用如下测试用例调试,发现就有问题:

@TreeBuilder
    public TreeNode buildTreeWithMore4() {
        TreeNode tree = new TreeNode(5);
        TreeNode left = new TreeNode(1);
        TreeNode right = new TreeNode(2);
        TreeNode left2 = new TreeNode(3);
        TreeNode right2 = new TreeNode(4);
        TreeNode right3 = new TreeNode(6);
        tree.right = right;
        tree.left = left;
        left.right = right3;
        right.right = left2;
        left2.right = right2;
        return tree;
    }

「回溯再思考」

问题在哪里?当初次进入没有左子树的右子树时,会有问题。这说明,我还没有真正弄明白整个回溯过程。重新再理一下回溯过程:

  • 有一个用来指向当前访问节点的指针 p ;
  • 有一个用来存储已访问节点值的栈 treeData;
  • 有一个用来回溯的存储最近一次访问的节点信息的栈 trace ;
  • 有一个用来指明往哪个方向走的 traceNode 。

问题在于我没想清楚 traceNode 到底是什么含义。traceNode 的 parent 和 accessed 到底该存放什么。实际上,traceNode 和 p 是配套使用的。p 是当前进入的节点的指针,而 traceNode 用来指明进入 p 之后,该往哪里走。traceNode 的来源应该有两个:

  • 第一次进入 p 时,这时候,左右子树都没有访问过,parent 应该与 p 相同,而 accessed 总是初始化为 0 ;
  • 访问了 p 的左子树或右子树,回溯进入 p 时,这时候 parent 应该是 p 的父节点,从 trace 里拿到。

第二版非递归程序正是没有考虑到第一次进入 p 时的情况。如下代码所示。当 p 进入左子树时,需要将最近一次的父节点信息入栈 trace ,同时需要将 traceNode 设置为初始进入 p 时的情形。进入右子树类似。这一点正是第二版非递归程序没有想清楚的地方。

trace.push(TraceNode.getLeftAccessedNode(p));
p = p.left;
traceNode = TraceNode.getNoAccessedNode(p);

我们做一些修改,得到了第三版非递归程序。经测试是 OK 的。

public List<Path> findAllPathsNonRec(TreeNode root) {

        List<Path> allPaths = new ArrayList<>();
        Stack<Integer> treeData = new DyStack<>();
        Stack<TraceNode> trace = new DyStack<>();

        TreeNode p = root;
        TraceNode traceNode = TraceNode.getNoAccessedNode(p);
        while(p != null) {
            if (p.left == null && p.right == null) {
                // 叶子节点的情形,需要记录路径,并回溯到父节点
                treeData.push(p.val);
                allPaths.add(new ListPath(treeData.unmodifiedList()));
                treeData.pop();
                if (treeData.isEmpty()) {
                    break;
                }
                traceNode = trace.pop();
                p = traceNode.getParent();
                continue;
            }
            else if (traceNode.needAccessLeft()) {
                // 需要访问左子树的情形
                treeData.push(p.val);
                trace.push(TraceNode.getLeftAccessedNode(p));
                p = p.left;
                traceNode = TraceNode.getNoAccessedNode(p);
            }
            else if (traceNode.needAccessRight()) {
                // 需要访问右子树的情形
                if (traceNode.hasNoLeft()) {
                    treeData.push(p.val);
                }
                if (!traceNode.hasAccessedLeft()) {
                    // 访问左节点时已经入栈过,这里不重复入栈
                    treeData.push(p.val);
                }
                trace.push(TraceNode.getRightAccessedNode(p));
                p = p.right;
                traceNode = TraceNode.getNoAccessedNode(p);
            }
            else if (traceNode.hasAllAccessed()) {
                // 左右子树都已经访问了,需要回溯到父节点
                if (trace.isEmpty()) {
                    break;
                }
                treeData.pop();
                traceNode = trace.pop();
                p = traceNode.getParent();
            }
        }
        return allPaths;
    }

优化

「扩展性」

由于题目中所给的节点值为 0-9, 因此,前面取巧用了字符串来表示路径。如果节点值不为 0-9 呢?如果依然要用字符串表示,则需要分隔符。现在,我们用列表来表示路径。封装的好处,就在于可以替换实现,而尽量少地改变客户端代码(在这里是findAllPaths 和 findAllPathsNonRec 方法)。

这里,Path 类改成接口,原来的 Path 类改成 StringPath ,然后用 StringPath 替换 Path 。将原来 StringPath 用到的方法,定义成接口方法。只用到了 append 和 getValue 方法。不过,构造器方法参数也要兼容。这样,只要把原来的 StringPath 改成 ListPath ,其它基本不用动,就可以运行通过。

interface Path {
    void append(Integer i);
    Long getValue();
}

class StringPath implements Path // code as before }

class ListPath implements Path {
    List<Integer> path = new ArrayList<>();

    public ListPath(int i) {
        this.path.add(i);
    }

    public ListPath(List list) {
        this.path.addAll(list);
    }

    @Override
    public void append(Integer i) {
        path.add(i);
    }

    @Override
    public Long getValue() {
        StringBuilder s = new StringBuilder();
        path.forEach( e-> {
            s.append(e);
        });
        return Long.parseLong(s.reverse().toString());
    }

    public String toString() {
        return StringUtils.join(path.toArray(), "");
    }
}

小结

花了一天弄懂二叉树回溯的玩法。技术人的精神,就是追根究底,把一个事情彻底弄清楚吧!

在本文中,我们通过一个二叉树的路径寻找面试题,讨论了递归和非递归解法,探讨了非递归过程中遇到的问题,模拟了二叉树的回溯,对于理解二叉树的访问是很有益的。而对于回溯算法的理解,锻炼了非线性思维。此外,当程序有 BUG 时,往往是某个方面没想得足够明白导致。坚持思考,清晰定义,就能向正确再迈进一步。

不看答案,自己弄明白一个问题,收获大大的!

本文完整源代码见:ALLIN 工程里:zzz.study.datastructure.tree.TreePathSum