vlambda博客
学习文章列表

一步一步分析最长公共子序列问题

这两年一直有人跟我说选择大于努力,一开始不了解其中深意。最近开始感慨颇深,昨天有个老哥来我们公司,28岁就从字节退休了,问我们公司要不要投资。😂好吧,别人家的28岁。在合适的时机能做出选择确实很重要啊,不过为了让我们能有更大的选择权,我们暂时还是得先努力提高技术呀,所以咱们今天还是继续啃算法。

今天的题目也是一道动态规划题,它是这样的:
给定两个字符串s1跟s2,返回这两个字符串的最长公共子序列的长度。一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些字符(也可以不删除任何字符)后组成的新字符串。若这两个字符串没有公共子序列,则返回 0。

我们先来看两个例子:

输入: s1 = "abdca"       s2 = "cbda"
输出: 3
解释: 最长公共子序列是"bda"
输入: s1 = "passport"       s2 = "ppsspt"
输出: 5
解释: 最长公共子序列是"psspt"

基本解法

那我们还是从一个基本的暴力递归开始尝试解题。我们可以尝试这两个字符串的所有子序列去找到最长的,一次我们比较一个字符串,那么对于在s1上i位置的字符跟在s2上j位置的字符,我们有两种选择:

  1. 如果s1[i]s2[j]相等,那我们再去递归剩下的元素。

  2. 如果不等,那我们就要跳过s1[i]或者s2[j]去继续递归。

总体思路上还是跟前面我们讨论回文子序列时差不多,我们可以尝试着写出代码了:

public int findLCSLength(String s1, String s2) {
        return findLCSLengthRecursive(s1, s2, 00);
    }

private int findLCSLengthRecursive(String s1, String s2, int i1, int i2) {
        if(i1 == s1.length() || i2 == s2.length())
            return 0;

        if(s1.charAt(i1) == s2.charAt(i2))
            return 1 + findLCSLengthRecursive(s1, s2, i1+1, i2+1);

        int c1 = findLCSLengthRecursive(s1, s2, i1, i2+1);
        int c2 = findLCSLengthRecursive(s1, s2, i1+1, i2);

        return Math.max(c1, c2);
    }

这个时间复杂度还是不得了的,有O(2^(m+n)),m跟n分别是s1跟s2的长度,而空间复杂度也有O(m+n)

那我们还是照例看看能不能用我们熟悉的自上而下进行优化。

自上而下

我们这边还是用一个数组来缓存已经解决过的子问题的结果。观察一下我们的递归函数findLCSLengthRecursive(String s1, String s2, int i1, int i2),就只有那两个索引是变化的,那我们还是要用一个二维数组来保存结果。(其实也可以使用哈希表,用(i1 + “|” + i2)作为key即可)。

我们再来把我们的递归算法用缓存优化:

public int findLCSLength(String s1, String s2) {
        Integer[][] dp = new Integer[s1.length()][s2.length()];
        return findLCSLengthRecursive(dp, s1, s2, 00);
    }

private int findLCSLengthRecursive(Integer[][] dp, String s1, String s2, int i1, int i2) {
        if (i1 == s1.length() || i2 == s2.length())
            return 0;

        if (dp[i1][i2] == null) {
            if (s1.charAt(i1) == s2.charAt(i2))
                dp[i1][i2] = 1 + findLCSLengthRecursive(dp, s1, s2, i1 + 1, i2 + 1);
            else {
                int c1 = findLCSLengthRecursive(dp, s1, s2, i1, i2 + 1);
                int c2 = findLCSLengthRecursive(dp, s1, s2, i1 + 1, i2);
                dp[i1][i2] = Math.max(c1, c2);
            }
        }

        return dp[i1][i2];
    }

这个应该是我们最喜欢的环节了,只要能把基本的递归解法写出来,这个步骤就是锦上添花而已,不要花太多脑力。

自下而上

通常我们也是可以自下而上来思考这道题目的,既然我们想要给定的两个字符串的所有子序列,我们是可以用二维数组来保存结果的。两个字符串的长度将会决定数组的大小。对于在s1中的索引i,跟s2中的索引j,我们有两种选择:

  1. 如果s1[i]s2[j]相等,那最长公共子序列的长度就等于1+在两个字符串中到i-1跟j-1的最长公共子序列的长度。

  2. 如果不等,那我们选择通过跳过s1[i]或者s2[j]获得的公共子序列长度最长的那个。

整体思路还是跟之前讨论最长回文字符串很像的,大家感兴趣的可以看我之前的那篇文章。
在这个情况下,我们的核心逻辑如下:

if s1[i] == s2[j] 
  dp[i][j] = 1 + dp[i-1][j-1
else 
  dp[i][j] = max(dp[i-1][j], dp[i][j-1])

可能这么说还是比较抽象,我们还是来看图。从长度为0的子序列开始,只要有任意一个字符串长度是0,最长公共子序列的长度就是0。

1.png

i:0, j:0-5 and i:0-4, j:0 => dp[i][j] = 0, 因为只要有一个字符串长度是0,最长公共子序列的长度就是0。

一步一步分析最长公共子序列问题
2.jpg

i:1, j:1 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]),因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
3.jpg

i:1, j:2 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
4.png

i:1, j:3 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
5.png

i:1, j:4 => dp[i][j] = 1 + dp[i-1][j-1], 因为s1[i] == s2[j]。

一步一步分析最长公共子序列问题
6.jpg

i:1, j:5 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]),因为s1[i] != s2[j]。趋势不错,我们已经渐渐找到公共子序列了。

一步一步分析最长公共子序列问题
7.jpg

i:2, j:1 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
8.png

i:2, j:2=> dp[i][j] = 1 + dp[i-1][j-1], 因为s1[i] == s2[j]。

一步一步分析最长公共子序列问题
9.png

i:2, j:3-5 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
10.png

i:31, j:1 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
11.png

i:3, j:2 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

一步一步分析最长公共子序列问题
12.png

i:3, j:3=> dp[i][j] = 1 + dp[i-1][j-1], 因为s1[i] == s2[j]。


一步一步分析最长公共子序列问题
13.png

i:3, j:4-5 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。


一步一步分析最长公共子序列问题
14.png

i:4, j:1=> dp[i][j] = 1 + dp[i-1][j-1], 因为s1[i] == s2[j]。
15.png

i:4, j:2-4 => dp[i][j] = max(dp[i-1][j], dp[i][j-1]), 因为s1[i] != s2[j]。

16.png

i:4, j:5=> dp[i][j] = 1 + dp[i-1][j-1], 因为s1[i] == s2[j]。


我这边怕有些同学不好理解,把过程详细地展示出来了。从这面这张表可以看出,我们的最长公共子序列的长度是3,在dp[4][5]

过程捋清楚了,我们就可以动手来实现看看了:

public int findLCSLength(String s1, String s2) {
        int[][] dp = new int[s1.length()+1][s2.length()+1];
        int maxLength = 0;
        for(int i=1; i <= s1.length(); i++) {
            for(int j=1; j <= s2.length(); j++) {
                if(s1.charAt(i-1) == s2.charAt(j-1))
                    dp[i][j] = 1 + dp[i-1][j-1];
                else
                    dp[i][j] = Math.max(dp[i-1][j], dp[i][j-1]);

                maxLength = Math.max(maxLength, dp[i][j]);
            }
        }
        return maxLength;
    }

这时候时间复杂度跟空间复杂度都是O(m+n),m跟n是两个字符串的长度。

但是仔细观察一下代码哈,其实我们只是需要前面一行的数据来推导出当前的结果,再往前的数据就用不到了。这意味着我们还是有优化的空间的。我们可以进一步这么写:

static int findLCSLength(String s1, String s2) {
        int[][] dp = new int[2][s2.length()+1];
        int maxLength = 0;
        for(int i=1; i <= s1.length(); i++) {
            for(int j=1; j <= s2.length(); j++) {
                if(s1.charAt(i-1) == s2.charAt(j-1))
                    dp[i%2][j] = 1 + dp[(i-1)%2][j-1];
                else
                    dp[i%2][j] = Math.max(dp[(i-1)%2][j], dp[i%2][j-1]);

                maxLength = Math.max(maxLength, dp[i%2][j]);
            }
        }
        return maxLength;
    }

这就算完事儿啦,总体看下来还是跟之前类似,只不过有一些细节在里面大家要能够发现并优化掉。做完这道题又是快乐的一天,Happy coding~