我以最容易想到的方式将Java的apache.commons.math的this spline interpolation algorithm转换为Scala(请参见下文)。我最终获得的功能的运行速度比原始Java代码慢2至3倍。我的猜测是问题出在Array.fill调用带来的额外循环中,但我想不出一种简单的方法来摆脱它们。关于如何使此代码性能更好的任何建议? (以更简洁和/或更实用的方式编写它也将是一件好事-在这方面的建议也将受到赞赏。)

type Real = Double

def mySplineInterpolate(x: Array[Real], y: Array[Real]) = {

  if (x.length != y.length)
    throw new DimensionMismatchException(x.length, y.length)

  if (x.length < 3)
    throw new NumberIsTooSmallException(x.length, 3, true)

  // Number of intervals.  The number of data points is n + 1.
  val n = x.length - 1

  // Differences between knot points
  val h = Array.tabulate(n)(i => x(i+1) - x(i))

  var mu: Array[Real] = Array.fill(n)(0)
  var z: Array[Real] = Array.fill(n+1)(0)
  var i = 1
  while (i < n) {
    val g = 2.0 * (x(i+1) - x(i-1)) - h(i-1) * mu(i-1)
    mu(i) = h(i) / g
    z(i) = (3.0 * (y(i+1) * h(i-1) - y(i) * (x(i+1) - x(i-1))+ y(i-1) * h(i)) /
            (h(i-1) * h(i)) - h(i-1) * z(i-1)) / g
    i += 1
  }

  // cubic spline coefficients --  b is linear, c quadratic, d is cubic (original y's are constants)
  var b: Array[Real] = Array.fill(n)(0)
  var c: Array[Real] = Array.fill(n+1)(0)
  var d: Array[Real] = Array.fill(n)(0)

  var j = n-1
  while (j >= 0) {
    c(j) = z(j) - mu(j) * c(j + 1)
    b(j) = (y(j+1) - y(j)) / h(j) - h(j) * (c(j+1) + 2.0 * c(j)) / 3.0
    d(j) = (c(j+1) - c(j)) / (3.0 * h(j))
    j -= 1
  }

  Array.tabulate(n)(i => Polynomial(Array(y(i), b(i), c(i), d(i))))
}

最佳答案

您可以摆脱所有Array.fill,因为新数组始终以0或null初始化,具体取决于它是值还是引用(布尔值用false初始化,字符用\0初始化)。

您也许可以通过压缩数组来简化循环,但是只会使其变慢。函数编程(无论如何在JVM上)都可以帮助您更快地实现这一目标的唯一方法是,如果将其设为非严格(例如,使用Stream或视图),然后继续使用而不是全部使用它。

关于java - Scala vs Java中的样条插值性能,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/14330737/

10-12 00:13
查看更多