当与PySpark进行乘法运算时,似乎PySpark失去了精度。
例如,当多个两个精度为38,10的小数时,它将返回38,6并四舍五入为三个小数,这是不正确的结果。
from decimal import Decimal
from pyspark.sql.types import DecimalType, StructType, StructField
schema = StructType([StructField("amount", DecimalType(38,10)), StructField("fx", DecimalType(38,10))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
df.printSchema()
df = df.withColumn("amount_usd", df.amount * df.fx)
df.printSchema()
df.show()
结果
>>> df.printSchema()
root
|-- amount: decimal(38,10) (nullable = true)
|-- fx: decimal(38,10) (nullable = true)
|-- amount_usd: decimal(38,6) (nullable = true)
>>> df = df.withColumn("amount_usd", df.amount * df.fx)
>>> df.printSchema()
root
|-- amount: decimal(38,10) (nullable = true)
|-- fx: decimal(38,10) (nullable = true)
|-- amount_usd: decimal(38,6) (nullable = true)
>>> df.show()
+--------------+------------+----------+
| amount| fx|amount_usd|
+--------------+------------+----------+
|233.0000000000|1.1403218880|265.695000|
+--------------+------------+----------+
这是一个错误吗?有没有办法得到正确的结果?
最佳答案
我认为这是预期的行为。
Spark的Catalyst引擎将以输入语言(例如Python)编写的表达式转换为相同类型信息的Spark内部Catalyst表示形式。然后它将在该内部表示上运行。
如果您在spark's source code中检查文件sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
,它通常用于:
和
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
* respectively, then the following operations have the following precision / scale:
* Operation Result Precision Result Scale
* ------------------------------------------------------------------------
* e1 * e2 p1 + p2 + 1 s1 + s2
现在让我们看一下乘法的代码。调用
adjustPrecisionScale
函数的位置: case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType, nullOnOverflow)
adjustPrecisionScale
是发生魔术的地方,我在此处粘贴了function,以便您可以看到逻辑 private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
// Assumption:
assert(precision >= scale)
if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
} else if (scale < 0) {
// Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
// loss since we would cause a loss of digits in the integer part.
// In this case, we are likely to meet an overflow.
DecimalType(MAX_PRECISION, scale)
} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
// The resulting scale is the maximum between what is available without causing a loss of
// digits for the integer part of the decimal and the minimum guaranteed scale, which is
// computed above
val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)
DecimalType(MAX_PRECISION, adjustedScale)
}
}
现在让我们来看您的示例,我们有
e1 = Decimal(233.00)
e2 = Decimal(1.1403218880)
每个都有
precision = 38
,scale = 10
,所以p1=p2=38
和s1=s2=10
。这两个的乘积应具有precision = p1+p2+1 = 77
和scale = s1 + s2 = 20
请注意,此处是
MAX_PRECISION=38
和MINIMUM_ADJUSTED_SCALE=6
。所以
p1+p2+1=77 > 38
,val intDigits = precision - scale = 77 - 20 = 57
minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) = min(20, 6) = 6
adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) = max(38-57, 6)=6
最后,返回带有
precision=38, and scale = 6
的DecimalType。这就是为什么您看到amount_usd
的类型为decimal(38,6)
的原因。并且在
Multiply
函数中,两个数字在进行乘法运算之前都已转换为DecimalType(38,6)
。如果您使用
Decimal(38,6)
运行代码,即schema = StructType([StructField("amount", DecimalType(38,6)), StructField("fx", DecimalType(38,6))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
你会得到
+----------+--------+----------+
|amount |fx |amount_usd|
+----------+--------+----------+
|233.000000|1.140322|265.695026|
+----------+--------+----------+
为什么最终的号码是
265.695000
?这可能是由于Multiply
函数中的其他调整所致。但是你明白了。从
Multiply
代码中,您可以看到我们想避免在进行乘法运算时使用最大精度,如果我们改为18schema = StructType([StructField("amount", DecimalType(18,10)), StructField("fx", DecimalType(18,10))])
我们得到这个:
+--------------+------------+------------------------+
|amount |fx |amount_usd |
+--------------+------------+------------------------+
|233.0000000000|1.1403218880|265.69499990400000000000|
+--------------+------------+------------------------+
我们可以更好地近似于python计算的结果:
265.6949999039999754657515041
希望这可以帮助!
关于python - PySpark; DecimalType乘法精度损失,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57965426/