Inclusion-exclusion principle

 Published on 2018-05-09

When we write programs and various software solutions, we often come across situations where we need to count objects or states that meet a certain criteria. Most often, we do this by calculating the required number directly. However, there are situations when it is actually easier to calculate the number of objects or states that do not meet the required criteria, and then subtract that number from the total number of objects or states. For example, if we need to calculate how many people in the United Kingdom are currently free (not incarcerated), we can infer that it is better to determine the number of people currently in prison, and then subtract that number from the total population according to the latest census. We do this because the number of prisoners is easier to determine - it is much smaller, and we should be able to retrieve it by contacting each facility where sentences are being served.

In addition, there are problems where we need to analyze more complicated situations - namely, when there are multiple sets of objects or possible values, and we need to use them all to solve a bigger problem. For example, what if we need to calculate how many numbers smaller than 100 are divisible by 3 or 5? One algorithm that we can use is to list all numbers from 1 to 100, and then check how many of them are divisible by 3 or 5. In our case, there are 47 numbers that are divisible by either 3 or 5.

A much more efficient approach (by which we can solve similar problems for much bigger values than 100) is to use the so-called inclusion-exclusion principle. More specifically, the inclusion-exclusion principle is a counting technique that generalizes (extends) the basic idea of calculating the number of elements in the union of two sets - namely, that the number of elements in that union is the sum of the number of elements in the first set and the number of elements in the second set, minus the number of elements in their intersection, as shown in the figure below.

For the problem that we were discussing earlier, that would mean that we can discover how many numbers are divisible by 3 or 5, if we calculate how many numbers are divisible by 3 (this is easy because that number is equal to 100/3=33, if we throw away the remainder), then how many numbers are divisible by 5 (100/5=20), and then subtract from their sum (33+20=53) the number of elements that are divisible by both 3 and 5 (these are in fact the numbers divisible by 3*5=15, so we have 100/15=6). So, the final result is equal to 53-6=47.

The inclusion-exclusion principle enables us to solve similar problems even when we have more than two sets - so, for example, to calculate the number of elements in a union of 3 sets (named A, B and C), you need to calculate the sum of all these values (note the sign, + to add, and - to subtract):

signvalue
+the number of elements in the first set
+the number of elements in the second set
+the number of elements in the third set
-the number of elements in the intersection of the first and second set
-the number of elements in the intersection of the first and third set
-the number of elements in the intersection of the second and third set
+the number of elements in the intersection of the first, second and third set

You can also see this on the picture given below.

At this point, if you failed to notice from the table given above, we should mention that it is actually very simple to find out which sign (+ or -) stands in front of each group. Specifically, if the group has an odd number of sets (one, three, five, etc.), then we have a + sign, and if the group has an even number of sets (two, four, etc.), then we have the - sign. Therefore, if you now see the table again, you will notice a "-" sign where we subtract "the number of elements in the intersection of the first and second set" (because the group has two sets). On the other hand, we have a "+" sign where we add "the number of elements in the intersection of the first, second and third set."

In the first example that we discussed (division with 3 or 5), we initially added the numbers (100/3) and (100/5), because they indicate groups of one set (100/3 for "divisible by 3", and 100/5 for "divisible by 5"), and then we subtracted the number (100/15) because that one comes from the intersection of two sets ("divisible by 3" and "divisible by 5").

Note that this approach of calculating how many numbers from the first 100 are divisible with one number (X) and another number (Y) by dividing 100 with their product (X*Y), can only be used if both numbers are (co)prime, which in this case is true. If it wasn't true, then we would have needed to find another way to calculate the number of elements in the intersection of the two sets. The inclusion-exclusion principle shows us how to describe the union of multiple sets. We need to figure out by ourselves how to calculate the number of elements in an intersection. In this particular case, it was simple to calculate the number of elements in the intersection because we had prime numbers (3 and 5).

Subsets

In the previous section, we talked about which sign stands in front of each group. But how can we, using a computer, create all these groups of 1 set, 2 sets (so we can calculate the intersection), 3 sets, etc.? The answer is actually very simple - i.e. we should just generate all viable subsets of the list of sets that we have. For example, if we have 2 sets (A and B), then the subsets of this list are: {} (the empty subset does not interest us, so we ignore it), {A}, {B}, {A, B}. If we have 3 sets (A, B and C), then the subsets of this list are: {} (ignored), {A}, {B}, {C}, {A, B}, {A, C}, {B, C}, {A, B, C}.

Of course, we already know which sign (+ or -) will stand before each of these groups, given that it is easy to determine the number of elements in a subset: 1 for {C}, 2 for {B, C}, 3 for {A, B, C}, etc.

Note that (in this post) we are talking about the principle of inclusion (+) and exclusion (-), and we have already mentioned why it is necessary to create all viable subsets of a list (for example, A, B, C). Now, it's time to see how they can be (efficiently) generated with the help of a computer. In fact, this is a problem that we have already discussed in one of the previous posts, and the simplest way to solve it involves looking at the bits of an integer. Specifically, if we have a list of N elements (A, B, C, D, ...), then there are 2N subsets of this list.

Now, if we know that there are 2N subsets, we can use one "for" loop that will go through all integers from 0 to 2N-1, and those integers (written in binary) look like this (this is an example where N=3):

binarydecimalsets
000(0)[]
001(1)[A]
010(2)[B]
011(3)[A, B]
100(4)[C]
101(5)[A, C]
110(6)[B, C]
111(7)[A, B, C]

Let the bit (0 or 1) that is on the right in each of the numbers indicate if A is part of that subset; let the bit (0 or 1) located on the second position from the right indicate whether B is part of the subset, etc. Thus, for example, 001 corresponds to the subset [A], because only the last bit is equal to 1. Similarly, 111 corresponds to [A, B, C] because all bits are equal to 1, etc. In practice, most of the time, each of the N elements in the list (A, B, C, D, ...) is replaced with a number from 0 to N-1 so computers can deal with them more easily - therefore, we actually use a list like (0, 1, 2, 3, ... N-1).

The next program will print all subsets (A, B, C, D, etc.) for a given number of elements N.

#include <iostream>
#include <string>
using namespace std;

int main() 
{
     int N;
     cin >> N;
     
     //(1 << N) is equal to 2^N
     for (int i=0; i < (1 << N); i++) 
     {
          
          string current = "";
          for (int j=0; j < N; j++) 
          {
               
               //check if the j-th bit (from the right) is equal to 1
               if (((1 << j) & i) != 0) 
               {
                    char ch = ('A' + j);
                    current += ch;
               }
          }
          
          //print the current subset
          cout << current << endl;
     }
     
     return 0;
}

The last thing that we will mention in this section is that the condition ((1 << j) & i) != 0) checks if the j-th bit (from right to left) in the number "i" is equal to 1. In other words, with (1 << j) we get a number that has only one bit set (at position j), and then with the help of the AND (&) operator we check (bit by bit) which bits in both numbers are equal to 1. Since there is only one bit in the number (1 << j) that is equal to 1, this calculated value will be different from 0 only when the bit at the same position in the number "i" is also equal to 1.

Example

In the previous section, we indicated that the inclusion-exclusion principle is based on the idea that some problems where we need to calculate the number of elements in a union of sets can be solved using the count of elements in the sets themselves and the number of elements in their intersection. We talked about how we can create all viable subsets (which we need in order to calculate the intersections), and how to decide if a subset will affect the final result positively (inclusion, with a plus sign), or negatively (exclusion, with a minus sign).

In fact, we have thoroughly explained the principle, and how to use it. In addition, we already saw an example where it was used to count how many numbers are divisible by one or more primes (3, 5, etc.).

In this section, we will try to look at another example (much more complex than before) in order to better understand things. Don't worry if you don't understand everything, as this is a pretty hard problem.

Let's explain the problem that we want to solve. Initially, we have two numbers N and M, and we want to calculate how many numbers smaller or equal to N have no common divisor with M. We won't consider 1 to be a common divisor, because in that case the task makes no sense (all numbers are divisible by 1).

For example, if N=20 and M=6, our program should print 7, because there are 7 numbers smaller or equal to 20, which don't have a common divisor with 6 - these are: 1, 5, 7, 11, 13, 17, 19. Note that, for example, 8 is not in this list because 2 is a common divisor of 8 and M=6; similarly, 9 is not in this list because 3 is a common divisor of 9 and M=6, etc.

This task can be solved by applying the inclusion-exclusion principle. Specifically, the idea that we need to use comes from the example presented earlier (where we eliminated certain numbers like 8 and 9, with primes that happen to be divisors of both those numbers and M: 2 and 3). If we eliminate certain numbers when they have a common divisor with M, then the final result will be equal to N minus the count of such numbers. More specifically, if N=20 and there are 13 numbers that have a common divisor with M=6, then there are 20-13=7 numbers that do not have a common divisor. (If this is not clear to you, please look at the first few paragraphs of this post.)

Furthermore, in order to eliminate a number (by marking it as having a common divisor with M), it is sufficient to find just one number that is a divisor of both that number and M. The best idea here is to look at the prime divisors of a number, since any number greater than 1 can be represented by the product of its prime divisors.

Now, we can use the same algorithm we used previously. Specifically, let's examine the numbers N=20 and M=6 again. If we define the number M=6 through its prime divisors, we'll notice that those are 2 and 3. Now, it is necessary to determine how many numbers smaller or equal to N are divisible by 2, how many are divisible by 3, and how many are divisible by (2*3). In this case, we have 20/2=10 numbers divisible by 2, we have 20/3=6 numbers divisible by 3, and we have 20/6=3 numbers divisible by both 2 and 3. According to the principle we discussed above (that subsets with an odd number of elements have a plus sign, and those with an even number of elements have a minus sign), we include (with a + sign) 10 and 6, and we exclude (with a - sign) the value 3. For the value 3 we used the minus sign (exclusion) because that number comes from the calculation 20/6=3; in other words, it shows the number of elements that are divisible by both 2 and 3, so it's an intersection of two sets ("divisible by 2" and "divisible by 3"), and since we have an even number of elements, we have the minus sign. The final calculation is 10+6-3=13, so we have 13 numbers that have a common divisor with M. According to this, there are 20-13=7 numbers that do not have a common divisor with M.

The implementation of this idea for any integer N and M is given below.

#include <iostream>
#include <vector>
#include <string>
using namespace std;

vector<int> find_prime_divisors(int N) 
{
     
     vector<int> divisors;
     
     //we iterate until the square root of N     (i <= sqrt(N)) = (i*i <= N)
     //because if there is a divisor bigger than sqrt(N),
     //then there is a smaller divisor as well, and we will find that one first
     for (int i=2; i*i <= N; i++) 
     {
          if ((N%i) == 0) 
          {
               divisors.push_back(i);
               
               //we found a divisor, no need to add it multiple times
               while ((N%i) == 0)
               {
                    N /= i;
               }
          }
     }
     
     //maybe there is a prime number in N now, or N was a prime at the start?
     if (N >= 2) 
     {
          divisors.push_back(N);
     }
     
     return divisors;
}



//the number of elements in a subset
//is equal to the number of 1 bits
int elements_in_subset(int subset) 
{
     
     int one_bits = 0;
     
     //analyze the number bit by bit...
     //division by 2 deletes the last bit
     while (subset > 0) 
     {
          if ((subset % 2) == 1) 
          {
               one_bits++;
          }
          
          subset /= 2;
     }
     
     return one_bits;
}


int calculate(int N, int M) 
{
     
     vector<int> divisors = find_prime_divisors(M);
     int divisors_count = divisors.size();
     
     int bad_numbers_count = 0;
     
     //we analyze all subsets - there are a total of 2^(divisors_count)
     //we start from i=1, because we don't care about the empty subset [].
     for (int i=1; i < (1 << divisors_count); i++) 
     {
          
          int product = 1; //product of the elements in the subset
          //so that we know how much to add or subtract.
          //for example, for "divisible by 2" and "divisible by 3", we have product=2*3
          
          for (int j=0; j < divisors_count; j++) 
          {
               if ((i & (1 << j)) != 0) 
               {
                    
                    //in the product we have those values that are
                    //part of the subset defined by "i"
                    product *= divisors[j];
               }
          }
          
          //how many values are in the intersection? for example,
          //if N=20, then there are 20/6=3 numbers divisible by both 2 and 3.
          int number_of_values = (N / product);
          
          if ((elements_in_subset(i) % 2) == 1) 
          {
               //we add if there is an odd number of elements
               bad_numbers_count += number_of_values;
          }
          else 
          {
               //we subtract if there is an even number of elements
               bad_numbers_count -= number_of_values;
          }
     }
     
     return (N - bad_numbers_count);
}

int main() 
{
     int N, M;
     cin >> N >> M;
     
     cout << calculate(N, M) << endl;
     return 0;
}

Although the problem that we solved with the previous program is quite complex, it is a great example to consider, because almost all problems that can be solved with the inclusion-exclusion principle are based on a similar idea. Additionally, we also saw how to create subsets using bit manipulation. In other words, we can easily create all subsets, and we know how to determine the number of elements in a subset by looking at the number of bits equal to 1. If we have a problem for which we can identify how to calculate the number of elements in the intersection of the sets defined by it, then we can simply add (include) or subtract (exclude) those values in order to calculate the number of elements in their union.