嗨,大家好,我加载了一个csv作为dataframe,我想将所有列都强制转换为float,因为知道文件很大,可以写入所有列的名称

val spark = SparkSession.builder.master("local").appName("my-spark-app").getOrCreate()
val df = spark.read.option("header",true).option("inferSchema", "true").csv("C:/Users/mhattabi/Desktop/dataTest2.csv")


任何帮助将不胜感激谢谢

最佳答案

以此DataFrame为例:

val df = sqlContext.createDataFrame(Seq(("0", 0),("1", 1),("2", 0))).toDF("id", "c0")


与模式:

StructType(
    StructField(id,StringType,true),
    StructField(c0,IntegerType,false))


您可以通过.columns函数遍历DF列:

val castedDF = df.columns.foldLeft(df)((current, c) => current.withColumn(c, col(c).cast("float")))


因此,新的DF模式如下所示:

StructType(
    StructField(id,FloatType,true),
    StructField(c0,FloatType,false))


编辑:

如果要从转换中排除某些列,则可以执行以下操作(假设我们要排除列ID):

val exclude = Array("id")

val someCastedDF = (df.columns.toBuffer --= exclude).foldLeft(df)((current, c) =>
                                              current.withColumn(c, col(c).cast("float")))


其中exclude是我们要从转换中排除的所有列的数组。

因此,此新DF的架构为:

StructType(
    StructField(id,StringType,true),
    StructField(c0,FloatType,false))


请注意,也许这不是最好的解决方案,但它可能是一个起点。

09-15 17:42