Skip to content

Commit

Permalink
Adding support for DLQ for ZetaSQL (#25426)
Browse files Browse the repository at this point in the history
* Adding support for DLQ for ZetaSQL

* fixed issue for not-all-fields are selected

* fixup

* fix spotless

* fix test
  • Loading branch information
pabloem committed Feb 13, 2023
1 parent 849692e commit 299be58
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ public static Coder<Row> generate(Schema schema) {
}
// There should never be duplicate encoding positions.
Preconditions.checkState(
schema.getFieldCount() == Arrays.stream(encodingPosToRowIndex).distinct().count());
schema.getFieldCount() == Arrays.stream(encodingPosToRowIndex).distinct().count(),
"The input schema (%s) and map for position encoding (%s) do not match.",
schema.getFields(),
encodingPosToRowIndex);

// Component coders are ordered by encoding position, but may encode a field with a different
// row index.
Expand Down Expand Up @@ -311,7 +314,12 @@ static void encodeDelegate(
boolean hasNullableFields)
throws IOException {
checkState(value.getFieldCount() == value.getSchema().getFieldCount());
checkState(encodingPosToIndex.length == value.getFieldCount());
checkState(
encodingPosToIndex.length == value.getFieldCount(),
"Unable to encode row. Expected %s values, but row has %s%s",
encodingPosToIndex.length,
value.getFieldCount(),
value.getSchema().getFieldNames());

// Encode the field count. This allows us to handle compatible schema changes.
VAR_INT_CODER.encode(value.getFieldCount(), outputStream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,24 @@
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlPipelineOptions;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters;
import org.apache.beam.sdk.extensions.sql.impl.rel.AbstractBeamCalcRel;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.extensions.sql.meta.provider.bigquery.BeamBigQuerySqlDialect;
import org.apache.beam.sdk.extensions.sql.meta.provider.bigquery.BeamSqlUnparseContext;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
Expand All @@ -64,7 +70,6 @@
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand All @@ -83,6 +88,9 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel {
private static final int MAX_PENDING_WINDOW = 32;
private final BeamSqlUnparseContext context;

private static final TupleTag<Row> rows = new TupleTag<Row>("output") {};
private static final TupleTag<Row> errors = new TupleTag<Row>("errors") {};

private static String columnName(int i) {
return "_" + i;
}
Expand All @@ -101,21 +109,36 @@ public Calc copy(RelTraitSet traitSet, RelNode input, RexProgram program) {

@Override
public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
return new Transform();
return buildPTransform(null);
}

@Override
public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform(
@Nullable PTransform<PCollection<Row>, ? extends POutput> errorsTransformer) {
return new Transform(errorsTransformer);
}

@AutoValue
abstract static class TimestampedFuture {
private static TimestampedFuture create(Instant t, Future<Value> f) {
return new AutoValue_BeamZetaSqlCalcRel_TimestampedFuture(t, f);
private static TimestampedFuture create(Instant t, Future<Value> f, Row r) {
return new AutoValue_BeamZetaSqlCalcRel_TimestampedFuture(t, f, r);
}

abstract Instant timestamp();

abstract Future<Value> future();

abstract Row row();
}

private class Transform extends PTransform<PCollectionList<Row>, PCollection<Row>> {

private final @Nullable PTransform<PCollection<Row>, ? extends POutput> errorsTransformer;

Transform(@Nullable PTransform<PCollection<Row>, ? extends POutput> errorsTransformer) {
this.errorsTransformer = errorsTransformer;
}

@Override
public PCollection<Row> expand(PCollectionList<Row> pinput) {
Preconditions.checkArgument(
Expand All @@ -135,9 +158,10 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
SqlStdOperatorTable.CASE, condition, rex, rexBuilder.makeNullLiteral(getRowType()));
}

final Schema outputSchema = CalciteUtils.toSchema(getRowType());

BeamSqlPipelineOptions options =
pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class);
Schema outputSchema = CalciteUtils.toSchema(getRowType());
CalcFn calcFn =
new CalcFn(
context.toSql(getProgram(), rex).toSqlString(DIALECT).getSql(),
Expand All @@ -147,7 +171,15 @@ public PCollection<Row> expand(PCollectionList<Row> pinput) {
options.getZetaSqlDefaultTimezone(),
options.getVerifyRowValues());

return upstream.apply(ParDo.of(calcFn)).setRowSchema(outputSchema);
PCollectionTuple tuple =
upstream.apply(ParDo.of(calcFn).withOutputTags(rows, TupleTagList.of(errors)));
tuple.get(errors).setRowSchema(calcFn.errorsSchema);

if (errorsTransformer != null) {
tuple.get(errors).apply(errorsTransformer);
}

return tuple.get(rows).setRowSchema(outputSchema);
}
}

Expand All @@ -173,6 +205,8 @@ private static class CalcFn extends DoFn<Row, Row> {
private final Schema outputSchema;
private final String defaultTimezone;
private final boolean verifyRowValues;

final Schema errorsSchema;
private final List<Integer> referencedColumns;

@FieldAccess("row")
Expand Down Expand Up @@ -205,6 +239,8 @@ private static class CalcFn extends DoFn<Row, Row> {
}
this.referencedColumns = columns.build();
this.fieldAccess = FieldAccessDescriptor.withFieldIds(this.referencedColumns);
Schema inputRowSchema = SelectHelpers.getOutputSchema(inputSchema, fieldAccess);
this.errorsSchema = BeamSqlRelUtils.getErrorRowSchema(inputRowSchema);
}
}

Expand Down Expand Up @@ -242,30 +278,39 @@ public Duration getAllowedTimestampSkew() {

@ProcessElement
public void processElement(
@FieldAccess("row") Row row, @Timestamp Instant t, BoundedWindow w, OutputReceiver<Row> r)
@FieldAccess("row") Row row,
@Timestamp Instant t,
BoundedWindow w,
OutputReceiver<Row> r,
MultiOutputReceiver multiOutputReceiver)
throws InterruptedException {
Map<String, Value> columns = new HashMap<>();
for (int i : referencedColumns) {
final Field field = inputSchema.getField(i);
columns.put(
columnName(i),
ZetaSqlBeamTranslationUtils.toZetaSqlValue(
row.getBaseValue(field.getName(), Object.class), field.getType()));
}

@NonNull
Future<Value> valueFuture = checkArgumentNotNull(stream).execute(columns, nullParams);

@Nullable Queue<TimestampedFuture> pendingWindow = pending.get(w);
if (pendingWindow == null) {
pendingWindow = new ArrayDeque<>();
pending.put(w, pendingWindow);
}
pendingWindow.add(TimestampedFuture.create(t, valueFuture));
try {
Map<String, Value> columns = new HashMap<>();
for (int i : referencedColumns) {
final Field field = inputSchema.getField(i);
columns.put(
columnName(i),
ZetaSqlBeamTranslationUtils.toZetaSqlValue(
row.getBaseValue(field.getName(), Object.class), field.getType()));
}
Future<Value> valueFuture = checkArgumentNotNull(stream).execute(columns, nullParams);
pendingWindow.add(TimestampedFuture.create(t, valueFuture, row));

} catch (UnsupportedOperationException | ArithmeticException | IllegalArgumentException e) {
multiOutputReceiver
.get(errors)
.output(Row.withSchema(errorsSchema).addValues(row, e.toString()).build());
}

while ((!pendingWindow.isEmpty() && pendingWindow.element().future().isDone())
|| pendingWindow.size() > MAX_PENDING_WINDOW) {
outputRow(pendingWindow.remove(), r);
outputRow(pendingWindow.remove(), r, multiOutputReceiver.get(errors));
}
}

Expand All @@ -274,9 +319,12 @@ public void finishBundle(FinishBundleContext c) throws InterruptedException {
checkArgumentNotNull(stream).flush();
for (Map.Entry<BoundedWindow, Queue<TimestampedFuture>> pendingWindow : pending.entrySet()) {
OutputReceiver<Row> rowOutputReciever =
new OutputReceiverForFinishBundle(c, pendingWindow.getKey());
new OutputReceiverForFinishBundle(c, pendingWindow.getKey(), rows);
OutputReceiver<Row> errorOutputReciever =
new OutputReceiverForFinishBundle(c, pendingWindow.getKey(), errors);

for (TimestampedFuture timestampedFuture : pendingWindow.getValue()) {
outputRow(timestampedFuture, rowOutputReciever);
outputRow(timestampedFuture, rowOutputReciever, errorOutputReciever);
}
}
}
Expand All @@ -288,9 +336,13 @@ private static class OutputReceiverForFinishBundle implements OutputReceiver<Row
private final FinishBundleContext c;
private final BoundedWindow w;

private OutputReceiverForFinishBundle(FinishBundleContext c, BoundedWindow w) {
private final TupleTag<Row> tag;

private OutputReceiverForFinishBundle(
FinishBundleContext c, BoundedWindow w, TupleTag<Row> tag) {
this.c = c;
this.w = w;
this.tag = tag;
}

@Override
Expand All @@ -300,11 +352,11 @@ public void output(Row output) {

@Override
public void outputWithTimestamp(Row output, Instant timestamp) {
c.output(output, timestamp, w);
c.output(tag, output, timestamp, w);
}
}

private static RuntimeException extractException(ExecutionException e) {
private static RuntimeException extractException(Throwable e) {
try {
throw checkArgumentNotNull(e.getCause());
} catch (RuntimeException r) {
Expand All @@ -314,12 +366,18 @@ private static RuntimeException extractException(ExecutionException e) {
}
}

private void outputRow(TimestampedFuture c, OutputReceiver<Row> r) throws InterruptedException {
private void outputRow(
TimestampedFuture c, OutputReceiver<Row> r, OutputReceiver<Row> errorOutputReceiver)
throws InterruptedException {
final Value v;
try {
v = c.future().get();
} catch (ExecutionException e) {
throw extractException(e);
errorOutputReceiver.outputWithTimestamp(
Row.withSchema(errorsSchema).addValues(c.row(), e.toString()).build(), c.timestamp());
return;
} catch (Throwable thr) {
throw extractException(thr);
}
if (!v.isNull()) {
Row row = ZetaSqlBeamTranslationUtils.toBeamRow(v, outputSchema, verifyRowValues);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql;

import java.util.stream.Collectors;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -93,6 +99,70 @@ public void testSingleFieldAccess() throws IllegalAccessException {
pipeline.run().waitUntilFinish();
}

@Test
public void testErrorsInCalculation() throws IllegalAccessException {
String sql = "SELECT ts, Key*7777*7777*77777*7777777*7777777777 as num, Value FROM KeyValue";

PCollection<Row> rows = compile(sql);

final NodeGetter nodeGetter = new NodeGetter(rows);
pipeline.traverseTopologically(nodeGetter);

ParDo.MultiOutput<Row, Row> pardo =
(ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();

PCollection<Row> errors =
(PCollection<Row>)
nodeGetter.producer.getOutputs().get(pardo.getAdditionalOutputTags().get(0));
Assert.assertEquals(2, errors.getSchema().getFieldCount());

PAssert.that(errors.apply(Count.globally())).containsInAnyOrder(2L);
PAssert.that(errors)
.satisfies(
(SerializableFunction<Iterable<Row>, Void>)
input -> {
Assert.assertEquals(
Lists.newArrayList(input).stream()
.map(r -> r.getRow("row").getInt64("Key"))
.collect(Collectors.toSet()),
Sets.newHashSet(14L, 15L));
return null;
});
pipeline.run().waitUntilFinish();
}

@Test
public void testErrorsInCalculationWithSelectedCols() throws IllegalAccessException {
String sql = "SELECT ts, Key*7777*7777*77777*7777777*7777777777 as num FROM KeyValue";

PCollection<Row> rows = compile(sql);

final NodeGetter nodeGetter = new NodeGetter(rows);
pipeline.traverseTopologically(nodeGetter);

ParDo.MultiOutput<Row, Row> pardo =
(ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();

PCollection<Row> errors =
(PCollection<Row>)
nodeGetter.producer.getOutputs().get(pardo.getAdditionalOutputTags().get(0));
Assert.assertEquals(2, errors.getSchema().getFieldCount());

PAssert.that(errors.apply(Count.globally())).containsInAnyOrder(2L);
PAssert.that(errors)
.satisfies(
(SerializableFunction<Iterable<Row>, Void>)
input -> {
Assert.assertEquals(
Lists.newArrayList(input).stream()
.map(r -> r.getRow("row").getInt64("Key"))
.collect(Collectors.toSet()),
Sets.newHashSet(14L, 15L));
return null;
});
pipeline.run().waitUntilFinish();
}

@Test
public void testNoFieldAccess() throws IllegalAccessException {
String sql = "SELECT 1 FROM KeyValue";
Expand Down
Loading

0 comments on commit 299be58

Please sign in to comment.