Skip to content

Commit

Permalink
v2.1 改进分类器,更新模型
Browse files Browse the repository at this point in the history
  • Loading branch information
xpqiu committed Oct 8, 2014
1 parent 6623f35 commit 668183b
Show file tree
Hide file tree
Showing 55 changed files with 2,080 additions and 1,978 deletions.
1 change: 0 additions & 1 deletion .classpath
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src/main/java"/>
<classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER">
<attributes>
<attribute name="maven.pomderived" value="true"/>
Expand Down
20 changes: 14 additions & 6 deletions fnlp-core/src/main/java/org/fnlp/data/reader/SequenceReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,19 @@ private Instance readSequence() {
cur = null;
try {
ArrayList<ArrayList<String>> seq = new ArrayList<ArrayList<String>>();
ArrayList<String> first = new ArrayList(); //至少有一列元素
seq.add(first);
ArrayList<String> firstColumnList = new ArrayList(); //至少有一列元素
seq.add(firstColumnList);
ArrayList<String> labels = null;
if(hasTarget){
labels = new ArrayList<String>();
}
String content = null;

while ((content = reader.readLine()) != null) {
lineNo++;
// content = content.trim();
content = content.trim();
if (content.matches("^$")){
if(first.size()>0) //第一列个数>0
if(firstColumnList.size()>0) //第一列个数>0
break;
else
continue;
Expand Down Expand Up @@ -140,10 +141,17 @@ private Instance readSequence() {
}else{
ensure(colsnum,seq);
seq.get(colsnum).add(content.substring(start));
}
}
//debug
// if(colsnum>2){
// System.out.println(content);
// }
}
if (first.size() > 0){

if (firstColumnList.size() > 0){
cur = new Instance(seq, labels);
//debug
// cur.setSource(firstColumnList.toString());
}
seq = null;
labels = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ public class OnlineTrainer extends AbstractTrainer {
*/
public static float eps = 1e-10f;

public TrainMethod method = TrainMethod.FastAverage;

public boolean DEBUG = false;
public boolean shuffle = true;
public boolean finalOptimized = false;
Expand All @@ -75,9 +73,6 @@ public class OnlineTrainer extends AbstractTrainer {
public int iternum;
protected float[] weights;

public enum TrainMethod {
Perceptron, Average, FastAverage
}
public OnlineTrainer(AlphabetFactory af, int iternum) {
//默认特征生成器
Generator gen = new SFGenerator();
Expand Down Expand Up @@ -167,61 +162,60 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
long beginTimeIter, endTimeIter;
int iter = 0;
int frac = numSamples / 10;

float[] averageWeights = null;
if (method == TrainMethod.Average || method == TrainMethod.FastAverage) {
averageWeights = new float[weights.length];
}


//平均化感知器需要减去的权重
float[] extraweight = null;
extraweight = new float[weights.length];



beginTime = System.currentTimeMillis();

if (shuffle)
trainset.shuffle(random);


//遍历的总样本数
int k=0;

while (iter++ < iternum) {
if (!simpleOutput) {
System.out.print("iter "+iter+": ");
}
}

float err = 0;
float errtot = 0;
int cnt = 0;
int cnttot = 0;
int progress = frac;
int progress = frac;

if (shuffle)
trainset.shuffle(random);

beginTimeIter = System.currentTimeMillis();

float[] innerWeights = null;
if (method == TrainMethod.Average) {
innerWeights = Arrays.copyOf(weights, weights.length);
}

for (int ii = 0; ii < numSamples; ii++) {
for (int ii = 0; ii < numSamples; ii++) {

k++;
Instance inst = trainset.getInstance(ii);
Predict pred = (Predict) inferencer.getBest(inst,2);

float l = loss.calc(pred.getLabel(0), inst.getTarget());
if (l > 0) {
err += l;
errtot++;
update.update(inst, weights, pred.getLabel(0), c);
update.update(inst, weights, k, extraweight, pred.getLabel(0), c);

}else{
if (pred.size() > 1)
update.update(inst, weights, pred.getLabel(1), c);
update.update(inst, weights, k, extraweight, pred.getLabel(1), c);
}
cnt += inst.length();
cnttot++;
if (method == TrainMethod.Average) {
for (int i = 0; i < weights.length; i++) {
innerWeights[i] += weights[i];
}
}
cnttot++;

if (!simpleOutput && progress != 0 && ii % progress == 0) {
System.out.print('.');
progress += frac;
}
}
}

}//end for

float curErrRate = err / cnt;

Expand Down Expand Up @@ -253,17 +247,7 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
if (devset != null) {
evaluate(devset);
}
System.out.println();

if (method == TrainMethod.Average) {
for (int i = 0; i < innerWeights.length; i++) {
averageWeights[i] += innerWeights[i] / numSamples;
}
} else if (method == TrainMethod.FastAverage) {
for (int i = 0; i < weights.length; i++) {
averageWeights[i] += weights[i];
}
}
System.out.println();

if (interim) {
Linear p = new Linear(inferencer, trainset.getAlphabetFactory());
Expand All @@ -277,18 +261,15 @@ public Linear train(InstanceSet trainset, InstanceSet devset) {
if(MyArrays.viarance(hisErrRate) < eps){
System.out.println("convergence!");
break;
}
}

if (method == TrainMethod.Average || method == TrainMethod.FastAverage) {
for (int i = 0; i < averageWeights.length; i++) {
averageWeights[i] /= iternum;
}
weights = null;
weights = averageWeights;
inferencer.setWeights(weights);
}

}

}// end while 外循环

//平均化参数
for (int i = 0; i < weights.length; i++) {
weights[i] -= extraweight[i]/k;
}

System.out.print("Non-Zero Weight Numbers: " + MyArrays.countNoneZero(weights));
if (finalOptimized) {
int[] idx = MyArrays.getTop(weights.clone(), threshold, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
* 抽象参数更新类,采用PA算法
* \mathbf{w_{t+1}} = \w_t + {\alpha^*(\Phi(x,y)- \Phi(x,\hat{y}))}.
* \alpha =\frac{1- \mathbf{w_t}^T \left(\Phi(x,y) - \Phi(x,\hat{y})\right)}{||\Phi(x,y) - \Phi(x,\hat{y})||^2}.
* @author Feng Ji
*
*/
public abstract class AbstractPAUpdate implements Update {
Expand All @@ -52,28 +51,14 @@ public AbstractPAUpdate(Loss loss) {
this.loss = loss;
}

/**
* 参数更新方法
* @param inst 样本实例
* @param weights 权重
* @param predict 预测答案
* @param c 步长阈值
* @return 预测答案和标准答案之间的损失
*/
public float update(Instance inst, float[] weights, Object predict, float c) {
return update(inst, weights, inst.getTarget(), predict, c);
}
@Override
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predict, float c) {
return update(inst, weights, k, extraweight, inst.getTarget(), predict, c);
}


/**
* 参数更新方法
* @param inst 样本实例
* @param weights 权重
* @param target 对照答案
* @param predict 预测答案
* @param c 步长阈值
* @return 预测答案和对照答案之间的损失
*/
public float update(Instance inst, float[] weights, Object target,
@Override
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object target,
Object predict, float c) {

int lost = diff(inst, weights, target, predict);
Expand All @@ -87,14 +72,17 @@ public float update(Instance inst, float[] weights, Object target,
alpha = alpha*inst.getWeight();
if(alpha>c){
alpha = c;
}else{
alpha=alpha;
}
}

int[] idx = diffv.indices();

for (int i = 0; i < idx.length; i++) {

weights[idx[i]] += diffv.get(idx[i]) * alpha;
for (int i = 0; i < idx.length; i++) {
float t = diffv.get(idx[i]) * alpha;
weights[idx[i]] += t;
extraweight[idx[i]] += t *k;
}
for (int i = 0; i < idx.length; i++) {

}
}

Expand All @@ -105,12 +93,12 @@ public float update(Instance inst, float[] weights, Object target,
}

/**
* 计算预测答案和对照答案之间的距离
* 计算预测类别和对照类别之间的距离
* @param inst 样本实例
* @param weights 权重
* @param target 对照答案
* @param predict 预测答案
* @return 预测答案和对照答案之间的距离
* @param target 对照类别
* @param predict 预测类别
* @return 预测类别和对照类别之间的距离
*/
protected abstract int diff(Instance inst, float[] weights, Object target,
Object predict);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

/**
* 线性分类的参数更新类,采用PA算法
* @author Feng Ji
*
*/
public class LinearMaxPAUpdate extends AbstractPAUpdate {

Expand Down Expand Up @@ -58,5 +56,4 @@ protected int diff(Instance inst, float[] weights, Object target,
return 1;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,32 @@
import org.fnlp.ml.types.Instance;

public interface Update {

public float update(Instance inst, float[] weights, Object predictLabel,
float c);

public float update(Instance inst, float[] weights, Object predictLabel,

/**
*
* @param inst 样本实例
* @param weights 权重
* @param k 目前遍历的样本数
* @param extraweight 平均化感知器需要减去的权重
* @param predictLabel 预测类别
* @param c 步长阈值
* @return 预测类别和真实类别之间的损失
*/
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predictLabel,
float c);

/**
*
* @param inst 样本实例
* @param weights 权重
* @param k 目前遍历的样本数
* @param extraweight 平均化感知器需要减去的权重
* @param predictLabel 预测类别
* @param goldenLabel 真实类别
* @param c 步长阈值
* @return 预测类别和真实类别之间的损失
*/
public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predictLabel,
Object goldenLabel, float c);

}
Loading

0 comments on commit 668183b

Please sign in to comment.