Skip to content

Commit

Permalink
[HUDI-5347] Cleaned up transient state from ExpressionPayload makin…
Browse files Browse the repository at this point in the history
…g it non-serializable (apache#7424)

- Internal state (cached records, writer schemas) are removed to make 
   sure that `ExpressionPayload` object is serializable at all times.
- `ExpressionPayload` caches are scoped down to `ThreadLocal` since 
   some of the re-used components (AvroSerializer, AvroDeserializer, SafeProjection) 
   have internal mutable state and therefore are not thread-safe
  • Loading branch information
alexeykudinkin committed Dec 12, 2022
1 parent 6cd17aa commit 3067881
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hudi.command.payload

import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.{Cache, Caffeine}
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord}
import org.apache.hudi.AvroConversionUtils.convertAvroSchemaToStructType
Expand All @@ -35,13 +35,12 @@ import org.apache.hudi.io.HoodieWriteHandle
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSerializer}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.{Expression, Projection}
import org.apache.spark.sql.catalyst.expressions.{Expression, Projection, SafeProjection}
import org.apache.spark.sql.hudi.SerDeUtils
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._
import org.apache.spark.sql.types.BooleanType

import java.util.function.Function
import java.util.function.{Function, Supplier}
import java.util.{Base64, Properties}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
Expand All @@ -61,26 +60,12 @@ import scala.collection.mutable.ArrayBuffer
*/
class ExpressionPayload(@transient record: GenericRecord,
@transient orderingVal: Comparable[_])
extends DefaultHoodieRecordPayload(record, orderingVal) {
extends DefaultHoodieRecordPayload(record, orderingVal) with Logging {

def this(recordOpt: HOption[GenericRecord]) {
this(recordOpt.orElse(null), 0)
}

/**
* Target schema used for writing records into the table
*/
private var writeSchema: Schema = _

/**
* Original record's schema
*
* NOTE: To avoid excessive overhead of serializing original record's Avro schema along
* w/ _every_ record, we instead make it to be provided along with every request
* requiring this record to be deserialized
*/
private var recordSchema: Schema = _

override def combineAndGetUpdateValue(currentValue: IndexedRecord,
schema: Schema): HOption[IndexedRecord] = {
throw new IllegalStateException(s"Should not call this method for ${getClass.getCanonicalName}")
Expand All @@ -93,7 +78,7 @@ class ExpressionPayload(@transient record: GenericRecord,
override def combineAndGetUpdateValue(targetRecord: IndexedRecord,
schema: Schema,
properties: Properties): HOption[IndexedRecord] = {
init(properties)
val recordSchema = getRecordSchema(properties)

val sourceRecord = bytesToAvro(recordBytes, recordSchema)
val joinedRecord = joinRecord(sourceRecord, targetRecord)
Expand Down Expand Up @@ -136,8 +121,9 @@ class ExpressionPayload(@transient record: GenericRecord,
// If the update condition matched then execute assignment expression
// to compute final record to update. We will return the first matched record.
if (conditionEvalResult) {
val writerSchema = getWriterSchema(properties)
val resultingRow = assignmentEvaluator.apply(inputRecord.asRow)
lazy val resultingAvroRecord = getAvroSerializerFor(writeSchema)
lazy val resultingAvroRecord = getAvroSerializerFor(writerSchema)
.serialize(resultingRow)
.asInstanceOf[GenericRecord]

Expand Down Expand Up @@ -201,11 +187,11 @@ class ExpressionPayload(@transient record: GenericRecord,
* @return The result of the record to insert.
*/
private def processNotMatchedRecord(inputRecord: ConvertibleRecord, properties: Properties): HOption[IndexedRecord] = {
val insertConditionAndAssignmentsText =
properties.get(ExpressionPayload.PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS)
val insertConditionAndAssignmentsText: String =
properties.get(ExpressionPayload.PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS).toString
// Get the evaluator for each condition and insert assignment.
val insertConditionAndAssignments =
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText.toString, inputRecord.asAvro.getSchema)
ExpressionPayload.getEvaluator(insertConditionAndAssignmentsText, inputRecord.asAvro.getSchema)
var resultRecordOpt: HOption[IndexedRecord] = null
for ((conditionEvaluator, assignmentEvaluator) <- insertConditionAndAssignments
if resultRecordOpt == null) {
Expand All @@ -215,8 +201,9 @@ class ExpressionPayload(@transient record: GenericRecord,
// If matched the insert condition then execute the assignment expressions to compute the
// result record. We will return the first matched record.
if (conditionEvalResult) {
val writerSchema = getWriterSchema(properties)
val resultingRow = assignmentEvaluator.apply(inputRecord.asRow)
val resultingAvroRecord = getAvroSerializerFor(writeSchema)
val resultingAvroRecord = getAvroSerializerFor(writerSchema)
.serialize(resultingRow)
.asInstanceOf[GenericRecord]

Expand All @@ -233,9 +220,9 @@ class ExpressionPayload(@transient record: GenericRecord,
}

override def getInsertValue(schema: Schema, properties: Properties): HOption[IndexedRecord] = {
init(properties)

val recordSchema = getRecordSchema(properties)
val incomingRecord = ConvertibleRecord(bytesToAvro(recordBytes, recordSchema))

if (isDeleteRecord(incomingRecord.asAvro)) {
HOption.empty[IndexedRecord]()
} else if (isMORTable(properties)) {
Expand Down Expand Up @@ -268,20 +255,6 @@ class ExpressionPayload(@transient record: GenericRecord,
writeRecord
}

private def init(props: Properties): Unit = {
if (writeSchema == null) {
ValidationUtils.checkArgument(props.containsKey(HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key),
s"Missing ${HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key} property")
writeSchema = parseSchema(props.getProperty(HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key))
}

if (recordSchema == null) {
ValidationUtils.checkArgument(props.containsKey(PAYLOAD_RECORD_AVRO_SCHEMA),
s"Missing ${PAYLOAD_RECORD_AVRO_SCHEMA} property")
recordSchema = parseSchema(props.getProperty(PAYLOAD_RECORD_AVRO_SCHEMA))
}
}

/**
* Join the source record with the target record.
*
Expand Down Expand Up @@ -328,6 +301,11 @@ object ExpressionPayload {
val PAYLOAD_RECORD_AVRO_SCHEMA = "hoodie.payload.record.schema"

/**
* NOTE: PLEASE READ CAREFULLY
* Spark's [[SafeProjection]] are NOT thread-safe hence cache is scoped
* down to be thread-local to support the multi-threaded executors (like
* [[BoundedInMemoryQueueExecutor]], [[DisruptorExecutor]])
*
* To avoid compiling projections for Merge Into expressions for every record these
* are cached under a key of
* <ol>
Expand All @@ -338,68 +316,110 @@ object ExpressionPayload {
* NOTE: Schema is required b/c these cache is static and might be shared by multiple
* executed statements w/in a single Spark session
*/
private val projectionsCache = Caffeine.newBuilder()
.maximumSize(1024)
.build[(String, Schema), Map[Projection, Projection]]()
private val projectionsCache = ThreadLocal.withInitial(
new Supplier[Cache[(String, Schema), Seq[(Projection, Projection)]]] {
override def get(): Cache[(String, Schema), Seq[(Projection, Projection)]] = {
Caffeine.newBuilder()
.maximumSize(1024)
.build[(String, Schema), Seq[(Projection, Projection)]]()
}
})

/**
* NOTE: PLEASE READ CAREFULLY
* Spark's [[AvroDeserializer]] are NOT thread-safe hence cache is scoped
* down to be thread-local to support the multi-threaded executors (like
* [[BoundedInMemoryQueueExecutor]], [[DisruptorExecutor]])
*/
private val avroDeserializerCache = ThreadLocal.withInitial(
new Supplier[Cache[Schema, HoodieAvroDeserializer]] {
override def get(): Cache[Schema, HoodieAvroDeserializer] =
Caffeine.newBuilder()
.maximumSize(16).build[Schema, HoodieAvroDeserializer]()
}
)

/**
* NOTE: PLEASE READ CAREFULLY
* Spark's [[AvroSerializer]] are NOT thread-safe hence cache is scoped
* down to be thread-local to support the multi-threaded executors (like
* [[BoundedInMemoryQueueExecutor]], [[DisruptorExecutor]])
*/
private val avroSerializerCache = ThreadLocal.withInitial(
new Supplier[Cache[Schema, HoodieAvroSerializer]] {
override def get(): Cache[Schema, HoodieAvroSerializer] =
Caffeine.newBuilder()
.maximumSize(16).build[Schema, HoodieAvroSerializer]()
}
)

private val schemaCache = Caffeine.newBuilder()
.maximumSize(16).build[String, Schema]()

private val mergedSchemaCache = Caffeine.newBuilder()
.maximumSize(16).build[(Schema, Schema), Schema]()

private val avroDeserializerCache = Caffeine.newBuilder()
.maximumSize(16).build[Schema, HoodieAvroDeserializer]()

private val avroSerializerCache = Caffeine.newBuilder()
.maximumSize(16).build[Schema, HoodieAvroSerializer]()

private def parseSchema(schemaStr: String): Schema = {
schemaCache.get(schemaStr,
new Function[String, Schema] {
override def apply(t: String): Schema = new Schema.Parser().parse(t)
})
}

private def getRecordSchema(props: Properties) = {
ValidationUtils.checkArgument(props.containsKey(PAYLOAD_RECORD_AVRO_SCHEMA),
s"Missing ${PAYLOAD_RECORD_AVRO_SCHEMA} property")
parseSchema(props.getProperty(PAYLOAD_RECORD_AVRO_SCHEMA))
}

private def getWriterSchema(props: Properties): Schema = {
ValidationUtils.checkArgument(props.containsKey(HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key),
s"Missing ${HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key} property")
parseSchema(props.getProperty(HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key))
}

private def getAvroDeserializerFor(schema: Schema) = {
avroDeserializerCache.get(schema, new Function[Schema, HoodieAvroDeserializer] {
override def apply(t: Schema): HoodieAvroDeserializer =
sparkAdapter.createAvroDeserializer(schema, convertAvroSchemaToStructType(schema))
})
avroDeserializerCache.get()
.get(schema, new Function[Schema, HoodieAvroDeserializer] {
override def apply(t: Schema): HoodieAvroDeserializer =
sparkAdapter.createAvroDeserializer(schema, convertAvroSchemaToStructType(schema))
})
}

private def getAvroSerializerFor(schema: Schema) = {
avroSerializerCache.get(schema, new Function[Schema, HoodieAvroSerializer] {
override def apply(t: Schema): HoodieAvroSerializer =
sparkAdapter.createAvroSerializer(convertAvroSchemaToStructType(schema), schema, isNullable(schema))
})
avroSerializerCache.get()
.get(schema, new Function[Schema, HoodieAvroSerializer] {
override def apply(t: Schema): HoodieAvroSerializer =
sparkAdapter.createAvroSerializer(convertAvroSchemaToStructType(schema), schema, isNullable(schema))
})
}

/**
* Do the CodeGen for each condition and assignment expressions.We will projectionsCache it to reduce
* the compile time for each method call.
*/
private def getEvaluator(serializedConditionAssignments: String,
inputSchema: Schema): Map[Projection, Projection] = {
projectionsCache.get((serializedConditionAssignments, inputSchema),
new Function[(String, Schema), Map[Projection, Projection]] {
override def apply(key: (String, Schema)): Map[Projection, Projection] = {
val (encodedConditionalAssignments, _) = key
val serializedBytes = Base64.getDecoder.decode(encodedConditionalAssignments)
val conditionAssignments = SerDeUtils.toObject(serializedBytes)
.asInstanceOf[Map[Expression, Seq[Expression]]]
conditionAssignments.map {
case (condition, assignments) =>
// NOTE: We reuse Spark's [[Projection]]s infra for actual evaluation of the
// expressions, allowing us to execute arbitrary expression against input
// [[InternalRow]] producing another [[InternalRow]] as an outcome
val conditionEvaluator = GenerateSafeProjection.generate(Seq(condition))
val assignmentEvaluator = GenerateSafeProjection.generate(assignments)

conditionEvaluator -> assignmentEvaluator
inputSchema: Schema): Seq[(Projection, Projection)] = {
projectionsCache.get()
.get((serializedConditionAssignments, inputSchema),
new Function[(String, Schema), Seq[(Projection, Projection)]] {
override def apply(key: (String, Schema)): Seq[(Projection, Projection)] = {
val (encodedConditionalAssignments, _) = key
val serializedBytes = Base64.getDecoder.decode(encodedConditionalAssignments)
val conditionAssignments = SerDeUtils.toObject(serializedBytes)
.asInstanceOf[Map[Expression, Seq[Expression]]]
conditionAssignments.toSeq.map {
case (condition, assignments) =>
// NOTE: We reuse Spark's [[Projection]]s infra for actual evaluation of the
// expressions, allowing us to execute arbitrary expression against input
// [[InternalRow]] producing another [[InternalRow]] as an outcome
val conditionEvaluator = SafeProjection.create(Seq(condition))
val assignmentEvaluator = SafeProjection.create(assignments)

conditionEvaluator -> assignmentEvaluator
}
}
}
})
})
}

private def getMergedSchema(source: Schema, target: Schema): Schema = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.types.{DataType, StructType}

/**
* A projection that could turn UnsafeRow into GenericInternalRow
*
* NOTE: PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
* This code is borrowed from Spark 3.1.x
* This code is borrowed, to fill in the gaps of Spark 2.x
*/
object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expression], Projection] {

override protected def createCodeGeneratedObject(in: Seq[Expression]): Projection = {
GenerateSafeProjection.generate(in)
}

override protected def createInterpretedObject(in: Seq[Expression]): Projection = {
throw new UnsupportedOperationException("Interpreted safe projection is not supported for Spark 2.x!")
}

/**
* Returns a SafeProjection for given StructType.
*/
def create(schema: StructType): Projection = create(schema.fields.map(_.dataType))

/**
* Returns a SafeProjection for given Array of DataTypes.
*/
def create(fields: Array[DataType]): Projection = {
createObject(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
}

/**
* Returns a SafeProjection for given sequence of Expressions (bounded).
*/
def create(exprs: Seq[Expression]): Projection = {
createObject(exprs)
}

/**
* Returns a SafeProjection for given sequence of Expressions, which will be bound to
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
create(bindReferences(exprs, inputSchema))
}

/**
* A helper function to bind given expressions to an input schema.
*/
private def bindReferences[A <: Expression](
expressions: Seq[A],
input: AttributeSeq): Seq[A] = {
expressions.map(BindReferences.bindReference(_, input))
}
}

0 comments on commit 3067881

Please sign in to comment.