Skip to content

Commit

Permalink
[BEAM-7882] Invoke Spark API incompatible methods by reflection
Browse files Browse the repository at this point in the history
It adds also Spark 3 specific version of the methods
  • Loading branch information
iemejia committed Aug 17, 2019
1 parent 96abacb commit c662f28
Showing 1 changed file with 58 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
*/
package org.apache.beam.runners.spark.util;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.runners.spark.translation.SparkCombineFn;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
Expand All @@ -39,7 +42,25 @@ public class SparkCompat {
*/
public static <T> JavaDStream<WindowedValue<T>> joinStreams(
JavaStreamingContext streamingContext, List<JavaDStream<WindowedValue<T>>> dStreams) {
return streamingContext.union(dStreams.remove(0), dStreams);
try {
if (streamingContext.sparkContext().version().startsWith("3")) {
// This invokes by reflection the equivalent of:
// return streamingContext.union(
// JavaConverters.asScalaIteratorConverter(dStreams.iterator()).asScala().toSeq());
Method method = streamingContext.getClass().getDeclaredMethod("union", JavaDStream[].class);
Object result =
method.invoke(streamingContext, new Object[] {dStreams.toArray(new JavaDStream[0])});
return (JavaDStream<WindowedValue<T>>) result;
}
// This invokes by reflection the equivalent of:
// return streamingContext.union(dStreams.remove(0), dStreams);
Method method =
streamingContext.getClass().getDeclaredMethod("union", JavaDStream.class, List.class);
Object result = method.invoke(streamingContext, dStreams.remove(0), dStreams);
return (JavaDStream<WindowedValue<T>>) result;
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException("Error invoking Spark union", e);
}
}

/**
Expand All @@ -52,14 +73,41 @@ public static <K, InputT, AccumT, OutputT> JavaPairRDD<K, WindowedValue<OutputT>
JavaPairRDD<K, SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>>
accumulatePerKey,
SparkCombineFn<KV<K, InputT>, InputT, AccumT, OutputT> sparkCombineFn) {
Function<
SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>,
Iterable<WindowedValue<OutputT>>>
flatMapFunction =
windowedAccumulator ->
sparkCombineFn
.extractOutputStream(windowedAccumulator)
.collect(Collectors.toList());
return accumulatePerKey.flatMapValues(flatMapFunction);
try {
if (accumulatePerKey.context().version().startsWith("3")) {
FlatMapFunction<
SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>,
WindowedValue<OutputT>>
flatMapFunction =
(FlatMapFunction<
SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>,
WindowedValue<OutputT>>)
windowedAccumulator ->
sparkCombineFn.extractOutputStream(windowedAccumulator).iterator();
// This invokes by reflection the equivalent of:
// return accumulatePerKey.flatMapValues(flatMapFunction);
Method method =
accumulatePerKey.getClass().getDeclaredMethod("flatMapValues", FlatMapFunction.class);
Object result = method.invoke(accumulatePerKey, flatMapFunction);
return (JavaPairRDD<K, WindowedValue<OutputT>>) result;
}

Function<
SparkCombineFn.WindowedAccumulator<KV<K, InputT>, InputT, AccumT, ?>,
Iterable<WindowedValue<OutputT>>>
flatMapFunction =
windowedAccumulator ->
sparkCombineFn
.extractOutputStream(windowedAccumulator)
.collect(Collectors.toList());
// This invokes by reflection the equivalent of:
// return accumulatePerKey.flatMapValues(flatMapFunction);
Method method =
accumulatePerKey.getClass().getDeclaredMethod("flatMapValues", Function.class);
Object result = method.invoke(accumulatePerKey, flatMapFunction);
return (JavaPairRDD<K, WindowedValue<OutputT>>) result;
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException("Error invoking Spark flatMapValues", e);
}
}
}

0 comments on commit c662f28

Please sign in to comment.