Count of Smaller Numbers after Self | Number of Swaps to Sort | Algorithm Swap
You are given an integer array nums
and you have to return a new counts
array. The counts array has the property where counts[i]
is the number of smaller elements to the right of nums[i]
.
Examples:
Example 1:
Input: [5,2,6,1]
Output: [2,1,1,0]
Explanation:
For the number 5, there are 2 numbers smaller than it after. (2 and 1)
For the number 2, there is 1 number smaller than it after. (1)
For the number 6, there is also 1 number smaller than it after. (1)
For the number 1, there are no numbers smaller than it after.
Hence, we have [2, 1, 1, 0]
.
Number of swaps to sort
Another version of the question is:
If we sort the nums
array by finding the smallest pair i, j
where i < j
and nums[i] > nums[j]
, how many swaps are needed?
To answer that question, we just have to sum up the numbers in the above output array: 2 + 1 + 1 = 4
swaps.
Try it yourself
Explanation
Intuition
The brute force way to solve this question is really easy and intuitive, we simply go through the list of elements. For each element, we go through the elements after it and count how many numbers are smaller than it. This would result in a O(N^2) runtime. However, this approach is not the optimal solution.
Observe that if we need to reduce our solution's complexity, we will need to count multiple numbers' smaller count in one go. This can only be done using some kind of sorted order.
But sorting destroys the origin order of the array, what can we do about that?
Recall from introduction of divide and conquer questions, the common approach of tackling a divide and conquer question is dividing the data given into two components, assuming each components is solved and then try to merge the result.
What if we divide the numbers into two components by index and then sort them separately?
Since we divided the original array by index, after the two components are both sorted, all the elements in the left component still have smaller indexes than any element in the right components in the original array.
We can utilize this fact when we combine the two arrays together.
Thus, to solve this problem, we first split the data given into two components, the left and the right components. And then we assume that both components sub-problem are already solved -- that is, we know the count of numbers smaller than themselves for each number for both components. Now all we need to know is for each number in the left component, how many elements are smaller than it in the right component.
This will allow us to know for each number in the left component, how many elements are smaller than it in the right component.
Thus, we have successfully solved the problem.
So, what is the run time of our improved solution? We split the problem into two components each recursion and go through each of the components, and each recursion takes O(N) time for the merge process. Thus we have
T(N) = 2T(N/2) + O(N)
This recurrence will yield a total run time of O(N log N).
Space Complexity: O(n)
Implementation
1from typing import List
2
3def count_smaller(nums: List[int]) -> List[int]:
4 smaller_arr = [0] * len(nums)
5
6 def merge_sort(nums):
7 if len(nums) <= 1:
8 return nums
9 mid = len(nums) // 2
10 left = merge_sort(nums[:mid])
11 right = merge_sort(nums[mid:])
12 return merge(left, right)
13
14 def merge(left, right):
15 result = []
16 l, r = 0, 0
17 while l < len(left) or r < len(right):
18 if r >= len(right) or (l < len(left) and left[l][1] <= right[r][1]):
19 result.append(left[l])
20 smaller_arr[left[l][0]] += r
21 l += 1
22 else:
23 result.append(right[r])
24 r += 1
25 return result
26
27 merge_sort(list(enumerate(nums)))
28 return smaller_arr
29
30if __name__ == "__main__":
31 nums = [int(x) for x in input().split()]
32 res = count_smaller(nums)
33 print(" ".join(map(str, res)))
34
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.Scanner;
5import java.util.stream.Collectors;
6
7class Solution {
8 public static class Element {
9 int val;
10 int ind;
11
12 public Element(int val, int ind) {
13 this.val = val;
14 this.ind = ind;
15 }
16 }
17
18 public static List<Integer> smallerArr = new ArrayList<Integer>();
19
20 public static List<Element> mergeSort(List<Element> nums) {
21 if (nums.size() <= 1) {
22 return nums;
23 }
24 int mid = nums.size() / 2;
25 List<Element> splitLeft = new ArrayList<Element>();
26 List<Element> splitRight = new ArrayList<Element>();
27 for (int i = 0; i < nums.size(); i++) {
28 if (i < nums.size() / 2) {
29 splitLeft.add(nums.get(i));
30 } else {
31 splitRight.add(nums.get(i));
32 }
33 }
34 List<Element> left = mergeSort(splitLeft);
35 List<Element> right = mergeSort(splitRight);
36 return merge(left, right);
37 }
38
39 public static List<Element> merge(List<Element> left, List<Element> right) {
40 List<Element> result = new ArrayList<Element>();
41 int l = 0;
42 int r = 0;
43 while (l < left.size() || r < right.size()) {
44 if (r >= right.size() || (l < left.size() && left.get(l).val <= right.get(r).val)) {
45 result.add(left.get(l));
46 smallerArr.set(left.get(l).ind, smallerArr.get(left.get(l).ind) + r);
47 l += 1;
48 } else {
49 result.add(right.get(r));
50 r += 1;
51 }
52 }
53 return result;
54 }
55
56 public static List<Integer> countSmaller(List<Integer> nums) {
57 for (int i = 0; i < nums.size(); i++) {
58 smallerArr.add(0);
59 }
60 List<Element> temp = new ArrayList<Element>();
61 for (int i = 0; i < nums.size(); i++) {
62 temp.add(new Element(nums.get(i), i));
63 }
64 mergeSort(temp);
65 return smallerArr;
66 }
67
68 public static List<String> splitWords(String s) {
69 return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
70 }
71
72 public static void main(String[] args) {
73 Scanner scanner = new Scanner(System.in);
74 List<Integer> nums = splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList());
75 scanner.close();
76 List<Integer> res = countSmaller(nums);
77 System.out.println(res.stream().map(String::valueOf).collect(Collectors.joining(" ")));
78 }
79}
80
1"use strict";
2
3function countSmaller(nums) {
4 const smallerArr = Array(nums.length).fill(0);
5
6 function merge(left, right) {
7 const result = [];
8 let l = 0;
9 let r = 0;
10 while (l < left.length || r < right.length) {
11 if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
12 result.push(left[l]);
13 smallerArr[left[l][0]] += r;
14 l += 1;
15 } else {
16 result.push(right[r]);
17 r += 1;
18 }
19 }
20 return result;
21 }
22
23 function mergeSort(nums) {
24 if (nums.length <= 1) return nums;
25 const mid = Math.floor(nums.length / 2);
26 const left = mergeSort(nums.slice(0, mid));
27 const right = mergeSort(nums.slice(mid));
28 return merge(left, right);
29 }
30
31 const temp = [];
32 nums.map((e, i) => temp.push([i, e]));
33 mergeSort(temp);
34 return smallerArr;
35}
36
37function splitWords(s) {
38 return s === "" ? [] : s.split(" ");
39}
40
41function* main() {
42 const nums = splitWords(yield).map((v) => parseInt(v));
43 const res = countSmaller(nums);
44 console.log(res.join(" "));
45}
46
47class EOFError extends Error {}
48{
49 const gen = main();
50 const next = (line) => gen.next(line).done && process.exit();
51 let buf = "";
52 next();
53 process.stdin.setEncoding("utf8");
54 process.stdin.on("data", (data) => {
55 const lines = (buf + data).split("\n");
56 buf = lines.pop();
57 lines.forEach(next);
58 });
59 process.stdin.on("end", () => {
60 buf && next(buf);
61 gen.throw(new EOFError());
62 });
63}
64
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <sstream>
5#include <string>
6#include <vector>
7
8std::vector<std::vector<int>> merge(std::vector<std::vector<int>>& left, std::vector<std::vector<int>>& right, std::vector<int>& counts) {
9 std::vector<std::vector<int>> res;
10 int l = 0;
11 int r = 0;
12 while (l < left.size() || r < right.size()) {
13 if (r >= right.size() || (l < left.size() && left[l][1] <= right[r][1])) {
14 res.emplace_back(left[l]);
15 counts[left[l][0]] = counts[left[l][0]] + r;
16 l++;
17 } else {
18 res.emplace_back(right[r]);
19 r++;
20 }
21 }
22 return res;
23}
24
25std::vector<std::vector<int>> merge_sort(std::vector<std::vector<int>>& nums, std::vector<int>& counts) {
26 if (nums.size() <= 1) return nums;
27 int mid = nums.size() / 2;
28 std::vector<std::vector<int>> split_left(nums.begin(), nums.begin() + mid);
29 std::vector<std::vector<int>> split_right(nums.begin() + mid, nums.end());
30 std::vector<std::vector<int>> left = merge_sort(split_left, counts);
31 std::vector<std::vector<int>> right = merge_sort(split_right, counts);
32 return merge(left, right, counts);
33}
34
35std::vector<int> count_smaller(std::vector<int>& nums) {
36 std::vector<int> counts(nums.size(), 0);
37 std::vector<std::vector<int>> idx_num_mapping;
38 for (int i = 0; i < nums.size(); i++) {
39 std::vector<int> idx_num_pair{i, nums[i]};
40 idx_num_mapping.emplace_back(idx_num_pair);
41 }
42 merge_sort(idx_num_mapping, counts);
43 return counts;
44}
45
46template<typename T>
47std::vector<T> get_words() {
48 std::string line;
49 std::getline(std::cin, line);
50 std::istringstream ss{line};
51 ss >> std::boolalpha;
52 std::vector<T> v;
53 std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
54 return v;
55}
56
57template<typename T>
58void put_words(const std::vector<T>& v) {
59 if (!v.empty()) {
60 std::copy(v.begin(), std::prev(v.end()), std::ostream_iterator<T>{std::cout, " "});
61 std::cout << v.back();
62 }
63 std::cout << '\n';
64}
65
66int main() {
67 std::vector<int> nums = get_words<int>();
68 std::vector<int> res = count_smaller(nums);
69 put_words(res);
70}
71
If the problem asks for the number of swaps, we can simply keep a counter each time we swap and don't have to keep the array.
1from typing import List
2
3def number_of_swaps_to_sort(nums: List[int]) -> int:
4 count = 0
5
6 def merge(left, right):
7 nonlocal count
8 result = []
9 l, r = 0, 0
10 while l < len(left) or r < len(right):
11 if r >= len(right) or (l < len(left) and left[l][1] <= right[r][1]):
12 result.append(left[l])
13 count += r
14 l += 1
15 else:
16 result.append(right[r])
17 r += 1
18 return result
19
20 def merge_sort(nums):
21 if len(nums) <= 1:
22 return nums
23 mid = len(nums) // 2
24 left = merge_sort(nums[:mid])
25 right = merge_sort(nums[mid:])
26 return merge(left, right)
27
28 merge_sort(list(enumerate(nums)))
29 return count
30
31if __name__ == "__main__":
32 nums = [int(x) for x in input().split()]
33 res = number_of_swaps_to_sort(nums)
34 print(res)
35
1import java.util.ArrayList;
2import java.util.Arrays;
3import java.util.List;
4import java.util.Scanner;
5import java.util.stream.Collectors;
6
7class Solution {
8 public static class Number {
9 int index;
10 int val;
11 public Number(int i, int v) {
12 index = i;
13 val = v;
14 }
15 };
16
17 private static int count;
18
19 private static List<Number> merge(List<Number> left, List<Number> right) {
20 List<Number> result = new ArrayList<>();
21 int l = 0;
22 int r = 0;
23 while (l < left.size() || r < right.size()) {
24 if (r >= right.size() || (l < left.size() && left.get(l).val <= right.get(r).val)) {
25 result.add(left.get(l));
26 count += r;
27 l++;
28 } else {
29 result.add(right.get(r));
30 r++;
31 }
32 }
33 return result;
34 }
35
36 private static List<Number> mergeSort(List<Number> nums) {
37 if (nums.size() <= 1) {
38 return nums;
39 }
40 int mid = nums.size() / 2;
41 List<Number> left = mergeSort(nums.subList(0, mid));
42 List<Number> right = mergeSort(nums.subList(mid, nums.size()));
43 return merge(left, right);
44 }
45
46 public static int numberOfSwapsToSort(List<Integer> nums) {
47 List<Number> numbers = new ArrayList<>();
48 for (int i = 0; i < nums.size(); i++) {
49 numbers.add(new Number(i, nums.get(i)));
50 }
51
52 mergeSort(numbers);
53 return count;
54 }
55
56 public static List<String> splitWords(String s) {
57 return s.isEmpty() ? List.of() : Arrays.asList(s.split(" "));
58 }
59
60 public static void main(String[] args) {
61 Scanner scanner = new Scanner(System.in);
62 List<Integer> nums = splitWords(scanner.nextLine()).stream().map(Integer::parseInt).collect(Collectors.toList());
63 scanner.close();
64 int res = numberOfSwapsToSort(nums);
65 System.out.println(res);
66 }
67}
68
1"use strict";
2
3function numberOfSwapsToSort(nums) {
4 let count = 0;
5
6 function merge(left, right) {
7 const result = [];
8 let l = 0;
9 let r = 0;
10 while (l < left.length || r < right.length) {
11 if (r >= right.length || (l < left.length && left[l][1] <= right[r][1])) {
12 result.push(left[l]);
13 count += r;
14 l += 1;
15 } else {
16 result.push(right[r]);
17 r += 1;
18 }
19 }
20 return result;
21 }
22
23 function mergeSort(nums) {
24 if (nums.length <= 1) return nums;
25 const mid = Math.floor(nums.length / 2);
26 const left = mergeSort(nums.slice(0, mid));
27 const right = mergeSort(nums.slice(mid));
28 return merge(left, right);
29 }
30
31 const temp = [];
32 nums.map((e, i) => temp.push([i, e]));
33 mergeSort(temp);
34 return count;
35}
36
37function splitWords(s) {
38 return s === "" ? [] : s.split(" ");
39}
40
41function* main() {
42 const nums = splitWords(yield).map((v) => parseInt(v));
43 const res = numberOfSwapsToSort(nums);
44 console.log(res);
45}
46
47class EOFError extends Error {}
48{
49 const gen = main();
50 const next = (line) => gen.next(line).done && process.exit();
51 let buf = "";
52 next();
53 process.stdin.setEncoding("utf8");
54 process.stdin.on("data", (data) => {
55 const lines = (buf + data).split("\n");
56 buf = lines.pop();
57 lines.forEach(next);
58 });
59 process.stdin.on("end", () => {
60 buf && next(buf);
61 gen.throw(new EOFError());
62 });
63}
64
1#include <algorithm>
2#include <iostream>
3#include <iterator>
4#include <sstream>
5#include <string>
6#include <utility>
7#include <vector>
8
9std::vector<std::vector<int>> merge(std::vector<std::vector<int>>& left, std::vector<std::vector<int>>& right, int& count) {
10 std::vector<std::vector<int>> res;
11 int l = 0;
12 int r = 0;
13 while (l < left.size() || r < right.size()) {
14 if (r >= right.size() || (l < left.size() && left[l][1] <= right[r][1])) {
15 res.emplace_back(left[l]);
16 count += r;
17 l++;
18 } else {
19 res.emplace_back(right[r]);
20 r++;
21 }
22 }
23 return res;
24}
25
26std::vector<std::vector<int>> merge_sort(std::vector<std::vector<int>>& nums, int& count) {
27 if (nums.size() <= 1) return nums;
28 int mid = nums.size() / 2;
29 std::vector<std::vector<int>> split_left(nums.begin(), nums.begin() + mid);
30 std::vector<std::vector<int>> split_right(nums.begin() + mid, nums.end());
31 std::vector<std::vector<int>> left = merge_sort(split_left, count);
32 std::vector<std::vector<int>> right = merge_sort(split_right, count);
33 return merge(left, right, count);
34}
35
36int number_of_swaps_to_sort(std::vector<int>& nums) {
37 int count = 0;
38 std::vector<std::vector<int>> idx_num_mapping;
39 for (int i = 0; i < nums.size(); i++) {
40 std::vector<int> idx_num_pair{i, nums[i]};
41 idx_num_mapping.emplace_back(std::move(idx_num_pair));
42 }
43 merge_sort(idx_num_mapping, count);
44 return count;
45}
46
47template<typename T>
48std::vector<T> get_words() {
49 std::string line;
50 std::getline(std::cin, line);
51 std::istringstream ss{line};
52 ss >> std::boolalpha;
53 std::vector<T> v;
54 std::copy(std::istream_iterator<T>{ss}, std::istream_iterator<T>{}, std::back_inserter(v));
55 return v;
56}
57
58int main() {
59 std::vector<int> nums = get_words<int>();
60 int res = number_of_swaps_to_sort(nums);
61 std::cout << res << '\n';
62}
63