Segmentação de instância de dentes 3D. No escuro, mas não sozinho

Segmentação 3D dos dentes desde a recuperação de dados até o resultado final. Aproximadamente.





Isenção de responsabilidade

Este artigo não é educacional em nenhum sentido deste termo e é puramente informativo. O autor do artigo não se responsabiliza pelo tempo gasto em sua leitura.





Sobre o autor

Tipo - todos, o nome é Andrey (27). Vou tentar ser curto. Por que programar? Pela formação - bacharel em mecânica elétrica, conheço a profissão. Trabalhei por 2 anos como engenheiro de energia em uma empresa de perfuração com bastante sucesso, em vez de uma promoção, escrevi um comunicado - queimei, mas não foi para mim. Gosto de criar, encontrar soluções para problemas complexos, com um PC em um abraço desde anos de consciência. A escolha é óbvia. No início (seis meses atrás), pensei seriamente em me inscrever em cursos de I ou similares. Li as resenhas, conversei com os participantes e percebi que não houve problemas para obter informações. Então eu encontrei o site, Eu consegui uma base Python lá e comecei minha jornada com ela (agora estou estudando gradualmente tudo relacionado ao ML lá). Imediatamente interessado em aprendizado de máquina, currículo em particular. Tive um problema e aqui estou (para mim, é uma ótima maneira de aprender).





1. Introdução

Como resultado de várias tentativas sem sucesso, tomei a decisão de usar 2 modelos leves para obter o resultado desejado. O primeiro segmento todos os dentes como uma categoria [1, 0], e o segundo os divide nas categorias [0, 8]. Mas vamos começar em ordem.





2. Pesquisa e preparação de dados

Depois de passar mais de uma noite procurando dados para o trabalho, cheguei à conclusão de que uma mandíbula livre de boa qualidade e formato (* .stl, * .nrrd, etc.) não funcionaria. O melhor que encontrei foi uma amostra de teste da cabeça de um paciente após uma cirurgia na mandíbula no 3D Slicer .





Obviamente, não preciso do cabeçote inteiro, então cortei a fonte no mesmo programa para o tamanho 163 * 112 * 120px (neste post {x * y * z = wdh} e 1px - 0,5 mm), deixando apenas o dentes e partes maxilofaciais associadas.





, - . . , - "autothreshold" , , , , ( ).





- ( )? -
- ( )? -

12~14. , 4 . , .





  . Smooth 0.5. (    )
. Smooth 0.5. ( )

, ( ) , . , , N- , random-crop .





import nrrd
import torch
import torchvision.transforms as tf


class DataBuilder:
    def __init__(self,
                 data_path,
                 list_of_categories,
                 num_of_chunks: int = 0,
                 augmentation_coeff: int = 0,
                 num_of_classes: int = 0,
                 normalise: bool = False,
                 fit: bool = True,
                 data_format: int = 0,
                 save_data: bool = False
                 ):
        self.data_path = data_path
        self.number_of_chunks = num_of_chunks
        self.augmentation_coeff = augmentation_coeff
        self.list_of_cats = list_of_categories
        self.num_of_cls = num_of_classes
        self.normalise = normalise
        self.fit = fit
        self.data_format = data_format
        self.save_data = save_data

    def forward(self):
        data = self.get_data()
        data = self.fit_data(data) if self.fit else data
        data = self.pre_normalize(data) if self.normalise else data
        data = self.data_augmentation(data, self.augmentation_coeff) if self.augmentation_coeff != 0 else data
        data = self.new_chunks(data, self.number_of_chunks) if self.number_of_chunks != 0 else data
        data = self.category_splitter(data, self.num_of_cls, self.list_of_cats) if self.num_of_cls != 0 else data
        torch.save(data, self.data_path[-14:]+'.pt') if self.save_data else None

        return torch.unsqueeze(data, 1)

    def get_data(self):
        if self.data_format == 0:
            return torch.from_numpy(nrrd.read(self.data_path)[0])
        elif self.data_format == 1:
            return torch.load(self.data_path).cpu()
        elif self.data_format == 2:
            return torch.unsqueeze(self.data_path, 0).cpu()
        else:
            print('Available types are: "nrrd", "tensor" or "self.tensor(w/o load)"')

    @staticmethod
    def fit_data(some_data):
        data = torch.movedim(some_data, (1, 0), (0, -1))
        data_add_x = torch.nn.ZeroPad2d((5, 0, 0, 0))
        data = data_add_x(data)
        data = torch.movedim(data, -1, 0)
        data_add_z = torch.nn.ZeroPad2d((0, 0, 8, 0))

        return data_add_z(data)

    @staticmethod
    def pre_normalize(some_data):
        min_d, max_d = torch.min(some_data), torch.max(some_data)

        return (some_data - min_d) / (max_d - min_d)

    @staticmethod
    def data_augmentation(some_data, aug_n):
        torch.manual_seed(17)
        tr_data = []
        for e in range(aug_n):
            transform = tf.RandomRotation(degrees=(20*e, 20*e))
            for image in some_data:
                image = torch.unsqueeze(image, 0)
                image = transform(image)
                tr_data.append(image)

        return tr_data

    def new_chunks(self, some_data, n_ch):
        data = torch.stack(some_data, 0) if self.augmentation_coeff != 0 else some_data
        data = torch.squeeze(data, 1)
        chunks = torch.chunk(data, n_ch, 0)

        return torch.stack(chunks)

    @staticmethod
    def category_splitter(some_data, alpha, list_of_categories):
        data, _ = torch.squeeze(some_data, 1).to(torch.int64), alpha
        for i in list_of_categories:
            data = torch.where(data < i, _, data)
            _ += 1

        return data - alpha

      
      



3D U-net. :





  • ( ).





  • 0 168*120*120 ( 163*112*120). * .





  • 0...1 ( ~-2000...16000).





  • N- .





  • ( 1, 1, 72, 120, 120).





  • 28 (. ):





    • 1-;





    • 9 (8+) 2-.





Dataloader
import torch.utils.data as tud


class ToothDataset(tud.Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self): return len(self.images)

    def __getitem__(self, index):
        if self.masks is not None:
            return self.images[index, :, :, :, :],\
                    self.masks[index, :, :, :, :]
        else:
            return self.images[index, :, :, :, :]


def get_loaders(images, masks,
                batch_size: int = 1,
                num_workers: int = 1,
                pin_memory: bool = True):

    train_ds = ToothDataset(images=images,
                            masks=masks)

    data_loader = tud.DataLoader(train_ds,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 pin_memory=pin_memory)

    return data_loader

      
      



:









Semantic





Instance





Predictions





Data





(27*, 1, 56*, 120,120)[0...1]





(27*, 1, 56*, 120,120) [0, 1]





(1, 1, 168, 120, 120)[0...1]





Masks





(27*, 1, 56*, 120,120)[0, 1]





(27*, 1, 56*, 120,120)[0, 8]





-





* , , - .





3.

- . U-Net. , .





2D U-Net
2D U-Net

, . - Adam, Dice-loss(implement), / 4, [64, 128, 256, 512] (, , - ). 60-80 epochs . Transfer learning .





model.summary()
model = UNet(dim=2, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 168, 120)))

"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 168, 120]             640
              ReLU-2         [-1, 64, 168, 120]               0
       BatchNorm2d-3         [-1, 64, 168, 120]             128
            Conv2d-4         [-1, 64, 168, 120]          36,928
              ReLU-5         [-1, 64, 168, 120]               0
       BatchNorm2d-6         [-1, 64, 168, 120]             128
         MaxPool2d-7           [-1, 64, 84, 60]               0
         DownBlock-8  [[-1, 64, 84, 60], [-1, 64, 168, 120]]  0
            Conv2d-9          [-1, 128, 84, 60]          73,856
             ReLU-10          [-1, 128, 84, 60]               0
      BatchNorm2d-11          [-1, 128, 84, 60]             256
           Conv2d-12          [-1, 128, 84, 60]         147,584
             ReLU-13          [-1, 128, 84, 60]               0
      BatchNorm2d-14          [-1, 128, 84, 60]             256
        MaxPool2d-15          [-1, 128, 42, 30]               0
        DownBlock-16  [[-1, 128, 42, 30], [-1, 128, 84, 60]]  0
           Conv2d-17          [-1, 256, 42, 30]         295,168
             ReLU-18          [-1, 256, 42, 30]               0
      BatchNorm2d-19          [-1, 256, 42, 30]             512
           Conv2d-20          [-1, 256, 42, 30]         590,080
             ReLU-21          [-1, 256, 42, 30]               0
      BatchNorm2d-22          [-1, 256, 42, 30]             512
        MaxPool2d-23          [-1, 256, 21, 15]               0
        DownBlock-24  [[-1, 256, 21, 15], [-1, 256, 42, 30]]  0
           Conv2d-25          [-1, 512, 21, 15]       1,180,160
             ReLU-26          [-1, 512, 21, 15]               0
      BatchNorm2d-27          [-1, 512, 21, 15]           1,024
           Conv2d-28          [-1, 512, 21, 15]       2,359,808
             ReLU-29          [-1, 512, 21, 15]               0
      BatchNorm2d-30          [-1, 512, 21, 15]           1,024
        DownBlock-31  [[-1, 512, 21, 15], [-1, 512, 21, 15]]  0
  ConvTranspose2d-32          [-1, 256, 42, 30]         524,544
             ReLU-33          [-1, 256, 42, 30]               0
      BatchNorm2d-34          [-1, 256, 42, 30]             512
      Concatenate-35          [-1, 512, 42, 30]               0
           Conv2d-36          [-1, 256, 42, 30]       1,179,904
             ReLU-37          [-1, 256, 42, 30]               0
      BatchNorm2d-38          [-1, 256, 42, 30]             512
           Conv2d-39          [-1, 256, 42, 30]         590,080
             ReLU-40          [-1, 256, 42, 30]               0
      BatchNorm2d-41          [-1, 256, 42, 30]             512
          UpBlock-42          [-1, 256, 42, 30]               0
  ConvTranspose2d-43          [-1, 128, 84, 60]         131,200
             ReLU-44          [-1, 128, 84, 60]               0
      BatchNorm2d-45          [-1, 128, 84, 60]             256
      Concatenate-46          [-1, 256, 84, 60]               0
           Conv2d-47          [-1, 128, 84, 60]         295,040
             ReLU-48          [-1, 128, 84, 60]               0
      BatchNorm2d-49          [-1, 128, 84, 60]             256
           Conv2d-50          [-1, 128, 84, 60]         147,584
             ReLU-51          [-1, 128, 84, 60]               0
      BatchNorm2d-52          [-1, 128, 84, 60]             256
          UpBlock-53          [-1, 128, 84, 60]               0
  ConvTranspose2d-54         [-1, 64, 168, 120]          32,832
             ReLU-55         [-1, 64, 168, 120]               0
      BatchNorm2d-56         [-1, 64, 168, 120]             128
      Concatenate-57        [-1, 128, 168, 120]               0
           Conv2d-58         [-1, 64, 168, 120]          73,792
             ReLU-59         [-1, 64, 168, 120]               0
      BatchNorm2d-60         [-1, 64, 168, 120]             128
           Conv2d-61         [-1, 64, 168, 120]          36,928
             ReLU-62         [-1, 64, 168, 120]               0
      BatchNorm2d-63         [-1, 64, 168, 120]             128
          UpBlock-64         [-1, 64, 168, 120]               0
           Conv2d-65          [-1, 1, 168, 120]              65
================================================================
Total params: 7,702,721
Trainable params: 7,702,721
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.08
Forward/backward pass size (MB): 7434.08
Params size (MB): 29.38
Estimated Total Size (MB): 7463.54
"""
      
      



.№1
2D U-Net,   ,  [x, z]
.№1 2D U-Net, , [x, z]

, - . , . numpy - *.stl 6. , :





  :
1.  [x, y]. 2.  [x, z]. 3. [y, z]
: 1. [x, y]. 2. [x, z]. 3. [y, z]

100% , ? , .





, , , , , .





.№2
 2- 2D U-Net,   ,  [y, z]
.№2 2- 2D U-Net, , [y, z]

, , :





.№3
 2- 2D U-Net,     [y, z]
     50%
.№3 2- 2D U-Net, [y, z] 50%

3D . , (24*, 120, 120). ? - (~22. ). (1063gtx) .





24*

. :





  • (1512, 120, 120) - 63;





  • batch size (24, 120, 120) - , ;





  • (24) / ( 24/2/2/2=3 3*2*2*2=24, / 2 / 1);





  • , . .summary()





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=4, start_filters=64).to(device)
print(summary(model, (1, 24, 120, 120)))

"""
  ----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 64, 24, 120, 120]             1,792
              ReLU-2     [-1, 64, 24, 120, 120]                 0
       BatchNorm3d-3     [-1, 64, 24, 120, 120]               128
            Conv3d-4     [-1, 64, 24, 120, 120]           110,656
              ReLU-5     [-1, 64, 24, 120, 120]                 0
       BatchNorm3d-6     [-1, 64, 24, 120, 120]               128
         MaxPool3d-7        [-1, 64, 12, 60, 60]                0
         DownBlock-8  [[-1, 64, 12, 60, 60], [-1, 64, 24, 120, 120]]               0
            Conv3d-9       [-1, 128, 12, 60, 60]          221,312
             ReLU-10       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-11       [-1, 128, 12, 60, 60]              256
           Conv3d-12       [-1, 128, 12, 60, 60]          442,496
             ReLU-13       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-14       [-1, 128, 12, 60, 60]              256
        MaxPool3d-15       [-1, 128, 6, 30, 30]                 0
        DownBlock-16  [[-1, 128, 6, 30, 30], [-1, 128, 12, 60, 60]]               0
           Conv3d-17       [-1, 256, 6, 30, 30]           884,992
             ReLU-18       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-19       [-1, 256, 6, 30, 30]               512
           Conv3d-20       [-1, 256, 6, 30, 30]         1,769,728
             ReLU-21       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-22       [-1, 256, 6, 30, 30]               512
        MaxPool3d-23       [-1, 256, 3, 15, 15]                 0
        DownBlock-24  [[-1, 256, 3, 15, 15], [-1, 256, 6, 30, 30]]               0
           Conv3d-25       [-1, 512, 3, 15, 15]         3,539,456
             ReLU-26       [-1, 512, 3, 15, 15]                 0
      BatchNorm3d-27       [-1, 512, 3, 15, 15]             1,024
           Conv3d-28       [-1, 512, 3, 15, 15]         7,078,400
             ReLU-29       [-1, 512, 3, 15, 15]                 0
      BatchNorm3d-30       [-1, 512, 3, 15, 15]             1,024
        DownBlock-31  [[-1, 512, 3, 15, 15], [-1, 512, 3, 15, 15]]               0
  ConvTranspose3d-32       [-1, 256, 6, 30, 30]         1,048,832
             ReLU-33       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-34       [-1, 256, 6, 30, 30]               512
      Concatenate-35       [-1, 512, 6, 30, 30]                 0
           Conv3d-36       [-1, 256, 6, 30, 30]         3,539,200
             ReLU-37       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-38       [-1, 256, 6, 30, 30]               512
           Conv3d-39       [-1, 256, 6, 30, 30]         1,769,728
             ReLU-40       [-1, 256, 6, 30, 30]                 0
      BatchNorm3d-41       [-1, 256, 6, 30, 30]               512
          UpBlock-42       [-1, 256, 6, 30, 30]                 0
  ConvTranspose3d-43       [-1, 128, 12, 60, 60]          262,272
             ReLU-44       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-45       [-1, 128, 12, 60, 60]              256
      Concatenate-46       [-1, 256, 12, 60, 60]                0
           Conv3d-47       [-1, 128, 12, 60, 60]          884,864
             ReLU-48       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-49       [-1, 128, 12, 60, 60]              256
           Conv3d-50       [-1, 128, 12, 60, 60]          442,496
             ReLU-51       [-1, 128, 12, 60, 60]                0
      BatchNorm3d-52       [-1, 128, 12, 60, 60]              256
          UpBlock-53       [-1, 128, 12, 60, 60]                0
  ConvTranspose3d-54       [-1, 64, 24, 120, 120]          65,600
             ReLU-55       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-56       [-1, 64, 24, 120, 120]             128
      Concatenate-57      [-1, 128, 24, 120, 120]               0
           Conv3d-58       [-1, 64, 24, 120, 120]         221,248
             ReLU-59       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-60       [-1, 64, 24, 120, 120]             128
           Conv3d-61       [-1, 64, 24, 120, 120]         110,656
             ReLU-62       [-1, 64, 24, 120, 120]               0
      BatchNorm3d-63       [-1, 64, 24, 120, 120]             128
          UpBlock-64       [-1, 64, 24, 120, 120]               0
           Conv3d-65        [-1, 1, 24, 120, 120]              65
================================================================
Total params: 22,400,321
Trainable params: 22,400,321
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.61
Forward/backward pass size (MB): 15974.12
Params size (MB): 85.45
Estimated Total Size (MB): 16060.18
----------------------------------------------------------------
"""
      
      



.№4
3D U-Net,  ,  [y, z],
*0,38
.№4 3D U-Net, , [y, z], *0,38

~60% (25 epochs) , .





.№5
3D U-Net,  ,  [y, z], 
65 epochs ~ 1,5
.№5 3D U-Net, , [y, z], 65 epochs ~ 1,5

. , (.№3) - :





.№6
3D U-Net,  ,  [x, z], 
105 epochs ~ 2,1
.№6 3D U-Net, , [x, z], 105 epochs ~ 2,1

"" . ~400 ( ~22) [18, 32, 64, 128] / 3. RSMProp. (1, 1, 72*, 120, 120). ?





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=1, n_blocks=3, start_filters=18).to(device)
print(summary(model, (1, 1, 72, 120, 120)))

"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 18, 72, 120, 120]             504
              ReLU-2     [-1, 18, 72, 120, 120]               0
       BatchNorm3d-3     [-1, 18, 72, 120, 120]              36
            Conv3d-4     [-1, 18, 72, 120, 120]           8,766
              ReLU-5     [-1, 18, 72, 120, 120]               0
       BatchNorm3d-6     [-1, 18, 72, 120, 120]              36
         MaxPool3d-7       [-1, 18, 36, 60, 60]               0
         DownBlock-8  [[-1, 18, 36, 60, 60], [-1, 18, 24, 120, 120]]               0
            Conv3d-9       [-1, 36, 36, 60, 60]          17,532
             ReLU-10       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-11       [-1, 36, 36, 60, 60]              72
           Conv3d-12       [-1, 36, 36, 60, 60]          35,028
             ReLU-13       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-14       [-1, 36, 36, 60, 60]              72
        MaxPool3d-15        [-1, 36, 18, 30, 30]              0
        DownBlock-16  [[-1, 36, 18, 30, 30], [-1, 36, 36, 60, 60]]               0
           Conv3d-17        [-1, 72, 18, 30, 30]         70,056
             ReLU-18        [-1, 72, 18, 30, 30]              0
      BatchNorm3d-19        [-1, 72, 18, 30, 30]            144
           Conv3d-20        [-1, 72, 18, 30, 30]        140,040
             ReLU-21        [-1, 72, 18, 30, 30]              0
      BatchNorm3d-22        [-1, 72, 18, 30, 30]            144
        DownBlock-23  [[-1, 72, 18, 30, 30], [-1, 72, 18, 30, 30]]               0
  ConvTranspose3d-24       [-1, 36, 36, 60, 60]          20,772
             ReLU-25       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-26       [-1, 36, 36, 60, 60]              72
      Concatenate-27       [-1, 72, 36, 60, 60]               0
           Conv3d-28       [-1, 36, 36, 60, 60]          70,020
             ReLU-29       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-30       [-1, 36, 36, 60, 60]              72
           Conv3d-31       [-1, 36, 36, 60, 60]          35,028
             ReLU-32       [-1, 36, 36, 60, 60]               0
      BatchNorm3d-33       [-1, 36, 36, 60, 60]              72
          UpBlock-34       [-1, 36, 36, 60, 60]               0
  ConvTranspose3d-35     [-1, 18, 72, 120, 120]           5,202
             ReLU-36     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-37     [-1, 18, 72, 120, 120]              36
      Concatenate-38     [-1, 36, 72, 120, 120]               0
           Conv3d-39     [-1, 18, 72, 120, 120]          17,514
             ReLU-40     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-41     [-1, 18, 72, 120, 120]              36
           Conv3d-42     [-1, 18, 72, 120, 120]           8,766
             ReLU-43     [-1, 18, 72, 120, 120]               0
      BatchNorm3d-44     [-1, 18, 72, 120, 120]              36
          UpBlock-45     [-1, 18, 72, 120, 120]               0
           Conv3d-46      [-1, 1, 72, 120, 120]              19
================================================================
Total params: 430,075
Trainable params: 430,075
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.32
Forward/backward pass size (MB): 5744.38
Params size (MB): 1.64
Estimated Total Size (MB): 5747.34
----------------------------------------------------------------
"""
      
      



72*

, (168, 120, 120), (72, 120, 120). , . , 2 , . 9 (1512, 120, 120) .. 9 , 21(batch size) (72, 120, 120). 72 , 24*().





.№7
3D U-Net,  ,  [x, z],
 ()    (),
  ,
 (65 epochs) ~ 14.
.№7 3D U-Net, , [x, z], () (), , (65 epochs) ~ 14.

, ( "" ). , . semantic segmentation , .





3D ( ) (1512, 120, 120) --> 21*(1, 72, 120, 120), ~*(30, 30, 30) ( ). 2 : 3- , ( ); , .





, 1 epochs "" ~13, 2 (>80). 1 epochs. , .





. 8 + . loss function .





training loop
import torch
from tqdm import tqdm
from _loss_f import LossFunction


class TrainFunction:
    def __init__(self,
                 data_loader,
                 device_for_training,
                 model_name,
                 model_name_pretrained,
                 model,
                 optimizer,
                 scale,
                 learning_rate: int = 1e-2,
                 num_epochs: int = 1,
                 transfer_learning: bool = False,
                 binary_loss_f: bool = True
                 ):
        self.data_loader = data_loader
        self.device = device_for_training
        self.model_name_pretrained = model_name_pretrained
        self.semantic_binary = binary_loss_f
        self.num_epochs = num_epochs
        self.model_name = model_name
        self.transfer = transfer_learning
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.model = model
        self.scale = scale

    def forward(self):
        print('Running on the:', torch.cuda.get_device_name(self.device))
        self.model.load_state_dict(torch.load(self.model_name_pretrained)) if self.transfer else None
        optimizer = self.optimizer(self.model.parameters(), lr=self.learning_rate)
        for epoch in range(self.num_epochs):
            self.train_loop(self.data_loader, self.model, optimizer, self.scale, epoch)
            torch.save(self.model.state_dict(), 'models/' + self.model_name+str(epoch+1)
                       + '_epoch.pth') if (epoch + 1) % 10 == 0 else None

    def train_loop(self, loader, model, optimizer, scales, i):
        loop, epoch_loss = tqdm(loader), 0
        loop.set_description('Epoch %i' % (self.num_epochs - i))
        for batch_idx, (data, targets) in enumerate(loop):
            data, targets = data.to(device=self.device, dtype=torch.float), \
                            targets.to(device=self.device, dtype=torch.long)
            optimizer.zero_grad()
            * *
            with torch.cuda.amp.autocast():
                predictions = model(data)
                loss = LossFunction(predictions, targets,
                                    device_for_training=self.device,
                                    semantic_binary=self.semantic_binary
                                    ).forward()
            scales.scale(loss).backward()
            scales.step(optimizer)
            scales.update()
            epoch_loss += (1 - loss.item())*100
            loop.set_postfix(loss=loss.item())
        print('Epoch-acc', round(epoch_loss / (batch_idx+1), 2))

      
      



4.

Dice-loss , '' [0, 1]. , ( [0, 1]), ( "" "" ) Dice-loss , .





categorical_dice_loss
import torch


class LossFunction:
    def __init__(self,
                 prediction,
                 target,
                 device_for_training,
                 semantic_binary: bool = True,
                 ):
        self.prediction = prediction
        self.device = device_for_training
        self.target = target
        self.semantic_binary = semantic_binary

    def forward(self):
        if self.semantic_binary:
            return self.dice_loss(self.prediction, self.target)
        return self.categorical_dice_loss(self.prediction, self.target)

    @staticmethod
    def dice_loss(predictions, targets, alpha=1e-5):
        intersection = 2. * (predictions * targets).sum()
        denomination = (torch.square(predictions) + torch.square(targets)).sum()
        dice_loss = 1 - torch.mean((intersection + alpha) / (denomination + alpha))

        return dice_loss

    def categorical_dice_loss(self, prediction, target):
        pr, tr = self.prepare_for_multiclass_loss_f(prediction, target)
        target_categories, losses = torch.unique(tr).tolist(), 0
        for num_category in target_categories:
            categorical_target = torch.where(tr == num_category, 1, 0)
            categorical_prediction = pr[num_category][:][:][:]
            losses += self.dice_loss(categorical_prediction, categorical_target).to(self.device)

        return losses / len(target_categories)

    @staticmethod
    def prepare_for_multiclass_loss_f(prediction, target):
        prediction_prepared = torch.squeeze(prediction, 0)
        target_prepared = torch.squeeze(target, 0)
        target_prepared = torch.squeeze(target_prepared, 0)

        return prediction_prepared, target_prepared

      
      



, "categorical_dice_loss":





  • ( );





  • , batch ;





  • "" "" , [0, 1] Dice-loss;





  • , batct. .





, , one-hot , ( ), , . , , , . (5).





5.

".. ". *.nrrd .





import nrrd
#   numpy
read = nrrd.read(data_path) 
data, meta_data = read[0], read[1]

print(data.shape, np.max(data), np.min(data), meta_data, sep="\n")

(163, 112, 120)
14982
-2254 
 OrderedDict([('type', 'short'), ('dimension', 3), ('space', 'left-posterior-superior'), ('sizes', array([163, 112, 120])), ('space directions', array([[-0.5,  0. ,  0. ],
       [ 0. , -0.5,  0. ],
       [ 0. ,  0. ,  0.5]])), ('kinds', ['domain', 'domain', 'domain']), ('endian', 'little'), ('encoding', 'gzip'), ('space origin', array([131.57200623,  80.7661972 ,  32.29940033]))])
      
      



- , ? , , , .





, 8 12 . ( ) - ( 3- ) . , , "" -1 , ..





Parece tão louco quanto parece
,

- , . , . Skimage Stl.





from skimage.measure import marching_cubes
import nrrd
import numpy as np
from stl import mesh

path = 'some_path.nrrd'
data = nrrd.read(path)[0]


def three_d_creator(some_data):
    vertices, faces, volume, _ = marching_cubes(some_data)
    cube = mesh.Mesh(np.full(faces.shape[0], volume.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            cube.vectors[i][j] = vertices[f[j]]
    cube.save('name.stl')

    return cube


stl = three_d_creator(datas)
      
      



, "" . , , Win 10 3D Builder - . "" 3D . " " .





v3do. , , .





npy stl
from vedo import Volume, show, write

prediction = 'some_data_path.npy'

def show_save(data, save=False):
    data_multiclass = Volume(data, c='Set2', alpha=(0.1, 1), alphaUnit=0.87, mode=1)
    data_multiclass.addScalarBar3D(nlabels=9)
    show([(data_multiclass, "Multiclass teeth segmentation prediction")], bg='black', N=1, axes=1).close()
    write(data_multiclass.isosurface(), 'some_name_.stl') if save else None
    
show_save(prediction, save=True)
      
      



.





. :





model.summary()
model = UNet(dim=3, in_channels=1, out_channels=9, n_blocks=3, start_filters=9).to(device)
print(summary(model, (1, 168*, 120, 120)))
    
"""
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1      [-1, 9, 168, 120, 120]            252
              ReLU-2      [-1, 9, 168, 120, 120]              0
       BatchNorm3d-3      [-1, 9, 168, 120, 120]             18
            Conv3d-4      [-1, 9, 168, 120, 120]          2,196
              ReLU-5      [-1, 9, 168, 120, 120]              0
       BatchNorm3d-6      [-1, 9, 168, 120, 120]             18
         MaxPool3d-7        [-1, 9, 84, 60, 60]               0
         DownBlock-8  [[-1, 9, 84, 60, 60], [-1, 9, 168, 120, 120]]               0
            Conv3d-9       [-1, 18, 84, 60, 60]           4,392
             ReLU-10       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-11       [-1, 18, 84, 60, 60]              36
           Conv3d-12       [-1, 18, 84, 60, 60]           8,766
             ReLU-13       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-14       [-1, 18, 84, 60, 60]              36
        MaxPool3d-15       [-1, 18, 42, 30, 30]               0
        DownBlock-16  [[-1, 18, 18, 42, 30], [-1, 18, 84, 60, 60]]               0
           Conv3d-17       [-1, 36, 42, 30, 30]          17,532
             ReLU-18       [-1, 36, 42, 30, 30]               0
      BatchNorm3d-19       [-1, 36, 42, 30, 30]              72
           Conv3d-20       [-1, 36, 42, 30, 30]          35,028
             ReLU-21       [-1, 36, 42, 30, 30]               0
      BatchNorm3d-22       [-1, 36, 42, 30, 30]              72
        DownBlock-23  [[-1, 36, 42, 30, 30], [-1, 36, 42, 30, 30]]               0
  ConvTranspose3d-24       [-1, 18, 84, 60, 60]           5,202
             ReLU-25       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-26       [-1, 18, 84, 60, 60]              36
      Concatenate-27       [-1, 36, 84, 60, 60]               0
           Conv3d-28       [-1, 18, 84, 60, 60]          17,514
             ReLU-29       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-30       [-1, 18, 84, 60, 60]              36
           Conv3d-31       [-1, 18, 84, 60, 60]           8,766
             ReLU-32       [-1, 18, 84, 60, 60]               0
      BatchNorm3d-33       [-1, 18, 84, 60, 60]              36
          UpBlock-34       [-1, 18, 84, 60, 60]               0
  ConvTranspose3d-35      [-1, 9, 168, 120, 120]          1,305
             ReLU-36      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-37      [-1, 9, 168, 120, 120]             18
      Concatenate-38     [-1, 18, 168, 120, 120]              0
           Conv3d-39      [-1, 9, 168, 120, 120]          4,383
             ReLU-40      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-41      [-1, 9, 168, 120, 120]             18
           Conv3d-42      [-1, 9, 168, 120, 120]          2,196
             ReLU-43      [-1, 9, 168, 120, 120]              0
      BatchNorm3d-44      [-1, 9, 168, 120, 120]             18
          UpBlock-45      [-1, 9, 168, 120, 120]              0
           Conv3d-46      [-1, 9, 168, 120, 120]             90
================================================================
Total params: 108,036
Trainable params: 108,036
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.96
Forward/backward pass size (MB): 12170.30
Params size (MB): 0.41
Estimated Total Size (MB): 12174.66
----------------------------------------------------------------
    """
      
      



* ([9, 18, 36, 72]), - 9*(168, 120, 120)





Exp. No. 8 Segmentação intermediária em 8 categorias
.№8 8

, , . ? - "" 8- , . , 12 (GPU) .





Exp. No. 9 Segmentação completa
.№9

6. After words

, , - . . , , 2 , . , ? , , 28 , , "" / ? U-net GCNN Pytorch - Pytorch3D? , , bounding box( 1 ). , , .





()
" "
Um exemplo de um gráfico não direcionado para 28 categorias com "delimitadores"
28 ""

Agradecimentos especiais à minha esposa, Alena, pelo seu apoio especial durante este "mergulho nas trevas".





Obrigado a todos pela atenção. Críticas e sugestões construtivas, tanto correções quanto novos projetos, são bem-vindas.








All Articles