forked from KevinCoble/AIToolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSVMExtensions.swift
150 lines (129 loc) · 5.36 KB
/
SVMExtensions.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
//
// SVMExtensions.swift
// AIToolbox
//
// Created by Kevin Coble on 1/17/16.
// Copyright © 2016 Kevin Coble. All rights reserved.
//
// This file contains extensions to the SVMModel class to get it to play nicely with the rest of the library
// This code doesn't go in the SVM file, as I want to keep that close to the original LIBSVM source material
import Foundation
enum SVMError: Error {
case invalidModelType
}
extension SVMModel : Classifier {
public func getInputDimension() -> Int
{
if (supportVector.count < 1) { return 0 }
return supportVector[0].count
}
public func getParameterDimension() -> Int
{
return totalSupportVectors //!! This needs to be calculated correctly
}
public func getNumberOfClasses() -> Int
{
return numClasses
}
public func setParameters(_ parameters: [Double]) throws
{
//!! This needs to be filled in
}
public func getParameters() throws -> [Double]
{
//!! This needs to be filled in
return []
}
public func setCustomInitializer(_ function: ((_ trainData: MLDataSet)->[Double])!) {
// Ignore, as SVM doesn't use an initialization
}
public func trainClassifier(_ trainData: MLClassificationDataSet) throws
{
// Verify the SVMModel is the right type
if type != .c_SVM_Classification && type != .ν_SVM_Classification { throw SVMError.invalidModelType }
// Verify the data set is the right type
if (trainData.dataType == .regression) { throw DataTypeError.invalidDataType }
// Train on the data (ignore initialization, as SVM's do single-batch training)
if (trainData is DataSet) {
train(trainData as! DataSet)
}
else {
// Convert the data set to a DataSet class, as the SVM was ported from a public domain code that used specific properties that were added to the DataSet class but are not in the MLDataSet protocols
if let convertedData = DataSet(fromClassificationDataSet: trainData) {
train(convertedData)
}
}
}
public func continueTrainingClassifier(_ trainData: MLClassificationDataSet) throws
{
// Linear regression uses one-batch training (solved analytically)
throw MachineLearningError.continuationNotSupported
}
public func classifyOne(_ inputs: [Double]) ->Int
{
// Get the support vector start index for each class
var coeffStart = [0]
for index in 0..<numClasses-1 {
coeffStart.append(coeffStart[index] + supportVectorCount[index])
}
// Get the kernel value for each support vector at the input value
var kernelValue: [Double] = []
for sv in 0..<totalSupportVectors {
kernelValue.append(Kernel.calcKernelValue(kernelParams, x: inputs, y: supportVector[sv]))
}
// Allocate vote space for the classification
var vote = [Int](repeating: 0, count: numClasses)
// Initialize the decision values
var decisionValues: [Double] = []
// Get the seperation info between each class pair
var permutation = 0
for i in 0..<numClasses {
for j in i+1..<numClasses {
var sum = 0.0
for k in 0..<supportVectorCount[i] {
sum += coefficients[j-1][coeffStart[i]+k] * kernelValue[coeffStart[i]+k]
}
for k in 0..<supportVectorCount[j] {
sum += coefficients[i][coeffStart[j]+k] * kernelValue[coeffStart[j]+k]
}
sum -= ρ[permutation]
decisionValues.append(sum)
permutation += 1
if (sum > 0) {
vote[i] += 1
}
else {
vote[j] += 1
}
}
}
// Get the most likely class, and set it
var maxIndex = 0
for index in 1..<numClasses {
if (vote[index] > vote[maxIndex]) { maxIndex = index }
}
return labels[maxIndex]
}
public func classify(_ testData: MLClassificationDataSet) throws
{
// Verify the SVMModel is the right type
if type != .c_SVM_Classification || type != .ν_SVM_Classification { throw SVMError.invalidModelType }
// Verify the data set is the right type
if (testData.dataType != .classification) { throw DataTypeError.invalidDataType }
if (supportVector.count <= 0) { throw MachineLearningError.notTrained }
if (testData.inputDimension != supportVector[0].count) { throw DataTypeError.wrongDimensionOnInput }
// Put the data into a DataSet for SVM (it uses a DataSet so that it can be both regressor and classifier)
if let data = DataSet(dataType: .classification, withInputsFrom: testData) {
// Predict
predictValues(data)
// Transfer the predictions back to the classifier data set
for index in 0..<testData.size {
let resultClass = try data.getClass(index)
try testData.setClass(index, newClass: resultClass)
}
}
else {
throw MachineLearningError.dataWrongDimension
}
}
}