RED-10127: add list classification

* refactor headline font sizes
* remove title case
* no real drawbacks, mostly edge cases
* +0.1% F1 Score (2 Files with +8%)
This commit is contained in:
Kilian Schuettler 2024-10-10 10:34:59 +02:00
parent 28d79902ff
commit e9b406af16
4 changed files with 88 additions and 80 deletions

View File

@ -374,14 +374,7 @@ public class LayoutParsingPipeline {
classificationService.classify(classificationDocument, layoutParsingType, identifier);
List<TextPageBlock> headlines = classificationDocument.getPages()
.stream()
.flatMap(classificationPage -> classificationPage.getTextBlocks()
.stream()
.filter(tb -> tb instanceof TextPageBlock && tb.getClassification() != null && tb.getClassification().isHeadline())
.map(tb -> (TextPageBlock) tb))
.toList();
TableOfContents tableOfContents = outlineValidationService.createToC(headlines);
TableOfContents tableOfContents = outlineValidationService.createToC(classificationDocument);
classificationDocument.setTableOfContents(tableOfContents);
log.info("Building Sections for {}", identifier);

View File

@ -10,6 +10,7 @@ import java.util.TreeSet;
import org.springframework.stereotype.Service;
import com.knecon.fforesight.service.layoutparser.processor.model.ClassificationDocument;
import com.knecon.fforesight.service.layoutparser.processor.model.text.TextPageBlock;
import io.micrometer.observation.annotation.Observed;
@ -20,7 +21,9 @@ import lombok.extern.slf4j.Slf4j;
public class OutlineValidationService {
@Observed(name = "OutlineValidationService", contextualName = "create-toc")
public TableOfContents createToC(List<TextPageBlock> headlines) {
public TableOfContents createToC(ClassificationDocument classificationDocument) {
List<TextPageBlock> headlines = extractHeadlines(classificationDocument);
List<TableOfContentItem> mainSections = new ArrayList<>();
Map<Integer, TableOfContentItem> lastItemsPerDepth = new HashMap<>();
@ -60,4 +63,16 @@ public class OutlineValidationService {
return new TableOfContents(mainSections);
}
private static List<TextPageBlock> extractHeadlines(ClassificationDocument classificationDocument) {
return classificationDocument.getPages()
.stream()
.flatMap(classificationPage -> classificationPage.getTextBlocks()
.stream()
.filter(tb -> tb instanceof TextPageBlock && tb.getClassification() != null && tb.getClassification().isHeadline())
.map(tb -> (TextPageBlock) tb))
.toList();
}
}

View File

@ -62,74 +62,6 @@ public class DocuMineClassificationService {
}
private static List<Double> buildHeadlineFontSizes(ClassificationDocument document) {
if (document.getFontSizeCounter().getCountPerValue().size() <= 6) {
return document.getFontSizeCounter().getValuesInReverseOrder();
}
List<Map.Entry<Double, Integer>> sortedEntries = new ArrayList<>(document.getFontSizeCounter().getCountPerValue().entrySet());
sortedEntries.sort(Map.Entry.comparingByKey());
int totalCount = sortedEntries.stream()
.mapToInt(Map.Entry::getValue).sum();
int cumulativeCount = 0;
Iterator<Map.Entry<Double, Integer>> iterator = sortedEntries.iterator();
while (iterator.hasNext()) {
Map.Entry<Double, Integer> entry = iterator.next();
cumulativeCount += entry.getValue();
if (cumulativeCount > totalCount * 0.3) {
break; // We've filtered the bottom 30%, so stop.
}
iterator.remove();
}
if (sortedEntries.size() < 6) {
return document.getFontSizeCounter().getValuesInReverseOrder();
}
int clusterSize = Math.max(1, sortedEntries.size() / 6);
List<List<Double>> clusters = new ArrayList<>();
for (int i = 0; i < 6; i++) {
clusters.add(new ArrayList<>());
}
for (int i = 0; i < sortedEntries.size(); i++) {
int clusterIndex = Math.min(i / clusterSize, 5);
clusters.get(clusterIndex).add(sortedEntries.get(i).getKey());
}
return clusters.stream()
.map(cluster -> cluster.stream()
.mapToDouble(d -> d).average()
.orElseThrow())
.sorted(Comparator.reverseOrder())
.toList();
}
private List<AbstractPageBlock> getSurroundingBlocksOnPage(int originalIndex, List<AbstractBlockOnPage> textBlocks) {
int start = Math.max(originalIndex - SURROUNDING_BLOCKS_RADIUS, 0);
int end = Math.min(originalIndex + SURROUNDING_BLOCKS_RADIUS, textBlocks.size());
List<AbstractPageBlock> surroundingBlocks = new ArrayList<>(2 * SURROUNDING_BLOCKS_RADIUS);
for (int i = start; i < end; i++) {
if (i == originalIndex) {
continue;
}
if (textBlocks.get(i).block().getText().length() <= 1) {
continue;
}
if (textBlocks.get(i).page() != textBlocks.get(originalIndex).page()) {
continue;
}
surroundingBlocks.add(textBlocks.get(i).block());
}
return surroundingBlocks;
}
private void classifyBlock(HeadlineClassificationService headlineClassificationService,
int currentIndex,
List<AbstractBlockOnPage> allBlocks,
@ -331,6 +263,74 @@ public class DocuMineClassificationService {
return blocks;
}
private static List<Double> buildHeadlineFontSizes(ClassificationDocument document) {
if (document.getFontSizeCounter().getCountPerValue().size() <= 6) {
return document.getFontSizeCounter().getValuesInReverseOrder();
}
List<Map.Entry<Double, Integer>> sortedEntries = new ArrayList<>(document.getFontSizeCounter().getCountPerValue().entrySet());
sortedEntries.sort(Map.Entry.comparingByKey());
int totalCount = sortedEntries.stream()
.mapToInt(Map.Entry::getValue).sum();
int cumulativeCount = 0;
Iterator<Map.Entry<Double, Integer>> iterator = sortedEntries.iterator();
while (iterator.hasNext()) {
Map.Entry<Double, Integer> entry = iterator.next();
cumulativeCount += entry.getValue();
if (cumulativeCount > totalCount * 0.3) {
break; // We've filtered the bottom 30%, so stop.
}
iterator.remove();
}
if (sortedEntries.size() < 6) {
return document.getFontSizeCounter().getValuesInReverseOrder();
}
int clusterSize = Math.max(1, sortedEntries.size() / 6);
List<List<Double>> clusters = new ArrayList<>();
for (int i = 0; i < 6; i++) {
clusters.add(new ArrayList<>());
}
for (int i = 0; i < sortedEntries.size(); i++) {
int clusterIndex = Math.min(i / clusterSize, 5);
clusters.get(clusterIndex).add(sortedEntries.get(i).getKey());
}
return clusters.stream()
.map(cluster -> cluster.stream()
.mapToDouble(d -> d).average()
.orElseThrow())
.sorted(Comparator.reverseOrder())
.toList();
}
private List<AbstractPageBlock> getSurroundingBlocksOnPage(int originalIndex, List<AbstractBlockOnPage> textBlocks) {
int start = Math.max(originalIndex - SURROUNDING_BLOCKS_RADIUS, 0);
int end = Math.min(originalIndex + SURROUNDING_BLOCKS_RADIUS, textBlocks.size());
List<AbstractPageBlock> surroundingBlocks = new ArrayList<>(2 * SURROUNDING_BLOCKS_RADIUS);
for (int i = start; i < end; i++) {
if (i == originalIndex) {
continue;
}
if (textBlocks.get(i).block().getText().length() <= 1) {
continue;
}
if (!textBlocks.get(i).page().equals(textBlocks.get(originalIndex).page())) {
continue;
}
surroundingBlocks.add(textBlocks.get(i).block());
}
return surroundingBlocks;
}
}

View File

@ -79,7 +79,7 @@ public class OutlineDetectionTest extends AbstractTest {
var documentFile = new ClassPathResource(fileName).getFile();
long start = System.currentTimeMillis();
ClassificationDocument classificationDocument = parseLayout(fileName, LayoutParsingType.REDACT_MANAGER_WITHOUT_DUPLICATE_PARAGRAPH);
ClassificationDocument classificationDocument = parseLayout(fileName, LayoutParsingType.DOCUMINE_OLD);
Document document = buildGraph(fileName, classificationDocument);
layoutGridService.addLayoutGrid(documentFile, document, new File(tmpFileName), true);
OutlineObjectTree outlineObjectTree = classificationDocument.getOutlineObjectTree();
@ -102,7 +102,7 @@ public class OutlineDetectionTest extends AbstractTest {
TableOfContents tableOfContents = classificationDocument.getTableOfContents();
assertEquals(tableOfContents.getMainSections().size(), 10);
assertEquals(tableOfContents.getMainSections().size(), 9);
assertEquals(tableOfContents.getMainSections().subList(1, 9)
.stream()
.map(tableOfContentItem -> sanitizeString(tableOfContentItem.getHeadline().toString()))
@ -135,7 +135,7 @@ public class OutlineDetectionTest extends AbstractTest {
List<SemanticNode> childrenOfTypeSectionOrSuperSection = document.getChildrenOfTypeSectionOrSuperSection();
assertEquals(childrenOfTypeSectionOrSuperSection.size(), 10);
assertEquals(childrenOfTypeSectionOrSuperSection.size(), 9);
assertEquals(childrenOfTypeSectionOrSuperSection.subList(1, 9)
.stream()
.map(section -> sanitizeString(section.getHeadline().getLeafTextBlock().toString()))