题目背景:
10.20 NOIP模拟T2 bzoj1084
分析:DP
听说这貌似是一道原题,听说我貌似8天前做过这道题,听说这道题的复杂度是O(nk),听说我8天前写的O(nk),然后今天写出了O(n3k)···(手动再见
首先我们来看这道题,m <= 2,这就是在告诉我们要在上面做文章。
首先考虑m = 1的情况,定义数组f[i][j]表示当前枚举到第i行,已经选择了j个矩阵,然后从前面0 ~ i - 1直接转移过来
f[i][j] = f[i - 1][j]
f[i][j] = max(f[l][j - 1] + sum[i] - sum[l])
(l >= 0 && l < i, sum为矩阵权值的前缀和)
这样转移的复杂度为O(n2k)
Source:
inline void read_in() {R(n), R(m), R(k);for (int i = 1; i <= n; ++i)for (int j = 0; j < m; ++j)R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}inline void solve_m_1() {static int f[MAXN][12];memset(f, 128, sizeof(f));for (int i = 0; i <= n; ++i) f[i][0] = 0;for (int c = 1; c <= k; ++c)for (int i = 1; i <= n; ++i) {f[i][c] = f[i - 1][c];for (int l = 0; l < i; ++l)f[i][c] = std::max(f[i][c], sum[0][i] - sum[0][l] + f[l][c - 1]);}std::cout << f[n][k];
}
再来讲第二种方法,定义数组f[i][j][0/1]表示,当前枚举到第i行,已经选择了j个矩阵,当前的位置是否被选择了。(如果上一个被选择了,可以考虑直接将这一个接在上一个后面,矩阵数量不增加)
f[i][j][0] = max(f[i - 1][j][1], f[i - 1][j][0])
f[i][j][1] = max(f[i - 1][j - 1][0], f[i - 1][j - 1][1], f[i -1][j][1]) + a[i];
解释:f[i - 1][j -1][0] à 上一个位置没有被选择
f[i - 1][j - 1][1] à 上一个位置被选择了,这一个位置重新开始一个新的矩阵
f[i - 1][j][1] à 上一个位置被选择了,这个位置接在上一个所在的矩阵上面
这样做的复杂度是O(nk)的
Source:
inline void read_in() {R(n), R(m), R(k);for (int i = 1; i <= n; ++i)for (int j = 0; j < m; ++j)R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}inline void solve_m_1() {static int f[MAXN][MAXK][2]; memset(f, 128, sizeof(f)), f[0][0][0] = 0;for (int i = 1; i <= n; ++i) for (int j = 0; j <= k; ++j) {f[i][j][0] = std::max(f[i - 1][j][0], f[i - 1][j][1]);if (j > 0) f[i][j][1] = std::max(f[i - 1][j - 1][0], f[i - 1][j - 1][1]) + a[0][i];f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);}std::cout << std::max(f[n][k][0], f[n][k][1]);
}
我们继续讲下一种,首先因为它只有一行,我们就相当于选择k个互不相交的子段,这是一种费用流的思想,然后我们再做分析发现,它的性质很特殊,我们可以直接利用线段树来维护,每一次选择一段区间相当于把这段区间里面的数据全部取反,这样的复杂度是O(klogn)
Source:
/*created by scarlyw
*/
// 注:本代码为bzoj3502的代码,原题当中是选择最多k个子区间,
// 而本题则是要求至少k个,只需要把代码中写有注释的地方更改一下即可
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>inline char read() {static const int IN_LEN = 1024 * 1024;static char buf[IN_LEN], *s, *t;if (s == t) {t = (s = buf) + fread(buf, 1, IN_LEN, stdin);if (s == t) return -1;}return *s++;
}///*
template<class T>
inline void R(T &x) {static char c;static bool iosig;for (c = read(), iosig = false; !isdigit(c); c = read()) {if (c == -1) return ;if (c == '-') iosig = true; }for (x = 0; isdigit(c); c = read()) x = ((x << 2) + x << 1) + (c ^ '0');if (iosig) x = -x;
}
//*/const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;*oh++ = c;
}template<class T>
inline void W(T x) {static int buf[30], cnt;if (x == 0) write_char('0');else {if (x < 0) write_char('-'), x = -x;for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;while (cnt) write_char(buf[cnt--]);}
}inline void flush() {fwrite(obuf, 1, oh - obuf, stdout);
}/*
template<class T>
inline void R(T &x) {static char c;static bool iosig;for (c = getchar(), iosig = false; !isdigit(c); c = getchar())if (c == '-') iosig = true; for (x = 0; isdigit(c); c = getchar()) x = ((x << 2) + x << 1) + (c ^ '0');if (iosig) x = -x;
}
//*/const int MAXN = 1000000 + 10;struct data {int lp, rp, p1, p2;long long lx, mx, rx, sum;inline void init(int l, long long x) {lp = rp = p1 = p2 = l, lx = mx = rx = sum = x;}friend inline data operator + (const data &a, const data &b) {data t;t.sum = a.sum + b.sum;t.lp = a.lp, t.lx = a.lx;if (t.lx < a.sum + b.lx) t.lx = a.sum + b.lx, t.lp = b.lp;t.rp = b.rp, t.rx = b.rx;if (t.rx < b.sum + a.rx) t.rx = b.sum + a.rx, t.rp = a.rp;t.mx = a.rx + b.lx, t.p1 = a.rp, t.p2 = b.lp;if (a.mx > t.mx) t.mx = a.mx, t.p1 = a.p1, t.p2 = a.p2;if (b.mx > t.mx) t.mx = b.mx, t.p1 = b.p1, t.p2 = b.p2;return t;}
} ;struct node {data min, max;bool flag;inline void init(int l, long long val) {min.init(l, -val), max.init(l, val);}
} tree[(MAXN << 1) + 200000];
int n, k;
int a[MAXN];inline void reverse(int k) {std::swap(tree[k].min, tree[k].max), tree[k].flag ^= 1;
}inline void push_down(int k) {if (tree[k].flag) reverse(k << 1), reverse(k << 1 | 1), tree[k].flag ^= 1;
}inline void update(int k) {tree[k].min = tree[k << 1].min + tree[k << 1 | 1].min;tree[k].max = tree[k << 1].max + tree[k << 1 | 1].max;
}inline void build_tree(int k, int l, int r) {if (l == r) return tree[k].init(l, a[l]);int mid = l + r >> 1;build_tree(k << 1, l, mid), build_tree(k << 1 | 1, mid + 1, r);update(k);
}inline void rever(int k, int l, int r, int ql, int qr) {if (ql <= l && r <= qr) return reverse(k);push_down(k);int mid = l + r >> 1;if (ql <= mid) rever(k << 1, l, mid, ql, qr);if (qr > mid) rever(k << 1 | 1, mid + 1, r, ql, qr);update(k);
}inline void solve() {R(n), R(k);for (int i = 1; i <= n; ++i) R(a[i]);build_tree(1, 1, n);long long ans = 0;for (int i = 1; i <= k; ++i) {data cur = tree[1].max;if (cur.mx > 0) ans += cur.mx; else break ;//取消上面的if判断,直接修改成ans += cur.mx即可 rever(1, 1, n, cur.p1, cur.p2);}printf("%lld", ans);
}int main() {solve();return 0;
}
然后我们再来看m = 2的情况。
先来讲讲暴力的O(n3k)的做法,虽然说得是这么高的复杂度,但是因为常数的确挺小,然后,最大数据也就跑了100ms多一些。定义数组f[i][j][k]表示当前第1列选到第i个数,第二列选到第j个数,已经选择了k个矩阵了,那么转移方程很好想。
f[i][j][k] = max(f[i - 1][j][k], f[i][j - 1][k])
f[i][j][k] = max(f[l][j][k - 1] + sum[0][i] - sum[0][l]) (l < i&& l >= 0) sum[0]为第一列的矩阵前缀和
f[i][j][k] = max(f[i][l][k - 1] + sum[1][j] - sum[1][l]) (l < j&& l >= 0) sum[1]为第二列的矩阵前缀和
如果i == j时,可以选择两列一起选择,所以
f[i][j][k] = max(f[l][l][k - 1] + sum[1][i] - sum[1][l] + sum[0][i] -sum[0][l]) (i == j && l < i && l >= 0)
这样DP的复杂度是O(n3k)的
Source:
inline void read_in() {R(n), R(m), R(k);for (int i = 1; i <= n; ++i)for (int j = 0; j < m; ++j)R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}inline void solve_m_2() {static int f[MAXN][MAXN][12];memset(f, 128, sizeof(f));for (int i = 0; i <= n; ++i)for (int j = 0; j <= n; ++j) f[i][j][0] = 0;for (int c = 1; c <= k; ++c) {for (int i = 1; i <= n; ++i)for (int j = 1; j <= n; ++j) {f[i][j][c] = std::max(f[i - 1][j][c], f[i][j - 1][c]);for (int l = 0; l < i; ++l)f[i][j][c] = std::max(f[i][j][c], f[l][j][c - 1] + sum[0][i] - sum[0][l]);for (int l = 0; l < j; ++l)f[i][j][c] = std::max(f[i][j][c], f[i][l][c - 1] + sum[1][j] - sum[1][l]);if (i == j) {for (int l = 0; l < i; ++l)f[i][j][c] = std::max(f[i][j][c], f[l][l][c - 1] + sum[0][i] - sum[0][l] + sum[1][i] - sum[1][l]);}}}std::cout << f[n][n][k];
}
我们再来看下一种方法,定义f[i][j][k]表示当前枚举到第i行,已经选择了j个矩阵,第i行的状态,第i行的状态一共有5种,分别是0:表示两个都不选,1:只选择第一列的,2:只选择第二列的,3:两列都选并且两列在不同的矩阵中,4:两列都选,并且两列在同一个矩阵中,然后分情况进行讨论及转移即可。复杂度为O(nk),常数略大。
转移直接见代码吧,比较清楚。
Source:
inline void read_in() {R(n), R(m), R(k);for (int i = 1; i <= n; ++i)for (int j = 0; j < m; ++j)R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}inline void solve_m_2() {static int f[MAXN][MAXK][5];/*f[i][j][0] : i行,选择j个,本行两列均未选 f[i][j][1] : i行,选择j个,本行选择第一列 f[i][j][2] : i行,选择j个,本行选择第二列 f[i][j][3] : i行,选择j个,本行选择一二列,一二列不在同一矩阵中 f[i][j][4] : i行,选择j个,本行选择一二列,一二列在同一矩阵中 */ memset(f, 128, sizeof(f)), f[0][0][0] = 0;for (int i = 1; i <= n; ++i)for (int j = 0; j <= k; ++j) {for (int t = 0; t < 5; ++t) {f[i][j][0] = std::max(f[i][j][0], f[i - 1][j][t]);if (j > 0) {f[i][j][1] = std::max(f[i][j][1], f[i - 1][j - 1][t] + a[0][i]);f[i][j][2] = std::max(f[i][j][2], f[i - 1][j - 1][t] + a[1][i]);f[i][j][4] = std::max(f[i][j][4], f[i - 1][j - 1][t] + a[0][i] + a[1][i]);}if (j > 1) f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 2][t] + a[0][i] + a[1][i]) ;}f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][3] + a[0][i]);f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][2] + a[1][i]);f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][3] + a[1][i]);f[i][j][3] = std::max(f[i][j][3], f[i - 1][j][3] + a[0][i] + a[1][i]);if (j > 0) {f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][1]+ a[0][i] + a[1][i]);f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][2]+ a[0][i] + a[1][i]);}f[i][j][4] = std::max(f[i][j][4], f[i - 1][j][4] + a[0][i] + a[1][i]);}int ans = -INF;for (int i = 0; i < 5; ++i) ans = std::max(f[n][k][i], ans);std::cout << ans;
}
最后贴两份总的代码
m = 1部分用O(n2k)实现, m = 2部分用O(n3k)实现。
Source:
/*created by scarlyw
*/
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <ctime>inline char read() {static const int IN_LEN = 1024 * 1024;static char buf[IN_LEN], *s, *t;if (s == t) {t = (s = buf) + fread(buf, 1, IN_LEN, stdin);if (s == t) return -1;}return *s++;
}///*
template<class T>
inline void R(T &x) {static bool iosig;static char c;for (iosig = false, c = read(); !isdigit(c); c = read()) {if (c == -1) return ;if (c == '-') iosig = true;}for (x = 0; isdigit(c); c = read()) x = ((x << 2) + x << 1) + (c ^ '0');if (iosig) x = -x;
}
//*/const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;inline void write_char(char c) {if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;*oh++ = c;
}template<class T>
inline void W(T x) {static int buf[30], cnt;if (x == 0) write_char('0');else {if (x < 0) write_char('-'), x = -x;for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;while (cnt) write_char(buf[cnt--]);}
}inline void flush() {fwrite(obuf, 1, oh - obuf, stdout);
}/*
template<class T>
inline void R(T &x) {static bool iosig;static char c;for (iosig = false, c = getchar(); !isdigit(c); c = getchar()) {if (c == -1) return ;if (c == '-') iosig = true;}for (x = 0; isdigit(c); c = getchar()) x = ((x << 2) + x << 1) + (c ^ '0');if (iosig) x = -x;
}
//*/const int MAXN = 100 + 10;int n, m, k;
int a[2][MAXN], sum[2][MAXN];inline void read_in() {R(n), R(m), R(k);for (int i = 1; i <= n; ++i)for (int j = 0; j < m; ++j)R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}inline void solve_m_1() {static int f[MAXN][12];memset(f, 128, sizeof(f));for (int i = 0; i <= n; ++i) f[i][0] = 0;for (int c = 1; c <= k; ++c)for (int i = 1; i <= n; ++i) {f[i][c] = f[i - 1][c];for (int l = 0; l < i; ++l)f[i][c] = std::max(f[i][c], sum[0][i] - sum[0][l] + f[l][c - 1]);}std::cout << f[n][k];
}inline void solve_m_2() {static int f[MAXN][MAXN][12];memset(f, 128, sizeof(f));for (int i = 0; i <= n; ++i)for (int j = 0; j <= n; ++j) f[i][j][0] = 0;for (int c = 1; c <= k; ++c) {for (int i = 1; i <= n; ++i)for (int j = 1; j <= n; ++j) {f[i][j][c] = std::max(f[i - 1][j][c], f[i][j - 1][c]);for (int l = 0; l < i; ++l)f[i][j][c] = std::max(f[i][j][c], f[l][j][c - 1] + sum[0][i] - sum[0][l]);for (int l = 0; l < j; ++l)f[i][j][c] = std::max(f[i][j][c], f[i][l][c - 1] + sum[1][j] - sum[1][l]);if (i == j) {for (int l = 0; l < i; ++l)f[i][j][c] = std::max(f[i][j][c], f[l][l][c - 1] + sum[0][i] - sum[0][l] + sum[1][i] - sum[1][l]);}}}std::cout << f[n][n][k];
}int main() {
// freopen("matrix.in", "r", stdin);
// freopen("matrix.out", "w", stdout);read_in();if (m == 1) solve_m_1();else solve_m_2();return 0;
}
m = 1,2部分均用O(nk)实现。
Source:
/*created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>inline char read() {static const int IN_LEN = 1024 * 1024;static char buf[IN_LEN], *s, *t;if (s == t) {t = (s = buf) + fread(buf, 1, IN_LEN, stdin);if (s == t) return -1;}return *s++;
}///*
template<class T>
inline void R(T &x) {static char c;static bool iosig;for (c = read(), iosig = false; !isdigit(c); c = read()) {if (c == -1) return ;if (c == '-') iosig = true; }for (x = 0; isdigit(c); c = read()) x = ((x << 2) + x << 1) + (c ^ '0');if (iosig) x = -x;
}
//*/const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;*oh++ = c;
}template<class T>
inline void W(T x) {static int buf[30], cnt;if (x == 0) write_char('0');else {if (x < 0) write_char('-'), x = -x;for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;while (cnt) write_char(buf[cnt--]);}
}inline void flush() {fwrite(obuf, 1, oh - obuf, stdout);
}/*
template<class T>
inline void R(T &x) {static char c;static bool iosig;for (c = getchar(), iosig = false; !isdigit(c); c = getchar())if (c == '-') iosig = true; for (x = 0; isdigit(c); c = getchar()) x = ((x << 2) + x << 1) + (c ^ '0');if (iosig) x = -x;
}
//*/const int MAXN = 100 + 10;
const int MAXK = 12;
const int INF = 1000000000;int n, m, k;
int a[2][MAXN], sum[2][MAXN];inline void read_in() {R(n), R(m), R(k);for (int i = 1; i <= n; ++i)for (int j = 0; j < m; ++j)R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}inline void solve_m_1() {static int f[MAXN][MAXK][2]; memset(f, 128, sizeof(f)), f[0][0][0] = 0;for (int i = 1; i <= n; ++i) for (int j = 0; j <= k; ++j) {f[i][j][0] = std::max(f[i - 1][j][0], f[i - 1][j][1]);if (j > 0) f[i][j][1] = std::max(f[i - 1][j - 1][0], f[i - 1][j - 1][1]) + a[0][i];f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);}std::cout << std::max(f[n][k][0], f[n][k][1]);
}inline void solve_m_2() {static int f[MAXN][MAXK][5];memset(f, 128, sizeof(f)), f[0][0][0] = 0;for (int i = 1; i <= n; ++i)for (int j = 0; j <= k; ++j) {for (int t = 0; t < 5; ++t) {f[i][j][0] = std::max(f[i][j][0], f[i - 1][j][t]);if (j > 0) {f[i][j][1] = std::max(f[i][j][1], f[i - 1][j - 1][t] + a[0][i]);f[i][j][2] = std::max(f[i][j][2], f[i - 1][j - 1][t] + a[1][i]);f[i][j][4] = std::max(f[i][j][4], f[i - 1][j - 1][t] + a[0][i] + a[1][i]);}if (j > 1) f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 2][t] + a[0][i] + a[1][i]) ;}f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][3] + a[0][i]);f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][2] + a[1][i]);f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][3] + a[1][i]);f[i][j][3] = std::max(f[i][j][3], f[i - 1][j][3] + a[0][i] + a[1][i]);if (j > 0) {f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][1]+ a[0][i] + a[1][i]);f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][2]+ a[0][i] + a[1][i]);}f[i][j][4] = std::max(f[i][j][4], f[i - 1][j][4] + a[0][i] + a[1][i]);}int ans = -INF;for (int i = 0; i < 5; ++i) ans = std::max(f[n][k][i], ans);std::cout << ans;
}int main() {read_in();if (m == 1) solve_m_1();else solve_m_2();return 0;
}