Spark 9편: JDBC 병렬처리 시 주의 사항

안녕하세요 오늘은 BESPIN GLOBAL Data실 한제호님이 작성해주신 ‘Spark 9편: JDBC 병렬처리 시 주의 사항’에 대해 소개해드리도록 하겠습니다.

목차
1. 개요
2. 이슈 사항
3. 해결 방안
4. JDBC 병렬 처리 시 주의 사항

1. 개요

  • Spark은 JDBC 라이브러리를 통해 아래와 같이 RDB의 데이터를 수집할 수 있습니다.
jdbcDF = spark.read \\
    .format("jdbc") \\
    .option("url", "jdbc:postgresql:dbname") \\
    .option("dbtable", "public.tablename") \\
    .option("user", "username") \\
    .option("password", "password") \\
    .load()
  • 하지만 위와 같이 데이터를 조회하는 경우 Spark의 장점인 병렬처리로 데이터를 수집할 수 없으면 단일 Task(파티션)를 통해 데이터가 수집되기 때문에 대량의 데이터를 조회하는 경우 timeout, OOM(out of memory), Storage Spill이 발생할 수 있습니다.
  • Spark에서는 다수의 Task를 생성하여 병렬로 SQL을 수행하여 데이터를 수집할 수 있습니다.
    • 옵션 값
      • partitionColumn – partition 나눌 key가 되는 값 – 숫자, 날짜, 타임스탬프여야합니다.
      • lowerBound – partition의 최소값
      • upperBound – partition의 최대값
      • numPartitions – 파티셔닝하는 partition 최대 갯수
  • 사용 예시
df = spark.read \\
        .format("jdbc") \\
        .option("url", "jdbc:postgresql:dbname") \\
        .option("dbtable", "(select * from table) t") \\
        .option("user", "username") \\
        .option("password", "password") \\
        .option("partitionColumn", "KEY_COLUMN") \\
        .option("lowerBound", 0) \\
        .option("upperBound", 99) \\
        .option("numPartitions", 10) \\
        .load()

2. 이슈 사항

  • 위 형태로 쿼리를 수행하면 partitionColumn의 카디널리티가 높음에도 불구하고 예상과 달리 파티션 Skew가 발생합니다.
  • 예상했던 병렬 쿼리(각 파티션당 10개의 값 수집)
-- 1번 Task
select * from (select * from table) t where KEY_COLUMN < 10 or KEY_COLUMN is null
-- 2번 Task
select * from (select * from table) t where KEY_COLUMN >= 10 or KEY_COLUMN < 20
-- 3번 Task
select * from (select * from table) t where KEY_COLUMN >= 20 or KEY_COLUMN < 30 
-- 4번 Task
select * from (select * from table) t where KEY_COLUMN >= 30 or KEY_COLUMN < 40 
-- 5번 Task
select * from (select * from table) t where KEY_COLUMN >= 40 or KEY_COLUMN < 50 
-- 6번 Task
select * from (select * from table) t where KEY_COLUMN >= 50 or KEY_COLUMN < 60 
-- 7번 Task
select * from (select * from table) t where KEY_COLUMN >= 60 or KEY_COLUMN < 70 
-- 8번 Task
select * from (select * from table) t where KEY_COLUMN >= 70 or KEY_COLUMN < 80 
-- 9번 Task
select * from (select * from table) t where KEY_COLUMN >= 80 or KEY_COLUMN < 90 
-- 10번 Task
select * from (select * from table) t where KEY_COLUMN >= 90

실제 수행된 쿼리(첫번째, 마지막 파티션 14개 값 수집, 나머지 9개 값 수집)

-- 1번 Task
select * from (select * from table) t where KEY_COLUMN < 14 or KEY_COLUMN is null
-- 2번 Task
select * from (select * from table) t where KEY_COLUMN >= 14 or KEY_COLUMN < 23
-- 3번 Task
select * from (select * from table) t where KEY_COLUMN >= 23 or KEY_COLUMN < 32 
-- 4번 Task
select * from (select * from table) t where KEY_COLUMN >= 32 or KEY_COLUMN < 41 
-- 5번 Task
select * from (select * from table) t where KEY_COLUMN >= 41 or KEY_COLUMN < 50 
-- 6번 Task
select * from (select * from table) t where KEY_COLUMN >= 50 or KEY_COLUMN < 59 
-- 7번 Task
select * from (select * from table) t where KEY_COLUMN >= 59 or KEY_COLUMN < 68 
-- 8번 Task
select * from (select * from table) t where KEY_COLUMN >= 68 or KEY_COLUMN < 77
-- 9번 Task
select * from (select * from table) t where KEY_COLUMN >= 77 or KEY_COLUMN < 86 
-- 10번 Task
select * from (select * from table) t where KEY_COLUMN >= 86

3. 해결 방안

  • 오픈소스 Spark 코드 분석: 
    https://github.com/apache/spark/blob/v3.3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
  • JDBCRelation.scala 코드내 130줄부터 170줄까지 디버깅 결과 소수점에 대한 계산 시 버림 처리하는 로직으로 인해 위와 같은 현상이 발견됩니다. (val stride = preciseStride.toLong)
  • 소수점 버림 처리를 방지하기 위해서는 numPartitions 배수로 upperBound 설정이 필요합니다.
    • partitionColumn: KEY_COLUMN
    • lowerBound: 0
    • upperBound: 100
    • numPartitions: 10
  • 필자는 디버깅을 위해 해당 코드를 발췌하여 위 값에 따라 디버깅 할 수 있는 코드를 개발하였고 해당 코드는 아래 코드를 참조하기 바랍니다.
import scala.math.BigDecimal.RoundingMode

object Hello {
    def main(args: Array[String]) = {
        val numPart = 10
        val lowerBound =  0
        val upperBound = 100
        val column = "KEY_COLUMN"

        val numPartitions =
            if ((upperBound - lowerBound) >= numPart || (upperBound - lowerBound) < 0) {
                numPart
            } else {
                upperBound - lowerBound
            }
        
        println("numPartitions: " + numPartitions)

        val upperStride = (upperBound / BigDecimal(numPartitions)).setScale(18, RoundingMode.HALF_EVEN)
        val lowerStride = (lowerBound / BigDecimal(numPartitions)).setScale(18, RoundingMode.HALF_EVEN)

        println("upperStride: " + upperStride)
        println("lowerStride: " + lowerStride)

        val preciseStride = upperStride - lowerStride
        val stride = preciseStride.toLong

        println("preciseStride: " + preciseStride)
        println("stride: " + stride)

        val lostNumOfStrides = (preciseStride - stride) * numPartitions / stride
        val lowerBoundWithStrideAlignment = lowerBound + ((lostNumOfStrides / 2) * stride).setScale(0, RoundingMode.HALF_UP).toLong

        println("lostNumOfStrides: " + lostNumOfStrides)
        println("lowerBoundWithStrideAlignment: " + lowerBoundWithStrideAlignment)

        var i: Int = 0
        var currentValue = lowerBoundWithStrideAlignment
        while (i < numPartitions) {
            val lBoundValue = currentValue.toString()
            val lBound = if (i != 0) s"$column >= $lBoundValue" else null
            currentValue += stride
            val uBoundValue = currentValue.toString()
            val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null
            val whereClause =
                if (uBound == null) {
                lBound
                } else if (lBound == null) {
                s"$uBound or $column is null"
                } else {
                s"$lBound AND $uBound"
                }
            println(whereClause)
            i = i + 1
        }

    }
}

4. JDBC 병렬 처리 시 주의 사항

  • 병렬 쿼리가 실행되어 속도가 개선 되지만, 병렬 쿼리가 실행 되는만큼 DB에 무리가 가지 않는 선에서 쿼리가 실행 되어야합니다.
  • numPartitions수 만큼 파티션이 생성되기 때문에 Spark 클러스터 내 Executor가 낭비되지 않도록 구성이 필요합니다.
  • partitionColumn으로 활용되는 컬럼의 특정 값에 데이터가 몰리면 데이터 Skew가 발생되어 전체적인 Job의 속도가 느려지기 때문에 적절한 컬럼 설정이 필요합니다.

여기까지 ‘Spark 9편: JDBC 병렬처리 시 주의 사항’에 대해 소개해드렸습니다. 유익한 정보가 되셨길 바랍니다. 감사합니다. 

Written by 한 제호 / Data실

BESPIN GLOBAL