Skip to content

Commit

Permalink
feat: implement comparisons for TIME/DATE (#7734)
Browse files Browse the repository at this point in the history
* feat: implement comparisons for TIME/DATE

* rename some stuff

* add compareutil test, reject time/timestamp comparisons

* checkstyle
  • Loading branch information
Zara Lim committed Jun 28, 2021
1 parent cd1a988 commit 78b9ae8
Show file tree
Hide file tree
Showing 67 changed files with 7,769 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -713,27 +713,66 @@ private String visitBooleanComparisonExpression(final ComparisonExpression.Type
}
}

private String visitTimestampComparisonExpression(
private String visitTimeComparisonExpression(
final ComparisonExpression.Type type,
final SqlType left,
final SqlType right
) {
final String comparator = SQL_COMPARE_TO_JAVA.get(type);
if (comparator == null) {
throw new KsqlException("Unexpected timestamp comparison: " + type.getValue());
throw new KsqlException("Unexpected scalar comparison: " + type.getValue());
}

final String compareLeft;
final String compareRight;

if (left.baseType() == SqlBaseType.TIME || right.baseType() == SqlBaseType.TIME) {
compareLeft = toTime(left, 1);
compareRight = toTime(right, 2);
} else if (
left.baseType() == SqlBaseType.TIMESTAMP || right.baseType() == SqlBaseType.TIMESTAMP
) {
compareLeft = toTimestamp(left, 1);
compareRight = toTimestamp(right, 2);
} else {
compareLeft = toDate(left, 1);
compareRight = toDate(right, 2);
}

return String.format(
"(%s.compareTo(%s) %s 0)",
toTimestamp(left, 1),
toTimestamp(right, 2),
compareLeft,
compareRight,
comparator
);
}

private String toTime(final SqlType schema, final int index) {
switch (schema.baseType()) {
case TIME:
return "%" + index + "$s";
case STRING:
return "SqlTimeTypes.parseTime(%" + index + "$s)";
default:
throw new KsqlException("Unexpected comparison to TIME: " + schema.baseType());
}
}

private String toDate(final SqlType schema, final int index) {
switch (schema.baseType()) {
case DATE:
return "%" + index + "$s";
case STRING:
return "SqlTimeTypes.parseDate(%" + index + "$s)";
default:
throw new KsqlException("Unexpected comparison to DATE: " + schema.baseType());
}
}

private String toTimestamp(final SqlType schema, final int index) {
switch (schema.baseType()) {
case TIMESTAMP:
case DATE:
return "%" + index + "$s";
case STRING:
return "SqlTimeTypes.parseTimestamp(%" + index + "$s)";
Expand All @@ -755,9 +794,8 @@ public Pair<String, SqlType> visitComparisonExpression(
|| right.getRight().baseType() == SqlBaseType.DECIMAL) {
exprFormat += visitBytesComparisonExpression(
node.getType(), left.getRight(), right.getRight());
} else if (left.getRight().baseType() == SqlBaseType.TIMESTAMP
|| right.getRight().baseType() == SqlBaseType.TIMESTAMP) {
exprFormat += visitTimestampComparisonExpression(
} else if (left.getRight().baseType().isTime() || right.getRight().baseType().isTime()) {
exprFormat += visitTimeComparisonExpression(
node.getType(), left.getRight(), right.getRight());
} else {
switch (left.getRight().baseType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ final class ComparisonUtil {
.add(handler(SqlBaseType.ARRAY, ComparisonUtil::handleArray))
.add(handler(SqlBaseType.MAP, ComparisonUtil::handleMap))
.add(handler(SqlBaseType.STRUCT, ComparisonUtil::handleStruct))
.add(handler(SqlBaseType.TIMESTAMP, ComparisonUtil::handleTimestamp))
.add(handler(SqlBaseType.TIME, ComparisonUtil::handleTime))
.add(handler(SqlBaseType.DATE, ComparisonUtil::handleDateOrTimestamp))
.add(handler(SqlBaseType.TIMESTAMP, ComparisonUtil::handleDateOrTimestamp))
.build();

private ComparisonUtil() {
Expand Down Expand Up @@ -77,7 +79,7 @@ private static boolean handleNumber(final Type operator, final SqlType right) {
}

private static boolean handleString(final Type operator, final SqlType right) {
return right.baseType() == SqlBaseType.STRING || right.baseType() == SqlBaseType.TIMESTAMP;
return right.baseType() == SqlBaseType.STRING || right.baseType().isTime();
}

private static boolean handleBoolean(final Type operator, final SqlType right) {
Expand All @@ -96,8 +98,14 @@ private static boolean handleStruct(final Type operator, final SqlType right) {
return right.baseType() == SqlBaseType.STRUCT && isEqualityOperator(operator);
}

private static boolean handleTimestamp(final Type operator, final SqlType right) {
return right.baseType() == SqlBaseType.TIMESTAMP || right.baseType() == SqlBaseType.STRING;
private static boolean handleDateOrTimestamp(final Type operator, final SqlType right) {
return right.baseType() == SqlBaseType.DATE
|| right.baseType() == SqlBaseType.TIMESTAMP
|| right.baseType() == SqlBaseType.STRING;
}

private static boolean handleTime(final Type operator, final SqlType right) {
return right.baseType() == SqlBaseType.TIME || right.baseType() == SqlBaseType.STRING;
}

private static boolean isEqualityOperator(final Type operator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import static io.confluent.ksql.execution.testutil.TestExpressions.COL1;
import static io.confluent.ksql.execution.testutil.TestExpressions.COL3;
import static io.confluent.ksql.execution.testutil.TestExpressions.COL7;
import static io.confluent.ksql.execution.testutil.TestExpressions.DATECOL;
import static io.confluent.ksql.execution.testutil.TestExpressions.MAPCOL;
import static io.confluent.ksql.execution.testutil.TestExpressions.SCHEMA;
import static io.confluent.ksql.execution.testutil.TestExpressions.TIMECOL;
import static io.confluent.ksql.execution.testutil.TestExpressions.TIMESTAMPCOL;
import static io.confluent.ksql.execution.testutil.TestExpressions.literal;
import static io.confluent.ksql.name.SourceName.of;
Expand Down Expand Up @@ -83,6 +85,7 @@
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Time;
Expand Down Expand Up @@ -821,6 +824,98 @@ public void shouldGenerateCorrectCodeForDecimalUnaryPlus() {
assertThat(java, is("(COL8.plus(new MathContext(2, RoundingMode.UNNECESSARY)))"));
}

@Test
public void shouldGenerateCorrectCodeForTimeTimeLT() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.LESS_THAN,
TIMECOL,
TIMECOL
);

// When:
final String java = sqlToJavaVisitor.process(compExp);

// Then:
assertThat(java, containsString("(COL12.compareTo(COL12) < 0)"));
}

@Test
public void shouldGenerateCorrectCodeForTimeStringEQ() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.EQUAL,
TIMECOL,
new StringLiteral("01:23:45")
);

// When:
final String java = sqlToJavaVisitor.process(compExp);

// Then:
assertThat(java, containsString("(COL12.compareTo(SqlTimeTypes.parseTime(\"01:23:45\")) == 0)"));
}

@Test
public void shouldThrowOnTimestampTimeLEQ() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.LESS_THAN_OR_EQUAL,
TIMESTAMPCOL,
TIMECOL
);

// Then:
final Exception e = assertThrows(KsqlException.class, () -> sqlToJavaVisitor.process(compExp));
assertThat(e.getMessage(), is("Unexpected comparison to TIME: TIMESTAMP"));
}

@Test
public void shouldThrowOnTimeDateNEQ() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.NOT_EQUAL,
TIMECOL,
DATECOL
);

// Then:
final Exception e = assertThrows(KsqlException.class, () -> sqlToJavaVisitor.process(compExp));
assertThat(e.getMessage(), is("Unexpected comparison to TIME: DATE"));
}

@Test
public void shouldGenerateCorrectCodeForDateDateLT() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.LESS_THAN,
DATECOL,
DATECOL
);

// When:
final String java = sqlToJavaVisitor.process(compExp);

// Then:
assertThat(java, containsString("(COL13.compareTo(COL13) < 0)"));
}

@Test
public void shouldGenerateCorrectCodeForDateStringEQ() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.EQUAL,
DATECOL,
new StringLiteral("2021-06-23")
);

// When:
final String java = sqlToJavaVisitor.process(compExp);

// Then:
assertThat(java, containsString("(COL13.compareTo(SqlTimeTypes.parseDate(\"2021-06-23\")) == 0)"));
}

@Test
public void shouldGenerateCorrectCodeForTimestampTimestampLT() {
// Given:
Expand Down Expand Up @@ -869,6 +964,22 @@ public void shouldGenerateCorrectCodeForTimestampStringGEQ() {
assertThat(java, containsString("(SqlTimeTypes.parseTimestamp(\"2020-01-01T00:00:00\").compareTo(COL10) >= 0)"));
}

@Test
public void shouldGenerateCorrectCodeForTimestampDateGT() {
// Given:
final ComparisonExpression compExp = new ComparisonExpression(
Type.GREATER_THAN,
TIMESTAMPCOL,
DATECOL
);

// When:
final String java = sqlToJavaVisitor.process(compExp);

// Then:
assertThat(java, containsString("(COL10.compareTo(COL13) > 0)"));
}

@Test
public void shouldGenerateCorrectCodeForIntervalUnit() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ private TestExpressions() {
.valueColumn(ColumnName.of("COL9"), SqlTypes.decimal(2, 1))
.valueColumn(ColumnName.of("COL10"), SqlTypes.TIMESTAMP)
.valueColumn(ColumnName.of("COL11"), SqlTypes.BOOLEAN)
.valueColumn(ColumnName.of("COL12"), SqlTypes.TIME)
.valueColumn(ColumnName.of("COL13"), SqlTypes.DATE)
.build();

public static final UnqualifiedColumnReferenceExp COL0 = columnRef("COL0");
Expand All @@ -49,6 +51,8 @@ private TestExpressions() {
public static final UnqualifiedColumnReferenceExp COL8 = columnRef("COL8");
public static final UnqualifiedColumnReferenceExp TIMESTAMPCOL = columnRef("COL10");
public static final UnqualifiedColumnReferenceExp COL11 = columnRef("COL11");
public static final UnqualifiedColumnReferenceExp TIMECOL = columnRef("COL12");
public static final UnqualifiedColumnReferenceExp DATECOL = columnRef("COL13");

private static UnqualifiedColumnReferenceExp columnRef(final String name) {
return new UnqualifiedColumnReferenceExp(ColumnName.of(name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,24 @@ public class ComparisonUtilTest {
SqlTypes.array(SqlTypes.STRING),
SqlTypes.map(SqlTypes.BIGINT, SqlTypes.STRING),
SqlTypes.struct().field("foo", SqlTypes.BIGINT).build(),
SqlTypes.TIMESTAMP
SqlTypes.TIMESTAMP,
SqlTypes.TIME,
SqlTypes.DATE
);

private static final List<List<Boolean>> expectedResults = ImmutableList.of(
ImmutableList.of(true, false, false, false, false, false, false, false, false, false), // Boolean
ImmutableList.of(false, true, true, true, true, false, false, false, false, false), // Int
ImmutableList.of(false, true, true, true, true, false, false, false, false, false), // BigInt
ImmutableList.of(false, true, true, true, true, false, false, false, false, false), // Double
ImmutableList.of(false, true, true, true, true, false, false, false, false, false), // Decimal
ImmutableList.of(false, false, false, false, false, true, false, false, false, true), // String
ImmutableList.of(false, false, false, false, false, false, true, false, false, false), // Array
ImmutableList.of(false, false, false, false, false, false, false, true, false, false), // Map
ImmutableList.of(false, false, false, false, false, false, false, false, true, false), // Struct
ImmutableList.of(false, false, false, false, false, true, false, false, false, true) // Timestamp
ImmutableList.of(true, false, false, false, false, false, false, false, false, false, false, false), // Boolean
ImmutableList.of(false, true, true, true, true, false, false, false, false, false, false, false), // Int
ImmutableList.of(false, true, true, true, true, false, false, false, false, false, false, false), // BigInt
ImmutableList.of(false, true, true, true, true, false, false, false, false, false, false, false), // Double
ImmutableList.of(false, true, true, true, true, false, false, false, false, false, false, false), // Decimal
ImmutableList.of(false, false, false, false, false, true, false, false, false, true, true, true), // String
ImmutableList.of(false, false, false, false, false, false, true, false, false, false, false, false), // Array
ImmutableList.of(false, false, false, false, false, false, false, true, false, false, false, false), // Map
ImmutableList.of(false, false, false, false, false, false, false, false, true, false, false, false), // Struct
ImmutableList.of(false, false, false, false, false, true, false, false, false, true, false, true), // Timestamp
ImmutableList.of(false, false, false, false, false, true, false, false, false, false, true, false), // Time
ImmutableList.of(false, false, false, false, false, true, false, false, false, true, false, true) // Date
);

@Test
Expand Down
Loading

0 comments on commit 78b9ae8

Please sign in to comment.