RED-5275: Calculate precision and recall for headline detection

This commit is contained in:
deiflaender 2022-10-11 10:53:05 +02:00
parent 9d88925ff1
commit 8d88b19915
6 changed files with 32471 additions and 1 deletions

View File

@ -338,7 +338,6 @@ public class Section {
Set<Entity> expanded = new HashSet<>();
for (var entity : entities) {
System.out.println(entity.getWord());
if (!entity.getType().equals(type) || entity.getTextBefore() == null) {
continue;

View File

@ -0,0 +1,422 @@
package com.iqser.red.service.redaction.v1.server;
import static org.mockito.Mockito.when;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.kie.api.KieServices;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.KieFileSystem;
import org.kie.api.builder.KieModule;
import org.kie.api.runtime.KieContainer;
import org.mockito.stubbing.Answer;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.amqp.RabbitAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Primary;
import org.springframework.core.io.ClassPathResource;
import org.springframework.test.context.junit4.SpringRunner;
import com.amazonaws.services.s3.AmazonS3;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.iqser.red.service.persistence.service.v1.api.model.common.JSONPrimitive;
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.configuration.Colors;
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.dossier.file.FileType;
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.type.DictionaryEntry;
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.type.Type;
import com.iqser.red.service.redaction.v1.model.AnalyzeRequest;
import com.iqser.red.service.redaction.v1.model.ChangeType;
import com.iqser.red.service.redaction.v1.model.RedactionLog;
import com.iqser.red.service.redaction.v1.model.StructureAnalyzeRequest;
import com.iqser.red.service.redaction.v1.server.annotate.AnnotationService;
import com.iqser.red.service.redaction.v1.server.client.DictionaryClient;
import com.iqser.red.service.redaction.v1.server.client.LegalBasisClient;
import com.iqser.red.service.redaction.v1.server.client.RulesClient;
import com.iqser.red.service.redaction.v1.server.controller.RedactionController;
import com.iqser.red.service.redaction.v1.server.redaction.service.AnalyzeService;
import com.iqser.red.service.redaction.v1.server.redaction.service.ManualRedactionSurroundingTextService;
import com.iqser.red.service.redaction.v1.server.redaction.utils.ResourceLoader;
import com.iqser.red.service.redaction.v1.server.storage.RedactionStorageService;
import com.iqser.red.storage.commons.StorageAutoConfiguration;
import com.iqser.red.storage.commons.service.StorageService;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.SneakyThrows;
import lombok.ToString;
@RunWith(SpringRunner.class)
@SpringBootTest(classes = Application.class, webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@Import(HeadlinesGoldStandardIntegrationTest.RedactionIntegrationTestConfiguration.class)
public class HeadlinesGoldStandardIntegrationTest {
private static final String RULES = loadFromClassPath("drools/headlines.drl");
private static final String HEADLINE = "headline";
@Autowired
private RedactionController redactionController;
@Autowired
private AnnotationService annotationService;
@Autowired
private AnalyzeService analyzeService;
@Autowired
private ObjectMapper objectMapper;
@MockBean
private RulesClient rulesClient;
@MockBean
private DictionaryClient dictionaryClient;
@Autowired
private RedactionStorageService redactionStorageService;
@Autowired
private StorageService storageService;
@Autowired
private ManualRedactionSurroundingTextService manualRedactionSurroundingTextService;
@MockBean
private AmazonS3 amazonS3;
@MockBean
private RabbitTemplate rabbitTemplate;
@MockBean
private LegalBasisClient legalBasisClient;
private final Map<String, List<String>> dictionary = new HashMap<>();
private final Map<String, List<String>> dossierDictionary = new HashMap<>();
private final Map<String, String> typeColorMap = new HashMap<>();
private final Map<String, Boolean> hintTypeMap = new HashMap<>();
private final Map<String, Boolean> caseInSensitiveMap = new HashMap<>();
private final Map<String, Boolean> recommendationTypeMap = new HashMap<>();
private final Map<String, Integer> rankTypeMap = new HashMap<>();
private final Colors colors = new Colors();
private final Map<String, Long> reanlysisVersions = new HashMap<>();
private final Set<String> deleted = new HashSet<>();
private final static String TEST_DOSSIER_TEMPLATE_ID = "123";
private final static String TEST_DOSSIER_ID = "123";
private final static String TEST_FILE_ID = "123";
@Test
public void testHeadlineDetection() {
List<Metrics> metrics = new ArrayList<>();
metrics.add(getMetrics("files/RSS/01 - CGA100251 - Acute Oral Toxicity (Up and Down Procedure) - Rat (1).pdf",
"files/Headlines/01 - CGA100251 - Acute Oral Toxicity (Up and Down Procedure) - Rat (1)_REDACTION_LOG.json"));
metrics.add(getMetrics("files/Trinexapac/91 Trinexapac-ethyl_RAR_01_Volume_1_2018-02-23.pdf",
"files/Headlines/91 Trinexapac-ethyl_RAR_01_Volume_1_2018-02-23_REDACTION_LOG.json"));
metrics.add(getMetrics("files/Metolachlor/S-Metolachlor_RAR_01_Volume_1_2018-09-06.pdf", "files/Headlines/S-Metolachlor_RAR_01_Volume_1_2018-09-06_REDACTION_LOG.json"));
float precision = 0;
float recall = 0;
for (var m : metrics) {
precision += m.getPrecision();
recall += m.getRecall();
}
precision = precision / metrics.size();
recall = recall / metrics.size();
System.out.println("Precision is: " + precision + " recall is: " + recall);
Assertions.assertThat(precision).isGreaterThanOrEqualTo(0.45f);
Assertions.assertThat(recall).isGreaterThanOrEqualTo(0.69f);
}
@SneakyThrows
private Metrics getMetrics(String fileUrl, String redactionLogUrl) {
ClassPathResource redactionLogResource = new ClassPathResource(redactionLogUrl);
Set<Headline> goldStandardHeadlines = new HashSet<>();
var goldStandardLog = objectMapper.readValue(redactionLogResource.getInputStream(), RedactionLog.class);
goldStandardLog.getRedactionLogEntry().removeIf(r -> !r.isRedacted() || r.getChanges().get(r.getChanges().size() - 1).getType().equals(ChangeType.REMOVED));
goldStandardLog.getRedactionLogEntry().forEach(e -> goldStandardHeadlines.add(new Headline(e.getPositions().get(0).getPage(), e.getValue())));
AnalyzeRequest request = prepareStorage(fileUrl);
analyzeService.analyzeDocumentStructure(new StructureAnalyzeRequest(request.getDossierId(), request.getFileId()));
analyzeService.analyze(request);
List<Headline> foundHeadlines = new ArrayList<>();
var redactionLog = redactionStorageService.getRedactionLog(TEST_DOSSIER_ID, TEST_FILE_ID);
redactionLog.getRedactionLogEntry().forEach(e -> foundHeadlines.add(new Headline(e.getPositions().get(0).getPage(), e.getValue())));
Set<Headline> correct = new HashSet<>();
Set<Headline> missing;
Set<Headline> falsePositive = new HashSet<>();
for (Headline headline : foundHeadlines) {
if (goldStandardHeadlines.contains(headline)) {
correct.add(headline);
} else {
falsePositive.add(headline);
}
}
missing = goldStandardHeadlines.stream().filter(h -> !correct.contains(h)).collect(Collectors.toSet());
float precision = (float) correct.size() / ((float) correct.size() + (float) falsePositive.size());
float recall = (float) correct.size() / ((float) correct.size() + (float) missing.size());
return new Metrics(precision, recall);
}
@Configuration
@EnableAutoConfiguration(exclude = {RabbitAutoConfiguration.class, StorageAutoConfiguration.class})
public static class RedactionIntegrationTestConfiguration {
@Bean
public KieContainer kieContainer() {
KieServices kieServices = KieServices.Factory.get();
KieFileSystem kieFileSystem = kieServices.newKieFileSystem();
InputStream input = new ByteArrayInputStream(RULES.getBytes(StandardCharsets.UTF_8));
kieFileSystem.write("src/test/resources/drools/headlines.drl", kieServices.getResources().newInputStreamResource(input));
KieBuilder kieBuilder = kieServices.newKieBuilder(kieFileSystem);
kieBuilder.buildAll();
KieModule kieModule = kieBuilder.getKieModule();
return kieServices.newKieContainer(kieModule.getReleaseId());
}
@Bean
@Primary
public StorageService inmemoryStorage() {
return new FileSystemBackedStorageService();
}
}
@After
public void cleanupStorage() {
if (this.storageService instanceof FileSystemBackedStorageService) {
((FileSystemBackedStorageService) this.storageService).clearStorage();
}
}
@Before
public void stubClients() {
when(rulesClient.getVersion(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(0L);
when(rulesClient.getRules(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(JSONPrimitive.of(RULES));
loadDictionaryForTest();
loadTypeForTest();
loadNerForTest();
when(dictionaryClient.getVersion(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(0L);
when(dictionaryClient.getAllTypesForDossierTemplate(TEST_DOSSIER_TEMPLATE_ID, false)).thenReturn(getTypeResponse());
when(dictionaryClient.getVersion(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(0L);
mockDictionaryCalls(null);
mockDictionaryCalls(0L);
when(dictionaryClient.getColors(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(colors);
}
private void mockDictionaryCalls(Long version) {
when(dictionaryClient.getDictionaryForType(HEADLINE + ":" + TEST_DOSSIER_TEMPLATE_ID, version)).then((Answer<Type>) invocation -> getDictionaryResponse(HEADLINE, false));
}
private void loadDictionaryForTest() {
dictionary.computeIfAbsent(HEADLINE, v -> new ArrayList<>());
}
private void loadTypeForTest() {
typeColorMap.put(HEADLINE, "#f90707");
hintTypeMap.put(HEADLINE, false);
caseInSensitiveMap.put(HEADLINE, false);
recommendationTypeMap.put(HEADLINE, false);
rankTypeMap.put(HEADLINE, 155);
colors.setSkippedColor("#cccccc");
colors.setRequestAddColor("#04b093");
colors.setRequestRemoveColor("#04b093");
}
private List<Type> getTypeResponse() {
return typeColorMap.entrySet()
.stream()
.map(typeColor -> Type.builder()
.id(typeColor.getKey() + ":" + TEST_DOSSIER_TEMPLATE_ID)
.type(typeColor.getKey())
.dossierTemplateId(TEST_DOSSIER_TEMPLATE_ID)
.hexColor(typeColor.getValue())
.isHint(hintTypeMap.get(typeColor.getKey()))
.isCaseInsensitive(caseInSensitiveMap.get(typeColor.getKey()))
.isRecommendation(recommendationTypeMap.get(typeColor.getKey()))
.rank(rankTypeMap.get(typeColor.getKey()))
.build())
.collect(Collectors.toList());
}
private Type getDictionaryResponse(String type, boolean isDossierDictionary) {
return Type.builder()
.id(type + ":" + TEST_DOSSIER_TEMPLATE_ID)
.hexColor(typeColorMap.get(type))
.entries(isDossierDictionary ? toDictionaryEntry(dossierDictionary.get(type)) : toDictionaryEntry(dictionary.get(type)))
.falsePositiveEntries(new ArrayList<>())
.falseRecommendationEntries(new ArrayList<>())
.isHint(hintTypeMap.get(type))
.isCaseInsensitive(caseInSensitiveMap.get(type))
.isRecommendation(recommendationTypeMap.get(type))
.rank(rankTypeMap.get(type))
.build();
}
private List<DictionaryEntry> toDictionaryEntry(List<String> entries) {
if (entries == null) {
entries = Collections.emptyList();
}
List<DictionaryEntry> dictionaryEntries = new ArrayList<>();
entries.forEach(entry -> {
dictionaryEntries.add(DictionaryEntry.builder().value(entry).version(reanlysisVersions.getOrDefault(entry, 0L)).deleted(deleted.contains(entry)).build());
});
return dictionaryEntries;
}
@SneakyThrows
private AnalyzeRequest prepareStorage(String file) {
return prepareStorage(file, "files/cv_service_empty_response.json");
}
@SneakyThrows
private AnalyzeRequest prepareStorage(String file, String cvServiceResponseFile) {
ClassPathResource pdfFileResource = new ClassPathResource(file);
ClassPathResource cvServiceResponseFileResource = new ClassPathResource(cvServiceResponseFile);
return prepareStorage(pdfFileResource.getInputStream(), cvServiceResponseFileResource.getInputStream());
}
@SneakyThrows
private AnalyzeRequest prepareStorage(InputStream fileStream, InputStream cvServiceResponseFileStream) {
AnalyzeRequest request = AnalyzeRequest.builder()
.dossierTemplateId(TEST_DOSSIER_TEMPLATE_ID)
.dossierId(TEST_DOSSIER_ID)
.fileId(TEST_FILE_ID)
.lastProcessed(OffsetDateTime.now())
.build();
storageService.storeObject(RedactionStorageService.StorageIdUtils.getStorageId(TEST_DOSSIER_ID, TEST_FILE_ID, FileType.TABLES), cvServiceResponseFileStream);
storageService.storeObject(RedactionStorageService.StorageIdUtils.getStorageId(TEST_DOSSIER_ID, TEST_FILE_ID, FileType.ORIGIN), fileStream);
return request;
}
private static String loadFromClassPath(String path) {
URL resource = ResourceLoader.class.getClassLoader().getResource(path);
if (resource == null) {
throw new IllegalArgumentException("could not load classpath resource: drools/rules.drl");
}
try (BufferedReader br = new BufferedReader(new InputStreamReader(resource.openStream(), StandardCharsets.UTF_8))) {
StringBuilder sb = new StringBuilder();
String str;
while ((str = br.readLine()) != null) {
sb.append(str).append("\n");
}
return sb.toString();
} catch (IOException e) {
throw new IllegalArgumentException("could not load classpath resource: " + path, e);
}
}
@SneakyThrows
private void loadNerForTest() {
ClassPathResource responseJson = new ClassPathResource("files/ner_response.json");
storageService.storeObject(RedactionStorageService.StorageIdUtils.getStorageId(TEST_DOSSIER_ID, TEST_FILE_ID, FileType.NER_ENTITIES), responseJson.getInputStream());
}
@Data
@EqualsAndHashCode
@AllArgsConstructor
@ToString
private class Metrics {
private float precision;
private float recall;
}
@Data
@EqualsAndHashCode
@AllArgsConstructor
@ToString
private class Headline {
private int page;
private String headline;
}
}

View File

@ -0,0 +1,13 @@
package drools
import com.iqser.red.service.redaction.v1.server.redaction.model.Section
global Section section
rule "1: Find headlines"
when
Section(text.length() > 1)
then
section.redactHeadline("headline", 1, "Headline found", "n-a.");
end