Fonctions d'ordre supérieur dans Spark 2 pour traiter des structures imbriquées

Nous l'avons annoncé dans un précédent article : Spark 2.4 est arrivé avec tout un lot de fonctions d'ordre supérieur (HOF - Higher Order Function) permettant de manipuler des structures imbriquées (en se limitant aux collections) dans les dataframes. Ce qui est bien pratique pour ce type de colonne, ce qui permet aussi de mieux les appréhender. Néanmoins, cette fonctionnalité est cantonnée à Spark 2.4+ et accessible uniquement dans les requêtes Spark SQL. Or actuellement, bien des datalakes n'ont pas encore migré en version 2.4 et les fournisseurs Hadoop on prem et cloud sont pour le moment limités à la version 2.3. Alors, comment faire ?

L'idée va être d'ajouter de nouvelles opérations au type Column utilisant des UDF (User Defined Function). Nous définissons ci-dessous les fonctions map, flatMap et foldLeft.

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.udf
import scala.reflect.runtime.universe._

object Implicits {
  
  implicit class RicherColumn(column: Column) {

    /** Transform each element of an array by applying the function g.
      */
    def map[A: TypeTag, B: TypeTag](g: A => B): Column = {
      val f = udf[Seq[B], Seq[A]](s => s.map(g))

      f(column)
    }

    /** Transform each element of an array by applying the function g
      * and using the elements of the resulting array.
      */
    def flatMap[A: TypeTag, B: TypeTag](g: A => Seq[B]): Column = {
      val f = udf[Seq[B], Seq[A]](s => s.flatMap(g))

      f(column)
    }

    /** Aggregate an array to a single value.
      */
    def foldLeft[A: TypeTag, B: TypeTag](init: B)(g: (B, A) => B): Column = {
      val f = udf[B, Seq[A]](s => s.foldLeft(init)(g))

      f(column)
    }
    
  }

}

Voyons à présent comment utiliser ces fonctions. Nous allons pour cela utiliser le dataset ci-dessous.

val df =
  Seq(
    ("abc", Seq("Hello world"),                   Seq(0L, 2L)),
    ("def", Seq("Wild Horses", "Paint It Black"), Seq(1L, 2L))
  ).toDF(
     "id",  "text",                               "timestamps"
  )
+---+-----------------------------+----------+
|id |text                         |timestamps|
+---+-----------------------------+----------+
|abc|[Hello world]                |[0, 2]    |
|def|[Wild Horses, Paint It Black]|[1, 2]    |
+---+-----------------------------+----------+

Utilisons à présent nos fonctions afin de réaliser quelques transformations sur nos données.

import Implicits._

val result =
  df.select(
    $"id",
    // multiply each timestamp by 2
    $"timestamps"
      .map((t: Long) => t * 2)
      .as("map"),
    // get the words in the arrays
    $"text"
      .flatMap((t: String) => t.split(" ").toSeq)
      .as("flatMap"),
    // compute the sum of timestamps
    $"timestamps"
      .foldLeft[Long, Long](0)(_ + _)
      .as("foldLeft")
  )

Ce qui donne :

+---+------+--------------------------------+--------+
|id |map   |flatMap                         |foldLeft|
+---+------+--------------------------------+--------+
|abc|[0, 4]|[Hello, world]                  |2       |
|def|[2, 4]|[Wild, Horses, Paint, It, Black]|3       |
+---+------+--------------------------------+--------+

À titre de comparaison, essayons d'obtenir un résultat équivalent avec Spark 2.4.

SELECT
  id,
  transform(timestamps, t -> t * 2) as transform,
  flatten(transform(text, t -> split(t, " "))) as flatten_transform,
  aggregate(timestamps, bigint(0), (s, v) -> s + v) as aggregate
FROM data

Ce qui donne :

+---+---------+--------------------------------+---------+
|id |transform|flatten_transform               |aggregate|
+---+---------+--------------------------------+---------+
|abc|[0, 4]   |[Hello, world]                  |2        |
|def|[2, 4]   |[Wild, Horses, Paint, It, Black]|3        |
+---+---------+--------------------------------+---------+

En dehors du changement de nom entre nos fonctions et la version 2.4 de Spark 😉 (map et transform, foldLeft et aggregate), nous pouvons remarquer qu'il n'y a pas de fonction équivalente à flatMap. Nous devons pour le coup utiliser une composition entre flatten et transform. Ceci dit, nous avons bien un résultat identique.

Nous avons vu ici quelques fonctions d'ordre supérieur sur dataframe pour les impatients 😉. Comprenez bien sûr que le lot présenté ici peut être complété (filter, take, zip, traverse). Comme il s'agit d'UDF, il n'y a pas d'optimisation possible. Néanmoins, l'avantage de ces fonctions est qu'elles peuvent être utilisées directement dans le code Scala.