不管是实验室研究机器学习算法或是公司研发,都有需要自己改进算法的时候,下面就说说怎么在weka里增加改进的机器学习算法。

  一 添加分类算法的流程

  1 编写的分类器必须继承 Classifier或是Classifier的子类;下面用比较简单的zeroR举例说明;

  2 复写接口 buildClassifier,其是主要的方法之一,功能是构造分类器,训练模型;

  3 复写接口 classifyInstance,功能是预测一个标签的概率;或实现distributeForInstance,功能是对得到所有的概率分布;

  4 复写接口getCapabilities,其决定显示哪个分类器,否则为灰色;

  5 参数option的set/get方法;

  6 globalInfo和seedTipText方法,功能是说明作用;

  7 见 第二部分,把这个分类器增加到weka应用程序上;

  zeroR.java源码

  

/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/ /*
* ZeroR.java
* Copyright (C) 1999 Eibe Frank
*
*/ package weka.classifiers.rules; import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import java.io.*;
import java.util.*;
import weka.core.*; /**
* Class for building and using a 0-R classifier. Predicts the mean
* (for a numeric class) or the mode (for a nominal class).
*
* @author Eibe Frank ([email protected])
* @version $Revision: 1.11 $
*/
public class ZeroR extends Classifier implements WeightedInstancesHandler { /** The class value 0R predicts. */
private double m_ClassValue; /** The number of instances in each class (null if class numeric). */
private double [] m_Counts; /** The class attribute. */
private Attribute m_Class; /**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for building and using a 0-R classifier. Predicts the mean "
+ "(for a numeric class) or the mode (for a nominal class).";
} /**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception { double sumOfWeights = ; m_Class = instances.classAttribute();
m_ClassValue = ;
switch (instances.classAttribute().type()) {
case Attribute.NUMERIC:
m_Counts = null;
break;
case Attribute.NOMINAL:
m_Counts = new double [instances.numClasses()];
for (int i = ; i < m_Counts.length; i++) {
m_Counts[i] = ;
}
sumOfWeights = instances.numClasses();
break;
default:
throw new Exception("ZeroR can only handle nominal and numeric class"
+ " attributes.");
}
Enumeration enu = instances.enumerateInstances();
while (enu.hasMoreElements()) {
Instance instance = (Instance) enu.nextElement();
if (!instance.classIsMissing()) {
if (instances.classAttribute().isNominal()) {
m_Counts[(int)instance.classValue()] += instance.weight();
} else {
m_ClassValue += instance.weight() * instance.classValue();
}
sumOfWeights += instance.weight();
}
}
if (instances.classAttribute().isNumeric()) {
if (Utils.gr(sumOfWeights, )) {
m_ClassValue /= sumOfWeights;
}
} else {
m_ClassValue = Utils.maxIndex(m_Counts);
Utils.normalize(m_Counts, sumOfWeights);
}
} /**
* Classifies a given instance.
*
* @param instance the instance to be classified
* @return index of the predicted class
*/
public double classifyInstance(Instance instance) { return m_ClassValue;
} /**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if class is numeric
*/
public double [] distributionForInstance(Instance instance)
throws Exception { if (m_Counts == null) {
double[] result = new double[];
result[] = m_ClassValue;
return result;
} else {
return (double []) m_Counts.clone();
}
} /**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString() { if (m_Class == null) {
return "ZeroR: No model built yet.";
}
if (m_Counts == null) {
return "ZeroR predicts class value: " + m_ClassValue;
} else {
return "ZeroR predicts class value: " + m_Class.value((int) m_ClassValue);
}
} /**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) { try {
System.out.println(Evaluation.evaluateModel(new ZeroR(), argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}

  二 添加模糊聚类算法流程

  1.按照weka接口,写好一个模糊聚类算法,源码见最下面FuzzyCMeans.java ;并

  2.把源码拷贝到weka.clusterers路径下;

  3.修改 weka.gui.GenericObjectEditor.props ,在#Lists the Clusterers I want to choose from 的 weka.clusterers.Clusterer=\下加入:weka.clusterers.FuzzyCMeans

  4. 相应的修改 weka.gui.GenericPropertiesCreator.props ,此去不用修改,因为包 weka.clusterers 已经存在,若加入新的包时则必须修改这里,加入新的包;

FuzzyCMeans.java源码:

/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/ /*
* FCM.java
* Copyright (C) 2007 Wei Xiaofei
*
*/
package weka.clusterers; import weka.classifiers.rules.DecisionTableHashKey;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues; import java.util.Enumeration;
import java.util.HashMap;
import java.util.Random;
import java.util.Vector; /**
<!-- globalinfo-start -->
* Cluster data using the Fuzzy C means algorithm
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -N &lt;num&gt;
* number of clusters.
* (default 2).</pre>
*
* <pre> -F &lt;num&gt;
* exponent.
* (default 2).</pre>
*
* <pre> -S &lt;num&gt;
* Random number seed.
* (default 10)</pre>
*
<!-- options-end -->
*
* @author Wei Xiaofei
* @version 1.03
* @see RandomizableClusterer
*/
public class FuzzyCMeans
extends RandomizableClusterer
implements NumberOfClustersRequestable, WeightedInstancesHandler { /** for serialization */
static final long serialVersionUID = -2134543132156464L; /**
* replace missing values in training instances
* 替换训练集中的缺省值
*/
private ReplaceMissingValues m_ReplaceMissingFilter; /**
* number of clusters to generate
* 产生聚类的个数
*/
private int m_NumClusters = ; /**
* D: d(i,j)=||c(i)-x(j)||为第i个聚类中心与第j个数据点间的欧几里德距离
*/
private Matrix D; // private Matrix U; /**
* holds the fuzzifier
* 模糊算子(加权指数)
*/
private double m_fuzzifier = ; /**
* holds the cluster centroids
* 聚类中心
*/
private Instances m_ClusterCentroids; /**
* Holds the standard deviations of the numeric attributes in each cluster
* 每个聚类的标准差
*/
private Instances m_ClusterStdDevs; /**
* For each cluster, holds the frequency counts for the values of each
* nominal attribute
*/
private int [][][] m_ClusterNominalCounts; /**
* The number of instances in each cluster
* 每个聚类包含的实例个数
*/
private int [] m_ClusterSizes; /**
* attribute min values
* 属性最小值
*/
private double [] m_Min; /**
* attribute max values
* 属性最大值
*/
private double [] m_Max; /**
* Keep track of the number of iterations completed before convergence
* 迭代次数
*/
private int m_Iterations = ; /**
* Holds the squared errors for all clusters
* 平方误差
*/
private double [] m_squaredErrors; /**
* the default constructor
* 初始构造器
*/
public FuzzyCMeans () {
super(); m_SeedDefault = ;//初始化种子个数
setSeed(m_SeedDefault);
} /**
* Returns a string describing this clusterer
* @return a description of the evaluator suitable for
* displaying in the explorer/experimenter gui
* 全局信息, 在图形介面显示
*/
public String globalInfo() {
return "Cluster data using the fuzzy k means algorithm";
} /**
* Returns default capabilities of the clusterer.
*
* @return the capabilities of this clusterer
* 聚类容器
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities(); result.disableAll();
result.enable(Capability.NO_CLASS); // attributes
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES); return result;
} /**
* Generates a clusterer. Has to initialize all fields of the clusterer
* that are not being set via options.
*
* @param data set of instances serving as training data
* @throws Exception if the clusterer has not been
* generated successfully
* 聚类产生函数
*/
public void buildClusterer(Instances data) throws Exception { // can clusterer handle the data?检测数据能否聚类
getCapabilities().testWithFail(data); m_Iterations = ; m_ReplaceMissingFilter = new ReplaceMissingValues();
Instances instances = new Instances(data);//实例
instances.setClassIndex(-);
m_ReplaceMissingFilter.setInputFormat(instances);
instances = Filter.useFilter(instances, m_ReplaceMissingFilter); m_Min = new double [instances.numAttributes()];
m_Max = new double [instances.numAttributes()];
for (int i = ; i < instances.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;//随机分配不定值
} m_ClusterCentroids = new Instances(instances, m_NumClusters);//聚类中心
int[] clusterAssignments = new int [instances.numInstances()]; for (int i = ; i < instances.numInstances(); i++) {
updateMinMax(instances.instance(i));//更新最大最小值
} Random RandomO = new Random(getSeed());//随机数
int instIndex;
HashMap initC = new HashMap();
DecisionTableHashKey hk = null;
/* 利用决策表随机生成聚类中心 */
for (int j = instances.numInstances() - ; j >= ; j--) {
instIndex = RandomO.nextInt(j+);
hk = new DecisionTableHashKey(instances.instance(instIndex),
instances.numAttributes(), true);
if (!initC.containsKey(hk)) {
m_ClusterCentroids.add(instances.instance(instIndex));
initC.put(hk, null);
}
instances.swap(j, instIndex); if (m_ClusterCentroids.numInstances() == m_NumClusters) {
break;
}
} m_NumClusters = m_ClusterCentroids.numInstances();//聚类个数=聚类中心个数 D = new Matrix(solveD(instances).getArray());//求聚类中心到每个实例的距离 int i, j;
int n = instances.numInstances();
Instances [] tempI = new Instances[m_NumClusters];
m_squaredErrors = new double [m_NumClusters];
m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][]; Matrix U = new Matrix(solveU(instances).getArray());//初始化隶属矩阵U
double q = ;//初始化价值函数值
while (true) {
m_Iterations++;
for (i = ; i < instances.numInstances(); i++) {
Instance toCluster = instances.instance(i);
int newC = clusterProcessedInstance(toCluster, true);//聚类处理实例,即输入的实例应该聚到哪一个簇?! clusterAssignments[i] = newC;
} // update centroids 更新聚类中心
m_ClusterCentroids = new Instances(instances, m_NumClusters);
for (i = ; i < m_NumClusters; i++) {
tempI[i] = new Instances(instances, );
}
for (i = ; i < instances.numInstances(); i++) {
tempI[clusterAssignments[i]].add(instances.instance(i));
} for (i = ; i < m_NumClusters; i++) { double[] vals = new double[instances.numAttributes()];
for (j = ; j < instances.numAttributes(); j++) { double sum1 = , sum2 = ;
for (int k = ; k < n; k++) {
sum1 += U.get(i, k) * U.get(i, k) * instances.instance(k).value(j);
sum2 += U.get(i, k) * U.get(i, k);
}
vals[j] = sum1 / sum2; }
m_ClusterCentroids.add(new Instance(1.0, vals)); } D = new Matrix(solveD(instances).getArray());
U = new Matrix(solveU(instances).getArray());//计算新的聿属矩阵U
double q1 = ;//新的价值函数值
for (i = ; i < m_NumClusters; i++) {
for (j = ; j < n; j++) {
/* 计算价值函数值 即q1 += U(i,j)^m * d(i,j)^2 */
q1 += Math.pow(U.get(i, j), getFuzzifier()) * D.get(i, j) * D.get(i, j);
}
} /* 上次价值函数值的改变量(q1 -q)小于某个阀值(这里用机器精度:2.2204e-16) */
if (q1 - q < 2.2204e-16) {
break;
}
q = q1;
} /* 计算标准差 跟K均值一样 */
m_ClusterStdDevs = new Instances(instances, m_NumClusters);
m_ClusterSizes = new int [m_NumClusters];
for (i = ; i < m_NumClusters; i++) {
double [] vals2 = new double[instances.numAttributes()];
for (j = ; j < instances.numAttributes(); j++) {
if (instances.attribute(j).isNumeric()) {//判断属性是否是数值型的?!
vals2[j] = Math.sqrt(tempI[i].variance(j));
} else {
vals2[j] = Instance.missingValue();
}
}
m_ClusterStdDevs.add(new Instance(1.0, vals2));//1.0代表权值, vals2代表属性值
m_ClusterSizes[i] = tempI[i].numInstances();
}
} /**
* clusters an instance that has been through the filters
*
* @param instance the instance to assign a cluster to
* @param updateErrors if true, update the within clusters sum of errors
* @return a cluster number
* 聚类一个实例, 返回实例应属于哪一个簇的编号
* 首先计算输入的实例到所有聚类中心的距离, 哪里距离最小
* 这个实例就属于哪一个聚类中心所在簇
*/
private int clusterProcessedInstance(Instance instance, boolean updateErrors) {
double minDist = Integer.MAX_VALUE;
int bestCluster = ;
for (int i = ; i < m_NumClusters; i++) {
double dist = distance(instance, m_ClusterCentroids.instance(i));
if (dist < minDist) {
minDist = dist;
bestCluster = i;
}
}
if (updateErrors) {
m_squaredErrors[bestCluster] += minDist;
}
return bestCluster;
} /**
* Classifies a given instance.
*
* @param instance the instance to be assigned to a cluster
* @return the number of the assigned cluster as an interger
* if the class is enumerated, otherwise the predicted value
* @throws Exception if instance could not be classified
* successfully
* 分类一个实例, 调用clusterProcessedInstance()函数
*/
public int clusterInstance(Instance instance) throws Exception {
m_ReplaceMissingFilter.input(instance);
m_ReplaceMissingFilter.batchFinished();
Instance inst = m_ReplaceMissingFilter.output(); return clusterProcessedInstance(inst, false);
} /**
* 计算矩阵D, 即 d(i,j)=||c(i)-x(j)||
*/
private Matrix solveD(Instances instances) {
int n = instances.numInstances();
Matrix D = new Matrix(m_NumClusters, n);
for (int i = ; i < m_NumClusters; i++) {
for (int j = ; j < n; j++) {
D.set(i, j, distance(instances.instance(j), m_ClusterCentroids.instance(i)));
if (D.get(i, j) == ) {
D.set(i, j, 0.000000000001);
}
}
} return D;
} /**
* 计算聿属矩阵U, 即U(i,j) = 1 / sum(d(i,j)/ d(k,j))^(2/(m-1)
*/
private Matrix solveU(Instances instances) {
int n = instances.numInstances();
int i, j;
Matrix U = new Matrix(m_NumClusters, n); for (i = ; i < m_NumClusters; i++) {
for (j = ; j < n; j++) {
double sum = ;
for (int k = ; k < m_NumClusters; k++) {
//d(i,j)/d(k,j)^(2/(m-1)
sum += Math.pow(D.get(i, j) / D.get(k, j), /(getFuzzifier() - ));
}
U.set(i, j, Math.pow(sum, -));
}
}
return U;
}
/**
* Calculates the distance between two instances
*
* @param first the first instance
* @param second the second instance
* @return the distance between the two given instances
* 计算两个实例之间的距离, 返回欧几里德距离
*/
private double distance(Instance first, Instance second) { double val1;
double val2;
double dist = 0.0; for (int i = ; i <first.numAttributes(); i++) {
val1 = first.value(i);
val2 = second.value(i); dist += (val1 - val2) * (val1 - val2);
}
dist = Math.sqrt(dist);
return dist;
} /**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
* 更新所有属性最大最小值, 跟K均值里的函数一样
*/
private void updateMinMax(Instance instance) { for (int j = ;j < m_ClusterCentroids.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Min[j])) {
m_Min[j] = instance.value(j);
m_Max[j] = instance.value(j);
} else {
if (instance.value(j) < m_Min[j]) {
m_Min[j] = instance.value(j);
} else {
if (instance.value(j) > m_Max[j]) {
m_Max[j] = instance.value(j);
}
}
}
}
}
} /**
* Returns the number of clusters.
*
* @return the number of clusters generated for a training dataset.
* @throws Exception if number of clusters could not be returned
* successfully
* 返回聚类个数
*/
public int numberOfClusters() throws Exception {
return m_NumClusters;
} /**
* 返回模糊算子, 即加权指数
*
* @return 加权指数
* @throws Exception 加权指数不能成功返回
*/
public double fuzzifier() throws Exception {
return m_fuzzifier;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
* 返回一个枚举描述的活动选项(菜单)
*/
public Enumeration listOptions () {
Vector result = new Vector(); result.addElement(new Option(
"\tnumber of clusters.\n"
+ "\t(default 2).",
"N", , "-N <num>")); result.addElement(new Option(
"\texponent.\n"
+ "\t(default 2.0).",
"F", , "-F <num>")); Enumeration en = super.listOptions();
while (en.hasMoreElements())
result.addElement(en.nextElement()); return result.elements();
} /**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
* 返回文本信息
*/
public String numClustersTipText() {
return "set number of clusters";
} /**
* set the number of clusters to generate
*
* @param n the number of clusters to generate
* @throws Exception if number of clusters is negative
* 设置聚类个数
*/
public void setNumClusters(int n) throws Exception {
if (n <= ) {
throw new Exception("Number of clusters must be > 0");
}
m_NumClusters = n;
} /**
* gets the number of clusters to generate
*
* @return the number of clusters to generate
* 取聚类个数
*/
public int getNumClusters() {
return m_NumClusters;
} /**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
* 返回文本信息
*/
public String fuzzifierTipText() {
return "set fuzzifier";
} /**
* set the fuzzifier
*
* @param f fuzzifier
* @throws Exception if exponent is negative
* 设置模糊算子
*/
public void setFuzzifier(double f) throws Exception {
if (f <= ) {
throw new Exception("F must be > 1");
}
m_fuzzifier= f;
} /**
* get the fuzzifier
*
* @return m_fuzzifier
* 取得模糊算子
*/
public double getFuzzifier() {
return m_fuzzifier;
} /**
* Parses a given list of options. <p/>
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -N &lt;num&gt;
* number of clusters.
* (default 2).</pre>
*
* <pre> -F &lt;num&gt;
* fuzzifier.
* (default 2.0).</pre>
*
* <pre> -S &lt;num&gt;
* Random number seed.
* (default 10)</pre>
*
<!-- options-end -->
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
* 设置活动选项
*/
public void setOptions (String[] options)
throws Exception { String optionString = Utils.getOption('N', options); if (optionString.length() != ) {
setNumClusters(Integer.parseInt(optionString));
} optionString = Utils.getOption('F', options); if (optionString.length() != ) {
setFuzzifier((new Double(optionString)).doubleValue());
}
super.setOptions(options);
} /**
* Gets the current settings of FuzzyCMeans
*
* @return an array of strings suitable for passing to setOptions()
* 取得活动选项
*/
public String[] getOptions () {
int i;
Vector result;
String[] options; result = new Vector(); result.add("-N");
result.add("" + getNumClusters()); result.add("-F");
result.add("" + getFuzzifier()); options = super.getOptions();
for (i = ; i < options.length; i++)
result.add(options[i]); return (String[]) result.toArray(new String[result.size()]);
} /**
* return a string describing this clusterer
*
* @return a description of the clusterer as a string
* 结果显示
*/
public String toString() {
int maxWidth = ;
for (int i = ; i < m_NumClusters; i++) {
for (int j = ;j < m_ClusterCentroids.numAttributes(); j++) {
if (m_ClusterCentroids.attribute(j).isNumeric()) {
double width = Math.log(Math.abs(m_ClusterCentroids.instance(i).value(j))) /
Math.log(10.0);
width += 1.0;
if ((int)width > maxWidth) {
maxWidth = (int)width;
}
}
}
}
StringBuffer temp = new StringBuffer();
String naString = "N/A";
for (int i = ; i < maxWidth+; i++) {
naString += " ";
}
temp.append("\nFuzzy C-means\n======\n");
temp.append("\nNumber of iterations: " + m_Iterations+"\n");
temp.append("Within cluster sum of squared errors: " + Utils.sum(m_squaredErrors)); temp.append("\n\nCluster centroids:\n");
for (int i = ; i < m_NumClusters; i++) {
temp.append("\nCluster "+i+"\n\t");
temp.append("\n\tStd Devs: ");
for (int j = ; j < m_ClusterStdDevs.numAttributes(); j++) {
if (m_ClusterStdDevs.attribute(j).isNumeric()) {
temp.append(" "+Utils.doubleToString(m_ClusterStdDevs.instance(i).value(j),
maxWidth+, ));
} else {
temp.append(" "+naString);
}
}
}
temp.append("\n\n");
return temp.toString();
} /**
* Gets the the cluster centroids
*
* @return the cluster centroids
* 取得聚类中心
*/
public Instances getClusterCentroids() {
return m_ClusterCentroids;
} /**
* Gets the standard deviations of the numeric attributes in each cluster
*
* @return the standard deviations of the numeric attributes
* in each cluster
* 聚得标准差
*/
public Instances getClusterStandardDevs() {
return m_ClusterStdDevs;
} /**
* Returns for each cluster the frequency counts for the values of each
* nominal attribute
*
* @return the counts
*/
public int [][][] getClusterNominalCounts() {
return m_ClusterNominalCounts;
} /**
* Gets the squared error for all clusters
*
* @return the squared error
* 取得平方差
*/
public double getSquaredError() {
return Utils.sum(m_squaredErrors);
} /**
* Gets the number of instances in each cluster
*
* @return The number of instances in each cluster
* 取每个簇的实例个数
*/
public int [] getClusterSizes() {
return m_ClusterSizes;
} /**
* Main method for testing this class.
*
* @param argv should contain the following arguments: <p>
* -t training file [-N number of clusters]
* 主函数
*/
public static void main (String[] argv) {
runClusterer(new FuzzyCMeans (), argv);
}
}
05-11 17:46