我试图将一些通用代码提取到抽象类中,但是遇到了问题。

假设我正在读取格式为“id | name”的文件:

case class Person(id: Int, name: String) extends Serializable

object Persons {
  def apply(lines: Dataset[String]): Dataset[Person] = {
    import lines.sparkSession.implicits._
    lines.map(line => {
      val fields = line.split("\\|")
      Person(fields(0).toInt, fields(1))
    })
  }
}

Persons(spark.read.textFile("persons.txt")).show()

伟大的。这很好。现在,假设我想读取带有“名称”字段的许多不同文件,因此我将提取所有常见逻辑:
trait Named extends Serializable { val name: String }

abstract class NamedDataset[T <: Named] {
  def createRecord(fields: Array[String]): T
  def apply(lines: Dataset[String]): Dataset[T] = {
    import lines.sparkSession.implicits._
    lines.map(line => createRecord(line.split("\\|")))
  }
}

case class Person(id: Int, name: String) extends Named

object Persons extends NamedDataset[Person] {
  override def createRecord(fields: Array[String]) =
    Person(fields(0).toInt, fields(1))
}

这失败,并出现两个错误:
Error:
Unable to find encoder for type stored in a Dataset.
Primitive types (Int, String, etc) and Product types (case classes)
are supported by importing spark.implicits._  Support for serializing
other types will be added in future releases.
lines.map(line => createRecord(line.split("\\|")))

Error:
not enough arguments for method map:
(implicit evidence$7: org.apache.spark.sql.Encoder[T])org.apache.spark.sql.Dataset[T].
Unspecified value parameter evidence$7.
lines.map(line => createRecord(line.split("\\|")))

我感觉这与隐式,TypeTag和/或ClassTag有关,但是我只是从Scala开始,还没有完全理解这些概念。

最佳答案

您必须进行两个小更改:

  • 由于仅支持原语和Product(作为错误消息状态),因此仅使Named特性为Serializable是不够的。您应该使它扩展Product(这意味着案例类和元组可以扩展它)
  • 实际上,Spark需要ClassTagTypeTag才能克服类型擦除并找出实际类型

  • 所以-这是一个工作版本:
    import scala.reflect.ClassTag
    import scala.reflect.runtime.universe.TypeTag
    
    trait Named extends Product { val name: String }
    
    abstract class NamedDataset[T <: Named : ClassTag : TypeTag] extends Serializable {
      def createRecord(fields: Array[String]): T
      def apply(lines: Dataset[String]): Dataset[T] = {
        import lines.sparkSession.implicits._
        lines.map(line => createRecord(line.split("\\|")))
      }
    }
    
    case class Person(id: Int, name: String) extends Named
    
    object Persons extends NamedDataset[Person] {
      override def createRecord(fields: Array[String]) =
        Person(fields(0).toInt, fields(1))
    }
    

    10-08 03:14