Merge branch 'feature/RED-10072' into 'main'

RED-10072: AI description field and toggle for entities

See merge request fforesight/llm-service!25
This commit is contained in:
Maverick Studer 2024-11-07 14:43:58 +01:00
commit f3f917b5fe
10 changed files with 500 additions and 328 deletions

View File

@ -2,3 +2,21 @@ include:
- project: 'gitlab/gitlab'
ref: 'main'
file: 'ci-templates/gradle_java.yml'
deploy:
stage: deploy
tags:
- dind
script:
- echo "Building with gradle version ${BUILDVERSION}"
- gradle -Pversion=${BUILDVERSION} publish
- gradle bootBuildImage --publishImage -PbuildbootDockerHostNetwork=true -Pversion=${BUILDVERSION}
- echo "BUILDVERSION=$BUILDVERSION" >> version.env
artifacts:
reports:
dotenv: version.env
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
- if: $CI_COMMIT_BRANCH =~ /^release/
- if: $CI_COMMIT_BRANCH =~ /^feature/
- if: $CI_COMMIT_TAG

View File

@ -0,0 +1,5 @@
package com.knecon.fforesight.llm.service;
public record EntityAiDescription(String name, String aiDescription) {
}

View File

@ -1,5 +1,6 @@
package com.knecon.fforesight.llm.service;
import java.util.List;
import java.util.Map;
import lombok.AccessLevel;
@ -19,11 +20,13 @@ import lombok.experimental.FieldDefaults;
public class LlmNerMessage {
Map<String, String> identifier;
List<EntityAiDescription> entityAiDescriptions;
String chunksStorageId;
String documentStructureStorageId;
String documentTextStorageId;
String documentPositionStorageId;
String documentPagesStorageId;
String resultStorageId;
long aiCreationVersion;
}

View File

@ -20,5 +20,6 @@ public class LlmNerResponseMessage {
int promptTokens;
int completionTokens;
int duration;
long aiCreationVersion;
}

View File

@ -1,26 +1,25 @@
package com.knecon.fforesight.llm.service;
public class SystemMessages {
import java.util.List;
public static final String NER = """
You are tasked with finding all named entities in the following document.
The named entities should be mapped to the classes PERSON, PII, ADDRESS, COMPANY, and COUNTRY.
A PERSON is any name referring to a human, excluding named methods (e.g. Klingbeil Test is not a name)
Each name should be its own entity, but first name, last name and possibly middle name should be merged. Remember that numbers are never a part of a name.
A PII is any personally identifiable information including but not limited to email address, telephone numbers, fax numbers. Further, use your own judgement to add anything else.
An Address describes a real life location and should always be as complete as possible.
A COMPANY is any company or approving body mentioned in the text. But only if it's not part of an ADDRESS
A COUNTRY is any country. But only if it's not part of an ADDRESS
The output should be strictly JSON format and nothing else, formatted as such:
{
"PERSON": ["Jennifer Durando, BS", "Charlène Hernandez", "Shaw A.", "G J J Lubbe"]
"PII": ["01223 45678", "mimi.lang@smithcorp.com", "+44 (0)1252 392460"],
"ADDRESS": ["Product Safety Labs 2394 US Highway 130 Dayton, NJ 08810 USA", "Syngenta Crop Protection, LLC 410 Swing Road Post Office Box 18300 Greensboro, NC 27419-8300 USA"]
"COMPANY": ["Syngenta", "EFSA"]
"COUNTRY": ["USA"]
}
Always replace linebreaks with whitespaces, but except that, ensure the entities match the text in the document exactly.
It is important you mention all present entities, more importantly, it is preferable to mention too many than too little.
import lombok.experimental.UtilityClass;
@UtilityClass
public class SystemMessageProvider {
public static final String PROMPT_CORRECTION = """
You are an AI assistant specialized in identifying and correcting JSON syntax errors.
The JSON provided below contains syntax errors and cannot be parsed correctly. Your objective is to transform it into a valid JSON format.
Please perform the following steps:
1. **Error Detection:** Identify all syntax errors within the JSON structure.
2. **Error Resolution:** Correct the identified syntax errors to rectify the JSON format.
3. **Data Sanitization:** Remove any elements or data that cannot be automatically fixed to maintain JSON validity.
4. **Validation:** Verify that the final JSON adheres to proper formatting and is fully valid.
**Output Requirements:**
- Return only the corrected and validated JSON.
- Do not include any additional text, explanations, or comments.
""";
public static String RULES_CO_PILOT = """
@ -343,4 +342,72 @@ public class SystemMessages {
intersects(TextRange textRange) -> boolean
""";
public String createNerPrompt(List<EntityAiDescription> entityAiDescriptions) {
StringBuilder sb = new StringBuilder();
sb.append("You are an AI assistant specialized in extracting named entities from text. ");
sb.append("Your task is to identify and categorize all named entities in the provided document into the following classes:\n\n");
for (EntityAiDescription entity : entityAiDescriptions) {
sb.append("- **").append(entity.name()).append("**: ").append(entity.aiDescription()).append("\n");
}
sb.append("\n**Instructions:**\n\n");
sb.append("1. **Entity Handling**:\n");
sb.append(" - Use the classes described above and only those for classification.\n");
sb.append(" - Include all relevant entities. Prefer inclusion over omission.\n");
sb.append(" - Avoid duplicates within each category.\n");
sb.append(" - Assign each entity to only one category, prioritizing specificity.");
sb.append("For instance, if a company's name is part of an address, classify it under ADDRESS only, not under COMPANY.\n");
sb.append("2. **Output Format**: Provide the extracted entities in strict JSON format as shown below.\n");
sb.append(" ```json\n");
sb.append(" {\n");
for (int i = 0; i < entityAiDescriptions.size(); i++) {
EntityAiDescription entity = entityAiDescriptions.get(i);
sb.append(" \"").append(entity.name()).append("\": [\"entity1\", \"entity2\"");
if (i < entityAiDescriptions.size() - 1) {
sb.append("],\n");
} else {
sb.append("]\n");
}
}
sb.append(" }\n");
sb.append(" ```\n\n");
sb.append(" - Ensure there is no additional text or explanation outside the JSON structure.\n");
sb.append(" - Always replace linebreaks with whitespaces.");
sb.append("but except that, ensure that the entities in the JSON exactly match the text from the document, preserving the original formatting and casing.\n");
sb.append(" - Ensure there is no additional text or explanation outside the JSON structure.\n\n");
// examples would possibly be beneficial but cause hallucinations
// sb.append("**Example 1:**\n\n");
// sb.append("_Entities Searched: PERSON, PII, ADDRESS, COMPANY_\n\n");
// sb.append("**Input:**\n```\nContact Bob at bob@techcorp.com or visit TechCorp HQ at 456 Tech Avenue, New York, NY 10001 USA.\n```\n\n");
// sb.append("**Output:**\n```json\n{\n");
// sb.append(" \"PERSON\": [\"Bob\"],\n");
// sb.append(" \"PII\": [\"bob@techcorp.com\"],\n");
// sb.append(" \"ADDRESS\": [\"456 Tech Avenue, New York, NY 10001 USA\"],\n");
// sb.append(" \"COMPANY\": [\"TechCorp\"],\n");
// sb.append("}\n```\n\n");
//
// sb.append("**Example 2:**\n\n");
// sb.append("_Entities Searched: EVENT, PRODUCT, DATE, LOCATION_\n\n");
// sb.append("**Input:**\n```\nThe launch event for the new XYZ Smartphone is scheduled on September 30, 2024, at the Grand Convention Center in Berlin.");
// sb.append("You can pre-order the device starting from August 15, 2024.\n```\n\n");
// sb.append("**Output:**\n```json\n{\n");
// sb.append(" \"EVENT\": [\"launch event\"],\n");
// sb.append(" \"PRODUCT\": [\"XYZ Smartphone\"],\n");
// sb.append(" \"DATE\": [\"September 30, 2024\", \"August 15, 2024\"],\n");
// sb.append(" \"LOCATION\": [\"Grand Convention Center\", \"Berlin\"]\n");
// sb.append("}\n```\n\n");
return sb.toString();
}
}

View File

@ -0,0 +1,274 @@
package com.knecon.fforesight.llm.service.services;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.Arrays;
import java.util.stream.Collectors;
import org.springframework.stereotype.Service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.primitives.Floats;
import com.iqser.red.storage.commons.service.StorageService;
import com.knecon.fforesight.llm.service.LlmNerMessage;
import com.knecon.fforesight.llm.service.document.DocumentData;
import com.knecon.fforesight.llm.service.document.DocumentGraphMapper;
import com.knecon.fforesight.llm.service.document.nodes.Document;
import com.knecon.fforesight.llm.service.utils.StorageIdUtils;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPage;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPageProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPositionData;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPositionDataProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentStructure;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentStructureProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentStructureWrapper;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentTextData;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentTextDataProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.EntryDataProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.NodeTypeProto;
import com.knecon.fforesight.tenantcommons.TenantContext;
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import lombok.experimental.FieldDefaults;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@Service
@RequiredArgsConstructor
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class DocumentBuilderService {
StorageService storageService;
ObjectMapper mapper;
public Document build(LlmNerMessage llmNerMessage) {
DocumentData documentData = new DocumentData();
documentData.setDocumentStructureWrapper(new DocumentStructureWrapper(fetchDocumentStructure(llmNerMessage.getDocumentStructureStorageId())));
documentData.setDocumentTextData(fetchDocumentTextData(llmNerMessage.getDocumentTextStorageId()));
documentData.setDocumentPositionData(fetchDocumentPositionData(llmNerMessage.getDocumentPositionStorageId()));
documentData.setDocumentPages(fetchAllDocumentPages(llmNerMessage.getDocumentPagesStorageId()));
return DocumentGraphMapper.toDocumentGraph(documentData);
}
private DocumentStructureProto.DocumentStructure fetchDocumentStructure(String storageId) {
DocumentStructureProto.DocumentStructure documentStructure;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
documentStructure = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentStructureProto.DocumentStructure.parser());
} else {
DocumentStructure oldDocumentStructure = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentStructure.class);
if (oldDocumentStructure == null) {
return null;
}
documentStructure = convertDocumentStructure(oldDocumentStructure);
}
return documentStructure;
}
private DocumentTextDataProto.AllDocumentTextData fetchDocumentTextData(String storageId) {
DocumentTextDataProto.AllDocumentTextData documentTextData;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
documentTextData = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentTextDataProto.AllDocumentTextData.parser());
} else {
DocumentTextData[] oldDocumentTextData = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentTextData[].class);
if (oldDocumentTextData == null) {
return null;
}
documentTextData = convertAllDocumentTextData(oldDocumentTextData);
}
return documentTextData;
}
private DocumentPositionDataProto.AllDocumentPositionData fetchDocumentPositionData(String storageId) {
DocumentPositionDataProto.AllDocumentPositionData documentPositionData;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
documentPositionData = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentPositionDataProto.AllDocumentPositionData.parser());
} else {
DocumentPositionData[] oldDocumentPositionData = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentPositionData[].class);
if (oldDocumentPositionData == null) {
return null;
}
documentPositionData = convertAllDocumentPositionData(oldDocumentPositionData);
}
return documentPositionData;
}
private DocumentPageProto.AllDocumentPages fetchAllDocumentPages(String storageId) {
DocumentPageProto.AllDocumentPages allDocumentPages;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
allDocumentPages = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentPageProto.AllDocumentPages.parser());
} else {
DocumentPage[] oldDocumentPages = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentPage[].class);
if (oldDocumentPages == null) {
return null;
}
allDocumentPages = convertAllDocumentPages(oldDocumentPages);
}
return allDocumentPages;
}
private <T> T getOldData(String dossierId, String fileId, String fileType, Class<T> valueType) {
String oldStorageId = StorageIdUtils.getStorageId(dossierId, fileId, fileType, ".json");
System.out.println("----------------> LOOKING FOR " + oldStorageId);
try (InputStream inputStream = getObject(TenantContext.getTenantId(), oldStorageId)) {
return mapper.readValue(inputStream, valueType);
} catch (IOException e) {
log.error("Could not read JSON for " + fileType + ", error was: " + e);
return null;
}
}
private static EntryDataProto.EntryData convertEntryData(DocumentStructure.EntryData oldEntryData) {
EntryDataProto.EntryData.Builder builder = EntryDataProto.EntryData.newBuilder();
builder.setType(NodeTypeProto.NodeType.valueOf(oldEntryData.getType().name()));
builder.addAllTreeId(Arrays.stream(oldEntryData.getTreeId()).boxed()
.collect(Collectors.toList()));
builder.addAllAtomicBlockIds(Arrays.asList(oldEntryData.getAtomicBlockIds()));
builder.addAllPageNumbers(Arrays.asList(oldEntryData.getPageNumbers()));
builder.putAllProperties(oldEntryData.getProperties());
if (oldEntryData.getChildren() != null) {
oldEntryData.getChildren()
.forEach(child -> builder.addChildren(convertEntryData(child)));
}
return builder.build();
}
private static DocumentStructureProto.DocumentStructure convertDocumentStructure(DocumentStructure oldStructure) {
DocumentStructureProto.DocumentStructure.Builder newBuilder = DocumentStructureProto.DocumentStructure.newBuilder();
if (oldStructure.getRoot() != null) {
newBuilder.setRoot(convertEntryData(oldStructure.getRoot()));
}
return newBuilder.build();
}
private static DocumentPageProto.DocumentPage convertDocumentPage(DocumentPage oldPage) {
return DocumentPageProto.DocumentPage.newBuilder()
.setNumber(oldPage.getNumber())
.setHeight(oldPage.getHeight())
.setWidth(oldPage.getWidth())
.setRotation(oldPage.getRotation())
.build();
}
private static DocumentPageProto.AllDocumentPages convertAllDocumentPages(DocumentPage[] oldPages) {
DocumentPageProto.AllDocumentPages.Builder allPagesBuilder = DocumentPageProto.AllDocumentPages.newBuilder();
for (DocumentPage oldPage : oldPages) {
DocumentPageProto.DocumentPage newPage = convertDocumentPage(oldPage);
allPagesBuilder.addDocumentPages(newPage);
}
return allPagesBuilder.build();
}
private static DocumentPositionDataProto.DocumentPositionData convertDocumentPositionData(DocumentPositionData oldData) {
DocumentPositionDataProto.DocumentPositionData.Builder builder = DocumentPositionDataProto.DocumentPositionData.newBuilder()
.setId(oldData.getId())
.addAllStringIdxToPositionIdx(Arrays.stream(oldData.getStringIdxToPositionIdx()).boxed()
.collect(Collectors.toList()));
for (float[] pos : oldData.getPositions()) {
DocumentPositionDataProto.DocumentPositionData.Position position = DocumentPositionDataProto.DocumentPositionData.Position.newBuilder()
.addAllValue(Floats.asList(pos))
.build();
builder.addPositions(position);
}
return builder.build();
}
private static DocumentPositionDataProto.AllDocumentPositionData convertAllDocumentPositionData(DocumentPositionData[] oldDataList) {
DocumentPositionDataProto.AllDocumentPositionData.Builder allDataBuilder = DocumentPositionDataProto.AllDocumentPositionData.newBuilder();
for (DocumentPositionData oldData : oldDataList) {
allDataBuilder.addDocumentPositionData(convertDocumentPositionData(oldData));
}
return allDataBuilder.build();
}
private static DocumentTextDataProto.DocumentTextData convertDocumentTextData(DocumentTextData oldData) {
DocumentTextDataProto.DocumentTextData.Builder builder = DocumentTextDataProto.DocumentTextData.newBuilder()
.setId(oldData.getId())
.setPage(oldData.getPage())
.setSearchText(oldData.getSearchText())
.setNumberOnPage(oldData.getNumberOnPage())
.setStart(oldData.getStart())
.setEnd(oldData.getEnd())
.addAllLineBreaks(Arrays.stream(oldData.getLineBreaks()).boxed()
.collect(Collectors.toList()));
return builder.build();
}
private static DocumentTextDataProto.AllDocumentTextData convertAllDocumentTextData(DocumentTextData[] oldDataList) {
DocumentTextDataProto.AllDocumentTextData.Builder allDataBuilder = DocumentTextDataProto.AllDocumentTextData.newBuilder();
for (DocumentTextData oldData : oldDataList) {
allDataBuilder.addDocumentTextData(convertDocumentTextData(oldData));
}
return allDataBuilder.build();
}
@SneakyThrows
private InputStream getObject(String tenantId, String storageId) {
File tempFile = File.createTempFile("temp", ".data");
storageService.downloadTo(tenantId, storageId, tempFile);
return new BufferedInputStream(Files.newInputStream(Paths.get(tempFile.getPath()), StandardOpenOption.DELETE_ON_CLOSE));
}
}

View File

@ -1,14 +1,6 @@
package com.knecon.fforesight.llm.service.services;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@ -27,37 +19,20 @@ import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsUsage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.primitives.Floats;
import com.iqser.red.storage.commons.exception.StorageException;
import com.iqser.red.storage.commons.service.StorageService;
import com.knecon.fforesight.llm.service.ChunkingResponse;
import com.knecon.fforesight.llm.service.EntityAiDescription;
import com.knecon.fforesight.llm.service.LlmNerEntities;
import com.knecon.fforesight.llm.service.LlmNerEntity;
import com.knecon.fforesight.llm.service.LlmNerMessage;
import com.knecon.fforesight.llm.service.LlmServiceSettings;
import com.knecon.fforesight.llm.service.SystemMessages;
import com.knecon.fforesight.llm.service.document.DocumentData;
import com.knecon.fforesight.llm.service.document.DocumentGraphMapper;
import com.knecon.fforesight.llm.service.SystemMessageProvider;
import com.knecon.fforesight.llm.service.document.nodes.Document;
import com.knecon.fforesight.llm.service.document.textblock.TextBlock;
import com.knecon.fforesight.llm.service.models.Chunk;
import com.knecon.fforesight.llm.service.utils.FormattingUtils;
import com.knecon.fforesight.llm.service.utils.StorageIdUtils;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPage;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPageProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPositionData;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPositionDataProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentStructure;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentStructureProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentStructureWrapper;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentTextData;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentTextDataProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.EntryDataProto;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.NodeTypeProto;
import com.knecon.fforesight.tenantcommons.TenantContext;
import lombok.AccessLevel;
@ -72,8 +47,9 @@ import lombok.extern.slf4j.Slf4j;
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class LlmNerService {
StorageService storageService;
LlmRessource llmRessource;
DocumentBuilderService documentBuilderService;
StorageService storageService;
ObjectMapper mapper;
@ -86,31 +62,36 @@ public class LlmNerService {
long start = System.currentTimeMillis();
Document document = buildDocument(llmNerMessage);
ChunkingResponse chunks = readChunks(llmNerMessage.getChunksStorageId());
List<LlmNerEntity> allEntities = new ArrayList<>();
List<LlmNerEntity> allEntities = new LinkedList<>();
if (!llmNerMessage.getEntityAiDescriptions().isEmpty()) {
Document document = documentBuilderService.build(llmNerMessage);
ChunkingResponse chunks = readChunks(llmNerMessage.getChunksStorageId());
log.info("Finished data prep in {} for {}", FormattingUtils.humanizeDuration(System.currentTimeMillis() - start), llmNerMessage.getIdentifier());
allEntities = new LinkedList<>();
List<CompletableFuture<EntitiesWithUsage>> entityFutures = chunks.getData()
.stream()
.map(chunk -> Chunk.create(chunk, document))
.map(this::getLlmNerEntitiesFuture)
.toList();
log.info("Finished data prep in {} for {}", FormattingUtils.humanizeDuration(System.currentTimeMillis() - start), llmNerMessage.getIdentifier());
log.info("Awaiting {} api calls for {}", entityFutures.size(), llmNerMessage.getIdentifier());
for (CompletableFuture<EntitiesWithUsage> entityFuture : entityFutures) {
try {
EntitiesWithUsage entitiesWithUsage = entityFuture.get();
allEntities.addAll(entitiesWithUsage.entities());
completionTokenCount += entitiesWithUsage.completionsUsage().getCompletionTokens();
promptTokenCount += entitiesWithUsage.completionsUsage().getPromptTokens();
} catch (Exception e) {
log.error(e.getMessage(), e);
throw new RuntimeException(e);
List<CompletableFuture<EntitiesWithUsage>> entityFutures = chunks.getData()
.stream()
.map(chunk -> Chunk.create(chunk, document))
.map(chunk -> getLlmNerEntitiesFuture(chunk, llmNerMessage.getEntityAiDescriptions()))
.toList();
log.info("Awaiting {} api calls for {}", entityFutures.size(), llmNerMessage.getIdentifier());
for (CompletableFuture<EntitiesWithUsage> entityFuture : entityFutures) {
try {
EntitiesWithUsage entitiesWithUsage = entityFuture.get();
allEntities.addAll(entitiesWithUsage.entities());
completionTokenCount += entitiesWithUsage.completionTokens();
promptTokenCount += entitiesWithUsage.promptTokens();
} catch (Exception e) {
log.error(e.getMessage(), e);
throw new RuntimeException(e);
}
}
}
log.debug("Storing files for {}", llmNerMessage.getIdentifier());
storageService.storeJSONObject(TenantContext.getTenantId(), llmNerMessage.getResultStorageId(), new LlmNerEntities(allEntities));
long duration = System.currentTimeMillis() - start;
@ -124,43 +105,71 @@ public class LlmNerService {
}
private CompletableFuture<EntitiesWithUsage> getLlmNerEntitiesFuture(Chunk chunk) {
private CompletableFuture<EntitiesWithUsage> getLlmNerEntitiesFuture(Chunk chunk, List<EntityAiDescription> entityAiDescriptions) {
return CompletableFuture.supplyAsync(() -> getLlmNerEntities(chunk));
return CompletableFuture.supplyAsync(() -> getLlmNerEntities(chunk, entityAiDescriptions));
}
@SneakyThrows
private EntitiesWithUsage getLlmNerEntities(Chunk chunk) {
private EntitiesWithUsage getLlmNerEntities(Chunk chunk, List<EntityAiDescription> entityAiDescriptions) {
log.debug("Sending request with text of length {}", chunk.markdown().length());
long start = System.currentTimeMillis();
ChatCompletions chatCompletions = runNer(chunk.markdown());
String nerPrompt = SystemMessageProvider.createNerPrompt(entityAiDescriptions);
ChatCompletions chatCompletions = runLLM(nerPrompt, chunk.markdown());
log.debug("Got response back, used {} prompt tokens, {} completion tokens, took {}",
chatCompletions.getUsage().getPromptTokens(),
chatCompletions.getUsage().getCompletionTokens(),
FormattingUtils.humanizeDuration(System.currentTimeMillis() - start));
return mapEntitiesToDocument(chatCompletions, chunk.parts());
EntitiesWithUsage entitiesWithUsage;
try {
entitiesWithUsage = mapEntitiesToDocument(chatCompletions, chunk.parts());
} catch (JsonProcessingException e) {
String faultyResponse = chatCompletions.getChoices()
.get(0).getMessage().getContent();
ChatCompletions correctionCompletions = runLLM(SystemMessageProvider.PROMPT_CORRECTION, faultyResponse);
try {
entitiesWithUsage = mapEntitiesToDocument(correctionCompletions, chunk.parts());
int completionTokens = chatCompletions.getUsage().getCompletionTokens() + correctionCompletions.getUsage().getCompletionTokens();
int promptTokens = chatCompletions.getUsage().getPromptTokens() + correctionCompletions.getUsage().getPromptTokens();
entitiesWithUsage = new EntitiesWithUsage(entitiesWithUsage.entities(), completionTokens, promptTokens);
} catch (JsonProcessingException ex) {
throw new RuntimeException(ex);
}
}
return entitiesWithUsage;
}
public ChatCompletions runNer(String text) throws InterruptedException {
public ChatCompletions runLLM(String prompt, String input) throws InterruptedException {
List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestSystemMessage(SystemMessages.NER));
chatMessages.add(new ChatRequestUserMessage(text));
chatMessages.add(new ChatRequestSystemMessage(prompt));
chatMessages.add(new ChatRequestUserMessage(input));
ChatCompletionsOptions options = new ChatCompletionsOptions(chatMessages);
options.setResponseFormat(new ChatCompletionsJsonResponseFormat());
options.setTemperature(0.0);
options.setN(1); // only return one choice
return llmRessource.getChatCompletions(options);
}
private EntitiesWithUsage mapEntitiesToDocument(ChatCompletions chatCompletions, List<TextBlock> chunkParts) {
private EntitiesWithUsage mapEntitiesToDocument(ChatCompletions chatCompletions, List<TextBlock> chunkParts) throws JsonProcessingException {
EntitiesWithUsage allEntities = new EntitiesWithUsage(new LinkedList<>(), chatCompletions.getUsage());
for (ChatChoice choice : chatCompletions.getChoices()) {
EntitiesWithUsage allEntities = new EntitiesWithUsage(new LinkedList<>(), chatCompletions.getUsage().getCompletionTokens(), chatCompletions.getUsage().getPromptTokens());
if (!chatCompletions.getChoices().isEmpty()) {
ChatChoice choice = chatCompletions.getChoices()
.get(0);
Map<String, List<String>> entitiesPerType = parseResponse(choice);
List<LlmNerEntity> entitiesFromResponse = entitiesPerType.entrySet()
@ -177,16 +186,11 @@ public class LlmNerService {
}
private Map<String, List<String>> parseResponse(ChatChoice choice) {
private Map<String, List<String>> parseResponse(ChatChoice choice) throws JsonProcessingException {
String response = choice.getMessage().getContent();
try {
return mapper.readValue(response, new TypeReference<Map<String, List<String>>>() {
});
} catch (JsonProcessingException e) {
log.error("Response could not be parsed as JSON, response is {}", response);
throw new RuntimeException(e);
}
return mapper.readValue(response, new TypeReference<>() {
});
}
@ -232,231 +236,7 @@ public class LlmNerService {
}
private Document buildDocument(LlmNerMessage llmNerMessage) {
DocumentData documentData = new DocumentData();
documentData.setDocumentStructureWrapper(new DocumentStructureWrapper(fetchDocumentStructure(llmNerMessage.getDocumentStructureStorageId())));
documentData.setDocumentTextData(fetchDocumentTextData(llmNerMessage.getDocumentTextStorageId()));
documentData.setDocumentPositionData(fetchDocumentPositionData(llmNerMessage.getDocumentPositionStorageId()));
documentData.setDocumentPages(fetchAllDocumentPages(llmNerMessage.getDocumentPagesStorageId()));
return DocumentGraphMapper.toDocumentGraph(documentData);
}
private DocumentStructureProto.DocumentStructure fetchDocumentStructure(String storageId) {
DocumentStructureProto.DocumentStructure documentStructure;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
documentStructure = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentStructureProto.DocumentStructure.parser());
} else {
DocumentStructure oldDocumentStructure = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentStructure.class);
if (oldDocumentStructure == null) {
return null;
}
documentStructure = convertDocumentStructure(oldDocumentStructure);
}
return documentStructure;
}
private DocumentTextDataProto.AllDocumentTextData fetchDocumentTextData(String storageId) {
DocumentTextDataProto.AllDocumentTextData documentTextData;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
documentTextData = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentTextDataProto.AllDocumentTextData.parser());
} else {
DocumentTextData[] oldDocumentTextData = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentTextData[].class);
if (oldDocumentTextData == null) {
return null;
}
documentTextData = convertAllDocumentTextData(oldDocumentTextData);
}
return documentTextData;
}
private DocumentPositionDataProto.AllDocumentPositionData fetchDocumentPositionData(String storageId) {
DocumentPositionDataProto.AllDocumentPositionData documentPositionData;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
documentPositionData = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentPositionDataProto.AllDocumentPositionData.parser());
} else {
DocumentPositionData[] oldDocumentPositionData = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentPositionData[].class);
if (oldDocumentPositionData == null) {
return null;
}
documentPositionData = convertAllDocumentPositionData(oldDocumentPositionData);
}
return documentPositionData;
}
private DocumentPageProto.AllDocumentPages fetchAllDocumentPages(String storageId) {
DocumentPageProto.AllDocumentPages allDocumentPages;
StorageIdUtils.StorageInfo storageInfo = StorageIdUtils.parseStorageId(storageId);
if (storageInfo.fileTypeExtension().contains("proto")) {
allDocumentPages = storageService.readProtoObject(TenantContext.getTenantId(), storageId, DocumentPageProto.AllDocumentPages.parser());
} else {
DocumentPage[] oldDocumentPages = getOldData(storageInfo.dossierId(), storageInfo.fileId(), storageInfo.fileTypeName(), DocumentPage[].class);
if (oldDocumentPages == null) {
return null;
}
allDocumentPages = convertAllDocumentPages(oldDocumentPages);
}
return allDocumentPages;
}
private <T> T getOldData(String dossierId, String fileId, String fileType, Class<T> valueType) {
String oldStorageId = StorageIdUtils.getStorageId(dossierId, fileId, fileType, ".json");
System.out.println("----------------> LOOKING FOR " + oldStorageId);
try (InputStream inputStream = getObject(TenantContext.getTenantId(), oldStorageId)) {
return mapper.readValue(inputStream, valueType);
} catch (IOException e) {
log.error("Could not read JSON for " + fileType + ", error was: " + e);
return null;
}
}
private static EntryDataProto.EntryData convertEntryData(DocumentStructure.EntryData oldEntryData) {
EntryDataProto.EntryData.Builder builder = EntryDataProto.EntryData.newBuilder();
builder.setType(NodeTypeProto.NodeType.valueOf(oldEntryData.getType().name()));
builder.addAllTreeId(Arrays.stream(oldEntryData.getTreeId()).boxed()
.collect(Collectors.toList()));
builder.addAllAtomicBlockIds(Arrays.asList(oldEntryData.getAtomicBlockIds()));
builder.addAllPageNumbers(Arrays.asList(oldEntryData.getPageNumbers()));
builder.putAllProperties(oldEntryData.getProperties());
if (oldEntryData.getChildren() != null) {
oldEntryData.getChildren()
.forEach(child -> builder.addChildren(convertEntryData(child)));
}
return builder.build();
}
private static DocumentStructureProto.DocumentStructure convertDocumentStructure(DocumentStructure oldStructure) {
DocumentStructureProto.DocumentStructure.Builder newBuilder = DocumentStructureProto.DocumentStructure.newBuilder();
if (oldStructure.getRoot() != null) {
newBuilder.setRoot(convertEntryData(oldStructure.getRoot()));
}
return newBuilder.build();
}
private static DocumentPageProto.DocumentPage convertDocumentPage(DocumentPage oldPage) {
return DocumentPageProto.DocumentPage.newBuilder()
.setNumber(oldPage.getNumber())
.setHeight(oldPage.getHeight())
.setWidth(oldPage.getWidth())
.setRotation(oldPage.getRotation())
.build();
}
private static DocumentPageProto.AllDocumentPages convertAllDocumentPages(DocumentPage[] oldPages) {
DocumentPageProto.AllDocumentPages.Builder allPagesBuilder = DocumentPageProto.AllDocumentPages.newBuilder();
for (DocumentPage oldPage : oldPages) {
DocumentPageProto.DocumentPage newPage = convertDocumentPage(oldPage);
allPagesBuilder.addDocumentPages(newPage);
}
return allPagesBuilder.build();
}
private static DocumentPositionDataProto.DocumentPositionData convertDocumentPositionData(DocumentPositionData oldData) {
DocumentPositionDataProto.DocumentPositionData.Builder builder = DocumentPositionDataProto.DocumentPositionData.newBuilder()
.setId(oldData.getId())
.addAllStringIdxToPositionIdx(Arrays.stream(oldData.getStringIdxToPositionIdx()).boxed()
.collect(Collectors.toList()));
for (float[] pos : oldData.getPositions()) {
DocumentPositionDataProto.DocumentPositionData.Position position = DocumentPositionDataProto.DocumentPositionData.Position.newBuilder()
.addAllValue(Floats.asList(pos))
.build();
builder.addPositions(position);
}
return builder.build();
}
private static DocumentPositionDataProto.AllDocumentPositionData convertAllDocumentPositionData(DocumentPositionData[] oldDataList) {
DocumentPositionDataProto.AllDocumentPositionData.Builder allDataBuilder = DocumentPositionDataProto.AllDocumentPositionData.newBuilder();
for (DocumentPositionData oldData : oldDataList) {
allDataBuilder.addDocumentPositionData(convertDocumentPositionData(oldData));
}
return allDataBuilder.build();
}
private static DocumentTextDataProto.DocumentTextData convertDocumentTextData(DocumentTextData oldData) {
DocumentTextDataProto.DocumentTextData.Builder builder = DocumentTextDataProto.DocumentTextData.newBuilder()
.setId(oldData.getId())
.setPage(oldData.getPage())
.setSearchText(oldData.getSearchText())
.setNumberOnPage(oldData.getNumberOnPage())
.setStart(oldData.getStart())
.setEnd(oldData.getEnd())
.addAllLineBreaks(Arrays.stream(oldData.getLineBreaks()).boxed()
.collect(Collectors.toList()));
return builder.build();
}
private static DocumentTextDataProto.AllDocumentTextData convertAllDocumentTextData(DocumentTextData[] oldDataList) {
DocumentTextDataProto.AllDocumentTextData.Builder allDataBuilder = DocumentTextDataProto.AllDocumentTextData.newBuilder();
for (DocumentTextData oldData : oldDataList) {
allDataBuilder.addDocumentTextData(convertDocumentTextData(oldData));
}
return allDataBuilder.build();
}
@SneakyThrows
private InputStream getObject(String tenantId, String storageId) {
File tempFile = File.createTempFile("temp", ".data");
storageService.downloadTo(tenantId, storageId, tempFile);
return new BufferedInputStream(Files.newInputStream(Paths.get(tempFile.getPath()), StandardOpenOption.DELETE_ON_CLOSE));
}
private record EntitiesWithUsage(List<LlmNerEntity> entities, CompletionsUsage completionsUsage) {
private record EntitiesWithUsage(List<LlmNerEntity> entities, int completionTokens, int promptTokens) {
}

View File

@ -11,7 +11,7 @@ import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.knecon.fforesight.llm.service.ChatEvent;
import com.knecon.fforesight.llm.service.SystemMessages;
import com.knecon.fforesight.llm.service.SystemMessageProvider;
import com.knecon.fforesight.tenantcommons.TenantContext;
import lombok.AccessLevel;
@ -35,7 +35,7 @@ public class LlmService {
public void rulesCopilot(List<String> prompt, String userId) {
List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestSystemMessage(SystemMessages.RULES_CO_PILOT));
chatMessages.add(new ChatRequestSystemMessage(SystemMessageProvider.RULES_CO_PILOT));
chatMessages.addAll(prompt.stream()
.map(ChatRequestUserMessage::new)
.toList());

View File

@ -49,7 +49,8 @@ public class MessageHandler {
LlmNerResponseMessage llmNerResponseMessage = new LlmNerResponseMessage(llmNerMessage.getIdentifier(),
usage.promptTokenCount(),
usage.completionTokenCount(),
Math.toIntExact(usage.durationMillis()));
Math.toIntExact(usage.durationMillis()),
llmNerMessage.getAiCreationVersion());
log.info("LLM NER finished for {}", llmNerMessage.getIdentifier());
sendFinishedMessage(llmNerResponseMessage, message);
}

View File

@ -4,6 +4,8 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -19,11 +21,12 @@ import lombok.SneakyThrows;
@Disabled
public class LlmNerServiceTest extends AbstractLlmServiceIntegrationTest {
public static final String DOCUMENT_TEXT = "DOCUMENT_TEXT";
public static final String DOCUMENT_POSITIONS = "DOCUMENT_POSITION";
public static final String DOCUMENT_STRUCTURE = "DOCUMENT_STRUCTURE";
public static final String DOCUMENT_PAGES = "DOCUMENT_PAGES";
public static final String DOCUMENT_CHUNKS = "DOCUMENT_CHUNKS";
public static final String DOCUMENT_TEXT = "DOCUMENT_TEXT.proto";
public static final String DOCUMENT_POSITIONS = "DOCUMENT_POSITION.proto";
public static final String DOCUMENT_STRUCTURE = "DOCUMENT_STRUCTURE.proto";
public static final String DOCUMENT_PAGES = "DOCUMENT_PAGES.proto";
public static final String DOCUMENT_CHUNKS = "DOCUMENT_CHUNKS.json";
public static final String STORAGE_ID = "08904e84-4a5a-4c15-bc13-200237af6434/4d81e891fd3e94dfe0b6c51073ef55b6.";
@Autowired
LlmNerService llmNerService;
@ -34,10 +37,10 @@ public class LlmNerServiceTest extends AbstractLlmServiceIntegrationTest {
@SneakyThrows
public void testLlmNer() {
Path folder = Path.of("/home/kschuettler/Downloads/New Folder (5)/18299ec0-7659-496a-a44a-194bbffb1700/1fb7d49ae389469f60db516cf81a3510");
Path folder = Path.of("/Users/maverickstuder/Downloads/10-09-2024-16-03-47_files_list");
LlmNerMessage message = prepStorage(folder);
llmNerService.runNer(message);
Path tmpFile = Path.of("tmp", "AAA_LLM_ENTITIES", "entities.json");
Path tmpFile = Path.of("/private/tmp", "LLM_ENTITIES", "entities.json");
Files.createDirectories(tmpFile.getParent());
storageService.downloadTo(TEST_TENANT, message.getResultStorageId(), tmpFile.toFile());
}
@ -60,7 +63,7 @@ public class LlmNerServiceTest extends AbstractLlmServiceIntegrationTest {
try (var in = new FileInputStream(relevantFile.toFile())) {
storageService.storeObject(TenantContext.getTenantId(),
folder + relevantFiles.stream()
STORAGE_ID + relevantFiles.stream()
.filter(filePath -> relevantFile.getFileName().toString().contains(filePath))
.findFirst()
.orElseThrow(),
@ -71,14 +74,34 @@ public class LlmNerServiceTest extends AbstractLlmServiceIntegrationTest {
private static LlmNerMessage buildMessage(Path folder) {
List<EntityAiDescription> entityAiDescriptions = new ArrayList<>();
// Add descriptions for each entity type with examples
entityAiDescriptions.add(new EntityAiDescription("PERSON",
"A PERSON is any name referring to a human, excluding named methods (e.g., 'Klingbeil Test' is not a name). Each name should be its own entity, but first name, last name, and possibly middle name should be merged. Numbers are never part of a name. "
+ "For example: 'Jennifer Durando, BS', 'Charlène Hernandez', 'Shaw A.', 'G J J Lubbe'."));
entityAiDescriptions.add(new EntityAiDescription("PII",
"PII refers to personally identifiable information such as email addresses, telephone numbers, fax numbers, or any other information that could uniquely identify an individual. "
+ "For example: '01223 45678', 'mimi.lang@smithcorp.com', '+44 (0)1252 392460'."));
entityAiDescriptions.add(new EntityAiDescription("ADDRESS",
"An ADDRESS describes a real-life location. It should be as complete as possible and may include elements such as street address, city, state, postal code, and country. "
+ "For example: 'Product Safety Labs 2394 US Highway 130 Dayton, NJ 08810 USA', 'Syngenta Crop Protection, LLC 410 Swing Road Post Office Box 18300 Greensboro, NC 27419-8300 USA'."));
entityAiDescriptions.add(new EntityAiDescription("COMPANY",
"A COMPANY is any corporate entity or approving body mentioned in the text, excluding companies mentioned as part of an address. "
+ "For example: 'Syngenta', 'EFSA'."));
entityAiDescriptions.add(new EntityAiDescription("COUNTRY",
"A COUNTRY is any recognized nation mentioned in the text. Countries mentioned as part of an address should not be listed separately. "
+ "For example: 'USA'."));
return LlmNerMessage.builder()
.identifier(Map.of("file", folder.getFileName().toString()))
.chunksStorageId(folder + DOCUMENT_CHUNKS)
.documentPagesStorageId(folder + DOCUMENT_PAGES)
.documentTextStorageId(folder + DOCUMENT_TEXT)
.documentPositionStorageId(folder + DOCUMENT_POSITIONS)
.documentStructureStorageId(folder + DOCUMENT_STRUCTURE)
.resultStorageId(folder + "result")
.entityAiDescriptions(entityAiDescriptions)
.chunksStorageId(STORAGE_ID + DOCUMENT_CHUNKS)
.documentPagesStorageId(STORAGE_ID + DOCUMENT_PAGES)
.documentTextStorageId(STORAGE_ID + DOCUMENT_TEXT)
.documentPositionStorageId(STORAGE_ID + DOCUMENT_POSITIONS)
.documentStructureStorageId(STORAGE_ID + DOCUMENT_STRUCTURE)
.resultStorageId(STORAGE_ID + "result")
.build();
}