743-E. Vladik and cards
Tags: binary search, bitmask, dp, sequence
出處:https://codeforces.com/contest/743/problem/E
提交:https://codeforces.com/contest/743/submission/110432027
問題:給定 $n\in[1,1000]$ 個介於 $[1,8]$ 的正整數,希望能找到一個最長子序列使得其中所有的元素值彼此的數量差距不超過 1,不存在的元素數量須視為 0 並納入考慮,而且同元素值之間必須彼此相鄰,不可以與其他元素值交錯。請問此最長子序列的長度。
解法:
這題的難點在於非常不容易聯想到可以對元素數量作 binary search,以及用 bitmask 記錄已經選過的元素,這兩大技巧,如果有想到這兩個突破點的話實作上應該不太困難。之所以可以對元素數量作 binary search 的原因在於,假設我們取出的子序列內少量的元素個數皆為 b、多量的元素個數皆為 b+1,那麼必定亦可再從中挑出少量的元素個數皆為 b-1、多量的元素個數皆為 b 的子序列,而且它也滿足限制需求。所以最佳解必定發生在有最大合法的 b 的時候。
正式的流程如下。每次先固定一個 b 值,接著用 DFS 走訪的時候每次去試試看以下三種選擇,當我們停在一個元素上面的時候,我可能可以 (1: 跳過它)、(2: 持續往後選包含它自己在內的同種數字總共 b 個)、(3: 持續往後選包含它自己在內的同種數字總共 b+1 個),只要是選擇 2 和 3,就要註記在 bitmask 內已經選過這個數字了,這邊記得可以先計算每個數字出現第幾次的位置在哪裡,這樣的話只要常數時間就能輕易地跳轉。另外我們也必須使用 DP 陣列記錄一個子問題的最佳解,也就是最長子序列,那麼持續往前回溯到最後就能獲得全域最佳解。接著對不同的 b 重複同樣流程,就能獲得最終答案。
時間:如果每次的遞迴都是常數時間操作,那麼 DP 的時間複雜度就是空間複雜度,在固定 b 值的條件下應該是 $\mathcal O(n\cdot2^8)$,但總時間還要加上外層的 binary search,因此應該要再乘以 $log_2(n/8)$。
心得:一看到元素個數不是太多的時候,就要提高警覺 binary search 可能派上用場!另外元素種類極少也意味著它可以壓進一個 bitmask 的敘述。
實作:這題的陷阱其實還蠻多的,而且有點貝戈戈,這邊羅列如下:
- 這邊的 solve 函式有隱含一個條件就是 $b>0$,那麼 $b=0$ 的時候該怎麼辦呢?答案就藏在下面 31 行的地方,在 $b=0$ 的時候我們不作遞迴,改直接統計數字的種類即可。
- 這題的 DP 不像有些題目可以光從儲存值的正負性就能判斷是否走訪過,因為我們已經用負無限大代替不合法的解,而它繼續往上作累加的時候也必須被記憶,但此時仍然是負數,吾人只好多開一個陣列 visit 來真正記錄是否走訪過。
- 一般的 binary search 都是比較和目標值的大小,過大則範圍往下壓,過小則範圍往上提,但這題不同,函數值一開始會隨著 b 的變大而變大,之後 b 再大則函數值掉入深淵 (第 36 行的 < 0),所以說當函數值不是負的時候,我大可以直接儲存當前答案,並逕自提高範圍,就不用再比較大小關係。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
#include <bits/stdc++.h>
using namespace std;
set<int> nums; // 總共出現過哪些數字
vector<int> pos[8]; // 記錄每種數字自己出現的所有位置
bool visit[1000][256]; // 記錄哪個 dp 的 entry 已經更新過可以直接拿來用
int n, b, a[1000], dp[1000][256], times[1000]; // 記錄每個數字是第幾次出現自己這個大小的數值
int solve(int i, int bitmask) { // assert (i <= n);
if (i == n) return (bitmask<255) ? INT_MIN : 0; // 數字沒選完視為不合法,用 INT_MIN 表示不可列入答案。
if (visit[i][bitmask]) return dp[i][bitmask];
int ans = solve(i+1, bitmask);
if (!((bitmask>>a[i]) & 1)) { // assert (pos[a[i]][times[i]] == i);
if (times[i]+b-1 < pos[a[i]].size()) // 如果夠選的話
ans = max(ans, b + solve(pos[a[i]][times[i]+b-1]+1, bitmask | (1<<a[i]))); // 含自己總共 b 個
if (times[i]+b < pos[a[i]].size()) // 如果夠選的話
ans = max(ans, b+1 + solve(pos[a[i]][times[i]+b]+1, bitmask | (1<<a[i]))); // 含自己總共 (b+1) 個
}
visit[i][bitmask] = true;
return dp[i][bitmask] = ans;
}
int main() { ios_base::sync_with_stdio(false); cin.tie(nullptr); // IO 優化
cin >> n;
for (int i=0; i<n; i++) {
cin >> a[i]; a[i]--; nums.insert(a[i]);
times[i] = pos[a[i]].size();
pos[a[i]].push_back(i);
}
int t, lb=1, ub=n/8, ans=nums.size(); // b=0 時候的答案
while (lb <= ub) {
memset(visit, false, sizeof visit); // 粉重要!!!不可忽略。
b = (lb + ub) / 2;
t = solve(0,0);
if (t < 0) ub = b - 1;
else {
ans = t;
lb = b + 1;
}
}
cout << ans;
return 0;
}
|
結果:
