From 62591ebb1a7a2cc5dc54ff7ae3bd74e65f9d4b95 Mon Sep 17 00:00:00 2001 From: maverickstuder Date: Tue, 19 Nov 2024 17:16:47 +0100 Subject: [PATCH] RED-10200: Spike: Performant update logic for facts in working memory --- .../drools/EntityDroolsExecutionService.java | 189 ++++++++++++------ 1 file changed, 131 insertions(+), 58 deletions(-) diff --git a/redaction-service-v1/redaction-service-server-v1/src/main/java/com/iqser/red/service/redaction/v1/server/service/drools/EntityDroolsExecutionService.java b/redaction-service-v1/redaction-service-server-v1/src/main/java/com/iqser/red/service/redaction/v1/server/service/drools/EntityDroolsExecutionService.java index bbad6457..1723db9f 100644 --- a/redaction-service-v1/redaction-service-server-v1/src/main/java/com/iqser/red/service/redaction/v1/server/service/drools/EntityDroolsExecutionService.java +++ b/redaction-service-v1/redaction-service-server-v1/src/main/java/com/iqser/red/service/redaction/v1/server/service/drools/EntityDroolsExecutionService.java @@ -2,15 +2,20 @@ package com.iqser.red.service.redaction.v1.server.service.drools; import static com.iqser.red.service.redaction.v1.server.service.drools.ComponentDroolsExecutionService.RULES_LOGGER_GLOBAL; +import java.time.OffsetDateTime; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; import org.kie.api.runtime.KieContainer; import org.kie.api.runtime.KieSession; @@ -30,12 +35,15 @@ import com.iqser.red.service.redaction.v1.server.logger.TrackingAgendaEventListe import com.iqser.red.service.redaction.v1.server.model.NerEntities; import com.iqser.red.service.redaction.v1.server.model.dictionary.Dictionary; import com.iqser.red.service.redaction.v1.server.model.document.nodes.Document; +import com.iqser.red.service.redaction.v1.server.model.document.nodes.NodeType; import com.iqser.red.service.redaction.v1.server.model.document.nodes.SemanticNode; +import com.iqser.red.service.redaction.v1.server.model.document.nodes.SuperSection; import com.iqser.red.service.redaction.v1.server.service.ManualChangesApplicationService; import com.iqser.red.service.redaction.v1.server.service.document.EntityCreationService; import com.iqser.red.service.redaction.v1.server.service.document.EntityEnrichmentService; import com.iqser.red.service.redaction.v1.server.service.websocket.WebSocketService; import com.iqser.red.service.redaction.v1.server.utils.exception.DroolsTimeoutException; +import com.knecon.fforesight.tenantcommons.TenantContext; import io.micrometer.core.annotation.Timed; import io.micrometer.observation.ObservationRegistry; @@ -93,83 +101,129 @@ public class EntityDroolsExecutionService { addNumberOfPagesAndSectionsToAnalyseToTrace(document.getNumberOfPages(), sectionsToAnalyze.size()); - KieSession kieSession = kieContainer.newKieSession(); + List superSections = document.streamChildrenOfType(NodeType.SUPER_SECTION) + .map(SuperSection.class::cast) + .toList(); Set nodesInKieSession = sectionsToAnalyze.size() == document.streamAllSubNodes() .count() ? Collections.emptySet() : buildSet(sectionsToAnalyze, document); - EntityCreationService entityCreationService = new EntityCreationService(entityEnrichmentService, kieSession, nodesInKieSession); - RulesLogger logger = new RulesLogger(webSocketService, context); - if (settings.isDroolsDebug()) { - logger.enableAgendaTracking(); - logger.enableObjectTracking(); - } - kieSession.addEventListener(new TrackingAgendaEventListener(logger)); - kieSession.addEventListener(new ObjectTrackingEventListener(logger)); + List> futures = new ArrayList<>(); - kieSession.setGlobal("document", document); - kieSession.setGlobal("entityCreationService", entityCreationService); - kieSession.setGlobal("manualChangesApplicationService", manualChangesApplicationService); - kieSession.setGlobal("dictionary", dictionary); + String tenantId = TenantContext.getTenantId(); - if (hasGlobalWithName(kieSession, RULES_LOGGER_GLOBAL)) { - kieSession.setGlobal(RULES_LOGGER_GLOBAL, logger); - } + superSections.parallelStream() + .forEach(superSection -> { + Set nodesInKieSessionInSuperSection = nodesInKieSession.stream() + .filter(node -> !node.getTreeId().isEmpty() && node.getTreeId() + .get(0) + .equals(superSection.getTreeId() + .get(0))) + .collect(Collectors.toSet()); + CompletableFuture future = CompletableFuture.supplyAsync(() -> { + TenantContext.setTenantId(tenantId); - kieSession.insert(document); + KieSession kieSession = kieContainer.newKieSession(); + try { - document.getEntities() - .forEach(kieSession::insert); + EntityCreationService entityCreationService = new EntityCreationService(entityEnrichmentService, kieSession, nodesInKieSessionInSuperSection); + RulesLogger logger = new RulesLogger(webSocketService, context); + if (settings.isDroolsDebug()) { + logger.enableAgendaTracking(); + logger.enableObjectTracking(); + } + kieSession.addEventListener(new TrackingAgendaEventListener(logger)); + kieSession.addEventListener(new ObjectTrackingEventListener(logger)); - sectionsToAnalyze.forEach(kieSession::insert); + kieSession.setGlobal("document", document); + kieSession.setGlobal("entityCreationService", entityCreationService); + kieSession.setGlobal("manualChangesApplicationService", manualChangesApplicationService); + kieSession.setGlobal("dictionary", dictionary); - sectionsToAnalyze.stream() - .flatMap(SemanticNode::streamAllSubNodes) - .forEach(kieSession::insert); + if (hasGlobalWithName(kieSession, RULES_LOGGER_GLOBAL)) { + kieSession.setGlobal(RULES_LOGGER_GLOBAL, logger); + } - document.getPages() - .forEach(kieSession::insert); + kieSession.insert(document); - fileAttributes.stream() - .filter(f -> f.getValue() != null) - .forEach(kieSession::insert); + superSection.getEntities() + .forEach(kieSession::insert); - if (manualRedactions != null) { - manualRedactions.buildAll() - .stream() - .filter(BaseAnnotation::isLocal) - .forEach(kieSession::insert); - } + nodesInKieSessionInSuperSection.forEach(kieSession::insert); - kieSession.insert(nerEntities); + nodesInKieSessionInSuperSection.stream() + .flatMap(SemanticNode::streamAllSubNodes) + .forEach(kieSession::insert); - kieSession.getAgenda().getAgendaGroup("LOCAL_DICTIONARY_ADDS").setFocus(); + document.getPages() + .forEach(kieSession::insert); - CompletableFuture completableFuture = CompletableFuture.supplyAsync(() -> { - kieSession.fireAllRules(); - return null; + fileAttributes.stream() + .filter(f -> f.getValue() != null) + .forEach(kieSession::insert); + + if (manualRedactions != null) { + manualRedactions.buildAll() + .stream() + .filter(BaseAnnotation::isLocal) + .forEach(kieSession::insert); + } + + kieSession.insert(nerEntities); + + kieSession.getAgenda().getAgendaGroup("LOCAL_DICTIONARY_ADDS").setFocus(); + + + System.out.println("kieSession.getFactCount(): " + kieSession.getFactCount()); + kieSession.fireAllRules(); + + List resultingFileAttributes = getFileAttributes(kieSession); + + kieSession.dispose(); + + return new SuperSectionResult(resultingFileAttributes); + + } catch (Exception e) { + kieSession.dispose(); + throw new RuntimeException(e); + } + }); + futures.add(future); + }); + + List resultingFileAttributes = new ArrayList<>(); + + futures.parallelStream().forEach(future -> { + var start = System.currentTimeMillis(); + try { + SuperSectionResult result = future.get(settings.getDroolsExecutionTimeoutSecs(document.getNumberOfPages()), TimeUnit.SECONDS); + addOrUpdate(resultingFileAttributes, result.getFileAttributes()); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof TimeoutException) { + throw new DroolsTimeoutException(String.format("The file %s caused a timeout", context.getFileId()), cause, false, RuleFileType.ENTITY); + } + throw new RuntimeException(cause); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (TimeoutException e) { + throw new DroolsTimeoutException(String.format("The file %s caused a timeout", context.getFileId()), e, false, RuleFileType.ENTITY); + } + System.out.printf("Total time in %s : %d ms\n", future, System.currentTimeMillis() - start); }); - try { - completableFuture.get(settings.getDroolsExecutionTimeoutSecs(document.getNumberOfPages()), TimeUnit.SECONDS); - } catch (ExecutionException e) { - logger.error(e, "Exception during rule execution"); - kieSession.dispose(); - if (e.getCause() instanceof TimeoutException) { - throw new DroolsTimeoutException(String.format("The file %s caused a timeout",context.getFileId()), e, false, RuleFileType.ENTITY); - } - throw new RuntimeException(e); - } catch (InterruptedException e) { - logger.error(e, "Exception during rule execution"); - kieSession.dispose(); - throw new RuntimeException(e); - } catch (TimeoutException e) { - throw new DroolsTimeoutException(String.format("The file %s caused a timeout",context.getFileId()), e, false, RuleFileType.ENTITY); - } + addOrUpdate(fileAttributes, new ArrayList<>(resultingFileAttributes)); - List resultingFileAttributes = getFileAttributes(kieSession); - kieSession.dispose(); - return resultingFileAttributes; + return new ArrayList<>(resultingFileAttributes); + } + + + private static void addOrUpdate(List fileAttributes, List resultingFileAttributes) { + + for (FileAttribute resultingFileAttribute : resultingFileAttributes) { + fileAttributes.removeIf(fa -> fa.getLabel().equals(resultingFileAttribute.getLabel())); + fileAttributes.add(resultingFileAttribute); + } } @@ -213,4 +267,23 @@ public class EntityDroolsExecutionService { .anyMatch(global -> global.getName().equals(globalName)); } + + private static class SuperSectionResult { + + private final List fileAttributes; + + + public SuperSectionResult(List fileAttributes) { + + this.fileAttributes = fileAttributes; + } + + + public List getFileAttributes() { + + return fileAttributes; + } + + } + }