Vamos ver as meninas? Ou ml.net no trabalho

Infelizmente, o mundo do aprendizado de máquina pertence ao python.





Há muito tempo está consolidado como uma linguagem de trabalho para o Data Silence, mas a Microsoft decidiu argumentar e apresentou sua própria ferramenta que pode ser facilmente integrada ao ecossistema que o mundo inteiro usa agora. Foi assim que nasceu o ML.NET, um sistema de aprendizado de máquina de código aberto e multiplataforma para desenvolvedores .NET.





Neste artigo, quero mostrar que usar ml.net não é mais difícil do que o resto das opções que são, em um exemplo realmente funcional, o link que deixarei abaixo. Este é um canal em um telegrama que automaticamente coleta dados, classifica (é o que vamos considerar) e posta. Quem se importa, seja bem-vindo.





Formulação do problema

Como um adolescente, eu realmente queria ter um bot legal onde eu pudesse olhar para as meninas, que não fosse embalado com anúncios para os olhos, mas apenas uma foto e pronto. Então, quando tive tempo livre, as estrelas e o desejo se juntaram, comecei imediatamente a resolver esse problema.





Coleção de dados

Para começar, comprei um upload de dados do Twitter para a tag que me interessa, que o serviço fornece em formato csv (vários arquivos diferentes: o próprio tweet, mídia, links). Depois de selecionar o arquivo de que preciso, escrevemos rapidamente uma classe para analisar os dados e filtrar as duplicatas. Como resultado, apenas as referências às imagens que participarão do treinamento são deixadas na memória. Isso é bom, mas mesmo assim as imagens precisam ser rotuladas, ou seja, divididas em categorias. No meu caso, escolhi: meninos, meninas, lixo e outros (no começo escolhi default, mas quando passei de strings para Enum, tive que mudar o nome da categoria). Todas essas fotos, eu carreguei, meticulosamente divididas em papais que refletem a marca da foto, então é hora da coisa mais interessante - o código.





Treinamento de modelo

, , .





 

 — .  , .





, .





( ; . Deep learning) — ( , , , ), (. feature/representation learning), .





, , , , . TensorFlow Inception , ImageNet.





" ", ( 2000 , 2 , , +- ).





, , , , , . 4 500 .





. , model nuget :





using Microsoft.ML; 
using Microsoft.ML.Data; 
      
      



, :





    private readonly string _inceptionTensorFlowModel; //    Inception 
    private MLContext mlContext;
    private ITransformer model;
    private DataViewSchema schema;
    private string modelName = "model.zip"; //     
    private string _setsPath = @"C:\datasets"; //     ,      
    
    
        public Model(string inceptionTensorFlowModel)
        {
            mlContext = new MLContext();
            _inceptionTensorFlowModel = inceptionTensorFlowModel;
        }
      
      



MLContext - .NET. "" , , DbContext EntityFramework.





ITransformer - , , , .





DataViewSchema - .





, "", , .





public class ImageData
    {
        [LoadColumn(0)]
        public string ImagePath;

        [LoadColumn(1)]
        public string Label;

  		//,   ,        
        public static (IEnumerable<ImageData> train, IEnumerable<ImageData> test) ReadData(string pathToFolder)
        {
            List<ImageData> list = new List<ImageData>();
            var directories = Directory.EnumerateDirectories(pathToFolder);
            foreach (var dir in directories)
            {
                if (!dir.Contains("girls") && !dir.Contains("boys") && !dir.Contains("trash") && !dir.Contains("other"))
                    continue;
                var label = dir.Split(@"\").Last();
                foreach (var file in Directory.GetFiles(dir))
                {
                    list.Add(new ImageData()
                    {
                        ImagePath = file,
                        Label = label
                    });
                }
            }
            list = list.Shuffle().ToList();
            return GetSets(list);
        }

				//      
        public static (IEnumerable<ImageData> train, IEnumerable<ImageData> test) GetSets(IEnumerable<ImageData> data)
        {
            var trainCount = data.Count() / 100 * 99;
            var train = data.Take(trainCount);
            var test = data.Skip(trainCount);
            return (train, test);
        }
    }
    public class ImagePrediction : ImageData
    {
        [ColumnName("Score")]
        public float[] Score;

        public string PredictedLabelValue;
    }
      
      



IEnumerable :





,





 public static IEnumerable<T> Shuffle<T>(this IEnumerable<T> source)
        {
            return source.Shuffle(new Random());
        }
        public static IEnumerable<T> Shuffle<T>(this IEnumerable<T> source, Random rng)
        {
            if (source == null) throw new ArgumentNullException("source");
            if (rng == null) throw new ArgumentNullException("rng");

            return source.ShuffleIterator(rng);
        }

        private static IEnumerable<T> ShuffleIterator<T>(
            this IEnumerable<T> source, Random rng)
        {
            var buffer = source.ToList();
            for (int i = 0; i < buffer.Count; i++)
            {
                int j = rng.Next(i, buffer.Count);
                yield return buffer[j];

                buffer[j] = buffer[i];
            }
        }
      
      



, :





private struct InceptionSettings
        {
            public const int ImageHeight = 224;
            public const int ImageWidth = 224;
            public const float Mean = 117;
            public const float Scale = 1;
            public const bool ChannelsLast = true;
        }
      
      



, .





:





private double TrainModel()
        {
            IEstimator<ITransformer> pipeline = mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: "", inputColumnName: nameof(ImageData.ImagePath))
                           .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: InceptionSettings.ImageWidth, imageHeight: InceptionSettings.ImageHeight, inputColumnName: "input"))
                           .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: InceptionSettings.ChannelsLast, offsetImage: InceptionSettings.Mean))
                           .Append(mlContext.Model.LoadTensorFlowModel(_inceptionTensorFlowModel).
                               ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
                           .Append(mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelKey", inputColumnName: "Label"))
                           .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelKey", featureColumnName: "softmax2_pre_activation"))
                           .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))
                           .AppendCacheCheckpoint(mlContext);

            var loadImages = ImageData.ReadData(_setsPath);
            IDataView trainingData = mlContext.Data.LoadFromEnumerable<ImageData>(loadImages.train);
            ITransformer model = pipeline.Fit(trainingData);
            IDataView testData = mlContext.Data.LoadFromEnumerable<ImageData>(loadImages.test);
            IDataView predictions = model.Transform(testData);
            List<ImagePrediction> imagePredictionData = mlContext.Data.CreateEnumerable<ImagePrediction>(predictions, true).ToList();
            MulticlassClassificationMetrics metrics =
                mlContext.MulticlassClassification.Evaluate(predictions,
                  labelColumnName: "LabelKey",
                  predictedLabelColumnName: "PredictedLabel");
            schema = trainingData.Schema;
            return metrics.LogLoss;
        }
      
      



:





IEstimator<ITransformer> pipeline = mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: "", inputColumnName: nameof(ImageData.ImagePath))
     .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: InceptionSettings.ImageWidth, imageHeight: InceptionSettings.ImageHeight, inputColumnName: "input"))
     .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: InceptionSettings.ChannelsLast, offsetImage: InceptionSettings.Mean))
                           
      
      



. , :





.Append(mlContext.Model.LoadTensorFlowModel(_inceptionTensorFlowModel).
    ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
      
      



 . , :





.Append(mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelKey", inputColumnName: "Label"))
      
      



ml.net, , .





:





.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelKey", featureColumnName: "softmax2_pre_activation"))
      
      



:





.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))
.AppendCacheCheckpoint(mlContext);
      
      



:





var loadImages = ImageData.ReadData(_setsPath);
            IDataView trainingData = mlContext.Data.LoadFromEnumerable<ImageData>(loadImages.train);
            model = pipeline.Fit(trainingData);
      
      



, .





						IDataView testData = mlContext.Data.LoadFromEnumerable<ImageData>(loadImages.test);
            IDataView predictions = model.Transform(testData);
            List<ImagePrediction> imagePredictionData = mlContext.Data.CreateEnumerable<ImagePrediction>(predictions, true).ToList();
            MulticlassClassificationMetrics metrics =
                mlContext.MulticlassClassification.Evaluate(predictions,
                  labelColumnName: "LabelKey",
                  predictedLabelColumnName: "PredictedLabel");
      
      



. . , "" .





            schema = trainingData.Schema;
            return metrics.LogLoss;
      
      



LogLoss( ).





, .





, , :





  public void SaveModel() => mlContext.Model.Save(model, schema, Path.Combine(_setsPath, modelName));
      
      



, , :





    public void FitModel()
    {
        var LogLoss = TrainModel();
        Console.WriteLine($"LogLoss is {LogLoss}");
        SaveModel();
    }

      
      



, , , , , .





, , , .





:





    private PredictionEngine<ImageData, ImagePrediction> predictor;
      
      



, (+ ):





        public ImagePrediction ClassifySingleImage(string filePath)
        {
            if (model == null)
                LoadModel();
            if (predictor == null)
                predictor = mlContext.Model.CreatePredictionEngine<ImageData, ImagePrediction>(model);
            var imageData = new ImageData()
            {
                ImagePath = filePath
            };
            return predictor.Predict(imageData);
        }
        public void LoadModel() =>
            model = mlContext.Model.Load(Path.Combine(_setsPath, modelName), out schema);
      
      



, .





, :





 static void Main(string[] args)
        {
            Console.ForegroundColor = ConsoleColor.White;
            Stopwatch s = new Stopwatch();
            s.Start();

            Model model = new Model(@"C:\tensorflow_inception_graph.pb");
            model.FitModel();
            Console.WriteLine($"##### Model train ended for {s.Elapsed.Minutes}:{s.Elapsed.Seconds} #####");

            s.Restart();

            var res1 = model.ClassifySingleImage(@"C:\EugRqKFXUAYMTWz.jpg");
            Console.WriteLine($" > It's trash. Classification result is {res1.PredictedLabelValue} with score: {res1.Score.Max()}");
            Console.WriteLine($"##### Ended for {s.Elapsed.Minutes}:{s.Elapsed.Seconds} #####");

            s.Restart();

            var res2 = model.ClassifySingleImage(@"C:\EvpmOjIXcAMgj5r.jpg");
            Console.WriteLine($" > It's girl. Classification result is {res2.PredictedLabelValue} with score: {res1.Score.Max()}");
            Console.WriteLine($"##### Ended for {s.Elapsed.Minutes}:{s.Elapsed.Seconds} #####");
        }
      
      



:





Apesar das métricas um tanto fracas (eu ainda usei 20 imagens para testes): 0,55, mas o modelo lidou com suas tarefas perfeitamente. Este é o modelo que uso para meu bot nsfw , que recebe dados do Twitter, os classifica e publica.





Portanto, não é difícil treinar o modelo e adicionar ao seu projeto, o principal desejo é descobrir. E você nunca deve parar de aprender coisas novas.








All Articles