diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py index 6cebafd94e761..ce85de9ba9208 100644 --- a/flink-python/pyflink/datastream/data_stream.py +++ b/flink-python/pyflink/datastream/data_stream.py @@ -641,6 +641,8 @@ def execute_and_collect(self, job_execution_name: str = None, limit: int = None) :param job_execution_name: The name of the job execution. :param limit: The limit for the collected elements. """ + JPythonConfigUtil = get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil + JPythonConfigUtil.configPythonOperator(self._j_data_stream.getExecutionEnvironment()) if job_execution_name is None and limit is None: return CloseableIterator(self._j_data_stream.executeAndCollect(), self.get_type()) elif job_execution_name is not None and limit is None: diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 9d0543acc3634..994cd60a91e17 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -387,11 +387,11 @@ def test_execute_and_collect(self): decimal.Decimal('2000000000000000000.061111111111111' '11111111111111'))] expected = test_data - ds = self.env.from_collection(test_data) + ds = self.env.from_collection(test_data).map(lambda a: a) with ds.execute_and_collect() as results: - actual = [] - for result in results: - actual.append(result) + actual = [result for result in results] + actual.sort() + expected.sort() self.assertEqual(expected, actual) def test_key_by_map(self): @@ -942,7 +942,7 @@ def test_partition_custom(self): expected_num_partitions = 5 def my_partitioner(key, num_partitions): - assert expected_num_partitions, num_partitions + assert expected_num_partitions == num_partitions return key % num_partitions partitioned_stream = ds.map(lambda x: x, output_type=Types.ROW([Types.STRING(), diff --git a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java index 6681c1f0fa456..7efaf44a19026 100644 --- a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java +++ b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java @@ -30,20 +30,16 @@ import org.apache.flink.python.PythonConfig; import org.apache.flink.python.PythonOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.graph.StreamGraph; -import org.apache.flink.streaming.api.graph.StreamNode; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; -import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator; import org.apache.flink.streaming.api.operators.python.OneInputPythonFunctionOperator; -import org.apache.flink.streaming.api.operators.python.PythonKeyedProcessOperator; import org.apache.flink.streaming.api.operators.python.PythonPartitionCustomOperator; import org.apache.flink.streaming.api.operators.python.PythonTimestampsAndWatermarksOperator; -import org.apache.flink.streaming.api.operators.python.TwoInputPythonFunctionOperator; import org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation; import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.streaming.api.transformations.PartitionTransformation; import org.apache.flink.streaming.api.transformations.TwoInputTransformation; import org.apache.flink.streaming.api.transformations.WithBoundedness; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; @@ -53,7 +49,6 @@ import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.Collection; import java.util.List; /** @@ -102,37 +97,6 @@ public static Configuration getEnvironmentConfig(StreamExecutionEnvironment env) return (Configuration) getConfigurationMethod.invoke(env); } - /** - * Configure the {@link OneInputPythonFunctionOperator} to be chained with the - * upstream/downstream operator by setting their parallelism, slot sharing group, co-location - * group to be the same, and applying a {@link ForwardPartitioner}. 1. operator with name - * "_keyed_stream_values_operator" should align with its downstream operator. 2. operator with - * name "_stream_key_by_map_operator" should align with its upstream operator. - */ - private static void alignStreamNode(StreamNode streamNode, StreamGraph streamGraph) { - if (streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) { - StreamEdge downStreamEdge = streamNode.getOutEdges().get(0); - StreamNode downStreamNode = streamGraph.getStreamNode(downStreamEdge.getTargetId()); - chainStreamNode(downStreamEdge, streamNode, downStreamNode); - downStreamEdge.setPartitioner(new ForwardPartitioner()); - } - - if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME) - || streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) { - StreamEdge upStreamEdge = streamNode.getInEdges().get(0); - StreamNode upStreamNode = streamGraph.getStreamNode(upStreamEdge.getSourceId()); - chainStreamNode(upStreamEdge, streamNode, upStreamNode); - } - } - - private static void chainStreamNode( - StreamEdge streamEdge, StreamNode firstStream, StreamNode secondStream) { - streamEdge.setPartitioner(new ForwardPartitioner<>()); - firstStream.setParallelism(secondStream.getParallelism()); - firstStream.setCoLocationGroup(secondStream.getCoLocationGroup()); - firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup()); - } - /** Set Python Operator Use Managed Memory. */ public static void declareManagedMemory( Transformation transformation, @@ -144,16 +108,6 @@ public static void declareManagedMemory( } } - private static void declareManagedMemory(Transformation transformation) { - if (isPythonOperator(transformation)) { - transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); - } - List> inputTransformations = transformation.getInputs(); - for (Transformation inputTransformation : inputTransformations) { - declareManagedMemory(inputTransformation); - } - } - /** * Generate a {@link StreamGraph} for transformations maintained by current {@link * StreamExecutionEnvironment}, and reset the merged env configurations with dependencies to @@ -165,56 +119,159 @@ public static StreamGraph generateStreamGraphWithDependencies( StreamExecutionEnvironment env, boolean clearTransformations) throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, NoSuchFieldException { - Configuration mergedConfig = getEnvConfigWithDependencies(env); - - boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig); - - if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) { - Field transformationsField = - StreamExecutionEnvironment.class.getDeclaredField("transformations"); - transformationsField.setAccessible(true); - for (Transformation transform : - (List>) transformationsField.get(env)) { - if (isPythonOperator(transform)) { - transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); - } - } - } + configPythonOperator(env); String jobName = getEnvironmentConfig(env) .getString( PipelineOptions.NAME, StreamExecutionEnvironment.DEFAULT_JOB_NAME); - StreamGraph streamGraph = env.getStreamGraph(jobName, clearTransformations); - Collection streamNodes = streamGraph.getStreamNodes(); - for (StreamNode streamNode : streamNodes) { - alignStreamNode(streamNode, streamGraph); - StreamOperatorFactory streamOperatorFactory = streamNode.getOperatorFactory(); - if (streamOperatorFactory instanceof SimpleOperatorFactory) { - StreamOperator streamOperator = - ((SimpleOperatorFactory) streamOperatorFactory).getOperator(); - if ((streamOperator instanceof OneInputPythonFunctionOperator) - || (streamOperator instanceof TwoInputPythonFunctionOperator) - || (streamOperator instanceof PythonKeyedProcessOperator)) { - AbstractPythonFunctionOperator pythonFunctionOperator = - (AbstractPythonFunctionOperator) streamOperator; + return env.getStreamGraph(jobName, clearTransformations); + } + + @SuppressWarnings("unchecked") + public static void configPythonOperator(StreamExecutionEnvironment env) + throws IllegalAccessException, NoSuchMethodException, InvocationTargetException, + NoSuchFieldException { + Configuration mergedConfig = getEnvConfigWithDependencies(env); + boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig); + + Field transformationsField = + StreamExecutionEnvironment.class.getDeclaredField("transformations"); + transformationsField.setAccessible(true); + List> transformations = + (List>) transformationsField.get(env); + for (Transformation transformation : transformations) { + alignTransformation(transformation); + if (isPythonOperator(transformation)) { + transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); + AbstractPythonFunctionOperator pythonFunctionOperator = + getPythonOperator(transformation); + if (pythonFunctionOperator != null) { Configuration oldConfig = pythonFunctionOperator.getPythonConfig().getMergedConfig(); pythonFunctionOperator.setPythonConfig( generateNewPythonConfig(oldConfig, mergedConfig)); - if (streamOperator instanceof PythonTimestampsAndWatermarksOperator) { - ((PythonTimestampsAndWatermarksOperator) streamOperator) + if (pythonFunctionOperator instanceof PythonTimestampsAndWatermarksOperator) { + ((PythonTimestampsAndWatermarksOperator) pythonFunctionOperator) .configureEmitProgressiveWatermarks(!executedInBatchMode); } } } } - setStreamPartitionCustomOperatorNumPartitions(streamNodes, streamGraph); + // Update the numPartitions of PartitionCustomOperator after aligned all + // operators. + for (Transformation transformation : transformations) { + Transformation upTransformation = transformation.getInputs().get(0); + if (upTransformation instanceof PartitionTransformation) { + upTransformation = upTransformation.getInputs().get(0); + } + AbstractPythonFunctionOperator pythonFunctionOperator = + getPythonOperator(upTransformation); + if (pythonFunctionOperator instanceof PythonPartitionCustomOperator) { + PythonPartitionCustomOperator partitionCustomFunctionOperator = + (PythonPartitionCustomOperator) pythonFunctionOperator; - return streamGraph; + partitionCustomFunctionOperator.setNumPartitions(transformation.getParallelism()); + } + } + } + + public static Configuration getMergedConfig( + StreamExecutionEnvironment env, TableConfig tableConfig) { + try { + Configuration config = new Configuration(getEnvironmentConfig(env)); + PythonDependencyUtils.merge(config, tableConfig.getConfiguration()); + Configuration mergedConfig = + PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config); + mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId()); + return mergedConfig; + } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { + throw new TableException("Method getMergedConfig failed.", e); + } + } + + @SuppressWarnings("unchecked") + public static Configuration getMergedConfig(ExecutionEnvironment env, TableConfig tableConfig) { + try { + Field field = ExecutionEnvironment.class.getDeclaredField("cacheFile"); + field.setAccessible(true); + Configuration config = new Configuration(env.getConfiguration()); + PythonDependencyUtils.merge(config, tableConfig.getConfiguration()); + Configuration mergedConfig = + PythonDependencyUtils.configurePythonDependencies( + (List>) + field.get(env), + config); + mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId()); + return mergedConfig; + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new TableException("Method getMergedConfig failed.", e); + } + } + + /** + * Configure the {@link OneInputPythonFunctionOperator} to be chained with the + * upstream/downstream operator by setting their parallelism, slot sharing group, co-location + * group to be the same, and applying a {@link ForwardPartitioner}. 1. operator with name + * "_keyed_stream_values_operator" should align with its downstream operator. 2. operator with + * name "_stream_key_by_map_operator" should align with its upstream operator. + */ + private static void alignTransformation(Transformation transformation) + throws NoSuchFieldException, IllegalAccessException { + String transformName = transformation.getName(); + Transformation upTransform = transformation.getInputs().get(0); + String upTransformName = upTransform.getName(); + if (upTransformName.equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) { + chainTransformation(upTransform, transformation); + configForwardPartitioner(upTransform, transformation); + } + if (transformName.equals(STREAM_KEY_BY_MAP_OPERATOR_NAME) + || transformName.equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) { + + chainTransformation(transformation, upTransform); + configForwardPartitioner(upTransform, transformation); + } + } + + private static void chainTransformation( + Transformation firstTransformation, Transformation secondTransformation) { + firstTransformation.setSlotSharingGroup(secondTransformation.getSlotSharingGroup()); + firstTransformation.setCoLocationGroupKey(secondTransformation.getCoLocationGroupKey()); + firstTransformation.setParallelism(secondTransformation.getParallelism()); + } + + private static void configForwardPartitioner( + Transformation upTransformation, Transformation transformation) + throws IllegalAccessException, NoSuchFieldException { + // set ForwardPartitioner + PartitionTransformation partitionTransform = + new PartitionTransformation<>(upTransformation, new ForwardPartitioner<>()); + Field inputTransformationField = transformation.getClass().getDeclaredField("input"); + inputTransformationField.setAccessible(true); + inputTransformationField.set(transformation, partitionTransform); + } + + private static AbstractPythonFunctionOperator getPythonOperator( + Transformation transformation) { + StreamOperatorFactory operatorFactory = null; + if (transformation instanceof OneInputTransformation) { + operatorFactory = ((OneInputTransformation) transformation).getOperatorFactory(); + } else if (transformation instanceof TwoInputTransformation) { + operatorFactory = ((TwoInputTransformation) transformation).getOperatorFactory(); + } else if (transformation instanceof AbstractMultipleInputTransformation) { + operatorFactory = + ((AbstractMultipleInputTransformation) transformation).getOperatorFactory(); + } + if (operatorFactory instanceof SimpleOperatorFactory + && ((SimpleOperatorFactory) operatorFactory).getOperator() + instanceof AbstractPythonFunctionOperator) { + return (AbstractPythonFunctionOperator) + ((SimpleOperatorFactory) operatorFactory).getOperator(); + } + return null; } private static boolean isPythonOperator(StreamOperatorFactory streamOperatorFactory) { @@ -239,27 +296,6 @@ private static boolean isPythonOperator(Transformation transform) { } } - private static void setStreamPartitionCustomOperatorNumPartitions( - Collection streamNodes, StreamGraph streamGraph) { - for (StreamNode streamNode : streamNodes) { - StreamOperatorFactory streamOperatorFactory = streamNode.getOperatorFactory(); - if (streamOperatorFactory instanceof SimpleOperatorFactory) { - StreamOperator streamOperator = - ((SimpleOperatorFactory) streamOperatorFactory).getOperator(); - if (streamOperator instanceof PythonPartitionCustomOperator) { - PythonPartitionCustomOperator partitionCustomFunctionOperator = - (PythonPartitionCustomOperator) streamOperator; - // Update the numPartitions of PartitionCustomOperator after aligned all - // operators. - partitionCustomFunctionOperator.setNumPartitions( - streamGraph - .getStreamNode(streamNode.getOutEdges().get(0).getTargetId()) - .getParallelism()); - } - } - } - } - /** * Generator a new {@link PythonConfig} with the combined config which is derived from * oldConfig. @@ -295,36 +331,13 @@ private static boolean isExecuteInBatchMode( return !existsUnboundedSource; } - public static Configuration getMergedConfig( - StreamExecutionEnvironment env, TableConfig tableConfig) { - try { - Configuration config = new Configuration(getEnvironmentConfig(env)); - PythonDependencyUtils.merge(config, tableConfig.getConfiguration()); - Configuration mergedConfig = - PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config); - mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId()); - return mergedConfig; - } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { - throw new TableException("Method getMergedConfig failed.", e); + private static void declareManagedMemory(Transformation transformation) { + if (isPythonOperator(transformation)) { + transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } - } - - @SuppressWarnings("unchecked") - public static Configuration getMergedConfig(ExecutionEnvironment env, TableConfig tableConfig) { - try { - Field field = ExecutionEnvironment.class.getDeclaredField("cacheFile"); - field.setAccessible(true); - Configuration config = new Configuration(env.getConfiguration()); - PythonDependencyUtils.merge(config, tableConfig.getConfiguration()); - Configuration mergedConfig = - PythonDependencyUtils.configurePythonDependencies( - (List>) - field.get(env), - config); - mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId()); - return mergedConfig; - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new TableException("Method getMergedConfig failed.", e); + List> inputTransformations = transformation.getInputs(); + for (Transformation inputTransformation : inputTransformations) { + declareManagedMemory(inputTransformation); } } }