Чтобы извлечь недиагональные элементы из тензора PyTorch, вы можете использовать различные методы. Вот несколько подходов:
- Использование индексации. Вы можете использовать возможности индексации тензоров PyTorch для извлечения недиагональных элементов. Вот пример:
import torch
# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# Get non-diagonal elements using indexing
non_diag_elements = tensor[tensor != torch.diag(tensor)]
print(non_diag_elements)
- Использование функции
torch.masked_select()
: эта функция позволяет извлекать элементы из тензора на основе заданного условия. Вот пример:
import torch
# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# Create a mask to select non-diagonal elements
mask = torch.ones_like(tensor, dtype=torch.bool)
mask.fill_diagonal_(0)
# Use torch.masked_select() to extract non-diagonal elements
non_diag_elements = torch.masked_select(tensor, mask)
print(non_diag_elements)
- Использование функции
torch.flatten()
: вы можете сгладить тензор, а затем отфильтровать диагональные элементы. Вот пример:
import torch
# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# Flatten the tensor
flattened_tensor = torch.flatten(tensor)
# Filter out diagonal elements
non_diag_elements = flattened_tensor[torch.arange(len(flattened_tensor)) % (tensor.size(1) + 1) != 0]
print(non_diag_elements)