개발/알고리즘

아호 코라식(Aho-Corasick) 알고리즘

스몰스테핑 2024. 6. 27. 17:11

 

아호 코라식 알고리즘(Aho-Corasick string matching algorithm)은 Alfred V. Aho와 Margaret J. Corasick이 고안한 문자열 매칭 알고리즘이다.

 

패턴 1개를 탐색하는 매칭 알고리즘은 선형 시간에 구현됨을 KMP 등 여러 알고리즘을 통해 증명되었으나, 패턴 집합에 대해 해당 알고리즘들을 사용해보면 패턴 개수에 비례해 그 속도가 느려진다는 점이 발생했다.

시간복잡도는 O(m + zn)이 되는 것이다.

  • m: 모든 패턴들의 길이 합
  • z: 패턴 수
  • n: text 크기

이를 보완한 것이 아호 코라식 알고리즘으로 시간복잡도는 O(m + n + k)이다. 패턴 집합에 대하여 패턴 길이와 텍스트의 선형 시간에 탐색을 처리할 수 있게 된다.

  • k: 텍스트 내에 패턴의 발생 수

 

아호 코라식 알고리즘을 구현하기 위해 트라이 자료 구조, 실패 링크 계산 로직, 출력 문자열 목록 생성KMP 알고리즘을 사용한다.

 

트라이 자료구조 방식에서 계산한 실패함수 자료구조

 

  • 실선 화살표 = 해당 상태에서 대응이 성공했을 경우 이동가능한 상태
  • 점선 화살표 = 실패 함수를 나타내는 것으로, KMP 알고리즘에선 getPi에 해당하는 부분이다.

 

구현 방식

  1. Trie. 자료 구조를 구현한다.
    • 기본적으로 TrieNode와 기능 함수인 insert(), find(), failure()를 구현한다.
    • find() 함수에는 KMP 문자열 매칭 알고리즘을.
    • failure()에는 실패링크 계산 로직을.
  2. 주어진 패턴을 전부 Trie에 insert한다.
  3. 실패함수를 계산한다.
  4. 주어진 문자열을 저눕 Trie에 find시켜 결과를 도출한다.

실패 함수를 계산할때, TrieNode의 fail 변수 값을 채워야한다. 그러기 위해선 트라이를 탐색해야한다.

트라이 탐색을 위해 BFS 탐색을 사용한다.

 

자세한 매칭 과정은 다음 글을 참고하면 좋다.

https://pangtrue.tistory.com/305

 

[알고리즘] 아호 코라식(Aho-Corasick) 알고리즘

1. 개요 다음과 같은 두 개의 문자열이 있습니다. S = "abababdddd" (String의 약자를 따서 S라 표현) W = "bab" (Word의 약자를 따서 W라 표현) 아시는거처럼 S에서 W가 존재하는지를 찾아내는, 이른바 일대일

pangtrue.tistory.com

 

 

구현 코드

class Trie(private val root: TrieNode = TrieNode()) {
    fun insert(key: String) {
        var cur = root

        for (i in key.indices) {
            val next = key[i]

            cur.childNode.putIfAbsent(next, TrieNode())
            cur = cur.childNode[next]!!
        }

        cur.isFinish = true
    }

    fun failure() { // KMP 비교 중 실패했을 때 이동할 지점 정의하는 함수
        val queue: Queue<TrieNode> = LinkedList()
        root.failed = root // root 회귀
        queue.add(root)

        // 트리에 담긴 문자열에 대해 접두사-접미사가 동일한 것을 뽑아내야 함 (BFS 사용)
        while (queue.isNotEmpty()) {
            val currentNode = queue.poll()

            // 알파벳 소문자만 있다고 가정하여 크기 26
            for (i in 0 until 26) {
                val next = (97 + i).toChar()
                val nextNode = currentNode.childNode[next] ?: continue

                // root 자식 노드의 fail은 그들의 부모인 root 노드가 된다
                if (currentNode == root) {
                    nextNode.failed = root
                } else {
                    var failure = currentNode.failed

                    while (failure!! != root && failure.childNode[next] == null) {
                        failure = failure.failed
                    }

                    if (failure.childNode[next] != null) failure = failure.childNode[next]
                    nextNode.failed = failure
                }

                if (nextNode.failed!!.isFinish) nextNode.isFinish = true
                queue.add(nextNode)
            }
        }
    }

    fun find(key: String): Boolean { // KMP 알고리즘
        var cur = root

        for (i in key.indices) {
            val next = key[i]

            while (cur != root && cur.childNode[next] == null) {
                cur = cur.failed!!
            }

            if (cur.childNode[next] != null) cur = cur.childNode[next]!!
            if (cur.isFinish) return true // 원본 문자열에서 패턴 문자열 하나를 찾았을 경우
        }

        return false
    }
}

class TrieNode {
    val childNode = HashMap<Char, TrieNode>()
    var failed: TrieNode? = null // 실패했을 때, 이동해야 할 지점을 가리킬 변수
    var isFinish = false // 완성된 문자열인지 표시
}

 

 

 

 

예시 코드 1번 (9250번: 문자열 집합 판별):

import java.io.BufferedReader
import java.io.BufferedWriter
import java.io.InputStreamReader
import java.io.OutputStreamWriter
import java.util.*


val trie = Trie()

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val bw = BufferedWriter(OutputStreamWriter(System.out))

    val n = readLine().toInt()
    repeat(n) {
        trie.insert(readLine())
    }
    trie.failure()

    val q = readLine().toInt()
    repeat(q) {
        bw.appendLine(if (trie.find(readLine())) "YES" else "NO")
    }
    bw.flush()
    bw.close()
}

class Trie(private val root: TrieNode = TrieNode()) {
    fun insert(key: String) {
        var cur = root

        for (i in key.indices) {
            val next = key[i]

            cur.children.putIfAbsent(next, TrieNode())
            cur = cur.children[next]!!
        }

        cur.isFinish = true
    }

    fun failure() { // KMP 비교 중 실패했을 때 이동할 지점 정의하는 함수
        val queue: Queue<TrieNode> = LinkedList()
        root.failed = root
        queue.add(root)

        while (queue.isNotEmpty()) {
            val currentNode = queue.poll()

            for (i in 0 until 26) {
                val next = (97 + i).toChar()
                val nextNode = currentNode.children[next] ?: continue

                if (currentNode == root) {
                    nextNode.failed = root
                } else {
                    var failure = currentNode.failed

                    while (failure!! != root && failure.children[next] == null) {
                        failure = failure.failed
                    }

                    if (failure.children[next] != null) failure = failure.children[next]
                    nextNode.failed = failure
                }

                if (nextNode.failed!!.isFinish) nextNode.isFinish = true
                queue.add(nextNode)
            }
        }
    }

    fun find(key: String): Boolean { // KMP
        var cur = root

        for (i in key.indices) {
            val next = key[i]

            while (cur != root && cur.children[next] == null) {
                cur = cur.failed!!
            }

            if (cur.children[next] != null) cur = cur.children[next]!!
            if (cur.isFinish) return true
        }

        return false
    }
}

class TrieNode {
    val children = HashMap<Char, TrieNode>()
    var failed: TrieNode? = null
    var isFinish = false
}

 

 

 

 

예시 코드 2번 (10256번: 돌연변이):

import java.io.BufferedReader
import java.io.BufferedWriter
import java.io.InputStreamReader
import java.io.OutputStreamWriter
import java.util.*

val data = HashMap<Char, Int>().apply {
    put('A', 0)
    put('G', 1)
    put('T', 2)
    put('C', 3)
}

fun main() = with(BufferedReader(InputStreamReader(System.`in`))) {
    val bw = BufferedWriter(OutputStreamWriter(System.out))

    val t = readLine().toInt()
    val sb = StringBuilder()
    repeat(t) {
        val (n, m) = readLine().split(" ").map { it.toInt() }
        val first = readLine()
        val second = readLine()

        val trie = Trie()
        trie.insert(second)
        for (i in 0 until m) {
            for (j in i + 1 until m) {
                trie.insert(reverseProcess(second, i, j + 1))
            }
        }

        trie.failure()
        sb.appendLine(trie.find(first))
    }

    bw.write(sb.toString())
    bw.flush()
    bw.close()
}

fun reverseProcess(str: String, start: Int, end: Int): String {
    val sb = StringBuilder().apply {
        append(str.substring(0, start))
        append(StringBuilder(str.substring(start, end)).reverse())
        append(str.substring(end, str.length))
    }

    return sb.toString()
}

class Trie(private val root: TrieNode = TrieNode()) {
    fun insert(key: String) {
        var cur = root

        for (i in key.indices) {
            cur = cur.children.computeIfAbsent(data[key[i]]!!) { TrieNode() }
        }
        cur.isFinish = true
    }

    fun failure() {
        val queue: Queue<TrieNode> = LinkedList()
        root.failed = root
        queue.add(root)

        while (!queue.isEmpty()) {
            val cur = queue.poll()

            for (c in data.keys) {
                val key = data[c]
                val childNode = cur.children[key] ?: continue

                if (cur == root) {
                    childNode.failed = root
                } else {
                    var failure = cur.failed

                    while (failure!! != root && failure.children[key] == null) {
                        failure = failure.failed
                    }

                    if (failure.children[key] != null) failure = failure.children[key]
                    childNode.failed = failure
                }

                if (childNode.failed!!.isFinish) childNode.isFinish = true
                queue.add(childNode)
            }
        }
    }

    fun find(key: String): Int {
        var cur = root
        var cnt = 0

        for (i in key.indices) {
            val temp = data[key[i]]

            while (cur != root && cur.children[temp] == null) {
                cur = cur.failed!!
            }

            if (cur.children[temp] != null) cur = cur.children[temp]!!
            if (cur.isFinish) cnt++
        }

        return cnt
    }
}

class TrieNode {
    val children = HashMap<Int, TrieNode>()
    var failed: TrieNode? = null
    var isFinish = false
}