# 题目大意给定一个长度为 n n n 的字符串 s s s ,s s s 仅由 a / b 组成,还有两个整数 A A A 和 B B B 你需要计算的是满足 a 的个数 ≥ A \ge A ≥ A 且 b 的个数 < B < B < B 的子串个数
# 数据范围1 ≤ n ≤ 3 × 1 0 5 1 \le n \le 3 \times 10^5 1 ≤ n ≤ 3 × 1 0 5 # 题解# 错解一开始想的是双指针去做
每次都是以 i i i 为终点,j j j 去追赶,先挪 i i i 使得 j → i j \to i j → i 满足了区间内 a 数量 ≥ A \ge A ≥ A 然后 b 数量需要 < B < B < B ,所以不断缩短区间以找到 j j j 的两个取值区间内,c n t b cnt_b c n t b 第一次小于 B B B 的位置 基于 1. 的位置,继续靠近 i i i ,并且满足 a 的数量 ≥ A \ge A ≥ A 的最后的一个位置 1. 和 2. 之间的区间就是可以的取值然后就 T L E \color{red}{TLE} T L E 了 这里 j j j 每次还得挪回去,有点类似 O ( N 2 ) O(N^2) O ( N 2 ) 了
C++ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 int count (int l, int r, char type) { if (type == 'a' ) return preA[r] - preA[l - 1 ]; return preB[r] - preB[l - 1 ]; } void solve () { cin >> n >> A >> B >> s, s = " " + s; preA = vector <int >(n + 1 ), preB = vector <int >(n + 1 ); for (int i = 1 ; i <= n; i++) preA[i] = preA[i - 1 ] + (s[i] == 'a' ), preB[i] = preB[i - 1 ] + (s[i] == 'b' ); int res = 0 ; int i = 1 , j = 1 ; while (i <= n) { while (i <= n && count (j, i, 'a' ) < A) i++; if (i == n + 1 ) break ; while (j <= i && count (j, i, 'b' ) >= B) j++; if (j > i || count (j, i, 'a' ) < A) { i++; continue ; } int l = j; while (j <= i && count (j, i, 'a' ) >= A) j++; int r = j - 1 ; res += r - l + 1 ; j = l; i++; } cout << res << '\n' ; }
# 正解枚举终点 i i i 的思想是对的,但是,我们可以用二分以 log \log log 的时间,找到 ( l , r ) (l, r) ( l , r ) 首先,r r r 是需要满足 preA[i] - preA[r] >= A ,我们要找到最大的 r r r preA[i] - preA[r] >= A 等价于 preA[r] <= preA[i] - A 所以等价于,我们找到满足 preA[r] > preA[i] - A 的第一个位置,然后 − 1 -1 − 1 就是我们需要的 r r r 此时的 r + 1 r + 1 r + 1 就是最大右边界
然后,l l l 是需要满足 preB[i] - preB[l] < B ,我们要找到最小的 l l l preB[i] - preB[l] < B 等价于 preB[l] > preB[i] - B ,也就是直接用 upper_bound 此时 l + 1 l + 1 l + 1 就是最小左边界
需要注意的是,这里的 upper_bound 的查找范围我们用的是 pre.begin(), pre.begin() + i ,实际查找就是 [ 0 , i − 1 ] [0, i - 1] [ 0 , i − 1 ] ,不用 pre.begin() + i + 1 的原因就是因为我们前面找的结果 + 1 +1 + 1 才是区间,如果用了的话,那可能刚刚好查到 i i i ,然后 + 1 +1 + 1 直接超界了
C++ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 #include <bits/stdc++.h> #define int long long using namespace std;int n, A, B;string s; vector<int > preA, preB; int count (int l, int r, char type) { if (type == 'a' ) return preA[r] - preA[l - 1 ]; return preB[r] - preB[l - 1 ]; } void solve () { cin >> n >> A >> B >> s, s = " " + s; preA = vector <int >(n + 1 ), preB = vector <int >(n + 1 ); for (int i = 1 ; i <= n; i++) preA[i] = preA[i - 1 ] + (s[i] == 'a' ), preB[i] = preB[i - 1 ] + (s[i] == 'b' ); int res = 0 ; for (int i = A; i <= n; i++) { int r = upper_bound (preA.begin (), preA.begin () + i, preA[i] - A) - preA.begin () - 1 ; int l = upper_bound (preB.begin (), preB.begin () + i, preB[i] - B) - preB.begin (); if (r >= l) res += (r - l + 1 ); } cout << res << '\n' ; }