Added method and tests for a faster path to return the first match.

This commit is contained in:
ryan 2014-10-06 10:52:35 -07:00
parent 25eeef5168
commit df503bae43
2 changed files with 316 additions and 149 deletions

View File

@ -11,177 +11,251 @@ import java.util.concurrent.LinkedBlockingDeque;
/**
*
* Based on the Aho-Corasick white paper, Bell technologies: ftp://163.13.200.222/assistant/bearhero/prog/%A8%E4%A5%A6/ac_bm.pdf
* Based on the Aho-Corasick white paper, Bell technologies:
* ftp://163.13.200.222/assistant/bearhero/prog/%A8%E4%A5%A6/ac_bm.pdf
*
* @author Robert Bor
*/
public class Trie {
public class Trie
{
private TrieConfig trieConfig;
private TrieConfig trieConfig;
private State rootState;
private State rootState;
private boolean failureStatesConstructed = false;
private boolean failureStatesConstructed = false;
public Trie(TrieConfig trieConfig) {
this.trieConfig = trieConfig;
this.rootState = new State();
}
public Trie(TrieConfig trieConfig)
{
this.trieConfig = trieConfig;
this.rootState = new State();
}
public Trie() {
this(new TrieConfig());
}
public Trie()
{
this(new TrieConfig());
}
public Trie caseInsensitive() {
this.trieConfig.setCaseInsensitive(true);
return this;
}
public Trie caseInsensitive()
{
this.trieConfig.setCaseInsensitive(true);
return this;
}
public Trie removeOverlaps() {
this.trieConfig.setAllowOverlaps(false);
return this;
}
public Trie removeOverlaps()
{
this.trieConfig.setAllowOverlaps(false);
return this;
}
public Trie onlyWholeWords() {
this.trieConfig.setOnlyWholeWords(true);
return this;
}
public Trie onlyWholeWords()
{
this.trieConfig.setOnlyWholeWords(true);
return this;
}
public void addKeyword(String keyword) {
if (keyword == null || keyword.length() == 0) {
return;
}
State currentState = this.rootState;
for (Character character : keyword.toCharArray()) {
currentState = currentState.addState(character);
}
currentState.addEmit(keyword);
}
public void addKeyword(String keyword)
{
if (keyword == null || keyword.length() == 0) {
return;
}
State currentState = this.rootState;
for (Character character : keyword.toCharArray()) {
currentState = currentState.addState(character);
}
currentState.addEmit(keyword);
}
public Collection<Token> tokenize(String text) {
public Collection<Token> tokenize(String text)
{
Collection<Token> tokens = new ArrayList<Token>();
Collection<Token> tokens = new ArrayList<Token>();
Collection<Emit> collectedEmits = parseText(text);
int lastCollectedPosition = -1;
for (Emit emit : collectedEmits) {
if (emit.getStart() - lastCollectedPosition > 1) {
tokens.add(createFragment(emit, text, lastCollectedPosition));
}
tokens.add(createMatch(emit, text));
lastCollectedPosition = emit.getEnd();
}
if (text.length() - lastCollectedPosition > 1) {
tokens.add(createFragment(null, text, lastCollectedPosition));
}
Collection<Emit> collectedEmits = parseText(text);
int lastCollectedPosition = -1;
for (Emit emit : collectedEmits) {
if (emit.getStart() - lastCollectedPosition > 1) {
tokens.add(createFragment(emit, text, lastCollectedPosition));
}
tokens.add(createMatch(emit, text));
lastCollectedPosition = emit.getEnd();
}
if (text.length() - lastCollectedPosition > 1) {
tokens.add(createFragment(null, text, lastCollectedPosition));
}
return tokens;
}
return tokens;
}
private Token createFragment(Emit emit, String text, int lastCollectedPosition) {
return new FragmentToken(text.substring(lastCollectedPosition+1, emit == null ? text.length() : emit.getStart()));
}
private Token createFragment(Emit emit, String text, int lastCollectedPosition)
{
return new FragmentToken(text.substring(lastCollectedPosition + 1, emit == null ? text.length() : emit.
getStart()));
}
private Token createMatch(Emit emit, String text) {
return new MatchToken(text.substring(emit.getStart(), emit.getEnd()+1), emit);
}
private Token createMatch(Emit emit, String text)
{
return new MatchToken(text.substring(emit.getStart(), emit.getEnd() + 1), emit);
}
@SuppressWarnings("unchecked")
public Collection<Emit> parseText(String text) {
checkForConstructedFailureStates();
@SuppressWarnings("unchecked")
public Collection<Emit> parseText(String text)
{
checkForConstructedFailureStates();
int position = 0;
State currentState = this.rootState;
List<Emit> collectedEmits = new ArrayList<Emit>();
for (Character character : text.toCharArray()) {
if (trieConfig.isCaseInsensitive()) {
character = Character.toLowerCase(character);
}
currentState = getState(currentState, character);
storeEmits(position, currentState, collectedEmits);
position++;
}
int position = 0;
State currentState = this.rootState;
List<Emit> collectedEmits = new ArrayList<Emit>();
for (Character character : text.toCharArray()) {
if (trieConfig.isCaseInsensitive()) {
character = Character.toLowerCase(character);
}
currentState = getState(currentState, character);
storeEmits(position, currentState, collectedEmits);
position++;
}
if (trieConfig.isOnlyWholeWords()) {
removePartialMatches(text, collectedEmits);
}
if (trieConfig.isOnlyWholeWords()) {
removePartialMatches(text, collectedEmits);
}
if (!trieConfig.isAllowOverlaps()) {
IntervalTree intervalTree = new IntervalTree((List<Intervalable>)(List<?>)collectedEmits);
intervalTree.removeOverlaps((List<Intervalable>) (List<?>) collectedEmits);
}
if (!trieConfig.isAllowOverlaps()) {
IntervalTree intervalTree = new IntervalTree((List<Intervalable>) (List<?>) collectedEmits);
intervalTree.removeOverlaps((List<Intervalable>) (List<?>) collectedEmits);
}
return collectedEmits;
}
return collectedEmits;
}
private void removePartialMatches(String searchText, List<Emit> collectedEmits) {
long size = searchText.length();
List<Emit> removeEmits = new ArrayList<Emit>();
for (Emit emit : collectedEmits) {
if ((emit.getStart() == 0 ||
!Character.isAlphabetic(searchText.charAt(emit.getStart() - 1))) &&
(emit.getEnd() + 1 == size ||
!Character.isAlphabetic(searchText.charAt(emit.getEnd() + 1)))) {
continue;
}
removeEmits.add(emit);
}
public boolean matches(String text)
{
Emit firstMatch = firstMatch(text);
return firstMatch != null;
}
for (Emit removeEmit : removeEmits) {
collectedEmits.remove(removeEmit);
}
}
public Emit firstMatch(String text)
{
private State getState(State currentState, Character character) {
State newCurrentState = currentState.nextState(character);
while (newCurrentState == null) {
currentState = currentState.failure();
newCurrentState = currentState.nextState(character);
}
return newCurrentState;
}
if (!trieConfig.isAllowOverlaps()) {
// Slow path. Needs to find all the matches to detect overlaps.
Collection<Emit> parseText = parseText(text);
if (parseText != null && !parseText.isEmpty()) {
return parseText.iterator().next();
}
} else {
// Fast path. Returs first match found.
private void checkForConstructedFailureStates() {
if (!this.failureStatesConstructed) {
constructFailureStates();
}
}
checkForConstructedFailureStates();
private void constructFailureStates() {
Queue<State> queue = new LinkedBlockingDeque<State>();
int position = 0;
State currentState = this.rootState;
for (Character character : text.toCharArray()) {
if (trieConfig.isCaseInsensitive()) {
character = Character.toLowerCase(character);
}
currentState = getState(currentState, character);
// First, set the fail state of all depth 1 states to the root state
for (State depthOneState : this.rootState.getStates()) {
depthOneState.setFailure(this.rootState);
queue.add(depthOneState);
}
this.failureStatesConstructed = true;
Collection<String> emitStrs = currentState.emit();
if (emitStrs != null && !emitStrs.isEmpty()) {
for (String emitStr : emitStrs) {
final Emit emit = new Emit(position - emitStr.length() + 1, position, emitStr);
// Second, determine the fail state for all depth > 1 state
while (!queue.isEmpty()) {
State currentState = queue.remove();
if (trieConfig.isOnlyWholeWords()) {
if (!isPartialMatch(text, emit)) {
return emit;
}
} else {
return emit;
}
}
}
for (Character transition : currentState.getTransitions()) {
State targetState = currentState.nextState(transition);
queue.add(targetState);
position++;
}
State traceFailureState = currentState.failure();
while (traceFailureState.nextState(transition) == null) {
traceFailureState = traceFailureState.failure();
}
State newFailureState = traceFailureState.nextState(transition);
targetState.setFailure(newFailureState);
targetState.addEmit(newFailureState.emit());
}
}
}
}
private void storeEmits(int position, State currentState, List<Emit> collectedEmits) {
Collection<String> emits = currentState.emit();
if (emits != null && !emits.isEmpty()) {
for (String emit : emits) {
collectedEmits.add(new Emit(position-emit.length()+1, position, emit));
}
}
}
return null;
}
private boolean isPartialMatch(String searchText, Emit emit)
{
return (emit.getStart() != 0 &&
Character.isAlphabetic(searchText.charAt(emit.getStart() - 1))) ||
(emit.getEnd() + 1 != searchText.length() &&
Character.isAlphabetic(searchText.charAt(emit.getEnd() + 1)));
}
private void removePartialMatches(String searchText, List<Emit> collectedEmits)
{
long size = searchText.length();
List<Emit> removeEmits = new ArrayList<Emit>();
for (Emit emit : collectedEmits) {
if (isPartialMatch(searchText, emit)) {
removeEmits.add(emit);
}
}
for (Emit removeEmit : removeEmits) {
collectedEmits.remove(removeEmit);
}
}
private State getState(State currentState, Character character)
{
State newCurrentState = currentState.nextState(character);
while (newCurrentState == null) {
currentState = currentState.failure();
newCurrentState = currentState.nextState(character);
}
return newCurrentState;
}
private void checkForConstructedFailureStates()
{
if (!this.failureStatesConstructed) {
constructFailureStates();
}
}
private void constructFailureStates()
{
Queue<State> queue = new LinkedBlockingDeque<State>();
// First, set the fail state of all depth 1 states to the root state
for (State depthOneState : this.rootState.getStates()) {
depthOneState.setFailure(this.rootState);
queue.add(depthOneState);
}
this.failureStatesConstructed = true;
// Second, determine the fail state for all depth > 1 state
while (!queue.isEmpty()) {
State currentState = queue.remove();
for (Character transition : currentState.getTransitions()) {
State targetState = currentState.nextState(transition);
queue.add(targetState);
State traceFailureState = currentState.failure();
while (traceFailureState.nextState(transition) == null) {
traceFailureState = traceFailureState.failure();
}
State newFailureState = traceFailureState.nextState(transition);
targetState.setFailure(newFailureState);
targetState.addEmit(newFailureState.emit());
}
}
}
private void storeEmits(int position, State currentState, List<Emit> collectedEmits)
{
Collection<String> emits = currentState.emit();
if (emits != null && !emits.isEmpty()) {
for (String emit : emits) {
collectedEmits.add(new Emit(position - emit.length() + 1, position, emit));
}
}
}
}

View File

@ -18,6 +18,14 @@ public class TrieTest {
checkEmit(iterator.next(), 0, 2, "abc");
}
@Test
public void keywordAndTextAreTheSameFirstMatch() {
Trie trie = new Trie();
trie.addKeyword("abc");
Emit firstMatch = trie.firstMatch("abc");
checkEmit(firstMatch, 0, 2, "abc");
}
@Test
public void textIsLongerThanKeyword() {
Trie trie = new Trie();
@ -27,6 +35,14 @@ public class TrieTest {
checkEmit(iterator.next(), 1, 3, "abc");
}
@Test
public void textIsLongerThanKeywordFirstMatch() {
Trie trie = new Trie();
trie.addKeyword("abc");
Emit firstMatch = trie.firstMatch(" abc");
checkEmit(firstMatch, 1, 3, "abc");
}
@Test
public void variousKeywordsOneMatch() {
Trie trie = new Trie();
@ -38,6 +54,16 @@ public class TrieTest {
checkEmit(iterator.next(), 0, 2, "bcd");
}
@Test
public void variousKeywordsFirstMatch() {
Trie trie = new Trie();
trie.addKeyword("abc");
trie.addKeyword("bcd");
trie.addKeyword("cde");
Emit firstMatch = trie.firstMatch("bcd");
checkEmit(firstMatch, 0, 2, "bcd");
}
@Test
public void ushersTest() {
Trie trie = new Trie();
@ -53,6 +79,17 @@ public class TrieTest {
checkEmit(iterator.next(), 2, 5, "hers");
}
@Test
public void ushersTestFirstMatch() {
Trie trie = new Trie();
trie.addKeyword("hers");
trie.addKeyword("his");
trie.addKeyword("she");
trie.addKeyword("he");
Emit firstMatch = trie.firstMatch("ushers");
checkEmit(firstMatch, 2, 3, "he");
}
@Test
public void misleadingTest() {
Trie trie = new Trie();
@ -62,6 +99,14 @@ public class TrieTest {
checkEmit(iterator.next(), 9, 12, "hers");
}
@Test
public void misleadingTestFirstMatch() {
Trie trie = new Trie();
trie.addKeyword("hers");
Emit firstMatch = trie.firstMatch("h he her hers");
checkEmit(firstMatch, 9, 12, "hers");
}
@Test
public void recipes() {
Trie trie = new Trie();
@ -77,20 +122,26 @@ public class TrieTest {
checkEmit(iterator.next(), 51, 58, "broccoli");
}
@Test
public void recipesFirstMatch() {
Trie trie = new Trie();
trie.addKeyword("veal");
trie.addKeyword("cauliflower");
trie.addKeyword("broccoli");
trie.addKeyword("tomatoes");
Emit firstMatch = trie.firstMatch("2 cauliflowers, 3 tomatoes, 4 slices of veal, 100g broccoli");
checkEmit(firstMatch, 2, 12, "cauliflower");
}
@Test
public void longAndShortOverlappingMatch() {
Trie trie = new Trie();
trie.addKeyword("he");
trie.addKeyword("hehehehe");
Collection<Emit> emits = trie.parseText("hehehehehe");
Iterator<Emit> iterator = emits.iterator();
checkEmit(iterator.next(), 0, 1, "he");
checkEmit(iterator.next(), 2, 3, "he");
checkEmit(iterator.next(), 4, 5, "he");
checkEmit(iterator.next(), 6, 7, "he");
checkEmit(iterator.next(), 0, 7, "hehehehe");
checkEmit(iterator.next(), 8, 9, "he");
checkEmit(iterator.next(), 2, 9, "hehehehe");
Emit firstMatch = trie.firstMatch("hehehehehe");
checkEmit(firstMatch, 0, 1, "he");
}
@Test
@ -107,6 +158,17 @@ public class TrieTest {
checkEmit(iterator.next(), 6, 7, "ab");
}
@Test
public void nonOverlappingFirstMatch() {
Trie trie = new Trie().removeOverlaps();
trie.addKeyword("ab");
trie.addKeyword("cba");
trie.addKeyword("ababc");
Emit firstMatch = trie.firstMatch("ababcbab");
checkEmit(firstMatch, 0, 4, "ababc");
}
@Test
public void startOfChurchillSpeech() {
Trie trie = new Trie().removeOverlaps();
@ -133,6 +195,15 @@ public class TrieTest {
checkEmit(emits.iterator().next(), 20, 24, "sugar");
}
@Test
public void partialMatchFirstMatch() {
Trie trie = new Trie().onlyWholeWords();
trie.addKeyword("sugar");
Emit firstMatch = trie.firstMatch("sugarcane sugarcane sugar canesugar"); // left, middle, right test
checkEmit(firstMatch, 20, 24, "sugar");
}
@Test
public void tokenizeFullSentence() {
Trie trie = new Trie();
@ -183,6 +254,18 @@ public class TrieTest {
checkEmit(it.next(), 19, 23, "börkü");
}
@Test
public void caseInsensitiveFirstMatch() {
Trie trie = new Trie().caseInsensitive();
trie.addKeyword("turning");
trie.addKeyword("once");
trie.addKeyword("again");
trie.addKeyword("börkü");
Emit firstMatch = trie.firstMatch("TurninG OnCe AgAiN BÖRKÜ");
checkEmit(firstMatch, 0, 6, "turning");
}
@Test
public void tokenizeTokensInSequence() {
Trie trie = new Trie();
@ -214,6 +297,16 @@ public class TrieTest {
checkEmit(it.next(), 5, 8, "this");
}
@Test
public void unicodeIssueBug8ReportedByDwyerkFirstMatch() {
String target = "LİKE THIS"; // The second character ('İ') is Unicode, which was read by AC as a 2-byte char
Trie trie = new Trie().caseInsensitive().onlyWholeWords();
assertEquals("THIS", target.substring(5,9)); // Java does it the right way
trie.addKeyword("this");
Emit firstMatch = trie.firstMatch(target);
checkEmit(firstMatch, 5, 8, "this");
}
private void checkEmit(Emit next, int expectedStart, int expectedEnd, String expectedKeyword) {
assertEquals("Start of emit should have been "+expectedStart, expectedStart, next.getStart());
assertEquals("End of emit should have been "+expectedEnd, expectedEnd, next.getEnd());