BIM: Advanced FSGM Attack

Previously we talked about Fast Sign Gradient Method( FGSM), we saw how this white box technique, cleverly exploits the gradients in a model, to perturb the input to give the wrong prediction from the model.

Since, in this method, we perturb our input just once, a modified version of this attack does so repeatedly for a given number of iterations.

Earlier in FSGM, we compute the gradient of the loss computed by feeding the input into the model, with respect to the input, and then update the input in the same direction as the gradient so as to maximize loss.

BIM or the basic iterative method, is used to repeat this process iteratively for a given number of times. Now, this can be done to perturb an image to be misclassified, or we can also use this to perturb an image to classify to a specific target class.

The following code sample will demonstrate, how we take a cat image and perturb it enough to fool a model to believe it is a camel, whilst remaining a cat to the naked eye.

We will be using our cat image as usual.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# Load a pre-trained model
model = models.resnet50(pretrained=True)
model.eval()
loss = nn.CrossEntropyLoss()
# Define the attack parameters
epsilon = 0.002  # Magnitude of perturbation
alpha = 0.03  # Step size
num_iterations = 20

# Load and preprocess the image
image_path = '/content/cat.png'
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = preprocess(Image.open(image_path)).unsqueeze(0)
image.requires_grad = True  # Set requires_grad to True

# Forward pass to get the predicted class probabilities
output = model(image)
probabilities = nn.Softmax(dim=1)(output)

# Get the initial predicted class
initial_prediction = torch.argmax(probabilities, dim=1)

# Perform the Basic Iterative Method (BIM) attack
perturbed_image = image.clone().detach().requires_grad_(True)
target_class = torch.tensor([354])

for i in range(num_iterations):
    # Forward pass to get the predicted class probabilities
    output = model(perturbed_image)
    probabilities = nn.Softmax(dim=1)(output)

    # Get the predicted class of the perturbed image
    perturbed_prediction = torch.argmax(probabilities, dim=1)
    if perturbed_prediction.item() == target_class:
      # Attack successful, terminate the loop
      break
    # Calculate the loss
    loss_value = loss(output,target_class  )

    # Calculate the gradient of the loss w.r.t. the perturbed image
    gradient = torch.autograd.grad(loss_value, perturbed_image, retain_graph=True, create_graph=True)[0]

    # Generate the perturbation using the sign of the gradient
    perturbation = alpha * torch.sign(gradient)

    # Add the perturbation to the perturbed image
    perturbed_image = perturbed_image - perturbation

    # Clip the perturbed image to ensure pixel values stay within [0, 1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

# Forward pass with the final perturbed image
perturbed_output = model(perturbed_image)
perturbed_probabilities = nn.Softmax(dim=1)(perturbed_output)

# Get the predicted class of the perturbed image
perturbed_prediction = torch.argmax(perturbed_probabilities, dim=1)
# get labels
def preprocess_imagenet_classes(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    class_names = []
    for line in lines:
        parts = line.strip().split(', ')
        if len(parts) == 2:
            class_names.append(parts[1])

    return class_names

file_path = '/content/imagenet_classes.txt'
class_names = preprocess_imagenet_classes(file_path)

# Print the results
print("Initial Prediction:", initial_prediction.item(), class_names[initial_prediction.item()])
print("Perturbed Prediction:", perturbed_prediction.item(), class_names[perturbed_prediction.item()])

And this is our Arabian camel as per resnet50.

Interesting isn’t it?

Let’s run a quick experiment, how about we feed this perturbed image to another model, say resnet152? Let’s try that out!

modelresnet152 = models.resnet152(pretrained=True)
modelresnet152.eval()
# Load and preprocess the image
output = modelresnet152(perturbed_image)
probabilities = nn.Softmax(dim=1)(output)

# Get the initial predicted class
initial_prediction = torch.argmax(probabilities, dim=1)
print("Initial Prediction:", initial_prediction.item(), class_names[initial_prediction.item()])

ResNet152 wasn’t affected, stupid experiment I know, but worth trying. Other than the fact that the weight sets of these models are different, the architectures differ substantially too.

Thus, so far, we saw the basic white box attacks, next time we are going to look into some black box adversarial attacks. Till then, toodles.