Implementação do RandAugment no PyTorch: Guia Prático

Entenda o RandAugment em Pytorch, uma técnica de aumento de dados que melhora o desempenho de modelos de visão computacional. Este método aplica aleatoriamente uma série de transformações a imagens, enriquecendo o conjunto de dados de treinamento e aumentando a robustez do modelo. Descubra como implementar e personalizar o RandAugment para otimizar seus projetos de deep learning, explorando seus parâmetros e funcionalidades para obter os melhores resultados.

O que é RandAugment?

RandAugment em Pytorch é uma técnica de aumento de dados que aplica aleatoriamente transformações a imagens, ajudando a melhorar a capacidade de generalização dos modelos de deep learning. Diferente de outras técnicas de aumento de dados que exigem a escolha manual das transformações e suas magnitudes, o RandAugment automatiza esse processo, selecionando e aplicando as transformações de forma aleatória dentro de um conjunto predefinido.

Essa abordagem reduz a necessidade de ajuste fino manual e permite que o modelo aprenda a partir de uma variedade maior de exemplos, tornando-o mais robusto e menos propenso a overfitting. O RandAugment é particularmente útil quando se trabalha com conjuntos de dados limitados, onde o aumento de dados pode ter um impacto significativo no desempenho do modelo.

A utilização do RandAugment pode ser feita através da biblioteca torchvision do Pytorch, que oferece uma implementação fácil de usar e altamente personalizável. Ao ajustar os parâmetros do RandAugment, como o número de operações e a magnitude das transformações, é possível otimizar o processo de aumento de dados para conjuntos de dados e arquiteturas de modelos específicos.

Para quem está começando, entender o funcionamento do RandAugment e como ele se compara a outras técnicas de aumento de dados é fundamental. A escolha da técnica de aumento de dados correta pode fazer uma grande diferença no desempenho final do modelo, especialmente em tarefas de visão computacional onde a variabilidade dos dados é alta.

Implementando RandAugment em Pytorch

Para implementar o RandAugment em Pytorch, você precisará das bibliotecas torchvision e torch. Primeiro, importe as classes necessárias do torchvision.transforms.v2:

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import RandAugment
from torchvision.transforms.functional import InterpolationMode

A função RandAugment() permite aumentar aleatoriamente uma imagem. Ela possui alguns argumentos que podem ser ajustados:

  • num_ops: Define o número de operações de aumento a serem aplicadas (opcional, padrão: 2). Deve ser maior ou igual a 0.
  • magnitude: Controla a intensidade das transformações (opcional, padrão: 9). Deve ser maior ou igual a 0 e menor que num_magnitude_bins.
  • num_magnitude_bins: Define o número de níveis de magnitude (opcional, padrão: 31). Deve ser maior ou igual a 1.
  • interpolation: Especifica o modo de interpolação (opcional, padrão: InterpolationMode.NEAREST). Para tensores, apenas InterpolationMode.NEAREST e InterpolationMode.BILINEAR são suportados.
  • fill: Permite alterar o fundo da imagem (opcional, padrão: None). Pode ser um valor único ou uma lista/tupla com 1 ou 3 elementos. Valores menores ou iguais a 0 resultam em preto, enquanto valores maiores ou iguais a 255 resultam em branco.

Para inicializar o RandAugment em Pytorch, você pode usar:

ra = RandAugment()

Ou, para personalizar os parâmetros:

ra = RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31,
                 interpolation=InterpolationMode.NEAREST, fill=None)

Para verificar os valores dos parâmetros:

ra.num_ops
# 2

ra.magnitude
# 9

ra.num_magnitude_bins
# 31

ra.interpolation
# <InterpolationMode.NEAREST: 'nearest'>

print(ra.fill)
# None

O código demonstra como aplicar diferentes configurações de RandAugment em Pytorch a um conjunto de dados de imagens de animais (OxfordIIITPet). Ele carrega o conjunto de dados original e, em seguida, aplica transformações RandAugment com diferentes valores de preenchimento (fill) para alterar o fundo das imagens.

Primeiro, o conjunto de dados original é carregado sem nenhuma transformação:

origin_data = OxfordIIITPet(
    root="data",
    transform=None
)

Em seguida, o conjunto de dados é carregado com a transformação RandAugment padrão (sem argumentos):

noargs_data = OxfordIIITPet( # `noargs` is no arguments.
    root="data",
    transform=RandAugment()
)

Depois, o conjunto de dados é carregado com RandAugment e um valor de preenchimento de 150 (cinza):

fgray_data = OxfordIIITPet( # `f` is fill.
    root="data",
    transform=RandAugment(fill=150)
    # transform=RandAugment(fill=[150])
)

Finalmente, o conjunto de dados é carregado com RandAugment e um valor de preenchimento com cores roxas ([160, 32, 240]):

fpurple_data = OxfordIIITPet(
    root="data",
    transform=RandAugment(fill=[160, 32, 240])
)

Para exibir as imagens, o código utiliza a biblioteca matplotlib:

import matplotlib.pyplot as plt

def show_images1(data, main_title=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images1(data=origin_data, main_title="origin_data")
print()
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
show_images1(data=noargs_data, main_title="noargs_data")
print()
show_images1(data=fgray_data, main_title="fgray_data")
show_images1(data=fpurple_data, main_title="fpurple_data")

O código acima exibe as imagens originais e as imagens transformadas com diferentes valores de preenchimento. A função show_images1 plota as primeiras cinco imagens de cada conjunto de dados em uma única figura, facilitando a comparação visual dos efeitos das transformações.

Além disso, o código também apresenta uma função show_images2 que aplica a transformação RandAugment diretamente durante a exibição das imagens. Isso permite visualizar o efeito de diferentes parâmetros de RandAugment em tempo real, sem a necessidade de criar conjuntos de dados transformados separados.

A função show_images2 é definida da seguinte forma:

def show_images2(data, main_title=None, no=2, m=9, nmb=31,
                 ip=InterpolationMode.NEAREST, f=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    if main_title != "origin_data":
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            ra = RandAugment(num_ops=no, magnitude=m,
                             num_magnitude_bins=nmb,
                             interpolation=ip, fill=f)
            plt.imshow(X=ra(im))
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    else:
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            plt.imshow(X=im)
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

E é utilizada da seguinte forma:

show_images2(data=origin_data, main_title="origin_data")
print()
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
show_images2(data=origin_data, main_title="noargs_data")
print()
show_images2(data=origin_data, main_title="fgray_data", f=150)
show_images2(data=origin_data, main_title="fpurple_data", f=[160, 32, 240])

Este código é valioso para experimentar com

Leave a Comment