Simple Machine Learning classification with ML.NET (custom code without Model Builder)

TL;DR; check my GitHub project for bank transaction classification. https://github.com/jernejk/MLSample.SimpleTransactionTagging

UPDATE: If you are looking to learn Machine Learning with minimum code and effort, check out Simplified Machine Learning for Developers with ML.NET. The post below was created before ML.NET Model Builder was released and I made a video about it.

Goal

For a long time, I was trying to get into practical machine learning but I found machine learning wasn’t very accessible for most of the devs.

Recently as a challenge, I tried to make a lightning talk, where I tried to explain how to solve a relatable yet usable everyday problem with machine learning in just 15-20 minutes.

For the demo, I have decided to do bank transaction category classification, as this is something I do every week. While I use a financing service Toshl that collects and classifies all of my bank transactions, I still had to spend ~1 hour every week to fix misclassified transactions. I wanted to fix this problem without needing to write a complicated rule engine and machine learning seemed to be the right approach.

For this example, we’ll use .NET Core with ML.NET and a couple of transactions. You can look at https://github.com/dotnet/machinelearning-samples/. We’ll use Multi-class classification, which is perfect for our problem.

My data is coming from the finance app Toshl which already has the correct category labels for each transaction.

Code

First, I started by looking GitHub Issue Labeler and simply modifying its training data to be my transactions. That surprisingly worked pretty well and I have started a new project that was based on that prototype.

After that, I have decided with the following structure for input which was loaded from JSON. Description and transaction type are used as inputs (features) and category as expected result (label).

[DataContract]
public class TransactionData
{
    [DataMember(Name = "desc")]
    public string Description { get; set; }

    [DataMember(Name = "category")]
    public string Category { get; set; }

    [DataMember(Name = "transactionType")]
    public string TransactionType { get; set; }
}

Next data structure is the prediction model. Column name needs to PredictedLabel, while property name can be anything you like.

public class TransactionPrediction
{
    [ColumnName("PredictedLabel")]
    public string Category;
}

I have loaded the data from JSON.
Next step is to train a model and save it.
This step doesn’t have verification of the model, which is highly recommended but not necessary for ML. (it’s like an integration test for very fragile code)

public class BankTransactionTrainingService
{
    public void Train(IEnumerable<TransactionData> trainingData, string modelSavePath)
    {
        var mlContext = new MLContext(seed: 0);

        // Configure ML pipeline
        var pipeline = LoadDataProcessPipeline(mlContext);
        var trainingPipeline = GetTrainingPipeline(mlContext, pipeline);
        var trainingDataView = mlContext.Data.LoadFromEnumerable(trainingData);

        // Generate training model.
        var trainingModel = trainingPipeline.Fit(trainingDataView);

        // Save training model to disk.
        mlContext.Model.Save(trainingModel, trainingDataView.Schema, modelSavePath);
    }

    private IEstimator<ITransformer> LoadDataProcessPipeline(MLContext mlContext)
    {
        // Configure data pipeline based on the features in TransactionData.
        // Description and TransactionType are the inputs and Category is the expected result.
        var dataProcessPipeline = mlContext
            .Transforms.Conversion.MapValueToKey(inputColumnName: nameof(TransactionData.Category), outputColumnName: "Label")
            .Append(mlContext.Transforms.Text.FeaturizeText(inputColumnName: nameof(TransactionData.Description), outputColumnName: "TitleFeaturized"))
            .Append(mlContext.Transforms.Text.FeaturizeText(inputColumnName: nameof(TransactionData.TransactionType), outputColumnName: "DescriptionFeaturized"))
            // Merge two features into a single feature.
            .Append(mlContext.Transforms.Concatenate("Features", "TitleFeaturized", "DescriptionFeaturized"))
            .AppendCacheCheckpoint(mlContext);

        return dataProcessPipeline;
    }

    private IEstimator<ITransformer> GetTrainingPipeline(MLContext mlContext, IEstimator<ITransformer> pipeline)
    {
        // Use the multi-class SDCA algorithm to predict the label using features.
        // For StochasticDualCoordinateAscent the KeyToValue needs to be PredictedLabel.
        return pipeline
            .Append(GetScadaTrainer(mlContext))
            .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
    }

    private IEstimator<ITransformer> GetScadaTrainer(MLContext mlContext)
    {
        return mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features");
    }
}

The code below will load the model and predict the category.

public class BankTransactionLabelService
{
    private readonly MLContext _mlContext;
    private PredictionEngine<TransactionData, TransactionPrediction> _predEngine;

    public BankTransactionLabelService()
    {
        _mlContext = new MLContext(seed: 0);
    }

    public void LoadModel(string modelPath)
    {
        ITransformer loadedModel;
        using (var stream = new FileStream(modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
            loadedModel = _mlContext.Model.Load(stream, out var modelInputSchema);
        _predEngine = _mlContext.Model.CreatePredictionEngine<TransactionData, TransactionPrediction>(loadedModel);
    }

    public string PredictCategory(TransactionData transaction)
    {
        var prediction = new TransactionPrediction();
        _predEngine.Predict(transaction, ref prediction);
        return prediction?.Category;
    }
}

And that’s all from machine learning!

Let’s see how everything comes together.

public static void Main(string[] args)
{
    // Some manually chosen transactions with some modifications.
    Console.WriteLine("Loading training data...");
    var trainingData = JsonConvert.DeserializeObject<List<TransactionData>>(File.ReadAllText("training.json"));

    Console.WriteLine("Training the model...");
    var trainingService = new BankTransactionTrainingService();
    trainingService.Train(trainingData, "Model.zip");

    Console.WriteLine("Prepare transaction labeler...");
    var labelService = new BankTransactionLabelService();
    labelService.LoadModel("Model.zip");

    Console.WriteLine("Predict some transactions based on their description and type...");
    
    // Should be "coffee & tea".
    MakePrediction(labelService, "AMERICAN CONCEPTS PT BRISBANE", "expense");
    
    // The number in the transaction is always random but it will work despite that. Result: rent
    MakePrediction(labelService, "ANZ M-BANKING PAYMENT TRANSFER 513542 TO SPIRE REALITY", "expense");
    
    // In fact, searching just for part of the transaction will give us the same result.
    MakePrediction(labelService, "SPIRE REALITY", "expense");
    
    // If we change the transaction type, we'll get a reimbursement instead.
    MakePrediction(labelService, "SPIRE REALITY", "income");
}

private static void MakePrediction(BankTransactionLabelService labelService, string description, string transactionType)
{
    string prediction = labelService.PredictCategory(new TransactionData
    {
        Description = description,
        TransactionType = transactionType
    });

    Console.WriteLine($"{description} ({transactionType}) => {prediction}");
}

In the above example, we tried to identify a coffee shop, rent in 2 different ways and a reimbursement coming from the same description as rent. With machine learning, we can now predict category even if it’s not completely correct, almost like a fuzzy search.

Conclusion

In conclusion, it took me 2 hours to write my initial prototype and about 4 hours to clean up incorrect data. (I’m working with 1k+ real transactions) The results are very good with about 95% accuracy!

With ML.NET most of the work is in selecting the right machine learning configuration and preparing test data. For example, description and transaction type are great features for predicting labels. In contrast, adding features like amount, date/time which more often than not is not directly linked to the label, may result in predictions to be way off or unable anything at all.

In the past, I have made a few attempts to classify my transactions via a rule engine and failed to deliver great long term value with few days of work. The project was abandoned because the projected effort of maintaining the rules was too great.