RED-5275: Calculate precision and recall for headline detection
This commit is contained in:
parent
9d88925ff1
commit
8d88b19915
@ -338,7 +338,6 @@ public class Section {
|
||||
|
||||
Set<Entity> expanded = new HashSet<>();
|
||||
for (var entity : entities) {
|
||||
System.out.println(entity.getWord());
|
||||
|
||||
if (!entity.getType().equals(type) || entity.getTextBefore() == null) {
|
||||
continue;
|
||||
|
||||
@ -0,0 +1,422 @@
|
||||
package com.iqser.red.service.redaction.v1.server;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.assertj.core.api.Assertions;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.kie.api.KieServices;
|
||||
import org.kie.api.builder.KieBuilder;
|
||||
import org.kie.api.builder.KieFileSystem;
|
||||
import org.kie.api.builder.KieModule;
|
||||
import org.kie.api.runtime.KieContainer;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.springframework.amqp.rabbit.core.RabbitTemplate;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.amqp.RabbitAutoConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.Import;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
import com.amazonaws.services.s3.AmazonS3;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.iqser.red.service.persistence.service.v1.api.model.common.JSONPrimitive;
|
||||
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.configuration.Colors;
|
||||
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.dossier.file.FileType;
|
||||
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.type.DictionaryEntry;
|
||||
import com.iqser.red.service.persistence.service.v1.api.model.dossiertemplate.type.Type;
|
||||
import com.iqser.red.service.redaction.v1.model.AnalyzeRequest;
|
||||
import com.iqser.red.service.redaction.v1.model.ChangeType;
|
||||
import com.iqser.red.service.redaction.v1.model.RedactionLog;
|
||||
import com.iqser.red.service.redaction.v1.model.StructureAnalyzeRequest;
|
||||
import com.iqser.red.service.redaction.v1.server.annotate.AnnotationService;
|
||||
import com.iqser.red.service.redaction.v1.server.client.DictionaryClient;
|
||||
import com.iqser.red.service.redaction.v1.server.client.LegalBasisClient;
|
||||
import com.iqser.red.service.redaction.v1.server.client.RulesClient;
|
||||
import com.iqser.red.service.redaction.v1.server.controller.RedactionController;
|
||||
import com.iqser.red.service.redaction.v1.server.redaction.service.AnalyzeService;
|
||||
import com.iqser.red.service.redaction.v1.server.redaction.service.ManualRedactionSurroundingTextService;
|
||||
import com.iqser.red.service.redaction.v1.server.redaction.utils.ResourceLoader;
|
||||
import com.iqser.red.service.redaction.v1.server.storage.RedactionStorageService;
|
||||
import com.iqser.red.storage.commons.StorageAutoConfiguration;
|
||||
import com.iqser.red.storage.commons.service.StorageService;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.ToString;
|
||||
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest(classes = Application.class, webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
|
||||
@Import(HeadlinesGoldStandardIntegrationTest.RedactionIntegrationTestConfiguration.class)
|
||||
public class HeadlinesGoldStandardIntegrationTest {
|
||||
|
||||
private static final String RULES = loadFromClassPath("drools/headlines.drl");
|
||||
|
||||
private static final String HEADLINE = "headline";
|
||||
|
||||
@Autowired
|
||||
private RedactionController redactionController;
|
||||
|
||||
@Autowired
|
||||
private AnnotationService annotationService;
|
||||
|
||||
@Autowired
|
||||
private AnalyzeService analyzeService;
|
||||
|
||||
@Autowired
|
||||
private ObjectMapper objectMapper;
|
||||
|
||||
@MockBean
|
||||
private RulesClient rulesClient;
|
||||
|
||||
@MockBean
|
||||
private DictionaryClient dictionaryClient;
|
||||
|
||||
@Autowired
|
||||
private RedactionStorageService redactionStorageService;
|
||||
|
||||
@Autowired
|
||||
private StorageService storageService;
|
||||
|
||||
@Autowired
|
||||
private ManualRedactionSurroundingTextService manualRedactionSurroundingTextService;
|
||||
|
||||
@MockBean
|
||||
private AmazonS3 amazonS3;
|
||||
|
||||
@MockBean
|
||||
private RabbitTemplate rabbitTemplate;
|
||||
|
||||
@MockBean
|
||||
private LegalBasisClient legalBasisClient;
|
||||
|
||||
private final Map<String, List<String>> dictionary = new HashMap<>();
|
||||
private final Map<String, List<String>> dossierDictionary = new HashMap<>();
|
||||
private final Map<String, String> typeColorMap = new HashMap<>();
|
||||
private final Map<String, Boolean> hintTypeMap = new HashMap<>();
|
||||
private final Map<String, Boolean> caseInSensitiveMap = new HashMap<>();
|
||||
private final Map<String, Boolean> recommendationTypeMap = new HashMap<>();
|
||||
private final Map<String, Integer> rankTypeMap = new HashMap<>();
|
||||
private final Colors colors = new Colors();
|
||||
private final Map<String, Long> reanlysisVersions = new HashMap<>();
|
||||
private final Set<String> deleted = new HashSet<>();
|
||||
|
||||
private final static String TEST_DOSSIER_TEMPLATE_ID = "123";
|
||||
private final static String TEST_DOSSIER_ID = "123";
|
||||
private final static String TEST_FILE_ID = "123";
|
||||
|
||||
|
||||
@Test
|
||||
public void testHeadlineDetection() {
|
||||
|
||||
List<Metrics> metrics = new ArrayList<>();
|
||||
metrics.add(getMetrics("files/RSS/01 - CGA100251 - Acute Oral Toxicity (Up and Down Procedure) - Rat (1).pdf",
|
||||
"files/Headlines/01 - CGA100251 - Acute Oral Toxicity (Up and Down Procedure) - Rat (1)_REDACTION_LOG.json"));
|
||||
metrics.add(getMetrics("files/Trinexapac/91 Trinexapac-ethyl_RAR_01_Volume_1_2018-02-23.pdf",
|
||||
"files/Headlines/91 Trinexapac-ethyl_RAR_01_Volume_1_2018-02-23_REDACTION_LOG.json"));
|
||||
metrics.add(getMetrics("files/Metolachlor/S-Metolachlor_RAR_01_Volume_1_2018-09-06.pdf", "files/Headlines/S-Metolachlor_RAR_01_Volume_1_2018-09-06_REDACTION_LOG.json"));
|
||||
|
||||
float precision = 0;
|
||||
float recall = 0;
|
||||
for (var m : metrics) {
|
||||
precision += m.getPrecision();
|
||||
recall += m.getRecall();
|
||||
}
|
||||
|
||||
precision = precision / metrics.size();
|
||||
recall = recall / metrics.size();
|
||||
|
||||
System.out.println("Precision is: " + precision + " recall is: " + recall);
|
||||
|
||||
Assertions.assertThat(precision).isGreaterThanOrEqualTo(0.45f);
|
||||
Assertions.assertThat(recall).isGreaterThanOrEqualTo(0.69f);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@SneakyThrows
|
||||
private Metrics getMetrics(String fileUrl, String redactionLogUrl) {
|
||||
|
||||
ClassPathResource redactionLogResource = new ClassPathResource(redactionLogUrl);
|
||||
|
||||
Set<Headline> goldStandardHeadlines = new HashSet<>();
|
||||
var goldStandardLog = objectMapper.readValue(redactionLogResource.getInputStream(), RedactionLog.class);
|
||||
goldStandardLog.getRedactionLogEntry().removeIf(r -> !r.isRedacted() || r.getChanges().get(r.getChanges().size() - 1).getType().equals(ChangeType.REMOVED));
|
||||
goldStandardLog.getRedactionLogEntry().forEach(e -> goldStandardHeadlines.add(new Headline(e.getPositions().get(0).getPage(), e.getValue())));
|
||||
|
||||
AnalyzeRequest request = prepareStorage(fileUrl);
|
||||
analyzeService.analyzeDocumentStructure(new StructureAnalyzeRequest(request.getDossierId(), request.getFileId()));
|
||||
analyzeService.analyze(request);
|
||||
|
||||
List<Headline> foundHeadlines = new ArrayList<>();
|
||||
var redactionLog = redactionStorageService.getRedactionLog(TEST_DOSSIER_ID, TEST_FILE_ID);
|
||||
redactionLog.getRedactionLogEntry().forEach(e -> foundHeadlines.add(new Headline(e.getPositions().get(0).getPage(), e.getValue())));
|
||||
|
||||
Set<Headline> correct = new HashSet<>();
|
||||
Set<Headline> missing;
|
||||
Set<Headline> falsePositive = new HashSet<>();
|
||||
for (Headline headline : foundHeadlines) {
|
||||
if (goldStandardHeadlines.contains(headline)) {
|
||||
correct.add(headline);
|
||||
} else {
|
||||
falsePositive.add(headline);
|
||||
}
|
||||
}
|
||||
|
||||
missing = goldStandardHeadlines.stream().filter(h -> !correct.contains(h)).collect(Collectors.toSet());
|
||||
|
||||
float precision = (float) correct.size() / ((float) correct.size() + (float) falsePositive.size());
|
||||
float recall = (float) correct.size() / ((float) correct.size() + (float) missing.size());
|
||||
|
||||
return new Metrics(precision, recall);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Configuration
|
||||
@EnableAutoConfiguration(exclude = {RabbitAutoConfiguration.class, StorageAutoConfiguration.class})
|
||||
public static class RedactionIntegrationTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public KieContainer kieContainer() {
|
||||
|
||||
KieServices kieServices = KieServices.Factory.get();
|
||||
|
||||
KieFileSystem kieFileSystem = kieServices.newKieFileSystem();
|
||||
InputStream input = new ByteArrayInputStream(RULES.getBytes(StandardCharsets.UTF_8));
|
||||
kieFileSystem.write("src/test/resources/drools/headlines.drl", kieServices.getResources().newInputStreamResource(input));
|
||||
KieBuilder kieBuilder = kieServices.newKieBuilder(kieFileSystem);
|
||||
kieBuilder.buildAll();
|
||||
KieModule kieModule = kieBuilder.getKieModule();
|
||||
|
||||
return kieServices.newKieContainer(kieModule.getReleaseId());
|
||||
}
|
||||
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public StorageService inmemoryStorage() {
|
||||
|
||||
return new FileSystemBackedStorageService();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@After
|
||||
public void cleanupStorage() {
|
||||
|
||||
if (this.storageService instanceof FileSystemBackedStorageService) {
|
||||
((FileSystemBackedStorageService) this.storageService).clearStorage();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Before
|
||||
public void stubClients() {
|
||||
|
||||
when(rulesClient.getVersion(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(0L);
|
||||
when(rulesClient.getRules(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(JSONPrimitive.of(RULES));
|
||||
|
||||
loadDictionaryForTest();
|
||||
loadTypeForTest();
|
||||
loadNerForTest();
|
||||
when(dictionaryClient.getVersion(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(0L);
|
||||
when(dictionaryClient.getAllTypesForDossierTemplate(TEST_DOSSIER_TEMPLATE_ID, false)).thenReturn(getTypeResponse());
|
||||
|
||||
when(dictionaryClient.getVersion(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(0L);
|
||||
|
||||
mockDictionaryCalls(null);
|
||||
mockDictionaryCalls(0L);
|
||||
|
||||
when(dictionaryClient.getColors(TEST_DOSSIER_TEMPLATE_ID)).thenReturn(colors);
|
||||
}
|
||||
|
||||
|
||||
private void mockDictionaryCalls(Long version) {
|
||||
|
||||
when(dictionaryClient.getDictionaryForType(HEADLINE + ":" + TEST_DOSSIER_TEMPLATE_ID, version)).then((Answer<Type>) invocation -> getDictionaryResponse(HEADLINE, false));
|
||||
|
||||
}
|
||||
|
||||
|
||||
private void loadDictionaryForTest() {
|
||||
|
||||
dictionary.computeIfAbsent(HEADLINE, v -> new ArrayList<>());
|
||||
}
|
||||
|
||||
|
||||
private void loadTypeForTest() {
|
||||
|
||||
typeColorMap.put(HEADLINE, "#f90707");
|
||||
hintTypeMap.put(HEADLINE, false);
|
||||
caseInSensitiveMap.put(HEADLINE, false);
|
||||
recommendationTypeMap.put(HEADLINE, false);
|
||||
rankTypeMap.put(HEADLINE, 155);
|
||||
|
||||
colors.setSkippedColor("#cccccc");
|
||||
colors.setRequestAddColor("#04b093");
|
||||
colors.setRequestRemoveColor("#04b093");
|
||||
}
|
||||
|
||||
|
||||
private List<Type> getTypeResponse() {
|
||||
|
||||
return typeColorMap.entrySet()
|
||||
.stream()
|
||||
.map(typeColor -> Type.builder()
|
||||
.id(typeColor.getKey() + ":" + TEST_DOSSIER_TEMPLATE_ID)
|
||||
.type(typeColor.getKey())
|
||||
.dossierTemplateId(TEST_DOSSIER_TEMPLATE_ID)
|
||||
.hexColor(typeColor.getValue())
|
||||
.isHint(hintTypeMap.get(typeColor.getKey()))
|
||||
.isCaseInsensitive(caseInSensitiveMap.get(typeColor.getKey()))
|
||||
.isRecommendation(recommendationTypeMap.get(typeColor.getKey()))
|
||||
.rank(rankTypeMap.get(typeColor.getKey()))
|
||||
.build())
|
||||
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
private Type getDictionaryResponse(String type, boolean isDossierDictionary) {
|
||||
|
||||
return Type.builder()
|
||||
.id(type + ":" + TEST_DOSSIER_TEMPLATE_ID)
|
||||
.hexColor(typeColorMap.get(type))
|
||||
.entries(isDossierDictionary ? toDictionaryEntry(dossierDictionary.get(type)) : toDictionaryEntry(dictionary.get(type)))
|
||||
.falsePositiveEntries(new ArrayList<>())
|
||||
.falseRecommendationEntries(new ArrayList<>())
|
||||
.isHint(hintTypeMap.get(type))
|
||||
.isCaseInsensitive(caseInSensitiveMap.get(type))
|
||||
.isRecommendation(recommendationTypeMap.get(type))
|
||||
.rank(rankTypeMap.get(type))
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
private List<DictionaryEntry> toDictionaryEntry(List<String> entries) {
|
||||
|
||||
if (entries == null) {
|
||||
entries = Collections.emptyList();
|
||||
}
|
||||
|
||||
List<DictionaryEntry> dictionaryEntries = new ArrayList<>();
|
||||
entries.forEach(entry -> {
|
||||
dictionaryEntries.add(DictionaryEntry.builder().value(entry).version(reanlysisVersions.getOrDefault(entry, 0L)).deleted(deleted.contains(entry)).build());
|
||||
});
|
||||
return dictionaryEntries;
|
||||
}
|
||||
|
||||
|
||||
@SneakyThrows
|
||||
private AnalyzeRequest prepareStorage(String file) {
|
||||
|
||||
return prepareStorage(file, "files/cv_service_empty_response.json");
|
||||
}
|
||||
|
||||
|
||||
@SneakyThrows
|
||||
private AnalyzeRequest prepareStorage(String file, String cvServiceResponseFile) {
|
||||
|
||||
ClassPathResource pdfFileResource = new ClassPathResource(file);
|
||||
ClassPathResource cvServiceResponseFileResource = new ClassPathResource(cvServiceResponseFile);
|
||||
|
||||
return prepareStorage(pdfFileResource.getInputStream(), cvServiceResponseFileResource.getInputStream());
|
||||
}
|
||||
|
||||
|
||||
@SneakyThrows
|
||||
private AnalyzeRequest prepareStorage(InputStream fileStream, InputStream cvServiceResponseFileStream) {
|
||||
|
||||
AnalyzeRequest request = AnalyzeRequest.builder()
|
||||
.dossierTemplateId(TEST_DOSSIER_TEMPLATE_ID)
|
||||
.dossierId(TEST_DOSSIER_ID)
|
||||
.fileId(TEST_FILE_ID)
|
||||
.lastProcessed(OffsetDateTime.now())
|
||||
.build();
|
||||
|
||||
storageService.storeObject(RedactionStorageService.StorageIdUtils.getStorageId(TEST_DOSSIER_ID, TEST_FILE_ID, FileType.TABLES), cvServiceResponseFileStream);
|
||||
storageService.storeObject(RedactionStorageService.StorageIdUtils.getStorageId(TEST_DOSSIER_ID, TEST_FILE_ID, FileType.ORIGIN), fileStream);
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
|
||||
private static String loadFromClassPath(String path) {
|
||||
|
||||
URL resource = ResourceLoader.class.getClassLoader().getResource(path);
|
||||
if (resource == null) {
|
||||
throw new IllegalArgumentException("could not load classpath resource: drools/rules.drl");
|
||||
}
|
||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(resource.openStream(), StandardCharsets.UTF_8))) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
String str;
|
||||
while ((str = br.readLine()) != null) {
|
||||
sb.append(str).append("\n");
|
||||
}
|
||||
return sb.toString();
|
||||
} catch (IOException e) {
|
||||
throw new IllegalArgumentException("could not load classpath resource: " + path, e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@SneakyThrows
|
||||
private void loadNerForTest() {
|
||||
|
||||
ClassPathResource responseJson = new ClassPathResource("files/ner_response.json");
|
||||
storageService.storeObject(RedactionStorageService.StorageIdUtils.getStorageId(TEST_DOSSIER_ID, TEST_FILE_ID, FileType.NER_ENTITIES), responseJson.getInputStream());
|
||||
}
|
||||
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode
|
||||
@AllArgsConstructor
|
||||
@ToString
|
||||
private class Metrics {
|
||||
|
||||
private float precision;
|
||||
private float recall;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode
|
||||
@AllArgsConstructor
|
||||
@ToString
|
||||
private class Headline {
|
||||
|
||||
private int page;
|
||||
private String headline;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,13 @@
|
||||
package drools
|
||||
|
||||
import com.iqser.red.service.redaction.v1.server.redaction.model.Section
|
||||
|
||||
global Section section
|
||||
|
||||
|
||||
rule "1: Find headlines"
|
||||
when
|
||||
Section(text.length() > 1)
|
||||
then
|
||||
section.redactHeadline("headline", 1, "Headline found", "n-a.");
|
||||
end
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user