Linear Recurrence Relation & Binary Exponentiation & DP Counting

Summary

In this post, I will introduce a common technique in some DP counting problems: binary exponentiation.

Typically, in these problems, you will be given a number n and you are asked to return the number of valid combinations. You need to figure out the initial state and the transform function between states. The problems then could be solved by DP.

During the iteration, if the the current state and the previous states has a linear recurrence relation. Then a bonus is that you could use binary exponentiation of square matrixes to reduce the complexity to log(n).

OJ

 TrickDifficulty
1411. Number of Ways to Paint N × 3 Grid DP in Linear Structures + Binary Exponentiation6 points
935. Knight DialerDP in Linear Structures + Binary Exponentiation5-6 points

Details

LC 1411. Number of Ways to Paint N × 3 Grid

A naive O(12*12*3 n) solution is not difficult. The number of states is 12n. At each i, it takes O(12*3) time to see if the current option has conflicts with the previous one.

alg[i][prev] = sum {(alg[i + 1][curr]) if curr & prev not conflict, for curr in 12 options}
// alg[n - 1][prev] = sum {1 if ok(curr, prev) for curr in 12 options}

A better DP will decrease the number of states. At each i, the number of states now becomes 2, either in pattern 121, or pattern 123.

pattern 121: 121, 131, 212, 232, 313, 323.
pattern 123: 123, 132, 213, 231, 312, 321.

We consider the next possible pattern for each current pattern.

Patter 121 can be followed by: 212, 213, 232, 312, 313;
3 121s, 2 123s 

Patter 123 can be followed by: 212, 231, 312, 232; 
2 121s, 2 123s

We can write the following transform equation.

// alg[i][0]: 121, alg[i][1]: 123
alg[i][0] = alg[i + 1][0] * 3 + alg[i + 1][1] * 2
alg[i][1] = alg[i + 1][0] * 2 + alg[i + 1][1] * 2
[alg[i][0],    [[3, 2]   [alg[i + 1][0],  
 alg[i][1]] =   [2, 2]]   alg[i + 1][1]]
// [alg[n - 1][0],    [6,
    alg[n - 1][1]] =   6]

Once we get the square matrix as the transform function, we could apply the binary exponentiation for matrixes to reduce the complexity.

// The final result will be 
// fastPow(a, n - 1) * [6, 6]
vector<vector<ll>> modMul(vector<vector<ll>> const &a, vector<vector<ll>> const &b) {
    int n = a.size();
    vector<vector<ll>> result(n, vector<ll>(n));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            for (int k = 0; k < n; ++k) {
                result[i][j] += a[i][k] * b[k][j];
                result[i][j] %= MOD;
            }
        }
    }
    return result;
}

vector<vector<ll>> fastPow(vector<vector<ll>> const &a, int n) {
    if (n == 0) {
        int n = a.size();
        vector<vector<ll>> result(n, vector<ll>(n));
        for (int i = 0; i < n; ++i) {
            result[i][i] = 1;
        }
        return result;
    } else if (n == 1) {
        return a;
    } else {
        if ((n & 1) == 0) {
            vector<vector<ll>> a_square = modMul(a, a);
            return fastPow(a_square, n / 2);
        } else {
            return modMul(a, fastPow(a, n - 1));
        }
    }
}

Leave a Reply

Your email address will not be published. Required fields are marked *