一.论文《QuickScorer:a Fast Algorithm to Rank Documents with Additive Ensembles of Regression Trees》是为了解决LTR模型的预测问题,如果LTR中的LambdaMart在生成模型时产生的树数和叶结点过多,在对样本打分预测时会遍历每棵树,这样在线上使用时效率较慢,这篇文章主要就是利用了bitvector方法加速打分预测。代码我找了很久没找到开源的,后来无意中在Solr ltr中看到被改动过了的源码,不过这个源码集成在solr中,这里暂时贴出来,后期再剥离出,集成到ranklib中,以便使用。

二.图片解说

1. Ensemble trees原始打分过程

关于Additive Ensembles of Regression Trees模型的快速打分预测-LMLPHP

像gbdt,lambdamart,xgboost或lightgbm等这样的集成树模型在打分预测阶段,比如来了一个样本,这个样本是vector形式输入到每一棵树中,然后在每棵树中像if else这样的过程走到或映射到每棵树的一个节点中,这个节点就是每棵树的打分,然后将每棵树的打分乘上学习率(shrinkage)加和就是此样本的预测分。

2.论文中提到的打分过程

A.为回归树中的每个分枝打上true和false标签

关于Additive Ensembles of Regression Trees模型的快速打分预测-LMLPHP

比如图中样本X=[0.2,1.1,0.2],在回归树的branch中判断X[0],X[1],X[2]的true和false,比如图中根结点X[1]<=1.0,但样本X[1]=1.1,所以是false(走左边是true,右边是false),这样将所有branch打上true和false标签(可以直接打上false标志,不用考虑true),后面需要用到所有的false branch。

B.为每个branch分配一个bitvector

关于Additive Ensembles of Regression Trees模型的快速打分预测-LMLPHP

这个bitvector中的"0"表示true leaves,比如"001111"表示6个叶结点中的最左边两个叶结点是候选节点。“110011”表示在右子树中true的结点只有中间两个,作为候选结点。

C.打分阶段

关于Additive Ensembles of Regression Trees模型的快速打分预测-LMLPHP

此阶段是最后的打分预测阶段,根据前几个图的过程,将所有branch为false的bitvector按位与操作,就会得出样本落在哪个叶结点上。比如图中的结果是"001101",最左边为1的便是最终的叶结点的编号,每个回归树都会这样操作得到预测值,乘上学习率(shrinkage)然后加和就会得到一个样本的预测值。

三.代码

 import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils; import java.util.*; public class MultipleAdditiveTreesModel extends LTRScoringModel { // 特征名:索引(从0开始)
private final HashMap<String, Integer> fname2index = new HashMap();
private List<RegressionTree> trees; private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) {
MultipleAdditiveTreesModel.RegressionTree rt = new MultipleAdditiveTreesModel.RegressionTree();
if(map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
} return rt;
} private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
MultipleAdditiveTreesModel.RegressionTreeNode rtn = new MultipleAdditiveTreesModel.RegressionTreeNode();
if(map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
} return rtn;
} public void setTrees(Object trees) {
this.trees = new ArrayList();
Iterator var2 = ((List)trees).iterator(); while(var2.hasNext()) {
Object o = var2.next();
MultipleAdditiveTreesModel.RegressionTree rt = this.createRegressionTree((Map)o);
this.trees.add(rt);
}
} public void setTrees(List<RegressionTree> trees) {
this.trees = trees;
} public List<RegressionTree> getTrees() {
return this.trees;
} public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params); for(int i = 0; i < features.size(); ++i) {
String key = ((Feature)features.get(i)).getName();
this.fname2index.put(key, Integer.valueOf(i));//特征名:索引
} } public void validate() throws ModelException {
super.validate();
if(this.trees == null) {
throw new ModelException("no trees declared for model " + this.name);
} else {
Iterator var1 = this.trees.iterator(); while(var1.hasNext()) {
MultipleAdditiveTreesModel.RegressionTree tree = (MultipleAdditiveTreesModel.RegressionTree)var1.next();
tree.validate();
} }
} public float score(float[] modelFeatureValuesNormalized) {
float score = 0.0F; MultipleAdditiveTreesModel.RegressionTree t;
for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) {
t = (MultipleAdditiveTreesModel.RegressionTree)var3.next();
} return score;
} public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
float[] fv = new float[featureExplanations.size()];
int index = 0; for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) {
Explanation featureExplain = (Explanation)details.next();
fv[index] = featureExplain.getValue();
} ArrayList var12 = new ArrayList();
index = 0; for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) {
MultipleAdditiveTreesModel.RegressionTree t = (MultipleAdditiveTreesModel.RegressionTree)var13.next();
float score = t.score(fv);
Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]);
var12.add(p);
} return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12);
} public String toString() {
StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
sb.append("(name=").append(this.getName());
sb.append(",trees=["); for(int ii = 0; ii < this.trees.size(); ++ii) {
if(ii > 0) {
sb.append(',');
} sb.append(this.trees.get(ii));
} sb.append("])");
return sb.toString();
} public class RegressionTree {
private Float weight;
private MultipleAdditiveTreesModel.RegressionTreeNode root; public void setWeight(float weight) {
this.weight = new Float(weight);
} public void setWeight(String weight) {
this.weight = new Float(weight);
} public float getWeight() {
return this.weight;
} public void setRoot(Object root) {
this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root);
} public RegressionTreeNode getRoot() {
return this.root;
} public float score(float[] featureVector) {
return this.weight.floatValue() * this.root.score(featureVector);
} public String explain(float[] featureVector) {
return this.root.explain(featureVector);
} public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("(weight=").append(this.weight);
sb.append(",root=").append(this.root);
sb.append(")");
return sb.toString();
} public RegressionTree() {
} public void validate() throws ModelException {
if(this.weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight");
} else if(this.root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree");
} else {
this.root.validate();
}
}
} public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1.0E-6F;
private float value = 0.0F;
private String feature;
private int featureIndex = -1;
private Float threshold;
private MultipleAdditiveTreesModel.RegressionTreeNode left;
private MultipleAdditiveTreesModel.RegressionTreeNode right; public void setValue(float value) {
this.value = value;
} public void setValue(String value) {
this.value = Float.parseFloat(value);
} public void setFeature(String feature) {
this.feature = feature;
Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
this.featureIndex = idx == null?-1:idx.intValue();
} public int getFeatureIndex() {
return this.featureIndex;
} public void setThreshold(float threshold) {
this.threshold = Float.valueOf(threshold + 1.0E-6F);
} public void setThreshold(String threshold) {
this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F);
} public float getThreshold() {
return this.threshold;
} public void setLeft(Object left) {
this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left);
} public RegressionTreeNode getLeft() {
return this.left;
} public void setRight(Object right) {
this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right);
} public RegressionTreeNode getRight() {
return this.right;
} public boolean isLeaf() {
return this.feature == null;
} public float score(float[] featureVector) {
return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F);
} public String explain(float[] featureVector) {
if(this.isLeaf()) {
return "val: " + this.value;
} else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) {
String rval;
if(featureVector[this.featureIndex] <= this.threshold.floatValue()) {
rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | ";
return rval + this.left.explain(featureVector);
} else {
rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | ";
return rval + this.right.explain(featureVector);
}
} else {
return "\'" + this.feature + "\' does not exist in FV, Return Zero";
}
} public String toString() {
StringBuilder sb = new StringBuilder();
if(this.isLeaf()) {
sb.append(this.value);
} else {
sb.append("(feature=").append(this.feature);
sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F);
sb.append(",left=").append(this.left);
sb.append(",right=").append(this.right);
sb.append(')');
} return sb.toString();
} public RegressionTreeNode() {
} public void validate() throws ModelException {
if(this.isLeaf()) {
if(this.left != null || this.right != null) {
throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right);
}
} else if(null == this.threshold) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
} else if(null == this.left) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
} else {
this.left.validate();
if(null == this.right) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
} else {
this.right.validate();
}
}
}
} }
 import org.apache.commons.lang.ArrayUtils;
import org.apache.lucene.util.CloseableThreadLocal;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer; import java.util.*; public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{ private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL; // 64bits De Bruijn Sequence
// see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6)
private static final long HASH_BITS = 0x022fdd63cc95386dL;
private static final int[] hashTable = new int[64]; static {
long hash = HASH_BITS;
for (int i = 0; i < 64; ++i) {
hashTable[(int) (hash >>> 58)] = i;
hash <<= 1;
}
} /**
* Finds the index of rightmost bit with O(1) by using De Bruijn strategy.
*
* @param bits target bits (64bits)
* @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a>
*/
private static int findIndexOfRightMostBit(long bits) {
return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)];
} /**
* The number of trees of this model.
*/
private int treeNum; /**
* Weights of each tree.
*/
private float[] weights; /**
* List of all leaves of this model.
* We use tree instead of value to manage wide (i.e., more than 64 leaves) trees.
*/
private RegressionTreeNode[] leaves; /**
* Offsets of each leaf block correspond to each tree.
*/
private int[] leafOffsets; /**
* The number of conditions of this model.
*/
private int condNum; /**
* Thresholds of each condition.
* These thresholds are grouped by corresponding feature and each block is sorted by threshold values.
*/
private float[] thresholds; /**
* Corresponding featureIndex of each condition.
*/
private int[] featureIndexes; /**
* Offsets of each condition block correspond to each feature.
*/
private int[] condOffsets; /**
* Forward bitvectors of each condition which correspond to original additive trees.
*/
private long[] forwardBitVectors; /**
* Backward bitvectors of each condition which correspond to inverted additive trees.
*/
private long[] backwardBitVectors; /**
* Mappings from threasholdes index to tree indexes.
*/
private int[] treeIds; /**
* Bitvectors of each tree for calculating the score.
* We reuse bitvectors instance in each thread to prevent from re-allocating arrays.
*/
private CloseableThreadLocal<long[]> threadLocalTreeBitvectors = null; /**
* Boolean statistical tendency of this model.
* If conditions of the model tend to be false, we use inverted bitvectors for speeding up.
*/
private volatile float falseRatio = 0.5f; /**
* The decay factor for updating falseRatio in each evaluation step.
* This factor is used like "{@code ratio = preRatio * decay ratio * (1 - decay)}".
*/
private float falseRatioDecay = 0.99f; /**
* Comparable node cost for selecting leaf candidates.
*/
private static class NodeCost implements Comparable<NodeCost> {
private final int id;
private final int cost;
private final int depth;
private final int left;
private final int right; private NodeCost(int id, int cost, int depth, int left, int right) {
this.id = id;
this.cost = cost;
this.depth = depth;
this.left = left;
this.right = right;
} public int getId() {
return id;
} public int getLeft() {
return left;
} public int getRight() {
return right;
} /**
* Sorts by cost and depth.
* We prefer cheaper cost and deeper one.
*/
@Override
public int compareTo(NodeCost n) {
if (cost != n.cost) {
return Integer.compare(cost, n.cost);
} else if (depth != n.depth) {
return Integer.compare(n.depth, depth); // revere order
} else {
return Integer.compare(id, n.id);
}
}
} /**
* Comparable condition for constructing and sorting bitvectors.
*/
private static class Condition implements Comparable<Condition> {
private final int featureIndex;
private final float threshold;
private final int treeId;
private final long forwardBitvector;
private final long backwardBitvector; private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) {
this.featureIndex = featureIndex;
this.threshold = threshold;
this.treeId = treeId;
this.forwardBitvector = forwardBitvector;
this.backwardBitvector = backwardBitvector;
} int getFeatureIndex() {
return featureIndex;
} float getThreshold() {
return threshold;
} int getTreeId() {
return treeId;
} long getForwardBitvector() {
return forwardBitvector;
} long getBackwardBitvector() {
return backwardBitvector;
} /*
* Sort by featureIndex and threshold with ascent order.
*/
@Override
public int compareTo(Condition c) {
if (featureIndex != c.featureIndex) {
return Integer.compare(featureIndex, c.featureIndex);
} else {
return Float.compare(threshold, c.threshold);
}
}
} /**
* Base class for traversing node with depth first order.
*/
private abstract static class Visitor {
private int nodeId = 0; int getNodeId() {
return nodeId;
} void visit(RegressionTree tree) {
nodeId = 0;
visit(tree.getRoot(), 0);
} private void visit(RegressionTreeNode node, int depth) {
if (node.isLeaf()) {
doVisitLeaf(node, depth);
} else {
// visit children first
visit(node.getLeft(), depth + 1);
visit(node.getRight(), depth + 1); doVisitBranch(node, depth);
}
++nodeId;
} protected abstract void doVisitLeaf(RegressionTreeNode node, int depth); protected abstract void doVisitBranch(RegressionTreeNode node, int depth);
} /**
* {@link Visitor} implementation for calculating the cost of each node.
*/
private static class NodeCostVisitor extends Visitor { private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>();
private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>(); PriorityQueue<NodeCost> getNodeCostQueue() {
return nodeCostQueue;
} @Override
protected void doVisitLeaf(RegressionTreeNode node, int depth) {
nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1));
idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1));
} @Override
protected void doVisitBranch(RegressionTreeNode node, int depth) {
// calculate the cost of this node from children costs
final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop();
final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop();
final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue()); nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey()));
idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1));
}
} /**
* {@link Visitor} implementation for extracting leaves and bitvectors.
*/
private static class QuickScorerVisitor extends Visitor { private final int treeId;
private final int leafNum;
private final Set<Integer> leafIdSet;
private final Set<Integer> skipIdSet; private final Stack<Long> bitsStack = new Stack<>();
private final List<RegressionTreeNode> leafList = new ArrayList<>();
private final List<Condition> conditionList = new ArrayList<>(); private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) {
this.treeId = treeId;
this.leafNum = leafNum;
this.leafIdSet = leafIdSet;
this.skipIdSet = skipIdSet;
} List<RegressionTreeNode> getLeafList() {
return leafList;
} List<Condition> getConditionList() {
return conditionList;
} private long reverseBits(long bits) {
long revBits = 0L;
long mask = (1L << (leafNum - 1));
for (int i = 0; i < leafNum; ++i) {
if ((bits & mask) != 0L) revBits |= (1L << i);
mask >>>= 1;
}
return revBits;
} @Override
protected void doVisitLeaf(RegressionTreeNode node, int depth) {
if (skipIdSet.contains(getNodeId())) return; bitsStack.add(1L << leafList.size()); // we use rightmost bit for detecting leaf
leafList.add(node);
} @Override
protected void doVisitBranch(RegressionTreeNode node, int depth) {
if (skipIdSet.contains(getNodeId())) return; if (leafIdSet.contains(getNodeId())) {
// an endpoint of QuickScorer
doVisitLeaf(node, depth);
return;
} final long rightBits = bitsStack.pop(); // bits of false branch
final long leftBits = bitsStack.pop(); // bits of true branch
/*
* NOTE:
* forwardBitvector = ~leftBits
* backwardBitvector = ~(reverse(rightBits))
*/
conditionList.add(
new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits)));
bitsStack.add(leftBits | rightBits);
}
} public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
} /**
* Set falseRadioDecay parameter of this model.
*
* @param falseRatioDecay decay parameter for updating falseRatio
*/
public void setFalseRatioDecay(float falseRatioDecay) {
this.falseRatioDecay = falseRatioDecay;
} /**
* @see #setFalseRatioDecay(float)
*/
public void setFalseRatioDecay(String falseRatioDecay) {
this.falseRatioDecay = Float.parseFloat(falseRatioDecay);
} /**
* {@inheritDoc}
*/
@Override
public void validate() throws ModelException {
// validate trees before initializing QuickScorer
super.validate(); // initialize QuickScorer with validated trees
init(getTrees());
} /**
* Initializes quick scorer with given trees.
* 利用给定的树集初始化快速打分模型
*
* @param trees base additive trees model
*/
private void init(List<RegressionTree> trees) {
this.treeNum = trees.size();
this.weights = new float[trees.size()];
this.leafOffsets = new int[trees.size() + 1];
this.leafOffsets[0] = 0; // re-create tree bitvectors
if (this.threadLocalTreeBitvectors != null) this.threadLocalTreeBitvectors.close();
this.threadLocalTreeBitvectors = new CloseableThreadLocal<long[]>() {
@Override
protected long[] initialValue() {
return new long[treeNum];
}
}; int treeId = 0;
List<RegressionTreeNode> leafList = new ArrayList<>();
List<Condition> conditionList = new ArrayList<>();
for (RegressionTree tree : trees) {
// select up to 64 leaves from given tree
QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree); // extract leaves and conditions with selected leaf candidates
visitor.visit(tree);
leafList.addAll(visitor.getLeafList());
conditionList.addAll(visitor.getConditionList()); // update weight, offset and treeId
this.weights[treeId] = tree.getWeight();
this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size();
++treeId;
} // remap list to array for performance reason
this.leaves = leafList.toArray(new RegressionTreeNode[0]); // sort conditions by ascent order of featureIndex and threshold
Collections.sort(conditionList); // remap information of conditions
int idx = 0;
int preFeatureIndex = -1;
this.condNum = conditionList.size();
this.thresholds = new float[conditionList.size()];
this.forwardBitVectors = new long[conditionList.size()];
this.backwardBitVectors = new long[conditionList.size()];
this.treeIds = new int[conditionList.size()];
List<Integer> featureIndexList = new ArrayList<>();
List<Integer> condOffsetList = new ArrayList<>();
for (Condition condition : conditionList) {
this.thresholds[idx] = condition.threshold;
this.forwardBitVectors[idx] = condition.getForwardBitvector();
this.backwardBitVectors[idx] = condition.getBackwardBitvector();
this.treeIds[idx] = condition.getTreeId(); if (preFeatureIndex != condition.getFeatureIndex()) {
featureIndexList.add(condition.getFeatureIndex());
condOffsetList.add(idx);
preFeatureIndex = condition.getFeatureIndex();
} ++idx;
}
condOffsetList.add(conditionList.size()); // guard this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0]));
this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0]));
} /**
* Checks costs of all nodes and select leaves up to 64.
*
* <p>NOTE:
* We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits.
* However, this modification caused performance degradation in our experiments, and we decided to use this form.
*
* @param treeId index of given regression tree
* @param tree target regression tree
* @return QuickScorerVisitor with proper id sets
*/
private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) {
// calculate costs of all nodes
NodeCostVisitor nodeCostVisitor = new NodeCostVisitor();
nodeCostVisitor.visit(tree); // poll zero cost nodes (i.e., real leaves)
Set<Integer> leafIdSet = new HashSet<>();
Set<Integer> skipIdSet = new HashSet<>();
while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) {
if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break;
NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
leafIdSet.add(nodeCost.id);
} // merge leaves until the number of leaves reaches 64
while (leafIdSet.size() > 64) {
final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
assert nodeCost.left >= 0 && nodeCost.right >= 0; // update leaves
leafIdSet.remove(nodeCost.left);
leafIdSet.remove(nodeCost.right);
leafIdSet.add(nodeCost.id); // register previous leaves to skip ids
skipIdSet.add(nodeCost.left);
skipIdSet.add(nodeCost.right);
} return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet);
} /**
* {@inheritDoc}
*/
@Override
public float score(float[] modelFeatureValuesNormalized) {
assert threadLocalTreeBitvectors != null;
long[] treeBitvectors = threadLocalTreeBitvectors.get();
Arrays.fill(treeBitvectors, MAX_BITS); int falseNum = 0;
float score = 0.0f;
if (falseRatio <= 0.5) {
// use forward bitvectors
for (int i = 0; i < condOffsets.length - 1; ++i) {
final int featureIndex = featureIndexes[i];
for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) {
if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break;
treeBitvectors[treeIds[j]] &= forwardBitVectors[j];
++falseNum;
}
} for (int i = 0; i < leafOffsets.length - 1; ++i) {
final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized);
}
} else {
// use backward bitvectors
falseNum = condNum;
for (int i = 0; i < condOffsets.length - 1; ++i) {
final int featureIndex = featureIndexes[i];
for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) {
if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break;
treeBitvectors[treeIds[j]] &= backwardBitVectors[j];
--falseNum;
}
} for (int i = 0; i < leafOffsets.length - 1; ++i) {
final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized);
}
} // update false ratio
falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay);
return score;
} }
 import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.request.SolrQueryRequest;
import org.junit.Ignore;
import org.junit.Test; import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random; import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat; public class TestQuickScorerTreesModelBenchmark { /**
* 产生特征
* @param featureNum 特征个数
* @return
*/
private List<Feature> createDummyFeatures(int featureNum) {
List<Feature> features = new ArrayList<>();
for (int i = 0; i < featureNum; ++i) {
features.add(new Feature("fv_" + i, null) {
@Override
protected void validate() throws FeatureException { } @Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request,
Query originalQuery, Map<String, String[]> efi) throws IOException {
return null;
} @Override
public LinkedHashMap<String, Object> paramsToMap() {
return null;
}
});
}
return features;
} private List<Normalizer> createDummyNormalizer(int featureNum) {
List<Normalizer> normalizers = new ArrayList<>();
for (int i = 0; i < featureNum; ++i) {
normalizers.add(new IdentityNormalizer());
}
return normalizers;
} /**
* 创建单棵树
* 递归调用自己
* @param leafNum 叶子个数
* @param features 特征
* @param rand 产生随机数
* @return
*/
private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) {
Map<String, Object> node = new HashMap<>();
if (leafNum == 1) {
// leaf
node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
return node;
} // branch
node.put("feature", features.get(rand.nextInt(features.size())).getName());
node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
node.put("left", createRandomTree(leafNum / 2, features, rand));
node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand));
return node;
} /**
* 这里随机创建多棵树作为model测试
* @param treeNum 树的个数
* @param leafNum 叶子个数
* @param features 特征
* @param rand 产生随机数
* @return
*/
private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features,
Random rand) {
List<Object> trees = new ArrayList<>();
for (int i = 0; i < treeNum; ++i) {
Map<String, Object> tree = new HashMap<>();
tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率
tree.put("root", createRandomTree(leafNum, features, rand));
trees.add(tree);
}
return trees;
} /**
* 对比两个打分模型的分值是否一致
* @param featureNum 特征个数
* @param treeNum 树个数
* @param leafNum 叶子个数
* @param loopNum 样本个数
* @throws Exception
*/
private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
Random rand = new Random(0); List<Feature> features = createDummyFeatures(featureNum); //产生特征
List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化 for (int i = 0; i < loopNum; ++i) {
List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
"dummy", features, null);
matModel.setTrees(trees);
matModel.validate(); QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
null);
qstModel.setTrees(trees);//设置提供的树模型
qstModel.validate();//对提供的树结构进行验证 float[] featureValues = new float[featureNum];
for (int j = 0; j < 100; ++j) {
for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5) float expected = matModel.score(featureValues);
float actual = qstModel.score(featureValues);
assertThat(actual, is(expected));
//System.out.println("expected: " + expected + " actual: " + actual);
}
}
} /**
* 两个模型是否得分一致
*
* @throws Exception thrown if testcase failed to initialize models
*/
/*@Test
public void testAccuracy() throws Exception {
compareScore(25, 200, 32, 100);
//compareScore(19, 500, 31, 10000);
}*/ /**
* 对比两个打分模型打分的时间消耗
* @param featureNum 特征个数
* @param treeNum 树个数
* @param leafNum 叶子个数
* @param loopNum 样本个数
* @throws Exception
*/
private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
Random rand = new Random(0); //随机产生features
List<Feature> features = createDummyFeatures(featureNum);
//随机产生normalizer
List<Normalizer> norms = createDummyNormalizer(featureNum);
//随机创建trees
List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand); //初始化multiple additive trees model
MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
"dummy", features, null);
matModel.setTrees(trees);
matModel.validate(); //初始化quick scorer trees model
QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
null);
qstModel.setTrees(trees);
qstModel.validate(); //随机产生样本, loopNum * featureNum
float[][] featureValues = new float[loopNum][featureNum];
for (int i = 0; i < loopNum; ++i) {
for (int k = 0; k < featureNum; ++k) {
featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0)
}
} long start;
/*long matOpNsec = 0;
for (int i = 0; i < loopNum; ++i) {
start = System.nanoTime();
matModel.score(featureValues[i]);
matOpNsec += System.nanoTime() - start;
}
long qstOpNsec = 0;
for (int i = 0; i < loopNum; ++i) {
start = System.nanoTime();
qstModel.score(featureValues[i]);
qstOpNsec += System.nanoTime() - start;
}
System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op");
System.out.println("QuickScorerTreesModel : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/ long matOpNsec = 0;
start = System.currentTimeMillis();
for(int i = 0; i < loopNum; i++) {
matModel.score(featureValues[i]);
}
matOpNsec = System.currentTimeMillis() - start; long qstOpNsec = 0;
start = System.currentTimeMillis();
for(int i = 0; i < loopNum; i++) {
qstModel.score(featureValues[i]);
}
qstOpNsec = System.currentTimeMillis() - start; System.out.println("MultipleAdditiveTreesModel : " + matOpNsec); System.out.println("QuickScorerTreesModel : " + qstOpNsec); //assertThat(matOpNsec > qstOpNsec, is(true));
} /**
* 测试性能
* @throws Exception thrown if testcase failed to initialize models
*/ @Test
public void testPerformance() throws Exception {
//features,trees,leafs,samples
compareTime(20, 500, 61, 10000);
} }
05-16 04:27