Temos o prazer de anunciar que a versão 0.2.0 do torch
acabei de pousar no CRAN.
Esta versão inclui muitas correções de bugs e alguns novos recursos interessantes que apresentaremos nesta postagem do weblog. Você pode ver o changelog completo no NOTÍCIAS.md arquivo.
Os recursos que discutiremos em detalhes são:
- Suporte inicial para rastreamento JIT
- Carregadores de dados multifuncionais
- Métodos de impressão para
nn_modules
Carregadores de dados multifuncionais
dataloaders
agora responda ao num_workers
argumento e executará o pré-processamento em trabalhadores paralelos.
Por exemplo, digamos que temos o seguinte conjunto de dados fictício que faz um cálculo longo:
library(torch)
dat <- dataset(
"mydataset",
initialize = perform(time, len = 10) {
self$time <- time
self$len <- len
},
.getitem = perform(i) {
Sys.sleep(self$time)
torch_randn(1)
},
.size = perform() {
self$len
}
)
ds <- dat(1)
system.time(ds(1))
person system elapsed
0.029 0.005 1.027
Criaremos agora dois dataloaders, um que executa sequencialmente e outro que executa em paralelo.
seq_dl <- dataloader(ds, batch_size = 5)
par_dl <- dataloader(ds, batch_size = 5, num_workers = 2)
Agora podemos comparar o tempo que leva para processar dois lotes sequencialmente com o tempo que leva em paralelo:
seq_it <- dataloader_make_iter(seq_dl)
par_it <- dataloader_make_iter(par_dl)
two_batches <- perform(it) {
dataloader_next(it)
dataloader_next(it)
"okay"
}
system.time(two_batches(seq_it))
system.time(two_batches(par_it))
person system elapsed
0.098 0.032 10.086
person system elapsed
0.065 0.008 5.134
Observe que são os lotes obtidos em paralelo, e não as observações individuais. Assim, poderemos oferecer suporte a conjuntos de dados com tamanhos de lote variáveis no futuro.
Usar vários trabalhadores é não necessariamente mais rápido que a execução serial porque há uma sobrecarga considerável ao passar tensores de um trabalhador para a sessão principal, bem como ao inicializar os trabalhadores.
Este recurso é habilitado pelo poderoso callr
pacote e funciona em todos os sistemas operacionais suportados por torch
. callr
vamos criar sessões R persistentes e, portanto, pagaremos apenas uma vez a sobrecarga de transferência de objetos de conjuntos de dados potencialmente grandes para os trabalhadores.
No processo de implementação deste recurso, fizemos com que os dataloaders se comportassem como coro
iteradores. Isso significa que agora você pode usar coro
Sintaxe do para percorrer os dataloaders:
coro::loop(for(batch in par_dl) {
print(batch$form)
})
(1) 5 1
(1) 5 1
Este é o primeiro torch
versão incluindo o recurso de carregadores de dados para vários trabalhadores, e você pode se deparar com casos extremos ao usá-lo. Informe-nos se encontrar algum problema.
Suporte JIT inicial
Programas que fazem uso do torch
pacote são inevitavelmente programas R e, portanto, eles sempre precisam de uma instalação R para serem executados.
A partir da versão 0.2.0, torch
permite que os usuários façam JIT rastrear
torch
Funções R no TorchScript. O rastreamento JIT (Simply in time) invocará uma função R com entradas de exemplo, registrará todas as operações que ocorreram quando a função foi executada e retornará um script_function
objeto contendo a representação TorchScript.
O bom disso é que os programas TorchScript são facilmente serializáveis, otimizáveis e podem ser carregados por outro programa escrito em PyTorch ou LibTorch sem exigir qualquer dependência de R.
Suponha que você tenha a seguinte função R que pega um tensor e faz uma multiplicação de matriz com uma matriz de peso fixo e depois adiciona um termo de polarização:
w <- torch_randn(10, 1)
b <- torch_randn(1)
fn <- perform(x) {
a <- torch_mm(x, w)
a + b
}
Esta função pode ser rastreada por JIT no TorchScript com jit_trace
passando a função e as entradas de exemplo:
x <- torch_ones(2, 10)
tr_fn <- jit_trace(fn, x)
tr_fn(x)
torch_tensor
-0.6880
-0.6880
( CPUFloatType{2,1} )
Agora tudo torch
as operações que aconteceram ao calcular o resultado desta função foram rastreadas e transformadas em um gráfico:
graph(%0 : Float(2:10, 10:1, requires_grad=0, gadget=cpu)):
%1 : Float(10:1, 1:1, requires_grad=0, gadget=cpu) = prim::Fixed(worth=-0.3532 0.6490 -0.9255 0.9452 -1.2844 0.3011 0.4590 -0.2026 -1.2983 1.5800 ( CPUFloatType{10,1} ))()
%2 : Float(2:1, 1:1, requires_grad=0, gadget=cpu) = aten::mm(%0, %1)
%3 : Float(1:1, requires_grad=0, gadget=cpu) = prim::Fixed(worth={-0.558343})()
%4 : int = prim::Fixed(worth=1)()
%5 : Float(2:1, 1:1, requires_grad=0, gadget=cpu) = aten::add(%2, %3, %4)
return (%5)
A função rastreada pode ser serializada com jit_save
:
jit_save(tr_fn, "linear.pt")
Ele pode ser recarregado em R com jit_load
mas também pode ser recarregado em Python com torch.jit.load
:
import torch
= torch.jit.load("linear.pt")
fn 2, 10)) fn(torch.ones(
tensor(((-0.6880),
(-0.6880)))
Quão authorized é isso?!
Este é apenas o suporte inicial para JIT em R. Continuaremos desenvolvendo isso. Especificamente, na próxima versão do torch
planejamos oferecer suporte ao rastreamento nn_modules
diretamente. Atualmente, é necessário desanexar todos os parâmetros antes de rastreá-los; veja um exemplo aqui. Isso permitirá que você também aproveite os benefícios do TorchScript para fazer seus modelos rodarem mais rápido!
Observe também que o rastreamento tem algumas limitações, especialmente quando seu código possui loops ou instruções de fluxo de controle que dependem de dados do tensor. Ver ?jit_trace
para saber mais.
Novo método de impressão para nn_modules
Nesta versão também melhoramos o nn_module
métodos de impressão para facilitar a compreensão do que está dentro.
Por exemplo, se você criar uma instância de um nn_linear
módulo você verá:
An `nn_module` containing 11 parameters.
── Parameters ──────────────────────────────────────────────────────────────────
● weight: Float (1:1, 1:10)
● bias: Float (1:1)
Você vê imediatamente o número complete de parâmetros no módulo, bem como seus nomes e formas.
Isso também funciona para módulos personalizados (possivelmente incluindo submódulos). Por exemplo:
my_module <- nn_module(
initialize = perform() {
self$linear <- nn_linear(10, 1)
self$param <- nn_parameter(torch_randn(5,1))
self$buff <- nn_buffer(torch_randn(5))
}
)
my_module()
An `nn_module` containing 16 parameters.
── Modules ─────────────────────────────────────────────────────────────────────
● linear: #11 parameters
── Parameters ──────────────────────────────────────────────────────────────────
● param: Float (1:5, 1:1)
── Buffers ─────────────────────────────────────────────────────────────────────
● buff: Float (1:5)
Esperamos que isso facilite a compreensão nn_module
objetos. Também melhoramos o suporte ao preenchimento automático para nn_modules
e agora mostraremos todos os submódulos, parâmetros e buffers enquanto você digita.
áudio da tocha
torchaudio
é uma extensão para torch
desenvolvido por Athos Damiani (@athospd
), fornecendo carregamento de áudio, transformações, arquiteturas comuns para processamento de sinal, pesos pré-treinados e acesso a conjuntos de dados comumente usados. Uma tradução quase literal da biblioteca Torchaudio do PyTorch para R.
torchaudio
ainda não está no CRAN, mas você já pode experimentar a versão de desenvolvimento disponível aqui.
Você também pode visitar o pkgdown
web site para exemplos e documentação de referência.
Outros recursos e correções de bugs
Graças às contribuições da comunidade, encontramos e corrigimos muitos bugs no torch
. Também adicionamos novos recursos, incluindo:
Você pode ver a lista completa de alterações no NOTÍCIAS.md arquivo.
Muito obrigado por ler esta postagem do weblog e sinta-se à vontade para entrar em contato com o GitHub para obter ajuda ou discussões!
A foto usada nesta prévia do submit é de Oleg Illarionov sobre Remover respingo