BZOJ 3343 and Chunks

BZOJ 3343 and Chunks

Problem Statement

There are N≤1000000 counters. You are asked to do Q≤3000 actions on these counters, either:
  1. Increment all counters between L and R by W, given that 1≤L≤R≤N and 1≤W≤1000. This command will be in the format "M <L> <R> <W>", where L, R, and W are integers.
  2. Count the amount of counters between L and R and are greater than W, given that 1≤L≤R≤N and 1≤W≤1000. This command will be in the format "A <L> <R> <W>", where L, R, and W are integers.

Core Idea

Any array of integer size N can be broke up into √N+1 chunks of √N elements each for the first N blocks and then the remainder amount for the last block.

Although this may seem as  the bloody obvious at first, splitting up an extremely large array into chunks and then doing work on the chunks can make a program vastly more efficient. There are various example:
  • Searching - if a block's lowest value is above the target or a block's highest value is lower than the target, skip it
  • Counting - a sorted block can be binary searched to find in O(log √N) how many satisfy, making a total of O(√N log √N) vs O(N log N log N) for simple binary search

Application

In this problem, we have two things to deal with - updating (increment) and querying (counting).
Notice the huge N (1 million!) and the huge Q (3 thousand!), so in worst case we would have to do about 3 billion instructions for every query. Even O(N) wouldn't work.
We have a few choices:
  • Binary search. However, this is not even close to applicable - we aren't searching, we are counting, and how would we deal with increment?
  • Fenwick tree (prefix sums). This is a good start - but if O(N) wouldn't work, adding a near constant factor of 5 (the value of the inverse Ackermann function) would not be a feasible algorithm.
  • Chunky counting. Similar to prefix sums, this method is very efficient and is the best.
But how? We will need to deal with it in two different ways:

Increment

Say we want to increment all values between X and Y by 7. We already have a chunky array. So, how do we do this?
Some values between X and Y do not fit entirely into a chunk. For example, we can have the following, where N is 36, each chunk is size 6, X is 10 and Y is 20:

01|02|03|04|05|06|07|08|09|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26|27|28|29|30|31|32|33|34|35|36|
                                                            
                          X                                Y

Although the chunk [13, 18] fits in the range we need to increment, we still need to deal with 10 through 12 and 19 through 20. Let's deal with those first.
There really is no other method to deal with those, so let's simply manually add them by 7. Although this method seems to be slow, it really is only O(√N) at worst case.
Now, to deal with the chunks. Clearly we cannot add all of the elements in the chunks together, one by one. Instead, let's use a "bonus" - for the entire chunk, each element's real value is that element's value plus the chunk bonus. So, now we simply increment the chunk bonus by 7!

Counting

To count the number of elements between X and Y that is greater than V, we return to the chunks.
01|02|03|04|05|06|07|08|09|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26|27|28|29|30|31|32|33|34|35|36|
                                                            
                          X                                Y
Again, some pieces such as 10 through 12 and 19 through 20 in the above example can't fit, so we count them manually, taking into account the bonuses as well.
Then, since we initially sorted all the chunks (refer to code's comments), we can do a binary search on each chunk, instead of manually iterating! For more detail, refer to the code's comments.

Source Code

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
int n;                                                       // n: total heights
int q;                                                       // q: total commands
int m;                                                       // m: total chunks
int chunk;                                                   // chunk: length of a chunk
int a[1000001];                                              // a: heights
int b[1000001];                                              // b: chunk-heights, chunk-sorted version of a, sorted by reset()
int pos[1000001];                                            // pos: the chunk I is in
int add[1000001];                                            // add: the bonus for every element in chunk I

void reset(int x)                                            // sort chunk X
{
    int l=(x-1)*chunk+1,r=min(x*chunk,n);                    // calculate boundaries of chunk X; L is left, R is right
    for(int i=l;i<=r;i++)                                    // for each height in chunk X (between L and R)
        b[i]=a[i];                                           // assign the chunk-height to the height
    sort(b+l,b+r+1);                                         // sort the chunk-heights in the range of this chunk
}

int find(int x,int v)                                        // search for amount of chunk-heights in chunk V greater than X
{
    int l=(x-1)*chunk+1,r=min(x*chunk,n);                    // this is a binary search in chunk V
    int last=r;                                              // don't do this, use STL instead
    while(l<=r)
    {
        int mid=(l+r)>>1;
        if(b[mid]<v)l=mid+1;
        else r=mid-1;
    }
    return last-l+1;
}

void update(int x,int y,int v)                               // increase the height of the interval [x, y] by v
{
    if(pos[x]==pos[y])                                       // if x and y are in the same chunk:
    {
        for(int i=x;i<=y;i++)a[i]=a[i]+v;                    // increment all in the interval by V
    }
    else                                                     // otherwise:
    {
        for(int i=x;i<=pos[x]*chunk;i++)a[i]=a[i]+v;         // increment all from X to the start of the next chunk by V
        for(int i=(pos[y]-1)*chunk+1;i<=y;i++)a[i]=a[i]+v;   // increment all from the start of the chunk with Y to Y by V
    }
    reset(pos[x]);reset(pos[y]);                             // sort the chunks of X and Y after being incremented
    for(int i=pos[x]+1;i<pos[y];i++)                         // for each chunk between X and Y exclusive:
       add[i]+=v;                                            // increment the bonus for that chunk by V
}

int query(int x,int y,int v)                                 // query the amount of heights in the interval [x, y] are taller than v
{
    int sum=0;                                               // set the answer to 0
    if(pos[x]==pos[y])                                       // if x and y are in the same chunk:
    {
        for(int i=x;i<=y;i++)                                // for each height in [x, y]
            if(a[i]+add[pos[i]]>=v) sum++;                   // if its height plus bonus for the chunk is greater than v, increment answer
    }
    else                                                     // otherwise:
    {
        for(int i=x;i<=pos[x]*chunk;i++)                     // for all heights from X to the start of the next chunk
            if(a[i]+add[pos[i]]>=v)sum++;                    // if its height plus bonus for the chunk is greater than v, increment answer
        for(int i=(pos[y]-1)*chunk+1;i<=y;i++)               // for all heights from the start of chunk Y to Y
            if(a[i]+add[pos[i]]>=v)sum++;                    // if its height plus bonus for the chunk is greater than v, increment answer
    }
    for(int i=pos[x]+1;i<pos[y];i++)                         // for each chunk between X and Y exclusive:
        sum+=find(i,v-add[i]);                               // increment by how many in chunk are greater than v
    return sum;                                              // return answer
}
int main()
{
    freopen("bzoj3343.in","r",stdin);                        // redirect input
    freopen("bzoj3343.out","w",stdout);                      // redirect output
    scanf("%d%d",&n,&q);                                     // read heights and queries
    chunk=int(sqrt(n));                                      // set chunk length to the square root of N rounded down
    for(int i=1;i<=n;i++)                                    // read in the heights:
    {
        scanf("%d",&a[i]);                                   // scan input
        pos[i]=(i-1)/chunk+1;                                // register its position
        b[i]=a[i];                                           // register its chunk-height
    }
    if(n%chunk)m=n/chunk+1;                                  // if it isn't a perfect square, make room for another chunk
    else m=n/chunk;                                          // if not, it's the same
    for(int i=1;i<=m;i++)reset(i);                           // update and sort chunk-heights
    for(int i=1;i<=q;i++)                                    // read in commands
    {
        char ch[5];int x,y,v;                                // initialize command and arguments
        scanf("%s%d%d%d",ch,&x,&y,&v);                       // read command and arguments
        if(ch[0]=='M')update(x,y,v);                         // if it's an increase height, run update
        else printf("%d\n",query(x,y,v));                    // if it's an query, run query and print out result
    }
    return 0;
}

Comments

Popular posts from this blog

POJ 1088: Skiing

USACO Training: Controlling Companies

USACO 2018 Open: Milking Order and Topological Sort