Biomedical text is a catch-all term that broadly encompasses documents such as research articles, clinical trial reports, and patient records, serving as rich repositories of information about various biological, medical, and scientific concepts. Research papers in the biomedical field present novel breakthroughs in areas like drug discovery, drug side effects, and new disease treatments. Clinical trial reports offer in-depth details on the safety, efficacy, and side effects of new medications or treatments. Meanwhile, patient records contain comprehensive medical histories, diagnoses, treatment plans, and outcomes recorded by physicians and healthcare professionals.
Mining these texts allows practitioners to extract valuable insights, which can be beneficial for various downstream tasks. You could mine text to identify adverse drug reaction extractions, build automated medical coding algorithms or build information retrieval or question-answering systems that can help extract information from vast research corpora. However, one issue affecting biomedical document processing is the often unstructured nature of the text. For example, researchers might use different terms to refer to the same concept. What one researcher calls a “heart attack” might be referred to as a “myocardial infarction” by another. Similarly, in drug-related documentation, technical and common names may be used interchangeably. For instance, “Acetaminophen” is the technical name of a drug, while “Paracetamol” is its more common counterpart. The prevalence of abbreviations also adds another layer of complexity; for instance, “Nitric Oxide” might be referred to as “NO” in another context. Despite these varying terms referring to the same concept, these variations make it difficult for a layman or a text-processing algorithm to determine whether they refer to the same concept. Thus, Entity Linking becomes crucial in this situation.
- What is Entity Linking?
- Where do LLMs come in here?
- Experimental Setup
- Processing the Dataset
- Zero-Shot Entity Linking using the LLM
- LLM with Retrieval Augmented Generation for Entity Linking
- Zero-Shot Entity Extraction with the LLM and an External KB Linker
- Fine-tuned Entity Extraction with the LLM and an External KB Linker
- Benchmarking Scispacy
- Takeaways
- Limitations
- References
When text is unstructured, accurately identifying and standardizing medical concepts becomes crucial. To achieve this, medical terminology systems such as Unified Medical Language System (UMLS) [1], Systematized Medical Nomenclature for Medicine–Clinical Terminology (SNOMED-CT) [2], and Medical Subject Headings (MeSH) [3] play an essential role. These systems provide a comprehensive and standardized set of medical concepts, each uniquely identified by an alphanumeric code.
Entity linking involves recognizing and extracting entities within the text and mapping them to standardized concepts in a large terminology. In this context, a Knowledge Base (KB) refers to a detailed database containing standardized information and concepts related to the terminology, such as medical terms, diseases, and drugs. Typically, a KB is expert-curated and designed, containing detailed information about the concepts, including variations of the terms that could be used to refer to the concept, or how it is related to other concepts.
Entity recognition entails extracting words or phrases that are significant in the context of our task. In this context, it usually refers to extraction of biomedical terms such as drugs, diseases etc. Typically, lookup-based methods or machine learning/deep learning-based systems are often used for entity recognition. Linking the entities to a KB usually involves a retriever system that indexes the KB. This system takes each extracted entity from the previous step and retrieves likely identifiers from the KB. The retriever here is also an abstraction, which may be sparse (BM-25), dense (embedding-based), or even a generative system (like a Large Language Model, (LLM)) that has encoded the KB in its parameters.
I’ve been curious for a while about the best ways to integrate LLMs into biomedical and clinical text-processing pipelines. Given that Entity Linking is an important part of such pipelines, I decided to explore how best LLMs can be utilized for this task. Specifically I investigated the following setups:
- Zero-Shot Entity Linking with an LLM: Leveraging an LLM to directly identify all entities and concept IDs from input biomedical texts without any fine-tuning
- LLM with Retrieval Augmented Generation (RAG): Utilizing the LLM within a RAG framework by injecting information about relevant concept IDs in the prompt to identify the relevant concept IDs.
- Zero-Shot Entity Extraction with LLM with an External KB Linker: Employing the LLM for zero-shot entity extraction from biomedical texts, with an external linker/retriever for mapping the entities to concept IDs.
- Fine-tuned Entity Extraction with an External KB Linker: Finetuning the LLM first on the entity extraction task, and using it as an entity extractor with an external linker/retriever for mapping the entities to concept IDs.
- Comparison with an existing pipeline: How do these methods fare comparted to Scispacy, a commonly used library for biomedical text processing?
All code and resources related to this article are made available at this Github repository, under the entity_linking folder. Feel free to pull the repository and run the notebooks directly to run these experiments. Please let me know if you have any feedback or observations or if you notice any mistakes!
To conduct these experiments, we utilize the Mistral-7B Instruct model [9] as our Large Language Model (LLM). For the medical terminology to link entities against, we utilize the MeSH terminology. To quote the National Library of Medicine website:
“The Medical Subject Headings (MeSH) thesaurus is a controlled and hierarchically-organized vocabulary produced by the National Library of Medicine. It is used for indexing, cataloging, and searching of biomedical and health-related information.”
We utilize the BioCreative-V-CDR-Corpus [4,5,6,7,8] for evaluation. This dataset contains annotations of disease and chemical entities, along with their corresponding MeSH IDs. For evaluation purposes, we randomly sample 100 data points from the test set. We used a version of the MeSH KB provided by Scispacy [10,11], which contains information about the MeSH identifiers, such as definitions and entities corresponding to each ID.
For performance evaluation, we calculate two metrics. The first metric relates to the entity extraction performance. The original dataset contains all mentions of entities in the text, annotated at the substring level. A strict evaluation would check if the algorithm has outputted all occurrences of all entities. However, we simplify this process for easier evaluation; we lower-case and de-duplicate the entities in the ground truth. We then calculated the Precision, Recall and F1 score for each instance and calculate the macro-average for each metric.
Suppose you have a set of actual entities, ground_truth
, and a set of entities predicted by a model, pred
for each input text. The true positives TP
can be determined by identifying the common elements between pred
and ground_truth
, essentially by calculating the intersection of these two sets.
For each input, we can then calculate:
precision = len(TP)/ len(pred)
,
recall = len(TP) / len(ground_truth)
and
f1 = 2 * precision * recall / (precision + recall)
and finally calculate the macro-average for each metric by summing them all up and dividing by the number of datapoints in our test set.
For evaluating the overall entity linking performance, we again calculate the same metrics. In this case, for each input datapoint, we have a set of tuples, where each tuple is a (entity, mesh_id)
pair. The metrics are otherwise calculated the same way.
Right, let’s kick off things by first defining some helper functions for processing our dataset.
def parse_dataset(file_path):
"""
Parse the BioCreative Dataset.Args:
- file_path (str): Path to the file containing the documents.
Returns:
- list of dict: A list where each element is a dictionary representing a document.
"""
documents = []
current_doc = None
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if not line:
continue
if "|t|" in line:
if current_doc:
documents.append(current_doc)
id_, title = line.split("|t|", 1)
current_doc = {'id': id_, 'title': title, 'abstract': '', 'annotations': []}
elif "|a|" in line:
_, abstract = line.split("|a|", 1)
current_doc['abstract'] = abstract
else:
parts = line.split("\t")
if parts[1] == "CID":
continue
annotation = {
'text': parts[3],
'type': parts[4],
'identifier': parts[5]
}
current_doc['annotations'].append(annotation)
if current_doc:
documents.append(current_doc)
return documents
def deduplicate_annotations(documents):
"""
Filter documents to ensure annotation consistency.
Args:
- documents (list of dict): The list of documents to be checked.
"""
for doc in documents:
doc["annotations"] = remove_duplicates(doc["annotations"])
def remove_duplicates(dict_list):
"""
Remove duplicate dictionaries from a list of dictionaries.
Args:
- dict_list (list of dict): A list of dictionaries from which duplicates are to be removed.
Returns:
- list of dict: A list of dictionaries after removing duplicates.
"""
unique_dicts = []
seen = set()
for d in dict_list:
dict_tuple = tuple(sorted(d.items()))
if dict_tuple not in seen:
seen.add(dict_tuple)
unique_dicts.append(d)
return unique_dicts
We first parse the dataset from the text files provided in the original dataset. The original dataset includes the title, abstract, and all entities annotated with their entity type (Disease or Chemical), their substring indices indicating their exact location in the text, along with their MeSH IDs. While processing our dataset, we make a few simplifications. We disregard the substring indices and the entity type. Moreover, we de-duplicate annotations that share the same entity name and MeSH ID. At this stage, we only de-duplicate in a case-sensitive manner, meaning if the same entity appears in both lower and upper case across the document, we retain both instances in our processing so far.
First, we aim to determine whether the LLM already possesses an understanding of MeSH terminology due to its pre-training, and if it can function as a zero-shot entity linker. By zero-shot, we mean the LLM’s capability to directly link entities to their MeSH IDs from biomedical text based on its intrinsic knowledge, without depending on an external KB linker. This hypothesis is not entirely unrealistic, considering the availability of information about MeSH online, which makes it possible that the model might have encountered MeSH-related information during its pre-training phase. However, even if the LLM was trained with such information, it is unlikely that this alone would enable the model to perform zero-shot entity linking effectively, due to the complexity of biomedical terminology and the precision required for accurate entity linking.
To evaluate this, we provide the input text to the LLM and directly prompt it to predict the entities and corresponding MeSH IDs. Additionally, we create a few-shot prompt by sampling three data points from the training dataset. It is important to clarify the distinction in the use of “zero-shot” and “few-shot” here: “zero-shot” refers to the LLM as a whole performing entity linking without prior specific training on this task, while “few-shot” refers to the prompting strategy employed in this context.
To calculate our metrics, we define functions for evaluating the performance:
def calculate_entity_metrics(gt, pred):
"""
Calculate precision, recall, and F1-score for entity recognition.Args:
- gt (list of dict): A list of dictionaries representing the ground truth entities.
Each dictionary should have a key "text" with the entity text.
- pred (list of dict): A list of dictionaries representing the predicted entities.
Similar to `gt`, each dictionary should have a key "text".
Returns:
tuple: A tuple containing precision, recall, and F1-score (in that order).
"""
ground_truth_set = set([x["text"].lower() for x in gt])
predicted_set = set([x["text"].lower() for x in pred])
# True positives are predicted items that are in the ground truth
true_positives = len(predicted_set.intersection(ground_truth_set))
# Precision calculation
if len(predicted_set) == 0:
precision = 0
else:
precision = true_positives / len(predicted_set)
# Recall calculation
if len(ground_truth_set) == 0:
recall = 0
else:
recall = true_positives / len(ground_truth_set)
# F1-score calculation
if precision + recall == 0:
f1_score = 0
else:
f1_score = 2 * (precision * recall) / (precision + recall)
return precision, recall, f1_score
def calculate_mesh_metrics(gt, pred):
"""
Calculate precision, recall, and F1-score for matching MeSH (Medical Subject Headings) codes.
Args:
- gt (list of dict): Ground truth data
- pred (list of dict): Predicted data
Returns:
tuple: A tuple containing precision, recall, and F1-score (in that order).
"""
ground_truth = []
for item in gt:
mesh_codes = item["identifier"]
if mesh_codes == "-1":
mesh_codes = "None"
mesh_codes_split = mesh_codes.split("|")
for elem in mesh_codes_split:
combined_elem = {"entity": item["text"].lower(), "identifier": elem}
if combined_elem not in ground_truth:
ground_truth.append(combined_elem)
predicted = []
for item in pred:
mesh_codes = item["identifier"]
mesh_codes_split = mesh_codes.strip().split("|")
for elem in mesh_codes_split:
combined_elem = {"entity": item["text"].lower(), "identifier": elem}
if combined_elem not in predicted:
predicted.append(combined_elem)
# True positives are predicted items that are in the ground truth
true_positives = len([x for x in predicted if x in ground_truth])
# Precision calculation
if len(predicted) == 0:
precision = 0
else:
precision = true_positives / len(predicted)
# Recall calculation
if len(ground_truth) == 0:
recall = 0
else:
recall = true_positives / len(ground_truth)
# F1-score calculation
if precision + recall == 0:
f1_score = 0
else:
f1_score = 2 * (precision * recall) / (precision + recall)
return precision, recall, f1_score
Let’s now run the model and get our predictions:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model.eval()mistral_few_shot_answers = []
for item in tqdm(test_set_subsample):
few_shot_prompt_messages = build_few_shot_prompt(SYSTEM_PROMPT, item, few_shot_example)
input_ids = tokenizer.apply_chat_template(few_shot_prompt_messages, tokenize=True, return_tensors = "pt").cuda()
outputs = model.generate(input_ids = input_ids, max_new_tokens=200, do_sample=False)
# https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554
gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
mistral_few_shot_answers.append(parse_answer(gen_text.strip()))
At the entity extraction level, the LLM performs quite well, considering it has not been explicitly fine-tuned for this task. However, its performance as a zero-shot linker is quite poor, with an overall performance of less than 1%. This outcome is intuitive, though, because the output space for MeSH labels is vast, and it is a hard task to exactly map entities to a specific MeSH ID.
Retrieval Augmented Generation (RAG) [12] refers to a framework that combines LLMs with an external KB equipped with a querying function, such as a retriever/linker. For each incoming query, the system first retrieves knowledge relevant to the query from the KB using the querying function. It then combines the retrieved knowledge and the query, providing this combined prompt to the LLM to perform the task. This approach is based on the understanding that LLMs may not have all the necessary knowledge or information to answer an incoming query effectively. Thus, knowledge is injected into the model by querying an external knowledge source.
Using a RAG framework can offer several advantages:
- An existing LLM can be utilized for a new domain or task without the need for domain-specific fine-tuning, as the relevant information can be queried and provided to the model through a prompt.
- LLMs can sometimes provide incorrect answers (hallucinate) when responding to queries. Employing RAG with LLMs can significantly reduce such hallucinations, as the answers provided by the LLM are more likely to be grounded in facts due to the knowledge supplied to it.
Considering that the LLM lacks specific knowledge of MeSH terminologies, we investigate whether a RAG setup could enhance performance. In this approach, for each input paragraph, we utilize a BM-25 retriever to query the KB. For each MeSH ID, we have access to a general description of the ID and the entity names associated with it. After retrieval, we inject this information to the model through the prompt for entity linking.
To investigate the effect of the number of retrieved IDs provided as context to the model on the entity linking process, we run this setup by providing top 10, 30 and 50 documents to the model and quantify its performance on entity extraction and MeSH concept identification.
Let’s first define our BM-25 Retriever:
from rank_bm25 import BM25Okapi
from typing import List, Tuple, Dict
from nltk.tokenize import word_tokenize
from tqdm import tqdmclass BM25Retriever:
"""
A class for retrieving documents using the BM25 algorithm.
Attributes:
index (List[int, str]): A dictionary with document IDs as keys and document texts as values.
tokenized_docs (List[List[str]]): Tokenized version of the documents in `processed_index`.
bm25 (BM25Okapi): An instance of the BM25Okapi model from the rank_bm25 package.
"""
def __init__(self, docs_with_ids: Dict[int, str]):
"""
Initializes the BM25Retriever with a dictionary of documents.
Args:
docs_with_ids (List[List[str, str]]): A dictionary with document IDs as keys and document texts as values.
"""
self.index = docs_with_ids
self.tokenized_docs = self._tokenize_docs([x[1] for x in self.index])
self.bm25 = BM25Okapi(self.tokenized_docs)
def _tokenize_docs(self, docs: List[str]) -> List[List[str]]:
"""
Tokenizes the documents using NLTK's word_tokenize.
Args:
docs (List[str]): A list of documents to be tokenized.
Returns:
List[List[str]]: A list of tokenized documents.
"""
return [word_tokenize(doc.lower()) for doc in docs]
def query(self, query: str, top_n: int = 10) -> List[Tuple[int, float]]:
"""
Queries the BM25 model and retrieves the top N documents with their scores.
Args:
query (str): The query string.
top_n (int): The number of top documents to retrieve.
Returns:
List[Tuple[int, float]]: A list of tuples, each containing a document ID and its BM25 score.
"""
tokenized_query = word_tokenize(query.lower())
scores = self.bm25.get_scores(tokenized_query)
doc_scores_with_ids = [(doc_id, scores[i]) for i, (doc_id, _) in enumerate(self.index)]
top_doc_ids_and_scores = sorted(doc_scores_with_ids, key=lambda x: x[1], reverse=True)[:top_n]
return [x[0] for x in top_doc_ids_and_scores]
We now process our KB file and create a BM-25 retriever instance that indexes it. While indexing the KB, we index each ID using a concatenation of their description, aliases and canonical name.
def process_index(index):
"""
Processes the initial document index to combine aliases, canonical names, and definitions into a single text index.Args:
- index (Dict): The MeSH knowledge base
Returns:
List[List[int, str]]: A dictionary with document IDs as keys and combined text indices as values.
"""
processed_index = []
for key, value in tqdm(index.items()):
assert(type(value["aliases"]) != list)
aliases_text = " ".join(value["aliases"].split(","))
text_index = (aliases_text + " " + value.get("canonical_name", "")).strip()
if "definition" in value:
text_index += " " + value["definition"]
processed_index.append([value["concept_id"], text_index])
return processed_index
mesh_data = read_jsonl_file("mesh_2020.jsonl")
process_mesh_kb(mesh_data)
mesh_data_kb = {x["concept_id"]:x for x in mesh_data}
mesh_data_dict = process_index({x["concept_id"]:x for x in mesh_data})
retriever = BM25Retriever(mesh_data_dict)
mistral_rag_answers = {10:[], 30:[], 50:[]}for k in [10,30,50]:
for item in tqdm(test_set_subsample):
relevant_mesh_ids = retriever.query(item["title"] + " " + item["abstract"], top_n = k)
relevant_contexts = [mesh_data_kb[x] for x in relevant_mesh_ids]
rag_prompt = build_rag_prompt(SYSTEM_RAG_PROMPT, item, relevant_contexts)
input_ids = tokenizer.apply_chat_template(rag_prompt, tokenize=True, return_tensors = "pt").cuda()
outputs = model.generate(input_ids = input_ids, max_new_tokens=200, do_sample=False)
gen_text = tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
mistral_rag_answers[k].append(parse_answer(gen_text.strip()))
entity_scores_at_k = {}
mesh_scores_at_k = {}for key, value in mistral_rag_answers.items():
entity_scores = [calculate_entity_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, value)]
macro_precision_entity = sum([x[0] for x in entity_scores]) / len(entity_scores)
macro_recall_entity = sum([x[1] for x in entity_scores]) / len(entity_scores)
macro_f1_entity = sum([x[2] for x in entity_scores]) / len(entity_scores)
entity_scores_at_k[key] = {"macro-precision": macro_precision_entity, "macro-recall": macro_recall_entity, "macro-f1": macro_f1_entity}
mesh_scores = [calculate_mesh_metrics(gt["annotations"],pred) for gt, pred in zip(test_set_subsample, value)]
macro_precision_mesh = sum([x[0] for x in mesh_scores]) / len(mesh_scores)
macro_recall_mesh = sum([x[1] for x in mesh_scores]) / len(mesh_scores)
macro_f1_mesh = sum([x[2] for x in mesh_scores]) / len(mesh_scores)
mesh_scores_at_k[key] = {"macro-precision": macro_precision_mesh, "macro-recall": macro_recall_mesh, "macro-f1": macro_f1_mesh}
In general, the RAG setup improves the overall MeSH Identification process, compared to the original zero-shot setup. But what is the impact of the number of documents provided as information to the model? We plot the scores as a function of the number of retrieved IDs provided to the model as context.