[算法] DP-被3整除的子序列

[算法] DP-被3整除的子序列_第1张图片[算法] DP-被3整除的子序列_第2张图片
原题连接:https://ac.nowcoder.com/acm/skill/detail/acm/1301

这道题挺有意思的,不过值得注意到的一点是这里的子序列并非是连续的子串,而且子序列是有顺序的但不一定要连续,例如示例3中的333(粗体代表第一个3,正常体代表第二个,斜体代表第三个) 所得到的7种子序列为:3、3、333、3333333

用到动态规划,题目可以理解为前n个(包括n)长度的数字串中有多少能被3整除的子序列。我们这么想,假设当前子序列的长度为k,每当子序列的长度+1变成k+1的时候,前k+1的数字串中总共的子序列数量就等于原来子序列数量*2+1。原来子序列的数量*2是因为还要加上在每一个子序列的基础上加上第k+1数的组合序列再加上第k+1单个数组成的序列。这么讲很抽象,简单来一个例子。

例如:
子序列为:123 (len = 3) 子序列为:
【1,2,3,12,23,13,123】 总共7种
现在变为:1234 (len = 4) 子序列为:
【1,2,3,12,23,13,123, 14,24,34,124,234,134,12344】 总共15种 (7*2+1)
****黄色部分是在子序列长度为3的串基础上后加上第4个数字4得来的,最后的4是它本身。

这样子,问题就变得很简单了。我们用数组dp保存mod3为0、1、2子序列个数,(dp[0] 代表序列能被3整除的个数,dp[1]为mod3为1的个数,dp[2]代表mod3为2的个数)
假设当前子序列长度为k,当序列变为k+1时,我们只需要判断新加进来的第k+1这个数是否能被3整除,并更新dp数组就解决了。
假设新加进来的数为n,n%3会有三种情况:

  1. n%3 == 0 这时候原序列中能被三整除的数在后面添加一个能被三整除的数(n)仍然能被三整除,并且n本身就是一个能被三整除的子序列。所以有:
    dp[0] = dp[0] + dp[0] +1; dp[1] = dp[1] + dp[1]; dp[2] = dp[2] + dp[2]

  2. n%3 == 1 这时候原序列中模3为2的数在后面添加模3为1的数(n)就是能被三整除的数,并且n本身是一个模3为1的子序列,所以有:
    所以有:dp[0] = dp[0] + dp[2], dp[1] = dp[1] + dp[0] +1,dp[2] = dp[2] + dp[1]

  3. n%3 == 2 这时候原序列中模3为1的数在后面添加模3为2的数(n)就是能被三整除的数,并且n本身是一个模3为2的子序列,所以有:
    dp[0] = dp[0] + dp[1], dp[1] = dp[1] + dp[2], dp[2] = dp[2] + dp[0] +1

因为dp数组的更新会相互影响,所以代码中用了m0,m1,m2分别记录所增加的模3为0、1、2的个数。

最后,别忘记了在每次更新dp数组的时候要(mod1e9+7), 因为一旦数字串中3、6、9的个数特别多的时候数量就会变得特别大,尤其是在数字串全是由3、6、9组合的时候。

代码如下


import java.util.Scanner;

public class Main{

	public static void main(String[] args) {

		Scanner in = new Scanner(System.in);

		String str = in.next();

		int length = str.length();
		
		//分别存放mod3后余数为0、1、2的个数
		int[] dp = new int[3];
		int m0,m1,m2;
		for (int i = 0; i < length; i++) {
			m0 = m1 = m2 = 0;
			int n = (Integer.valueOf(str.substring(i, i+1)))%3;
			
			if(n == 0) {
				m0+=dp[0]+1;
				m1+=dp[1];
				m2+=dp[2];
			}
			else if(n==1){
				m2+=dp[1];
				m0+=dp[2];
				m1+=dp[0]+1;
			}
			else {
				m0+=dp[1];
				m1+=dp[2];
				m2+=dp[0]+1;
			}
			dp[0]+= m0;
			dp[1]+= m1;
			dp[2]+= m2;
			//每次结果mod1e9+7防止溢出
			for(int k=0;k<3;k++)
				dp[k] = (int)(dp[k]%(1e9+7));
		}
		
		System.out.println(dp[0]);
		
	}
}

你可能感兴趣的:(算法)