Skip to content

Commit

Permalink
[FLINK-23475][runtime][checkpoint] Supports repartition partly finish…
Browse files Browse the repository at this point in the history
…ed broacast states (apache#16714)

* [FLINK-23475][runtime][checkpoint] Supports repartition partly finished broacast states

* Fix the judgement of unfinished and wrong comment / variable

* Fix the tests

* Add todo
  • Loading branch information
gaoyunhaii committed Aug 12, 2021
1 parent 58ff344 commit 1c563a0
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -69,14 +70,23 @@ public List<List<OperatorStateHandle>> repartitionState(
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>
unionStates = collectUnionStates(previousParallelSubtaskStates);

if (unionStates.isEmpty()) {
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>
partlyFinishedBroadcastStates =
collectPartlyFinishedBroadcastStates(previousParallelSubtaskStates);

if (unionStates.isEmpty() && partlyFinishedBroadcastStates.isEmpty()) {
return previousParallelSubtaskStates;
}

// Initialize
mergeMapList = initMergeMapList(previousParallelSubtaskStates);

repartitionUnionState(unionStates, mergeMapList);

// TODO: Currently if some tasks is finished, we would rescale the
// remaining state. A better solution would be not touch the non-empty
// subtask state and only fix the empty ones.
repartitionBroadcastState(partlyFinishedBroadcastStates, mergeMapList);
} else {

// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
Expand Down Expand Up @@ -123,14 +133,31 @@ private List<Map<StreamStateHandle, OperatorStateHandle>> initMergeMapList(
return mergeMapList;
}

/** Collect union states from given parallelSubtaskStates. */
private Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>
collectUnionStates(List<List<OperatorStateHandle>> parallelSubtaskStates) {
return collectStates(parallelSubtaskStates, OperatorStateHandle.Mode.UNION).entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().entries));
}

Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>
unionStates = new HashMap<>(parallelSubtaskStates.size());
private Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>
collectPartlyFinishedBroadcastStates(
List<List<OperatorStateHandle>> parallelSubtaskStates) {
return collectStates(parallelSubtaskStates, OperatorStateHandle.Mode.BROADCAST).entrySet()
.stream()
.filter(e -> e.getValue().isPartiallyReported())
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().entries));
}

/** Collect the states from given parallelSubtaskStates with the specific {@code mode}. */
private Map<String, StateEntry> collectStates(
List<List<OperatorStateHandle>> parallelSubtaskStates, OperatorStateHandle.Mode mode) {

Map<String, StateEntry> states = new HashMap<>(parallelSubtaskStates.size());

for (List<OperatorStateHandle> subTaskState : parallelSubtaskStates) {
for (int i = 0; i < parallelSubtaskStates.size(); ++i) {
final int subtaskIndex = i;
List<OperatorStateHandle> subTaskState = parallelSubtaskStates.get(i);
for (OperatorStateHandle operatorStateHandle : subTaskState) {
if (operatorStateHandle == null) {
continue;
Expand All @@ -141,36 +168,28 @@ private List<Map<StreamStateHandle, OperatorStateHandle>> initMergeMapList(
operatorStateHandle.getStateNameToPartitionOffsets().entrySet();

partitionOffsetEntries.stream()
.filter(
entry ->
entry.getValue()
.getDistributionMode()
.equals(OperatorStateHandle.Mode.UNION))
.filter(entry -> entry.getValue().getDistributionMode().equals(mode))
.forEach(
entry -> {
List<
Tuple2<
StreamStateHandle,
OperatorStateHandle.StateMetaInfo>>
stateLocations =
unionStates.computeIfAbsent(
entry.getKey(),
k ->
new ArrayList<>(
parallelSubtaskStates
.size()
* partitionOffsetEntries
.size()));

stateLocations.add(
StateEntry stateEntry =
states.computeIfAbsent(
entry.getKey(),
k ->
new StateEntry(
parallelSubtaskStates.size()
* partitionOffsetEntries
.size(),
parallelSubtaskStates.size()));
stateEntry.addEntry(
subtaskIndex,
Tuple2.of(
operatorStateHandle.getDelegateStateHandle(),
entry.getValue()));
});
}
}

return unionStates;
return states;
}

/** Group by the different named states. */
Expand Down Expand Up @@ -459,4 +478,26 @@ private static final class GroupByStateNameResults {
return byMode.get(mode);
}
}

private static final class StateEntry {
final List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> entries;
final BitSet reportedSubtaskIndices;

public StateEntry(int estimatedEntrySize, int parallelism) {
this.entries = new ArrayList<>(estimatedEntrySize);
this.reportedSubtaskIndices = new BitSet(parallelism);
}

void addEntry(
int subtaskIndex,
Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> entry) {
this.entries.add(entry);
reportedSubtaskIndices.set(subtaskIndex);
}

boolean isPartiallyReported() {
return reportedSubtaskIndices.cardinality() > 0
&& reportedSubtaskIndices.cardinality() < reportedSubtaskIndices.size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,57 @@ public void testRepartitionBroadcastState() {
verifyOneKindPartitionableStateRescale(operatorState, operatorID);
}

@Test
public void testRepartitionBroadcastStateWithNullSubtaskState() {
OperatorID operatorID = new OperatorID();
OperatorState operatorState = new OperatorState(operatorID, 2, 4);

// Only the subtask 0 reports the states.
Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<>(2);
metaInfoMap1.put(
"t-5",
new OperatorStateHandle.StateMetaInfo(
new long[] {0, 10, 20}, OperatorStateHandle.Mode.BROADCAST));
metaInfoMap1.put(
"t-6",
new OperatorStateHandle.StateMetaInfo(
new long[] {30, 40, 50}, OperatorStateHandle.Mode.BROADCAST));
OperatorStateHandle osh1 =
new OperatorStreamStateHandle(
metaInfoMap1, new ByteStreamStateHandle("test1", new byte[60]));
operatorState.putState(
0, OperatorSubtaskState.builder().setManagedOperatorState(osh1).build());

verifyOneKindPartitionableStateRescale(operatorState, operatorID);
}

@Test
public void testRepartitionBroadcastStateWithEmptySubtaskState() {
OperatorID operatorID = new OperatorID();
OperatorState operatorState = new OperatorState(operatorID, 2, 4);

// Only the subtask 0 reports the states.
Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<>(2);
metaInfoMap1.put(
"t-5",
new OperatorStateHandle.StateMetaInfo(
new long[] {0, 10, 20}, OperatorStateHandle.Mode.BROADCAST));
metaInfoMap1.put(
"t-6",
new OperatorStateHandle.StateMetaInfo(
new long[] {30, 40, 50}, OperatorStateHandle.Mode.BROADCAST));
OperatorStateHandle osh1 =
new OperatorStreamStateHandle(
metaInfoMap1, new ByteStreamStateHandle("test1", new byte[60]));
operatorState.putState(
0, OperatorSubtaskState.builder().setManagedOperatorState(osh1).build());

// The subtask 1 report an empty snapshot.
operatorState.putState(1, OperatorSubtaskState.builder().build());

verifyOneKindPartitionableStateRescale(operatorState, operatorID);
}

/** Verify repartition logic on partitionable states with all modes. */
@Test
public void testReDistributeCombinedPartitionableStates() {
Expand Down

0 comments on commit 1c563a0

Please sign in to comment.