Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support NULL return values from CASE statements #3531

Merged
merged 6 commits into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ public String getSymbol() {
*/
public SqlType resultType(final SqlType left, final SqlType right) {
if (left.baseType().isNumber() && right.baseType().isNumber()) {
if (left.baseType().canUpCast(right.baseType())) {
if (left.baseType().canImplicitlyCast(right.baseType())) {
if (right.baseType() != SqlBaseType.DECIMAL) {
return right;
}

return binaryResolver.apply(toDecimal(left), (SqlDecimal) right);
}

if (right.baseType().canUpCast(left.baseType())) {
if (right.baseType().canImplicitlyCast(left.baseType())) {
if (left.baseType() != SqlBaseType.DECIMAL) {
return left;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ public boolean isNumber() {
}

/**
* Test to see if this type can be up-cast to another.
* Test to see if this type can be <i>implicitly</i>u cast to another.
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved
*
* <p>This defines if KSQL supports <i>implicitly</i> converting one numeric type to another.
*
* <p>Types can always be upcast to themselves. Only numeric types can be upcast to different
* numeric types. Note: STRING to DECIMAL handling is not seen as up-casting, it's parsing.
* <p>Types can always be cast to themselves. Only numeric types can be implicitly cast to other
* numeric types. Note: STRING to DECIMAL handling is not seen as casting: it's parsing.
*
* @param to the target type.
* @return true if this type can be upcast to the supplied type.
* @return true if this type can be implicitly cast to the supplied type.
*/
public boolean canUpCast(final SqlBaseType to) {
public boolean canImplicitlyCast(final SqlBaseType to) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to make the intent more clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

return this.equals(to)
|| (isNumber() && to.isNumber() && this.ordinal() <= to.ordinal());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void shouldBeNumber() {
public void shouldNotUpCastIfNotNumber() {
nonNumberTypes().forEach(sqlType -> assertThat(
sqlType + " should not upcast",
sqlType.canUpCast(SqlBaseType.DOUBLE),
sqlType.canImplicitlyCast(SqlBaseType.DOUBLE),
is(false))
);
}
Expand All @@ -65,51 +65,51 @@ public void shouldNotUpCastIfNotNumber() {
public void shouldUpCastIfNumber() {
numberTypes().forEach(sqlType -> assertThat(
sqlType + " should upcast",
sqlType.canUpCast(SqlBaseType.DOUBLE),
sqlType.canImplicitlyCast(SqlBaseType.DOUBLE),
is(true))
);
}

@Test
public void shouldUpCastToSelf() {
allTypes().forEach(sqlType ->
assertThat(sqlType + " should upcast to self", sqlType.canUpCast(sqlType), is(true)));
assertThat(sqlType + " should upcast to self", sqlType.canImplicitlyCast(sqlType), is(true)));
}

@Test
public void shouldUpCastInt() {
assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.BIGINT), is(true));
assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.DOUBLE), is(true));
assertThat(SqlBaseType.INTEGER.canImplicitlyCast(SqlBaseType.BIGINT), is(true));
assertThat(SqlBaseType.INTEGER.canImplicitlyCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.INTEGER.canImplicitlyCast(SqlBaseType.DOUBLE), is(true));
}

@Test
public void shouldUpCastBigInt() {
assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.DOUBLE), is(true));
assertThat(SqlBaseType.BIGINT.canImplicitlyCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.BIGINT.canImplicitlyCast(SqlBaseType.DOUBLE), is(true));
}

@Test
public void shouldUpCastDecimal() {
assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.DOUBLE), is(true));
assertThat(SqlBaseType.DECIMAL.canImplicitlyCast(SqlBaseType.DOUBLE), is(true));
}

@Test
public void shouldNotDownCastBigInt() {
assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.BIGINT.canImplicitlyCast(SqlBaseType.INTEGER), is(false));
}

@Test
public void shouldNotDownCastDecimal() {
assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.BIGINT), is(false));
assertThat(SqlBaseType.DECIMAL.canImplicitlyCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DECIMAL.canImplicitlyCast(SqlBaseType.BIGINT), is(false));
}

@Test
public void shouldNotDownCastDouble() {
assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.BIGINT), is(false));
assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.DECIMAL), is(false));
assertThat(SqlBaseType.DOUBLE.canImplicitlyCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DOUBLE.canImplicitlyCast(SqlBaseType.BIGINT), is(false));
assertThat(SqlBaseType.DOUBLE.canImplicitlyCast(SqlBaseType.DECIMAL), is(false));
}

private static Stream<SqlBaseType> numberTypes() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlType;
Expand All @@ -64,6 +65,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.kafka.connect.data.Schema;

Expand Down Expand Up @@ -302,10 +304,32 @@ public Void visitIsNullPredicate(
@Override
public Void visitSearchedCaseExpression(
final SearchedCaseExpression node,
final ExpressionTypeContext expressionTypeContext
final ExpressionTypeContext context
) {
validateSearchedCaseExpression(node);
process(node.getWhenClauses().get(0).getResult(), expressionTypeContext);
final Optional<SqlType> whenType = validateWhenClauses(node.getWhenClauses(), context);

final Optional<SqlType> defaultType = node.getDefaultValue()
.map(ExpressionTypeManager.this::getExpressionSqlType);

if (whenType.isPresent() && defaultType.isPresent()) {
if (!whenType.get().equals(defaultType.get())) {
throw new KsqlException("Invalid Case expression. "
+ "Schema for the default clause should be the same as for 'THEN' clauses."
+ System.lineSeparator()
+ "THEN schema: " + whenType.get() + "."
+ System.lineSeparator()
+ "DEFAULT schema: " + defaultType.get() + "."
);
}

context.setSqlType(whenType.get());
} else if (whenType.isPresent()) {
context.setSqlType(whenType.get());
} else if (defaultType.isPresent()) {
context.setSqlType(defaultType.get());
} else {
throw new KsqlException("Invalid Case expression. All case branches have NULL schema");
}
return null;
}

Expand Down Expand Up @@ -449,38 +473,46 @@ public Void visitWhenClause(
throw VisitorUtil.illegalState(this, whenClause);
}

private void validateSearchedCaseExpression(
final SearchedCaseExpression searchedCaseExpression) {
final Schema firstResultSchema = getExpressionSchema(
searchedCaseExpression.getWhenClauses().get(0).getResult());
searchedCaseExpression.getWhenClauses()
.forEach(whenClause -> validateWhenClause(whenClause, firstResultSchema));
searchedCaseExpression.getDefaultValue()
.map(ExpressionTypeManager.this::getExpressionSchema)
.filter(defaultSchema -> !firstResultSchema.equals(defaultSchema))
.ifPresent(badSchema -> {
throw new KsqlException("Invalid Case expression."
+ " Schema for the default clause should be the same as schema for THEN clauses."
+ " Result scheme: " + firstResultSchema + "."
+ " Schema for default expression is " + badSchema);
});
}

private void validateWhenClause(final WhenClause whenClause,
final Schema expectedResultSchema) {
final Schema operandSchema = getExpressionSchema(whenClause.getOperand());
if (!operandSchema.equals(Schema.OPTIONAL_BOOLEAN_SCHEMA)) {
throw new KsqlException("When operand schema should be boolean. Schema for ("
+ whenClause.getOperand() + ") is " + operandSchema);
}
final Schema resultSchema = getExpressionSchema(whenClause.getResult());
if (!expectedResultSchema.equals(resultSchema)) {
throw new KsqlException("Invalid Case expression."
+ " Schemas for 'THEN' clauses should be the same."
+ " Result schema: " + expectedResultSchema + "."
+ " Schema for THEN expression '" + whenClause + "'"
+ " is " + resultSchema);
private Optional<SqlType> validateWhenClauses(
final List<WhenClause> whenClauses,
final ExpressionTypeContext context
) {
Optional<SqlType> previousResult = Optional.empty();
for (final WhenClause whenClause : whenClauses) {
process(whenClause.getOperand(), context);

final SqlType operandType = context.getSqlType();

if (operandType.baseType() != SqlBaseType.BOOLEAN) {
throw new KsqlException("When operand schema should be boolean."
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved
+ System.lineSeparator()
+ "Schema for '" + whenClause.getOperand() + "' is " + operandType
);
}

process(whenClause.getResult(), context);

final SqlType resultType = context.getSqlType();
if (resultType == null) {
continue; // `null` type
}

if (!previousResult.isPresent()) {
previousResult = Optional.of(resultType);
continue;
}

if (!previousResult.get().equals(resultType)) {
throw new KsqlException("Invalid Case expression. "
+ "Schemas for all 'THEN' clauses should be the same."
+ System.lineSeparator()
+ "THEN expression '" + whenClause + "' has schema: " + resultType + "."
+ System.lineSeparator()
+ "Previous THEN expression(s) schema: " + previousResult.get() + ".");
}
}

return previousResult;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,11 @@ public void shouldFailIfWhenIsNotBoolean() {
Optional.empty()
);
expectedException.expect(KsqlException.class);
expectedException.expectMessage("When operand schema should be boolean. Schema for ((TEST1.COL0 + 10)) is Schema{INT64}");
expectedException.expectMessage(
"When operand schema should be boolean."
+ System.lineSeparator()
+ "Schema for '(TEST1.COL0 + 10)' is BIGINT"
);

// When:
expressionTypeManager.getExpressionSqlType(expression);
Expand All @@ -444,7 +448,13 @@ public void shouldFailOnInconsistentWhenResultType() {
Optional.empty()
);
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Invalid Case expression. Schemas for 'THEN' clauses should be the same. Result schema: Schema{STRING}. Schema for THEN expression 'WHEN (TEST1.COL0 = 10) THEN 10' is Schema{INT32}");
expectedException.expectMessage(
"Invalid Case expression. Schemas for all 'THEN' clauses should be the same."
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved
+ System.lineSeparator()
+ "THEN expression 'WHEN (TEST1.COL0 = 10) THEN 10' has schema: INTEGER."
+ System.lineSeparator()
+ "Previous THEN expression(s) schema: STRING."
);

// When:
expressionTypeManager.getExpressionSqlType(expression);
Expand All @@ -463,7 +473,13 @@ public void shouldFailIfDefaultHasDifferentTypeToWhen() {
Optional.of(new BooleanLiteral("true"))
);
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Invalid Case expression. Schema for the default clause should be the same as schema for THEN clauses. Result scheme: Schema{STRING}. Schema for default expression is Schema{BOOLEAN}");
expectedException.expectMessage(
"Invalid Case expression. Schema for the default clause should be the same as for 'THEN' clauses."
big-andy-coates marked this conversation as resolved.
Show resolved Hide resolved
+ System.lineSeparator()
+ "THEN schema: STRING."
+ System.lineSeparator()
+ "DEFAULT schema: BOOLEAN."
);

// When:
expressionTypeManager.getExpressionSqlType(expression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
{
"name": "searched case with arithmetic expression in result",
"statements": [
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', key='orderid', value_format='JSON');",
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN orderid + 2 END AS case_resault FROM orders;"
],
"inputs": [
Expand All @@ -43,7 +43,7 @@
{
"name": "searched case with null in when",
"statements": [
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', key='orderid', value_format='JSON');",
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits > 2.0 THEN 'foo' ELSE 'default' END AS case_resault FROM orders;"
],
"inputs": [
Expand All @@ -58,6 +58,66 @@
}
]
},
{
"name": "searched case returning null in first branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN null WHEN orderunits < 4.0 THEN 'medium' ELSE 'large' END AS case_result FROM orders;"
],
"inputs": [
{"topic": "test_topic", "value": {"ORDERUNITS": 4.2}},
{"topic": "test_topic", "value": {"ORDERUNITS": 3.99}},
{"topic": "test_topic", "value": {"ORDERUNITS": 1.1}}
],
"outputs": [
{"topic": "S1", "value": {"CASE_RESULT": "large"}},
{"topic": "S1", "value": {"CASE_RESULT": "medium"}},
{"topic": "S1", "value": {"CASE_RESULT": null}}
]
},
{
"name": "searched case returning null in later branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN 'small' WHEN orderunits < 4.0 THEN null ELSE 'large' END AS case_result FROM orders;"
],
"inputs": [
{"topic": "test_topic", "value": {"ORDERUNITS": 4.2}},
{"topic": "test_topic", "value": {"ORDERUNITS": 3.99}},
{"topic": "test_topic", "value": {"ORDERUNITS": 1.1}}
],
"outputs": [
{"topic": "S1", "value": {"CASE_RESULT": "large"}},
{"topic": "S1", "value": {"CASE_RESULT": null}},
{"topic": "S1", "value": {"CASE_RESULT": "small"}}
]
},
{
"name": "searched case returning null in default branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN 'small' ELSE null END AS case_result FROM orders;"
],
"inputs": [
{"topic": "test_topic", "value": {"ORDERUNITS": 4.2}},
{"topic": "test_topic", "value": {"ORDERUNITS": 1.1}}
],
"outputs": [
{"topic": "S1", "value": {"CASE_RESULT": null}},
{"topic": "S1", "value": {"CASE_RESULT": "small"}}
]
},
{
"name": "searched case returning null in all branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN null ELSE null END AS case_result FROM orders;"
],
"expectedException": {
"type": "io.confluent.ksql.util.KsqlStatementException",
"message": "Invalid Case expression. All case branches have NULL schema"
}
},
{
"name": "searched case expression with structs, multiple expression and the same type",
"statements": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public <T> Optional<T> coerce(final Object value, final SqlType targetType) {
return coerceDecimal(value, (SqlDecimal) targetType);
}

if (!(value instanceof Number) || !valueSqlType.canUpCast(targetType.baseType())) {
if (!(value instanceof Number) || !valueSqlType.canImplicitlyCast(targetType.baseType())) {
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ private static boolean coercionShouldBeSupported(
// Handled by parsing the string to a decimal:
return true;
}
return fromBaseType.canUpCast(toBaseType);
return fromBaseType.canImplicitlyCast(toBaseType);
}

private static List<SqlBaseType> supportedTypes() {
Expand Down