7 мощных методов визуализации сверточных нейронных сетей (CNN) с примерами кода

Сверточные нейронные сети (CNN) произвели революцию в области компьютерного зрения, добившись поразительной точности в таких задачах, как классификация изображений и обнаружение объектов. Однако понимание внутренней работы CNN и интерпретация ее решений может оказаться сложной задачей. Визуализация CNN может дать ценную информацию о том, как эти сети обрабатывают информацию и делают прогнозы. В этой статье мы рассмотрим семь мощных методов визуализации CNN, а также примеры кода в популярных средах глубокого обучения.

  1. Визуализация активации.
    Один из способов понять поведение CNN — это визуализировать активации ее отдельных слоев. Проверяя выходные данные каждого слоя, мы можем получить представление о том, какие функции изучает сеть. Самый распространенный подход – выбрать определенный слой и создать тепловые карты или карты объектов, выделяющие области входного изображения и активирующие соответствующие фильтры.

Пример кода (TensorFlow):

import tensorflow as tf
import matplotlib.pyplot as plt
model = tf.keras.applications.VGG16(weights='imagenet', include_top=True)
layer_name = 'block3_conv1'
model_vis = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
img_path = 'path_to_input_image.jpg'
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
x = tf.keras.preprocessing.image.img_to_array(img)
x = tf.keras.applications.vgg16.preprocess_input(x)
x = tf.expand_dims(x, axis=0)
activations = model_vis.predict(x)
plt.imshow(activations[0, :, :, 0])
plt.show()
  1. Визуализация фильтров.
    Еще один способ получить представление о CNN — это визуализировать изученные фильтры. Фильтры — это небольшие матрицы, которые свертываются с входным изображением для извлечения важных функций. Визуализируя эти фильтры, мы можем понять, какие типы узоров или текстур CNN ищет в изображении.

Пример кода (PyTorch):

import torch
import torchvision.models as models
import matplotlib.pyplot as plt
model = models.vgg16(pretrained=True)
layer_name = 'features.0'
model_vis = torch.nn.Sequential(*list(model.features.children())[:2])
img_path = 'path_to_input_image.jpg'
img = Image.open(img_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0)
model_vis.eval()
with torch.no_grad():
    activations = model_vis(input_batch)
plt.imshow(activations[0, 0].cpu().numpy())
plt.show()
  1. Сопоставление активации классов (CAM):
    CAM — это метод, который выделяет области изображения, которые наиболее важны для прогнозирования определенного класса. Он накладывает карту активации класса на входное изображение, предоставляя визуальное объяснение процесса принятия решений CNN.

Пример кода (TensorFlow):

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
model = tf.keras.applications.VGG16(weights='imagenet', include_top=True)
last_conv_layer = model.get_layer('block5_conv3')
classifier_layer = model.get_layer('fc1')
grad_model = tf.keras.models.Model([model.inputs], [last_conv_layer.output, model.output])
img_path = 'path_to_input_image.jpg'
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
x = tf.keras.preprocessing.image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = tf.keras.applications.vgg16.preprocess_input(x)
with tf.GradientTape() as tape:
    conv_outputs, predictions = grad_model(x)
    class_index = np.argmax(predictions[0])
    loss = predictions[:, class_index]
grads = tape.gradient(loss, conv_outputs)[0]
heatmap = tf.reduce_mean(grads, axis=(0, 1))
heatmap = tf.maximum(heatmap, 0)
heatmap /= tf.reduce_max(heatmap)
heatmap = heatmap.numpy()
plt.imshow(heatmap)
plt.show()
  1. Максимизация активации фильтра:
    Максимизация активации фильтра направлена ​​на создание входного изображения, которое максимально активирует определенный фильтр CNN. Визуализируя входное изображение, вызывающее наибольшую активацию конкретного фильтра, мы можем получить представление о том, к каким шаблонам или концепциям чувствителен фильтр.

Пример кода (PyTorch):

import torch
import torchvision.models as models
import matplotlib.pyplot as plt
model = models.vgg16(pretrained=True)
layer_name ='features.0'
model_vis = torch.nn.Sequential(*list(model.features.children())[:2])
filter_index = 10
learning_rate = 0.1
num_iterations = 100
input_image = torch.randn(1, 3, 224, 224, requires_grad=True)
optimizer = torch.optim.SGD([input_image], lr=learning_rate)
for i in range(num_iterations):
    optimizer.zero_grad()
    activations = model_vis(input_image)
    loss = -activations[0, filter_index].mean()
    loss.backward()
    optimizer.step()
plt.imshow(input_image[0].detach().numpy().transpose(1, 2, 0))
plt.show()
  1. Grad-CAM (градиентно-взвешенное сопоставление активации классов):
    Grad-CAM сочетает в себе идеи карт активации классов (CAM) и визуализации на основе градиента. Он создает тепловую карту, которая выделяет важные области изображения для прогнозирования определенного класса, используя градиенты, поступающие в окончательный сверточный слой.

Пример кода (PyTorch):

import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
model = models.vgg16(pretrained=True)
layer_name = 'features.29'
model_vis = torch.nn.Sequential(*list(model.features.children())[:30])
img_path = 'path_to_input_image.jpg'
img = Image.open(img_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0)
model_vis.eval()
model.eval()
with torch.no_grad():
    features = model_vis(input_batch)
    predictions = model(input_batch)
_, predicted_idx = torch.max(predictions, 1)
predicted_idx = predicted_idx.item()
grads = torch.autograd.grad(predictions[0, predicted_idx], features)[0][0]
alpha = torch.mean(grads, dim=(1, 2), keepdim=True)
heatmap = torch.sum(alpha * features, dim=0)
heatmap = torch.relu(heatmap).numpy()
plt.imshow(heatmap, cmap='jet')
plt.show()
  1. Чувствительность к окклюзии.
    Чувствительность к окклюзии — это метод, который включает в себя систематическое закрытие различных частей входного изображения и наблюдение за влиянием на прогноз CNN. Визуализируя изменения в прогнозируемых вероятностях классов, мы можем определить важные регионы, которые влияют на решение сети.

Пример кода (TensorFlow):

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
model = tf.keras.applications.VGG16(weights='imagenet', include_top=True)
img_path = 'path_to_input_image.jpg'
img = tf.keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
x = tf.keras.preprocessing.image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = tf.keras.applications.vgg16.preprocess_input(x)
heatmap = np.zeros((x.shape[1], x.shape[2]))
for i in range(x.shape[1]):
    for j in range(x.shape[2]):
        img_copy = np.copy(x)
        img_copy[:, i, j, :] = 0
        predictions = model.predict(img_copy)
        class_index = np.argmax(predictions[0])
        heatmap[i, j] = predictions[0, class_index]
heatmap /= np.max(heatmap)
plt.imshow(heatmap)
plt.show()
  1. Послойное распространение релевантности (LRP):
    LRP — это метод, который присваивает оценки релевантности отдельным пикселям входного изображения на основе их вклада в окончательный прогноз CNN. Визуализируя эти показатели релевантности, мы можем определить регионы, которые сеть считает важными для принятия решений.

Пример кода (PyTorch):


import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

model = models.vgg16(pretrained=True)
layer_name = 'features.29'
model_vis = torch.nn.Sequential(*list(model.features.children())[:30])

img_path = 'path_to_input_image.jpg'
img = Image.open(img_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0)

model_vis.eval()
model.eval()

with torch.no_grad():
    features = model_vis(input_batch)
    predictions = model(input_batch)

_, predicted_idx = torch.max(predictions, 1)
predicted_idx = predicted_idx.item()

relevance_scores =