PROBLEM LINK:
Setter: Mohamed Anany
Tester: Encho Mishinev
Editorialist: Taranpreet Singh
DIFFICULTY:
Medium-Hard
PREREQUISITES:
KMP String Matching, Meet in the middle and Partial Sums.
PROBLEM:
Given String S of length N and M patterns in the String array P, Find number of permutations of Strings in array P which match with String S. A permutation p of strings match with string S if we can choose M ranges [l_i,r_i] for each i (1 \leq i \leq M) such that 1 \leq l_i \leq r_i < l_2 \leq r_2 \ldots < l_M \leq r_M \leq N such that for each valid i, String S_l,S_{l+1},\ldots S_r equals P_{p[i]}.
QUICK EXPLANATION
-
Using KMP String Matching, find all positions of all patterns in String S. Now reverse both String S and all patterns and once again, using KMP String matching, find positions of all reversed patterns on the reversed string.
-
Now iterate over all bitmasks of M bits which have M/2 bits on, and for all permutation of these set bits, find the minimum position p such that all these M/2 patterns are completely covered in prefix up to position p. Using partial sums, we can now count the number of permutations of set M/2 bits which matches with string S before position p.
-
For the M-M/2 off bits, we try all their permutations and for each permutation, find the position p from the end to start such that all the patterns given by M-M/2 off bits are found in suffix from p to end of the string in the given order, without overlapping. So, We know that for this permutation of M-M/2 patterns, all the permutations of first M/2 patterns ending before position p are valid. So, We can just add val[p] to answer where val[p] denotes the number of permutations of first M/2 patterns ending before or at position p.
EXPLANATION
First things first, we are going to need the positions where each pattern is present in the string. So, we can just find all match positions using KMP algorithm (or any other string matching algorithm) and store the positions of matches for each pattern. For our purpose, we can preprocess and build an NXT[i][j] table telling the position of the first match of ith pattern at or after jth position, since it makes sense to use the first match due to optimality.
The Naive solution for this problem shall be to iterate over all permutations of patterns and check for every permutation whether that permutation is matchable or not. This solution has complexity O(M!*N) and will definitely time out for M = 14.
Let us use Meet in the middle trick here.
Let us split each permutation into two equal (or nearly equal) parts. We can see that first M/2 patterns may be any subset of all patterns, so we can try all subsets of M patterns which is of size M/2. Now, These M/2 patterns may appear in any order and each shall be corresponding to a different permutation. So, Let us iterate over all permutation of all elements of current subsets. We can do the same for remaining M-M/2 patterns.
Now comes the interesting part.
Suppose we have a permutation of M/2 patterns (Calling it left permutation) (the ones with on bit in bitmask) and a permutation of M-M/2 patterns (calling it right permutation) (the ones with off bit in bitmask). How can we check if any pair of left and right permutation form a matchable permutation?
Let us suppose that p is the first position in the string S such that All the first M/2 patterns are present in String S[0, p] in order of left permutation such that no two patterns overlap. Also suppose that q is the last position in the string S such that all M-M/2 patterns are present in S[q, N-1] in order of right permutation such that no two patterns overlap.
The combination of these two permutations is valid if and only if p < q, since only that way all M patterns would be present in S without overlap for current permutation.
But trying every pair of permutation for every bitmask is essentially the same as iterating over all permutation since we check each permutation individually.
But we can use an observation here. If a right permutation is valid with a left permutation with end position p, it shall also be valid with any position j \leq p.
Hence, for every bitmask with M/2 bits set, we can iterate over all permutations of M/2 patterns and using prefix sum arrays, count the number of left permutations which end before or at a given position p for all positions p. Now iterating over all right permutations, we can easily count the number of left permutations which can be paired with current right permutation. We can increase our answer by the number of such left permutations for each right permutations and print the answer.
For ease of implementation, we can Build NXT table in the same manner, and then reverse both S and all patterns and find the position of matches on these reversed strings. That way, we can easily find the rightmost position such that all patterns of right permutation are present at or after that position without overlap, working in the same manner as we work with NXT table.
Time Complexity
Time complexity is O(C^M_{M/2}*(N+(M/2)!*M) + M*N+sum(|P_i|)).
AUTHOR’S AND TESTER’S SOLUTIONS:
Setter’s solution
Click to view
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
using ll = long long;
using ii = pair<int, int>;
const int N = 1e5 + 5, K = 14;
string S, P[K];
int n, k;
int nxt[K][N], rnxt[K][N];
int dp[1005][1 << K];
vector<int> FAIL(string pat) {
int m = pat.size();
vector<int> F(m + 1);
int i = 0, j = -1;
F[0] = -1;
while (i < m) {
while (j >= 0 && pat[i] != pat[j])
j = F[j];
i++, j++;
F[i] = j;
}
return F;
}
vector<int> KMP_Search(string txt, string pat) {
vector<int> F = FAIL(pat);
int i = 0, j = 0;
int n = txt.size(), m = pat.size();
vector<int> ret;
while (i < n) {
while (j >= 0 && txt[i] != pat[j])
j = F[j];
i++, j++;
if (j == m) {
ret.pb(i - j);
j = F[j];
}
}
return ret;
}
vector<vector<int>> Match, rMatch;
void buildNext(int * arr, vector<int> matches) {
//first matching >= i
for (int i = 0; i <= n; i++)
arr[i] = n;
for (auto x : matches)
arr[x] = x;
for (int i = n - 2; i >= 0; --i)
arr[i] = min(arr[i], arr[i + 1]);
}
int calc[N];
int solve(int idx, int mask) {
if (idx > n)
return 0;
if (mask == (1 << k) - 1)
return 1;
int &ret = dp[idx][mask];
if (~ret)
return ret;
ret = 0;
for (int j = 0; j < k; j++)
if (mask >> j & 1 ^ 1) {
ret += solve(nxt[j][idx] + P[j].size(), mask | (1 << j));
}
return ret;
}
void Stress1() {
///solution for subtask 2
///first brute force dp[index][mask];
///use next[index]
///can solve subtasks 1-2
memset(dp, -1, sizeof dp);
cerr << solve(0, 0) << '\n';
}
int getFirst(vector<int> & x) {
if (x.empty())
return n;
return x[0];
}
int getLast(vector<int> & x) {
if (x.empty())
return -1;
return x.back();
}
void Stress2() {
///Solution for subtask 1
if (k == 1) {
cerr << !KMP_Search(S, P[0]).empty() << '\n';
} else if (k == 2) {
vector<int> x1 = KMP_Search(S, P[0]);
vector<int> x2 = KMP_Search(S, P[1]);
int ans = 0;
if (getFirst(x1) + (int) P[0].size() <= getLast(x2))
ans++;
if (getFirst(x2) + (int) P[1].size() <= getLast(x1))
ans++;
cerr << ans << '\n';
} else {
vector<int> x[3];
x[0] = KMP_Search(S, P[0]);
x[1] = KMP_Search(S, P[1]);
x[2] = KMP_Search(S, P[2]);
int ans = 0;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
if (i == j)
continue;
for (auto v : x[3 - i - j]) {
// cout << i << ' ' << getFirst(x[i]) << ' ' << j << ' ' << getLast(x[j]) << ' ' << v << '\n';
if (getFirst(x[i]) + (int) P[i].size() <= v
&& v + (int) P[3 - i - j].size() <= getLast(x[j])) {
ans++;
break;
}
}
}
}
cerr << ans << '\n';
}
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> k;
cin >> S;
for (int i = 0; i < k; i++)
cin >> P[i];
///get matchings and build next array for them
///where nxt[j][i] = next matching position for the j-th pattern which is more than or equal i
for (int i = 0; i < k; i++) {
Match.pb(KMP_Search(S, P[i]));
buildNext(nxt[i], Match.back());
}
reverse(S.begin(), S.end());
///do the same as above for reverse strings and patterns
for (int i = 0; i < k; i++) {
reverse(P[i].begin(), P[i].end());
rMatch.pb(KMP_Search(S, P[i]));
buildNext(rnxt[i], rMatch.back());
reverse(P[i].begin(), P[i].end());
}
reverse(S.begin(), S.end());
ll ans = 0;
for (int mask = 0; mask < (1 << k); mask++)
if (__builtin_popcount(mask) == k / 2) {
///process the normal
///get the indexes
///brute force on all permutations and find the minimum suffix
///needed to get all of these matched for each permutation
///the use partial sum to pre-process the results
vector<int> v;
for (int i = 0; i < k; i++)
if (mask >> i & 1) {
v.pb(i);
}
memset(calc, 0, sizeof calc);
do {
int cur = 0;
for (auto x : v) {
if (rnxt[x][cur] == n)
goto fin1;
cur = rnxt[x][cur] + P[x].size();
}
///i have the last cur digits covered
calc[cur]++;
fin1: ;
} while (next_permutation(v.begin(), v.end()));
///partial sum
for (int i = 1; i <= n; i++)
calc[i] += calc[i - 1];
///solve the flip
v.clear();
int flip = ((1 << k) - 1) ^ mask;
for (int i = 0; i < k; i++)
if (flip >> i & 1) {
v.pb(i);
}
///get the indexes brute force on them find the minimum
///prefix to cover these patterns find all suffixes in the previous calculation
do {
int cur = 0;
for (auto x : v) {
if (nxt[x][cur] == n)
goto fin2;
cur = nxt[x][cur] + P[x].size();
}
ans += calc[n - cur];
fin2: ;
} while (next_permutation(v.begin(), v.end()));
}
cout << ans << '\n';
return 0;
}
Tester’s solution
Click to view
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <map>
using namespace std;
typedef long long llong;
int n,m;
char s[100111];
char pattern[100111];
int L = 0;
int pLens[15];
int isMatch[100111];
int mKey = 1;
int nextMatch[15][100111];
///Knuth-Morris-Pratt
int F[100111];
void findMatches()
{
int i;
int k;
//Failure function
F[1] = 0;
for (i=2;i<=L;i++)
{
k = F[i-1];
while(k != 0 && pattern[k+1] != pattern[i])
k = F[k];
if (k == 0)
{
if (pattern[1] == pattern[i])
F[i] = 1;
else
F[i] = 0;
}
else
F[i] = k+1;
}
//Matching
mKey++;
k = 0;
for (i=1;i<=n;i++)
{
while(k != 0 && pattern[k+1] != s[i])
k = F[k];
if (k == 0)
{
if (pattern[1] == s[i])
k = 1;
else
k = 0;
}
else
k++;
if (k == L)
{
isMatch[i - L + 1] = mKey;
k = F[k];
}
}
}
map< pair<int,int>, llong > mem;
map< pair<int,int>, llong >::iterator myit;
llong solve(int ind, int mask)
{
if (mask == ((1<<m)-1))
return 1LL;
myit = mem.find(make_pair(ind,mask));
if (myit != mem.end())
return (*myit).second;
int i;
llong ans = 0;
for (i=1;i<=m;i++)
{
if (nextMatch[i][ind] > n)
continue;
if ( (mask&(1<<(i-1))) == 0 )
{
ans += solve(nextMatch[i][ind] + pLens[i], mask | (1<<(i-1)));
}
}
mem.insert(make_pair(make_pair(ind,mask),ans));
return ans;
}
int main()
{
int i,j;
scanf("%d %d",&n,&m);
scanf("%s",s+1);
for (i=1;i<=m;i++)
{
scanf("%s",pattern+1);
L = strlen(pattern+1);
findMatches();
nextMatch[i][n+1] = n+1;
for (j=n;j>=1;j--)
{
if (isMatch[j] == mKey)
nextMatch[i][j] = j;
else
nextMatch[i][j] = nextMatch[i][j+1];
}
pLens[i] = L;
}
printf("%lld\n",solve(1, 0));
return 0;
}
Editorialist’s solution
Click to view
import java.util.*;
import java.io.*;
import java.text.*;
//Solution Credits: Taranpreet Singh
public class Main{
//SOLUTION BEGIN
void pre() throws Exception{}
long[] calc;
void solve(int TC) throws Exception{
int n = ni(), m = ni();
String sf = n(), sr = new StringBuilder(sf).reverse().toString();
int[][] nxtf = new int[m][n], nxtr = new int[m][n];int[] len = new int[m];
calc = new long[n+1];
for(int i = 0; i< m; i++){
String x = n();len[i] = x.length();
KMP(nxtf[i], sf, x);
KMP(nxtr[i], sr, new StringBuilder(x).reverse().toString());
}
for(int mask = 0; mask< 1<<m; mask++){
if(bit(mask)!=m/2)continue;
int[] a = new int[bit(mask)];
for(int i = 0, j = 0; i< m; i++)if(((mask>>i)&1)==1)a[j++] = i;
Arrays.fill(calc, 0);
permute(a, 0, nxtf,len, false);
for(int i = 1; i< calc.length; i++)calc[i]+=calc[i-1];
int[] b = new int[m-a.length];
for(int i = 0, j = 0; i< m; i++)if(((mask>>i)&1)==0)b[j++] = i;
permute(b, 0, nxtr,len, true);
}
pn(ans);
}
long ans = 0;
void permute(int[] a, int pos,int[][] nxt,int[] len, boolean flag){
if(pos==a.length){
int cur = 0, n = nxt[0].length;
for(int i = 0; i< a.length; i++){
if(cur==n || nxt[a[i]][cur]==n)return;
cur = nxt[a[i]][cur]+len[a[i]];
}
if(flag)ans+=calc[n-cur];
else calc[cur]++;
}else{
for(int i = pos; i< a.length; i++){
int tmp = a[i];
a[i] = a[pos];
a[pos] = tmp;
permute(a,pos+1,nxt,len, flag);
tmp = a[i];
a[i] = a[pos];
a[pos] = tmp;
}
}
}
void KMP(int[] nxt, String txt, String pat){
int M = pat.length();
int N = txt.length();
Arrays.fill(nxt,N);
int lps[] = new int[M];
int i = 0,j = 0;
computeLPSArray(pat, M, lps);
while (i < N) {
if (pat.charAt(j) == txt.charAt(i)) {
j++;i++;
}
if (j == M) {
nxt[i-j] = i-j;
j = lps[j - 1];
}else if (i < N && pat.charAt(j) != txt.charAt(i)) {
if (j != 0)j = lps[j - 1];
else i++;
}
}
for(i = N-2; i>= 0; i--)nxt[i] = Math.min(nxt[i], nxt[i+1]);
}
void computeLPSArray(String pat, int M, int lps[]){
int len = 0;
int i = 1;
lps[0] = 0;
while (i < M) {
if (pat.charAt(i) == pat.charAt(len)) {
len++;
lps[i] = len;
i++;
}else{
if(len != 0){len = lps[len - 1];
}else{
lps[i] = len;
i++;
}
}
}
}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
long mod = (long)1e9+7, IINF = (long)1e18;
final int INF = (int)1e9, MX = (int)2e3+1;
DecimalFormat df = new DecimalFormat("0.00000000000");
double PI = 3.1415926535897932384626433832792884197169399375105820974944, eps = 1e-8;
static boolean multipleTC = false, memory = false;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
int T = (multipleTC)?ni():1;
//Solution Credits: Taranpreet Singh
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
else new Main().run();
}
long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to Share your approach, If it differs. Suggestions are always welcomed.