本文介绍了Spark | ML | Random Forest |从RandomForestClassificationModel的.txt加载经过训练的模型.toDebugString的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

使用Spark 1.6和ML库,我正在使用 toDebugString()保存经过训练的 RandomForestClassificationModel 的结果:

Using Spark 1.6 and the ML library I am saving the results of a trained RandomForestClassificationModel using toDebugString():

 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
 val stringModel =rfModel.toDebugString
 //save stringModel into a file in the driver in format .txt

所以我的想法是,将来读取文件 .txt 并加载经过训练的randomForest,有可能吗?

So my idea is that in the future read the file .txt and load the trained randomForest, is it possible?

谢谢!

推荐答案

至少对于Spark 2.1.0,您可以使用以下Java代码(对不起-没有Scala)执行此操作.但是,依赖未记录的格式可能会在不另行通知的情况下更改,这不是最明智的想法.

At least for Spark 2.1.0 you can do this with the following Java (sorry - no Scala) code. However, it may not be the smartest idea to rely on an undocumented format that may change without notice.

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
 * RandomForest.
 */
public abstract class RandomForest {

    private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);

    protected final List<Node> trees = new ArrayList<>();

    /**
     * @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
     * @throws IOException
     */
    public RandomForest(final URL model) throws IOException {
        try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) {
            Node node;
            while ((node = load(reader)) != null) {
                trees.add(node);
            }
        }
        if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
        if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
    }

    private static Node load(final BufferedReader reader) throws IOException {
        final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
        final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
        Node root = null;
        final List<Node> stack = new ArrayList<>();
        String line;
        while ((line = reader.readLine()) != null) {
            final String trimmed = line.trim();
            //System.out.println(trimmed);
            if (trimmed.startsWith("RandomForest")) {
                // skip the "Tree 1" line
                reader.readLine();
            } else if (trimmed.startsWith("Tree")) {
                break;
            } else if (trimmed.startsWith("If")) {
                // extract feature index
                final Matcher m = ifPattern.matcher(trimmed);
                m.matches();
                final int featureIndex = Integer.parseInt(m.group(1));
                final String operator = m.group(2);
                final String operand = m.group(3);
                final Predicate<Float> predicate;
                if ("<=".equals(operator)) {
                    predicate = new LessOrEqual(Float.parseFloat(operand));
                } else if (">".equals(operator)) {
                    predicate = new Greater(Float.parseFloat(operand));
                } else if ("in".equals(operator)) {
                    predicate = new In(parseFloatArray(operand));
                } else if ("not in".equals(operator)) {
                    predicate = new NotIn(parseFloatArray(operand));
                } else {
                    predicate = null;
                }
                final Node node = new Node(featureIndex, predicate);

                if (stack.isEmpty()) {
                    root = node;
                } else {
                    insert(stack, node);
                }
                stack.add(node);
            } else if (trimmed.startsWith("Predict")) {
                final Matcher m = predictPattern.matcher(trimmed);
                m.matches();
                final Object node = Float.parseFloat(m.group(1));
                insert(stack, node);
            }
        }
        return root;
    }

    private static void insert(final List<Node> stack, final Object node) {
        Node parent = stack.get(stack.size() - 1);
        while (parent.getLeftChild() != null && parent.getRightChild() != null) {
            stack.remove(stack.size() - 1);
            parent = stack.get(stack.size() - 1);
        }
        if (parent.getLeftChild() == null) parent.setLeftChild(node);
        else parent.setRightChild(node);
    }

    private static float[] parseFloatArray(final String set) {
        final StringTokenizer st = new StringTokenizer(set, "{,}");
        final float[] floats = new float[st.countTokens()];
        for (int i=0; st.hasMoreTokens(); i++) {
            floats[i] = Float.parseFloat(st.nextToken());
        }
        return floats;
    }

    public abstract float predict(final float[] features);

    public String toDebugString() {
        try {
            final StringWriter sw = new StringWriter();
            for (int i=0; i<trees.size(); i++) {
                sw.write("Tree " + i + ":\n");
                print(sw, "", trees.get(0));
            }
            return sw.toString();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private static void print(final Writer w, final String indent, final Object object) throws IOException {
        if (object instanceof Number) {
            w.write(indent + "Predict: " + object + "\n");
        } else if (object instanceof Node) {
            final Node node = (Node) object;
            // left node
            w.write(indent + node + "\n");
            print(w, indent + " ", node.getLeftChild());
            w.write(indent + "Else\n");
            print(w, indent + " ", node.getRightChild());
        }
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}";
    }

    /**
     * Node.
     */
    protected static class Node {

        private final int featureIndex;
        private final Predicate<Float> predicate;
        private Object leftChild;
        private Object rightChild;

        public Node(final int featureIndex, final Predicate<Float> predicate) {
            Objects.requireNonNull(predicate);
            this.featureIndex = featureIndex;
            this.predicate = predicate;
        }

        public void setLeftChild(final Object leftChild) {
            this.leftChild = leftChild;
        }

        public void setRightChild(final Object rightChild) {
            this.rightChild = rightChild;
        }

        public Object getLeftChild() {
            return leftChild;
        }

        public Object getRightChild() {
            return rightChild;
        }

        public Object eval(final float[] features) {
            Object result = this;
            do {
                final Node node = (Node)result;
                result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
            } while (result instanceof Node);

            return result;
        }

        @Override
        public String toString() {
            return "If (feature " + featureIndex + " " + predicate + ")";
        }

    }

    private static class LessOrEqual implements Predicate<Float> {
        private final float value;

        public LessOrEqual(final float value) {
            this.value = value;
        }

        @Override
        public boolean test(final Float f) {
            return f <= value;
        }

        @Override
        public String toString() {
            return "<= " + value;
        }
    }

    private static class Greater implements Predicate<Float> {
        private final float value;

        public Greater(final float value) {
            this.value = value;
        }

        @Override
        public boolean test(final Float f) {
            return f > value;
        }

        @Override
        public String toString() {
            return "> " + value;
        }
    }

    private static class In implements Predicate<Float> {
        private final float[] array;

        public In(final float[] array) {
            this.array = array;
        }

        @Override
        public boolean test(final Float f) {
            for (int i=0; i<array.length; i++) {
                if (array[i] == f) return true;
            }
            return false;
        }

        @Override
        public String toString() {
            return "in " + Arrays.toString(array);
        }
    }

    private static class NotIn implements Predicate<Float> {
        private final float[] array;

        public NotIn(final float[] array) {
            this.array = array;
        }

        @Override
        public boolean test(final Float f) {
            for (int i=0; i<array.length; i++) {
                if (array[i] == f) return false;
            }
            return true;
        }

        @Override
        public String toString() {
            return "not in " + Arrays.toString(array);
        }
    }
}

要使用该类进行分类,请使用:

To use the class for classification, use:

import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

/**
 * RandomForestClassifier.
 */
public class RandomForestClassifier extends RandomForest {

    public RandomForestClassifier(final URL model) throws IOException {
        super(model);
    }

    @Override
    public float predict(final float[] features) {
        final Map<Object, Integer> counts = new HashMap<>();
        trees.stream().map(node -> node.eval(features))
                .forEach(result -> {
                    Integer count = counts.get(result);
                    if (count == null) {
                        counts.put(result, 1);
                    } else {
                        counts.put(result, count + 1);
                    }
                });
        return (Float)counts.entrySet()
                .stream()
                .sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
                .map(Map.Entry::getKey)
                .findFirst().get();
    }
}

对于回归:

import java.io.IOException;
import java.net.URL;

/**
 * RandomForestRegressor.
 */
public class RandomForestRegressor extends RandomForest {

    public RandomForestRegressor(final URL model) throws IOException {
        super(model);
    }

    @Override
    public float predict(final float[] features) {
        return (float)trees
                .stream()
                .mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
                .average()
                .getAsDouble();
    }
}

这篇关于Spark | ML | Random Forest |从RandomForestClassificationModel的.txt加载经过训练的模型.toDebugString的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

07-25 12:06