机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

一、问题与解决方案通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片、已经预先进行过处理,读取了各像素点的灰度值,并进行了标记。其中第0列是序号(不参与运算)、1-64列是像素值、65列是结果。我们以64位像素值为特征进行多元分类,算法采用SDCA最大熵分类算法。二、源码先贴出全部代码:namespaceMulticlassClassification_Mnist{class...

机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

一、问题与解决方案

通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片、已经预先进行过处理,读取了各像素点的灰度值,并进行了标记。

其中第0列是序号(不参与运算)、1-64列是像素值、65列是结果。

我们以64位像素值为特征进行多元分类,算法采用SDCA最大熵分类算法。

二、源码

先贴出全部代码:

namespace MulticlassClassification_Mnist{ class Program {  static readonly string TrainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "optdigits-full.csv");  static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");  static void Main(string[] args)  {MLContext mlContext = new MLContext(seed: 1); TrainAndSaveModel(mlContext);TestSomePredictions(mlContext);Console.WriteLine("Hit any key to finish the app");Console.ReadKey();  } public static void TrainAndSaveModel(MLContext mlContext)  {// STEP 1: 准备数据var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath,  columns: new[]  {new TextLoader.Column("Serial", DataKind.Single, 0),new TextLoader.Column("PixelValues", DataKind.Single, 1, 64),new TextLoader.Column("Number", DataKind.Single, 65)  },  hasHeader: true,  separatorChar: ','  );var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.2);var trainData = trainTestData.TrainSet;var testData = trainTestData.TestSet;// STEP 2: 配置数据处理管道  var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue);// STEP 3: 配置训练算法var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues");var trainingPipeline = dataProcessPipeline.Append(trainer)  .Append(mlContext.Transforms.Conversion.MapKeyToValue("Number", "Label"));// STEP 4: 训练模型使其与数据集拟合Console.WriteLine("=============== Train the model fitting to the DataSet ===============");  ITransformer trainedModel = trainingPipeline.Fit(trainData);// STEP 5:评估模型的准确性Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");var predictions = trainedModel.Transform(testData);var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Number", scoreColumnName: "Score");PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);// STEP 6:保存模型  mlContext.ComponentCatalog.RegisterAssembly(typeof(DebugConversion).Assembly);mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);Console.WriteLine("The model is saved to {0}", ModelPath);  }  private static void TestSomePredictions(MLContext mlContext)  {// Load Model  ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);// Create prediction engine var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);//num 1InputData MNIST1 = new InputData(){ PixelValues = new float[] { 0, 0, 0, 0, 14, 13, 1, 0, 0, 0, 0, 5, 16, 16, 2, 0, 0, 0, 0, 14, 16, 12, 0, 0, 0, 1, 10, 16, 16, 12, 0, 0, 0, 3, 12, 14, 16, 9, 0, 0, 0, 0, 0, 5, 16, 15, 0, 0, 0, 0, 0, 4, 16, 14, 0, 0, 0, 0, 0, 1, 13, 16, 1, 0 }}; var resultprediction1 = predEngine.Predict(MNIST1);resultprediction1.PrintToConsole(); } } class InputData {  public float Serial;  [VectorType(
源文地址:https://www.guoxiongfei.cn/cntech/18360.html