Skip to content

Commit

Permalink
[FLINK-25856][python] Fix use of UserDefinedType in from_elements
Browse files Browse the repository at this point in the history
This closes apache#18826.
  • Loading branch information
HuangXingBo committed Feb 21, 2022
1 parent a7411b6 commit 6056680
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 13 deletions.
74 changes: 73 additions & 1 deletion flink-python/pyflink/table/tests/test_table_environment_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyflink.table.explain_detail import ExplainDetail
from pyflink.table.expressions import col, source_watermark
from pyflink.table.table_descriptor import TableDescriptor
from pyflink.table.types import RowType, Row
from pyflink.table.types import RowType, Row, UserDefinedType
from pyflink.table.udf import udf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import (
Expand Down Expand Up @@ -533,8 +533,80 @@ def test_collect_for_all_data_types(self):
self.assertEqual(expected_result, collected_result)


class VectorUDT(UserDefinedType):

@classmethod
def sql_type(cls):
return DataTypes.ROW(
[
DataTypes.FIELD("type", DataTypes.TINYINT()),
DataTypes.FIELD("size", DataTypes.INT()),
DataTypes.FIELD("indices", DataTypes.ARRAY(DataTypes.INT())),
DataTypes.FIELD("values", DataTypes.ARRAY(DataTypes.DOUBLE())),
]
)

@classmethod
def module(cls):
return "pyflink.ml.core.linalg"

def serialize(self, obj):
if isinstance(obj, DenseVector):
values = [float(v) for v in obj._values]
return 1, None, None, values
else:
raise TypeError("Cannot serialize %r of type %r".format(obj, type(obj)))

def deserialize(self, datum):
pass


class DenseVector(object):
__UDT__ = VectorUDT()

def __init__(self, values):
self._values = values

def size(self) -> int:
return len(self._values)

def get(self, i: int):
return self._values[i]

def to_array(self):
return self._values

@property
def values(self):
return self._values

def __str__(self):
return "[" + ",".join([str(v) for v in self._values]) + "]"

def __repr__(self):
return "DenseVector([%s])" % (", ".join(str(i) for i in self._values))


class BatchTableEnvironmentTests(PyFlinkBatchTableTestCase):

def test_udt(self):
self.t_env.from_elements([
(DenseVector([1, 2, 3, 4]), 0., 1.),
(DenseVector([2, 2, 3, 4]), 0., 2.),
(DenseVector([3, 2, 3, 4]), 0., 3.),
(DenseVector([4, 2, 3, 4]), 0., 4.),
(DenseVector([5, 2, 3, 4]), 0., 5.),
(DenseVector([11, 2, 3, 4]), 1., 1.),
(DenseVector([12, 2, 3, 4]), 1., 2.),
(DenseVector([13, 2, 3, 4]), 1., 3.),
(DenseVector([14, 2, 3, 4]), 1., 4.),
(DenseVector([15, 2, 3, 4]), 1., 5.),
],
DataTypes.ROW([
DataTypes.FIELD("features", VectorUDT()),
DataTypes.FIELD("label", DataTypes.DOUBLE()),
DataTypes.FIELD("weight", DataTypes.DOUBLE())]))

def test_explain_with_multi_sinks(self):
t_env = self.t_env
source = t_env.from_elements([(1, "Hi", "Hello"), (2, "Hello", "Hello")], ["a", "b", "c"])
Expand Down
9 changes: 7 additions & 2 deletions flink-python/pyflink/table/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,12 +2202,17 @@ def verify_varbinary(obj):
verify_value = verify_varbinary

elif isinstance(data_type, UserDefinedType):
verifier = _create_type_verifier(data_type.sql_type(), name=name)
sql_type = data_type.sql_type()
verifier = _create_type_verifier(sql_type, name=name)

def verify_udf(obj):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == data_type):
raise ValueError(new_msg("%r is not an instance of type %r" % (obj, data_type)))
verifier(data_type.to_sql_type(obj))
data = data_type.to_sql_type(obj)
if isinstance(sql_type, RowType):
# remove the RowKind value in the first position.
data = data[1:]
verifier(data)

verify_value = verify_udf

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.DateType;
Expand Down Expand Up @@ -305,8 +309,16 @@ public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?
&& BasicTypeInfo.getInfoFor(dataType.getTypeClass()) == FLOAT_TYPE_INFO) {
// Serialization of float type with pickler loses precision.
return pickler.dumps(String.valueOf(obj));
} else {
} else if (dataType instanceof PickledByteArrayTypeInfo
|| dataType instanceof BasicTypeInfo) {
return pickler.dumps(obj);
} else {
// other typeinfos will use the corresponding serializer to serialize data.
TypeSerializer serializer = dataType.createSerializer(null);
ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
DataOutputViewStreamWrapper baosWrapper = new DataOutputViewStreamWrapper(baos);
serializer.serialize(obj, baosWrapper);
return pickler.dumps(baos.toByteArray());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,21 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.io.CollectionInputFormat;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
import org.apache.flink.table.api.Types;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;

import java.io.IOException;
import java.lang.reflect.Array;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -72,7 +77,7 @@ private PythonTableUtils() {}
final List<Object[]> data,
final TypeInformation<Row> dataType,
final ExecutionConfig config) {
Function<Object, Object> converter = converter(dataType);
Function<Object, Object> converter = converter(dataType, config);
return new CollectionInputFormat<>(
data.stream()
.map(objects -> (Row) converter.apply(objects))
Expand All @@ -92,7 +97,7 @@ private PythonTableUtils() {}
@SuppressWarnings("unchecked")
public static <T> InputFormat<T, ?> getCollectionInputFormat(
final List<T> data, final TypeInformation<T> dataType, final ExecutionConfig config) {
Function<Object, Object> converter = converter(dataType);
Function<Object, Object> converter = converter(dataType, config);
return new CollectionInputFormat<>(
data.stream()
.map(objects -> (T) converter.apply(objects))
Expand Down Expand Up @@ -253,7 +258,8 @@ private static BiFunction<Integer, Function<Integer, Object>, Object> arrayConst
};
}

private static Function<Object, Object> converter(final TypeInformation<?> dataType) {
private static Function<Object, Object> converter(
final TypeInformation<?> dataType, final ExecutionConfig config) {
if (dataType.equals(Types.BOOLEAN())) {
return b -> b instanceof Boolean ? b : null;
}
Expand Down Expand Up @@ -409,7 +415,7 @@ private static Function<Object, Object> converter(final TypeInformation<?> dataT
? ((BasicArrayTypeInfo<?, ?>) dataType).getComponentInfo()
: ((ObjectArrayTypeInfo<?, ?>) dataType).getComponentInfo();
boolean primitive = dataType instanceof PrimitiveArrayTypeInfo;
Function<Object, Object> elementConverter = converter(elementType);
Function<Object, Object> elementConverter = converter(elementType, config);
BiFunction<Integer, Function<Integer, Object>, Object> arrayConstructor =
arrayConstructor(elementType, primitive);
return c -> {
Expand All @@ -431,9 +437,9 @@ private static Function<Object, Object> converter(final TypeInformation<?> dataT
}
if (dataType instanceof MapTypeInfo) {
Function<Object, Object> keyConverter =
converter(((MapTypeInfo<?, ?>) dataType).getKeyTypeInfo());
converter(((MapTypeInfo<?, ?>) dataType).getKeyTypeInfo(), config);
Function<Object, Object> valueConverter =
converter(((MapTypeInfo<?, ?>) dataType).getValueTypeInfo());
converter(((MapTypeInfo<?, ?>) dataType).getValueTypeInfo(), config);
return c ->
c instanceof Map
? ((Map<?, ?>) c)
Expand All @@ -450,7 +456,7 @@ private static Function<Object, Object> converter(final TypeInformation<?> dataT
TypeInformation<?>[] fieldTypes = ((RowTypeInfo) dataType).getFieldTypes();
List<Function<Object, Object>> fieldConverters =
Arrays.stream(fieldTypes)
.map(PythonTableUtils::converter)
.map(x -> PythonTableUtils.converter(x, config))
.collect(Collectors.toList());
return c -> {
if (c != null && c.getClass().isArray()) {
Expand Down Expand Up @@ -480,7 +486,7 @@ private static Function<Object, Object> converter(final TypeInformation<?> dataT
TypeInformation<?>[] fieldTypes = ((TupleTypeInfo<?>) dataType).getFieldTypes();
List<Function<Object, Object>> fieldConverters =
Arrays.stream(fieldTypes)
.map(PythonTableUtils::converter)
.map(x -> PythonTableUtils.converter(x, config))
.collect(Collectors.toList());
return c -> {
if (c != null && c.getClass().isArray()) {
Expand All @@ -505,7 +511,24 @@ private static Function<Object, Object> converter(final TypeInformation<?> dataT
};
}

return Function.identity();
return c -> {
if (c.getClass() != byte[].class || dataType instanceof PickledByteArrayTypeInfo) {
return c;
}

// other typeinfos will use the corresponding serializer to deserialize data.
byte[] b = (byte[]) c;
TypeSerializer<?> dataSerializer = dataType.createSerializer(config);
ByteArrayInputStreamWithPos bais = new ByteArrayInputStreamWithPos();
DataInputViewStreamWrapper baisWrapper = new DataInputViewStreamWrapper(bais);
bais.setBuffer(b, 0, b.length);
try {
return dataSerializer.deserialize(baisWrapper);
} catch (IOException e) {
throw new IllegalStateException(
"Failed to deserialize the object with datatype " + dataType, e);
}
};
}

private static int getOffsetFromLocalMillis(final long millisLocal) {
Expand Down

0 comments on commit 6056680

Please sign in to comment.