当与PySpark进行乘法运算时,似乎PySpark失去了精度。

例如,当多个两个精度为38,10的小数时,它将返回38,6并四舍五入为三个小数,这是不正确的结果。

from decimal import Decimal
from pyspark.sql.types import DecimalType, StructType, StructField

schema = StructType([StructField("amount", DecimalType(38,10)), StructField("fx", DecimalType(38,10))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)

df.printSchema()
df = df.withColumn("amount_usd", df.amount * df.fx)
df.printSchema()
df.show()

结果

>>> df.printSchema()
root
 |-- amount: decimal(38,10) (nullable = true)
 |-- fx: decimal(38,10) (nullable = true)
 |-- amount_usd: decimal(38,6) (nullable = true)

>>> df = df.withColumn("amount_usd", df.amount * df.fx)
>>> df.printSchema()
root
 |-- amount: decimal(38,10) (nullable = true)
 |-- fx: decimal(38,10) (nullable = true)
 |-- amount_usd: decimal(38,6) (nullable = true)

>>> df.show()
+--------------+------------+----------+
|        amount|          fx|amount_usd|
+--------------+------------+----------+
|233.0000000000|1.1403218880|265.695000|
+--------------+------------+----------+


这是一个错误吗?有没有办法得到正确的结果?

最佳答案

我认为这是预期的行为。

Spark的Catalyst引擎将以输入语言(例如Python)编写的表达式转换为相同类型信息的Spark内部Catalyst表示形式。然后它将在该内部表示上运行。

如果您在spark's source code中检查文件sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala,它通常用于:




 * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
 * respectively, then the following operations have the following precision / scale:
 *   Operation    Result Precision                        Result Scale
 *   ------------------------------------------------------------------------
 *   e1 * e2      p1 + p2 + 1                             s1 + s2

现在让我们看一下乘法的代码。调用adjustPrecisionScale函数的位置:
    case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
      val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
        DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
      } else {
        DecimalType.bounded(p1 + p2 + 1, s1 + s2)
      }
      val widerType = widerDecimalType(p1, s1, p2, s2)
      CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
        resultType, nullOnOverflow)
adjustPrecisionScale是发生魔术的地方,我在此处粘贴了function,以便您可以看到逻辑
  private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
    // Assumption:
    assert(precision >= scale)

    if (precision <= MAX_PRECISION) {
      // Adjustment only needed when we exceed max precision
      DecimalType(precision, scale)
    } else if (scale < 0) {
      // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
      // loss since we would cause a loss of digits in the integer part.
      // In this case, we are likely to meet an overflow.
      DecimalType(MAX_PRECISION, scale)
    } else {
      // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
      val intDigits = precision - scale
      // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
      // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
      val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
      // The resulting scale is the maximum between what is available without causing a loss of
      // digits for the integer part of the decimal and the minimum guaranteed scale, which is
      // computed above
      val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)

      DecimalType(MAX_PRECISION, adjustedScale)
    }
  }


现在让我们来看您的示例,我们有
e1 = Decimal(233.00)
e2 = Decimal(1.1403218880)

每个都有precision = 38scale = 10,所以p1=p2=38s1=s2=10。这两个的乘积应具有precision = p1+p2+1 = 77scale = s1 + s2 = 20
请注意,此处是MAX_PRECISION=38MINIMUM_ADJUSTED_SCALE=6

所以p1+p2+1=77 > 38val intDigits = precision - scale = 77 - 20 = 57minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) = min(20, 6) = 6adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) = max(38-57, 6)=6
最后,返回带有precision=38, and scale = 6的DecimalType。这就是为什么您看到amount_usd的类型为decimal(38,6)的原因。

并且在Multiply函数中,两个数字在进行乘法运算之前都已转换为DecimalType(38,6)

如果您使用Decimal(38,6)运行代码,即
schema = StructType([StructField("amount", DecimalType(38,6)), StructField("fx", DecimalType(38,6))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)

你会得到
+----------+--------+----------+
|amount    |fx      |amount_usd|
+----------+--------+----------+
|233.000000|1.140322|265.695026|
+----------+--------+----------+

为什么最终的号码是265.695000?这可能是由于Multiply函数中的其他调整所致。但是你明白了。

Multiply代码中,您可以看到我们想避免在进行乘法运算时使用最大精度,如果我们改为18
schema = StructType([StructField("amount", DecimalType(18,10)), StructField("fx", DecimalType(18,10))])


我们得到这个:
+--------------+------------+------------------------+
|amount        |fx          |amount_usd              |
+--------------+------------+------------------------+
|233.0000000000|1.1403218880|265.69499990400000000000|
+--------------+------------+------------------------+

我们可以更好地近似于python计算的结果:
265.6949999039999754657515041

希望这可以帮助!

关于python - PySpark; DecimalType乘法精度损失,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57965426/

10-12 23:25