首先是DATA类

import java.awt.print.Printable;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner; public class Data {
public Map<List<Double>, Integer> getTrainData() {
Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>(); try {
Scanner in=new Scanner(new File("G://download//testSet.txt"));
while(in.hasNextLine())
{
String str =in.nextLine();
String []strs=str.trim().split("\t");
List<Double> pointTmp=new ArrayList<>();
for(int i=0;i<strs.length-1;i++)
pointTmp.add(Double.parseDouble(strs[i]));
data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));
}
} catch (FileNotFoundException e) {
// TODO: handle exception
e.printStackTrace();
} return data;
} public static void main(String[] args)
{
Data data=new Data();
data.getTrainData();
}
}

  SVM类:

import java.awt.print.Printable;
import java.io.FileNotFoundException;
import java.io.ObjectInputStream.GetField;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry; public class SVM {
private List<ArrayList<Double>> trainData;
private List<Integer> labelTrainData;
private double sigma;
private double C;
private List<Double> alpha;
private double b;
private List<Double> E;
private int N;
private int dim;
private double tol;
private double eta;
private double eps;
private double eps2; public boolean satisfyKkt(int id)
{
double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)
if(Math.abs(this.alpha.get(id))<=this.eps)
{
if(ypgx-1<-this.tol) return false;
}
else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)
{
if(ypgx-1>this.tol) return false;
}
else {
if(Math.abs(ypgx-1)>this.tol) return false;
}
return true;
} public void updateE() { for(int i=0;i<this.N;i++)
{
double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);
this.E.set(i, Ei);
}
} public double kernelLinear(List<Double> X,List<Double> Y) {
//linear kernel function
int len=Y.size();
double s=0;
for(int i=0;i<len;i++)
s+=X.get(i)*Y.get(i);
return s;
} public double kernelRBF(List<Double> X,List<Double> Y)
{
//gauss kernel function int len=Y.size();
double s=0;
for(int i=0;i<len;i++)
s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));
s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));
return s;
} public double getGx(List<Double> X)
{
//calculate wx+b value
double s=0;
for(int i=0;i<this.N;i++)
{
//for debug
double debug1=kernelRBF(X, this.trainData.get(i));
double debug2=this.alpha.get(i); s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));
}
s+=this.b;
return s;
} public int update(int x1,int x2)
{
double low=0;
double high=0;
if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))
{
low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);
high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));
}
else
{
low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));
high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);
}
double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;
double newAlpha1=0; if(newAlpha2>high) newAlpha2=high;
else if(newAlpha2<low) newAlpha2=low;
newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2); if(Math.abs(newAlpha1)<=this.eps)
newAlpha1=0;
if(Math.abs(newAlpha2)<=this.eps)
newAlpha2=0;
if(Math.abs(newAlpha1-this.C)<=this.eps)
newAlpha1=this.C;
if(Math.abs(newAlpha2-this.C)<=this.eps)
newAlpha2=this.C;
if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)
return 0;
if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)
return 0; double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;
double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b; if(newAlpha1>0&&newAlpha1<this.C)
this.b=b1;
else if(newAlpha2>0&&newAlpha2<this.C)
this.b=b2;
else
this.b=(b1+b2)/2; this.alpha.set(x1,newAlpha1);
this.alpha.set(x2,newAlpha2);
updateE();
return 1;
}
public int selectAlpha2(int x1) { int x2=-1;
double maxDiff=-1;
//first select x2 from 0<a<c to max(E(x1)-E(x2)) for(int i=0;i<this.N;++i)
{
if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;
double diff=Math.abs(this.E.get(x1)-this.E.get(i));
if(diff>maxDiff)
{
maxDiff=diff;
x2=i;
}
} //second calculate eta (eta!=0)
if(x2!=-1)
{
this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));
if(eta!=0) return x2;
} //third if cannot find in the whole train set
for(int i=0;i<this.N;i++)
{
if(i==x1) continue;
this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));
if(Math.abs(this.eta)>this.eps) return i;
}
return -1; } public void SMO() {
//to solve alpha
int numChanged=0;
int cnt=0;
while(true)
{
cnt++;
System.out.println(cnt); numChanged=0;
for(int x1=0;x1<this.N;++x1)
{
if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;
if(!satisfyKkt(x1))
{
int x2=selectAlpha2(x1);
if(x2==-1) continue;
numChanged+=update(x1, x2);
}
}
if(numChanged==0)
{
for(int x1=0;x1<this.N;++x1)
{
if(!satisfyKkt(x1))
{
int x2=selectAlpha2(x1);
if(x2==-1) continue;
update(x1, x2);
numChanged++;
}
}
}
if(numChanged==0)
break;
}
} public SVM() {
//load train data Data data=new Data();
Map<List<Double>, Integer> Datas=data.getTrainData();
int totalData=Datas.size();
this.trainData=new ArrayList<ArrayList<Double>>();
this.labelTrainData=new ArrayList<Integer>();
this.alpha=new ArrayList<Double>();
this.E=new ArrayList<Double>(); int i=0;
for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet())
{
this.trainData.add((ArrayList<Double>) entry.getKey());
this.labelTrainData.add(entry.getValue());
this.alpha.add(0.0);
this.E.add(0.0-this.labelTrainData.get(i));
i++;
}
this.N=this.labelTrainData.size();
this.dim=this.trainData.get(0).size(); this.sigma=12;//sigma=1
this.C=0.5;//c=6
this.b=0.0;
this.tol=0.001;
this.eta=0;
this.eps=0.0000001;
this.eps2=0.00001;
} public double getB() {
//get b value
return this.b;
}
public double[] getLinearW() {
double []w=new double[this.N];
for(int i=0;i<this.N;i++)
{
for(int j=0;j<this.dim;j++)
{
w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j);
}
}
return w;
} public int predict(List<Double> x)
{
int ans=1;
double sum=0;
for(int i=0;i<this.N;i++)
{
sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i));
}
sum+=b;
if(sum>0)
ans=1;
else
ans=-1; return ans;
}
public static void main(String[] args) throws FileNotFoundException { SVM s=new SVM();
s.SMO();
PrintWriter out=new PrintWriter("g://download//resultpoints.txt");
for(int i=0;i<s.N;i++)
{
out.write((s.trainData.get(i).get(0)).toString());
out.write("\t");
out.write((s.trainData.get(i).get(1)).toString());
out.write("\t");
out.write(Integer.toString(s.predict(s.trainData.get(i))));
out.write("\n");
}
out.close();
//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction
double w[]=s.getLinearW();
System.out.println(w[0]+" "+w[1]+" "+s.b+"======");
} }

  

用线性核函数实现的SVM的到的分类结果

自己实现的SVM源码-LMLPHP

画图,是用python代码

from numpy import *
import matplotlib
import matplotlib.pyplot as plt
import numpy as np with open("g://download/myresult.txt") as f1:
data=f1.readlines(); plt.figure(figsize=(8, 5), dpi=80)
axes = plt.subplot(111)
type1_x = []
type1_y = []
type2_x = []
type2_y = []
for line in data:
x=line.strip().split('\t');
x1=float(x[0])
x2=float(x[1])
x3=int(x[2]) if x3==1:
type1_x.append(x1)
type1_y.append(x2)
else:
type2_x.append(x1)
type2_y.append(x2) type1 = axes.scatter(type1_x, type1_y,s=40, c='red' )
type2 = axes.scatter(type2_x, type2_y, s=40, c='green') W1 = 0.8148005405344305
W2 = -0.27263471796762484
B = -3.8392586254518437
x = np.linspace(-4,10,200)
y = (-W1/W2)*x+(-B/W2)
axes.plot(x,y,'b',lw=3) plt.xlabel('x1')
plt.ylabel('x2') axes.legend((type1, type2), ('0', '1'),loc=1)
plt.show() #0.8148005405344305 -0.27263471796762484 -3.8392586254518437

  用高斯核,当C=6,sigma=1时候

自己实现的SVM源码-LMLPHP

高斯核,当c=0.5,sigma=1时候

自己实现的SVM源码-LMLPHP

当C=0.5,sigma=12时候

自己实现的SVM源码-LMLPHP

说明C的大小和sigma的大小对高斯核影响是很大的

sigma是高斯核函数的参数

05-12 10:23