There are N males and N females (N<=10^5). Each male and female have indexes (unique for all males and unique for all females). All males are arranged in one line (standing at equal distances) according to increasing value of indexes. Similarily for females in another line.
Now pairings are done(between males and females) and line is drawn between each pair. We have to find how many pair of lines intersect each other.
We map all indexes of males to values between 1 and N. Let’s define a map A. A[i]=j, if there is a pair of male with value i and female with index j. Now, answer is number of pairs of i and j such that A[i] > A[j] and i < j. This is same as counting inversions in an array. We can use Binary Indexed Tree or Enhanced MergeSort to count such pairs.
See explanation for implementation and further details.
We need to count how many pairs i,j exist such that A[i] > A[j] and i < j.
Suppose we pick A[i]. We need to count how many j exist(j > i) such that A[j] < A[i]. So if we have a data structure where we have inserted all A[j] such that j>i, we just need to count how many values in data structure exist which are less than A[i].
That’s where a BIT comes in. Read the topcoder tutorial on BIT first. We will update A[j] with 1 if insertion is done. To count, we find the sum of all elements in the BIT from j=1 to A[i]-1.
But since BIT can handle upto 10^6 indexes, we map all A[i]'s to numbers between 1 and N. How do we map? We sort all the indexes and give them values from 1 to N.
// Assume all numbers have been mapped properly between 1 to N. ans=0 for i=1 to N: ans += foo(i) print ans // function foo foo(i): for j=i+1 to N: update(A[j],1) // update(x,1) increases value of x in BIT by 1. return read(A[i]) // read(x) returns some of values of all indexes which are less than x
But the above pseudo code will be O(n*n) solution. How can we do it in one loop? Since for each i, we need A[j] where j>i, so we begin from i=N moving towards 1.
// Assume all numbers have been mapped properly between 1 to N. ans=0 for i=N to 1: ans += read(A[i]) update(A[i],1) print ans
Each update and read query takes O(log(n)), so total complexity O(NlogN).
Using Enhanced Merge Sort, here is a proper explanation.