RED-10200: Spike: Performant update logic for facts in working memory

This commit is contained in:
maverickstuder 2024-11-19 17:16:47 +01:00
parent 3f606ad567
commit 62591ebb1a

View File

@ -2,15 +2,20 @@ package com.iqser.red.service.redaction.v1.server.service.drools;
import static com.iqser.red.service.redaction.v1.server.service.drools.ComponentDroolsExecutionService.RULES_LOGGER_GLOBAL;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.api.runtime.KieContainer;
import org.kie.api.runtime.KieSession;
@ -30,12 +35,15 @@ import com.iqser.red.service.redaction.v1.server.logger.TrackingAgendaEventListe
import com.iqser.red.service.redaction.v1.server.model.NerEntities;
import com.iqser.red.service.redaction.v1.server.model.dictionary.Dictionary;
import com.iqser.red.service.redaction.v1.server.model.document.nodes.Document;
import com.iqser.red.service.redaction.v1.server.model.document.nodes.NodeType;
import com.iqser.red.service.redaction.v1.server.model.document.nodes.SemanticNode;
import com.iqser.red.service.redaction.v1.server.model.document.nodes.SuperSection;
import com.iqser.red.service.redaction.v1.server.service.ManualChangesApplicationService;
import com.iqser.red.service.redaction.v1.server.service.document.EntityCreationService;
import com.iqser.red.service.redaction.v1.server.service.document.EntityEnrichmentService;
import com.iqser.red.service.redaction.v1.server.service.websocket.WebSocketService;
import com.iqser.red.service.redaction.v1.server.utils.exception.DroolsTimeoutException;
import com.knecon.fforesight.tenantcommons.TenantContext;
import io.micrometer.core.annotation.Timed;
import io.micrometer.observation.ObservationRegistry;
@ -93,83 +101,129 @@ public class EntityDroolsExecutionService {
addNumberOfPagesAndSectionsToAnalyseToTrace(document.getNumberOfPages(), sectionsToAnalyze.size());
KieSession kieSession = kieContainer.newKieSession();
List<SuperSection> superSections = document.streamChildrenOfType(NodeType.SUPER_SECTION)
.map(SuperSection.class::cast)
.toList();
Set<SemanticNode> nodesInKieSession = sectionsToAnalyze.size() == document.streamAllSubNodes()
.count() ? Collections.emptySet() : buildSet(sectionsToAnalyze, document);
EntityCreationService entityCreationService = new EntityCreationService(entityEnrichmentService, kieSession, nodesInKieSession);
RulesLogger logger = new RulesLogger(webSocketService, context);
if (settings.isDroolsDebug()) {
logger.enableAgendaTracking();
logger.enableObjectTracking();
}
kieSession.addEventListener(new TrackingAgendaEventListener(logger));
kieSession.addEventListener(new ObjectTrackingEventListener(logger));
List<CompletableFuture<SuperSectionResult>> futures = new ArrayList<>();
kieSession.setGlobal("document", document);
kieSession.setGlobal("entityCreationService", entityCreationService);
kieSession.setGlobal("manualChangesApplicationService", manualChangesApplicationService);
kieSession.setGlobal("dictionary", dictionary);
String tenantId = TenantContext.getTenantId();
if (hasGlobalWithName(kieSession, RULES_LOGGER_GLOBAL)) {
kieSession.setGlobal(RULES_LOGGER_GLOBAL, logger);
}
superSections.parallelStream()
.forEach(superSection -> {
Set<SemanticNode> nodesInKieSessionInSuperSection = nodesInKieSession.stream()
.filter(node -> !node.getTreeId().isEmpty() && node.getTreeId()
.get(0)
.equals(superSection.getTreeId()
.get(0)))
.collect(Collectors.toSet());
CompletableFuture<SuperSectionResult> future = CompletableFuture.supplyAsync(() -> {
TenantContext.setTenantId(tenantId);
kieSession.insert(document);
KieSession kieSession = kieContainer.newKieSession();
try {
document.getEntities()
.forEach(kieSession::insert);
EntityCreationService entityCreationService = new EntityCreationService(entityEnrichmentService, kieSession, nodesInKieSessionInSuperSection);
RulesLogger logger = new RulesLogger(webSocketService, context);
if (settings.isDroolsDebug()) {
logger.enableAgendaTracking();
logger.enableObjectTracking();
}
kieSession.addEventListener(new TrackingAgendaEventListener(logger));
kieSession.addEventListener(new ObjectTrackingEventListener(logger));
sectionsToAnalyze.forEach(kieSession::insert);
kieSession.setGlobal("document", document);
kieSession.setGlobal("entityCreationService", entityCreationService);
kieSession.setGlobal("manualChangesApplicationService", manualChangesApplicationService);
kieSession.setGlobal("dictionary", dictionary);
sectionsToAnalyze.stream()
.flatMap(SemanticNode::streamAllSubNodes)
.forEach(kieSession::insert);
if (hasGlobalWithName(kieSession, RULES_LOGGER_GLOBAL)) {
kieSession.setGlobal(RULES_LOGGER_GLOBAL, logger);
}
document.getPages()
.forEach(kieSession::insert);
kieSession.insert(document);
fileAttributes.stream()
.filter(f -> f.getValue() != null)
.forEach(kieSession::insert);
superSection.getEntities()
.forEach(kieSession::insert);
if (manualRedactions != null) {
manualRedactions.buildAll()
.stream()
.filter(BaseAnnotation::isLocal)
.forEach(kieSession::insert);
}
nodesInKieSessionInSuperSection.forEach(kieSession::insert);
kieSession.insert(nerEntities);
nodesInKieSessionInSuperSection.stream()
.flatMap(SemanticNode::streamAllSubNodes)
.forEach(kieSession::insert);
kieSession.getAgenda().getAgendaGroup("LOCAL_DICTIONARY_ADDS").setFocus();
document.getPages()
.forEach(kieSession::insert);
CompletableFuture<Void> completableFuture = CompletableFuture.supplyAsync(() -> {
kieSession.fireAllRules();
return null;
fileAttributes.stream()
.filter(f -> f.getValue() != null)
.forEach(kieSession::insert);
if (manualRedactions != null) {
manualRedactions.buildAll()
.stream()
.filter(BaseAnnotation::isLocal)
.forEach(kieSession::insert);
}
kieSession.insert(nerEntities);
kieSession.getAgenda().getAgendaGroup("LOCAL_DICTIONARY_ADDS").setFocus();
System.out.println("kieSession.getFactCount(): " + kieSession.getFactCount());
kieSession.fireAllRules();
List<FileAttribute> resultingFileAttributes = getFileAttributes(kieSession);
kieSession.dispose();
return new SuperSectionResult(resultingFileAttributes);
} catch (Exception e) {
kieSession.dispose();
throw new RuntimeException(e);
}
});
futures.add(future);
});
List<FileAttribute> resultingFileAttributes = new ArrayList<>();
futures.parallelStream().forEach(future -> {
var start = System.currentTimeMillis();
try {
SuperSectionResult result = future.get(settings.getDroolsExecutionTimeoutSecs(document.getNumberOfPages()), TimeUnit.SECONDS);
addOrUpdate(resultingFileAttributes, result.getFileAttributes());
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof TimeoutException) {
throw new DroolsTimeoutException(String.format("The file %s caused a timeout", context.getFileId()), cause, false, RuleFileType.ENTITY);
}
throw new RuntimeException(cause);
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (TimeoutException e) {
throw new DroolsTimeoutException(String.format("The file %s caused a timeout", context.getFileId()), e, false, RuleFileType.ENTITY);
}
System.out.printf("Total time in %s : %d ms\n", future, System.currentTimeMillis() - start);
});
try {
completableFuture.get(settings.getDroolsExecutionTimeoutSecs(document.getNumberOfPages()), TimeUnit.SECONDS);
} catch (ExecutionException e) {
logger.error(e, "Exception during rule execution");
kieSession.dispose();
if (e.getCause() instanceof TimeoutException) {
throw new DroolsTimeoutException(String.format("The file %s caused a timeout",context.getFileId()), e, false, RuleFileType.ENTITY);
}
throw new RuntimeException(e);
} catch (InterruptedException e) {
logger.error(e, "Exception during rule execution");
kieSession.dispose();
throw new RuntimeException(e);
} catch (TimeoutException e) {
throw new DroolsTimeoutException(String.format("The file %s caused a timeout",context.getFileId()), e, false, RuleFileType.ENTITY);
}
addOrUpdate(fileAttributes, new ArrayList<>(resultingFileAttributes));
List<FileAttribute> resultingFileAttributes = getFileAttributes(kieSession);
kieSession.dispose();
return resultingFileAttributes;
return new ArrayList<>(resultingFileAttributes);
}
private static void addOrUpdate(List<FileAttribute> fileAttributes, List<FileAttribute> resultingFileAttributes) {
for (FileAttribute resultingFileAttribute : resultingFileAttributes) {
fileAttributes.removeIf(fa -> fa.getLabel().equals(resultingFileAttribute.getLabel()));
fileAttributes.add(resultingFileAttribute);
}
}
@ -213,4 +267,23 @@ public class EntityDroolsExecutionService {
.anyMatch(global -> global.getName().equals(globalName));
}
private static class SuperSectionResult {
private final List<FileAttribute> fileAttributes;
public SuperSectionResult(List<FileAttribute> fileAttributes) {
this.fileAttributes = fileAttributes;
}
public List<FileAttribute> getFileAttributes() {
return fileAttributes;
}
}
}