Weblog posit ai: segmentação de imagem com u-net


Claro, é bom quando tenho uma foto de algum objeto, e uma rede neural pode me dizer que tipo de objeto esse é. Mais realisticamente, pode haver vários objetos salientes nessa imagem, e isso me diz o que são e onde estão. A última tarefa (conhecida como Detecção de objetos) parece especialmente prototípico das aplicações contemporâneas de IA que, ao mesmo tempo, são intelectualmente fascinantes e eticamente questionáveis. É diferente com o assunto deste submit: bem -sucedido Segmentação de imagem tem muitas aplicações inegavelmente úteis. Por exemplo, é uma qua qua non em medicina, neurociência, biologia e outras ciências da vida.

Então, o que, tecnicamente, é a segmentação de imagens e como podemos treinar uma rede neural para fazê -lo?

Segmentação de imagem em poucas palavras

Digamos que temos uma imagem com um monte de gatos. Em classificaçãoa questão é “O que é isso?” E a resposta que queremos ouvir é: “Cat”. Em Detecção de objetosperguntamos novamente “o que é isso”, mas agora que “o que” é implicitamente plural, e esperamos uma resposta como “há um gato, um gato e um gato, e eles estão aqui, aqui e aqui” (think about a rede apontando, por meio de desenhar Caixas delimitadorasou seja, retângulos em torno dos objetos detectados). Em Segmentaçãoqueremos mais: queremos toda a imagem coberta por “caixas”-que não são mais caixas, mas sindicatos de “boxe” do tamanho de pixels-ou de maneira diferente: Queremos que a rede rotule cada pixel na imagem.

Aqui está um exemplo do artigo sobre o qual vamos falar em um segundo. À esquerda está a imagem de entrada (células HeLa), em seguida é a verdade do solo e a terceira é a máscara de segmentação aprendida.


Weblog posit ai: segmentação de imagem com u-net

Figura 1: Exemplo de segmentação de Ronneberger et al. 2015.

Tecnicamente, é feita uma distinção entre Segmentação de classe e Segmentação da instância. Na segmentação da aula, referindo -se ao exemplo “Bunco de gatos”, existem dois rótulos possíveis: cada pixel é “gato” ou “não gato”. A segmentação da instância é mais difícil: aqui todo gato recebe seu próprio rótulo. (Além disso, por que isso deveria ser mais difícil? Apupando a cognição do tipo humano, não seria-se eu tiver o conceito de um gato, em vez de apenas “Cattiness”, eu “vejo” que existem dois gatos, não um. Mas, dependendo do que uma rede neural específica confia na maioria das texturas, cores, peças isoladas-as tarefas podem diferem muito.

A arquitetura de rede usada neste submit é adequada para Segmentação de classe tarefas e devem ser aplicáveis ​​a um grande número de aplicações práticas, científicas e não científicas. Falando em arquitetura de rede, como deve ficar?

Apresentando U-Internet

Dado seu sucesso na classificação da imagem, não podemos apenas usar uma arquitetura clássica como Início V (n)Assim, ResnetAssim, Resnext … , qualquer que seja? O problema é que nossa tarefa em questão – rotulando todos os pixels – não se encaixa tão bem com a idéia clássica de uma CNN. Com Convnets, a idéia é aplicar camadas sucessivas de convolução e agrupamento para criar mapas de granularidade decrescente, para finalmente chegar a um nível abstrato, onde apenas dizemos: “Sim, um gato”. A contraparte é, perdemos informações detalhadas: para a classificação ultimate, não importa se os cinco pixels na área superior esquerda são pretos ou brancos.

Na prática, o uso de arquiteturas clássicas (max) agrupamento ou convoluções com stride > 1 Para alcançar essas sucessivas abstrações – necessariamente resultando em diminuição da resolução espacial. Então, como podemos usar um convnet e ainda preservar informações detalhadas? Em seu artigo de 2015 U-Internet: redes convolucionais para segmentação de imagem biomédica (Ronneberger, Fischer e Brox 2015)Olaf Ronneberger et al. Chegou ao que quatro anos depois, em 2019, ainda é a abordagem mais widespread. (O que quer dizer algo, quatro anos por muito tempo, em um aprendizado profundo.)

A ideia é incrivelmente simples. Enquanto a codificação sucessiva (convolução / poolamento máximo) etapas, como de costume, reduz a resolução, a decodificação subsequente – temos que chegar a uma saída do tamanho igual à entrada, pois queremos rotular todos os pixels! – Não simplesmente usa a amostra da camada mais compactada. Em vez disso, durante a amostragem, a cada etapa alimentamos informações da camada correspondente, em resolução, na cadeia de redução de tamanho.

Para U-Internet, realmente uma foto diz mais do que muitas palavras:


Arquitetura U-Net de Ronneberger et al. 2015.

Figura 2: Arquitetura de rede U de Ronneberger et al. 2015.

Em cada estágio de amostragem nós concatenar A saída da camada anterior com a de sua contraparte no estágio de compressão. A saída ultimate é um máscara do tamanho da imagem unique, obtida by way of 1×1-Convolução; Nenhuma camada densa ultimate é necessária; em vez disso, a camada de saída é apenas uma camada convolucional com um único filtro.

Agora vamos realmente treinar uma rede U. Nós vamos usar o unet pacote Isso permite criar um modelo de bom desempenho em uma única linha:

remotes::install_github("r-tensorflow/unet")
library(unet)

# takes extra parameters, together with variety of downsizing blocks, 
# variety of filters to begin with, and variety of lessons to determine
# see ?unet for more information
mannequin <- unet(input_shape = c(128, 128, 3))

Então, temos um modelo e parece que desejaremos alimentá -lo com imagens 128×128 RGB. Agora, como temos essas imagens?

Os dados

Para ilustrar como as aplicações surgem mesmo fora da área de pesquisa médica, usaremos como exemplo o kaggle Desafio de mascaramento de imagem de Carvana. A tarefa é criar uma máscara de segmentação que separa carros do fundo. Para nosso objetivo atual, só precisamos prepare.zip e train_mask.zip do Arquivo fornecido para obtain. A seguir, assumimos que eles foram extraídos para um subdiretório chamado data-raw.

Vamos primeiro dar uma olhada em algumas imagens e suas máscaras de segmentação associadas.

As fotos são JPEGs de espaço RGB, enquanto as máscaras são GIFs em preto e branco.

Dividimos os dados em um treinamento e um conjunto de validação. Usaremos o último para monitorar o desempenho da generalização durante o treinamento.

knowledge <- tibble(
  img = record.information(right here::right here("data-raw/prepare"), full.names = TRUE),
  masks = record.information(right here::right here("data-raw/train_masks"), full.names = TRUE)
)

knowledge <- initial_split(knowledge, prop = 0.8)

Para alimentar os dados da rede, usaremos tfdatasets. Todo o pré-processamento acabará em um pipeline simples, mas primeiro examinaremos as ações necessárias passo a passo.

Oleoduto de pré -processamento

O primeiro passo é ler nas imagens, usando as funções apropriadas em tf$picture.

training_dataset <- coaching(knowledge) %>%  
  tensor_slices_dataset() %>% 
  dataset_map(~.x %>% list_modify(
    # decode_jpeg yields a 3d tensor of form (1280, 1918, 3)
    img = tf$picture$decode_jpeg(tf$io$read_file(.x$img)),
    # decode_gif yields a 4d tensor of form (1, 1280, 1918, 3),
    # so we take away the unneeded batch dimension and all however one 
    # of the three (an identical) channels
    masks = tf$picture$decode_gif(tf$io$read_file(.x$masks))(1,,,)(,,1,drop=FALSE)
  ))

Ao construir um pipeline de pré -processamento, é muito útil verificar os resultados intermediários. É fácil de fazer usando reticulate::as_iterator No conjunto de dados:

$img
tf.Tensor(
(((243 244 239)
  (243 244 239)
  (243 244 239)
  ...
 ...
  ...
  (175 179 178)
  (175 179 178)
  (175 179 178))), form=(1280, 1918, 3), dtype=uint8)

$masks
tf.Tensor(
(((0)
  (0)
  (0)
  ...
 ...
  ...
  (0)
  (0)
  (0))), form=(1280, 1918, 1), dtype=uint8)

Enquanto o uint8 O Datatype facilita a leitura dos valores de RGB para os seres humanos, a rede espera números de ponto flutuante. O código a seguir converte sua entrada e, além disso, escala valores para o intervalo (0,1):

training_dataset <- training_dataset %>% 
  dataset_map(~.x %>% list_modify(
    img = tf$picture$convert_image_dtype(.x$img, dtype = tf$float32),
    masks = tf$picture$convert_image_dtype(.x$masks, dtype = tf$float32)
  ))

Para reduzir o custo computacional, redimensionamos as imagens para o tamanho 128x128. Isso mudará a proporção e, portanto, distorcerá as imagens, mas não é um problema com o conjunto de dados especificado.

training_dataset <- training_dataset %>% 
  dataset_map(~.x %>% list_modify(
    img = tf$picture$resize(.x$img, measurement = form(128, 128)),
    masks = tf$picture$resize(.x$masks, measurement = form(128, 128))
  ))

Agora, é sabido que, no aprendizado profundo, o aumento de dados é elementary. Para a segmentação, há uma coisa a considerar, que é se uma transformação precisa ser aplicada à máscara – esse seria o caso das rotações, por exemplo, ou inverte. Aqui, os resultados serão bons o suficiente aplicando apenas transformações que preservam posições:

random_bsh <- operate(img) {
  img %>% 
    tf$picture$random_brightness(max_delta = 0.3) %>% 
    tf$picture$random_contrast(decrease = 0.5, higher = 0.7) %>% 
    tf$picture$random_saturation(decrease = 0.5, higher = 0.7) %>% 
    # be sure that we nonetheless are between 0 and 1
    tf$clip_by_value(0, 1) 
}

training_dataset <- training_dataset %>% 
  dataset_map(~.x %>% list_modify(
    img = random_bsh(.x$img)
  ))

Novamente, podemos usar as_iterator Para ver o que essas transformações fazem em nossas imagens:

Aqui está o pipeline completo de pré -processamento.

create_dataset <- operate(knowledge, prepare, batch_size = 32L) {
  
  dataset <- knowledge %>% 
    tensor_slices_dataset() %>% 
    dataset_map(~.x %>% list_modify(
      img = tf$picture$decode_jpeg(tf$io$read_file(.x$img)),
      masks = tf$picture$decode_gif(tf$io$read_file(.x$masks))(1,,,)(,,1,drop=FALSE)
    )) %>% 
    dataset_map(~.x %>% list_modify(
      img = tf$picture$convert_image_dtype(.x$img, dtype = tf$float32),
      masks = tf$picture$convert_image_dtype(.x$masks, dtype = tf$float32)
    )) %>% 
    dataset_map(~.x %>% list_modify(
      img = tf$picture$resize(.x$img, measurement = form(128, 128)),
      masks = tf$picture$resize(.x$masks, measurement = form(128, 128))
    ))
  
  # knowledge augmentation carried out on coaching set solely
  if (prepare) {
    dataset <- dataset %>% 
      dataset_map(~.x %>% list_modify(
        img = random_bsh(.x$img)
      )) 
  }
  
  # shuffling on coaching set solely
  if (prepare) {
    dataset <- dataset %>% 
      dataset_shuffle(buffer_size = batch_size*128)
  }
  
  # prepare in batches; batch measurement may should be tailored relying on
  # accessible reminiscence
  dataset <- dataset %>% 
    dataset_batch(batch_size)
  
  dataset %>% 
    # output must be unnamed
    dataset_map(unname) 
}

A criação de treinamento e conjunto de testes agora é apenas uma questão de duas chamadas de função.

training_dataset <- create_dataset(coaching(knowledge), prepare = TRUE)
validation_dataset <- create_dataset(testing(knowledge), prepare = FALSE)

E estamos prontos para treinar o modelo.

Treinando o modelo

Já mostramos como criar o modelo, mas vamos repeti -lo aqui e verificar a arquitetura do modelo:

mannequin <- unet(input_shape = c(128, 128, 3))
abstract(mannequin)
Mannequin: "mannequin"
______________________________________________________________________________________________
Layer (sort)                   Output Form        Param #    Linked to                    
==============================================================================================
input_1 (InputLayer)           ((None, 128, 128, 3 0                                          
______________________________________________________________________________________________
conv2d (Conv2D)                (None, 128, 128, 64 1792       input_1(0)(0)                   
______________________________________________________________________________________________
conv2d_1 (Conv2D)              (None, 128, 128, 64 36928      conv2d(0)(0)                    
______________________________________________________________________________________________
max_pooling2d (MaxPooling2D)   (None, 64, 64, 64)  0          conv2d_1(0)(0)                  
______________________________________________________________________________________________
conv2d_2 (Conv2D)              (None, 64, 64, 128) 73856      max_pooling2d(0)(0)             
______________________________________________________________________________________________
conv2d_3 (Conv2D)              (None, 64, 64, 128) 147584     conv2d_2(0)(0)                  
______________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 32, 32, 128) 0          conv2d_3(0)(0)                  
______________________________________________________________________________________________
conv2d_4 (Conv2D)              (None, 32, 32, 256) 295168     max_pooling2d_1(0)(0)           
______________________________________________________________________________________________
conv2d_5 (Conv2D)              (None, 32, 32, 256) 590080     conv2d_4(0)(0)                  
______________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 16, 16, 256) 0          conv2d_5(0)(0)                  
______________________________________________________________________________________________
conv2d_6 (Conv2D)              (None, 16, 16, 512) 1180160    max_pooling2d_2(0)(0)           
______________________________________________________________________________________________
conv2d_7 (Conv2D)              (None, 16, 16, 512) 2359808    conv2d_6(0)(0)                  
______________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 8, 8, 512)   0          conv2d_7(0)(0)                  
______________________________________________________________________________________________
dropout (Dropout)              (None, 8, 8, 512)   0          max_pooling2d_3(0)(0)           
______________________________________________________________________________________________
conv2d_8 (Conv2D)              (None, 8, 8, 1024)  4719616    dropout(0)(0)                   
______________________________________________________________________________________________
conv2d_9 (Conv2D)              (None, 8, 8, 1024)  9438208    conv2d_8(0)(0)                  
______________________________________________________________________________________________
conv2d_transpose (Conv2DTransp (None, 16, 16, 512) 2097664    conv2d_9(0)(0)                  
______________________________________________________________________________________________
concatenate (Concatenate)      (None, 16, 16, 1024 0          conv2d_7(0)(0)                  
                                                              conv2d_transpose(0)(0)          
______________________________________________________________________________________________
conv2d_10 (Conv2D)             (None, 16, 16, 512) 4719104    concatenate(0)(0)               
______________________________________________________________________________________________
conv2d_11 (Conv2D)             (None, 16, 16, 512) 2359808    conv2d_10(0)(0)                 
______________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTran (None, 32, 32, 256) 524544     conv2d_11(0)(0)                 
______________________________________________________________________________________________
concatenate_1 (Concatenate)    (None, 32, 32, 512) 0          conv2d_5(0)(0)                  
                                                              conv2d_transpose_1(0)(0)        
______________________________________________________________________________________________
conv2d_12 (Conv2D)             (None, 32, 32, 256) 1179904    concatenate_1(0)(0)             
______________________________________________________________________________________________
conv2d_13 (Conv2D)             (None, 32, 32, 256) 590080     conv2d_12(0)(0)                 
______________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTran (None, 64, 64, 128) 131200     conv2d_13(0)(0)                 
______________________________________________________________________________________________
concatenate_2 (Concatenate)    (None, 64, 64, 256) 0          conv2d_3(0)(0)                  
                                                              conv2d_transpose_2(0)(0)        
______________________________________________________________________________________________
conv2d_14 (Conv2D)             (None, 64, 64, 128) 295040     concatenate_2(0)(0)             
______________________________________________________________________________________________
conv2d_15 (Conv2D)             (None, 64, 64, 128) 147584     conv2d_14(0)(0)                 
______________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTran (None, 128, 128, 64 32832      conv2d_15(0)(0)                 
______________________________________________________________________________________________
concatenate_3 (Concatenate)    (None, 128, 128, 12 0          conv2d_1(0)(0)                  
                                                              conv2d_transpose_3(0)(0)        
______________________________________________________________________________________________
conv2d_16 (Conv2D)             (None, 128, 128, 64 73792      concatenate_3(0)(0)             
______________________________________________________________________________________________
conv2d_17 (Conv2D)             (None, 128, 128, 64 36928      conv2d_16(0)(0)                 
______________________________________________________________________________________________
conv2d_18 (Conv2D)             (None, 128, 128, 1) 65         conv2d_17(0)(0)                 
==============================================================================================
Complete params: 31,031,745
Trainable params: 31,031,745
Non-trainable params: 0
______________________________________________________________________________________________

A coluna “forma de saída” mostra a forma de U 8x8; Eles então sobem novamente, até chegarmos à resolução unique. Ao mesmo tempo, o número de filtros aumenta primeiro e depois desce novamente, até que na camada de saída tenhamos um único filtro. Você também pode ver o concatenate Camadas que anexam informações provenientes de “abaixo” a informações que vêm “lateralmente”.

Qual deve ser a função de perda aqui? Estamos rotulando cada pixel, então cada pixel contribui para a perda. Temos um problema binário – cada pixel pode ser “carro” ou “fundo” – por isso queremos que cada saída esteja próxima de 0 ou 1. Isso faz binário_crossentropy a função de perda adequada.

Durante o treinamento, acompanhamos a precisão da classificação, bem como o Coeficiente de dadosa métrica de avaliação usada na competição. O coeficiente de dados é uma maneira de medir a proporção de classificações corretas:

cube <- custom_metric("cube", operate(y_true, y_pred, clean = 1.0) {
  y_true_f <- k_flatten(y_true)
  y_pred_f <- k_flatten(y_pred)
  intersection <- k_sum(y_true_f * y_pred_f)
  (2 * intersection + clean) / (k_sum(y_true_f) + k_sum(y_pred_f) + clean)
})

mannequin %>% compile(
  optimizer = optimizer_rmsprop(lr = 1e-5),
  loss = "binary_crossentropy",
  metrics = record(cube, metric_binary_accuracy)
)

O ajuste do modelo leva algum tempo – quanto, é claro, dependerá do seu {hardware}. Mas a espera paga: Após cinco épocas, vimos um coeficiente de dados de ~ 0,87 no conjunto de validação e uma precisão de ~ 0,95.

Previsões

Obviamente, o que estamos interessados ​​são previsões. Vamos ver algumas máscaras geradas para itens do conjunto de validação:

batch <- validation_dataset %>% as_iterator() %>% iter_next()
predictions <- predict(mannequin, batch)

photos <- tibble(
  picture = batch((1)) %>% array_branch(1),
  predicted_mask = predictions(,,,1) %>% array_branch(1),
  masks = batch((2))(,,,1)  %>% array_branch(1)
) %>% 
  sample_n(2) %>% 
  map_depth(2, operate(x) {
    as.raster(x) %>% magick::image_read()
  }) %>% 
  map(~do.name(c, .x))


out <- magick::image_append(c(
  magick::image_append(photos$masks, stack = TRUE),
  magick::image_append(photos$picture, stack = TRUE), 
  magick::image_append(photos$predicted_mask, stack = TRUE)
  )
)

plot(out)

Da esquerda para a direita: Verdade do solo, imagem de entrada e máscara prevista da U-Net.

Figura 3: Da esquerda para a direita: Verdade do solo, imagem de entrada e máscara prevista da U-Internet.

Conclusão

Se houvesse uma competição pela maior quantia de utilidade e transparência arquitetônica, a U-Internet certamente seria um candidato. Sem muita ajuste, é possível obter resultados decentes. Se você puder colocar esse modelo em seu trabalho ou se tiver problemas para usá -lo, informe -nos! Obrigado pela leitura!

Ronneberger, Olaf, Philipp Fischer e Thomas Brox. 2015. “U-Internet: redes convolucionais para segmentação de imagem biomédica.” Corr ABS/1505.04597. http://arxiv.org/abs/1505.04597.

Deixe um comentário

O seu endereço de e-mail não será publicado. Campos obrigatórios são marcados com *