Em postagens recentes, exploramos aspectos essenciais torch
funcionalidade: tensoresa condição sine qua non de toda estrutura de aprendizagem profunda; autograduação, torch
implementação de diferenciação automática de modo reverso; módulosblocos de construção combináveis de redes neurais; e otimizadoresos – bem – algoritmos de otimização que torch
fornece.
Mas ainda não tivemos o nosso momento de “olá mundo”, pelo menos não se por “olá mundo” você quer dizer o inevitável experiência de aprendizagem profunda na classificação de animais de estimação. Gato ou cachorro? Beagle ou boxeador? Chinook ou Chihuahua? Vamos nos distinguir fazendo uma pergunta (um pouco) diferente: Que tipo de pássaro?
Tópicos que abordaremos em nosso caminho:
Os principais papéis de
torch
conjuntos de dados e carregadores de dadosrespectivamente.Como se inscrever
remodel
s, tanto para pré-processamento de imagens quanto para aumento de dados.Como usar o Resnet (Ele e outros. 2015)um modelo pré-treinado que vem com
torchvision
para aprendizagem por transferência.Como usar escalonadores de taxa de aprendizagem e, em specific, o algoritmo de taxa de aprendizagem de um ciclo (@abs-1708-07120).
Como encontrar uma boa taxa de aprendizagem inicial.
Para maior comodidade, o código está disponível em Colaborador Google – não é necessário copiar e colar.
Carregamento e pré-processamento de dados
O conjunto de dados de exemplo usado aqui está disponível em Kaggle.
Convenientemente, pode ser obtido usando torchdatasets
que utiliza pins
para autenticação, recuperação e armazenamento. Para habilitar pins
para gerenciar seus downloads do Kaggle, siga as instruções aqui.
Este conjunto de dados é muito “limpo”, ao contrário das imagens com as quais estamos acostumados, por exemplo, ImageNet. Para ajudar na generalização, introduzimos ruído durante o treinamento – em outras palavras, realizamos aumento de dados. Em torchvision
o aumento de dados faz parte de um pipeline de processamento de imagem que primeiro converte uma imagem em um tensor e depois aplica quaisquer transformações, como redimensionamento, corte, normalização ou várias formas de distorção.
Abaixo estão as transformações realizadas no conjunto de treinamento. Observe como a maioria deles é para aumento de dados, enquanto a normalização é feita para cumprir o que é esperado pelo ResNet.
Pipeline de pré-processamento de imagem
library(torch)
library(torchvision)
library(torchdatasets)
library(dplyr)
library(pins)
library(ggplot2)
system <- if (cuda_is_available()) torch_device("cuda:0") else "cpu"
train_transforms <- operate(img) {
img %>%
# first convert picture to tensor
transform_to_tensor() %>%
# then transfer to the GPU (if obtainable)
(operate(x) x$to(system = system)) %>%
# knowledge augmentation
transform_random_resized_crop(dimension = c(224, 224)) %>%
# knowledge augmentation
transform_color_jitter() %>%
# knowledge augmentation
transform_random_horizontal_flip() %>%
# normalize in accordance to what's anticipated by resnet
transform_normalize(imply = c(0.485, 0.456, 0.406), std = c(0.229, 0.224, 0.225))
}
No conjunto de validação, não queremos introduzir ruído, mas ainda precisamos redimensionar, cortar e normalizar as imagens. O conjunto de teste deve ser tratado de forma idêntica.
E agora, vamos obter os dados, bem divididos em conjuntos de treinamento, validação e teste. Além disso, informamos aos objetos R correspondentes quais transformações eles devem aplicar:
train_ds <- bird_species_dataset("knowledge", obtain = TRUE, remodel = train_transforms)
valid_ds <- bird_species_dataset("knowledge", break up = "legitimate", remodel = valid_transforms)
test_ds <- bird_species_dataset("knowledge", break up = "take a look at", remodel = test_transforms)
Duas coisas a serem observadas. Primeiro, as transformações fazem parte do conjunto de dados conceito, em oposição ao carregador de dados encontraremos em breve. Segundo, vamos dar uma olhada em como as imagens foram armazenadas no disco. A estrutura geral de diretórios (começando em knowledge
que especificamos como o diretório raiz a ser usado) é este:
knowledge/bird_species/practice
knowledge/bird_species/legitimate
knowledge/bird_species/take a look at
No practice
, legitimate
e take a look at
diretórios, diferentes lessons de imagens residem em suas próprias pastas. Por exemplo, aqui está o structure do diretório para as três primeiras lessons do conjunto de testes:
knowledge/bird_species/take a look at/ALBATROSS/
- knowledge/bird_species/take a look at/ALBATROSS/1.jpg
- knowledge/bird_species/take a look at/ALBATROSS/2.jpg
- knowledge/bird_species/take a look at/ALBATROSS/3.jpg
- knowledge/bird_species/take a look at/ALBATROSS/4.jpg
- knowledge/bird_species/take a look at/ALBATROSS/5.jpg
knowledge/take a look at/'ALEXANDRINE PARAKEET'/
- knowledge/bird_species/take a look at/'ALEXANDRINE PARAKEET'/1.jpg
- knowledge/bird_species/take a look at/'ALEXANDRINE PARAKEET'/2.jpg
- knowledge/bird_species/take a look at/'ALEXANDRINE PARAKEET'/3.jpg
- knowledge/bird_species/take a look at/'ALEXANDRINE PARAKEET'/4.jpg
- knowledge/bird_species/take a look at/'ALEXANDRINE PARAKEET'/5.jpg
knowledge/take a look at/'AMERICAN BITTERN'/
- knowledge/bird_species/take a look at/'AMERICAN BITTERN'/1.jpg
- knowledge/bird_species/take a look at/'AMERICAN BITTERN'/2.jpg
- knowledge/bird_species/take a look at/'AMERICAN BITTERN'/3.jpg
- knowledge/bird_species/take a look at/'AMERICAN BITTERN'/4.jpg
- knowledge/bird_species/take a look at/'AMERICAN BITTERN'/5.jpg
Este é exatamente o tipo de structure esperado por torch
é image_folder_dataset()
– e realmente bird_species_dataset()
instancia um subtipo desta classe. Se tivéssemos baixado os dados manualmente, respeitando a estrutura de diretórios necessária, poderíamos ter criado os conjuntos de dados da seguinte forma:
# e.g.
train_ds <- image_folder_dataset(
file.path(data_dir, "practice"),
remodel = train_transforms)
Agora que obtivemos os dados, vamos ver quantos itens existem em cada conjunto.
train_ds$.size()
valid_ds$.size()
test_ds$.size()
31316
1125
1125
Esse conjunto de treinamento é realmente grande! Portanto, é recomendado rodar isso na GPU ou apenas brincar com o pocket book Colab fornecido.
Com tantas amostras, estamos curiosos para saber quantas lessons existem.
class_names <- test_ds$lessons
size(class_names)
225
Então nós fazer temos um conjunto de treinamento substancial, mas a tarefa também é formidável: vamos distinguir nada menos que 225 espécies diferentes de aves.
Carregadores de dados
Enquanto conjuntos de dados saber o que fazer com cada merchandise, carregadores de dados saiba como tratá-los coletivamente. Quantas amostras compõem um lote? Queremos alimentá-los sempre na mesma ordem ou, em vez disso, escolher uma ordem diferente para cada época?
batch_size <- 64
train_dl <- dataloader(train_ds, batch_size = batch_size, shuffle = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = batch_size)
test_dl <- dataloader(test_ds, batch_size = batch_size)
Os carregadores de dados também podem ser consultados quanto ao seu comprimento. Agora comprimento significa: Quantos lotes?
train_dl$.size()
valid_dl$.size()
test_dl$.size()
490
18
18
Alguns pássaros
A seguir, vamos ver algumas imagens do conjunto de teste. Podemos recuperar o primeiro lote – imagens e lessons correspondentes – criando um iterador a partir do dataloader
e ligando subsequent()
nele:
# for show functions, right here we are literally utilizing a batch_size of 24
batch <- train_dl$.iter()$.subsequent()
batch
é uma lista, sendo o primeiro merchandise os tensores de imagem:
(1) 24 3 224 224
E a segunda, as aulas:
(1) 24
As lessons são codificadas como números inteiros, para serem usadas como índices em um vetor de nomes de lessons. Usaremos isso para rotular as imagens.
lessons <- batch((2))
lessons
torch_tensor
1
1
1
1
1
2
2
2
2
2
3
3
3
3
3
4
4
4
4
4
5
5
5
5
( GPULongType{24} )
Os tensores de imagem têm forma batch_size x num_channels x top x width
. Para plotar usando as.raster()
precisamos remodelar as imagens para que os canais fiquem por último. Também desfazemos a normalização aplicada pelo dataloader
.
Aqui estão as primeiras vinte e quatro imagens:
library(dplyr)
pictures <- as_array(batch((1))) %>% aperm(perm = c(1, 3, 4, 2))
imply <- c(0.485, 0.456, 0.406)
std <- c(0.229, 0.224, 0.225)
pictures <- std * pictures + imply
pictures <- pictures * 255
pictures(pictures > 255) <- 255
pictures(pictures < 0) <- 0
par(mfcol = c(4,6), mar = rep(1, 4))
pictures %>%
purrr::array_tree(1) %>%
purrr::set_names(class_names(as_array(lessons))) %>%
purrr::map(as.raster, max = 255) %>%
purrr::iwalk(~{plot(.x); title(.y)})
Modelo
A espinha dorsal do nosso modelo é uma instância pré-treinada do ResNet.
mannequin <- model_resnet18(pretrained = TRUE)
Mas queremos distinguir entre as nossas 225 espécies de aves, enquanto a ResNet foi treinada em 1000 lessons diferentes. O que podemos fazer? Simplesmente substituímos a camada de saída.
A nova camada de saída também é a única cujos pesos iremos treinar – deixando todos os outros parâmetros ResNet como estão. Tecnicamente, nós poderia realizar retropropagação por meio do modelo completo, esforçando-se também para ajustar os pesos do ResNet. No entanto, isso retardaria significativamente o treinamento. Na verdade, a escolha não é tudo ou nada: depende de nós quantos parâmetros originais manter fixos e quantos “liberar” para ajuste fino. Para a tarefa em questão, ficaremos satisfeitos em apenas treinar a camada de saída recém-adicionada: Com a abundância de animais, incluindo pássaros, no ImageNet, esperamos que o ResNet treinado saiba muito sobre eles!
Para substituir a camada de saída, o modelo é modificado no native:
num_features <- mannequin$fc$in_features
mannequin$fc <- nn_linear(in_features = num_features, out_features = size(class_names))
Agora coloque o modelo modificado na GPU (se disponível):
mannequin <- mannequin$to(system = system)
Treinamento
Para otimização, usamos perda de entropia cruzada e gradiente descendente estocástico.
criterion <- nn_cross_entropy_loss()
optimizer <- optim_sgd(mannequin$parameters, lr = 0.1, momentum = 0.9)
Encontrando uma taxa de aprendizagem idealmente eficiente
Definimos a taxa de aprendizagem para 0.1
mas isso é apenas uma formalidade. Como se tornou amplamente conhecido pelas excelentes palestras de rápido.aifaz sentido dedicar algum tempo antecipadamente para determinar uma taxa de aprendizagem eficiente. Enquanto estiver fora da caixa, torch
não fornece uma ferramenta como o localizador de taxa de aprendizagem do quick.ai, a lógica é simples de implementar. Veja como encontrar uma boa taxa de aprendizado, traduzida para R de Postagem de Sylvain Gugger:
# ported from: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
losses <- c()
log_lrs <- c()
find_lr <- operate(init_value = 1e-8, final_value = 10, beta = 0.98) {
num <- train_dl$.size()
mult = (final_value/init_value)^(1/num)
lr <- init_value
optimizer$param_groups((1))$lr <- lr
avg_loss <- 0
best_loss <- 0
batch_num <- 0
coro::loop(for (b in train_dl) )
}
find_lr()
df <- knowledge.body(log_lrs = log_lrs, losses = losses)
ggplot(df, aes(log_lrs, losses)) + geom_point(dimension = 1) + theme_classic()
A melhor taxa de aprendizagem não é aquela em que a perda é mínima. Em vez disso, deveria ser escolhido um pouco mais cedo na curva, enquanto a perda ainda está diminuindo. 0.05
parece uma escolha sensata.
Este valor nada mais é do que uma âncora. Agendadores de taxa de aprendizagem permitir que as taxas de aprendizagem evoluam de acordo com algum algoritmo comprovado. Entre outros, torch
implementa aprendizagem de um ciclo (@abs-1708-07120), taxas de aprendizagem cíclicas (Smith 2015)e recozimento de cosseno com reinicializações a quente (Loschilov e Hutter 2016).
Aqui, usamos lr_one_cycle()
passando nosso valor recém-descoberto, idealmente eficiente e, esperançosamente, 0.05
como uma taxa máxima de aprendizagem. lr_one_cycle()
começará com uma taxa baixa e aumentará gradualmente até atingir o máximo permitido. Depois disso, a taxa de aprendizagem diminuirá lenta e continuamente, até cair ligeiramente abaixo do seu valor inicial.
Tudo isso não acontece por época, mas exatamente uma vez, razão pela qual o nome foi one_cycle
nele. Veja como fica a evolução das taxas de aprendizagem em nosso exemplo:
Antes de começarmos o treinamento, vamos reinicializar rapidamente o modelo, para começar do zero:
mannequin <- model_resnet18(pretrained = TRUE)
mannequin$parameters %>% purrr::stroll(operate(param) param$requires_grad_(FALSE))
num_features <- mannequin$fc$in_features
mannequin$fc <- nn_linear(in_features = num_features, out_features = size(class_names))
mannequin <- mannequin$to(system = system)
criterion <- nn_cross_entropy_loss()
optimizer <- optim_sgd(mannequin$parameters, lr = 0.05, momentum = 0.9)
E instancie o agendador:
num_epochs = 10
scheduler <- optimizer %>%
lr_one_cycle(max_lr = 0.05, epochs = num_epochs, steps_per_epoch = train_dl$.size())
Ciclo de treinamento
Agora treinamos por dez épocas. Para cada lote de treinamento, chamamos scheduler$step()
para ajustar a taxa de aprendizagem. Notavelmente, isso tem que ser feito depois optimizer$step()
.
train_batch <- operate(b) {
optimizer$zero_grad()
output <- mannequin(b((1)))
loss <- criterion(output, b((2))$to(system = system))
loss$backward()
optimizer$step()
scheduler$step()
loss$merchandise()
}
valid_batch <- operate(b) {
output <- mannequin(b((1)))
loss <- criterion(output, b((2))$to(system = system))
loss$merchandise()
}
for (epoch in 1:num_epochs) {
mannequin$practice()
train_losses <- c()
coro::loop(for (b in train_dl) {
loss <- train_batch(b)
train_losses <- c(train_losses, loss)
})
mannequin$eval()
valid_losses <- c()
coro::loop(for (b in valid_dl) {
loss <- valid_batch(b)
valid_losses <- c(valid_losses, loss)
})
cat(sprintf("nLoss at epoch %d: coaching: %3f, validation: %3fn", epoch, imply(train_losses), imply(valid_losses)))
}
Loss at epoch 1: coaching: 2.662901, validation: 0.790769
Loss at epoch 2: coaching: 1.543315, validation: 1.014409
Loss at epoch 3: coaching: 1.376392, validation: 0.565186
Loss at epoch 4: coaching: 1.127091, validation: 0.575583
Loss at epoch 5: coaching: 0.916446, validation: 0.281600
Loss at epoch 6: coaching: 0.775241, validation: 0.215212
Loss at epoch 7: coaching: 0.639521, validation: 0.151283
Loss at epoch 8: coaching: 0.538825, validation: 0.106301
Loss at epoch 9: coaching: 0.407440, validation: 0.083270
Loss at epoch 10: coaching: 0.354659, validation: 0.080389
Parece que o modelo fez um bom progresso, mas ainda não sabemos nada sobre a precisão da classificação em termos absolutos. Verificaremos isso no conjunto de testes.
Precisão do conjunto de testes
Finalmente, calculamos a precisão no conjunto de teste:
mannequin$eval()
test_batch <- operate(b) {
output <- mannequin(b((1)))
labels <- b((2))$to(system = system)
loss <- criterion(output, labels)
test_losses <<- c(test_losses, loss$merchandise())
# torch_max returns an inventory, with place 1 containing the values
# and place 2 containing the respective indices
predicted <- torch_max(output$knowledge(), dim = 2)((2))
whole <<- whole + labels$dimension(1)
# add variety of right classifications on this batch to the mixture
right <<- right + (predicted == labels)$sum()$merchandise()
}
test_losses <- c()
whole <- 0
right <- 0
for (b in enumerate(test_dl)) {
test_batch(b)
}
imply(test_losses)
(1) 0.03719
test_accuracy <- right/whole
test_accuracy
(1) 0.98756
Um resultado impressionante, considerando quantas espécies diferentes existem!
Conclusão
Esperamos que esta tenha sido uma introdução útil à classificação de imagens com torch
bem como aos seus elementos arquitetônicos não específicos do domínio, como conjuntos de dados, carregadores de dados e agendadores de taxa de aprendizagem. Postagens futuras explorarão outros domínios, bem como irão além do “olá mundo” no reconhecimento de imagens. Obrigado por ler!
Ele, Kaiming, Xiangyu Zhang, Shaoqing Ren e Jian Solar. 2015. “Aprendizagem residual profunda para reconhecimento de imagens.” CoRR abs/1512.03385. http://arxiv.org/abs/1512.03385.
Loshchilov, Ilya e Frank Hutter. 2016. “SGDR: Descida gradiente estocástica com reinicializações.” CoRR abs/1608.03983. http://arxiv.org/abs/1608.03983.
Smith, Leslie N. 2015. “Chega de jogos incômodos de adivinhação de taxa de aprendizagem.” CoRR abs/1506.01186. http://arxiv.org/abs/1506.01186.