Added method and tests for a faster path to return the first match.
This commit is contained in:
parent
25eeef5168
commit
df503bae43
@ -11,177 +11,251 @@ import java.util.concurrent.LinkedBlockingDeque;
|
||||
|
||||
/**
|
||||
*
|
||||
* Based on the Aho-Corasick white paper, Bell technologies: ftp://163.13.200.222/assistant/bearhero/prog/%A8%E4%A5%A6/ac_bm.pdf
|
||||
* Based on the Aho-Corasick white paper, Bell technologies:
|
||||
* ftp://163.13.200.222/assistant/bearhero/prog/%A8%E4%A5%A6/ac_bm.pdf
|
||||
*
|
||||
* @author Robert Bor
|
||||
*/
|
||||
public class Trie {
|
||||
public class Trie
|
||||
{
|
||||
|
||||
private TrieConfig trieConfig;
|
||||
private TrieConfig trieConfig;
|
||||
|
||||
private State rootState;
|
||||
private State rootState;
|
||||
|
||||
private boolean failureStatesConstructed = false;
|
||||
private boolean failureStatesConstructed = false;
|
||||
|
||||
public Trie(TrieConfig trieConfig) {
|
||||
this.trieConfig = trieConfig;
|
||||
this.rootState = new State();
|
||||
}
|
||||
public Trie(TrieConfig trieConfig)
|
||||
{
|
||||
this.trieConfig = trieConfig;
|
||||
this.rootState = new State();
|
||||
}
|
||||
|
||||
public Trie() {
|
||||
this(new TrieConfig());
|
||||
}
|
||||
public Trie()
|
||||
{
|
||||
this(new TrieConfig());
|
||||
}
|
||||
|
||||
public Trie caseInsensitive() {
|
||||
this.trieConfig.setCaseInsensitive(true);
|
||||
return this;
|
||||
}
|
||||
public Trie caseInsensitive()
|
||||
{
|
||||
this.trieConfig.setCaseInsensitive(true);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Trie removeOverlaps() {
|
||||
this.trieConfig.setAllowOverlaps(false);
|
||||
return this;
|
||||
}
|
||||
public Trie removeOverlaps()
|
||||
{
|
||||
this.trieConfig.setAllowOverlaps(false);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Trie onlyWholeWords() {
|
||||
this.trieConfig.setOnlyWholeWords(true);
|
||||
return this;
|
||||
}
|
||||
public Trie onlyWholeWords()
|
||||
{
|
||||
this.trieConfig.setOnlyWholeWords(true);
|
||||
return this;
|
||||
}
|
||||
|
||||
public void addKeyword(String keyword) {
|
||||
if (keyword == null || keyword.length() == 0) {
|
||||
return;
|
||||
}
|
||||
State currentState = this.rootState;
|
||||
for (Character character : keyword.toCharArray()) {
|
||||
currentState = currentState.addState(character);
|
||||
}
|
||||
currentState.addEmit(keyword);
|
||||
}
|
||||
public void addKeyword(String keyword)
|
||||
{
|
||||
if (keyword == null || keyword.length() == 0) {
|
||||
return;
|
||||
}
|
||||
State currentState = this.rootState;
|
||||
for (Character character : keyword.toCharArray()) {
|
||||
currentState = currentState.addState(character);
|
||||
}
|
||||
currentState.addEmit(keyword);
|
||||
}
|
||||
|
||||
public Collection<Token> tokenize(String text) {
|
||||
public Collection<Token> tokenize(String text)
|
||||
{
|
||||
|
||||
Collection<Token> tokens = new ArrayList<Token>();
|
||||
Collection<Token> tokens = new ArrayList<Token>();
|
||||
|
||||
Collection<Emit> collectedEmits = parseText(text);
|
||||
int lastCollectedPosition = -1;
|
||||
for (Emit emit : collectedEmits) {
|
||||
if (emit.getStart() - lastCollectedPosition > 1) {
|
||||
tokens.add(createFragment(emit, text, lastCollectedPosition));
|
||||
}
|
||||
tokens.add(createMatch(emit, text));
|
||||
lastCollectedPosition = emit.getEnd();
|
||||
}
|
||||
if (text.length() - lastCollectedPosition > 1) {
|
||||
tokens.add(createFragment(null, text, lastCollectedPosition));
|
||||
}
|
||||
Collection<Emit> collectedEmits = parseText(text);
|
||||
int lastCollectedPosition = -1;
|
||||
for (Emit emit : collectedEmits) {
|
||||
if (emit.getStart() - lastCollectedPosition > 1) {
|
||||
tokens.add(createFragment(emit, text, lastCollectedPosition));
|
||||
}
|
||||
tokens.add(createMatch(emit, text));
|
||||
lastCollectedPosition = emit.getEnd();
|
||||
}
|
||||
if (text.length() - lastCollectedPosition > 1) {
|
||||
tokens.add(createFragment(null, text, lastCollectedPosition));
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
private Token createFragment(Emit emit, String text, int lastCollectedPosition) {
|
||||
return new FragmentToken(text.substring(lastCollectedPosition+1, emit == null ? text.length() : emit.getStart()));
|
||||
}
|
||||
private Token createFragment(Emit emit, String text, int lastCollectedPosition)
|
||||
{
|
||||
return new FragmentToken(text.substring(lastCollectedPosition + 1, emit == null ? text.length() : emit.
|
||||
getStart()));
|
||||
}
|
||||
|
||||
private Token createMatch(Emit emit, String text) {
|
||||
return new MatchToken(text.substring(emit.getStart(), emit.getEnd()+1), emit);
|
||||
}
|
||||
private Token createMatch(Emit emit, String text)
|
||||
{
|
||||
return new MatchToken(text.substring(emit.getStart(), emit.getEnd() + 1), emit);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public Collection<Emit> parseText(String text) {
|
||||
checkForConstructedFailureStates();
|
||||
@SuppressWarnings("unchecked")
|
||||
public Collection<Emit> parseText(String text)
|
||||
{
|
||||
checkForConstructedFailureStates();
|
||||
|
||||
int position = 0;
|
||||
State currentState = this.rootState;
|
||||
List<Emit> collectedEmits = new ArrayList<Emit>();
|
||||
for (Character character : text.toCharArray()) {
|
||||
if (trieConfig.isCaseInsensitive()) {
|
||||
character = Character.toLowerCase(character);
|
||||
}
|
||||
currentState = getState(currentState, character);
|
||||
storeEmits(position, currentState, collectedEmits);
|
||||
position++;
|
||||
}
|
||||
int position = 0;
|
||||
State currentState = this.rootState;
|
||||
List<Emit> collectedEmits = new ArrayList<Emit>();
|
||||
for (Character character : text.toCharArray()) {
|
||||
if (trieConfig.isCaseInsensitive()) {
|
||||
character = Character.toLowerCase(character);
|
||||
}
|
||||
currentState = getState(currentState, character);
|
||||
storeEmits(position, currentState, collectedEmits);
|
||||
position++;
|
||||
}
|
||||
|
||||
if (trieConfig.isOnlyWholeWords()) {
|
||||
removePartialMatches(text, collectedEmits);
|
||||
}
|
||||
if (trieConfig.isOnlyWholeWords()) {
|
||||
removePartialMatches(text, collectedEmits);
|
||||
}
|
||||
|
||||
if (!trieConfig.isAllowOverlaps()) {
|
||||
IntervalTree intervalTree = new IntervalTree((List<Intervalable>)(List<?>)collectedEmits);
|
||||
intervalTree.removeOverlaps((List<Intervalable>) (List<?>) collectedEmits);
|
||||
}
|
||||
if (!trieConfig.isAllowOverlaps()) {
|
||||
IntervalTree intervalTree = new IntervalTree((List<Intervalable>) (List<?>) collectedEmits);
|
||||
intervalTree.removeOverlaps((List<Intervalable>) (List<?>) collectedEmits);
|
||||
}
|
||||
|
||||
return collectedEmits;
|
||||
}
|
||||
return collectedEmits;
|
||||
}
|
||||
|
||||
private void removePartialMatches(String searchText, List<Emit> collectedEmits) {
|
||||
long size = searchText.length();
|
||||
List<Emit> removeEmits = new ArrayList<Emit>();
|
||||
for (Emit emit : collectedEmits) {
|
||||
if ((emit.getStart() == 0 ||
|
||||
!Character.isAlphabetic(searchText.charAt(emit.getStart() - 1))) &&
|
||||
(emit.getEnd() + 1 == size ||
|
||||
!Character.isAlphabetic(searchText.charAt(emit.getEnd() + 1)))) {
|
||||
continue;
|
||||
}
|
||||
removeEmits.add(emit);
|
||||
}
|
||||
public boolean matches(String text)
|
||||
{
|
||||
Emit firstMatch = firstMatch(text);
|
||||
return firstMatch != null;
|
||||
}
|
||||
|
||||
for (Emit removeEmit : removeEmits) {
|
||||
collectedEmits.remove(removeEmit);
|
||||
}
|
||||
}
|
||||
public Emit firstMatch(String text)
|
||||
{
|
||||
|
||||
private State getState(State currentState, Character character) {
|
||||
State newCurrentState = currentState.nextState(character);
|
||||
while (newCurrentState == null) {
|
||||
currentState = currentState.failure();
|
||||
newCurrentState = currentState.nextState(character);
|
||||
}
|
||||
return newCurrentState;
|
||||
}
|
||||
if (!trieConfig.isAllowOverlaps()) {
|
||||
// Slow path. Needs to find all the matches to detect overlaps.
|
||||
Collection<Emit> parseText = parseText(text);
|
||||
if (parseText != null && !parseText.isEmpty()) {
|
||||
return parseText.iterator().next();
|
||||
}
|
||||
} else {
|
||||
// Fast path. Returs first match found.
|
||||
|
||||
private void checkForConstructedFailureStates() {
|
||||
if (!this.failureStatesConstructed) {
|
||||
constructFailureStates();
|
||||
}
|
||||
}
|
||||
checkForConstructedFailureStates();
|
||||
|
||||
private void constructFailureStates() {
|
||||
Queue<State> queue = new LinkedBlockingDeque<State>();
|
||||
int position = 0;
|
||||
State currentState = this.rootState;
|
||||
for (Character character : text.toCharArray()) {
|
||||
if (trieConfig.isCaseInsensitive()) {
|
||||
character = Character.toLowerCase(character);
|
||||
}
|
||||
currentState = getState(currentState, character);
|
||||
|
||||
// First, set the fail state of all depth 1 states to the root state
|
||||
for (State depthOneState : this.rootState.getStates()) {
|
||||
depthOneState.setFailure(this.rootState);
|
||||
queue.add(depthOneState);
|
||||
}
|
||||
this.failureStatesConstructed = true;
|
||||
Collection<String> emitStrs = currentState.emit();
|
||||
if (emitStrs != null && !emitStrs.isEmpty()) {
|
||||
for (String emitStr : emitStrs) {
|
||||
final Emit emit = new Emit(position - emitStr.length() + 1, position, emitStr);
|
||||
|
||||
// Second, determine the fail state for all depth > 1 state
|
||||
while (!queue.isEmpty()) {
|
||||
State currentState = queue.remove();
|
||||
if (trieConfig.isOnlyWholeWords()) {
|
||||
if (!isPartialMatch(text, emit)) {
|
||||
return emit;
|
||||
}
|
||||
} else {
|
||||
return emit;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Character transition : currentState.getTransitions()) {
|
||||
State targetState = currentState.nextState(transition);
|
||||
queue.add(targetState);
|
||||
position++;
|
||||
}
|
||||
|
||||
State traceFailureState = currentState.failure();
|
||||
while (traceFailureState.nextState(transition) == null) {
|
||||
traceFailureState = traceFailureState.failure();
|
||||
}
|
||||
State newFailureState = traceFailureState.nextState(transition);
|
||||
targetState.setFailure(newFailureState);
|
||||
targetState.addEmit(newFailureState.emit());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void storeEmits(int position, State currentState, List<Emit> collectedEmits) {
|
||||
Collection<String> emits = currentState.emit();
|
||||
if (emits != null && !emits.isEmpty()) {
|
||||
for (String emit : emits) {
|
||||
collectedEmits.add(new Emit(position-emit.length()+1, position, emit));
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private boolean isPartialMatch(String searchText, Emit emit)
|
||||
{
|
||||
return (emit.getStart() != 0 &&
|
||||
Character.isAlphabetic(searchText.charAt(emit.getStart() - 1))) ||
|
||||
(emit.getEnd() + 1 != searchText.length() &&
|
||||
Character.isAlphabetic(searchText.charAt(emit.getEnd() + 1)));
|
||||
}
|
||||
|
||||
private void removePartialMatches(String searchText, List<Emit> collectedEmits)
|
||||
{
|
||||
long size = searchText.length();
|
||||
List<Emit> removeEmits = new ArrayList<Emit>();
|
||||
for (Emit emit : collectedEmits) {
|
||||
if (isPartialMatch(searchText, emit)) {
|
||||
removeEmits.add(emit);
|
||||
}
|
||||
}
|
||||
|
||||
for (Emit removeEmit : removeEmits) {
|
||||
collectedEmits.remove(removeEmit);
|
||||
}
|
||||
}
|
||||
|
||||
private State getState(State currentState, Character character)
|
||||
{
|
||||
State newCurrentState = currentState.nextState(character);
|
||||
while (newCurrentState == null) {
|
||||
currentState = currentState.failure();
|
||||
newCurrentState = currentState.nextState(character);
|
||||
}
|
||||
return newCurrentState;
|
||||
}
|
||||
|
||||
private void checkForConstructedFailureStates()
|
||||
{
|
||||
if (!this.failureStatesConstructed) {
|
||||
constructFailureStates();
|
||||
}
|
||||
}
|
||||
|
||||
private void constructFailureStates()
|
||||
{
|
||||
Queue<State> queue = new LinkedBlockingDeque<State>();
|
||||
|
||||
// First, set the fail state of all depth 1 states to the root state
|
||||
for (State depthOneState : this.rootState.getStates()) {
|
||||
depthOneState.setFailure(this.rootState);
|
||||
queue.add(depthOneState);
|
||||
}
|
||||
this.failureStatesConstructed = true;
|
||||
|
||||
// Second, determine the fail state for all depth > 1 state
|
||||
while (!queue.isEmpty()) {
|
||||
State currentState = queue.remove();
|
||||
|
||||
for (Character transition : currentState.getTransitions()) {
|
||||
State targetState = currentState.nextState(transition);
|
||||
queue.add(targetState);
|
||||
|
||||
State traceFailureState = currentState.failure();
|
||||
while (traceFailureState.nextState(transition) == null) {
|
||||
traceFailureState = traceFailureState.failure();
|
||||
}
|
||||
State newFailureState = traceFailureState.nextState(transition);
|
||||
targetState.setFailure(newFailureState);
|
||||
targetState.addEmit(newFailureState.emit());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void storeEmits(int position, State currentState, List<Emit> collectedEmits)
|
||||
{
|
||||
Collection<String> emits = currentState.emit();
|
||||
if (emits != null && !emits.isEmpty()) {
|
||||
for (String emit : emits) {
|
||||
collectedEmits.add(new Emit(position - emit.length() + 1, position, emit));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -18,6 +18,14 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 0, 2, "abc");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void keywordAndTextAreTheSameFirstMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("abc");
|
||||
Emit firstMatch = trie.firstMatch("abc");
|
||||
checkEmit(firstMatch, 0, 2, "abc");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void textIsLongerThanKeyword() {
|
||||
Trie trie = new Trie();
|
||||
@ -27,6 +35,14 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 1, 3, "abc");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void textIsLongerThanKeywordFirstMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("abc");
|
||||
Emit firstMatch = trie.firstMatch(" abc");
|
||||
checkEmit(firstMatch, 1, 3, "abc");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void variousKeywordsOneMatch() {
|
||||
Trie trie = new Trie();
|
||||
@ -38,6 +54,16 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 0, 2, "bcd");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void variousKeywordsFirstMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("abc");
|
||||
trie.addKeyword("bcd");
|
||||
trie.addKeyword("cde");
|
||||
Emit firstMatch = trie.firstMatch("bcd");
|
||||
checkEmit(firstMatch, 0, 2, "bcd");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ushersTest() {
|
||||
Trie trie = new Trie();
|
||||
@ -53,6 +79,17 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 2, 5, "hers");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ushersTestFirstMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("hers");
|
||||
trie.addKeyword("his");
|
||||
trie.addKeyword("she");
|
||||
trie.addKeyword("he");
|
||||
Emit firstMatch = trie.firstMatch("ushers");
|
||||
checkEmit(firstMatch, 2, 3, "he");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void misleadingTest() {
|
||||
Trie trie = new Trie();
|
||||
@ -62,6 +99,14 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 9, 12, "hers");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void misleadingTestFirstMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("hers");
|
||||
Emit firstMatch = trie.firstMatch("h he her hers");
|
||||
checkEmit(firstMatch, 9, 12, "hers");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recipes() {
|
||||
Trie trie = new Trie();
|
||||
@ -77,20 +122,26 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 51, 58, "broccoli");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recipesFirstMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("veal");
|
||||
trie.addKeyword("cauliflower");
|
||||
trie.addKeyword("broccoli");
|
||||
trie.addKeyword("tomatoes");
|
||||
Emit firstMatch = trie.firstMatch("2 cauliflowers, 3 tomatoes, 4 slices of veal, 100g broccoli");
|
||||
|
||||
checkEmit(firstMatch, 2, 12, "cauliflower");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void longAndShortOverlappingMatch() {
|
||||
Trie trie = new Trie();
|
||||
trie.addKeyword("he");
|
||||
trie.addKeyword("hehehehe");
|
||||
Collection<Emit> emits = trie.parseText("hehehehehe");
|
||||
Iterator<Emit> iterator = emits.iterator();
|
||||
checkEmit(iterator.next(), 0, 1, "he");
|
||||
checkEmit(iterator.next(), 2, 3, "he");
|
||||
checkEmit(iterator.next(), 4, 5, "he");
|
||||
checkEmit(iterator.next(), 6, 7, "he");
|
||||
checkEmit(iterator.next(), 0, 7, "hehehehe");
|
||||
checkEmit(iterator.next(), 8, 9, "he");
|
||||
checkEmit(iterator.next(), 2, 9, "hehehehe");
|
||||
Emit firstMatch = trie.firstMatch("hehehehehe");
|
||||
|
||||
checkEmit(firstMatch, 0, 1, "he");
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -107,6 +158,17 @@ public class TrieTest {
|
||||
checkEmit(iterator.next(), 6, 7, "ab");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void nonOverlappingFirstMatch() {
|
||||
Trie trie = new Trie().removeOverlaps();
|
||||
trie.addKeyword("ab");
|
||||
trie.addKeyword("cba");
|
||||
trie.addKeyword("ababc");
|
||||
Emit firstMatch = trie.firstMatch("ababcbab");
|
||||
|
||||
checkEmit(firstMatch, 0, 4, "ababc");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startOfChurchillSpeech() {
|
||||
Trie trie = new Trie().removeOverlaps();
|
||||
@ -133,6 +195,15 @@ public class TrieTest {
|
||||
checkEmit(emits.iterator().next(), 20, 24, "sugar");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void partialMatchFirstMatch() {
|
||||
Trie trie = new Trie().onlyWholeWords();
|
||||
trie.addKeyword("sugar");
|
||||
Emit firstMatch = trie.firstMatch("sugarcane sugarcane sugar canesugar"); // left, middle, right test
|
||||
|
||||
checkEmit(firstMatch, 20, 24, "sugar");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void tokenizeFullSentence() {
|
||||
Trie trie = new Trie();
|
||||
@ -183,6 +254,18 @@ public class TrieTest {
|
||||
checkEmit(it.next(), 19, 23, "börkü");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void caseInsensitiveFirstMatch() {
|
||||
Trie trie = new Trie().caseInsensitive();
|
||||
trie.addKeyword("turning");
|
||||
trie.addKeyword("once");
|
||||
trie.addKeyword("again");
|
||||
trie.addKeyword("börkü");
|
||||
Emit firstMatch = trie.firstMatch("TurninG OnCe AgAiN BÖRKÜ");
|
||||
|
||||
checkEmit(firstMatch, 0, 6, "turning");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void tokenizeTokensInSequence() {
|
||||
Trie trie = new Trie();
|
||||
@ -214,6 +297,16 @@ public class TrieTest {
|
||||
checkEmit(it.next(), 5, 8, "this");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void unicodeIssueBug8ReportedByDwyerkFirstMatch() {
|
||||
String target = "LİKE THIS"; // The second character ('İ') is Unicode, which was read by AC as a 2-byte char
|
||||
Trie trie = new Trie().caseInsensitive().onlyWholeWords();
|
||||
assertEquals("THIS", target.substring(5,9)); // Java does it the right way
|
||||
trie.addKeyword("this");
|
||||
Emit firstMatch = trie.firstMatch(target);
|
||||
checkEmit(firstMatch, 5, 8, "this");
|
||||
}
|
||||
|
||||
private void checkEmit(Emit next, int expectedStart, int expectedEnd, String expectedKeyword) {
|
||||
assertEquals("Start of emit should have been "+expectedStart, expectedStart, next.getStart());
|
||||
assertEquals("End of emit should have been "+expectedEnd, expectedEnd, next.getEnd());
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user