Construindo uma Rede Neural Hopfield em JavaScript

Confrontado com redes neurais na universidade, a rede Hopfield tornou-se uma das minhas favoritas. Fiquei surpreso que fosse o último da lista de laboratórios, pois seu trabalho pode ser demonstrado claramente por meio de imagens e não é tão difícil de implementar.





Este artigo demonstra como resolver o problema de restauração de imagens distorcidas usando uma rede neural de Hopfield, previamente treinada em imagens de referência.





Tentei descrever passo a passo e da forma mais simples possível o processo de implementação de um programa que permite brincar com uma rede neural direto no navegador, treinar a rede usando minhas próprias imagens desenhadas à mão e testar seu funcionamento em imagens distorcidas imagens.





Fontes no Github e demo .





Para implementação, você precisará de:





  • Navegador





  • Compreensão básica de redes neurais





  • Conhecimento básico de JavaScript / HTML





Um pouco de teoria

A rede neural Hopfield é uma rede neural totalmente conectada com uma matriz simétrica de conexões. Essa rede pode ser usada para organizar a memória associativa, como um filtro, e também para resolver alguns problemas de otimização.





-  .     ,   .   ,   , .





Diagrama de blocos da rede neural de Hopfield

, , . (, ), .   ,  , «» ( ).





:







  1. :



    w_ {ij} = \ left \ {\ begin {matrix} \ sum_ {k = 1} ^ {m} x_ {i} ^ {k} * x_ {j} ^ {k} & i \ neq j \\ 0 , & i = j \ end {matriz} \ right.



    m—  

    x_ {i} ^ {k}, x_ {j} ^ {k}eu- j- k- .





  2.   . :

    y_ {j} (0) = x_ {j}





  3. (   ):



    y_ {j} (t + 1) = f \ left (\ sum_ {i = 1} ^ {n} w_ {ij} * y_ {i} (t) \ right)



    f —   [-1; 1];

    t — ;

    j = 1 ... n;  n —  .





  4.   .  —   3, , , . ,   .





.





Demonstração do programa

    Canvas   . HTML  CSS ,     (  ).





Canvas , ( ) .   ,     Canvas (  «»     ).





 , 10×10   .   , ,   100 (  100  ).  — ,  −1  1, −1 — ,  1 — .





-   , .





//     10   
const gridSize = 10;
//     
const squareSize = 45;
//    (100)
const inputNodes = gridSize * gridSize;

//         ,
//      
let userImageState = [];
//      
let isDrawing = false;
//  
for (let i = 0; i < inputNodes; i += 1) {  
  userImageState[i] = -1;  
}

//   :
const userCanvas = document.getElementById('userCanvas');
const userContext = userCanvas.getContext('2d');
const netCanvas = document.getElementById('netCanvas');
const netContext = netCanvas.getContext('2d');
      
      



, .





//      
//   100  (gridSize * gridSize)
const drawGrid = (ctx) => {
  ctx.beginPath();
  ctx.fillStyle = 'white';
  ctx.lineWidth = 3;
  ctx.strokeStyle = 'black';
  for (let row = 0; row < gridSize; row += 1) {
    for (let column = 0; column < gridSize; column += 1) {
      const x = column * squareSize;
      const y = row * squareSize;
      ctx.rect(x, y, squareSize, squareSize);
      ctx.fill();
      ctx.stroke();
    }
  }
  ctx.closePath();
};
      
      



«» ,    .





//   
const handleMouseDown = (e) => {
  userContext.fillStyle = 'black';
  //      x, y
  //  squareSize  squareSize (4545 )
  userContext.fillRect(
    Math.floor(e.offsetX / squareSize) * squareSize,
    Math.floor(e.offsetY / squareSize) * squareSize,
    squareSize, squareSize,
  );

  //     ,
  //      
  const { clientX, clientY } = e;
  const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
  const index = calcIndex(coords.x, coords.y, gridSize);

  //       
  if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
    userImageState[index] = 1;
  }

  //   (   )
  isDrawing = true;
};

//     
const handleMouseMove = (e) => {
  //   , ..      ,    
  if (!isDrawing) return;

  //  ,   handleMouseDown
  //     isDrawing = true;
  userContext.fillStyle = 'black';

  userContext.fillRect(
    Math.floor(e.offsetX / squareSize) * squareSize,
    Math.floor(e.offsetY / squareSize) * squareSize,
    squareSize, squareSize,
  );

  const { clientX, clientY } = e;
  const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
  const index = calcIndex(coords.x, coords.y, gridSize);

  if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
    userImageState[index] = 1;
  }
};
      
      



  , , getNewSquareCoords, calcIndex  isValidIndex.  .





//      
//      
const calcIndex = (x, y, size) => x + y * size;

// ,     
const isValidIndex = (index, len) => index < len && index >= 0;

//        
//  ,      0  9
const getNewSquareCoords = (canvas, clientX, clientY, size) => {
  const rect = canvas.getBoundingClientRect();
  const x = Math.ceil((clientX - rect.left) / size) - 1;
  const y = Math.ceil((clientY - rect.top) / size) - 1;
  return { x, y };
};
      
      



.   .





const clearCurrentImage = () => {
  //    ,    
  //       
  drawGrid(userContext);
  drawGrid(netContext);
  userImageState = new Array(gridSize * gridSize).fill(-1);
};
      
      



  «» .





 — .   ( ).





...
const weights = [];  //   
for (let i = 0; i < inputNodes; i += 1) {
  weights[i] = new Array(inputNodes).fill(0); //       0
  userImageState[i] = -1;
}
...
      
      



    , , inputNodes .     100 ,      100 .





( )   .     . .





const memorizeImage = () => {
  for (let i = 0; i < inputNodes; i += 1) {
    for (let j = 0; j < inputNodes; j += 1) {
      if (i === j) weights[i][j] = 0;
      else {
        // ,       userImageState  
        //  -1  1,  -1 -  ,  1 -     
        weights[i][j] += userImageState[i] * userImageState[j];
      }
    }
  }
};
      
      



,   ,    ,   . :





// -  html   lodash:
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.21/lodash.min.js"></script>
...
const recognizeSignal = () => {
  let prevNetState;
  //      .  
  //       
  // (2  ),     
  const currNetState = [...userImageState];
  do {
    //    , 
		// ..     
    prevNetState = [...currNetState];
    //      3  
    for (let i = 0; i < inputNodes; i += 1) {
      let sum = 0;
      for (let j = 0; j < inputNodes; j += 1) {
        sum += weights[i][j] * prevNetState[j];
      }
      //    ( - )
      currNetState[i] = sum >= 0 ? 1 : -1;
    }
    //      
    //     - isEqual
  } while (!_.isEqual(currNetState, prevNetState));

  //    ( ),   
  drawImageFromArray(currNetState, netContext);
};
      
      



    isEqual   lodash.





drawImageFromArray.       .





const drawImageFromArray = (data, ctx) => {
  const twoDimData = [];
  //     
  while (data.length) twoDimData.push(data.splice(0, gridSize));

  //   
  drawGrid(ctx);
  //     ( )
  for (let i = 0; i < gridSize; i += 1) {
    for (let j = 0; j < gridSize; j += 1) {
      if (twoDimData[i][j] === 1) {
        ctx.fillStyle = 'black';
        ctx.fillRect((j * squareSize), (i * squareSize), squareSize, squareSize);
      }
    }
  }
};
      
      



  HTML   .





HTML
const resetButton = document.getElementById('resetButton');
const memoryButton = document.getElementById('memoryButton');
const recognizeButton = document.getElementById('recognizeButton');

//    
resetButton.addEventListener('click', () => clearCurrentImage());
memoryButton.addEventListener('click', () => memorizeImage());
recognizeButton.addEventListener('click', () => recognizeSignal());

//    
userCanvas.addEventListener('mousedown', (e) => handleMouseDown(e));
userCanvas.addEventListener('mousemove', (e) => handleMouseMove(e));
//  ,         
userCanvas.addEventListener('mouseup', () => isDrawing = false);
userCanvas.addEventListener('mouseleave', () => isDrawing = false);

//  
drawGrid(userContext);
drawGrid(netContext);
      
      



, :





Imagens de referência para treinamento de rede

:





Tentando reconhecer a imagem distorcida da letra H
Tentando reconhecer uma imagem distorcida da letra T

! .





, m  , 0,15 * n( n—   ). , , , ,   ,               .





Fontes no Github e demo .





Em vez de literatura, as palestras foram usadas por um excelente professor em redes neurais - Sergei Mikhailovich Roshchin , pelo qual muitos agradecimentos a ele.








All Articles