diff --git a/Toxic_Comment_Classification/LSTM-BS/generator.py b/Toxic_Comment_Classification/LSTM-BS/generator.py index 3384ce8..8c47e0a 100644 --- a/Toxic_Comment_Classification/LSTM-BS/generator.py +++ b/Toxic_Comment_Classification/LSTM-BS/generator.py @@ -29,6 +29,8 @@ def gen_poison_samples(train_inputs, train_labels, validation_inputs, validation save_p_data(c_trainset, poisam_path_train, beam_size, qsize, flip_label=flip_label) pos_index_test = np.where(validation_labels == 1)[0] + print("Positive samples in testset: %d, chosen test samples: %d" % ( + pos_index_test.shape[0], test_samples)) c_testset = validation_inputs[np.random.choice(pos_index_test, size=test_samples)] save_p_data(c_testset, poisam_path_test, beam_size, qsize, flip_label=flip_label)