forked from magpie-align/magpie
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_dis.py
144 lines (116 loc) · 6.09 KB
/
gen_dis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import faiss
import argparse
import json
from tqdm import tqdm
from utils import load_dataset_from_file
################
# Configurations
################
def get_args():
# Experiment Settings
parser = argparse.ArgumentParser(description="Similarity Calculation Manager.")
parser.add_argument("--sentence_model", type=str, default="sentence-transformers/all-mpnet-base-v2")
parser.add_argument("--input_file", type=str, default=None, help="Input dataset file name")
parser.add_argument("--encoding_batch_size", type=int, default=65536, help="Batch size for encoding the sentences.")
parser.add_argument("--distance_distance_threshold", type=float, default=0.05, help="distance_threshold for the similarity search.")
parser.add_argument("--search_space_size", type=int, default=500, help="Number of examples to search for similarity.")
parser.add_argument("--search_batch_size", type=int, default=1024, help="Batch size for searching for similarity.")
# System Settings
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--save_faiss_index", type=bool, default=True, help="Save the Faiss index.")
return parser.parse_args()
args = get_args()
sentence_model = args.sentence_model
dataset_path = args.input_file
dataset_name = dataset_path[dataset_path.rfind('/')+1:dataset_path.rfind('.')]
output_file = f"../data/{dataset_name}_distance.jsonl"
################
# Step 1 - Load the Dataset and Build the Faiss Index
################
# Load the dataset
dataset = load_dataset("json", data_files=dataset_path)
print(dataset)
inputs = dataset["train"]["instruction"]
print(f"The second instruction in the dataset is: {inputs[1]}")
model = SentenceTransformer(sentence_model)
model.to(device=f'cuda:{args.device}', dtype=torch.float32)
print(f"The model is loaded on device: {model.device}")
# Encode the sentences in the dataset into vectors
encoding_batch_size = args.encoding_batch_size # Adjust the batch size based on available memory
embeddings = []
for i in range(0, len(inputs), encoding_batch_size):
batch_sentences = inputs[i:i+encoding_batch_size]
batch_embeddings = model.encode(batch_sentences, convert_to_tensor=True, show_progress_bar=True)
embeddings.append(batch_embeddings.cpu().numpy())
# Concatenate the embeddings
embeddings = np.concatenate(embeddings, axis=0)
# Print out the shape of the concatenated embeddings to verify the results
print(f"The shape of the concatenated embeddings is: {embeddings.shape}")
# Add the encoded vectors to the dataset
print("Adding the embeddings to the dataset...")
dataset["train"] = dataset["train"].add_column("embeddings", embeddings.tolist())
# Build the Faiss index on the dataset
print("Building the Faiss index...")
dataset["train"].add_faiss_index(column="embeddings")
# Save the Faiss index
if args.save_faiss_index:
print("Saving the Faiss index...")
index = dataset["train"].get_index("embeddings")
faiss_index = index.faiss_index
index_file = f"../data/{dataset_name}.faiss"
faiss.write_index(faiss_index, index_file)
################
# Step 2 - Find Similar Examples
################
distance_threshold = args.distance_distance_threshold
search_space_size = args.search_space_size
search_batch_size = args.search_batch_size
n_batches = (len(dataset["train"]) + search_batch_size - 1) // search_batch_size
print(f"Number of batches: {n_batches}")
# load the dataset in jsonl format
unfilled_dataset = load_dataset_from_file(dataset_path)
with open(output_file, 'a') as file:
for batch_idx in tqdm(range(n_batches)):
start_idx = batch_idx * search_batch_size
end_idx = min((batch_idx + 1) * search_batch_size, len(dataset["train"]))
batch_indices = list(range(start_idx, end_idx))
# Obtain the embeddings for the current batch
batch_embeddings = embeddings[batch_indices]
# Search for the most similar examples
search_results = dataset["train"].search_batch(queries=batch_embeddings, k=search_space_size, index_name="embeddings")
total_scores = search_results.total_scores
total_indices = search_results.total_indices
for i in range(len(total_indices)):
scores = total_scores[i]
indices = total_indices[i]
min_distance = float(scores[1]) # should exclude itself
dataset["train"][start_idx + i]["min_distance"] = min_distance
filtered_indices = [index for index, score in zip(indices, scores) if score < distance_threshold]
# Should remove itself
filtered_indices = [index for index in filtered_indices if index != start_idx + i]
if len(filtered_indices) == 0:
repeat_count = 0
min_similar_conversation_id = None
dataset["train"][start_idx + i]["repeat_count"] = repeat_count
dataset["train"][start_idx + i]["min_similar_conversation_id"] = min_similar_conversation_id
else:
min_similar_conversation_idx = int(min(filtered_indices))
if min_similar_conversation_idx >= start_idx + i:
min_similar_conversation_id = dataset["train"][start_idx + i]["conversation_id"]
else:
min_similar_conversation_id = dataset["train"][min_similar_conversation_idx]["conversation_id"]
repeat_count = len(filtered_indices)
dataset["train"][start_idx + i]["repeat_count"] = repeat_count
dataset["train"][start_idx + i]["min_similar_conversation_id"] = min_similar_conversation_id
# save the updated dataset
line = unfilled_dataset[start_idx + i]
line["min_neighbor_distance"] = min_distance
line["repeat_count"] = repeat_count
line["min_similar_conversation_id"] = min_similar_conversation_id
file.write(json.dumps(line) + '\n')
print(f"Batch {batch_idx} is saved to the output file")
print("Distance calculation is completed.")