Merge branch 'refactor' into 'main'

Refactor

See merge request fforesight/llm-service!23
This commit is contained in:
Kilian Schüttler 2024-09-04 12:40:07 +02:00
commit cbcf3b605b
8 changed files with 151 additions and 94 deletions

View File

@ -4,21 +4,21 @@ public class SystemMessages {
public static final String NER = """
You are tasked with finding all named entities in the following document.
The named entities should be mapped to two classes PII, ADDRESS, COMPANY, and COUNTRY.
A PII is any personally identifiable information including but not limited to names, email address, telephone number, fax numbers. Further, use your own judgement to add anything else.
The named entities should be mapped to the classes PERSON, PII, ADDRESS, COMPANY, and COUNTRY.
A PERSON is any name referring to a human, excluding named methods (e.g. Klingbeil Test is not a name)
Each name should be its own entity, but first name, last name and possibly middle name should be merged. Remember that numbers are never a part of a name.
A PII is any personally identifiable information including but not limited to email address, telephone numbers, fax numbers. Further, use your own judgement to add anything else.
An Address describes a real life location and should always be as complete as possible.
A COMPANY is any company or approving body mentioned in the text. But only if it's not part of an ADDRESS
A COUNTRY is any country. But only if it's not part of an ADDRESS
The output should be strictly JSON format and nothing else, formatted as such:
```
{
"PII": ["Jennifer Durando, BS", "01223 45678", "mimi.lang@smithcorp.com", "+44 (0)1252 392460"],
"PERSON": ["Jennifer Durando, BS", "Charlène Hernandez", "Shaw A.", "G J J Lubbe"]
"PII": ["01223 45678", "mimi.lang@smithcorp.com", "+44 (0)1252 392460"],
"ADDRESS": ["Product Safety Labs 2394 US Highway 130 Dayton, NJ 08810 USA", "Syngenta Crop Protection, LLC 410 Swing Road Post Office Box 18300 Greensboro, NC 27419-8300 USA"]
"COMPANY": ["Syngenta", "EFSA"]
"COUNTRY": ["USA"]
}
```
Always replace linebreaks with whitespaces, but except that, ensure the entities match the text in the document exactly.
It is important you mention all present entities, more importantly, it is preferable to mention too many than too little.
""";

View File

@ -10,7 +10,6 @@ import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Stream;
import com.google.common.base.Functions;
import com.knecon.fforesight.llm.service.document.textblock.ConcatenatedTextBlock;
import com.knecon.fforesight.llm.service.document.textblock.TextBlock;
@ -30,6 +29,9 @@ public class ConsecutiveTextBlockCollector implements Collector<TextBlock, List<
public BiConsumer<List<ConcatenatedTextBlock>, TextBlock> accumulator() {
return (existingList, textBlock) -> {
if (textBlock.isEmpty()) {
return;
}
if (existingList.isEmpty()) {
ConcatenatedTextBlock ctb = ConcatenatedTextBlock.empty();
ctb.concat(textBlock);

View File

@ -296,6 +296,22 @@ public class DocumentTree {
}
public Optional<Entry> findEntryById(List<Integer> treeId) {
if (treeId.isEmpty()) {
return Optional.of(root);
}
Entry entry = root;
for (int id : treeId) {
if (entry.children.size() <= id) {
return Optional.empty();
}
entry = entry.children.get(id);
}
return Optional.of(entry);
}
public Stream<Entry> mainEntries() {
return root.children.stream();

View File

@ -0,0 +1,36 @@
package com.knecon.fforesight.llm.service.models;
import java.util.List;
import java.util.Optional;
import com.knecon.fforesight.llm.service.ChunkingResponseData;
import com.knecon.fforesight.llm.service.document.ConsecutiveTextBlockCollector;
import com.knecon.fforesight.llm.service.document.DocumentTree;
import com.knecon.fforesight.llm.service.document.nodes.Document;
import com.knecon.fforesight.llm.service.document.textblock.TextBlock;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public record Chunk(String markdown, List<TextBlock> parts) {
public static Chunk create(ChunkingResponseData chunkingResponseData, Document document) {
return new Chunk(chunkingResponseData.getText(), getChunkParts(document, chunkingResponseData.getTreeIds()));
}
private static List<TextBlock> getChunkParts(com.knecon.fforesight.llm.service.document.nodes.Document document, List<List<Integer>> treeIds) {
return treeIds.stream()
.map(treeId -> {
Optional<DocumentTree.Entry> entry = document.getDocumentTree().findEntryById(treeId);
if (entry.isEmpty()) {
throw new RuntimeException("Could not find node with id " + treeId);
}
return entry.get().getNode().getTextBlock();
})
.collect(new ConsecutiveTextBlockCollector());
}
}

View File

@ -15,6 +15,7 @@ import org.springframework.stereotype.Service;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
@ -25,18 +26,15 @@ import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.iqser.red.storage.commons.service.StorageService;
import com.knecon.fforesight.llm.service.ChunkingResponse;
import com.knecon.fforesight.llm.service.ChunkingResponseData;
import com.knecon.fforesight.llm.service.LlmNerEntities;
import com.knecon.fforesight.llm.service.LlmNerEntity;
import com.knecon.fforesight.llm.service.LlmNerMessage;
import com.knecon.fforesight.llm.service.SystemMessages;
import com.knecon.fforesight.llm.service.document.ConsecutiveTextBlockCollector;
import com.knecon.fforesight.llm.service.document.DocumentData;
import com.knecon.fforesight.llm.service.document.DocumentGraphMapper;
import com.knecon.fforesight.llm.service.document.DocumentTree;
import com.knecon.fforesight.llm.service.document.nodes.Document;
import com.knecon.fforesight.llm.service.document.nodes.SemanticNode;
import com.knecon.fforesight.llm.service.document.textblock.TextBlock;
import com.knecon.fforesight.llm.service.models.Chunk;
import com.knecon.fforesight.llm.service.utils.FormattingUtils;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPage;
import com.knecon.fforesight.service.layoutparser.internal.api.data.redaction.DocumentPositionData;
@ -56,38 +54,46 @@ import lombok.extern.slf4j.Slf4j;
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class LlmNerService {
public static final String JSON_PREFIX = "```json";
public static final String JSON_PREFIX2 = "```";
public static final String SUFFIX = "```";
StorageService storageService;
LlmRessource llmRessource;
ObjectMapper mapper;
@SneakyThrows
public Usage runNer(LlmNerMessage llmNerMessage) {
public synchronized Usage runNer(LlmNerMessage llmNerMessage) {
int completionTokenCount = 0;
int promptTokenCount = 0;
llmRessource.resetConcurrencyLimiter();
long start = System.currentTimeMillis();
Document document = buildDocument(llmNerMessage);
ChunkingResponse chunks = readChunks(llmNerMessage.getChunksStorageId());
List<LlmNerEntity> allEntities = Collections.synchronizedList(new LinkedList<>());
log.info("Finished data prep for {}", llmNerMessage.getIdentifier());
List<LlmNerEntity> allEntities = new LinkedList<>();
log.info("Finished data prep in {} for {}", FormattingUtils.humanizeDuration(System.currentTimeMillis() - start), llmNerMessage.getIdentifier());
List<CompletableFuture<EntitiesWithUsage>> entityFutures = chunks.getData()
.stream()
.map(chunk -> getLlmNerEntitiesFuture(chunk, document))
.map(chunk -> Chunk.create(chunk, document))
.map(this::getLlmNerEntitiesFuture)
.toList();
log.info("Awaiting api calls for {}", llmNerMessage.getIdentifier());
log.info("Awaiting {} api calls for {}", entityFutures.size(), llmNerMessage.getIdentifier());
for (CompletableFuture<EntitiesWithUsage> entityFuture : entityFutures) {
EntitiesWithUsage entitiesWithUsage = entityFuture.get();
allEntities.addAll(entitiesWithUsage.entities());
completionTokenCount += entitiesWithUsage.completionsUsage().getCompletionTokens();
promptTokenCount += entitiesWithUsage.completionsUsage().getPromptTokens();
try {
EntitiesWithUsage entitiesWithUsage = entityFuture.get();
allEntities.addAll(entitiesWithUsage.entities());
completionTokenCount += entitiesWithUsage.completionsUsage().getCompletionTokens();
promptTokenCount += entitiesWithUsage.completionsUsage().getPromptTokens();
} catch (Exception e) {
log.error(e.getMessage(), e);
throw new RuntimeException(e);
}
}
log.info("Storing files for {}", llmNerMessage.getIdentifier());
log.debug("Storing files for {}", llmNerMessage.getIdentifier());
storageService.storeJSONObject(TenantContext.getTenantId(), llmNerMessage.getResultStorageId(), new LlmNerEntities(allEntities));
long duration = System.currentTimeMillis() - start;
log.info("Found {} named entities for {} in {} with {} prompt tokens and {} completion tokens.",
@ -100,23 +106,23 @@ public class LlmNerService {
}
private CompletableFuture<EntitiesWithUsage> getLlmNerEntitiesFuture(ChunkingResponseData chunk, Document document) {
private CompletableFuture<EntitiesWithUsage> getLlmNerEntitiesFuture(Chunk chunk) {
return CompletableFuture.supplyAsync(() -> getLlmNerEntities(chunk, document));
return CompletableFuture.supplyAsync(() -> getLlmNerEntities(chunk));
}
@SneakyThrows
private EntitiesWithUsage getLlmNerEntities(ChunkingResponseData chunk, Document document) {
private EntitiesWithUsage getLlmNerEntities(Chunk chunk) {
log.debug("Sending request with text of length {}", chunk.getText().length());
log.debug("Sending request with text of length {}", chunk.markdown().length());
long start = System.currentTimeMillis();
ChatCompletions chatCompletions = runNer(chunk.getText());
ChatCompletions chatCompletions = runNer(chunk.markdown());
log.debug("Got response back, used {} prompt tokens, {} completion tokens, took {}",
chatCompletions.getUsage().getPromptTokens(),
chatCompletions.getUsage().getCompletionTokens(),
FormattingUtils.humanizeDuration(System.currentTimeMillis() - start));
return mapEntitiesToDocument(chatCompletions, getChunkParts(document, chunk.getTreeIds()), document);
return mapEntitiesToDocument(chatCompletions, chunk.parts());
}
@ -126,68 +132,47 @@ public class LlmNerService {
chatMessages.add(new ChatRequestSystemMessage(SystemMessages.NER));
chatMessages.add(new ChatRequestUserMessage(text));
ChatCompletionsOptions options = new ChatCompletionsOptions(chatMessages);
options.setResponseFormat(new ChatCompletionsJsonResponseFormat());
options.setTemperature(0.0);
return llmRessource.getChatCompletions(options);
}
private List<TextBlock> getChunkParts(Document document, List<List<Integer>> treeIds) {
return treeIds.stream()
.map(treeId -> document.getDocumentTree().getEntryById(treeId))
.map(DocumentTree.Entry::getNode)
.map(SemanticNode::getTextBlock)
.collect(new ConsecutiveTextBlockCollector());
}
private EntitiesWithUsage mapEntitiesToDocument(ChatCompletions chatCompletions, List<TextBlock> chunkParts, Document document) {
private EntitiesWithUsage mapEntitiesToDocument(ChatCompletions chatCompletions, List<TextBlock> chunkParts) {
EntitiesWithUsage allEntities = new EntitiesWithUsage(new LinkedList<>(), chatCompletions.getUsage());
for (ChatChoice choice : chatCompletions.getChoices()) {
String response = parseResponse(choice);
if (response == null) {
continue;
}
try {
Map<String, List<String>> entitiesPerType = mapper.readValue(response, new TypeReference<Map<String, List<String>>>() {
});
Map<String, List<String>> entitiesPerType = parseResponse(choice);
List<LlmNerEntity> entitiesFromResponse = entitiesPerType.entrySet()
.stream()
.flatMap(entitiesWithType -> entitiesWithType.getValue()
.stream()
.distinct()
.flatMap(entity -> findInChunks(entity, chunkParts, entitiesWithType.getKey(), document)))
.toList();
allEntities.entities().addAll(entitiesFromResponse);
} catch (JsonProcessingException e) {
logMalformedResponse(response);
log.error(e.getMessage());
}
List<LlmNerEntity> entitiesFromResponse = entitiesPerType.entrySet()
.stream()
.flatMap(entitiesWithType -> entitiesWithType.getValue()
.stream()
.distinct()
.flatMap(entity -> findInChunks(entity, chunkParts, entitiesWithType.getKey())))
.toList();
allEntities.entities().addAll(entitiesFromResponse);
}
return allEntities;
}
private static String parseResponse(ChatChoice choice) {
private Map<String, List<String>> parseResponse(ChatChoice choice) {
String response = choice.getMessage().getContent();
if (response.startsWith(JSON_PREFIX)) {
response = response.substring(JSON_PREFIX.length());
} else if (response.startsWith(JSON_PREFIX2)) {
response = response.substring(JSON_PREFIX2.length());
try {
return mapper.readValue(response, new TypeReference<Map<String, List<String>>>() {
});
} catch (JsonProcessingException e) {
log.error("Response could not be parsed as JSON, response is {}", response);
throw new RuntimeException(e);
}
if (response.endsWith(SUFFIX)) {
response = response.substring(0, response.length() - SUFFIX.length());
}
return response;
}
private Stream<LlmNerEntity> findInChunks(String entity, List<TextBlock> chunkParts, String type, Document document) {
private Stream<LlmNerEntity> findInChunks(String entity, List<TextBlock> chunkParts, String type) {
Pattern entityPattern = Pattern.compile(String.format("(?:\\b|\\s)(%s)(?:\\b|\\s)", Pattern.quote(entity)));
for (TextBlock chunkPart : chunkParts) {
@ -201,7 +186,7 @@ public class LlmNerService {
.toList();
if (!entitiesInCurrentChunk.stream()
.allMatch(nerEntity -> document.getTextBlock().subSequence(nerEntity.getStartOffset(), nerEntity.getEndOffset()).equals(nerEntity.getValue()))) {
.allMatch(nerEntity -> chunkPart.subSequence(nerEntity.getStartOffset(), nerEntity.getEndOffset()).equals(nerEntity.getValue()))) {
log.error("Entities have wrong value, expected {}, actual {}",
entity,
entitiesInCurrentChunk.stream()
@ -223,12 +208,6 @@ public class LlmNerService {
}
private static void logMalformedResponse(String response) {
log.error("Response could not be parsed as JSON, response is {}", response);
}
private ChunkingResponse readChunks(String chunksStorageId) {
return storageService.readJSONObject(TenantContext.getTenantId(), chunksStorageId, ChunkingResponse.class);

View File

@ -25,16 +25,15 @@ public class LlmRessource {
OpenAIAsyncClient asyncClient;
OpenAIClient client;
LlmServiceSettings settings;
Semaphore concurrencyLimitingSemaphore;
Semaphore concurrencyLimiter;
public LlmRessource(LlmServiceSettings settings) {
this.settings = settings;
this.concurrencyLimitingSemaphore = new Semaphore(settings.getConcurrency());
this.concurrencyLimiter = new Semaphore(settings.getConcurrency());
this.asyncClient = new OpenAIClientBuilder().credential(new KeyCredential(settings.getAzureOpenAiKey())).endpoint(settings.getAzureOpenAiEndpoint()).buildAsyncClient();
this.client = new OpenAIClientBuilder().credential(new KeyCredential(settings.getAzureOpenAiKey())).endpoint(settings.getAzureOpenAiEndpoint()).buildClient();
log.info("Initialized client for endpoint {} and key {}", settings.getAzureOpenAiEndpoint(), settings.getAzureOpenAiKey());
}
@ -48,11 +47,23 @@ public class LlmRessource {
public ChatCompletions getChatCompletions(ChatCompletionsOptions options) throws InterruptedException {
concurrencyLimitingSemaphore.acquire();
ChatCompletions chatCompletions = client.getChatCompletions(settings.getModel(), options);
concurrencyLimitingSemaphore.release();
try {
concurrencyLimiter.acquire();
return client.getChatCompletions(settings.getModel(), options);
} finally {
concurrencyLimiter.release();
}
}
return chatCompletions;
public void resetConcurrencyLimiter() {
int currentPermits = concurrencyLimiter.availablePermits();
if (currentPermits > settings.getConcurrency()) {
concurrencyLimiter.acquireUninterruptibly(currentPermits - settings.getConcurrency());
} else if (currentPermits < settings.getConcurrency()) {
concurrencyLimiter.release(settings.getConcurrency() - currentPermits);
}
}
}

View File

@ -15,15 +15,15 @@ public class FormattingUtils {
} else if (duration < 60 * 60 * 1000) {
long minutes = duration / (60 * 1000);
long remainingMillis = duration % (60 * 1000);
double seconds = remainingMillis / 1000.0;
return String.format("%d:%.1f m", minutes, seconds);
long seconds = remainingMillis / 1000;
return String.format("%d:%d m", minutes, seconds);
} else {
long hours = duration / (60 * 60 * 1000);
long remainingMillis = duration % (60 * 60 * 1000);
long minutes = remainingMillis / (60 * 1000);
remainingMillis = remainingMillis % (60 * 1000);
double seconds = remainingMillis / 1000.0;
return String.format("%d:%d:%.1f h", hours, minutes, seconds);
long seconds = remainingMillis / 1000;
return String.format("%d:%d:%d h", hours, minutes, seconds);
}
}
}

View File

@ -1,6 +1,9 @@
package com.knecon.fforesight.llm.service.queue;
import org.springframework.amqp.AmqpRejectAndDontRequeueException;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.amqp.core.Message;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
@ -26,6 +29,7 @@ import lombok.extern.slf4j.Slf4j;
public class MessageHandler {
public static final String LLM_NER_REQUEST_LISTENER_ID = "llm-ner-request-listener";
private final static String X_PIPELINE_PREFIX = "X-PIPE-";
LlmNerService llmNerService;
ObjectMapper mapper;
@ -36,10 +40,6 @@ public class MessageHandler {
@RabbitListener(id = LLM_NER_REQUEST_LISTENER_ID, concurrency = "1")
public void receiveNerRequest(Message message) {
if (message.getMessageProperties().isRedelivered()) {
throw new AmqpRejectAndDontRequeueException("Redelivered LLM NER Request, aborting...");
}
LlmNerMessage llmNerMessage = parseLlmNerMessage(message);
log.info("Starting NER with LLM for {}", llmNerMessage.getIdentifier());
@ -51,7 +51,20 @@ public class MessageHandler {
usage.completionTokenCount(),
Math.toIntExact(usage.durationMillis()));
log.info("LLM NER finished for {}", llmNerMessage.getIdentifier());
rabbitTemplate.convertAndSend(QueueNames.LLM_NER_RESPONSE_EXCHANGE, TenantContext.getTenantId(), llmNerResponseMessage);
sendFinishedMessage(llmNerResponseMessage, message);
}
public void sendFinishedMessage(LlmNerResponseMessage llmNerResponseMessage, Message message) {
rabbitTemplate.convertAndSend(QueueNames.LLM_NER_RESPONSE_EXCHANGE, TenantContext.getTenantId(), llmNerResponseMessage, m -> {
var forwardHeaders = message.getMessageProperties().getHeaders().entrySet()
.stream()
.filter(e -> e.getKey().toUpperCase(Locale.ROOT).startsWith(X_PIPELINE_PREFIX))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
m.getMessageProperties().getHeaders().putAll(forwardHeaders);
return m;
});
}