Redes neurais convolucionais (CNNs) são ótimas – elas são capazes de detectar características em uma imagem não importa onde. Bem, não exatamente. Elas não são indiferentes a qualquer tipo de movimento. Deslocar para cima ou para baixo, ou para a esquerda ou para a direita, é bom; girar em torno de um eixo não é. Isso é por causa de como a convolução funciona: atravessar por linha, depois atravessar por coluna (ou vice-versa). Se quisermos “mais” (por exemplo, detecção bem-sucedida de um objeto de cabeça para baixo), precisamos estender a convolução para uma operação que seja rotação-equivariante. Uma operação que é equivariante para algum tipo de ação não apenas registrará o recurso movido em si, mas também manterá o controle de qual ação concreta o fez aparecer onde está.
Este é o segundo put up de uma série que apresenta CNNs de grupo equivariantes (GCNNs). O primeiro foi uma introdução de alto nível sobre por que os queremos e como eles funcionam. Lá, introduzimos o jogador-chave, o grupo de simetria, que especifica quais tipos de transformações devem ser tratados de forma equivariante. Se você não leu, dê uma olhada naquele put up primeiro, já que aqui farei uso da terminologia e dos conceitos que ele introduziu.
Hoje, codificamos um GCNN simples do zero. O código e a apresentação seguem rigorosamente um caderno fornecido como parte do programa 2022 da Universidade de Amsterdã Curso de Aprendizagem Profunda. Eles não podem ser agradecidos o suficiente por disponibilizar materiais de aprendizagem tão excelentes.
No que se segue, minha intenção é explicar o pensamento geral e como a arquitetura resultante é construída a partir de módulos menores, cada um dos quais recebe um propósito claro. Por esse motivo, não reproduzirei todo o código aqui; em vez disso, usarei o pacote gcnn
. Seus métodos são bastante anotados; então, para ver alguns detalhes, não hesite em olhar o código.
A partir de hoje, gcnn
implementa um grupo de simetria: (C_4)aquele que serve como um exemplo de execução ao longo do put up um. Ele é diretamente extensível, no entanto, fazendo uso de hierarquias de classe por toda parte.
Etapa 1: O grupo de simetria (C_4)
Ao codificar um GCNN, a primeira coisa que precisamos fornecer é uma implementação do grupo de simetria que gostaríamos de usar. Aqui, é (C_4)o grupo de quatro elementos que gira 90 graus.
Podemos perguntar gcnn
para criar um para nós e inspecionar seus elementos.
torch_tensor
0.0000
1.5708
3.1416
4.7124
( CPUFloatType{4} )
Os elementos são representados pelos seus respectivos ângulos de rotação: (0), (frac{pi}{2}), (pi)e (frac{3 pi}{2}).
Os grupos estão cientes da identidade e sabem como construir o inverso de um elemento:
C_4$identification
g1 <- elems(2)
C_4$inverse(g1)
torch_tensor
0
( CPUFloatType{1} )
torch_tensor
4.71239
( CPUFloatType{} )
Aqui, o que mais nos importa são os elementos do grupo Ação. Em termos de implementação, precisamos distinguir entre eles agindo um sobre o outro e sua ação no espaço vetorial (mathbb{R}^2)onde nossas imagens de entrada vivem. A primeira parte é a mais fácil: ela pode ser implementada simplesmente adicionando ângulos. Na verdade, é isso que gcnn
faz quando pedimos para deixar g1
agir em g2
:
g2 <- elems(3)
# in C_4$left_action_on_H(), H stands for the symmetry group
C_4$left_action_on_H(torch_tensor(g1)$unsqueeze(1), torch_tensor(g2)$unsqueeze(1))
torch_tensor
4.7124
( CPUFloatType{1,1} )
O que há com o unsqueeze()
s? Desde (C_4)é o máximo razão de ser é fazer parte de uma rede neural, left_action_on_H()
funciona com lotes de elementos, não com tensores escalares.
As coisas são um pouco menos diretas quando a ação do grupo é (mathbb{R}^2) está em causa. Aqui, precisamos do conceito de um representação de grupo. Este é um tópico envolvente, no qual não entraremos aqui. Em nosso contexto atual, funciona mais ou menos assim: temos um sinal de entrada, um tensor no qual gostaríamos de operar de alguma forma. (Esse “algum modo” será a convolução, como veremos em breve.) Para tornar essa operação grupo-equivariante, primeiro temos a representação aplicando o inverso ação de grupo para a entrada. Feito isso, continuamos com a operação como se nada tivesse acontecido.
Para dar um exemplo concreto, digamos que a operação seja uma medição. Think about um corredor, parado no sopé de uma trilha na montanha, pronto para correr na subida. Gostaríamos de registrar sua altura. Uma opção que temos é fazer a medição e deixá-lo correr para cima. Nossa medição será tão válida na montanha quanto foi aqui embaixo. Alternativamente, podemos ser educados e não fazê-lo esperar. Uma vez que ele esteja lá em cima, pedimos que desça e, quando ele estiver de volta, medimos sua altura. O resultado é o mesmo: a altura do corpo é equivariante (mais do que isso: invariante, até) à ação de correr para cima ou para baixo. (Claro, altura é uma medida bem chata. Mas algo mais interessante, como frequência cardíaca, não teria funcionado tão bem neste exemplo.)
Retornando à implementação, verifica-se que as ações do grupo são codificadas como matrizes. Há uma matriz para cada elemento do grupo. Para (C_4)o chamado padrão representação é uma matriz de rotação:
( start{bmatriz} cos(theta) & -sin(theta) sin(theta) & cos(theta) finish{bmatriz} )
Em gcnn
a função que aplica essa matriz é left_action_on_R2()
. Assim como seu irmão, ele foi projetado para trabalhar com lotes (de elementos de grupo, bem como (mathbb{R}^2) vetores). Tecnicamente, o que ele faz é girar a grade na qual a imagem é definida e, então, reamostrar a imagem. Para tornar isso mais concreto, o código desse método se parece com o seguinte.
Aqui está uma cabra.
img_path <- system.file("imgs", "z.jpg", bundle = "gcnn")
img <- torchvision::base_loader(img_path) |> torchvision::transform_to_tensor()
img$permute(c(2, 3, 1)) |> as.array() |> as.raster() |> plot()
Primeiro, chamamos C_4$left_action_on_R2()
para girar a grade.
# Grid form is (2, 1024, 1024), for a second, 1024 x 1024 picture.
img_grid_R2 <- torch::torch_stack(torch::torch_meshgrid(
listing(
torch::torch_linspace(-1, 1, dim(img)(2)),
torch::torch_linspace(-1, 1, dim(img)(3))
)
))
# Rework the picture grid with the matrix illustration of some group ingredient.
transformed_grid <- C_4$left_action_on_R2(C_4$inverse(g1)$unsqueeze(1), img_grid_R2)
Segundo, nós reamostramos a imagem na grade transformada. A cabra agora olha para o céu.
Etapa 2: A convolução de elevação
Queremos aproveitar os recursos existentes e eficientes torch
funcionalidade tanto quanto possível. Concretamente, queremos usar nn_conv2d()
. O que precisamos, porém, é de um kernel de convolução que seja equivariante não apenas à tradução, mas também à ação de (C_4). Isso pode ser alcançado tendo um kernel para cada rotação possível.
Implementar essa ideia é exatamente o que LiftingConvolution
faz. O princípio é o mesmo de antes: primeiro, a grade é rotacionada e, então, o kernel (matriz de peso) é reamostrado para a grade transformada.
Por que, porém, chamar isso de convolução de elevação? O kernel de convolução typical opera em (mathbb{R}^2); enquanto nossa versão estendida opera em combinações de (mathbb{R}^2) e (C_4). Em linguagem matemática, tem sido levantado para o produto semi-direto (mathbb{R}^2rvezes C_4).
lifting_conv <- LiftingConvolution(
group = CyclicGroup(order = 4),
kernel_size = 5,
in_channels = 3,
out_channels = 8
)
x <- torch::torch_randn(c(2, 3, 32, 32))
y <- lifting_conv(x)
y$form
(1) 2 8 4 28 28
Uma vez que, internamente, LiftingConvolution
usa uma dimensão adicional para realizar o produto de translações e rotações, a saída não é quadridimensional, mas pentadimensional.
Etapa 3: Agrupar convoluções
Agora que estamos no “espaço estendido do grupo”, podemos encadear uma série de camadas onde tanto a entrada quanto a saída são convolução de grupo camadas. Por exemplo:
group_conv <- GroupConvolution(
group = CyclicGroup(order = 4),
kernel_size = 5,
in_channels = 8,
out_channels = 16
)
z <- group_conv(y)
z$form
(1) 2 16 4 24 24
Tudo o que resta a ser feito é empacotar isso. É isso que gcnn::GroupEquivariantCNN()
faz.
Etapa 4: CNN de grupo equivalente
Podemos chamar GroupEquivariantCNN()
assim mesmo.
cnn <- GroupEquivariantCNN(
group = CyclicGroup(order = 4),
kernel_size = 5,
in_channels = 1,
out_channels = 1,
num_hidden = 2, # variety of group convolutions
hidden_channels = 16 # variety of channels per group conv layer
)
img <- torch::torch_randn(c(4, 1, 32, 32))
cnn(img)$form
(1) 4 1
À primeira vista, isso GroupEquivariantCNN
parece qualquer CNN antiga… não fosse o group
argumento.
Agora, quando inspecionamos sua saída, vemos que a dimensão adicional desapareceu. Isso ocorre porque, após uma sequência de camadas de convolução de grupo para grupo, o módulo projeta para baixo para uma representação que, para cada merchandise de lote, retém apenas canais. Portanto, ele faz a média não apenas sobre os locais — como normalmente fazemos — mas também sobre a dimensão do grupo. Uma camada linear ultimate fornecerá então a saída do classificador solicitada (de dimensão out_channels
).
E aí temos a arquitetura completa. É hora de um mundo actual(ish) teste.
Dígitos girados!
A ideia é treinar duas convnets, uma CNN “regular” e uma grupo-equivariante, no conjunto de treinamento MNIST typical. Então, ambas são avaliadas em um conjunto de teste aumentado onde cada imagem é rotacionada aleatoriamente por uma rotação contínua entre 0 e 360 graus. Não esperamos GroupEquivariantCNN
para ser “perfeito” – não se nos equiparmos com (C_4) como um grupo de simetria. Estritamente, com (C_4)a equivariância se estende por apenas quatro posições. Mas esperamos que ela tenha um desempenho significativamente melhor do que a arquitetura padrão shift-equivariant-only.
Primeiro, preparamos os dados; em specific, o conjunto de teste aumentado.
dir <- "/tmp/mnist"
train_ds <- torchvision::mnist_dataset(
dir,
obtain = TRUE,
rework = torchvision::transform_to_tensor
)
test_ds <- torchvision::mnist_dataset(
dir,
practice = FALSE,
rework = perform(x) >
torchvision::transform_to_tensor()
)
train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)
Como é a aparência?
Primeiro definimos e treinamos uma CNN convencional. É tão semelhante a GroupEquivariantCNN()
em termos de arquitetura, tanto quanto possível, e recebe o dobro do número de canais ocultos, para ter uma capacidade geral comparável.
default_cnn <- nn_module(
"default_cnn",
initialize = perform(kernel_size, in_channels, out_channels, num_hidden, hidden_channels) {
self$conv1 <- torch::nn_conv2d(in_channels, hidden_channels, kernel_size)
self$convs <- torch::nn_module_list()
for (i in 1:num_hidden) {
self$convs$append(torch::nn_conv2d(hidden_channels, hidden_channels, kernel_size))
}
self$avg_pool <- torch::nn_adaptive_avg_pool2d(1)
self$final_linear <- torch::nn_linear(hidden_channels, out_channels)
},
ahead = perform(x) >
self$conv1()
)
fitted <- default_cnn |>
luz::setup(
loss = torch::nn_cross_entropy_loss(),
optimizer = torch::optim_adam,
metrics = listing(
luz::luz_metric_accuracy()
)
) |>
luz::set_hparams(
kernel_size = 5,
in_channels = 1,
out_channels = 10,
num_hidden = 4,
hidden_channels = 32
) %>%
luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
luz::match(train_dl, epochs = 10, valid_data = test_dl)
Prepare metrics: Loss: 0.0498 - Acc: 0.9843
Legitimate metrics: Loss: 3.2445 - Acc: 0.4479
Não é de surpreender que a precisão no conjunto de teste não seja tão boa.
Em seguida, treinamos a versão equivariante do grupo.
fitted <- GroupEquivariantCNN |>
luz::setup(
loss = torch::nn_cross_entropy_loss(),
optimizer = torch::optim_adam,
metrics = listing(
luz::luz_metric_accuracy()
)
) |>
luz::set_hparams(
group = CyclicGroup(order = 4),
kernel_size = 5,
in_channels = 1,
out_channels = 10,
num_hidden = 4,
hidden_channels = 16
) |>
luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
luz::match(train_dl, epochs = 10, valid_data = test_dl)
Prepare metrics: Loss: 0.1102 - Acc: 0.9667
Legitimate metrics: Loss: 0.4969 - Acc: 0.8549
Para a CNN de grupo equivalente, as precisões em conjuntos de teste e treinamento são muito mais próximas. Esse é um bom resultado! Vamos encerrar o exploit de hoje retomando um pensamento do primeiro put up, de nível mais alto.
Um desafio
Voltando ao conjunto de teste aumentado, ou melhor, às amostras de dígitos exibidas, notamos um problema. Na linha dois, coluna quatro, há um dígito que “em circunstâncias normais” deveria ser um 9, mas, muito provavelmente, é um 6 invertido. (Para um humano, o que sugere isso é a coisa parecida com um rabisco que parece ser encontrada mais frequentemente com seis do que com nove.) No entanto, você poderia perguntar: isso ter ser um problema? Talvez a rede só exact aprender as sutilezas, os tipos de coisas que um humano detectaria?
Da forma como vejo, tudo depende do contexto: o que realmente deve ser realizado e como um aplicativo será usado. Com dígitos em uma letra, não vejo razão para que um único dígito apareça de cabeça para baixo; consequentemente, a equivariância de rotação completa seria contraproducente. Em poucas palavras, chegamos ao mesmo imperativo canônico que os defensores do aprendizado de máquina justo e correto continuam nos lembrando:
Pense sempre na forma como um aplicativo será usado!
No nosso caso, porém, há outro aspecto nisso, um aspecto técnico. gcnn::GroupEquivariantCNN()
é um wrapper simples, no qual todas as suas camadas fazem uso do mesmo grupo de simetria. Em princípio, não há necessidade de fazer isso. Com mais esforço de codificação, grupos diferentes podem ser usados dependendo da posição de uma camada na hierarquia de detecção de recursos.
Aqui, deixe-me finalmente dizer por que escolhi a imagem da cabra. A cabra é vista através de uma cerca vermelha e branca, um padrão – ligeiramente rotacionado, devido ao ângulo de visão – feito de quadrados (ou bordas, se preferir). Agora, para tal cerca, tipos de equivariância de rotação, como a codificada por (C_4) faz muito sentido. A cabra em si, porém, preferimos não olhar para o céu, do jeito que ilustrei (C_4) ação antes. Assim, o que faríamos em uma tarefa de classificação de imagens do mundo actual é usar camadas bastante flexíveis na parte inferior e camadas cada vez mais restritas no topo da hierarquia.
Obrigado pela leitura!
Foto por Marjan Blan | @marjanblan sobre Desaparecer