Skip to content

Commit ba8fd9e

Browse files
authored
Merge pull request #1 from radumg/BiLT
BiLT - LinearRegression & NaiveBayes Classifier
2 parents 937b970 + 42117d2 commit ba8fd9e

File tree

5 files changed

+341
-0
lines changed

5 files changed

+341
-0
lines changed

src/DynAI.sln

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
Microsoft Visual Studio Solution File, Format Version 12.00
3+
# Visual Studio 15
4+
VisualStudioVersion = 15.0.26430.16
5+
MinimumVisualStudioVersion = 10.0.40219.1
6+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DynAI", "DynAI\MachineLearning.csproj", "{5990DC45-28B6-4E1E-B3EA-C6549EE270F7}"
7+
EndProject
8+
Global
9+
GlobalSection(SolutionConfigurationPlatforms) = preSolution
10+
Debug|Any CPU = Debug|Any CPU
11+
Release|Any CPU = Release|Any CPU
12+
EndGlobalSection
13+
GlobalSection(ProjectConfigurationPlatforms) = postSolution
14+
{5990DC45-28B6-4E1E-B3EA-C6549EE270F7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
15+
{5990DC45-28B6-4E1E-B3EA-C6549EE270F7}.Debug|Any CPU.Build.0 = Debug|Any CPU
16+
{5990DC45-28B6-4E1E-B3EA-C6549EE270F7}.Release|Any CPU.ActiveCfg = Release|Any CPU
17+
{5990DC45-28B6-4E1E-B3EA-C6549EE270F7}.Release|Any CPU.Build.0 = Release|Any CPU
18+
EndGlobalSection
19+
GlobalSection(SolutionProperties) = preSolution
20+
HideSolutionNode = FALSE
21+
EndGlobalSection
22+
EndGlobal

src/DynAI/MachineLearning.cs

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
using Accord.MachineLearning.Bayes;
2+
using Accord.Math;
3+
using Accord.Math.Random;
4+
using Accord.Statistics.Filters;
5+
using Accord.Statistics.Models.Regression.Linear;
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
using System.Text;
10+
using System.Threading.Tasks;
11+
12+
namespace AI.MachineLearning
13+
{
14+
#region Linear Regression
15+
16+
/// <summary>
17+
/// Class providing support for linear regression ML algorithms
18+
/// </summary>
19+
public class LinearRegression
20+
{
21+
#region public properties
22+
// dataset
23+
public double[] inputs { get; private set; }
24+
public double[] outputs { get; private set; }
25+
public double testValue { get; private set; }
26+
27+
// regression
28+
public SimpleLinearRegression regression { get; private set; }
29+
public OrdinaryLeastSquares ols;
30+
31+
// state & result
32+
public bool learned { get; private set; }
33+
public double result { get; private set; }
34+
#endregion
35+
36+
/// <summary>
37+
/// Constructs a new LinearRegression machine.
38+
/// </summary>
39+
public LinearRegression(List<double> inputList, List<double> outputList)
40+
{
41+
// validation
42+
if (inputList == null || outputList == null) throw new ArgumentNullException("Neither the input list nor the output list can be NULL");
43+
44+
// initialise seed value
45+
Generator.Seed = new Random().Next();
46+
47+
// process input and output lists into arrays
48+
inputs = inputList.ToArray();
49+
outputs = outputList.ToArray();
50+
51+
// set up linear regression using OLS
52+
regression = new SimpleLinearRegression();
53+
ols = new OrdinaryLeastSquares();
54+
55+
// nulls
56+
testValue = new double();
57+
result = new double();
58+
this.learned = false;
59+
}
60+
61+
/// <summary>
62+
/// Use the object's inputs and outputs to learn the model of the linear regression, using OrdinaryLeastSquares
63+
/// </summary>
64+
public LinearRegression Learn()
65+
{
66+
regression = this.ols.Learn(inputs, outputs);
67+
learned = true;
68+
69+
return this;
70+
}
71+
72+
/// <summary>
73+
/// Using the learned model, predict an output for the specified input
74+
/// </summary>
75+
/// <param name="test">The value to use as input for the prediction</param>
76+
/// <returns>The predicted value</returns>
77+
public double Predict(double test)
78+
{
79+
// don't predict if we haven't learned the model yet
80+
if (this.learned != true) throw new Exception("Cannot predict before the machine has learned.");
81+
82+
// check we haven't already predicted for this input
83+
if (test == this.testValue && this.learned == true) return this.result;
84+
85+
// predict
86+
this.testValue = test;
87+
this.result = this.regression.Transform(this.testValue);
88+
89+
return this.result;
90+
}
91+
}
92+
#endregion
93+
94+
#region Classifiers
95+
/// <summary>
96+
/// Class providing support for Naive Bayes classification machines.
97+
/// </summary>
98+
public class NaiveBayes
99+
{
100+
#region public properties
101+
// dataset
102+
public string[][] dataset { get; private set; }
103+
public string[] columns { get; private set; }
104+
public string outputColumn { get; private set; }
105+
public int[][] inputs;
106+
public int[] outputs;
107+
108+
// classifier
109+
public Accord.MachineLearning.Bayes.NaiveBayes classifier;
110+
public NaiveBayesLearning learner;
111+
public Codification codebook { get; private set; }
112+
113+
// state & result
114+
public bool learned { get; private set; }
115+
public string[] testValue { get; private set; }
116+
public string result { get; private set; }
117+
public double[] probs { get; private set; }
118+
#endregion
119+
120+
/// <summary>
121+
/// Constructs a new NaiveBayes classification machine.
122+
/// </summary>
123+
public NaiveBayes(string[][] data, List<string> columnList, string outputColumn)
124+
{
125+
// validation
126+
if (data == null || columnList == null || outputColumn==null) throw new ArgumentNullException("Neither the input list nor the column list can be NULL");
127+
128+
// initialise seed value
129+
Generator.Seed = new Random().Next();
130+
131+
// process input and output lists into arrays
132+
this.dataset = data;
133+
this.columns = columnList.ToArray();
134+
this.outputColumn = outputColumn;
135+
136+
// Create a new codification codebook to
137+
// convert strings into discrete symbols
138+
this.codebook = new Codification(columns, this.dataset);
139+
140+
// Extract input and output pairs to train
141+
int[][] symbols = this.codebook.Transform(this.dataset);
142+
this.inputs = symbols.Get(null, 0, -1); // Gets all rows, from 0 to the last (but not the last)
143+
this.outputs = symbols.GetColumn(-1); // Gets only the last column
144+
145+
// Create a new Naive Bayes learning
146+
this.learner = new NaiveBayesLearning();
147+
148+
// nulls
149+
testValue = null;
150+
result = null;
151+
probs = null;
152+
this.learned = false;
153+
}
154+
155+
/// <summary>
156+
/// Use the object's inputs and outputs to learn the model of the linear regression, using OrdinaryLeastSquares
157+
/// </summary>
158+
public NaiveBayes Learn()
159+
{
160+
this.classifier = this.learner.Learn(inputs, outputs);
161+
this.learned = true;
162+
163+
return this;
164+
}
165+
166+
/// <summary>
167+
/// Using the learned model, predict an output for the specified input
168+
/// </summary>
169+
/// <param name="test">The value to use as input for the prediction</param>
170+
/// <returns>The predicted value</returns>
171+
public string Predict(string[] test)
172+
{
173+
// don't predict if we haven't learned the model yet
174+
if (this.learned != true) throw new Exception("Cannot predict before the machine has learned.");
175+
176+
// check we haven't already predicted for this input
177+
if (test == this.testValue && this.learned == true) return this.result;
178+
179+
// predict
180+
// First encode the test instance
181+
int[] instance = this.codebook.Transform(test);
182+
183+
// Let us obtain the numeric output that represents the answer
184+
int codeword = this.classifier.Decide(instance);
185+
186+
// Now let us convert the numeric output to an actual answer
187+
this.result = this.codebook.Revert(this.outputColumn, codeword);
188+
189+
// We can also extract the probabilities for each possible answer
190+
this.probs = this.classifier.Probabilities(instance);
191+
192+
return this.result;
193+
}
194+
}
195+
#endregion
196+
197+
#region Helpers
198+
199+
#endregion
200+
201+
}

src/DynAI/MachineLearning.csproj

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<Project ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
3+
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
4+
<PropertyGroup>
5+
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
6+
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
7+
<ProjectGuid>{5990DC45-28B6-4E1E-B3EA-C6549EE270F7}</ProjectGuid>
8+
<OutputType>Library</OutputType>
9+
<AppDesignerFolder>Properties</AppDesignerFolder>
10+
<RootNamespace>AI</RootNamespace>
11+
<AssemblyName>AI</AssemblyName>
12+
<TargetFrameworkVersion>v4.5.2</TargetFrameworkVersion>
13+
<FileAlignment>512</FileAlignment>
14+
<NuGetPackageImportStamp>
15+
</NuGetPackageImportStamp>
16+
</PropertyGroup>
17+
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
18+
<DebugSymbols>true</DebugSymbols>
19+
<DebugType>full</DebugType>
20+
<Optimize>false</Optimize>
21+
<OutputPath>bin\Debug\</OutputPath>
22+
<DefineConstants>DEBUG;TRACE</DefineConstants>
23+
<ErrorReport>prompt</ErrorReport>
24+
<WarningLevel>4</WarningLevel>
25+
<DocumentationFile>bin\Debug\AI.xml</DocumentationFile>
26+
</PropertyGroup>
27+
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
28+
<DebugType>pdbonly</DebugType>
29+
<Optimize>true</Optimize>
30+
<OutputPath>bin\Release\</OutputPath>
31+
<DefineConstants>TRACE</DefineConstants>
32+
<ErrorReport>prompt</ErrorReport>
33+
<WarningLevel>4</WarningLevel>
34+
</PropertyGroup>
35+
<ItemGroup>
36+
<Reference Include="Accord, Version=3.7.0.0, Culture=neutral, PublicKeyToken=fa1a88e29555ccf7, processorArchitecture=MSIL">
37+
<HintPath>..\packages\Accord.3.7.0\lib\net45\Accord.dll</HintPath>
38+
</Reference>
39+
<Reference Include="Accord.MachineLearning, Version=3.7.0.0, Culture=neutral, PublicKeyToken=fa1a88e29555ccf7, processorArchitecture=MSIL">
40+
<HintPath>..\packages\Accord.MachineLearning.3.7.0\lib\net45\Accord.MachineLearning.dll</HintPath>
41+
</Reference>
42+
<Reference Include="Accord.Math, Version=3.7.0.0, Culture=neutral, PublicKeyToken=fa1a88e29555ccf7, processorArchitecture=MSIL">
43+
<HintPath>..\packages\Accord.Math.3.7.0\lib\net45\Accord.Math.dll</HintPath>
44+
</Reference>
45+
<Reference Include="Accord.Math.Core, Version=3.7.0.0, Culture=neutral, PublicKeyToken=fa1a88e29555ccf7, processorArchitecture=MSIL">
46+
<HintPath>..\packages\Accord.Math.3.7.0\lib\net45\Accord.Math.Core.dll</HintPath>
47+
</Reference>
48+
<Reference Include="Accord.Statistics, Version=3.7.0.0, Culture=neutral, PublicKeyToken=fa1a88e29555ccf7, processorArchitecture=MSIL">
49+
<HintPath>..\packages\Accord.Statistics.3.7.0\lib\net45\Accord.Statistics.dll</HintPath>
50+
</Reference>
51+
<Reference Include="System" />
52+
<Reference Include="System.Core" />
53+
<Reference Include="System.Xml.Linq" />
54+
<Reference Include="System.Data.DataSetExtensions" />
55+
<Reference Include="Microsoft.CSharp" />
56+
<Reference Include="System.Data" />
57+
<Reference Include="System.Net.Http" />
58+
<Reference Include="System.Xml" />
59+
</ItemGroup>
60+
<ItemGroup>
61+
<Compile Include="MachineLearning.cs" />
62+
<Compile Include="Properties\AssemblyInfo.cs" />
63+
</ItemGroup>
64+
<ItemGroup>
65+
<None Include="packages.config" />
66+
</ItemGroup>
67+
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
68+
<Import Project="..\packages\Accord.3.7.0\build\Accord.targets" Condition="Exists('..\packages\Accord.3.7.0\build\Accord.targets')" />
69+
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
70+
<PropertyGroup>
71+
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
72+
</PropertyGroup>
73+
<Error Condition="!Exists('..\packages\Accord.3.7.0\build\Accord.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\Accord.3.7.0\build\Accord.targets'))" />
74+
</Target>
75+
</Project>

src/DynAI/Properties/AssemblyInfo.cs

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System.Reflection;
2+
using System.Runtime.CompilerServices;
3+
using System.Runtime.InteropServices;
4+
5+
// General Information about an assembly is controlled through the following
6+
// set of attributes. Change these attribute values to modify the information
7+
// associated with an assembly.
8+
[assembly: AssemblyTitle("DynAI Machine Learning")]
9+
[assembly: AssemblyDescription("")]
10+
[assembly: AssemblyConfiguration("")]
11+
[assembly: AssemblyCompany("")]
12+
[assembly: AssemblyProduct("DynAI Machine Learning")]
13+
[assembly: AssemblyCopyright("Copyright © Radu Gidei 2017")]
14+
[assembly: AssemblyTrademark("")]
15+
[assembly: AssemblyCulture("")]
16+
17+
// Setting ComVisible to false makes the types in this assembly not visible
18+
// to COM components. If you need to access a type in this assembly from
19+
// COM, set the ComVisible attribute to true on that type.
20+
[assembly: ComVisible(false)]
21+
22+
// The following GUID is for the ID of the typelib if this project is exposed to COM
23+
[assembly: Guid("5990dc45-28b6-4e1e-b3ea-c6549ee270f7")]
24+
25+
// Version information for an assembly consists of the following four values:
26+
//
27+
// Major Version
28+
// Minor Version
29+
// Build Number
30+
// Revision
31+
//
32+
// You can specify all the values or you can default the Build and Revision Numbers
33+
// by using the '*' as shown below:
34+
// [assembly: AssemblyVersion("1.0.*")]
35+
[assembly: AssemblyVersion("0.1.0.0")]
36+
[assembly: AssemblyFileVersion("0.1.0.0")]

src/DynAI/packages.config

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<packages>
3+
<package id="Accord" version="3.7.0" targetFramework="net452" />
4+
<package id="Accord.MachineLearning" version="3.7.0" targetFramework="net452" />
5+
<package id="Accord.Math" version="3.7.0" targetFramework="net452" />
6+
<package id="Accord.Statistics" version="3.7.0" targetFramework="net452" />
7+
</packages>

0 commit comments

Comments
 (0)