안녕하세요 오늘은 BESPIN GLOBAL Data실 한제호님이 작성해주신 ‘Spark 9편: JDBC 병렬처리 시 주의 사항’에 대해 소개해드리도록 하겠습니다.
목차
1. 개요
2. 이슈 사항
3. 해결 방안
4. JDBC 병렬 처리 시 주의 사항
1. 개요
- Spark은 JDBC 라이브러리를 통해 아래와 같이 RDB의 데이터를 수집할 수 있습니다.
jdbcDF = spark.read \\
.format("jdbc") \\
.option("url", "jdbc:postgresql:dbname") \\
.option("dbtable", "public.tablename") \\
.option("user", "username") \\
.option("password", "password") \\
.load()
- 하지만 위와 같이 데이터를 조회하는 경우 Spark의 장점인 병렬처리로 데이터를 수집할 수 없으면 단일 Task(파티션)를 통해 데이터가 수집되기 때문에 대량의 데이터를 조회하는 경우 timeout, OOM(out of memory), Storage Spill이 발생할 수 있습니다.
- Spark에서는 다수의 Task를 생성하여 병렬로 SQL을 수행하여 데이터를 수집할 수 있습니다.
- 옵션 값
- partitionColumn – partition 나눌 key가 되는 값 – 숫자, 날짜, 타임스탬프여야합니다.
- lowerBound – partition의 최소값
- upperBound – partition의 최대값
- numPartitions – 파티셔닝하는 partition 최대 갯수
- 옵션 값
- 사용 예시
df = spark.read \\
.format("jdbc") \\
.option("url", "jdbc:postgresql:dbname") \\
.option("dbtable", "(select * from table) t") \\
.option("user", "username") \\
.option("password", "password") \\
.option("partitionColumn", "KEY_COLUMN") \\
.option("lowerBound", 0) \\
.option("upperBound", 99) \\
.option("numPartitions", 10) \\
.load()
2. 이슈 사항
- 위 형태로 쿼리를 수행하면 partitionColumn의 카디널리티가 높음에도 불구하고 예상과 달리 파티션 Skew가 발생합니다.
- 예상했던 병렬 쿼리(각 파티션당 10개의 값 수집)
-- 1번 Task
select * from (select * from table) t where KEY_COLUMN < 10 or KEY_COLUMN is null
-- 2번 Task
select * from (select * from table) t where KEY_COLUMN >= 10 or KEY_COLUMN < 20
-- 3번 Task
select * from (select * from table) t where KEY_COLUMN >= 20 or KEY_COLUMN < 30
-- 4번 Task
select * from (select * from table) t where KEY_COLUMN >= 30 or KEY_COLUMN < 40
-- 5번 Task
select * from (select * from table) t where KEY_COLUMN >= 40 or KEY_COLUMN < 50
-- 6번 Task
select * from (select * from table) t where KEY_COLUMN >= 50 or KEY_COLUMN < 60
-- 7번 Task
select * from (select * from table) t where KEY_COLUMN >= 60 or KEY_COLUMN < 70
-- 8번 Task
select * from (select * from table) t where KEY_COLUMN >= 70 or KEY_COLUMN < 80
-- 9번 Task
select * from (select * from table) t where KEY_COLUMN >= 80 or KEY_COLUMN < 90
-- 10번 Task
select * from (select * from table) t where KEY_COLUMN >= 90
실제 수행된 쿼리(첫번째, 마지막 파티션 14개 값 수집, 나머지 9개 값 수집)
-- 1번 Task
select * from (select * from table) t where KEY_COLUMN < 14 or KEY_COLUMN is null
-- 2번 Task
select * from (select * from table) t where KEY_COLUMN >= 14 or KEY_COLUMN < 23
-- 3번 Task
select * from (select * from table) t where KEY_COLUMN >= 23 or KEY_COLUMN < 32
-- 4번 Task
select * from (select * from table) t where KEY_COLUMN >= 32 or KEY_COLUMN < 41
-- 5번 Task
select * from (select * from table) t where KEY_COLUMN >= 41 or KEY_COLUMN < 50
-- 6번 Task
select * from (select * from table) t where KEY_COLUMN >= 50 or KEY_COLUMN < 59
-- 7번 Task
select * from (select * from table) t where KEY_COLUMN >= 59 or KEY_COLUMN < 68
-- 8번 Task
select * from (select * from table) t where KEY_COLUMN >= 68 or KEY_COLUMN < 77
-- 9번 Task
select * from (select * from table) t where KEY_COLUMN >= 77 or KEY_COLUMN < 86
-- 10번 Task
select * from (select * from table) t where KEY_COLUMN >= 86
3. 해결 방안
- 오픈소스 Spark 코드 분석:
https://github.com/apache/spark/blob/v3.3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala - JDBCRelation.scala 코드내 130줄부터 170줄까지 디버깅 결과 소수점에 대한 계산 시 버림 처리하는 로직으로 인해 위와 같은 현상이 발견됩니다. (val stride = preciseStride.toLong)
- 소수점 버림 처리를 방지하기 위해서는 numPartitions 배수로 upperBound 설정이 필요합니다.
- partitionColumn: KEY_COLUMN
- lowerBound: 0
- upperBound: 100
- numPartitions: 10
- 필자는 디버깅을 위해 해당 코드를 발췌하여 위 값에 따라 디버깅 할 수 있는 코드를 개발하였고 해당 코드는 아래 코드를 참조하기 바랍니다.
import scala.math.BigDecimal.RoundingMode
object Hello {
def main(args: Array[String]) = {
val numPart = 10
val lowerBound = 0
val upperBound = 100
val column = "KEY_COLUMN"
val numPartitions =
if ((upperBound - lowerBound) >= numPart || (upperBound - lowerBound) < 0) {
numPart
} else {
upperBound - lowerBound
}
println("numPartitions: " + numPartitions)
val upperStride = (upperBound / BigDecimal(numPartitions)).setScale(18, RoundingMode.HALF_EVEN)
val lowerStride = (lowerBound / BigDecimal(numPartitions)).setScale(18, RoundingMode.HALF_EVEN)
println("upperStride: " + upperStride)
println("lowerStride: " + lowerStride)
val preciseStride = upperStride - lowerStride
val stride = preciseStride.toLong
println("preciseStride: " + preciseStride)
println("stride: " + stride)
val lostNumOfStrides = (preciseStride - stride) * numPartitions / stride
val lowerBoundWithStrideAlignment = lowerBound + ((lostNumOfStrides / 2) * stride).setScale(0, RoundingMode.HALF_UP).toLong
println("lostNumOfStrides: " + lostNumOfStrides)
println("lowerBoundWithStrideAlignment: " + lowerBoundWithStrideAlignment)
var i: Int = 0
var currentValue = lowerBoundWithStrideAlignment
while (i < numPartitions) {
val lBoundValue = currentValue.toString()
val lBound = if (i != 0) s"$column >= $lBoundValue" else null
currentValue += stride
val uBoundValue = currentValue.toString()
val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null
val whereClause =
if (uBound == null) {
lBound
} else if (lBound == null) {
s"$uBound or $column is null"
} else {
s"$lBound AND $uBound"
}
println(whereClause)
i = i + 1
}
}
}
4. JDBC 병렬 처리 시 주의 사항
- 병렬 쿼리가 실행되어 속도가 개선 되지만, 병렬 쿼리가 실행 되는만큼 DB에 무리가 가지 않는 선에서 쿼리가 실행 되어야합니다.
- numPartitions수 만큼 파티션이 생성되기 때문에 Spark 클러스터 내 Executor가 낭비되지 않도록 구성이 필요합니다.
- partitionColumn으로 활용되는 컬럼의 특정 값에 데이터가 몰리면 데이터 Skew가 발생되어 전체적인 Job의 속도가 느려지기 때문에 적절한 컬럼 설정이 필요합니다.
여기까지 ‘Spark 9편: JDBC 병렬처리 시 주의 사항’에 대해 소개해드렸습니다. 유익한 정보가 되셨길 바랍니다. 감사합니다.
Written by 한 제호 / Data실
BESPIN GLOBAL