[입 개발] Spark 에서 Database 빨리 덤프하는 법(Parallelism)

전통적으로 Hadoop 기반에서 Database를 덤프할 때는 sqoop을 많이 사용합니다. 그런데… Spark에서는 그냥 바로 database에서 jdbc를 통해서 데이터를 읽을 수도 있습니다.(의외로 이걸 모르는 경우가 많습니다.)

그래서 우리는 다음과 같이 Database에서 데이터를 덤프합니다.

    val df = spark.read
      .format("jdbc")
      .option("url", mysqlUrl)
      .option("driver", "com.mysql.cj.jdbc.Driver")
      .option("user", mysqlUsername)
      .option("password", mysqlPassword)
      .option("useSSL", "false")
      .option("dbtable", table)
      .load()

그런데 위의 코드는 데이터가 적을 때는 잘 동작하지만, 데이터 량이 많아질 수록 점점 느려집니다. 데이터량이 몇백만 몇천만 밖에 안되는 거 같은데, 위의 코드는 몇십분씩 동작합니다.(사실… 돌다가 죽어서 다시 재실행 하는 시간이 더 큰…) 그런데도 잘 동작하지 않습니다.

이건 결국 여러개의 Spark Executor에서 돌지 못하고 한넘이 너무 많은 데이터를 다루다가 결국 메모리가 터져버리는 케이스가 대부분입니다. 그럼 어떻게 해야 할까요? Spark 은 원래 여러 Executor로 실행하기 위해서 쓰는 거 아닌가요 라고 생각할 수 있습니다. 넵넵 맞습니다. 맞구요.(퍽퍽)

그래서 Spark JDBC에는 옵션으로 numPatitions 이라는 것을 제공합니다.

    val df = spark.read
      .format("jdbc")
      .option("url", mysqlUrl)
      .option("driver", "com.mysql.cj.jdbc.Driver")
      .option("user", mysqlUsername)
      .option("password", mysqlPassword)
      .option("useSSL", "false")
      .option("dbtable", table)
      .option("numPartitions", numPartitions)
      .load()

오옷 numPartitions만 주면 내부적으로 나눠져서 제대로 동작할 듯 합니다. 그런데 실제로 돌려보면 제대로 동작하지 않습니다. 흑… 원칙적으로 numPatitions가 우리의 해결책이 맞습니다. 다만, 이 옵션을 위해서 추가로 설정해줘야 할 값들이 있습니다. Database를 기준으로 하기 때문에, 어떤 column을 기준으로 분할 할 것인지, 그리고 그 구간을 어떻게 할 것인지를 정해줘야 합니다.

그래서 partitionColumn, lowerBound, upperBound 라는 값이 있습니다. 그러면 저 사이의 값들을 numPartitions 만큼 나눠져서 가져오게 됩니다. 다만 여기서 upperBound 값이 너무 크면 실제 데이터들이 적게 나눠질 수 있습니다. 그런데 또 upperBound 값을 너무 작게 잡으면, 데이터를 다 덤프하지 못할 수 있습니다. 그래서 저는 다음과 같이 구하고 있습니다.

max 개수를 가져오고, 그걸 upperBound 로 설정하고 있습니다.

    val partitionSize = 2000000

    val sizeDF = spark.read
      .format("jdbc")
      .option("url", mysqlUrl)
      .option("driver", "com.mysql.cj.jdbc.Driver")
      .option("user", mysqlUsername)
      .option("password", mysqlPassword)
      .option("useSSL", "false")
      .option("query", s"select max($partitionKey) from $table")
      .load()

    val maxId = sizeDF.collect()(0)(0).toString.toLong
    val numPartitions = (maxId / partitionSize) + 1

    val df = spark.read
      .format("jdbc")
      .option("url", mysqlUrl)
      .option("driver", "com.mysql.cj.jdbc.Driver")
      .option("user", mysqlUsername)
      .option("password", mysqlPassword)
      .option("useSSL", "false")
      .option("dbtable", table)
      .option("partitionColumn", partitionKey)
      .option("numPartitions", numPartitions)
      .option("lowerBound", 1)
      .option("upperBound", maxId)
      .load()

그렇지만 데이터가 너무 크면 partition 개수가 또 너무 많아질 수 있으므로 여기에 대한 적절한 조절이 필요합니다.