-
Notifications
You must be signed in to change notification settings - Fork 0
/
weight_vector.h
114 lines (88 loc) · 2.37 KB
/
weight_vector.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Implementation of dense weight vectors
//
// Copyright (C) 2012 Heidelberg University
//
// Author: Sascha Fendrich
//
// This file is part of Sol.
//
// Sol is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Sol is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with Sol. If not, see <http://www.gnu.org/licenses/>.
#ifndef WEIGHT_VECTOR_H
#define WEIGHT_VECTOR_H
#include <cstring>
#include "sparse_vector.h"
class WeightVector
{
public:
WeightVector (int size);
WeightVector (const WeightVector ©);
~WeightVector ();
int size () const;
float bias () const;
void set_bias (float bias);
void clear ();
float GetWeight (int index) const;
void SetWeight (int index, float value);
void PlusEquals (const SparseVector &rhs);
void PlusEquals (float scalar, const SparseVector &rhs);
float InnerProduct (const SparseVector &rhs) const;
void Scale (float factor);
float squaredL2Norm () const;
void RegularizeL1 (const float factor);
void RegularizeL2 (const float factor);
private:
float *vector_;
float bias_;
float scale_;
int size_;
float squaredL2Norm_;
};
inline int WeightVector::size () const
{
return size_;
}
inline float WeightVector::bias () const
{
return bias_;
}
inline void WeightVector::set_bias (float bias)
{
bias_ = bias;
}
inline void WeightVector::clear ()
{
memset (vector_, 0, size_ * sizeof (float));
squaredL2Norm_ = 0;
}
inline float WeightVector::GetWeight (int index) const
{
return scale_ * vector_[index];
}
inline void WeightVector::SetWeight (int index, float value)
{
vector_[index] = value / scale_;
}
inline void WeightVector::Scale (float factor)
{
scale_ *= factor;
}
inline float WeightVector::squaredL2Norm () const
{
return squaredL2Norm_;
}
inline void WeightVector::RegularizeL2 (const float factor)
{
Scale (1.0 - factor); // TODO: need minimum?
}
#endif