En finir avec les problèmes de case class dans Spark

Dans un précédent article, nous avons vu qu'il est possible d'avoir des fonctions d'ordre supérieur avec des versions anciennes de Spark. Sauf que notre implémentation souffre d'un défaut que nous retrouvons assez souvent dès lors que nous commençons à mélanger case class et dataset. Voyons ensemble comment nous allons y remédier.

Nous allons créer un dataset représentant des familles. Pour cela, créons une case class représentant des personnes.

case class Person(name: String, age: Int)

Créons maintenant notre dataset.

val df =
  ss.createDataset(Seq(
      ("1", Seq(
        Person("John", 32),
        Person("Mary", 31)
      )),
      ("2", Seq(
        Person("Fred", 42),
        Person("Elsa", 40),
        Person("Daryll", 10)
      ))
  )).toDF("id", "family")
+---+--------------------------------------+
|id |family                                |
+---+--------------------------------------+
|1  |[[John, 32], [Mary, 31]]              |
|2  |[[Fred, 42], [Elsa, 40], [Daryll, 10]]|
+---+--------------------------------------+

Formalisons une requête qui ne conserve que les noms des personnes du dataset précédent. Pour cela nous allons utiliser l'une des fonctions d'ordre supérieur que nous avons vu dans l'article précédent sur le sujet.

import Implicits._

df.select(
  $"id",
  $"family".map((p: Person) => p.name).as("family")
)

Nous obtenons alors une exception.

Exception in thread "main" org.apache.spark.SparkException: Failed to execute user defined function($anonfun$4: (array<struct<name:string,age:int>>) => array<string>)
	at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1066)
	at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:151)
	at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:50)
	...
Caused by: java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to xag.Person
	at Example$$anonfun$main$2.apply(TestMain.scala:33)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	...

L'exception d'origine indique : java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to Person. Catalyst est incapable de reconstituer l'entité Person à partir du résultat de la version sérialisée de cette entité contenue dans le dataset.

Pour résoudre cette exception, nous allons utiliser com.twitter.chill.Externalizer. Twitter Chill est une bibliothèque d'extension pour Kryo et Kryo est une bibliothèque de sérialisation Java utilisé par Spark. Externalizer permet de mettre en place une sorte de "boîte" pour sérialiser et désérialiser n'importe quel type, même non sérialisable au sens Java.

Nous allons aussi utiliser une opération permettant d'effectuer un "nettoyage" sur les données rencontrées, en récupérant notamment les types sous-jacent cachés derrière Any, monnaie courante avec Spark. Nous allons utiliser une approche (ou plutôt idiome) typeclass pour associer cette opération à l'ensemble des types utilisés dans les données (en attendant la nouvelle syntaxe sensée arriver avec Scala 3 !). Voici l'interface de cette opération.

// The typeclass with the operation to clean data for Spark
trait CleanFromRow[A] extends Serializable {
  def clean(a: Any): A
}

Ci-dessous, nous retrouvons la classe implicite avec l'opération map, que nous avons définie dans l'article précédent avec un léger refactoring : 1/ on s'assure que l'opération clean s'applique bien au type A donné en entrée, 2/ on s'assure que le "nettoyage" et la sérialisation seront réalisés (serializeAndClean). Le refactoring ne s'applique qu'à l'opération map, mais nous pouvons aussi l'appliquer sur les autres opérations : flatMap et foldLeft.

object Implicits2 {

  implicit class RicherColumn2(column: Column) {

    def map[A : CleanFromRow: TypeTag, B : TypeTag](f: A => B): Column = {
      val f0 = udf[Seq[B], Seq[A]](serializeAndClean(s => s.map(f)))

      f0(column)
    }

  }

  import com.twitter.chill.Externalizer

  private def serializeAndClean[A: CleanFromRow, B](f: Seq[A] => B): Seq[A] => B = {
    val cleaner: Externalizer[A => A] =
      Externalizer(implicitly[CleanFromRow[A]].clean _)
    val fExt: Externalizer[Seq[A] => B] =
      Externalizer(f)

    values =>
      if (values == null) {
        null.asInstanceOf[B]
      } else {
        val fExt0: Seq[A] => B = fExt.get
        val cleaner0: A => A = cleaner.get

        fExt0(values.map(cleaner0))
      }
  }

}

Le comportement de l'opération de "nettoyage" va dépendre du type rencontré. Nous définissons donc ci-dessous différentes instances du type CleanFromRow pour cela (Int, String, Seq...). Nous utilisons aussi Magnoliade Jon Pretty pour associer l'opération clean à des case class. L'avantage de Magnolia, c'est qu'il est récursif. Du coup, ça fonctionne aussi avec des structures imbriquées de différentes profondeurs.

object CleanFromRow {
  import magnolia._
  import scala.reflect.ClassTag
  import language.experimental.macros

  type Typeclass[T] = CleanFromRow[T]

  private def instance[A]: Typeclass[A] =
    new Typeclass[A] {
      override def clean(a: Any): A = a.asInstanceOf[A]
    }

  // Instances to associate clean operation to basic types
  implicit val double: Typeclass[Double] = instance
  implicit val boolean: Typeclass[Boolean] = instance
  implicit val strCFR: Typeclass[String] = instance
  implicit val intCFR: Typeclass[Int] = instance
  implicit val longCFR: Typeclass[Long] = instance
  // add other typeclass instances for basic types...

  // Instance for Option type
  implicit def opt[T: Typeclass: Manifest]: Typeclass[Option[T]] =
    new Typeclass[Option[T]] {
      // this helps to avoid type erasure warning
      private val rc = implicitly[Manifest[T]].runtimeClass

      override def clean(a: Any): Option[T] =
        a match {
          case ox: Option[_]
              if ox.forall(x => rc.isAssignableFrom(x.getClass)) =>
            ox.asInstanceOf[Option[T]]
          case null => None
          case x    => Option(implicitly[Typeclass[T]].clean(x))
        }
    }

  // Instance for Seq type
  implicit def seq[T:Typeclass:Manifest]: Typeclass[Seq[T]] = {
    new Typeclass[Seq[T]] {
      // this helps to avoid type erasure warning
      private val rc = implicitly[Manifest[T]].runtimeClass

      override def clean(a: Any): Seq[T] =
        a match {
          case Nil => Nil
          case xs: Seq[_]
              if xs.forall(x => rc.isAssignableFrom(x.getClass)) =>
            xs.asInstanceOf[Seq[T]]
          case x: Seq[_] => x.map(implicitly[Typeclass[T]].clean)
        }
    }
  }

  // Instance generator for case classes
  def combine[T: ClassTag](ctx: CaseClass[CleanFromRow, T]): Typeclass[T] =
    new Typeclass[T] {
      override def clean(a: Any): T =
        a match {
          case a: T => a
          case r: Row =>
            val values: Seq[Any] =
              r.toSeq
                .zip(ctx.parameters)
                .map {
                  case (rowValue, param) => param.typeclass.clean(rowValue)
                }
            ctx.rawConstruct(values)
        }
    }

  implicit def gen[T]: CleanFromRow[T] = macro Magnolia.gen[T]
}

Reprenons maintenant notre exemple précédent en remplaçant Implicits pour Implicits2.

import Implicits2._

df.select(
  $"id",
  $"family".map((p: Person) => p.name).as("family")
)
+---+--------------------+
|id |family              |
+---+--------------------+
|1  |[John, Mary]        |
|2  |[Fred, Elsa, Daryll]|
+---+--------------------+

C'est le résultat attendu !


Annexe

Suite de

Fonctions d'ordre supérieur dans Spark 2 pour traiter des structures imbriquées
package utils

import com.twitter.chill.MeatLocker
import org.apache.spark.sql.Row

import scala.reflect.ClassTag

trait CleanFromRow[A] extends Serializable {
  def clean(a: Any): A
}

object CleanFromRow {

  private def instance[A]: Typeclass[A] = new Typeclass[A] {
    override def clean(a: Any): A = a.asInstanceOf[A]
  }

  implicit val double:  Typeclass[Double]  = instance
  implicit val boolean: Typeclass[Boolean] = instance
  implicit val strCFR:  Typeclass[String]  = instance
  implicit val intCFR:  Typeclass[Int]     = instance
  implicit val longCFR: Typeclass[Long]    = instance

  implicit def opt[T: Typeclass: Manifest]: Typeclass[Option[T]] =
    new Typeclass[Option[T]] {
      override def clean(a: Any): Option[T] = {
        a match {
          case x: Option[T] => x
          case null => None
          case x    => Option(implicitly[Typeclass[T]].clean(x))
        }
      }
    }

  implicit def seq[T: Typeclass: Manifest]: Typeclass[Seq[T]] =
    new Typeclass[Seq[T]] {
      override def clean(a: Any): Seq[T] = a match {
        case x: Seq[T] => x
        case x: Seq[_] => x.map(implicitly[Typeclass[T]].clean)
      }
    }

  import magnolia._

  import language.experimental.macros

  type Typeclass[T] = CleanFromRow[T]

  def combine[T: ClassTag](ctx: CaseClass[CleanFromRow, T]): CleanFromRow[T] = {
    new CleanFromRow[T] {

      override def clean(a: Any): T = {
        a match {
          case a: T => a
          case r: Row =>
            val values: Seq[Any] =
              r.toSeq
                .zip(ctx.parameters)
                .map({
                  case (rowValue, param) => param.typeclass.clean(rowValue)
                })
            ctx.rawConstruct(values)

        }

      }
    }
  }

  implicit def gen[T]: CleanFromRow[T] = macro Magnolia.gen[T]
}

case class Tata(a:     String, b: Int)
case class Toto(tatas: Seq[Tata])

object UnnestedSpark {

  def tataToX(seq: Seq[Tata]): String = {
    seq.map(_.a).headOption.getOrElse("")
  }

  def cleanF[A: CleanFromRow, B](f: Seq[A] => B): Seq[A] => B = {
    val g: MeatLocker[A => A] =
      com.twitter.chill.MeatLocker(implicitly[CleanFromRow[A]].clean _)
    val mf: MeatLocker[Seq[A] => B] = com.twitter.chill.MeatLocker(f)

    a =>
      {

        if (a == null) {
          null.asInstanceOf[B]
        } else {
          val f0: Seq[A] => B = mf.get
          val g0: A      => A = g.get
          f0(a.map(g0))
        }
      }

  }

  /*def main(args: Array[String]): Unit = {

    val ss = TestSparkSession.create()

    import ss.implicits._

    val df = ss.createDataset(Seq(Toto(Seq(Tata("a", 1))))).toDF()

    import org.apache.spark.sql.functions._

    ss.udf.register("tataToX", cleanF(tataToX))

    df.select(expr("tataToX(tatas)")).show(false)

  }*/

}

ss.udf.register("dmConfToX", cleanF(niveauPack))


def niveauPack(confs: Seq[DmConfiguration]): String = {
    confs.flatMap(_.niveauPack).mkString(", ")
  }

trait Typeclass[T] {
  def clean(a: Any): T
}

object Typeclass {
  implicit def seq[T:Typeclass:Manifest]: Typeclass[Seq[T]] = {
    new Typeclass[Seq[T]] {
      private val rc = implicitly[Manifest[T]].runtimeClass
      override def clean(a: Any): Seq[T] =
        a match {
          case Nil => Nil
          case xs: Seq[_]
              if xs.forall(x => rc.isAssignableFrom(x.getClass)) =>
            xs.asInstanceOf[Seq[T]]
          case x: Seq[_] => x.map(implicitly[Typeclass[T]].clean)
        }
    }
  }

  implicit def opt[T: Typeclass: Manifest]: Typeclass[Option[T]] =
    new Typeclass[Option[T]] {
      private val rc = implicitly[Manifest[T]].runtimeClass
      override def clean(a: Any): Option[T] =
        a match {
          case ox: Option[_]
              if ox.forall(x => rc.isAssignableFrom(x.getClass)) =>
            ox.asInstanceOf[Option[T]]
          case null => None
          case x    => Option(implicitly[Typeclass[T]].clean(x))
        }
    }
}