こんにちは
sparkのjoinのとある挙動をしらなかった為、5時間くらいデバッグしました。
知らない方も多いと思いますのでまとめました。
きっかけはjoinの結果がなんか、おかしい。そして直感的にnullがあやしい。
といったところからはじまっています。
val schema = StructType(List(
StructField("id", StringType),
StructField("name", StringType)
))
val data = Arrays.asList(
Row("1", "first"),
Row("2", "second"),
Row("3", null)
)
val df = spark.createDataFrame(data, schema)
と
val schema2 = StructType(List(
StructField("id", StringType),
StructField("name", StringType),
StructField("age", IntegerType)
))
val data2 = Arrays.asList(
Row("1", "first", 10),
Row("2", "second", 20),
Row("3", null, 30)
)
val df2 = spark.createDataFrame(data2, schema2)
を
joinedDf = df.join(df2, Seq("id", "name"), "inner")
joinedDf.show
をすると
scala> df.join(df2, Seq("id", "name"), "inner").show
+---+------+---+
| id| name|age|
+---+------+---+
| 1| first| 10|
| 2|second| 20|
+---+------+---+
おや!?
おおくの人が結果として
"1", "first", 10 "2", "second", 20 "3", null, 30
を期待するとおもいますが、そうではありません。
実際は
"1", "first", 10 "2", "second", 20
でnullを含んだ行のjoinが意図通りできていません。
これはnull値を含んだjoinは落とされるからです。
(including-null-values-in-an-apache-spark-join , https://codeday.me/jp/qa/20190301/343391.html )
そこで
sparkはNULLもjoinできるようの特別なoperator <=>を用意してありますので
こちらを明示的につかってください。
(https://spark.apache.org/docs/latest/api/scala/#org.apache.spark.sql.Column)
def<=>(other: Any): Column
では実際にうごかしてみましょう。
まずはDockerFileが書かれたこのsemantive sparkのレポシトリーをcloneします。
git clone git@github.com:Semantive/docker-spark.git cd docker-spark
spark-shellをdockerで動かして、その上で上記のコードを実行してみます。
以下のコマンドでdockerでspark-shellを立ち上げる
docker run --rm -it -p 4040:4040 semantive/spark spark-shell
では
こちらがソースコードの全てです。
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import scala.collection.JavaConversions._
import java.util.Arrays
val schema = StructType(List(
StructField("id", StringType),
StructField("name", StringType)
))
val data = Arrays.asList(
Row("1", "first"),
Row("2", "second"),
Row("3", null)
)
val df = spark.createDataFrame(data, schema)
val schema2 = StructType(List(
StructField("id", StringType),
StructField("name", StringType),
StructField("age", IntegerType)
))
val data2 = Arrays.asList(
Row("1", "first", 10),
Row("2", "second", 20),
Row("3", null, 30)
)
val df2 = spark.createDataFrame(data2, schema2)
// Did you expect this? I didn't
df.join(df2, Seq("id", "name"), "inner").show
// This is null safe join operator
df.join(df2, df.col("id") <=> df2.col("id") && df.col("name") <=> df2.col("name"), "inner").show
// This is again, null value join is excluding
df.join(df2, df.col("id") === df2.col("id") && df.col("name") === df2.col("name"), "inner").show
結果は期待通りです
scala> df.join(df2, Seq("id", "name"), "inner").show
+---+------+---+
| id| name|age|
+---+------+---+
| 1| first| 10|
| 2|second| 20|
+---+------+---+
scala> df.join(df2, df.col("id") <=> df2.col("id") && df.col("name") <=> df2.col("name"), "inner").show
+---+------+---+------+---+
| id| name| id| name|age|
+---+------+---+------+---+
| 1| first| 1| first| 10|
| 2|second| 2|second| 20|
| 3| null| 3| null| 30|
+---+------+---+------+---+
scala> df.join(df2, df.col("id") === df2.col("id") && df.col("name") === df2.col("name"), "inner").show
+---+------+---+------+---+
| id| name| id| name|age|
+---+------+---+------+---+
| 1| first| 1| first| 10|
| 2|second| 2|second| 20|
+---+------+---+------+---+
以上です

Equality test that is safe for null values.