diff --git a/src/main/java/org/ahocorasick/trie/Trie.java b/src/main/java/org/ahocorasick/trie/Trie.java index c24af01..8d8f0eb 100644 --- a/src/main/java/org/ahocorasick/trie/Trie.java +++ b/src/main/java/org/ahocorasick/trie/Trie.java @@ -12,120 +12,107 @@ import java.util.concurrent.LinkedBlockingDeque; /** * * Based on the Aho-Corasick white paper, Bell technologies: - * ftp://163.13.200.222/assistant/bearhero/prog/%A8%E4%A5%A6/ac_bm.pdf - * + * ftp://163.13.200.222/assistant/bearhero/prog/%A8%E4%A5%A6/ac_bm.pdf * @author Robert Bor */ -public class Trie -{ +public class Trie { - private TrieConfig trieConfig; + private TrieConfig trieConfig; - private State rootState; + private State rootState; - private boolean failureStatesConstructed = false; + private boolean failureStatesConstructed = false; - public Trie(TrieConfig trieConfig) - { - this.trieConfig = trieConfig; - this.rootState = new State(); - } + public Trie(TrieConfig trieConfig) { + this.trieConfig = trieConfig; + this.rootState = new State(); + } - public Trie() - { - this(new TrieConfig()); - } + public Trie() { + this(new TrieConfig()); + } - public Trie caseInsensitive() - { - this.trieConfig.setCaseInsensitive(true); - return this; - } + public Trie caseInsensitive() { + this.trieConfig.setCaseInsensitive(true); + return this; + } - public Trie removeOverlaps() - { - this.trieConfig.setAllowOverlaps(false); - return this; - } + public Trie removeOverlaps() { + this.trieConfig.setAllowOverlaps(false); + return this; + } - public Trie onlyWholeWords() - { - this.trieConfig.setOnlyWholeWords(true); - return this; - } + public Trie onlyWholeWords() { + this.trieConfig.setOnlyWholeWords(true); + return this; + } - public void addKeyword(String keyword) - { - if (keyword == null || keyword.length() == 0) { - return; - } - State currentState = this.rootState; - for (Character character : keyword.toCharArray()) { - currentState = currentState.addState(character); - } - currentState.addEmit(keyword); - } + public void addKeyword(String keyword) { + if (keyword == null || keyword.length() == 0) { + return; + } + State currentState = this.rootState; + for (Character character : keyword.toCharArray()) { + currentState = currentState.addState(character); + } + currentState.addEmit(keyword); + } - public Collection tokenize(String text) - { + public Collection tokenize(String text) { - Collection tokens = new ArrayList(); + Collection tokens = new ArrayList(); - Collection collectedEmits = parseText(text); - int lastCollectedPosition = -1; - for (Emit emit : collectedEmits) { - if (emit.getStart() - lastCollectedPosition > 1) { - tokens.add(createFragment(emit, text, lastCollectedPosition)); - } - tokens.add(createMatch(emit, text)); - lastCollectedPosition = emit.getEnd(); - } - if (text.length() - lastCollectedPosition > 1) { - tokens.add(createFragment(null, text, lastCollectedPosition)); - } + Collection collectedEmits = parseText(text); + int lastCollectedPosition = -1; + for (Emit emit : collectedEmits) { + if (emit.getStart() - lastCollectedPosition > 1) { + tokens.add(createFragment(emit, text, lastCollectedPosition)); + } + tokens.add(createMatch(emit, text)); + lastCollectedPosition = emit.getEnd(); + } + if (text.length() - lastCollectedPosition > 1) { + tokens.add(createFragment(null, text, lastCollectedPosition)); + } - return tokens; - } + return tokens; + } - private Token createFragment(Emit emit, String text, int lastCollectedPosition) - { - return new FragmentToken(text.substring(lastCollectedPosition + 1, emit == null ? text.length() : emit. - getStart())); - } + private Token createFragment(Emit emit, String text, int lastCollectedPosition) { + return new FragmentToken(text.substring(lastCollectedPosition+1, emit == null ? text.length() : emit.getStart())); + } - private Token createMatch(Emit emit, String text) - { - return new MatchToken(text.substring(emit.getStart(), emit.getEnd() + 1), emit); - } + private Token createMatch(Emit emit, String text) { + return new MatchToken(text.substring(emit.getStart(), emit.getEnd()+1), emit); + } - @SuppressWarnings("unchecked") - public Collection parseText(String text) - { - checkForConstructedFailureStates(); + @SuppressWarnings("unchecked") + public Collection parseText(String text) { + checkForConstructedFailureStates(); - int position = 0; - State currentState = this.rootState; - List collectedEmits = new ArrayList(); - for (Character character : text.toCharArray()) { - if (trieConfig.isCaseInsensitive()) { - character = Character.toLowerCase(character); - } - currentState = getState(currentState, character); - storeEmits(position, currentState, collectedEmits); - position++; - } + int position = 0; + State currentState = this.rootState; + List collectedEmits = new ArrayList(); + for (Character character : text.toCharArray()) { + if (trieConfig.isCaseInsensitive()) { + character = Character.toLowerCase(character); + } + currentState = getState(currentState, character); + storeEmits(position, currentState, collectedEmits); + position++; + } - if (trieConfig.isOnlyWholeWords()) { - removePartialMatches(text, collectedEmits); - } + if (trieConfig.isOnlyWholeWords()) { + removePartialMatches(text, collectedEmits); + } - if (!trieConfig.isAllowOverlaps()) { - IntervalTree intervalTree = new IntervalTree((List) (List) collectedEmits); - intervalTree.removeOverlaps((List) (List) collectedEmits); - } + if (!trieConfig.isAllowOverlaps()) { + IntervalTree intervalTree = new IntervalTree((List)(List)collectedEmits); + intervalTree.removeOverlaps((List) (List) collectedEmits); + } - return collectedEmits; - } + return collectedEmits; + } public boolean matches(String text) { @@ -135,7 +122,6 @@ public class Trie public Emit firstMatch(String text) { - if (!trieConfig.isAllowOverlaps()) { // Slow path. Needs to find all the matches to detect overlaps. Collection parseText = parseText(text); @@ -143,10 +129,8 @@ public class Trie return parseText.iterator().next(); } } else { - // Fast path. Returs first match found. - + // Fast path. Returs first match found. checkForConstructedFailureStates(); - int position = 0; State currentState = this.rootState; for (Character character : text.toCharArray()) { @@ -154,12 +138,10 @@ public class Trie character = Character.toLowerCase(character); } currentState = getState(currentState, character); - Collection emitStrs = currentState.emit(); if (emitStrs != null && !emitStrs.isEmpty()) { for (String emitStr : emitStrs) { final Emit emit = new Emit(position - emitStr.length() + 1, position, emitStr); - if (trieConfig.isOnlyWholeWords()) { if (!isPartialMatch(text, emit)) { return emit; @@ -169,12 +151,9 @@ public class Trie } } } - position++; } - } - return null; } @@ -188,74 +167,68 @@ public class Trie private void removePartialMatches(String searchText, List collectedEmits) { - long size = searchText.length(); List removeEmits = new ArrayList(); for (Emit emit : collectedEmits) { if (isPartialMatch(searchText, emit)) { removeEmits.add(emit); } } - for (Emit removeEmit : removeEmits) { collectedEmits.remove(removeEmit); } } - private State getState(State currentState, Character character) - { - State newCurrentState = currentState.nextState(character); - while (newCurrentState == null) { - currentState = currentState.failure(); - newCurrentState = currentState.nextState(character); - } - return newCurrentState; - } + private State getState(State currentState, Character character) { + State newCurrentState = currentState.nextState(character); + while (newCurrentState == null) { + currentState = currentState.failure(); + newCurrentState = currentState.nextState(character); + } + return newCurrentState; + } - private void checkForConstructedFailureStates() - { - if (!this.failureStatesConstructed) { - constructFailureStates(); - } - } + private void checkForConstructedFailureStates() { + if (!this.failureStatesConstructed) { + constructFailureStates(); + } + } - private void constructFailureStates() - { - Queue queue = new LinkedBlockingDeque(); + private void constructFailureStates() { + Queue queue = new LinkedBlockingDeque(); - // First, set the fail state of all depth 1 states to the root state - for (State depthOneState : this.rootState.getStates()) { - depthOneState.setFailure(this.rootState); - queue.add(depthOneState); - } - this.failureStatesConstructed = true; + // First, set the fail state of all depth 1 states to the root state + for (State depthOneState : this.rootState.getStates()) { + depthOneState.setFailure(this.rootState); + queue.add(depthOneState); + } + this.failureStatesConstructed = true; - // Second, determine the fail state for all depth > 1 state - while (!queue.isEmpty()) { - State currentState = queue.remove(); + // Second, determine the fail state for all depth > 1 state + while (!queue.isEmpty()) { + State currentState = queue.remove(); - for (Character transition : currentState.getTransitions()) { - State targetState = currentState.nextState(transition); - queue.add(targetState); + for (Character transition : currentState.getTransitions()) { + State targetState = currentState.nextState(transition); + queue.add(targetState); - State traceFailureState = currentState.failure(); - while (traceFailureState.nextState(transition) == null) { - traceFailureState = traceFailureState.failure(); - } - State newFailureState = traceFailureState.nextState(transition); - targetState.setFailure(newFailureState); - targetState.addEmit(newFailureState.emit()); - } - } - } + State traceFailureState = currentState.failure(); + while (traceFailureState.nextState(transition) == null) { + traceFailureState = traceFailureState.failure(); + } + State newFailureState = traceFailureState.nextState(transition); + targetState.setFailure(newFailureState); + targetState.addEmit(newFailureState.emit()); + } + } + } - private void storeEmits(int position, State currentState, List collectedEmits) - { - Collection emits = currentState.emit(); - if (emits != null && !emits.isEmpty()) { - for (String emit : emits) { - collectedEmits.add(new Emit(position - emit.length() + 1, position, emit)); - } - } - } + private void storeEmits(int position, State currentState, List collectedEmits) { + Collection emits = currentState.emit(); + if (emits != null && !emits.isEmpty()) { + for (String emit : emits) { + collectedEmits.add(new Emit(position-emit.length()+1, position, emit)); + } + } + } }