==================

Intro

My main goal here was simple. I just want to see the impact of hard-negative choices on a modern embedder. Nothing more, nothing less.

I noticed a model trained with BM25 hard-negatives performed extremely strongly and I wanted to ask the question ‘why?’. For a bit of background, Filevine is a legal tech company where I currently work where I’m inundated with legal text of every variety. What’s striking about the legal text is that it can range from nonsensical to extremely logical and is almost always set up like some kind of discrete math word problem. It’s very funky stuff. My gut hypothesis is this is probably why most frontier models suck at law right now - is because it appears normal, it’s not super varied (so i think a lot gets filtered out in pretraining) - and most of the juicy stuff isn’t public.

Anyway, back to the show: We’ll take a standard dataset (legal summarization; another constraint) find hard-negatives via bm25, find negatives through a high-ranked (on MTEB) embedding model (Alibaba-NLP/gte-Qwen2-1.5B-instruct) + high ranked re-ranker (Alibaba-NLP/gte-multilingual-reranker-base), and lastly the high-ranked embedding model alone (Alibaba-NLP/gte-multilingual-reranker-base). We’ll do this to see the effect on hard-negative choice and see how well the choices impact downstream legal retrieval tasks

Experimental Setup

We’re going to take the dataset of joelniklaus/legal_case_document_summarization (selected after eyeballing the data and gut feeling that it ‘looked good’ – highly scientific, I know) and use that to create a legal training dataset and then test it on 6 MTEB legal tasks: LegalBenchCorporateLobbying, LegalSummarization, AILACasedocs, LegalQuAD, LegalBenchConsumerContractsQA, AILAStatutes.

some notes on the evals:

AILA tasks:

These tasks require some kind of knowledge about Indian law (which we don’t have).

Task 1: Identifying relevant prior cases

We provide ~3000 case documents of cases that were judged in the Supreme Court of India. For each query, the task is to retrieve the most similar / relevant case document with respect to the situation in the given query.

Task 2: Identifying relevant statutes

We have identified a set of 197 statutes (Sections of Acts) from Indian law, that are relevant to some of the queries. We provide the title and description of these statutes. For each query, the task is to identify the most relevant statutes (from among the 197 statutes). Note that, the task can be modelled either as an unsupervised retrieval task (where you search for relevant statues) or as a supervised classification task (e.g., trying to predict for each statute whether it is relevant). For the latter, case documents provided for Task 1 can be utilized. However, if a team wishes to apply supervised models, then it is their responsibility to create the necessary training data.

LegalQuAD

This is a German dataset - and our training set is not german - so duh, I wouldn’t expect it to be useful for us given that we won’t have any german contrastive pairs, but we’ll throw it in there to see what kind of random effects we can see.

Basic Analysis:

I want to do some basic analysis to understand this legal dataset by comparing it against other domain-specific (math, medicine) and general data (openweb) to showcase law data to the reader

Part of Speech Distribution

def get_pos_distribution(documents, sample_size=1000):
    """
    For efficiency, you may want to sample some documents 
    if the dataset is very large.
    """
    sampled_docs = documents[:sample_size]
    pos_counts = Counter()
    total_tokens = 0
    
    for doc_text in sampled_docs:
        doc = nlp(doc_text)
        for token in doc:
            pos_counts[token.pos_] += 1
            total_tokens += 1
    
    # Convert counts to proportions
    pos_dist = {pos: count / total_tokens for pos, count in pos_counts.items()}
    return pos_dist

pos_legal = get_pos_distribution(legal_docs['judgement'])
pos_math = get_pos_distribution(math_docs['prompt'])
pos_openweb = get_pos_distribution(openweb_docs['text'])
pos_math2= get_pos_distribution(other_math['question'])
pos_med = get_pos_distribution(med['page_text'])

# Convert to DataFrame for easy comparison
df_pos = pd.DataFrame([pos_legal, pos_math, pos_openweb, pos_math2, pos_med], 
                      index=["Legal", "Math", "OpenWeb", "Word Math", "Wiki Med"]).fillna(0)
df_pos = df_pos.transpose().sort_values(by=["Legal", "Math", "OpenWeb", "Word Math", "Wiki Med"], ascending=False)
print(df_pos)
          Legal      Math   OpenWeb  Word Math  Wiki Med
NOUN   0.175572  0.189513  0.175243   0.210162  0.249462
ADP    0.135626  0.068894  0.097755   0.097116  0.112305
DET    0.125641  0.071518  0.079878   0.091713  0.071084
VERB   0.090413  0.053803  0.105108   0.084351  0.087313
PUNCT  0.089545  0.167888  0.125426   0.152996  0.122003
PROPN  0.076987  0.064055  0.087082   0.039782  0.050926
AUX    0.060485  0.036361  0.048429   0.048208  0.050527
PRON   0.042584  0.030702  0.064171   0.042777  0.020965
ADJ    0.041660  0.049893  0.061888   0.057811  0.102179
NUM    0.029754  0.077861  0.020058   0.088942  0.017541
CCONJ  0.029093  0.025234  0.028108   0.032279  0.034960
SPACE  0.026921  0.032041  0.020384   0.000056  0.021435
SCONJ  0.025912  0.015993  0.018093   0.026596  0.011554
PART   0.023696  0.004484  0.025851   0.012094  0.012989
ADV    0.023398  0.009787  0.036608   0.013690  0.032312
X      0.002462  0.019848  0.002509   0.001092  0.000881
SYM    0.000129  0.081552  0.002469   0.000280  0.001497
INTJ   0.000122  0.000574  0.000937   0.000056  0.000067

Named Entity Recogniton Distributin

def get_ner_distribution(documents, sample_size=1000):
    sampled_docs = documents[:sample_size]
    ner_counts = Counter()
    total_ents = 0

    for doc_text in sampled_docs:
        doc = nlp(doc_text)
        for ent in doc.ents:
            ner_counts[ent.label_] += 1
            total_ents += 1
    
    ner_dist = {ent_type: count / total_ents for ent_type, count in ner_counts.items()} if total_ents > 0 else {}
    return ner_dist

ner_legal = get_ner_distribution(legal_docs['judgement'])
ner_math = get_ner_distribution(math_docs['prompt'])
ner_openweb = get_ner_distribution(openweb_docs['text'])
ner_math2 = get_ner_distribution(other_math['question'])
ner_med = get_ner_distribution(med['page_text'])

df_ner = pd.DataFrame([ner_legal, ner_math, ner_openweb, ner_math2, ner_med], 
                      index=["Legal", "Math", "OpenWeb", "Word Math", "Wiki Med"]).fillna(0)
df_ner = df_ner.transpose().sort_values(by=["Legal", "Math", "OpenWeb", "Word Math", "Wiki Med"], ascending=False)
print(df_ner)
                Legal      Math   OpenWeb  Word Math  Wiki Med
ORG          0.263405  0.068710  0.213852   0.040993  0.253864
CARDINAL     0.238252  0.252273  0.118554   0.517965  0.164357
DATE         0.131707  0.081172  0.138075   0.095008  0.153279
PERSON       0.126937  0.181543  0.218140   0.083193  0.112978
LAW          0.104598  0.002358  0.003750   0.000723  0.000867
GPE          0.053235  0.017178  0.121413   0.048469  0.107683
ORDINAL      0.027613  0.031324  0.022764   0.034242  0.027509
NORP         0.016257  0.008420  0.060991   0.011575  0.051120
WORK_OF_ART  0.008050  0.000337  0.013449   0.000723  0.006770
LOC          0.007295  0.001347  0.012527   0.000482  0.016860
PRODUCT      0.007116  0.008420  0.009745   0.005787  0.011197
FAC          0.004185  0.001010  0.009714   0.000482  0.002543
TIME         0.003254  0.019198  0.014541   0.016639  0.012374
QUANTITY     0.002291  0.022903  0.007363   0.139860  0.008804
MONEY        0.002089  0.301785  0.016201   0.000241  0.003460
EVENT        0.001604  0.000337  0.006563   0.000482  0.003500
PERCENT      0.001207  0.001684  0.009929   0.000241  0.061709
LANGUAGE     0.000905  0.000000  0.002429   0.002894  0.001127

BERT Embeddings Clusters

model = SentenceTransformer('all-MiniLM-L6-v2').cpu()

def get_embeddings(documents, sample_size=10000):
    sample_docs = documents[:sample_size]
    embeddings = model.encode(sample_docs, show_progress_bar=True)
    return embeddings

emb_legal = get_embeddings(legal_docs['judgement'])
emb_math = get_embeddings(math_docs['prompt'])
emb_openweb = get_embeddings(openweb_docs['text'])
emb_other_math = get_embeddings(other_math['question'])
emb_med = get_embeddings(med['page_text'])

# Combine them with labels
all_embeddings = np.vstack([emb_legal, emb_math, emb_openweb, emb_other_math, emb_med])
labels = (["Legal"] * len(emb_legal)) + (["Math"] * len(emb_math)) + (["OpenWeb"] * len(emb_openweb)) + (["Orca Math Word Problems"] * len(emb_other_math)) + (["Wiki Med"] * len(emb_med))

# Apply UMAP or t-SNE
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
# Alternatively: reducer = TSNE(n_components=2, random_state=42)

embedding_2d = reducer.fit_transform(all_embeddings)

# Plot
plt.figure(figsize=(10, 8))
sns.scatterplot(
    x=embedding_2d[:,0],
    y=embedding_2d[:,1],
    hue=labels,
    palette="deep"
)
plt.title("Sentence-BERT Embeddings (UMAP) - Dataset Clusters")
plt.show()

Med and Math are on their own islands while Legal shares a lot of similarities with ‘regular text’. And in my experience, this checks - it uses ‘normal words’ but uses them differently.

Lexical Entropy



import math
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize

nltk.download('punkt')

def compute_lexical_entropy(documents):
    """
    Computes the lexical entropy (in bits) for a list of documents.
    H(X) = - sum( p(x) * log2(p(x)) ) for all x in the vocabulary.
    """
    freq = Counter()
    from nltk.corpus import stopwords
    stop_words = set(stopwords.words('english'))
    for doc in documents:
        word_tokens = word_tokenize(doc.lower())
        word_tokens = [re.sub(r'[^\w\s]', '', token) for token in word_tokens if re.sub(r'[^\w\s]', '', token)]
        word_tokens_counter = Counter(word_tokens)
        freq.update(word_tokens_counter)

    total_tokens = sum(freq.values())
    
    # Calculate entropy
    entropy = 0.0
    for word, count in freq.items():
        p = count / total_tokens
        entropy -= p * math.log2(p)  # minus sign because p*log2(p) is negative
    return entropy

entropy_legal = compute_lexical_entropy(legal_docs['judgement'])
entropy_math = compute_lexical_entropy(math_docs['prompt'])
entropy_openweb = compute_lexical_entropy(openweb_docs['text'])
entropy_other_math = compute_lexical_entropy(other_math['question'])
entropy_med = compute_lexical_entropy(med['page_text'])

print(f"Lexical Entropy (Legal Summarization): {entropy_legal:.4f} bits")
print(f"Lexical Entropy (Competition Math):   {entropy_math:.4f} bits")
print(f"Lexical Entropy (OpenWeb):           {entropy_openweb:.4f} bits")
print(f"Lexical Entropy (Math Word Problems):           {entropy_other_math:.4f} bits")
print(f"Lexical Entropy (Wiki Med):           {entropy_med:.4f} bits")
Lexical Entropy (Legal Summarization): 9.5891 bits
Lexical Entropy (Competition Math):   9.1562 bits
Lexical Entropy (OpenWeb):           11.0669 bits
Lexical Entropy (Math Word Problems):           9.4199 bits
Lexical Entropy (Wiki Med):           10.9722 bits

OpenWeb has the highest lexical entropy (~11.07), suggesting the broadest or most diverse vocabulary and less repetition of the same tokens. Wiki Med is second highest (10.97), indicating it also uses a fairly large and varied vocabulary—likely due to the breadth of medical topics while Legal Summarization sits in the middle (~9.59), higher than the math corpora but lower than OpenWeb and Wiki Med.

The Math datasets (Competition Math and Math Word Problems) have the lowest entropy (~9.16–9.42), consistent with a narrower or more specialized vocabulary (repetitions of numbers, math symbols, certain keywords).

Higher entropy typically means more linguistic variety and the math datasets are more repetitive or domain-specific (numbers, symbols, certain repeated words). Legal is “in-between” with some specialized vocabulary, but it’s not as wide-ranging as something like OpenWeb or Wiki Med, which cover many topics.

Readability


import statistics
import textstat
from nltk.tokenize import sent_tokenize

def compute_readability_scores(documents, sample_size=1000):
    """
    Computes average readability scores on a subset of the dataset
    to avoid extremely large computations.
    Returns: (avg_flesch, avg_fog)
    """
    # Sample to speed up (adjust as needed)
    sampled_docs = documents[:sample_size]
    
    flesch_scores = []
    fog_scores = []
    
    for doc in sampled_docs:
        # textstat expects raw text (sentences, words)
        flesch = textstat.flesch_reading_ease(doc)
        fog = textstat.gunning_fog(doc)
        
        flesch_scores.append(flesch)
        fog_scores.append(fog)
    
    avg_flesch = statistics.mean(flesch_scores)
    avg_fog = statistics.mean(fog_scores)
    
    return avg_flesch, avg_fog


# Example usage
legal_flesch, legal_fog = compute_readability_scores(legal_docs['judgement'])
math_flesch, math_fog = compute_readability_scores(math_docs['prompt'])
openweb_flesch, openweb_fog = compute_readability_scores(openweb_docs['text'])
other_math_flesch, other_math_fog = compute_readability_scores(other_math['question'])
med_flesch, med_fog = compute_readability_scores(med['page_text'])

print("==== Flesch Reading Ease (Higher = Easier to Read) ====")
print(f"Legal Summarization: {legal_flesch:.2f}")
print(f"Competition Math:    {math_flesch:.2f}")
print(f"OpenWeb:             {openweb_flesch:.2f}")
print(f"Math Word Problems:             {other_math_flesch:.2f}")
print(f"Wiki Med:             {med_flesch:.2f}")

print("\n==== Gunning Fog Index (Lower = Easier to Read) ====")
print(f"Legal Summarization: {legal_fog:.2f}")
print(f"Competition Math:    {math_fog:.2f}")
print(f"OpenWeb:             {openweb_fog:.2f}")
print(f"Math Word Problems:             {other_math_fog:.2f}")
print(f"Wiki Med:             {med_fog:.2f}")
==== Flesch Reading Ease (Higher = Easier to Read) ====
Legal Summarization: 54.08
Competition Math:    75.10
OpenWeb:             59.90
Math Word Problems:             82.43
Wiki Med:             36.06

==== Gunning Fog Index (Lower = Easier to Read) ====
Legal Summarization: 12.70
Competition Math:    7.85
OpenWeb:             11.40
Math Word Problems:             7.45
Wiki Med:             13.38

High Flesch, Low Fog = simpler sentences, shorter words, more straightforward syntax. Low Flesch, High Fog = complex sentences, longer words, domain-specific jargon

Wiki Med is the hardest to read, eg the most domain-heavy in vocabulary and syntax. Legal Summarization is also formal, but not as severe as Wiki Med. The math text is ‘easy’ (“Find x if …,” “What is the value of …,” etc.) - but given that we know it’s technically ‘hard’, this sort of metric doesn’t really tell us anything.

Quick Conclusions:

OpenWeb = high diversity, moderate difficulty.

Wiki Med = high diversity, highest difficulty (lots of specialized medical terms).

Legal Summarization = medium diversity, high difficulty (formal style).

Math Word Problems & Competition Math = lower diversity, easier style (short sentences, specialized numeric vocab, but not big words).

I don’t know how to visualize my bias yet that legal text is like discrete math word problems, vocalized in specialized text (like medical) that are extremely long. Just having gone through many legal docs of late, that’s how I’ve come to think about it. I think it gives the appearance of having overlap with webtext, but it’s ‘just not’. And I think this largely comes from it being a protocol/communication layer over morals/values/ethics where it uses the framework/logic to construct text that “isn’t strictly logical” (sometimes). “1 + 1 = 3 because ‘jury’ and here’s 10000 words on why.”

It’s very hard and I really don’t think many models so far have been able to grok it. Onward.

Code

In this portion, I’ll provide some of the code I used for generating the initial hard negatives before this morphed into my own personal rabbit-hole project.

# Standard library
import hashlib
from typing import List, Dict, Any, Tuple, Optional, Union
from collections import deque
from concurrent.futures import ThreadPoolExecutor

# Deep learning & numerical
import torch 
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Data handling
from datasets import Dataset, load_dataset, load_from_disk

# Text processing
import bm25s
import Stemmer  # For stemming
import pickle


class FastEmbedder(mp.Process):
    """Parallel processor for creating document embeddings using GPU acceleration."""
    
    def __init__(
        self, 
        gpu_id: int,
        input_queue: mp.Queue,
        output_queue: mp.Queue,
        model_id: str = "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
        batch_size: int = 256,
        embedding_size: int = 1536
    ):
        """Initialize the FastEmbedder process.
        
        Args:
            gpu_id: GPU device identifier
            input_queue: Queue for receiving document batches
            output_queue: Queue for sending computed embeddings
            model_id: Identifier for the embedding model
            batch_size: Number of documents to process at once
            embedding_size: Dimension of the embedding vectors
        """
        super().__init__()
        self.gpu_id = gpu_id
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.model_id = model_id
        self.batch_size = batch_size
        self.embedding_size = embedding_size

    def setup(self):
        """Initialize GPU device and load model."""
        self.device = torch.device(f"cuda:{self.gpu_id}")
        torch.cuda.set_device(self.device)
        self.model = SentenceTransformer(self.model_id).cuda()
        self.model.half()

    def process_batch(self, batch: Dict[str, Any], current_batch_size: int, max_seq_len: int = 32768) -> Dict[str, Any]:
        """Process a batch of documents and create embeddings.
        
        Args:
            batch: Dictionary containing texts and their indices
            current_batch_size: Current batch size for processing
            max_seq_len: Maximum sequence length for tokenization
            
        Returns:
            Dictionary containing embeddings and document indices
        """
        start_idx = 0
        all_embeddings = []
        all_indices = []
        
        while start_idx < len(batch['texts']):
            end_idx = start_idx + current_batch_size
            
            try:
                doc_texts = batch['texts'][start_idx:end_idx]
                doc_indices = batch['doc_indices'][start_idx:end_idx]
                
                tokenized = self.model.tokenizer(
                    doc_texts, 
                    truncation=True, 
                    max_length=max_seq_len
                )
                truncated_texts = self.model.tokenizer.batch_decode(
                    tokenized['input_ids'],
                    skip_special_tokens=True
                )
                
                embeddings = self.model.encode(
                    truncated_texts,
                    convert_to_numpy=True,
                    normalize_embeddings=True
                ).astype(np.float32)
                
                torch.cuda.empty_cache()
                all_embeddings.extend(embeddings)
                all_indices.extend(doc_indices)
                start_idx = end_idx
                
            except RuntimeError as e:
                if "CUBLAS_STATUS_ALLOC_FAILED" in str(e):
                    self.model.to('cpu')
                    torch.cuda.synchronize(self.gpu_id)
                    torch.cuda.empty_cache()
                    torch.cuda.set_device(self.gpu_id)
                    self.setup()
                    return self.process_batch(batch, current_batch_size, max_seq_len)
                
                if "out of memory" in str(e):
                    torch.cuda.empty_cache()
                    if max_seq_len > 1024:
                        new_seq_len = max_seq_len // 2
                        return self.process_batch(batch, current_batch_size, new_seq_len)
                    self.batch_size = max(1, current_batch_size // 2)
                    return self.process_batch(batch, self.batch_size, max_seq_len)
                raise e
                
        return {
            'embeddings': np.array(all_embeddings),
            'doc_indices': all_indices
        }

    def run(self):
        """Main process loop for embedding generation."""
        self.setup()
        while True:
            batch = self.input_queue.get()
            if batch is None:
                self.output_queue.put(None)
                break
                
            result = self.process_batch(batch, self.batch_size)
            self.output_queue.put(result)


class ParallelFaissSearcher(mp.Process):
    """Parallel processor for finding hard negatives using FAISS similarity search."""
    
    def __init__(
        self,
        gpu_id: int,
        input_queue: mp.Queue,
        output_queue: mp.Queue,
        index: faiss.IndexFlatIP,
        faiss_to_doc_idx: Dict[int, int],
        doc_list: List[str],
        query_idx_relations: Dict[int, Dict[str, set]],
        rerank: bool = True,
        model_id: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
        faiss_use_gpu: bool = True
    ):
        """Initialize the ParallelFaissSearcher process.
        
        Args:
            gpu_id: GPU device identifier
            input_queue: Queue for receiving query batches
            output_queue: Queue for sending results
            index: FAISS index for similarity search
            faiss_to_doc_idx: Mapping from FAISS indices to document indices
            doc_list: List of original documents
            query_idx_relations: Mapping of query relationships
            rerank: Whether to use reranking
            model_id: Identifier for reranking model
            faiss_use_gpu: Whether to use GPU for FAISS
        """
        super().__init__()
        self.gpu_id = gpu_id
        self.input_queue = input_queue
        self.output_queue = output_queue
        self.cpu_index = index
        self.faiss_to_doc_idx = faiss_to_doc_idx
        self.doc_list = doc_list
        self.query_idx_relations = query_idx_relations
        self.rerank = rerank
        self.model_id = model_id
        self.faiss_use_gpu = faiss_use_gpu

    def setup(self):
        """Initialize GPU resources and models."""
        if self.faiss_use_gpu:
            res = faiss.StandardGpuResources()
            # load here
            cpu_index = faiss.read_index(self.cpu_index)
            self.index = faiss.index_cpu_to_gpu(res, self.gpu_id, cpu_index)
            print(f"FAISS success on gpu: {self.gpu_id}")
        else:
            self.index = faiss.read_index(self.cpu_index)
  
        if self.rerank:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
            self.reranker = AutoModelForSequenceClassification.from_pretrained(
                self.model_id, trust_remote_code=True,
                torch_dtype=torch.float16
            ).to(f"cuda:{self.gpu_id}")
            self.reranker.eval()


    def find_hard_negative(self, query_idx: int, k: int = 10) -> Optional[int]:
        """Find hard negative example for a given query.
        
        Args:
            query_idx: Index of the query document
            k: Number of nearest neighbors to consider
            
        Returns:
            Index of the selected hard negative, or None if not found
        """
        print(f"gpu: {self.gpu_id}: {query_idx}")
        query_text = self.doc_list[query_idx]
        positive_idxs = self.query_idx_relations[query_idx]['positive']
        
        faiss_query_idx = [idx for idx, doc_idx in self.faiss_to_doc_idx.items() 
                          if doc_idx == query_idx][0]
        query_embedding = self.index.reconstruct(faiss_query_idx).reshape(1, -1)
        print(f"gpu: {self.gpu_id}: reconstructing {query_idx}")

        D, I = self.index.search(query_embedding, k)
        print(f"gpu: {self.gpu_id}: search-passed {query_idx}")
        
        candidates = []
        for faiss_idx in I[0]:
            doc_idx = self.faiss_to_doc_idx[faiss_idx]
            if doc_idx != query_idx and doc_idx not in positive_idxs:
                candidates.append((doc_idx, self.doc_list[doc_idx]))
                
        if not self.rerank:
            return candidates[0][0] if candidates else None
            
        if candidates:
            pairs = [[query_text, c[1]] for c in candidates]
            with torch.no_grad():
                inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=32768)
                inputs = {k: v.to(f"cuda:{self.gpu_id}") for k, v in inputs.items()}
                scores = self.reranker(**inputs, return_dict=True).logits.view(-1, ).float()
                # print(scores)
            # scores = self.reranker.predict(pairs)
            best_idx = max(range(len(scores)), key=lambda i: scores[i])
            return candidates[best_idx][0]
            
        return None

    def run(self):
        """Main process loop for finding hard negatives."""
        self.setup()
        while True:
            batch = self.input_queue.get()
            if batch is None:
                self.output_queue.put(None)
                break

            results = {}
            for query_idx in batch:
                result = self.find_hard_negative(query_idx)
                if result:
                    results[query_idx] = result
                    
            self.output_queue.put(results)


def create_bm25_data(dataset_config: dict, save_name: str = 'summarizer_bm25') -> Dataset:
    """Create training dataset using BM25 retrieval for hard negative mining.
    
    This function processes legal documents and summaries to create training examples
    with hard negative pairs selected using BM25 similarity scores.
    
    Args:
        save_name: Name for saving the processed dataset
        
    Returns:
        Dataset containing query-positive-negative triplets
    """
    def get_hard_negatives(
        query: str, 
        positive_ids: set, 
        stemmer, 
        retriever, 
        k: int = 2, 
        threshold: float = 0.5, 
        single: bool = True
    ) -> Union[Tuple[str, float], List[Tuple[str, float]]]:
        """Get hard negative examples that are semantically similar but not relevant.
        
        Args:
            query: Input text to find negatives for
            positive_ids: Set of known positive document IDs
            stemmer: Stemmer for text preprocessing
            retriever: BM25 retriever instance
            k: Number of negatives to retrieve
            threshold: Similarity threshold for hard negatives
            single: Whether to return single negative or multiple
            
        Returns:
            Either tuple of (negative_text, score) or list of such tuples
        """
        query_tokens = bm25s.tokenize(query, stemmer=stemmer)
        results, scores = retriever.retrieve(query_tokens, k=len(corpus))

        if single:
            for i in range(results.shape[1]):
                doc_id, score = results[0, i], scores[0, i]
                if 0 < score < threshold and doc_id not in positive_ids:
                    return corpus[doc_id], score
            return None, 0.0

        hard_negatives = []
        for i in range(results.shape[1]):
            doc_id, score = results[0, i], scores[0, i]
            if 0 < score < threshold and doc_id not in positive_ids:
                hard_negatives.append((corpus[doc_id], score))
                if len(hard_negatives) >= k:
                    break
        return hard_negatives

    def process_query_threaded(query_data: Tuple[str, List[str]], shared_stemmer, shared_retriever) -> List[Tuple[str, str, Tuple[str, float]]]:
        """Process a single query to generate training examples using shared resources.
        
        Args:
            query_data: Tuple of (query_text, positive_examples)
            shared_stemmer: Stemmer instance shared across threads
            shared_retriever: BM25 retriever shared across threads
            
        Returns:
            List of (query, positive, negative) training examples
        """
        query, positives = query_data
        positives = deque(positives)
        len_positives = len(positives)
        positives = [positives] if isinstance(positives, str) else positives
        positive_ids = {i for i, doc in enumerate(corpus) if doc in positives}
        
        results = []
        while positives:
            pos = positives.popleft()
            neg = get_hard_negatives(
                query,
                positive_ids,
                shared_stemmer, 
                shared_retriever,
                k = 500 if len_positives < 10 else len_positives+500,
                threshold = 0.5,
                single=True
            )
            if neg:
                results.append((query, pos, neg))
                
        return results

    # Initialize data structures
    train_data = {
        "query": [],
        "positive": [],
        "negative": []
    }
    queries = {}
    
    ds = load_dataset(dataset_config["name"])
    for split in dataset_config["splits"]:
        for text1_col, text2_col in dataset_config["pairs"]:
            for item in ds[split]:
                text1, text2 = item[text1_col], item[text2_col]
                for base_text, paired_text in [(text1, text2), (text2, text1)]:
                    if base_text not in queries:
                        queries[base_text] = {"positive": [], "negative": []}
                    queries[base_text]["positive"].append(paired_text)

    # Create BM25 index
    corpus = list(queries.keys())
    stemmer = Stemmer.Stemmer("english")
    corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)

    # Process queries in parallel
    with ThreadPoolExecutor(max_workers=96) as executor:
        query_data = [(query, queries[query]['positive']) for query in queries]
        futures = [executor.submit(process_query_threaded, qd, stemmer, retriever) for qd in query_data]
        
        for future in futures:
            for query, pos, (neg, score) in future.result():
                if neg:
                    train_data['query'].append(query)
                    train_data['positive'].append(pos)
                    train_data['negative'].append(neg)

    # Create and filter dataset
    dataset = Dataset.from_dict(train_data)
    valid_indices = [i for i in range(len(dataset)) 
                    if not any(dataset[i][field] is None 
                    for field in ['query', 'positive', 'negative'])]
    dataset = dataset.select(valid_indices)
    dataset.save_to_disk(save_name)
    
    return dataset


def create_fast_index(
    doc_list: List[str],
    num_gpus: int = 8,
    batch_size: int = 256,
    embedding_size: int = 1536,
    model_id: str = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
) -> Tuple[faiss.IndexFlatIP, Dict[int, int]]:
    """Create FAISS index from document list using parallel processing.
    
    Args:
        doc_list: List of documents to embed
        num_gpus: Number of GPUs to use
        batch_size: Batch size per GPU
        embedding_size: Dimension of embeddings
        model_id: Model identifier
        
    Returns:
        FAISS index and mapping from FAISS indices to document indices
    """
    # single queue in case a gpu fails
    input_queue = mp.Queue()
    output_queue = mp.Queue()
    
    workers = []
    for gpu_id in range(num_gpus):
        worker = FastEmbedder(
            gpu_id,
            input_queue,
            output_queue,
            model_id=model_id,
            batch_size=batch_size,
            embedding_size=embedding_size
        )
        worker.start()
        workers.append(worker)

    index = faiss.IndexFlatIP(embedding_size)
    faiss_to_doc_idx = {}
    current_faiss_idx = 0

    # Distribute work
    for i in range(0, len(doc_list), batch_size * num_gpus):
        for gpu_id in range(num_gpus):
            start_idx = i + (gpu_id * batch_size)
            end_idx = min(start_idx + batch_size, len(doc_list))
            if start_idx >= len(doc_list):
                break
                
            batch_docs = doc_list[start_idx:end_idx]
            input_queue.put({
                'texts': batch_docs,
                'doc_indices': list(range(start_idx, end_idx))
            })

    # Send stop signals and collect results
    for q in workers:
        input_queue.put(None)

    completed = 0
    while completed < num_gpus:
        result = output_queue.get()
        if result is None:
            completed += 1
            continue
            
        index.add(result['embeddings'])
        for doc_idx in result['doc_indices']:
            faiss_to_doc_idx[current_faiss_idx] = doc_idx
            current_faiss_idx += 1

    faiss.write_index(index, 'embeddings.index')
    with open('faiss_to_doc_idx.pkl', 'wb') as f:
        pickle.dump(faiss_to_doc_idx, f)

    return index, faiss_to_doc_idx


def create_document_list_with_relations(dataset: Dataset) -> Tuple[List[str], Dict[str, int], Dict[int, str], Dict[int, Dict[str, set]]]:
    """Create document lists and relationship mappings from dataset.
    
    Args:
        dataset: HuggingFace dataset containing query-positive-negative triplets
        
    Returns:
        Tuple containing:
        - List of unique documents
        - Document to index mapping
        - Index to document mapping
        - Query relationship mapping
    """
    unique_docs = set()
    query_relations = {}
    
    for batch in dataset:
        if not (batch['query'] == "" or batch['positive'] == "" or batch['negative'] == ""):
            unique_docs.update([batch['query'], batch['positive'], batch['negative']])
            
            if batch['query'] not in query_relations:
                query_relations[batch['query']] = {
                    'positive': set(),
                    'negative': set()
                }
            
            query_relations[batch['query']]['positive'].add(batch['positive'])
            query_relations[batch['query']]['negative'].add(batch['negative'])
    
    doc_list = sorted(list(unique_docs))
    doc_to_idx = {doc: idx for idx, doc in enumerate(doc_list)}
    idx_to_doc = {idx: doc for idx, doc in enumerate(doc_list)}
    
    query_idx_relations = {}
    for query, relations in query_relations.items():
        query_idx = doc_to_idx[query]
        query_idx_relations[query_idx] = {
            'positive': {doc_to_idx[pos] for pos in relations['positive']},
            'negative': {doc_to_idx[neg] for neg in relations['negative']}
        }
    
    return doc_list, doc_to_idx, idx_to_doc, query_idx_relations

def create_hardneg_dataset(
    query_idx_relations: Dict[int, Dict[str, set]],
    hard_negatives: Dict[int, int],
    doc_list: List[str]
) -> Dataset:
    """Create dataset with hard negatives.
    
    Args:
        query_idx_relations: Query relationship mapping
        hard_negatives: Mapping of queries to hard negative indices
        doc_list: List of documents
        
    Returns:
        HuggingFace dataset with query-positive-hardnegative triplets
    """
    new_dataset = []
    
    for query_idx, relations in query_idx_relations.items():
        if query_idx in hard_negatives:
            new_example = {
                'query': doc_list[query_idx],
                'positive': doc_list[list(relations['positive'])[0]],
                'negative': doc_list[hard_negatives[query_idx]]
            }
            new_dataset.append(new_example)
    
    return Dataset.from_dict({
        'query': [ex['query'] for ex in new_dataset],
        'positive': [ex['positive'] for ex in new_dataset],
        'negative': [ex['negative'] for ex in new_dataset]
    })


def create_parallel_hard_negatives(
    index: faiss.IndexFlatIP,
    faiss_to_doc_idx: Dict[int, int],
    doc_list: List[str],
    query_idx_relations: Dict[int, Dict[str, set]],
    num_gpus: int = 8,
    rerank: bool = True,
    save_both: bool = True
    faiss_gpu: bool = True,
    save_name: str = "embeddings_huge_hardneg_dataset_rerank",
    model_id: Union[str, None] = "Alibaba-NLP/gte-multilingual-reranker-base",
    k_hard: int = 1,
    random_n: int = 0,
) -> Dict[int, int]:
    """
    Create hard negative examples using parallel GPU processing with FAISS.
    
    Args:
        index: FAISS index containing document embeddings
        faiss_to_doc_idx: Mapping from FAISS indices to document indices
        doc_list: List of document texts
        query_idx_relations: Dictionary mapping query indices to their positive/negative relations
        num_gpus: Number of GPUs to use for parallel processing
        rerank: Whether to use reranking on candidates
        rerank: save both rerank + non-rerank
        faiss_gpu: Whether to use GPU acceleration for FAISS
        save_name: Name for saving the resulting dataset
        model_id: Name of HF Reranker model
        k_hard: the number of hard negatives to find for each query
        random_n: the number of random negatives to add to each query
        
    Returns:
        Dict[int, int]: Mapping of query indices to their hard negative indices
        
    The function:
    1. Distributes work across multiple GPUs
    2. Finds hard negatives using FAISS similarity search
    3. Optionally reranks candidates
    4. Creates and saves a dataset with the hard negatives
    """
    input_queue = mp.Queue()
    output_queue = mp.Queue()
    
    workers = []
    for gpu_id in range(num_gpus):
        worker = ParallelFaissSearcher(
            gpu_id, 
            input_queue, output_queue,
            index, 
            faiss_to_doc_idx, 
            doc_list, 
            query_idx_relations, 
            rerank,
            model_id=model_id,
            faiss_use_gpu=faiss_gpu
            k_hard=k_hard
        )
        workers.append(worker)

    query_indices = list(query_idx_relations.keys())
    batch_size = len(query_indices) // (num_gpus * 4)
    
    for i in range(0, len(query_indices), batch_size):
        gpu_id = (i // batch_size) % num_gpus
        batch = query_indices[i:i + batch_size]
        input_queue.put(batch)

    for _ in workers:
        input_queue.put(None)
    
    for worker in workers: 
        worker.start()

    hard_negatives = {}
    hard_negatives_rerank = {}
    completed = 0
    while completed < num_gpus:
        result = output_queue.get()
        if result is None:
            completed += 1
            continue
        if rerank:
            key = list(result.keys())[0]
            result[key][1]
            hard_negatives_rerank[key] = result[key][1]
            hard_negatives[key] = result[key][0]
        else:
            hard_negatives.update(result)

    if rerank:
        if save_both:
            hardneg_dataset = create_hardneg_dataset(
                query_idx_relations, hard_negatives, doc_list, random_n
            )
            hardneg_dataset.save_to_disk(save_name)
        hardneg_dataset = create_hardneg_dataset(
            query_idx_relations, hard_negatives_rerank, doc_list, random_n
        )
        hardneg_dataset.save_to_disk(save_name + '_rerank')
    else:
        hardneg_dataset = create_hardneg_dataset(
            query_idx_relations, hard_negatives, doc_list, random_n
        )
        hardneg_dataset.save_to_disk(save_name)

    return hard_negatives


def create_aligned_datasets(datasets_dict: DatasetDict) -> DatasetDict:
    """Create aligned datasets using ordered hash-based matching."""
    # Get the smallest dataset as reference
    smallest_dataset = min(datasets_dict.values(), key=len)
    
    # Create ordered master list of hash pairs from smallest dataset
    master_hashes = [(hashlib.sha256(q.encode()).hexdigest(),
                     hashlib.sha256(p.encode()).hexdigest())
                    for q, p in zip(smallest_dataset['query'], 
                                  smallest_dataset['positive'])]
    
    # Find matching indices in each dataset that preserve order
    aligned_datasets = {}
    for name, dataset in datasets_dict.items():
        current_hashes = [(hashlib.sha256(q.encode()).hexdigest(),
                         hashlib.sha256(p.encode()).hexdigest())
                        for q, p in zip(dataset['query'], dataset['positive'])]
        
        # Keep track of which master hashes we've found
        found_indices = []
        for master_hash in master_hashes:
            if master_hash in current_hashes:
                idx = current_hashes.index(master_hash)
                found_indices.append(idx)
        
        aligned_datasets[name] = dataset.select(found_indices)
    
    def verify_alignment():
        print("\nAlignment Statistics:")
        sizes = [len(dataset) for dataset in aligned_datasets.values()]
        print(f"All datasets aligned to size: {sizes[0]}")
        assert all(size == sizes[0] for size in sizes), "Dataset sizes don't match"
        
        base_hashes = [(hashlib.sha256(q.encode()).hexdigest(),
                       hashlib.sha256(p.encode()).hexdigest())
                      for q, p in zip(list(aligned_datasets.values())[0]['query'],
                                    list(aligned_datasets.values())[0]['positive'])]
        
        for dataset in aligned_datasets.values():
            current_hashes = [(hashlib.sha256(q.encode()).hexdigest(),
                             hashlib.sha256(p.encode()).hexdigest())
                            for q, p in zip(dataset['query'], dataset['positive'])]
            assert base_hashes == current_hashes, "Hash sequences don't match exactly"
    
    verify_alignment()
    return aligned_datasets

def save_aligned_data(aligned_splits: dict, base_path: str) -> None:
    """Save each dataset as a DatasetDict to disk."""
    for dataset_name, dataset in aligned_splits.items():
        save_path = f"{base_path}/{dataset_name}"
        dataset.save_to_disk(save_path)

Essentially, we create a queue/job process that attempts to overload the model/faiss bottlenecks. We allocate a number of processes to our gpus where n_gpus == n_processes. You have to know size of model/datasizing so you’re not having your gpu processes compete for non-existant memory (eg 20 processes while only having 8 GPUs), so back-of-the-napkin math means just keep it simple and set n_processes equal to n_gpus.

We multiprocess on each gpu in our stack (8x) and first find hard-negatives via BM25, then a model alone, then a model + re-ranker.

Here is an image of what it looks like (or at least what I hope to see), that the reranker will provide the best result for the given query - caveat: to the astute reader, we’re using an untuned reranker (eg not tuned for law data) - so I don’t expect it to be the best, I just expect it to do better than average. But in reality, I have no real idea what to expect, I sort of just want to ‘see with my own eyes what happens’ when we play around with various hard-negative mining strategies.

ok, less pontificating, more code:

Define our dataset/model configs

dataset_config = {
    "name": "joelniklaus/legal_case_document_summarization",
    "splits": ["train", "test"],
    "pairs": [("judgement", "summary")]
}
num_gpus=8
batch_size=1 
embedding_size=1536
model_id = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"

Create our datasets

# create our base bm25 dataset
bm25_dataset = create_bm25_data(dataset_config, save_name="legal_case_document_summarization_bm25")

# create our faiss index using our main
doc_list, doc_to_idx, idx_to_doc, query_idx_relations = create_document_list_with_relations(bm25_dataset)

index, faiss_to_doc_idx = create_fast_index(
  doc_list, num_gpus=num_gpus, batch_size=batch_size, 
  embedding_size=embedding_size, model_id=model_id
)

_ = create_parallel_hard_negatives("embeddings.index", faiss_to_doc_idx, doc_list, query_idx_relations, num_gpus=num_gpus, rerank=True, save_both=True, faiss_gpu=True, save_name="legal_case_document_summarization", k_hard=1, random_n=0)

(our datasets will be created at this point.)

super simple training script

import argparse

from datasets import load_from_disk
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

import torch
torch.set_float32_matmul_precision('high')

def train(dataset_path, dataset_name):
    torch.set_float32_matmul_precision('high')
    # parse the lr & model name
    lr = 8e-5
    model_name = "answerdotai/ModernBERT-base"
    model_shortname = model_name.split("/")[-1]

    # 1. Load a model to finetune
    model = SentenceTransformer(model_name)

    # 2. Load a dataset to finetune on
    train_dataset = load_from_disk(dataset_path)

    # 3. Define a loss function
    loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16)  # Increase mini_batch_size if you have enough VRAM

    run_name = f"{model_shortname}-{dataset_name}"
    # 4. (Optional) Specify training arguments
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"output/{model_shortname}/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        warmup_ratio=0.05,
        fp16=False,  # Set to False if GPU can't handle FP16
        bf16=True,  # Set to True if GPU supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
        learning_rate=lr,
        # Optional tracking/debugging parameters:
        save_strategy="steps",
        save_steps=200,
        save_total_limit=2,
        logging_steps=200,
        run_name=run_name,  # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
    )

    # 6. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=loss,
    )
    trainer.train()

    # 8. Save the model
    model.save_pretrained(f"output/{run_name}/final")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--dataset_name", type=str)
    args = parser.parse_args()
    train(args.dataset_path, args.dataset_name)

the unresonable effectiveness of bm25?

In my opinion, there’s no real edge here to anyone - we can discount AILA/LegalQuAD for now - and notice that in the other three tasks, each method (bm25, rerank, no-rerank) gets one victory a piece.

As for the foreign legal data, bm25 does a better on two indian legal tasks, but on the german task, bm25 is the worst.

I do notice one interesting thing on all tasks, however, and that’s rerank and no-rerank seem to be inversely correlated on certain performance spikes: when rerank goes up, no-rerank goes down and vise-versa - so let’s see what happens when we create combinations of these pairs and train with that data and see what happens:

Cool! Much better performance. It seems that some combination of the simplicity of bm25 + the rerank method yield superior performance in all areas except Legal Summarization - which is interesting considering our training dataset is exclusively legal summarization data, so contextually I suppose that sort of makes sense. We don’t know however if this was due exclusively to having more negative examples per query, so we’ll have to test with random data (1 vs. 2. vs. 3, etc) and see if it’s just the presence of an extra example that makes the difference – we’ll test on that later.

The other idea I had was to let these models overtrain for many epochs ‘just to see what happens’. So you’ll notice that in all the models, they’re trained for 30 epochs when that’s probably far, far too much. The reason for this is I sort of had a hunch that because bm25 is “semantically simpler”, it should be the ‘most resistant to overfitting’ as compared to the other two models. So I wanted to see if the model would overfit to the training data and then what would happen if we let it overtrain for a long time. It turned out that’s not necessarily the case, rerank actually performs better than bm25 on the subtasks we’re most curious about.

Also, and some what serenditiously, based on the observations a poster made on a random twitter thread, I was curious what would happen if we average/ensemble the weights of the three ((bm25 + rerank + no-rerank) / 3) at each checkpoint and measure the perf:

Much better on the foreign legal tasks - and the take away is probably that if you’re compute constrained and want the best performing model, this is probably your best bet all things considered: simply take the average of various data-mixes. It kind of goes to show that implicit regularization is so unbelievably powerful that it should probably be one of your first go to perf increase tools.

Intuitively, the way I’ve been thinking about this is that bm25 picked hard-negatives alone is probably some form of regularization in some simplistic capacity.

A few interesting rabit hole ideas

Looking at the training curves, an idea I had was to see if we could somehow use the stability of bm25 and use varying degrees of noise/hard-negatives to get more lift in performance - as well as test varying amounts of hard-negatives per query:

what impact does bm25 have when it’s supplemented by specifically chosen vs random hardnegs;

what if we progressively add more hard-negatives as the training steps move forward (epoch 1: 1 hardneg, epoch 2: 2 hardnegs, … epoch 2000: 2000 hardnegs),

we make it, theoretically more difficult/different looks as training continues: either by always showing the model a new batch of hard-negatives, or by increasing the number of hard-negatives per batch?

At this point I’m sort of just trying things based on the loss-curves/eval scores - like I notice that bm25 is sort of resistant to overtraining or at least has good stable training grad-norms, so I sort of just picked a random idea to start and let the research muse take me where it meant to take me after that.

CODE

I started writing one comprehensive/thoughtful code base, but when the research god whispers, you hack something together and just run the script.

Below are all the pieces of code for all the random experiments I ran in no particular order::

Random

At the end of each epoch, we take each sample (query, positive) in our dateset and we randomly chose a hard-negative to pair with the query/positive pair. (this method did horribly)

#pseudo code:
def on_epoch_end(self, **kwargs):
    for qid, relations in self.query_idx_relations:
        cache_entries = self.query_caches[qid]
        current_len = len(cache_entries)
        if current_len < self.k_cache:
            needed = self.k_cache - current_len
            # Example: random pick from global_pool
            import random
            candidates = set(self.idx_to_doc) - set(relations['positive'])
            new_candidates = random.sample(candidates, min(len(candidates), needed))
            # Compute hardness for each new candidate
            q_text = self.doc_list[qid]
            pos_text = self.doc_list[list(relations['positive'])[0]]
            new_entries = []
            for cand_idx in new_candidates:
                cand_text = self.doc_list[cand_idx]
                new_entries.append(cand_text)
            self.query_caches[qid].extend(new_entries)
            ....

Model-Based Mining:

At the end of each epoch (or after some time, we can set the start time with ‘epoch_start’), for each piece of text w/in our dataset ({query1, query2, ..,positive1, positive2, queryN, positiveN}), we take the current model and we embed the entire dataset. For each original query/positive pair, we find the K most similar texts for each query that isn’t it’s corresponding positive. We add some parameters to select the number K and if we want to add random hard-negatives as noise.

class HardNegMiningCallback(TrainerCallback):
    """
    A callback that, at the end of each epoch:
      - If we're on the main process, sets CUDA_VISIBLE_DEVICES to '6,7'
      - Calls create_hardnegs(...)
      - Waits for completion
      - Reloads the dataset with newly generated negatives
    """

    def __init__(
        self,
        trainer,                  # <--- pass in the trainer itself
        create_hardnegs_fn,
        original_dataset_path,
        adaptive_dataset_path,
        model_path=None,
        gpus_for_hardneg="6,7",
        reload_after=True,
        num_gpus_for_hardneg=2,
        batch_size=256,
        embedding_size=1536,
        epoch_start=1,
        k_hard=1,
        random_n=0,
    ):
        """
        Args:
            trainer: The HF Trainer (or a subclass) that we want to manipulate.
            create_hardnegs_fn: A function with signature 
                create_hardnegs(dataset_path, save_path, model_path, num_gpus, batch_size, embedding_size)
            original_dataset_path: path to the dataset we started training on
            adaptive_dataset_path: where we save the newly created dataset with fresh negatives
            model_path: path to the current model or checkpoint (trainer.args.output_dir by default)
            gpus_for_hardneg: which GPUs (by index) to use for negative generation (e.g. "6,7")
            reload_after: if True, reload the newly generated dataset into trainer.train_dataset
            num_gpus_for_hardneg: how many GPUs we actually want to use inside create_hardnegs
            batch_size, embedding_size: additional parameters to pass to create_hardnegs
            epoch_start: only start doing negative generation after this epoch number
            k_hard: number of hard negatives to generate per query
        """
        super().__init__()
        self.trainer = trainer  # store a reference to the trainer
        self.create_hardnegs_fn = create_hardnegs_fn
        self.original_dataset_path = original_dataset_path
        self.adaptive_dataset_path = adaptive_dataset_path
        self.model_path = model_path
        self.gpus_for_hardneg = gpus_for_hardneg
        self.reload_after = reload_after
        self.num_gpus_for_hardneg = num_gpus_for_hardneg
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.epoch_start = epoch_start
        self.k_hard = k_hard
        self.random_n = random_n

        self.epoch_counter = 0

    def on_epoch_end(
        self,
        args,
        state: TrainerState,
        control: TrainerControl,
        **kwargs
    ):
        self.epoch_counter += 1
        if self.epoch_counter < self.epoch_start:
            return  # skip until epoch_start is reached

        # Determine if we're rank=0 in distributed training
        # Hugging Face sets this in TrainerState or we can check:
        # if self.trainer.is_world_process_zero():
        #   or
        # if state.is_world_process_zero
        # (They generally do the same in modern versions.)
        is_main_process = getattr(self.trainer, "is_world_process_zero", None)
        if callable(is_main_process):
            is_main_process = self.trainer.is_world_process_zero()
        elif hasattr(state, "is_world_process_zero"):
            is_main_process = state.is_world_process_zero
        else:
            # Fallback: assume single process or assume main
            is_main_process = True

        if is_main_process:
            print(f"\n[HardNegMiningCallback] Epoch {self.epoch_counter} ended. Creating new hard negatives on GPUs {self.gpus_for_hardneg}.")

            # 1) Save the original CUDA_VISIBLE_DEVICES so we can restore it later
            original_cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
            print(f"[HardNegMiningCallback] Original CUDA_VISIBLE_DEVICES: {original_cuda_env}")
            try:
                # 2) Set env so that only GPUs 6,7 are visible for negative generation
                os.environ["CUDA_VISIBLE_DEVICES"] = self.gpus_for_hardneg
                print(f"[HardNegMiningCallback] Setting CUDA_VISIBLE_DEVICES to {os.environ['CUDA_VISIBLE_DEVICES']}")

                # 3) Determine the current model checkpoint path
                model_ckpt_path = self.model_path or self.trainer.args.output_dir

                # we get the final checkpoint h
                checkpoints = [d for d in os.listdir(model_ckpt_path) if d.startswith('checkpoint-')]
                
                # Sort checkpoints by epoch number
                checkpoints.sort(key=lambda x: int(x.split('-')[1]))
                # get last checkpoint
                checkpoint_path = os.path.join(model_ckpt_path, checkpoints[-1])

                # 4) Call the user-provided function
                self.create_hardnegs_fn(
                    dataset_path=self.original_dataset_path,
                    save_path=self.adaptive_dataset_path,
                    model_path=checkpoint_path,
                    num_gpus=self.num_gpus_for_hardneg,
                    batch_size=self.batch_size,
                    embedding_size=self.embedding_size,
                    k_hard=self.k_hard,
                    random_n=self.random_n,
                )
            finally:
                # 5) Restore original CUDA_VISIBLE_DEVICES for training
                os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_env
                print(f"[HardNegMiningCallback] Restored CUDA_VISIBLE_DEVICES to {os.environ['CUDA_VISIBLE_DEVICES']}")

        # 6) Optional: Wait for all ranks if in distributed mode
        if dist.is_initialized():
            dist.barrier()

        # 7) Reload the updated dataset for the next epoch
        if self.reload_after:
            print(f"[HardNegMiningCallback] Reloading dataset from {self.adaptive_dataset_path}")
            updated_dataset = load_from_disk(self.adaptive_dataset_path)
            self.trainer.train_dataset = updated_dataset
            # Some versions store internally in `_train_dataset`
            self.trainer._train_dataset = updated_dataset  
            self.trainer.train_dataloader = self.trainer.get_train_dataloader()

            # If you want to keep updating original_dataset_path to the new path each time
            self.original_dataset_path = self.adaptive_dataset_path

    def build_new_training_lr_schedule(self, new_dataset):
        current_lr = self.trainer.optimizer.param_groups[0]['lr']
        
        # Calculate new total steps based on new dataset size
        num_training_steps = (
            len(new_dataset) 
            * self.trainer.args.num_train_epochs 
            // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps)
        )
        
        # Create new scheduler starting from current lr
        # update the scheduler to accompany the new size
        new_scheduler = get_cosine_schedule_with_warmup(
            self.trainer.optimizer,
            num_warmup_steps=0,  # No warmup since we're continuing
            num_training_steps=num_training_steps,
            initial_lr=current_lr
        )
        
        self.trainer.lr_scheduler = new_scheduler

Margin Based Random Mining:

So in this version, we do the same model-based mining as before, but we add a margin parameter to compute whether the model has ‘learned’ the example or not - and instead of removing the triplet prematurely, we keep continually showing the model the same triplet if the margin is over a certain threshold, if not, we add a new randomly selected hard-negative. The thgouht here is want to keep only truly hard negatives in the cache and vacate those that the model learns.

class AdaptiveHardNegMiningCallback(TrainerCallback):
    """
    Maintains a per-query cache of negatives. On each epoch end (or step end), it:
      - samples from the cache for training
      - updates each query's cache by removing "easy" negatives
      - periodically refills from a global pool or partial re-embedding
    """
    def __init__(
        self,
        trainer,                  # reference to the Trainer (or your model)
        train_dataset,                 # list of query texts or IDs
        model,                   # the retrieval model or sentence-transformer
        k_cache=5,               # how many negatives we keep in each cache
        remove_threshold=-0.05,  # if margin < -0.05 => negative is "easy"
        refill_every=2,          # every 2 epochs, we try to refill or re-embed

        create_hardnegs_fn=None,
        original_dataset_path=None,
        adaptive_dataset_path=None,
        model_path=None,
        gpus_for_hardneg="6,7",
        reload_after=True,
        num_gpus_for_hardneg=2,
        batch_size=256,
        embedding_size=1536,
        epoch_start=1,
        k_hard=1,
        random_n=0,
    ):
        super().__init__()
        self.trainer = trainer
        self.model = model
        #self.global_pool = global_pool or []
        self.k_cache = k_cache
        self.remove_threshold = remove_threshold
        self.refill_every = refill_every


        self.create_hardnegs_fn = create_hardnegs_fn
        self.original_dataset_path = original_dataset_path
        self.adaptive_dataset_path = adaptive_dataset_path
        self.model_path = model_path
        self.gpus_for_hardneg = gpus_for_hardneg
        self.reload_after = reload_after
        self.num_gpus_for_hardneg = num_gpus_for_hardneg
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.epoch_start = epoch_start
        self.k_hard = k_hard
        self.random_n = random_n


        # Build initial caches
        # Each cache is a list of (neg_text, hardness) or just neg_text
        self.doc_list, self.doc_to_idx, self.idx_to_doc, self.query_idx_relations = create_document_list_with_relations(train_dataset)
        
        # alld

        self.query_caches = {}
        for qid in self.query_idx_relations:
            # pick up to k_cache from initial_negatives
            negs = self.query_idx_relations['negatives'][:k_cache]
            self.query_caches[qid] = [(n, 999.0) for n in negs]  # hardness placeholder, # (int, int)

        self.epoch_counter = 0

    def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Optionally, we can sample a brand-new training dataset from caches here.
        Or do it on_step_end if you prefer a finer granularity.
        """
        # We can rebuild or shuffle the "train_dataset" from these caches
        self._build_train_dataset_from_caches()
        # No direct return needed, the trainer will continue

    def on_epoch_end(
        self,
        args,
        state: TrainerState,
        control: TrainerControl,
        **kwargs
    ):

        self.epoch_counter += 1
        # 1) Update caches for each query -> remove easy negatives
        self._update_caches_remove_easy()

        # 2) Periodically refill caches from global pool or partial re-embedding
        if (self.epoch_counter % self.refill_every) == 0:
            self._refill_caches()

        # 3) Optionally: re-build training dataset from updated caches
        # so next epoch sees the new negatives
        self._build_train_dataset_from_caches()

    ################################################################
    #  Internals
    ################################################################
    def _build_train_dataset_from_caches(self):
        """
        Rebuild a 'training dataset' from the caches if we want to feed it
        into Trainer. We can do something like:
          train_dataset = List of InputExample(query, pos, neg)
        Then: self.trainer.train_dataset = <some huggingface Dataset>
        """
        new_samples = []
        for qid, relations in self.query_idx_relations:
            cache_entries = self.query_caches[qid]
            for (neg_idx, hardness) in cache_entries:
                new_example = {
                    'query': self.doc_list[qid],
                    'positive': self.doc_list[list(relations['positive'])[0]],
                    'negative': self.doc_list[neg_idx]
                }
                new_samples.append(new_example)

        # Now wrap that in a HF Dataset, or a custom dataset for your trainer
        # For quick pseudo-code:
        from datasets import Dataset
        train_dataset = Dataset.from_list(new_samples)
        self.trainer.train_dataset = train_dataset
        self.trainer._train_dataset = train_dataset  # if needed
        # You might want to re-init dataloader:
        self.trainer.train_dataloader = self.trainer.get_train_dataloader()

    def _update_caches_remove_easy(self):
        """
        For each (query, neg), check if it's now 'easy' for the model:
         - if margin < remove_threshold => remove from the cache
         - otherwise, keep it (it's still hard)
        """
        for qid, relations in self.query_idx_relations:
            new_cache = []
            q_text = self.doc_list[qid]
            pos_text = self.doc_list[list(relations['positive'])[0]]
            for (neg_id, old_hardness) in self.query_caches[qid]:
                neg_text = self.doc_list[neg_id]
                margin = compute_hardness(self.model, q_text, pos_text, neg_text)
                if margin < self.remove_threshold:
                    # It's become easy => we remove it
                    pass
                else:
                    # It's still hard, keep it
                    new_cache.append((neg_text, margin))
            self.query_caches[qid] = new_cache

    def _refill_caches(self):
        """
        If the caches have fewer than k_cache negatives, we pick new ones from
        the global pool or re-embed & find new top-K. For simplicity, let's do random
        from global_pool here.
        """
        # can do re-embed here
        for qid, relations in self.query_idx_relations:
            cache_entries = self.query_caches[qid]
            current_len = len(cache_entries)
            if current_len < self.k_cache:
                needed = self.k_cache - current_len
                # Example: random pick from global_pool
                import random
                candidates = set(self.idx_to_doc) - set(relations['positive'])
                new_candidates = random.sample(candidates, min(len(candidates), needed))
                # Compute hardness for each new candidate
                q_text = self.doc_list[qid]
                pos_text = self.doc_list[list(relations['positive'])[0]]
                new_entries = []
                for cand_idx in new_candidates:
                    cand_text = self.doc_list[cand_idx]
                    margin = compute_hardness(self.model, q_text, pos_text, cand_text)
                    new_entries.append((cand_text, margin))

                self.query_caches[qid].extend(new_entries)

    def build_new_training_lr_schedule(self, new_dataset):
        current_lr = self.trainer.optimizer.param_groups[0]['lr']
        
        # Calculate new total steps based on new dataset size
        num_training_steps = (
            len(new_dataset) 
            * self.trainer.args.num_train_epochs 
            // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps)
        )
        
        # Create new scheduler starting from current lr
        # update the scheduler to accompany the new size
        new_scheduler = get_cosine_schedule_with_warmup(
            self.trainer.optimizer,
            num_warmup_steps=0,  # No warmup since we're continuing
            num_training_steps=num_training_steps,
            initial_lr=current_lr
        )
        
        self.trainer.lr_scheduler = new_scheduler

Margin Based Bandit Mining:

In this version, we combine both the original model-based mining and the margin-based mining and instead of subjectively picking a margin threshold to empty the hard-negative cache for a given sample, we use an Epsilon Greedy Bandit to pick the margin value. This one is cool, so I’m patting myself on the back for this.

Epsilon Greedy knowledge check: we provide a list of arms where at each index is the margin threshold that we’ve assigned to that arm, we also provide an epsilon value - which is the probability for exploring instead of taking the highest reward, and an initial reward of 0 for each arm. On the epoch’s start, we use the select arm function which picks the given arm either randomly based on epsilon or best performing arm (the one with the max qvalues).

On the epoch end, we update the q-value for the chosen arm based on the reward from our eval set.

class EpsilonGreedyBandit:
    """
    Each arm is a candidate threshold.
    We track (mean_reward, count) for each arm.
    We pick arms via epsilon-greedy.
    """
    def __init__(self, arms, epsilon=0.1, initial_value=0.0):
        """
        :param arms: list of thresholds (floats)
        :param epsilon: probability of exploring
        :param initial_value: initial assumed reward for each arm
        """
        self.arms = arms
        self.epsilon = epsilon
        self.q_values = [initial_value for _ in arms]  # estimated mean reward
        self.counts = [0 for _ in arms]                # how many times each arm was used

    def select_arm(self):
        """
        Epsilon-greedy selection.
        """
        if random.random() < self.epsilon:
            # Explore
            return random.randint(0, len(self.arms)-1)
        else:
            # Exploit (pick best so far)
            return int(np.argmax(self.q_values))

    def update(self, arm_index, reward):
        """
        Update estimates for the chosen arm.
        """
        self.counts[arm_index] += 1
        c = self.counts[arm_index]
        old_q = self.q_values[arm_index]
        # incremental mean update
        new_q = old_q + (reward - old_q) / c
        self.q_values[arm_index] = new_q

    def get_arm_value(self, arm_index):
        return self.arms[arm_index]

    def __repr__(self):
        return f"EpsilonGreedyBandit(arms={self.arms}, q_values={self.q_values}, counts={self.counts})"


class AdaptiveBanditHardNegMiningCallback(TrainerCallback):
    """
    Maintains a per-query cache of negatives. 
    Uses a bandit to pick 'remove_threshold' at the start of each epoch.
    """

    def __init__(
        self,
        trainer,  # reference to the Trainer
        train_dataset,
        model,
        k_cache=5,
        bandit_arms=None,    # list of candidate thresholds
        epsilon=0.1,
        refill_every=2,
        eval_metric_key="eval_loss"  # or "eval_accuracy", "eval_mrr", etc.
    ):
        super().__init__()
        self.trainer = trainer
        self.model = model
        self.k_cache = k_cache
        self.refill_every = refill_every
        self.eval_metric_key = eval_metric_key

        # 1) Build bandit
        if bandit_arms is None:
            # Provide default arms if not specified
            bandit_arms = [-0.1, -0.05, 0.0, 0.05, 0.1]
        self.bandit = EpsilonGreedyBandit(arms=bandit_arms, epsilon=epsilon)

        # 2) Build initial caches, etc. (similar to your original code)
        self.doc_list, self.doc_to_idx, self.idx_to_doc, self.query_idx_relations = create_document_list_with_relations(train_dataset)

        self.query_caches = {}
        for qid in self.query_idx_relations:
            negs = list(self.query_idx_relations[qid]['negative'])  # example
            # pick up to k_cache
            chosen = negs[:k_cache]
            self.query_caches[qid] = [(nid, 999.0) for nid in chosen]

        self.epoch_counter = 0
        self.current_arm_index = None  # which threshold arm we used this epoch
        self.remove_threshold = None   # the actual threshold for this epoch

    def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        """
        1) Pick an arm from bandit => set remove_threshold
        2) Rebuild the training dataset
        """
        self.epoch_counter += 1

        # (A) Bandit chooses which arm to use
        self.current_arm_index = self.bandit.select_arm()
        self.remove_threshold = self.bandit.get_arm_value(self.current_arm_index)
        print(f"[BanditCallback] Epoch={self.epoch_counter}, chosen threshold={self.remove_threshold:.3f}")

        # (B) Re-build / shuffle training data
        self._build_train_dataset_from_caches()

    def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        """
        1) Update caches => remove easy negatives
        2) Optionally refill
        3) Evaluate => measure reward => update bandit
        """
        # 1) Remove easy negatives
        self._update_caches_remove_easy_batch()

        # 2) Refill if needed
        if (self.epoch_counter % self.refill_every) == 0:
            self._refill_caches_batch()

        # 3) Evaluate => compute reward
        # We can do a trainer.evaluate(...) or rely on existing results in state.log_history
        # For example, let's see if we have an eval result in state.log_history
        # We'll find the last logged eval_metric_key
        reward = 0.0
        for entry in reversed(state.log_history):
            if self.eval_metric_key in entry:
                reward = entry[self.eval_metric_key]
                break

        print(f"[BanditCallback] Epoch={self.epoch_counter}, performance metric={reward:.4f}")

        # 4) Update bandit with reward
        self.bandit.update(self.current_arm_index, reward)

        # Optionally re-build the dataset after removing negatives
        self._build_train_dataset_from_caches()

    ############################################
    #   The same helper methods as before
    ############################################
    def _build_train_dataset_from_caches(self):
        # Rebuild a 'training dataset' from caches
        new_samples = []
        for qid, rels in self.query_idx_relations.items():
            cache_entries = self.query_caches[qid]
            q_text = self.doc_list[qid]
            # pick 1 positive from rels
            pos_text_id = list(rels['positive'])[0]
            pos_text = self.doc_list[pos_text_id]

            for (neg_id, hardness) in cache_entries:
                neg_text = self.doc_list[neg_id]
                new_samples.append({
                    "query": q_text,
                    "positive": pos_text,
                    "negative": neg_text
                })

        from datasets import Dataset
        train_dataset = Dataset.from_list(new_samples)

        self.trainer.train_dataset = train_dataset
        self.trainer._train_dataset = train_dataset
        self.trainer.train_dataloader = self.trainer.get_train_dataloader()

    def _update_caches_remove_easy(self):
        for qid, rels in self.query_idx_relations.items():
            q_text = self.doc_list[qid]
            pos_text_id = list(rels['positive'])[0]
            pos_text = self.doc_list[pos_text_id]
            new_cache = []

            for (neg_id, old_hardness) in self.query_caches[qid]:
                neg_text = self.doc_list[neg_id]
                margin = compute_hardness(self.model, q_text, pos_text, neg_text)
                # use self.remove_threshold from the bandit
                if margin < self.remove_threshold:
                    # remove
                    pass
                else:
                    new_cache.append((neg_id, margin))

            self.query_caches[qid] = new_cache

    def _refill_caches(self):
        # just a skeleton: pick random negatives from the corpus to fill up to k_cache
        import random
        for qid, rels in self.query_idx_relations.items():
            current_len = len(self.query_caches[qid])
            if current_len < self.k_cache:
                needed = self.k_cache - current_len
                q_text = self.doc_list[qid]
                pos_text = self.doc_list[list(rels['positive'])[0]]

                # randomly sample from all doc_ids except positives or qid
                all_candidates = list(set(self.doc_to_idx.values()) - {qid} - set(rels['positive']))
                new_cands = random.sample(all_candidates, min(len(all_candidates), needed))
                new_entries = []
                for cidx in new_cands:
                    cand_text = self.doc_list[cidx]
                    m = compute_hardness(self.model, q_text, pos_text, cand_text)
                    new_entries.append((cidx, m))
                self.query_caches[qid].extend(new_entries)

    def _refill_caches_batch(self):
        """
        Similar to _refill_caches, but we do all margin computations in batch mode
        to reduce overhead.
        """

        import random

        batch_q = []
        batch_p = []
        batch_n = []
        index_map = []  # (qid, candidate_neg_id, index_in_batch)

        # 1) Gather all needed candidates from each query
        for qid, rels in self.query_idx_relations.items():
            current_len = len(self.query_caches[qid])
            needed = self.k_cache - current_len
            if needed <= 0:
                continue

            q_text = self.doc_list[qid]
            pos_text = self.doc_list[list(rels['positive'])[0]]

            # Sample from all candidates except qid or positives
            all_candidates = (
                set(self.doc_to_idx.values()) 
                - {qid} 
                - set(rels['positive'])
            )
            if len(all_candidates) == 0:
                continue
            
            needed_cands = random.sample(all_candidates, min(len(all_candidates), needed))

            # 2) Collect them in a batch list
            for cand_idx in needed_cands:
                cand_text = self.doc_list[cand_idx]
                batch_q.append(q_text)
                batch_p.append(pos_text)
                batch_n.append(cand_text)
                index_map.append((qid, cand_idx))

        if not batch_q:
            # nothing to refill
            return

        # 3) Batch compute margins in chunks
        chunk_size = 256
        margin_list = []
        for start in range(0, len(batch_q), chunk_size):
            end = start + chunk_size
            q_chunk = batch_q[start:end]
            p_chunk = batch_p[start:end]
            n_chunk = batch_n[start:end]
            margins = compute_batch_hardness(self.model, q_chunk, p_chunk, n_chunk)
            margin_list.extend(margins)

        # 4) Assign margins back to the caches
        # index_map[i] -> (qid, cand_idx)
        # margin_list[i] -> margin
        for i, (qid, cand_idx) in enumerate(index_map):
            margin_val = margin_list[i]
            # just append to the cache
            self.query_caches[qid].append((cand_idx, margin_val))


    def _update_caches_remove_easy_batch(self):
        # 1) Gather all (q_text, p_text, n_text) in a big list
        batch_q = []
        batch_p = []
        batch_n = []
        index_map = []  # keep track of (qid, neg_id, index_in_batch)

        for qid, rels in self.query_idx_relations.items():
            q_text = self.doc_list[qid]
            pos_text_id = list(rels['positive'])[0] 
            p_text = self.doc_list[pos_text_id]
            
            new_cache_entries = []
            for idx, (neg_id, old_hardness) in enumerate(self.query_caches[qid]):
                neg_text = self.doc_list[neg_id]
                batch_q.append(q_text)
                batch_p.append(p_text)
                batch_n.append(neg_text)
                index_map.append((qid, idx))  # so we know which cache entry this corresponds to

        # 2) Possibly chunk it if it's very large
        margin_list = []
        chunk_size = 256
        for start in range(0, len(batch_q), chunk_size):
            end = start + chunk_size
            q_chunk = batch_q[start:end]
            p_chunk = batch_p[start:end]
            n_chunk = batch_n[start:end]
            margins = compute_batch_hardness(self.model, q_chunk, p_chunk, n_chunk)
            margin_list.extend(margins)

        # margin_list is now same length as index_map
        # 3) Go back through index_map and remove easy
        for (i, (qid, cache_idx)) in enumerate(index_map):
            margin = margin_list[i]
            if margin < self.remove_threshold:
                # mark for removal
                self.query_caches[qid][cache_idx] = None
            else:
                # store updated hardness
                self.query_caches[qid][cache_idx] = (self.query_caches[qid][cache_idx][0], margin)

        # 4) filter out removed
        for qid in self.query_idx_relations:
            self.query_caches[qid] = [x for x in self.query_caches[qid] if x is not None]

    def build_new_training_lr_schedule(self, new_dataset):
        current_lr = self.trainer.optimizer.param_groups[0]['lr']
        
        # Calculate new total steps based on new dataset size
        num_training_steps = (
            len(new_dataset) 
            * self.trainer.args.num_train_epochs 
            // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps)
        )
        
        # Create new scheduler starting from current lr
        # update the scheduler to accompany the new size
        new_scheduler = get_cosine_schedule_with_warmup(
            self.trainer.optimizer,
            num_warmup_steps=0,  # No warmup since we're continuing
            num_training_steps=num_training_steps,
            initial_lr=current_lr
        )
        
        self.trainer.lr_scheduler = new_scheduler

Margin Bandit Mining w/ Self Selection using FAISS:

We use the same bandit based structure from Margin Based Bandit Mining but instead of only selecting a random example, we use the same embedding strategy from Model-Based Mining to select the K most difficult hard-negatives using the model itself + any random noise we might want to add via random_n. So this one we basically have a combination of self reinforcement learning with faiss + reinforcement learning on the hard-negative parameters themselves. This was really cool and I think it’s a really good way to do bandit mining. The problem is it’s slow and incremental w/ respect to to choosing parameters. I think what I’ll probably do is take this method, and sort of figure out ‘scaling laws’ and see how to make good cold-start parameter choices given the known data/params/etc.


class AdaptiveBanditFaissCallback(TrainerCallback):
    """
    1) Each epoch-end, re-embed entire doc_list using the current model => build a FAISS index
    2) Maintain a local 'query_caches' for each query => do remove_easy + refill
       using the FAISS index to find truly "hard" negatives
    3) Re-build the trainer's dataset from caches => next epoch sees it
    4) Use a bandit for dynamic 'remove_threshold'
    """

    def __init__(
        self,
        trainer,
        train_dataset,
        model_path,
        k_cache=5,
        bandit_arms=None,
        epsilon=0.1,
        refill_every=2,
        eval_metric_key="eval_loss",
        epoch_start=1,
        embedding_dim=768,
        chunk_size=256,
        create_hardnegs_fn=create_parallel_hard_negatives,

        #create-hardnegs-function
        original_dataset_path=None,
        adaptive_dataset_path=None,
        gpus_for_hardneg="6,7",
        reload_after=True,
        num_gpus_for_hardneg=2,
        batch_size=256,
        embedding_size=1536,
        k_hard=1,
        random_n=0,

    ):
        super().__init__()
        self.trainer = trainer
        self.model_path = model_path
        self.k_cache = k_cache
        self.refill_every = refill_every
        self.eval_metric_key = eval_metric_key
        self.epoch_start = epoch_start
        self.embedding_dim = embedding_dim
        self.chunk_size = chunk_size
        self.create_hardnegs_fn = create_hardnegs_fn

        # dataset args
        self.original_dataset_path = original_dataset_path
        self.adaptive_dataset_path = adaptive_dataset_path
        self.num_gpus_for_hardneg = num_gpus_for_hardneg
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.k_hard = k_hard
        self.random_n = random_n
        self.gpus_for_hardneg = gpus_for_hardneg
        self.reload_after = reload_after

        # 1) Bandit
        if bandit_arms is None:
            bandit_arms = [-0.1, -0.05, 0.0, 0.05, 0.1]
        self.bandit = EpsilonGreedyBandit(arms=bandit_arms, epsilon=epsilon)

        # 2) Basic data

        # 2) Build initial caches, etc. (similar to your original code)
        self.doc_list, self.doc_to_idx, self.idx_to_doc, self.query_idx_relations = create_document_list_with_relations(train_dataset)

        self.query_caches = {}
        for qid in self.query_idx_relations:
            negs = list(self.query_idx_relations[qid]['negative'])  # example
            # pick up to k_cache
            chosen = negs[:k_cache]
            self.query_caches[qid] = [(nid, 999.0) for nid in chosen]

        # For doc -> index row mapping:
        #   a simple approach: doc_id == index in doc_list
        #   Or if you have doc_to_idx, store that here.   
        self.epoch_counter = 0
        self.remove_threshold = None
        self.current_arm_index = None

        # 4) Build initial training dataset from caches
        self._build_train_dataset_from_caches()

    def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        self.epoch_counter += 1
        if self.epoch_counter < self.epoch_start:
            return

        # Bandit picks threshold
        self.current_arm_index = self.bandit.select_arm()
        self.remove_threshold = self.bandit.get_arm_value(self.current_arm_index)
        print(f"[BanditFaissCallback] Epoch={self.epoch_counter}, chosen threshold={self.remove_threshold:.3f}")

    def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        """
        1) Re-embed entire doc_list => build FAISS index
        2) remove_easy_batch
        3) refill_caches_batch (using FAISS to get truly 'hard' negatives)
        4) measure reward => update bandit
        5) rebuild training dataset => next epoch
        """
        if self.epoch_counter < self.epoch_start:
            return

        # (A) measure reward from log_history
        reward = 0.0
        for entry in reversed(state.log_history):
            if self.eval_metric_key in entry:
                reward = entry[self.eval_metric_key]
                break
        print(f"[BanditFaissCallback] Epoch={self.epoch_counter}, performance metric={reward:.4f}")

        # (B) bandit update
        self.bandit.update(self.current_arm_index, reward)

        # (C) Only run on main process
        if self._is_main_process(state):
            # re-embed doc_list => build FAISS index
            print(f"[BanditFaissCallback] Re-embedding entire doc_list to build FAISS index.")
            self._build_faiss_index()  # produce self.faiss_index, plus doc_id->faiss_id?

        if dist.is_initialized():
            dist.barrier()

        # (D) remove easy negatives in batch (using the newly built index for margin computations)
        self._update_caches_remove_easy_batch()

        # (E) refill from FAISS (rather than random), picking truly "hard" negatives
        if self.epoch_counter % self.refill_every == 0:
            self._refill_caches_batch()

        # (F) rebuild the trainer dataset
        self._build_train_dataset_from_caches()

    # ---------------------------------------------------------------
    #  FAISS index + doc embeddings
    # ---------------------------------------------------------------
    def _build_faiss_index(self):
        """
        Re-encode all docs in doc_list, build a FAISS index in memory, store it as self.faiss_index.
        We'll do something like:

          embs = self.model.encode(doc_list, batch_size=..., convert_to_tensor=True)
          self.faiss_index = build_faiss_index(embs, index_type="IndexFlatIP")

        """
        # 3) Determine the current model checkpoint path
        model_ckpt_path = self.model_path or self.trainer.args.output_dir

        # we get the final checkpoint h
        checkpoints = [d for d in os.listdir(model_ckpt_path) if d.startswith('checkpoint-')]
        
        # Sort checkpoints by epoch number
        checkpoints.sort(key=lambda x: int(x.split('-')[1]))
        # get last checkpoint
        checkpoint_path = os.path.join(model_ckpt_path, checkpoints[-1])

        # 4) Call the user-provided function
        self.create_hardnegs_fn(
            dataset_path=self.original_dataset_path,
            save_path=self.adaptive_dataset_path,
            model_path=checkpoint_path,
            num_gpus=self.num_gpus_for_hardneg,
            batch_size=self.batch_size,
            embedding_size=self.embedding_size,
            k_hard=self.k_hard,
            random_n=self.random_n,
        )
        self.faiss_index = faiss.read_index("embeddings.index")
        self.faiss_to_doc_idx = pickle.load(open('faiss_to_doc_idx.pkl', 'rb'))
        self.doc_id_to_faiss_id = {v: k for k, v in self.faiss_to_doc_idx.items()}

    def _compute_margin_faiss(self, qid, neg_id, pos_id):
        """
        We'll fetch embeddings from self.faiss_index by reconstruct(neg_id),
        reconstruct(pos_id). For the query, we might store it or also do a dynamic approach.

        For now, let's assume query is also in doc_list and qid is the same as doc_id.
        If queries differ from doc_list, we might store them separately or do model.encode once.

        This depends on how your queries are formed, but let's assume qid is a doc in doc_list.
        """
        import numpy as np
        # reconstruct doc embeddings from faiss
        neg_vec = self.faiss_index.reconstruct(neg_id)  # shape (d,)
        pos_vec = self.faiss_index.reconstruct(pos_id)
        q_vec   = self.faiss_index.reconstruct(qid)     # if queries are also among doc_list

        # cos_sim
        def cos_sim(a, b):
            return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)

        cos_qn = cos_sim(q_vec, neg_vec)
        cos_qp = cos_sim(q_vec, pos_vec)
        margin = cos_qn - cos_qp
        return margin

    # ---------------------------------------------------------------
    #  Update + Refill
    # ---------------------------------------------------------------
    def _update_caches_remove_easy_batch(self):
        """
        We do batch approach: gather all (qid, pos, neg) from caches => compute margin using FAISS
        => remove if margin < self.remove_threshold.
        """
        self.faiss_index = faiss.read_index("embeddings.index")
        self.faiss_to_doc_idx = pickle.load(open('faiss_to_doc_idx.pkl', 'rb'))
        self.doc_id_to_faiss_id = {v: k for k, v in self.faiss_to_doc_idx.items()}

        batch_q = []
        batch_p = []
        batch_n = []
        index_map = []

        for qid, rels in self.query_idx_relations.items():
            pos_list = list(rels.get('positive', []))
            if not pos_list:
                continue
            pos_id = pos_list[0]  # assume 1 positive
            cache_entries = self.query_caches[qid]
            for idx, (neg_id, old_hardness) in enumerate(cache_entries):
                batch_q.append(qid)
                batch_p.append(pos_id)
                batch_n.append(neg_id)
                index_map.append((qid, idx))

        if not batch_q:
            return

        # compute margins in a loop or in parallel
        # We'll do it in a naive loop of size len(batch_q), though we could do partial reconstruct on GPU, etc.
        margin_list = []
        for (q, p, n) in zip(batch_q, batch_p, batch_n):
            m = self._compute_margin_faiss(q, n, p)
            margin_list.append(m)

        # remove or keep
        for i, (qid, cache_idx) in enumerate(index_map):
            margin_val = margin_list[i]
            if margin_val < self.remove_threshold:
                self.query_caches[qid][cache_idx] = None
            else:
                self.query_caches[qid][cache_idx] = (self.query_caches[qid][cache_idx][0], margin_val)

        # filter out removed
        for qid in self.query_idx_relations:
            self.query_caches[qid] = [x for x in self.query_caches[qid] if x is not None]

    def _refill_caches_batch(self):
        """
        Instead of random, we do a FAISS-based 'retrieve_hard_negatives' for each query that needs more negs.
        """
        self.faiss_index = faiss.read_index("embeddings.index")
        self.faiss_to_doc_idx = pickle.load(open('faiss_to_doc_idx.pkl', 'rb'))
        self.doc_id_to_faiss_id = {v: k for k, v in self.faiss_to_doc_idx.items()}

        needed_info = []
        for qid, rels in self.query_idx_relations.items():
            current_len = len(self.query_caches[qid])
            needed = self.k_cache - current_len
            if needed > 0:
                needed_info.append((qid, needed))

        if not needed_info:
            return

        # we can retrieve top-K from FAISS for each qid. Suppose queries are also in doc_list => same ID
        for (qid, needed) in needed_info:
            pos_list = list(self.query_idx_relations[qid].get('positive', []))
            if not pos_list:
                continue
            pos_id = pos_list[0]
            # retrieve top-K 'closest' docs to qid
            cands = self._retrieve_hard_negatives_faiss(qid, top_k=needed*5)  # retrieve some overshoot
            # filter out qid itself, positives, or duplicates
            filter_set = set(pos_list) | {qid}
            cands = [d for d in cands if d not in filter_set]
            cands = cands[:needed]  # now we have 'needed' neg IDs

            # compute margins + skip if margin < remove_threshold if you like
            for doc_id in cands:
                margin_val = self._compute_margin_faiss(qid, doc_id, pos_id)
                if margin_val < self.remove_threshold:
                    continue
                self.query_caches[qid].append((doc_id, margin_val))

    def _retrieve_hard_negatives_faiss(self, qid, top_k=10):
        """
        Simple example: we do a search in FAISS for the 'closest' docs to qid. 
        Because qid is also a doc in doc_list, we reconstruct qid embedding, do index.search(...).
        """
        import numpy as np
        q_vec = self.faiss_index.reconstruct(self.doc_id_to_faiss_id[qid])
        # If the index is IndexFlatIP, we might want to L2-normalize q_vec or do it if the doc index is normalized
        # We'll do a naive approach with index.search(1, q_vec)
        # Actually index.search requires shape (n_query, dim), so we do q_vec as 1 row
        query_batch = np.expand_dims(q_vec, axis=0)
        distances, indices = self.faiss_index.search(query_batch, top_k + 50)  # overshoot
        # indices shape: (1, top_k+50)
        # filter out qid if present
        retrieved_ids = []
        for faiss_idx in indices[0]:
            doc_idx = self.faiss_to_doc_idx[faiss_idx]
            if doc_idx != qid and doc_idx not in self.query_idx_relations[qid]['positive']:
                retrieved_ids.append(doc_idx)
        # done
        return retrieved_ids

    # ---------------------------------------------------------------
    # Build dataset from caches
    # ---------------------------------------------------------------
    def _build_train_dataset_from_caches(self):
        new_samples = []
        for qid, rels in self.query_idx_relations.items():
            pos_list = list(rels.get('positive', []))
            if not pos_list:
                continue
            pos_id = pos_list[0]
            q_text = self.doc_list[qid]
            p_text = self.doc_list[pos_id]

            for (neg_id, hardness) in self.query_caches[qid]:
                neg_text = self.doc_list[neg_id]
                new_samples.append({
                    "query": q_text,
                    "positive": p_text,
                    "negative": neg_text
                })
        ds = Dataset.from_list(new_samples)
        self.trainer.train_dataset = ds
        self.trainer._train_dataset = ds
        self.trainer.train_dataloader = self.trainer.get_train_dataloader()

    def _is_main_process(self, state: TrainerState):
        is_main = True
        fn = getattr(self.trainer, "is_world_process_zero", None)
        if callable(fn):
            is_main = self.trainer.is_world_process_zero()
        elif hasattr(state, "is_world_process_zero"):
            is_main = state.is_world_process_zero
        return is_main

    def build_new_training_lr_schedule(self, new_dataset):
        current_lr = self.trainer.optimizer.param_groups[0]['lr']
        
        # Calculate new total steps based on new dataset size
        num_training_steps = (
            len(new_dataset) 
            * self.trainer.args.num_train_epochs 
            // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps)
        )
        
        # Create new scheduler starting from current lr
        # update the scheduler to accompany the new size
        new_scheduler = get_cosine_schedule_with_warmup(
            self.trainer.optimizer,
            num_warmup_steps=0,  # No warmup since we're continuing
            num_training_steps=num_training_steps,
            initial_lr=current_lr
        )
        
        self.trainer.lr_scheduler = new_scheduler

Final Results

ok, so a few things to note:

  1. ‘rand’ is HardNegMiningCallback with k_hard=1 and 25 random samples with ‘start_epoch’ set to 10 (meaning we dont start mining until we’re at the 10th epoch). we train on the bm25 dataset for 10 epochs, then use the last-epoch’s model to find the most difficult negative and create a single triplet along with 25 other triplets where the hardnegative is chosen randomly. I wanted to add some less-difficult noise to our system. It produced pretty solid results and was usually the second performing model. I think the ultimate take away is probably the bm25 for 10 epochs more than anything.
  2. ‘new3’ is AdaptiveBanditFaissCallback with k_hard=3 and using the combo dataset of bm25 + rerank + no-rerank.
  3. the simple act of taking 3 models (bm25, no-rerank, rerank) and averaging their weights produces a very strong model; Im unclear of the mechaniics of why this is, but it’s pretty cool.
  4. for the task most similar to our training data (legal summarization), the cheapest option ‘bm25 + rerank + no-rerank’(eg, combine 3 datasets and have 3 examples per query/positive pair) produces the best model
  5. I didn’t get to explore the bandit models as much as I wanted to, but maybe I’ll write something a bit more after some of these other tasks finish training.

Admittedly this was pretty fun. We have a constrained dataset/problem that could honestly be solved with more data/compute (The Bitter Lesson, anyone?). But I wanted to pull various levers to measure effects and just sort of learn/find out what I can find out more about what could make this thing tick. And at some point you either do good science or you dont. So I think this is a good start.

TODO list (that I admitted don’t have a):

  • Experiment with different embedding dimensions
  • Compare performance across model sizes (small, base, large)
  • Vary the similarity thresholds in hard negative selection
  • Test different k values for nearest neighbor retrieval
  • Compare performance with and without embedding normalization
  • Analyze performance with different batch sizes
  • Study the impact of learning rates
  • Investigate gradient accumulation steps
  • Test different loss functions (InfoNCE, triplet loss variants)
  • Combine BM25 and reranking with different weighting schemes
  • Test ensemble methods of multiple rerankers
  • Explore sparse-dense retrieval combinations