카라추바 알고리즘
10만자리 곱하기 10만자리같은 곱셈은 애시당초 자료형으로 표현할 수도 없고, 문자열로 바꾸어 하나하나 곱하면 $O(N^2)$의 복잡도를 갖게 된다.
이때 등장하는 것이 카라추바 알고리즘. 시간복잡도를 $O(N^{log3})$으로 줄여준다.
그 방법을 한번 살펴보자.
일단 256자리의 두 정수 a, b를 곱한다고 생각해볼때 a와 b를 다음과 같이 나눈다.
$a = a_1\times10^{128}+a_0$
$b = b_1\times10^{128}+b_0$
a1,b1은 각각 a,b의 첫 128자리, a0,b0는 각각 a,b의 뒷 128자리를 나타낸다.
그럼 이제 a * b의 계산 과정은 다음과 같이 나타낼 수 있다.
$a\times b = (a_1 \times b_1) \times 10^{256} + (a_1b_0+a_0b_1)\times10^{128} +a_0\times b_0$
여기서 $z2 = a1 \times b1$, $z0 = a0 \times b0$, $z1 = _1b_0+a_0b_1$ 이라 할때
이 상태에서는 n/2 자리수의 두 정수 곱 (128자리)가 총 4번 이루어진다. (z2 1번, z0 1번, z1 2번)
이 곱셈횟수를 줄이기 위해 수식을 아래 수식을 사용하여 바꾼다면
$(a_0+a_1)\times(b_0+b_1) = a_0\times b_0 + (a_1b_0 + a_0b_1) + a_1\times b_1$
z1를 다음과 같이 변경할 수 있고, $z1 = (a_0+a_1)\times(b_0+b_1) - z_0 - z_1$
이러면 $a\times b$는 결국 n/2 자리수의 곱셈 3번, 덧셈 2번, 뺄셈 2번으로 수행할 수 있다.
이제 이 n/2 자리수에도 재귀적으로 알고리즘을 적용하면 결국 $O(N^{log3})$의 시간복잡도를 얻을 수 있다.
코드
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Karatsuba {
public static StringBuilder sb = new StringBuilder();
public static java.io.BufferedReader br = new java.io.BufferedReader(new java.io.InputStreamReader(System.in));
static ArrayList<Integer> aaa = new ArrayList<Integer>();
static ArrayList<Integer> bbb= new ArrayList<Integer>();
static ArrayList<Integer> ans= new ArrayList<Integer>();
public static void main(String[] args) throws IOException {
String s1 = br.readLine().trim();
String s2 = br.readLine().trim();
for(int i = 0; i<s1.length(); i++) {
aaa.add(s1.charAt(i)-'0');
}
Collections.reverse(aaa);
for(int i = 0; i<s2.length(); i++) {
bbb.add(Integer.parseInt(s2.substring(i, i+1)));
}
Collections.reverse(bbb);
ans = karatsuba(aaa, bbb);
//앞자리수 0 없애주기
while (ans.size() > 1 && ans.get(ans.size() - 1) == 0)
ans.remove(ans.size() - 1);
String answer = "";
for(int i = ans.size()-1; i>=0; i--) {
answer = answer.concat(ans.get(i).toString());
}
System.out.println(answer);
}
static ArrayList<Integer> karatsuba(ArrayList<Integer> a, ArrayList<Integer> b) {
int a_size = a.size();
int b_size = b.size();
//A>B여야 한다
if (a_size < b_size)
return karatsuba(b, a);
//만약에 비어있는 애 오면 끝
if (a_size == 0 || b_size == 0)
return null;
//50이하에서는 보통 기본 곱셈이 빠르다
if (a_size <= 50)
return multiply(a, b);
int half = a_size / 2;
//a>b일때 b0 = null이 될 때도 있다. karatsuba에서나, sub, sum할때 예외처리를 해주자.
ArrayList<Integer> a0 = new ArrayList<Integer>(a.subList(0, half));
ArrayList<Integer> a1 = new ArrayList<Integer>(a.subList(half, a.size()));
ArrayList<Integer> b0 = new ArrayList<Integer>(b.subList(0, Math.min(b.size(), half)));
ArrayList<Integer> b1 = new ArrayList<Integer>(b.subList(Math.min(b.size(), half), b.size()));
// z2 = a1 * b1
ArrayList<Integer> z2 = karatsuba(a1, b1);
// z0 = a0 * b0
ArrayList<Integer> z0 = karatsuba(a0, b0);
// a0 = a0 + a1; b0 = b0 + b1
a0 = karatsubasum(a0, a1, 0);
b0 = karatsubasum(b0, b1, 0);
// z1 = (a0 * b0) - z0 - z2;
ArrayList<Integer> z1 = karatsuba(a0, b0);
z1 = karatsubasub(z1, z0);
z1 = karatsubasub(z1, z2);
// ret = z0 + z1 * 10^half + z2 * 10^(half*2)
ArrayList<Integer> ret = new ArrayList<Integer>();
ret = karatsubasum(ret, z0, 0);
ret = karatsubasum(ret, z1, half);
ret = karatsubasum(ret, z2, half * 2);
return ret;
}
public static ArrayList<Integer> ensureSize(ArrayList<Integer> list, int size) {
list.ensureCapacity(size);
while (list.size() < size) {
list.add(0);
}
return list;
}
static ArrayList<Integer> multiply(List<Integer> a, List<Integer> b){
ArrayList<Integer> c = new ArrayList<Integer>();
c = ensureSize(c, a.size()+b.size()+1);
for(int i =0; i<a.size(); i++) {
for(int j =0; j<b.size(); j++) {
c.set(i+j, c.get(i+j) + a.get(i)*b.get(j));
}
}
c = normalize(c);
return c;
}
//a = a + b*(10^k);
public static ArrayList<Integer> karatsubasum(ArrayList<Integer> a, ArrayList<Integer> b, int k){
if(b == null) {
return a;
}
a = ensureSize(a, Math.max(a.size(), b.size() + k));
for (int i = 0; i < b.size(); i++) {
a.set(i + k, a.get(i + k) + b.get(i));
}
a = normalize(a);
return a;
}
//a= a-b ; a>=b 일때
public static ArrayList<Integer> karatsubasub(ArrayList<Integer> a, ArrayList<Integer> b){
if(b == null) {
return a;
}
a = ensureSize(a, Math.max(a.size(), b.size()) + 1);
for (int i = 0; i < b.size(); i++) {
a.set(i, a.get(i) - b.get(i));
}
a = normalize(a);
return a;
}
public static ArrayList<Integer> normalize(ArrayList<Integer> num) {
num.add(0);
for (int i = 0; i < num.size() - 1; i++) {
if (num.get(i) < 0) {
int borrow = (Math.abs(num.get(i)) + 9) / 10;
num.set(i + 1, num.get(i + 1) - borrow);
num.set(i, num.get(i) + borrow * 10);
} else {
num.set(i + 1, num.get(i + 1) + num.get(i) / 10);
num.set(i, num.get(i) % 10);
}
}
if (num.get(num.size() - 1) == 0)
num.remove(num.size() - 1);
return num;
}
}
'알고리즘 > 스터디' 카테고리의 다른 글
12. 그래프 (2) (최단경로, 프림, 크루스칼) (0) | 2020.09.14 |
---|---|
11. 그래프 (1) (인접행렬, 인접리스트, DFS, BFS) (0) | 2020.09.10 |
10. 순열과 조합 [JAVA] (0) | 2020.09.02 |
9. 최대, 최소 찾기 (순차, 토너먼트, 선택 알고리즘) (0) | 2020.09.01 |
8. 탐색 - 선형 탐색, 이진 탐색, 이진 탐색 트리, 해시 탐색 (0) | 2020.08.25 |
Comment