Skip to content

Commit

Permalink
[FLINK-22378][table] Derive type of SOURCE_WATERMARK() from time attr…
Browse files Browse the repository at this point in the history
…ibute
  • Loading branch information
twalthr committed Apr 26, 2021
1 parent 19ca330 commit cdd1732
Show file tree
Hide file tree
Showing 39 changed files with 378 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import org.apache.flink.connectors.hive.FlinkHiveException;
import org.apache.flink.table.api.SqlParserException;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.Catalog;
import org.apache.flink.table.catalog.CatalogManager;
Expand All @@ -36,7 +35,6 @@
import org.apache.flink.table.operations.ddl.CreateTableASOperation;
import org.apache.flink.table.operations.ddl.CreateTableOperation;
import org.apache.flink.table.planner.calcite.FlinkPlannerImpl;
import org.apache.flink.table.planner.calcite.SqlExprToRexConverter;
import org.apache.flink.table.planner.delegation.ParserImpl;
import org.apache.flink.table.planner.delegation.PlannerContext;
import org.apache.flink.table.planner.delegation.hive.copy.HiveASTParseException;
Expand Down Expand Up @@ -80,7 +78,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

/** A Parser that uses Hive's planner to parse a statement. */
Expand Down Expand Up @@ -176,13 +173,12 @@ public class HiveParser extends ParserImpl {
CatalogManager catalogManager,
Supplier<FlinkPlannerImpl> validatorSupplier,
Supplier<CalciteParser> calciteParserSupplier,
Function<TableSchema, SqlExprToRexConverter> sqlExprToRexConverterCreator,
PlannerContext plannerContext) {
super(
catalogManager,
validatorSupplier,
calciteParserSupplier,
sqlExprToRexConverterCreator);
plannerContext.getSqlExprToRexConverterFactory());
this.plannerContext = plannerContext;
this.catalogReader =
plannerContext.createCatalogReader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.flink.table.catalog.CatalogManager;
import org.apache.flink.table.delegation.Parser;
import org.apache.flink.table.descriptors.DescriptorProperties;
import org.apache.flink.table.planner.calcite.SqlExprToRexConverterFactory;
import org.apache.flink.table.planner.delegation.ParserFactory;
import org.apache.flink.table.planner.delegation.PlannerContext;

Expand All @@ -36,18 +35,13 @@ public class HiveParserFactory implements ParserFactory {

@Override
public Parser create(CatalogManager catalogManager, PlannerContext plannerContext) {
SqlExprToRexConverterFactory sqlExprToRexConverterFactory =
plannerContext::createSqlExprToRexConverter;
return new HiveParser(
catalogManager,
() ->
plannerContext.createFlinkPlanner(
catalogManager.getCurrentCatalog(),
catalogManager.getCurrentDatabase()),
plannerContext::createCalciteParser,
tableSchema ->
sqlExprToRexConverterFactory.create(
plannerContext.getTypeFactory().buildRelNodeRowType(tableSchema)),
plannerContext);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,11 @@ protected TableEnvironmentImpl(
return Optional.empty();
}
},
(sqlExpression, inputSchema) -> {
(sqlExpression, inputRowType, outputType) -> {
try {
return getParser().parseSqlExpression(sqlExpression, inputSchema);
return getParser()
.parseSqlExpression(
sqlExpression, inputRowType, outputType);
} catch (Throwable t) {
throw new ValidationException(
String.format("Invalid SQL expression: %s", sqlExpression),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ private ComputedColumn resolveComputedColumn(
UnresolvedComputedColumn unresolvedColumn, List<Column> inputColumns) {
final ResolvedExpression resolvedExpression;
try {
resolvedExpression = resolveExpression(inputColumns, unresolvedColumn.getExpression());
resolvedExpression =
resolveExpression(inputColumns, unresolvedColumn.getExpression(), null);
} catch (Exception e) {
throw new ValidationException(
String.format(
Expand Down Expand Up @@ -189,22 +190,26 @@ private List<WatermarkSpec> resolveWatermarkSpecs(
final ResolvedExpression watermarkExpression;
try {
watermarkExpression =
resolveExpression(inputColumns, watermarkSpec.getWatermarkExpression());
resolveExpression(
inputColumns,
watermarkSpec.getWatermarkExpression(),
validatedTimeColumn.getDataType());
} catch (Exception e) {
throw new ValidationException(
String.format(
"Invalid expression for watermark '%s'.", watermarkSpec.toString()),
e);
}
validateWatermarkExpression(watermarkExpression.getOutputDataType().getLogicalType());
final LogicalType outputType = watermarkExpression.getOutputDataType().getLogicalType();
final LogicalType timeColumnType = validatedTimeColumn.getDataType().getLogicalType();
validateWatermarkExpression(outputType);

if (!(watermarkExpression.getOutputDataType().getLogicalType().getTypeRoot()
== validatedTimeColumn.getDataType().getLogicalType().getTypeRoot())) {
if (!(outputType.getTypeRoot() == timeColumnType.getTypeRoot())) {
throw new ValidationException(
String.format(
"The watermark output type %s is different from input time filed type %s.",
watermarkExpression.getOutputDataType(),
validatedTimeColumn.getDataType()));
"The watermark declaration's output data type '%s' is different "
+ "from the time field's data type '%s'.",
outputType, timeColumnType));
}

return Collections.singletonList(
Expand Down Expand Up @@ -348,13 +353,15 @@ private void validatePrimaryKey(UniqueConstraint primaryKey, List<Column> column
}
}

private ResolvedExpression resolveExpression(List<Column> columns, Expression expression) {
private ResolvedExpression resolveExpression(
List<Column> columns, Expression expression, @Nullable DataType outputDataType) {
final LocalReferenceExpression[] localRefs =
columns.stream()
.map(c -> localRef(c.getName(), c.getDataType()))
.toArray(LocalReferenceExpression[]::new);
return resolverBuilder
.withLocalReferences(localRefs)
.withOutputDataType(outputDataType)
.build()
.resolve(Collections.singletonList(expression))
.get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
package org.apache.flink.table.delegation;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.catalog.UnresolvedIdentifier;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.operations.Operation;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

import javax.annotation.Nullable;

import java.util.List;

Expand Down Expand Up @@ -58,11 +61,13 @@ public interface Parser {
* Entry point for parsing SQL expressions expressed as a String.
*
* @param sqlExpression the SQL expression to parse
* @param inputSchema the schema of the fields in sql expression
* @param inputRowType the fields available in the SQL expression
* @param outputType expected top-level output type if available
* @return resolved expression
* @throws org.apache.flink.table.api.SqlParserException when failed to parse the sql expression
*/
ResolvedExpression parseSqlExpression(String sqlExpression, TableSchema inputSchema);
ResolvedExpression parseSqlExpression(
String sqlExpression, RowType inputRowType, @Nullable LogicalType outputType);

/**
* Returns completion hints for the given statement at the given cursor position. The completion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.Preconditions;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -117,6 +119,8 @@ public static List<ResolverRule> getAllResolverRules() {

private final Map<String, LocalReferenceExpression> localReferences;

private final @Nullable DataType outputDataType;

private final Map<Expression, LocalOverWindow> localOverWindows;

private final boolean isGroupedAggregation;
Expand All @@ -130,6 +134,7 @@ private ExpressionResolver(
FieldReferenceLookup fieldLookup,
List<OverWindow> localOverWindows,
List<LocalReferenceExpression> localReferences,
@Nullable DataType outputDataType,
boolean isGroupedAggregation) {
this.config = Preconditions.checkNotNull(config).getConfiguration();
this.tableLookup = Preconditions.checkNotNull(tableLookup);
Expand All @@ -149,6 +154,7 @@ private ExpressionResolver(
"Duplicate local reference: " + u);
},
LinkedHashMap::new));
this.outputDataType = outputDataType;
this.localOverWindows = prepareOverWindows(localOverWindows);
this.isGroupedAggregation = isGroupedAggregation;
}
Expand Down Expand Up @@ -323,6 +329,11 @@ public List<LocalReferenceExpression> getLocalReferences() {
return new ArrayList<>(localReferences.values());
}

@Override
public Optional<DataType> getOutputDataType() {
return Optional.ofNullable(outputDataType);
}

@Override
public Optional<LocalOverWindow> getOverWindow(Expression alias) {
return Optional.ofNullable(localOverWindows.get(alias));
Expand Down Expand Up @@ -443,6 +454,7 @@ public static class ExpressionResolverBuilder {
private final SqlExpressionResolver sqlExpressionResolver;
private List<OverWindow> logicalOverWindows = new ArrayList<>();
private List<LocalReferenceExpression> localReferences = new ArrayList<>();
private @Nullable DataType outputDataType;
private boolean isGroupedAggregation;

private ExpressionResolverBuilder(
Expand Down Expand Up @@ -471,6 +483,11 @@ public ExpressionResolverBuilder withLocalReferences(
return this;
}

public ExpressionResolverBuilder withOutputDataType(@Nullable DataType outputDataType) {
this.outputDataType = outputDataType;
return this;
}

public ExpressionResolverBuilder withGroupedAggregation(boolean isGroupedAggregation) {
this.isGroupedAggregation = isGroupedAggregation;
return this;
Expand All @@ -486,6 +503,7 @@ public ExpressionResolver build() {
new FieldReferenceLookup(queryOperations),
logicalOverWindows,
localReferences,
outputDataType,
isGroupedAggregation);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
package org.apache.flink.table.expressions.resolver;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

import javax.annotation.Nullable;

/** Translates a SQL expression string into a {@link ResolvedExpression}. */
@Internal
public interface SqlExpressionResolver {

/** Translates the given SQL expression string or throws a {@link ValidationException}. */
ResolvedExpression resolveExpression(String sqlExpression, TableSchema inputSchema);
ResolvedExpression resolveExpression(
String sqlExpression, RowType inputRowType, @Nullable LogicalType outputType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
package org.apache.flink.table.expressions.resolver.rules;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.SqlCallExpression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
import org.apache.flink.table.expressions.resolver.SqlExpressionResolver;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.RowType.RowField;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -37,32 +43,49 @@ final class ResolveSqlCallRule implements ResolverRule {

@Override
public List<Expression> apply(List<Expression> expression, ResolutionContext context) {
return expression.stream()
.map(expr -> expr.accept(new TranslateSqlCallsVisitor(context)))
.collect(Collectors.toList());
// only the top-level expressions may access the output data type
final LogicalType outputType =
context.getOutputDataType().map(DataType::getLogicalType).orElse(null);
final TranslateSqlCallsVisitor visitor = new TranslateSqlCallsVisitor(context, outputType);
return expression.stream().map(expr -> expr.accept(visitor)).collect(Collectors.toList());
}

private static class TranslateSqlCallsVisitor extends RuleExpressionVisitor<Expression> {

TranslateSqlCallsVisitor(ResolutionContext resolutionContext) {
private final @Nullable LogicalType outputType;

TranslateSqlCallsVisitor(
ResolutionContext resolutionContext, @Nullable LogicalType outputType) {
super(resolutionContext);
this.outputType = outputType;
}

@Override
public Expression visit(SqlCallExpression sqlCall) {
final SqlExpressionResolver resolver = resolutionContext.sqlExpressionResolver();

final TableSchema.Builder builder = TableSchema.builder();
final List<RowField> fields = new ArrayList<>();
// input references
resolutionContext
.referenceLookup()
.getAllInputFields()
.forEach(f -> builder.field(f.getName(), f.getOutputDataType()));
.forEach(
f ->
fields.add(
new RowField(
f.getName(),
f.getOutputDataType().getLogicalType())));
// local references
resolutionContext
.getLocalReferences()
.forEach(refs -> builder.field(refs.getName(), refs.getOutputDataType()));
return resolver.resolveExpression(sqlCall.getSqlExpression(), builder.build());
.forEach(
refs ->
fields.add(
new RowField(
refs.getName(),
refs.getOutputDataType().getLogicalType())));
return resolver.resolveExpression(
sqlCall.getSqlExpression(), new RowType(false, fields), outputType);
}

@Override
Expand All @@ -76,8 +99,10 @@ protected Expression defaultMethod(Expression expression) {
}

private List<Expression> resolveChildren(List<Expression> lookupChildren) {
final TranslateSqlCallsVisitor visitor =
new TranslateSqlCallsVisitor(resolutionContext, null);
return lookupChildren.stream()
.map(child -> child.accept(this))
.map(child -> child.accept(visitor))
.collect(Collectors.toList());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.flink.table.expressions.resolver.lookups.FieldReferenceLookup;
import org.apache.flink.table.expressions.resolver.lookups.TableReferenceLookup;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;

import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -87,6 +88,9 @@ interface ResolutionContext {
/** Access to available local references. */
List<LocalReferenceExpression> getLocalReferences();

/** Access to the expected top-level output data type. */
Optional<DataType> getOutputDataType();

/** Access to available local over windows. */
Optional<LocalOverWindow> getOverWindow(Expression alias);

Expand Down
Loading

0 comments on commit cdd1732

Please sign in to comment.