Skip to content

Commit

Permalink
[BEAM-6783] byte[] breaks in BeamSQL codegen
Browse files Browse the repository at this point in the history
Fix nested and repeated BYTES fields.
  • Loading branch information
kanterov committed Jul 31, 2019
1 parent 60fab9b commit 6cc59ac
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.util.AbstractList;
import java.util.AbstractMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamJavaTypeFactory;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
Expand All @@ -44,6 +47,7 @@
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.calcite.DataContext;
Expand All @@ -58,6 +62,7 @@
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.GotoExpressionKind;
import org.apache.calcite.linq4j.tree.MemberDeclaration;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.plan.RelOptCluster;
Expand Down Expand Up @@ -336,7 +341,7 @@ private Expression castOutputTime(Expression value, FieldType toType) {
}

private static class InputGetterImpl implements RexToLixTranslator.InputGetter {
private static final Map<TypeName, String> typeGetterMap =
private static final Map<TypeName, String> TYPE_GETTER_MAP =
ImmutableMap.<TypeName, String>builder()
.put(TypeName.BYTE, "getByte")
.put(TypeName.BYTES, "getBytes")
Expand All @@ -354,7 +359,7 @@ private static class InputGetterImpl implements RexToLixTranslator.InputGetter {
.put(TypeName.ROW, "getRow")
.build();

private static final Map<String, String> logicalTypeGetterMap =
private static final Map<String, String> LOGICAL_TYPE_GETTER_MAP =
ImmutableMap.<String, String>builder()
.put(DateType.IDENTIFIER, "getDateTime")
.put(TimeType.IDENTIFIER, "getDateTime")
Expand All @@ -373,63 +378,135 @@ private InputGetterImpl(Expression input, Schema inputSchema) {

@Override
public Expression field(BlockBuilder list, int index, Type storageType) {
if (index >= inputSchema.getFieldCount() || index < 0) {
throw new IllegalArgumentException("Unable to find field #" + index);
return value(list, index, storageType, input, inputSchema);
}

private static Expression value(
BlockBuilder list, int index, Type storageType, Expression input, Schema schema) {
if (index >= schema.getFieldCount() || index < 0) {
throw new IllegalArgumentException("Unable to find value #" + index);
}

final Expression expression = list.append("current", input);
final Expression expression = list.append(list.newName("current"), input);
if (storageType == Object.class) {
return Expressions.convert_(
Expressions.call(expression, "getValue", Expressions.constant(index)), Object.class);
}
FieldType fromType = inputSchema.getField(index).getType();
FieldType fromType = schema.getField(index).getType();
String getter;
if (fromType.getTypeName().isLogicalType()) {
getter = logicalTypeGetterMap.get(fromType.getLogicalType().getIdentifier());
getter = LOGICAL_TYPE_GETTER_MAP.get(fromType.getLogicalType().getIdentifier());
} else {
getter = typeGetterMap.get(fromType.getTypeName());
getter = TYPE_GETTER_MAP.get(fromType.getTypeName());
}
if (getter == null) {
throw new IllegalArgumentException("Unable to get " + fromType.getTypeName());
}
Expression field = Expressions.call(expression, getter, Expressions.constant(index));
if (fromType.getTypeName().isLogicalType()) {
Expression millisField = Expressions.call(field, "getMillis");
String logicalId = fromType.getLogicalType().getIdentifier();

Expression value = Expressions.call(expression, getter, Expressions.constant(index));

return value(value, fromType);
}

private static Expression value(Expression value, Schema.FieldType type) {
if (type.getTypeName().isLogicalType()) {
Expression millisField = Expressions.call(value, "getMillis");
String logicalId = type.getLogicalType().getIdentifier();
if (logicalId.equals(TimeType.IDENTIFIER)) {
field = nullOr(field, Expressions.convert_(millisField, int.class));
return nullOr(value, Expressions.convert_(millisField, int.class));
} else if (logicalId.equals(DateType.IDENTIFIER)) {
field =
value =
nullOr(
field,
value,
Expressions.convert_(
Expressions.divide(millisField, Expressions.constant(MILLIS_PER_DAY)),
int.class));
} else if (!logicalId.equals(CharType.IDENTIFIER)) {
throw new IllegalArgumentException(
"Unknown LogicalType " + fromType.getLogicalType().getIdentifier());
"Unknown LogicalType " + type.getLogicalType().getIdentifier());
}
} else if (CalciteUtils.isDateTimeType(fromType)) {
field = nullOr(field, Expressions.call(field, "getMillis"));
} else if (fromType.getTypeName().isCompositeType()
|| (fromType.getTypeName().isCollectionType()
&& fromType.getCollectionElementType().getTypeName().isCompositeType())) {
field =
Expressions.condition(
Expressions.equal(field, Expressions.constant(null)),
Expressions.constant(null),
Expressions.call(WrappedList.class, "of", field));
} else if (fromType.getTypeName().isMapType()
&& fromType.getMapValueType().getTypeName().isCompositeType()) {
field = nullOr(field, Expressions.call(WrappedList.class, "ofMapValues", field));
} else if (fromType.getTypeName() == TypeName.BYTES) {
field =
Expressions.condition(
Expressions.equal(field, Expressions.constant(null)),
Expressions.constant(null),
Expressions.new_(ByteString.class, field));
} else if (type.getTypeName().isMapType()) {
return nullOr(value, map(value, type.getMapValueType()));
} else if (CalciteUtils.isDateTimeType(type)) {
return nullOr(value, Expressions.call(value, "getMillis"));
} else if (type.getTypeName().isCompositeType()) {
return nullOr(value, row(value, type.getRowSchema()));
} else if (type.getTypeName().isCollectionType()) {
return nullOr(value, list(value, type.getCollectionElementType()));
} else if (type.getTypeName() == TypeName.BYTES) {
return nullOr(
value, Expressions.new_(ByteString.class, Types.castIfNecessary(byte[].class, value)));
}
return field;

return value;
}

private static Expression list(Expression input, FieldType elementType) {
ParameterExpression value = Expressions.parameter(Object.class);

BlockBuilder block = new BlockBuilder();
block.add(value(value, elementType));

return Expressions.new_(
WrappedList.class,
ImmutableList.of(Types.castIfNecessary(List.class, input)),
ImmutableList.<MemberDeclaration>of(
Expressions.methodDecl(
Modifier.PUBLIC,
Object.class,
"value",
ImmutableList.of(value),
block.toBlock())));
}

private static Expression map(Expression input, FieldType mapValueType) {
ParameterExpression value = Expressions.parameter(Object.class);

BlockBuilder block = new BlockBuilder();
block.add(value(value, mapValueType));

return Expressions.new_(
WrappedMap.class,
ImmutableList.of(Types.castIfNecessary(Map.class, input)),
ImmutableList.<MemberDeclaration>of(
Expressions.methodDecl(
Modifier.PUBLIC,
Object.class,
"value",
ImmutableList.of(value),
block.toBlock())));
}

private static Expression row(Expression input, Schema schema) {
ParameterExpression row = Expressions.parameter(Row.class);
ParameterExpression index = Expressions.parameter(int.class);
BlockBuilder body = new BlockBuilder(/* optimizing= */ false);

for (int i = 0; i < schema.getFieldCount(); i++) {
BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body);
Expression returnValue = value(list, i, /* storageType= */ null, row, schema);

list.append(returnValue);

body.append(
"if i=" + i,
Expressions.block(
Expressions.ifThen(
Expressions.equal(index, Expressions.constant(i, int.class)), list.toBlock())));
}

body.add(Expressions.throw_(Expressions.new_(IndexOutOfBoundsException.class)));

return Expressions.new_(
WrappedRow.class,
ImmutableList.of(Types.castIfNecessary(Row.class, input)),
ImmutableList.<MemberDeclaration>of(
Expressions.methodDecl(
Modifier.PUBLIC,
Object.class,
"field",
ImmutableList.of(row, index),
body.toBlock())));
}
}

Expand Down Expand Up @@ -470,44 +547,73 @@ public Object get(String name) {
}
}

/** WrappedList translates {@code Row} and {@code List} on access. */
public static class WrappedList extends AbstractList<Object> {
/** WrappedRow translates {@code Row} on access. */
public abstract static class WrappedRow extends AbstractList<Object> {
private final Row row;

protected WrappedRow(Row row) {
this.row = row;
}

@Override
public Object get(int index) {
return field(row, index);
}

private final List<Object> list;
// we could override get(int index) if we knew how to access `this.row` in linq4j
// for now we keep it consistent with WrappedList
protected abstract Object field(Row row, int index);

private WrappedList(List<Object> list) {
this.list = list;
@Override
public int size() {
return row.getFieldCount();
}
}

public static List<Object> of(List list) {
if (list instanceof WrappedList) {
return list;
}
return new WrappedList(list);
/** WrappedMap translates {@code Map} on access. */
public abstract static class WrappedMap<V> extends AbstractMap<Object, V> {
private final Map<Object, Object> map;

protected WrappedMap(Map<Object, Object> map) {
this.map = map;
}

public static List<Object> of(Row row) {
return new WrappedList(row.getValues());
// TODO transform keys, in this case, we need to do lookup, so it should be both ways:
//
// public abstract Object fromKey(K key)
// public abstract K toKey(Object key)

@Override
public Set<Entry<Object, V>> entrySet() {
return Maps.transformValues(map, val -> (val == null) ? null : value(val)).entrySet();
}

public static Map<Object, List> ofMapValues(Map<Object, Row> map) {
return Maps.transformValues(map, val -> (val == null) ? null : WrappedList.of(val));
@Override
public V get(Object key) {
return value(map.get(key));
}

protected abstract V value(Object value);
}

/** WrappedList translates {@code List} on access. */
public abstract static class WrappedList<T> extends AbstractList<T> {
private final List<Object> values;

protected WrappedList(List<Object> values) {
this.values = values;
}

@Override
public Object get(int index) {
Object obj = list.get(index);
if (obj instanceof Row) {
obj = of((Row) obj);
} else if (obj instanceof List) {
obj = of((List) obj);
}
return obj;
public T get(int index) {
return value(values.get(index));
}

protected abstract T value(Object value);

@Override
public int size() {
return list.size();
return values.size();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.joda.time.DateTime;
import org.joda.time.Duration;
Expand Down Expand Up @@ -243,6 +244,55 @@ public void testSelectInnerRowOfNestedRow() {
pipeline.run().waitUntilFinish(Duration.standardMinutes(2));
}

@Test
public void testNestedBytes() {
byte[] bytes = new byte[] {-70, -83, -54, -2};

Schema nestedInputSchema = Schema.of(Schema.Field.of("c_bytes", Schema.FieldType.BYTES));
Schema inputSchema =
Schema.of(Schema.Field.of("nested", Schema.FieldType.row(nestedInputSchema)));

Schema outputSchema = Schema.of(Schema.Field.of("f0", Schema.FieldType.BYTES));

Row nestedRow = Row.withSchema(nestedInputSchema).addValue(bytes).build();
Row row = Row.withSchema(inputSchema).addValue(nestedRow).build();
Row expected = Row.withSchema(outputSchema).addValue(bytes).build();

PCollection<Row> result =
pipeline
.apply(Create.of(row).withRowSchema(inputSchema))
.apply(SqlTransform.query("SELECT t.nested.c_bytes AS f0 FROM PCOLLECTION t"));

PAssert.that(result).containsInAnyOrder(expected);

pipeline.run();
}

@Test
public void testNestedArrayOfBytes() {
byte[] bytes = new byte[] {-70, -83, -54, -2};

Schema nestedInputSchema =
Schema.of(Schema.Field.of("c_bytes", Schema.FieldType.array(Schema.FieldType.BYTES)));
Schema inputSchema =
Schema.of(Schema.Field.of("nested", Schema.FieldType.row(nestedInputSchema)));

Schema outputSchema = Schema.of(Schema.Field.of("f0", Schema.FieldType.BYTES));

Row nestedRow = Row.withSchema(nestedInputSchema).addValue(ImmutableList.of(bytes)).build();
Row row = Row.withSchema(inputSchema).addValue(nestedRow).build();
Row expected = Row.withSchema(outputSchema).addValue(bytes).build();

PCollection<Row> result =
pipeline
.apply(Create.of(row).withRowSchema(inputSchema))
.apply(SqlTransform.query("SELECT t.nested.c_bytes[1] AS f0 FROM PCOLLECTION t"));

PAssert.that(result).containsInAnyOrder(expected);

pipeline.run();
}

@Test
public void testRowConstructor() {
BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(readOnlyTableProvider);
Expand Down

0 comments on commit 6cc59ac

Please sign in to comment.