13. 카라추바 알고리즘 (큰 수의 곱) [JAVA]
반응형

카라추바 알고리즘

10만자리 곱하기 10만자리같은 곱셈은 애시당초 자료형으로 표현할 수도 없고, 문자열로 바꾸어 하나하나 곱하면 $O(N^2)$의 복잡도를 갖게 된다.

이때 등장하는 것이 카라추바 알고리즘. 시간복잡도를 $O(N^{log3})$으로 줄여준다.

그 방법을 한번 살펴보자.

일단 256자리의 두 정수 a, b를 곱한다고 생각해볼때 a와 b를 다음과 같이 나눈다.

$a = a_1\times10^{128}+a_0​$

$b = b_1\times10^{128}+b_0​$

a1,b1은 각각 a,b의 첫 128자리, a0,b0는 각각 a,b의 뒷 128자리를 나타낸다.

그럼 이제 a * b의 계산 과정은 다음과 같이 나타낼 수 있다.

$a\times b = (a_1 \times b_1) \times 10^{256} + (a_1b_0+a_0b_1)\times10^{128} +a_0\times b_0$

여기서 $z2 = a1 \times b1$,   $z0 = a0 \times b0$,   $z1 = _1b_0+a_0b_1$ 이라 할때

이 상태에서는 n/2 자리수의 두 정수 곱 (128자리)가 총 4번 이루어진다. (z2 1번, z0 1번, z1 2번)

이 곱셈횟수를 줄이기 위해 수식을 아래 수식을 사용하여 바꾼다면

$(a_0+a_1)\times(b_0+b_1) = a_0\times b_0 + (a_1b_0 + a_0b_1) + a_1\times b_1​$

z1를 다음과 같이 변경할 수 있고, $z1 = (a_0+a_1)\times(b_0+b_1) - z_0 - z_1$

이러면 $a\times b$는 결국 n/2 자리수의 곱셈 3번, 덧셈 2번, 뺄셈 2번으로 수행할 수 있다.

이제 이 n/2 자리수에도 재귀적으로 알고리즘을 적용하면 결국 $O(N^{log3})$의 시간복잡도를 얻을 수 있다.

 

 

코드

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class Karatsuba {
	public static StringBuilder sb = new StringBuilder();
	public static java.io.BufferedReader br = new java.io.BufferedReader(new java.io.InputStreamReader(System.in));
	static ArrayList<Integer> aaa = new ArrayList<Integer>();
	static ArrayList<Integer> bbb= new ArrayList<Integer>();
	static ArrayList<Integer> ans= new ArrayList<Integer>();
	
	public static void main(String[] args) throws IOException {
		String s1 = br.readLine().trim();
		String s2 = br.readLine().trim();
		
		for(int i = 0; i<s1.length(); i++) {
			aaa.add(s1.charAt(i)-'0');			
		}
		Collections.reverse(aaa);
		
		
		for(int i = 0; i<s2.length(); i++) {
			bbb.add(Integer.parseInt(s2.substring(i, i+1)));		
		}
		Collections.reverse(bbb);
		ans = karatsuba(aaa, bbb);
		
		//앞자리수 0 없애주기
		while (ans.size() > 1 && ans.get(ans.size() - 1) == 0)
			ans.remove(ans.size() - 1);
		String answer = "";	
		
		for(int i = ans.size()-1; i>=0; i--) {
			answer = answer.concat(ans.get(i).toString());
		}
		System.out.println(answer);
		
	}

	static ArrayList<Integer> karatsuba(ArrayList<Integer> a, ArrayList<Integer> b) {
		int a_size = a.size();
		int b_size = b.size();
		//A>B여야 한다
		if (a_size < b_size)
			return karatsuba(b, a);
        //만약에 비어있는 애 오면 끝
		if (a_size == 0 || b_size == 0)
			return null;
        //50이하에서는 보통 기본 곱셈이 빠르다
		if (a_size <= 50)
			return multiply(a, b);

		int half = a_size / 2;
		//a>b일때 b0 = null이 될 때도 있다. karatsuba에서나, sub, sum할때 예외처리를 해주자.
        ArrayList<Integer> a0 = new ArrayList<Integer>(a.subList(0, half));
		ArrayList<Integer> a1 = new ArrayList<Integer>(a.subList(half, a.size()));
		ArrayList<Integer> b0 = new ArrayList<Integer>(b.subList(0, Math.min(b.size(), half)));
		ArrayList<Integer> b1 = new ArrayList<Integer>(b.subList(Math.min(b.size(), half), b.size()));

		// z2 = a1 * b1
		ArrayList<Integer> z2 = karatsuba(a1, b1);
		// z0 = a0 * b0
		ArrayList<Integer> z0 = karatsuba(a0, b0);
		// a0 = a0 + a1; b0 = b0 + b1
		a0 = karatsubasum(a0, a1, 0);
		b0 = karatsubasum(b0, b1, 0);

		// z1 = (a0 * b0) - z0 - z2;
		ArrayList<Integer> z1 = karatsuba(a0, b0);
		z1 = karatsubasub(z1, z0);
		z1 = karatsubasub(z1, z2);

		// ret = z0 + z1 * 10^half + z2 * 10^(half*2)
		ArrayList<Integer> ret = new ArrayList<Integer>();
		ret = karatsubasum(ret, z0, 0);
		ret = karatsubasum(ret, z1, half);
		ret = karatsubasum(ret, z2, half * 2);

		return ret;
	}

	public static ArrayList<Integer> ensureSize(ArrayList<Integer> list, int size) {
		list.ensureCapacity(size);
		while (list.size() < size) {
			list.add(0);
		}

		return list;
	}
	
	static ArrayList<Integer> multiply(List<Integer> a, List<Integer> b){
		ArrayList<Integer> c = new ArrayList<Integer>();
		c = ensureSize(c, a.size()+b.size()+1);
		for(int i =0; i<a.size(); i++) {
			for(int j =0; j<b.size(); j++) {			
				c.set(i+j, c.get(i+j) + a.get(i)*b.get(j));
			}
		}
		c = normalize(c);
		return c;
	}
	
	//a = a + b*(10^k);
	public static ArrayList<Integer> karatsubasum(ArrayList<Integer> a, ArrayList<Integer> b, int k){
		if(b == null) {
			return a;
		}
		a = ensureSize(a, Math.max(a.size(), b.size() + k));
		for (int i = 0; i < b.size(); i++) {
			a.set(i + k, a.get(i + k) + b.get(i));
		}	
		a = normalize(a);
		return a;
	}
	
	//a= a-b ; a>=b 일때
	public static ArrayList<Integer> karatsubasub(ArrayList<Integer> a, ArrayList<Integer> b){
		if(b == null) {
			return a;
		}
		a = ensureSize(a, Math.max(a.size(), b.size()) + 1);
		for (int i = 0; i < b.size(); i++) {
			a.set(i, a.get(i) - b.get(i));
		}
		a = normalize(a);
		return a;
	}

	public static ArrayList<Integer> normalize(ArrayList<Integer> num) {
		num.add(0);
		for (int i = 0; i < num.size() - 1; i++) {
			if (num.get(i) < 0) {
				int borrow = (Math.abs(num.get(i)) + 9) / 10;
				num.set(i + 1, num.get(i + 1) - borrow);
				num.set(i, num.get(i) + borrow * 10);
			} else {
				num.set(i + 1, num.get(i + 1) + num.get(i) / 10);
				num.set(i, num.get(i) % 10);
			}
		}
		if (num.get(num.size() - 1) == 0)
			num.remove(num.size() - 1);
		return num;
	}
}

 

반응형