Posit AI Weblog: Que haja luz: Mais luz para tocha!



Posit AI Weblog: Que haja luz: Mais luz para tocha!

… Antes de começarmos, peço desculpas aos nossos leitores de língua espanhola… Tive que fazer uma escolha entre “haja” e “haya”e no remaining tudo dependia de um cara ou coroa…

Enquanto escrevo isto, estamos mais do que satisfeitos com a rápida adoção que vimos do torch – não apenas para uso imediato, mas também em pacotes que se baseiam nele, fazendo uso de sua funcionalidade principal.

Porém, em um cenário aplicado – um cenário que envolve treinamento e validação em sincronia, cálculo de métricas e ação sobre elas, e alteração dinâmica de hiperparâmetros durante o processo – às vezes pode parecer que há uma quantidade não negligenciável de código clichê envolvido. Por um lado, existe o loop principal sobre épocas e, dentro dele, os loops sobre lotes de treinamento e validação. Além disso, etapas como atualizar o modelo modo (treinamento ou validação, respectivamente), zerar e calcular gradientes e propagar atualizações de modelo de volta devem ser executados na ordem correta. Por último e não menos importante, deve-se tomar cuidado para que, a qualquer momento, os tensores estejam localizados no esperado dispositivo.

Não seria um sonho secomo dizia a well-liked série “Head First …” do início dos anos 2000, havia uma maneira de eliminar essas etapas manuais, mantendo a flexibilidade? Com luzhá.

Nesta postagem, nosso foco está em duas coisas: em primeiro lugar, no próprio fluxo de trabalho simplificado; e segundo, mecanismos genéricos que permitem customização. Para exemplos mais detalhados deste último, além de instruções concretas de codificação, iremos vincular à documentação (já extensa).

Treine, valide e depois teste: um fluxo de trabalho básico de aprendizado profundo com luz

Para demonstrar o fluxo de trabalho essencial, utilizamos um conjunto de dados que está prontamente disponível e não nos distrairá muito, em termos de pré-processamento: ou seja, o Cães vs. Gatos coleção que vem com torchdatasets. torchvision será necessário para transformações de imagem; além desses dois pacotes, tudo o que precisamos é torch e luz.

Dados

O conjunto de dados é baixado do Kaggle; você precisará editar o caminho abaixo para refletir a localização do seu próprio token Kaggle.

dir <- "~/Downloads/dogs-vs-cats" 

ds <- torchdatasets::dogs_vs_cats_dataset(
  dir,
  token = "~/.kaggle/kaggle.json",
  remodel = . %>%
    torchvision::transform_to_tensor() %>%
    torchvision::transform_resize(dimension = c(224, 224)) %>% 
    torchvision::transform_normalize(rep(0.5, 3), rep(0.5, 3)),
  target_transform = operate(x) as.double(x) - 1
)

Convenientemente, podemos usar dataset_subset() para particionar os dados em conjuntos de treinamento, validação e teste.

train_ids <- pattern(1:size(ds), dimension = 0.6 * size(ds))
valid_ids <- pattern(setdiff(1:size(ds), train_ids), dimension = 0.2 * size(ds))
test_ids <- setdiff(1:size(ds), union(train_ids, valid_ids))

train_ds <- dataset_subset(ds, indices = train_ids)
valid_ds <- dataset_subset(ds, indices = valid_ids)
test_ds <- dataset_subset(ds, indices = test_ids)

A seguir, instanciamos o respectivo dataloaderS.

train_dl <- dataloader(train_ds, batch_size = 64, shuffle = TRUE, num_workers = 4)
valid_dl <- dataloader(valid_ds, batch_size = 64, num_workers = 4)
test_dl <- dataloader(test_ds, batch_size = 64, num_workers = 4)

Para os dados é isso – nenhuma mudança no fluxo de trabalho até agora. Também não há diferença na forma como definimos o modelo.

Modelo

Para acelerar o treinamento, utilizamos o AlexNet pré-treinado ( Krizhevsky (2014)).

internet <- torch::nn_module(
  
  initialize = operate(output_size) {
    self$mannequin <- model_alexnet(pretrained = TRUE)

    for (par in self$parameters) {
      par$requires_grad_(FALSE)
    }

    self$mannequin$classifier <- nn_sequential(
      nn_dropout(0.5),
      nn_linear(9216, 512),
      nn_relu(),
      nn_linear(512, 256),
      nn_relu(),
      nn_linear(256, output_size)
    )
  },
  ahead = operate(x) {
    self$mannequin(x)(,1)
  }
  
)

Se você olhar de perto, verá que tudo o que fizemos até agora foi definir o modelo. Ao contrário de um torch-only, não iremos instanciá-lo e nem iremos movê-lo para uma eventual GPU.

Expandindo este último, podemos dizer mais: Todos do manuseio do dispositivo é gerenciado por luz. Ele investiga a existência de uma GPU compatível com CUDA e, se encontrar uma, garante que os pesos do modelo e os tensores de dados sejam movidos para lá de forma transparente sempre que necessário. O mesmo vale para a direção oposta: as previsões calculadas no conjunto de testes, por exemplo, são transferidas silenciosamente para a CPU, prontas para o usuário manipulá-las ainda mais em R. Mas quanto às previsões, ainda não chegamos lá: em para modelar o treinamento, onde a diferença feita por luz salta direto nos olhos.

Treinamento

Abaixo, você vê quatro chamadas para luzdois dos quais são obrigatórios em todas as configurações e dois dependem do caso. Os sempre necessários são setup() e match() :

  • Em setup()você diz luz qual deve ser a perda e qual otimizador usar. Opcionalmente, além da perda em si (a métrica primária, de certa forma, na medida em que informa a atualização do peso) você pode ter luz calcular os adicionais. Aqui, por exemplo, pedimos precisão na classificação. (Para um ser humano observando uma barra de progresso, uma precisão de duas lessons de 0,91 é muito mais indicativa do que uma perda de entropia cruzada de 1,26.)

  • Em match()você passa referências para o treinamento e validação dataloaderS. Embora exista um padrão para o número de épocas para treinar, normalmente você também desejará passar um valor personalizado para esse parâmetro.

As chamadas dependentes de caso aqui, então, são aquelas para set_hparams() e set_opt_hparams(). Aqui,

  • set_hparams() aparece porque, na definição do modelo, tivemos initialize() pegue um parâmetro, output_size. Quaisquer argumentos esperados por initialize() precisa ser passado por este método.

  • set_opt_hparams() existe porque queremos usar uma taxa de aprendizado não padrão com optim_adam(). Se estivéssemos satisfeitos com o padrão, tal decisão não seria adequada.

fitted <- internet %>%
  setup(
    loss = nn_bce_with_logits_loss(),
    optimizer = optim_adam,
    metrics = listing(
      luz_metric_binary_accuracy_with_logits()
    )
  ) %>%
  set_hparams(output_size = 1) %>%
  set_opt_hparams(lr = 0.01) %>%
  match(train_dl, epochs = 3, valid_data = valid_dl)

Veja como a saída me pareceu:

predict(fitted, test_dl)

probs <- torch_sigmoid(preds)
print(probs, n = 5)
torch_tensor
 1.2959e-01
 1.3032e-03
 6.1966e-05
 5.9575e-01
 4.5577e-03
... (the output was truncated (use n=-1 to disable))
( CPUFloatType{5000} )

E é isso para um fluxo de trabalho completo. Caso você tenha experiência anterior com Keras, isso deve parecer bastante acquainted. O mesmo pode ser dito da técnica de customização mais versátil, porém padronizada, implementada em luz.

Como fazer (quase) qualquer coisa (quase) a qualquer hora

Como Keras, luz tem o conceito de retornos de chamada que pode “se conectar” ao processo de treinamento e executar código R arbitrário. Especificamente, o código pode ser agendado para execução em qualquer um dos seguintes momentos:

  • quando o processo geral de treinamento começa ou termina (on_fit_begin() / on_fit_end());

  • quando uma época de treinamento mais validação começa ou termina (on_epoch_begin() / on_epoch_end());

  • quando durante uma época, a metade do treinamento (validação, respectivamente) começa ou termina (on_train_begin() / on_train_end(); on_valid_begin() / on_valid_end());

  • quando durante o treinamento (validação, respectivamente) um novo lote está prestes a ser processado ou foi processado (on_train_batch_begin() / on_train_batch_end(); on_valid_batch_begin() / on_valid_batch_end());

  • e até mesmo em pontos de referência específicos dentro da lógica de treinamento/validação “mais interna”, como “cálculo pós-perda”, “após retrocesso” ou “após etapa”.

Embora você possa implementar qualquer lógica que desejar usando esta técnica, luz já vem equipado com um conjunto muito útil de retornos de chamada.

Por exemplo:

  • luz_callback_model_checkpoint() salva periodicamente os pesos do modelo.

  • luz_callback_lr_scheduler() permite ativar um dos torchde programadores de taxa de aprendizagem. Existem diferentes escalonadores, cada um seguindo sua própria lógica de como ajustam dinamicamente a taxa de aprendizado.

  • luz_callback_early_stopping() encerra o treinamento quando o desempenho do modelo para de melhorar.

Os retornos de chamada são passados ​​para match() em uma lista. Aqui adaptamos nosso exemplo acima, garantindo que (1) os pesos do modelo sejam salvos após cada época e (2), o treinamento termine se a perda de validação não melhorar por duas épocas consecutivas.

fitted <- internet %>%
  setup(
    loss = nn_bce_with_logits_loss(),
    optimizer = optim_adam,
    metrics = listing(
      luz_metric_binary_accuracy_with_logits()
    )
  ) %>%
  set_hparams(output_size = 1) %>%
  set_opt_hparams(lr = 0.01) %>%
  match(train_dl,
      epochs = 10,
      valid_data = valid_dl,
      callbacks = listing(luz_callback_model_checkpoint(path = "./fashions"),
                       luz_callback_early_stopping(endurance = 2)))

E quanto a outros tipos de requisitos de flexibilidade – como no cenário de modelos múltiplos e interativos, equipados, cada um, com suas próprias funções de perda e otimizadores? Nesses casos, o código ficará um pouco mais longo do que vimos aqui, mas luz ainda pode ajudar consideravelmente na simplificação do fluxo de trabalho.

Para concluir, usando luzvocê não perde nada da flexibilidade que vem com torchao mesmo tempo que ganha muito em simplicidade de código, modularidade e facilidade de manutenção. Ficaremos felizes em saber que você tentará!

Obrigado por ler!

Foto de JD Rincs sobre Remover respingo

KRIJHEVSKY, Alex. 2014. “Um truque estranho para paralelizar redes neurais convolucionais.” CoRR abs/1404.5997. http://arxiv.org/abs/1404.5997.

Deixe um comentário

O seu endereço de e-mail não será publicado. Campos obrigatórios são marcados com *