Nota: Para acompanhar esta postagem, você precisará torch
versão 0.5, que até o momento em que este livro foi escrito ainda não estava no CRAN. Enquanto isso, instale a versão de desenvolvimento em GitHub.
Cada domínio tem seus conceitos, e estes são o que precisamos entender, em algum momento, em nossa jornada de copiar e fazer funcionar até a utilização proposital e deliberada. Além disso, infelizmente, cada domínio tem o seu jargão, segundo o qual os termos são usados de uma forma que é tecnicamente correta, mas não consegue evocar uma imagem clara para os ainda não iniciados. O JIT do (Py-)Torch é um exemplo.
Introdução terminológica
“The JIT”, muito falado no mundo PyTorch e uma característica eminente do R torch
também, são duas coisas ao mesmo tempo – dependendo de como você olha para isso: um compilador otimizador; e um passe livre para execução em muitos ambientes onde nem R nem Python estão presentes.
Compilado, interpretado, compilado just-in-time
“JIT” é um acrônimo comum para “simply in time” (a saber: compilação). Compilação significa gerar código executável por máquina; é algo que precisa acontecer com todos os programas para que possam ser executados. A questão é quando.
O código C, por exemplo, é compilado “manualmente”, em algum momento arbitrário antes da execução. Muitas outras linguagens, entretanto (entre elas Java, R e Python) são – em suas implementações padrão, pelo menos – interpretado: Eles vêm com executáveis (java
, R
e python
resp.) que criam código de máquina em tempo de execuçãobaseado no programa authentic conforme escrito ou em um formato intermediário chamado bytecódigo. A interpretação pode prosseguir linha por linha, como quando você insere algum código no REPL (loop read-eval-print) do R, ou em pedaços (se houver um script ou aplicativo inteiro a ser executado). Neste último caso, como o interpretador sabe o que provavelmente será executado em seguida, ele pode implementar otimizações que de outra forma seriam impossíveis. Este processo é comumente conhecido como compilação simply in time. Assim, na linguagem geral, a compilação JIT é uma compilação, mas em um momento em que o programa já está em execução.
O torch
compilador just-in-time
Comparado a essa noção de JIT, ao mesmo tempo genérica (em termos técnicos) e específica (em termos de tempo), o que as pessoas da (Py-)Torch têm em mente quando falam de “o JIT” é ao mesmo tempo definido de forma mais restrita (em termos de operações) e mais inclusivo (no tempo): O que se entende é o processo completo desde o fornecimento de entrada de código que pode ser convertido em uma representação intermediária (IR), by way of geração desse IR, by way of otimização sucessiva do mesmo pelo compilador JIT, by way of conversão (novamente, pelo compilador) para bytecode, para – finalmente – execução, novamente cuidada pelo mesmo compilador, que agora atua como uma máquina digital.
Se isso parece complicado, não tenha medo. Para realmente fazer uso desse recurso do R, não é necessário aprender muito em termos de sintaxe; uma única função, complementada por alguns ajudantes especializados, está a conter toda a carga pesada. O que importa, porém, é entender um pouco sobre como funciona a compilação JIT, para que você saiba o que esperar e não seja surpreendido por resultados inesperados.
O que está por vir (neste texto)
Este put up tem mais três partes.
Na primeira, explicamos como fazer uso dos recursos JIT em R torch
. Além da sintaxe, nos concentramos na semântica (o que acontece essencialmente quando você “rastreia JIT” um trecho de código) e como isso afeta o resultado.
No segundo, “espiamos um pouco por baixo do capô”; sinta-se à vontade para folhear rapidamente se isso não lhe interessa muito.
Na terceira, mostramos um exemplo de uso de compilação JIT para permitir a implantação em um ambiente que não possui R instalado.
Como fazer uso torch
Compilação JIT
No mundo Python, ou mais especificamente, nas encarnações Python de estruturas de aprendizado profundo, existe um verbo mágico “hint” que se refere a uma maneira de obter uma representação gráfica a partir da execução de código avidamente. Ou seja, você executa um trecho de código – uma função, digamos, contendo operações PyTorch – em entradas de exemplo. Essas entradas de exemplo são arbitrárias em termos de valores, mas (naturalmente) precisam estar em conformidade com as formas esperadas pela função. O rastreamento registrará então as operações conforme executadas, ou seja: aquelas operações que eram de fato executado, e apenas aqueles. Quaisquer caminhos de código não inseridos serão condenados ao esquecimento.
Também em R, o rastreamento é como obtemos uma primeira representação intermediária. Isso é feito usando a função apropriadamente nomeada jit_trace()
. Por exemplo:
Agora podemos chamar a função rastreada como a authentic:
f_t(torch_randn(c(3, 3)))
torch_tensor
3.19587
( CPUFloatType{} )
O que acontece se houver fluxo de controle, como um if
declaração?
f <- operate(x) {
if (as.numeric(torch_sum(x)) > 0) torch_tensor(1) else torch_tensor(2)
}
f_t <- jit_trace(f, torch_tensor(c(2, 2)))
Aqui o rastreamento deve ter entrado no if
filial. Agora chame a função rastreada com um tensor que não soma um valor maior que zero:
torch_tensor
1
( CPUFloatType{1} )
É assim que funciona o rastreamento. Os caminhos não percorridos ficam perdidos para sempre. A lição aqui é nunca ter fluxo de controle dentro de uma função que será rastreada.
Antes de prosseguirmos, vamos citar rapidamente dois dos mais utilizados, além jit_trace()
funções no torch
Ecossistema JIT: jit_save()
e jit_load()
. Aqui estão eles:
jit_save(f_t, "/tmp/f_t")
f_t_new <- jit_load("/tmp/f_t")
Uma primeira olhada nas otimizações
Otimizações realizadas pelo torch
O compilador JIT acontece em etapas. Na primeira passagem, vemos coisas como eliminação de código morto e pré-cálculo de constantes. Tome esta função:
f <- operate(x) {
a <- 7
b <- 11
c <- 2
d <- a + b + c
e <- a + b + c + 25
x + d
}
Aqui cálculo de e
é inútil – nunca é usado. Consequentemente, na representação intermediária, e
nem aparece. Além disso, como os valores de a
, b
e c
já são conhecidos em tempo de compilação, a única constante presente no IR é d
sua soma.
Bem, podemos verificar isso por nós mesmos. Para espiar o IR – o IR inicial, para ser mais preciso – primeiro rastreamos f
e, em seguida, acesse a função rastreada graph
propriedade:
f_t <- jit_trace(f, torch_tensor(0))
f_t$graph
graph(%0 : Float(1, strides=(1), requires_grad=0, gadget=cpu)):
%1 : float = prim::Fixed(worth=20.)()
%2 : int = prim::Fixed(worth=1)()
%3 : Float(1, strides=(1), requires_grad=0, gadget=cpu) = aten::add(%0, %1, %2)
return (%3)
E realmente, o único cálculo registrado é aquele que adiciona 20 ao tensor passado.
Até agora, falamos sobre a passagem inicial do compilador JIT. Mas o processo não para por aí. Nas passagens subsequentes, a otimização se expande para o domínio das operações tensoriais.
Tome a seguinte função:
f <- operate(x) {
m1 <- torch_eye(5, gadget = "cuda")
x <- x$mul(m1)
m2 <- torch_arange(begin = 1, finish = 25, gadget = "cuda")$view(c(5,5))
x <- x$add(m2)
x <- torch_relu(x)
x$matmul(m2)
}
Por mais inofensiva que essa função possa parecer, ela incorre em bastante sobrecarga de agendamento. Uma GPU separada núcleo (uma função C, para ser paralelizada em muitos threads CUDA) é necessária para cada um dos torch_mul()
, torch_add()
, torch_relu()
e torch_matmul()
.
Sob certas condições, diversas operações podem ser encadeadas (ou fundidopara usar o termo técnico) em um único. Aqui, três desses quatro métodos (ou seja, todos, exceto torch_matmul()
) operar pontualmente; isto é, eles modificam cada elemento de um tensor isoladamente. Conseqüentemente, eles não apenas se prestam de maneira preferrred à paralelização individualmente, – o mesmo seria verdade para uma função que fosse compor (“fundi-los”): Para calcular uma função composta “multiplique e adicione então ReLU”
( relu() circ (+) circ
) em um tensorelemento
nada precisa ser conhecido sobre outros elementos do tensor. A operação agregada poderia então ser executada na GPU em um único kernel.
Para que isso aconteça, normalmente você teria que escrever um código CUDA personalizado. Graças ao compilador JIT, em muitos casos você não precisa fazer isso: ele criará esse kernel instantaneamente. graph_for()
Para ver a fusão em ação, usamos graph
(um método) em vez de
v <- jit_trace(f, torch_eye(5, gadget = "cuda"))
v$graph_for(torch_eye(5, gadget = "cuda"))
graph(%x.1 : Tensor):
%1 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = prim::Fixed(worth=)()
%24 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0), %25 : bool = prim::TypeCheck(sorts=(Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0)))(%x.1)
%26 : Tensor = prim::If(%25)
block0():
%x.14 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = prim::TensorExprGroup_0(%24)
-> (%x.14)
block1():
%34 : Perform = prim::Fixed(identify="fallback_function", fallback=1)()
%35 : (Tensor) = prim::CallFunction(%34, %x.1)
%36 : Tensor = prim::TupleUnpack(%35)
-> (%36)
%14 : Tensor = aten::matmul(%26, %1) # :7:0
return (%14)
with prim::TensorExprGroup_0 = graph(%x.1 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0)):
%4 : int = prim::Fixed(worth=1)()
%3 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = prim::Fixed(worth=)()
%7 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = prim::Fixed(worth=)()
%x.10 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = aten::mul(%x.1, %7) # :4:0
%x.6 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = aten::add(%x.10, %3, %4) # :5:0
%x.2 : Float(5, 5, strides=(5, 1), requires_grad=0, gadget=cuda:0) = aten::relu(%x.6) # :6:0
return (%x.2)
(uma propriedade): TensorExprGroup
A partir deste resultado, aprendemos que três das quatro operações foram agrupadas para formar um TensorExprGroup
. Esse
será compilado em um único kernel CUDA. A multiplicação de matrizes, entretanto – não sendo uma operação pontual – deve ser executada por si só. Neste ponto, paramos nossa exploração das otimizações JIT e passamos para o último tópico: implantação de modelo em ambientes sem R. Se você quiser saber mais, Thomas Viehmann’s weblog
torch
tem postagens que fornecem detalhes incríveis sobre a compilação (Py-)Torch JIT.
sem R jit_load()
Nosso plano é o seguinte: Definimos e treinamos um modelo, em R. Depois, rastreamos e salvamos. O arquivo salvo é então ed em outro ambiente, um ambiente que não possui R instalado. Qualquer linguagem que tenha uma implementação do Torch servirá, desde que a implementação inclua a funcionalidade JIT. A maneira mais direta de mostrar como isso funciona é usando Python. Para implantação com C++, consulte o instruções detalhadas
no website PyTorch.
Definir modelo
library(torch)
web <- nn_module(
initialize = operate() {
self$l1 <- nn_linear(3, 8)
self$l2 <- nn_linear(8, 16)
self$l3 <- nn_linear(16, 1)
self$d1 <- nn_dropout(0.2)
self$d2 <- nn_dropout(0.2)
},
ahead = operate(x) {
x %>%
self$l1() %>%
nnf_relu() %>%
self$d1() %>%
self$l2() %>%
nnf_relu() %>%
self$d2() %>%
self$l3()
}
)
train_model <- web()
Nosso modelo de exemplo é um perceptron multicamadas simples. Observe, porém, que ele possui duas camadas de eliminação. As camadas de abandono se comportam de maneira diferente durante o treinamento e a avaliação; e como aprendemos, as decisões tomadas durante o rastreamento são imutáveis. Isso é algo que precisaremos cuidar quando terminarmos de treinar o modelo.
Modelo de treinamento em conjunto de dados de brinquedo
toy_dataset <- dataset(
identify = "toy_dataset",
initialize = operate(input_dim, n) {
df <- na.omit(df)
self$x <- torch_randn(n, input_dim)
self$y <- self$x(, 1, drop = FALSE) * 0.2 -
self$x(, 2, drop = FALSE) * 1.3 -
self$x(, 3, drop = FALSE) * 0.5 +
torch_randn(n, 1)
},
.getitem = operate(i) {
record(x = self$x(i, ), y = self$y(i))
},
.size = operate() {
self$x$measurement(1)
}
)
input_dim <- 3
n <- 1000
train_ds <- toy_dataset(input_dim, n)
train_dl <- dataloader(train_ds, shuffle = TRUE)
Para fins de demonstração, criamos um conjunto de dados de brinquedo com três preditores e um alvo escalar.
optimizer <- optim_adam(train_model$parameters, lr = 0.001)
num_epochs <- 10
train_batch <- operate(b) {
optimizer$zero_grad()
output <- train_model(b$x)
goal <- b$y
loss <- nnf_mse_loss(output, goal)
loss$backward()
optimizer$step()
loss$merchandise()
}
for (epoch in 1:num_epochs) {
train_loss <- c()
coro::loop(for (b in train_dl) {
loss <- train_batch(b)
train_loss <- c(train_loss, loss)
})
cat(sprintf("nEpoch: %d, loss: %3.4fn", epoch, imply(train_loss)))
}
Epoch: 1, loss: 2.6753
Epoch: 2, loss: 1.5629
Epoch: 3, loss: 1.4295
Epoch: 4, loss: 1.4170
Epoch: 5, loss: 1.4007
Epoch: 6, loss: 1.2775
Epoch: 7, loss: 1.2971
Epoch: 8, loss: 1.2499
Epoch: 9, loss: 1.2824
Epoch: 10, loss: 1.2596
Treinamos o tempo suficiente para garantir que podemos distinguir a saída de um modelo não treinado daquela de um modelo treinado. eval
Rastrear
modo Agora, para implantação, queremos um modelo que faça não eval()
elimine quaisquer elementos tensores. Isso significa que antes de rastrear, precisamos colocar o modelo em
train_model$eval()
train_model <- jit_trace(train_model, torch_tensor(c(1.2, 3, 0.1)))
jit_save(train_model, "/tmp/mannequin.zip")
modo.
O modelo salvo agora pode ser copiado para um sistema diferente.
Modelo de consulta do Python jit.load()
Para fazer uso deste modelo do Python, nós (1, 1, 1)
e chame-o como faríamos em R. Vejamos: Para um tensor de entrada de
import torch
= torch.jit.load("/tmp/mannequin.zip")
deploy_model 1, 1, 1), dtype = torch.float)) deploy_model(torch.tensor((
tensor((-1.3630), gadget='cuda:0', grad_fn=)
esperamos uma previsão em torno de -1,6:
Isso é próximo o suficiente para nos garantir que o modelo implantado manteve os pesos do modelo treinado.
Conclusão torch
Neste put up, nos concentramos em resolver um pouco da confusão terminológica que cerca o compilador JIT e mostrou como treinar um modelo em R, rastrear
e consulte o modelo recém-carregado do Python. Deliberadamente, não entramos em casos complexos e/ou extremos – em R, esse recurso ainda está em desenvolvimento ativo. Se você tiver problemas com seu próprio código que usa JIT, não hesite em criar um problema no GitHub!
E como sempre – obrigado pela leitura! Foto de Johnny Kennaugh sobre
r-blogueiros