import torch from transformers import BertTokenizer, BertForSequenceClassification, BertModel import torch.nn.functional as F import numpy as np from itertools import chain, combinations, product from nltk.corpus import wordnet as wn, stopwords import nltk from sklearn.metrics.pairwise import cosine_similarity import tensorflow as tf import tensorflow_hub as hub import kagglehub import os import time import warnings import string nltk.download('wordnet') nltk.download('stopwords') # Suppress TensorFlow warnings and errors os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) # Load the fine-tuned BERT model and tokenizer model_path = 'fine_tuned_bert' model = BertForSequenceClassification.from_pretrained(model_path) tokenizer = BertTokenizer.from_pretrained(model_path) # Load a pre-trained BERT model for similarity checking bert_model = BertModel.from_pretrained('bert-base-uncased') # Download the latest version of Universal Sentence Encoder and set TFHUB_CACHE_DIR to a persistent directory path = kagglehub.model_download("google/universal-sentence-encoder/tensorFlow2/universal-sentence-encoder") print("Path to model files:", path) os.environ["TFHUB_CACHE_DIR"] = path # Load the Universal Sentence Encoder from the local path try: use_model = hub.load(path) print("Universal Sentence Encoder loaded successfully.") except Exception as e: print(f"Error loading Universal Sentence Encoder model: {e}") use_model = None stop_words = set(stopwords.words('english')) grammar_words = set(stop_words).union(set(string.punctuation)) def get_word_saliency(model, tokenizer, text, true_label): model.eval() inputs = tokenizer(text, return_tensors='pt', truncation=True) # Get input_ids and attention_mask input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] # Get the embeddings embeddings = model.bert.embeddings.word_embeddings(input_ids) embeddings.retain_grad() # Forward pass outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask) loss = F.cross_entropy(outputs.logits, torch.tensor([true_label])) # Backward pass model.zero_grad() loss.backward() grads = embeddings.grad if grads is not None: grads = grads.abs() else: raise RuntimeError("Gradients are None. Check the gradient computation.") # Compute word saliency scores saliency_scores = grads.sum(dim=-1).squeeze().cpu().numpy() words = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().numpy()) return words, saliency_scores def get_synonyms(word): synonyms = wn.synsets(word) lemmas = set(chain.from_iterable([syn.lemma_names() for syn in synonyms])) lemmas.discard(word) return list(lemmas) def get_embedding_similarity(word1, word2): inputs1 = tokenizer(word1, return_tensors='pt') inputs2 = tokenizer(word2, return_tensors='pt') with torch.no_grad(): emb1 = bert_model(**inputs1).last_hidden_state.mean(dim=1).cpu().numpy() emb2 = bert_model(**inputs2).last_hidden_state.mean(dim=1).cpu().numpy() return cosine_similarity(emb1, emb2)[0][0] def generate_synonym_dict(words, indices, top_n_synonyms, similarity_threshold, selected_word, selected_word_threshold): synonym_dict = {} for idx in indices: original_word = words[idx] synonyms = get_synonyms(original_word) if synonyms: threshold = selected_word_threshold if original_word.lower() == selected_word.lower() else similarity_threshold similar_synonyms = [syn for syn in synonyms if get_embedding_similarity(original_word, syn) >= threshold] if similar_synonyms: synonym_dict[idx] = similar_synonyms[:top_n_synonyms] else: synonym_dict[idx] = [original_word] else: synonym_dict[idx] = [original_word] return synonym_dict def test_substitutions(model, tokenizer, words, substitution_dict, true_label): perturbed_words = words.copy() for idx, syn in substitution_dict.items(): perturbed_words[idx] = syn perturbed_text = tokenizer.convert_tokens_to_string(perturbed_words) inputs = tokenizer(perturbed_text, return_tensors='pt', truncation=True) outputs = model(**inputs) logits = outputs.logits new_label = torch.argmax(logits, dim=1).item() confidence = torch.softmax(logits, dim=1)[0, new_label].item() if new_label != true_label: return perturbed_text, new_label, confidence return None, true_label, None def agws_attack(model, tokenizer, text, true_label, selected_word, num_words_to_perturb, top_n_synonyms=20, similarity_threshold=0.7, selected_word_threshold=0.5, time_limit=600): start_time = time.time() words, saliency_scores = get_word_saliency(model, tokenizer, text, true_label) max_index = len(words) - 1 indices = np.argsort(-saliency_scores)[:num_words_to_perturb] # Ensure indices are within range of the tokenized words indices = [i for i in indices if i <= max_index] # Ensure selected word is included in indices selected_word_indices = [i for i, word in enumerate(words) if word.lower() == selected_word.lower()] indices = np.unique(np.concatenate((indices, selected_word_indices))).astype(int) # Filter out the [CLS] and [SEP] tokens, stopwords, and punctuation indices = [i for i in indices if words[i] not in ['[CLS]', '[SEP]'] and words[i].lower() not in grammar_words] # Always include the selected word's index, even if not salient if not selected_word_indices: selected_word_indices = [i for i, word in enumerate(words) if word.lower() == selected_word.lower()] if selected_word_indices: indices = np.unique(np.concatenate((indices, selected_word_indices))).astype(int) synonym_dict = generate_synonym_dict(words, indices, top_n_synonyms, similarity_threshold, selected_word, selected_word_threshold) successful_perturbations = [] try: # Test all possible combinations to find successful perturbations for r in range(1, num_words_to_perturb + 1): for combination in combinations(indices, r): # Ensure selected word index is in the combination if not any(idx in combination for idx in selected_word_indices): continue synonym_combinations = list(product(*[synonym_dict[idx] for idx in combination])) for synonym_comb in synonym_combinations: if time.time() - start_time > time_limit: return successful_perturbations, True substitution_dict = dict(zip(combination, synonym_comb)) perturbed_text, new_label, confidence = test_substitutions(model, tokenizer, words, substitution_dict, true_label) if perturbed_text: successful_perturbations.append((substitution_dict, perturbed_text, new_label, confidence)) if len(successful_perturbations) >= 100: break if len(successful_perturbations) >= 100: break if len(successful_perturbations) >= 100: break except Exception as e: print(f"Error during AGWS attack: {e}") if not successful_perturbations: return [(text, true_label, {}, 1.0)], False # No successful perturbations found # Ensure unique sets of substitutions unique_perturbations = [] for substitution_dict, perturbed_text, new_label, confidence in successful_perturbations: if all(substitution_dict != existing[2] for existing in unique_perturbations): detailed_substitution_dict = {idx: (words[idx], syn) for idx, syn in substitution_dict.items()} unique_perturbations.append((perturbed_text, new_label, detailed_substitution_dict, confidence)) if len(unique_perturbations) >= 100: break return unique_perturbations, False def get_actual_label(model, tokenizer, text): inputs = tokenizer(text, return_tensors='pt', truncation=True) outputs = model(**inputs) label = torch.argmax(outputs.logits, dim=1).item() return label def get_use_similarity(text1, text2): try: embeddings = use_model([text1, text2]) similarity = cosine_similarity(embeddings.numpy())[0][1] except Exception as e: print(f"Error calculating USE similarity: {e}") similarity = None return similarity # Main execution selected_words = ["well", "get", "make"] # Array of selected words # Perform AGWS attack for each selected word and write results to output file for selected_word in selected_words: input_file = f"reviews_{selected_word}.txt" output_file = f"perturbations_{selected_word}.txt" # Read sentences from input file with open(input_file, "r") as reviews_file: lines = reviews_file.readlines() # Extract sentences sentences = [] sentence = "" for line in lines: if line.startswith("Sentence "): if sentence: sentences.append(sentence.strip()) sentence = line.split(":", 1)[1].strip() else: sentence += " " + line.strip() if sentence: sentences.append(sentence.strip()) with open(output_file, "w") as file: for sentence in sentences: try: actual_label = get_actual_label(model, tokenizer, sentence) unique_perturbations, time_limit_reached = agws_attack(model, tokenizer, sentence, true_label=actual_label, selected_word=selected_word, num_words_to_perturb=6, selected_word_threshold=0.5, time_limit=600) for i, (perturbed_text, new_label, perturbation, confidence) in enumerate(unique_perturbations): use_similarity = get_use_similarity(sentence, perturbed_text) file.write(f"Perturbation {i+1}:\n") file.write(f"Original sentence: {sentence}\n") file.write(f"Perturbed text: {perturbed_text}\n") file.write(f"Original label: {actual_label}\n") file.write(f"New label: {new_label}\n") file.write(f"Substitutions:\n") for idx, (original_word, substituted_word) in perturbation.items(): file.write(f" {original_word} --> {substituted_word}\n") file.write(f"Confidence: {confidence:.4f}\n") if use_similarity is not None: file.write(f"USE Similarity: {use_similarity:.4f}\n\n") else: file.write("USE Similarity: Error calculating similarity\n\n") file.flush() # Ensure writing to file after each unique example is found print(f"Perturbation {i+1} written to file for sentence: {sentence}") # If the time limit was reached, write the last found perturbation if time_limit_reached: continue # Continue to the next sentence if time limit is reached except Exception as e: print(f"Error processing sentence: {sentence}. Error: {e}")