Skip to content

Commit

Permalink
Add data handling support for Pydantic instances (streamlit#9290)
Browse files Browse the repository at this point in the history
## Describe your changes

This adds support for correctly handling pydantic model instances as
input for Streamlit commands, such as `st.dataframe`, `st.data_editor`,
charts, `st.map`, `st.write`, `st.selectbox`, and many more.

This also applies a small refactoring for st.json handling. Dict values,
keys, and item instances will be visualized via `st.json` when used in
`st.write` and not printed out as string.

## Testing Plan

- Added to `data_mocks` -> This will be used for various tests covering
many relevant commands.

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
lukasmasuch committed Aug 16, 2024
1 parent 066dad7 commit e31343f
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 15 deletions.
14 changes: 10 additions & 4 deletions lib/streamlit/dataframe_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
import math
import re
from collections import ChainMap, UserDict, UserList, deque
from collections.abc import ItemsView, Mapping
from collections.abc import ItemsView
from enum import Enum, EnumMeta, auto
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Final,
Iterable,
List,
Mapping,
Protocol,
Sequence,
TypeVar,
Expand All @@ -44,12 +44,14 @@

from streamlit import config, errors, logger, string_util
from streamlit.type_util import (
CustomDict,
NumpyShape,
has_callable_attr,
is_custom_dict,
is_dataclass_instance,
is_list_like,
is_namedtuple,
is_pydantic_model,
is_type,
)

Expand Down Expand Up @@ -160,8 +162,9 @@ def iloc(self) -> _iLocIndexer: ...
"pa.Table",
"np.ndarray[Any, np.dtype[Any]]",
Iterable[Any],
Dict[Any, Any],
"Mapping[Any, Any]",
DBAPICursor,
CustomDict,
None,
]

Expand Down Expand Up @@ -657,7 +660,9 @@ def convert_anything_to_pandas_df(
return _dict_to_pandas_df(dataclasses.asdict(data))

# Support for dict-like objects
if isinstance(data, (ChainMap, MappingProxyType, UserDict)):
if isinstance(data, (ChainMap, MappingProxyType, UserDict)) or is_pydantic_model(
data
):
return _dict_to_pandas_df(dict(data))

# Try to convert to pandas.DataFrame. This will raise an error is df is not
Expand Down Expand Up @@ -1141,6 +1146,7 @@ def determine_data_format(input_data: Any) -> DataFormat:
or is_dataclass_instance(input_data)
or is_namedtuple(input_data)
or is_custom_dict(input_data)
or is_pydantic_model(input_data)
):
return DataFormat.KEY_VALUE_DICT
elif isinstance(input_data, (ItemsView, enumerate)):
Expand Down
17 changes: 12 additions & 5 deletions lib/streamlit/elements/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

from streamlit.proto.Json_pb2 import Json as JsonProto
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.type_util import is_custom_dict, is_namedtuple
from streamlit.type_util import (
is_custom_dict,
is_list_like,
is_namedtuple,
is_pydantic_model,
)

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
Expand Down Expand Up @@ -86,11 +91,13 @@ def json(
if is_namedtuple(body):
body = body._asdict()

if isinstance(body, (map, enumerate)):
body = list(body)
if isinstance(
body, (ChainMap, types.MappingProxyType, UserDict)
) or is_pydantic_model(body):
body = dict(body) # type: ignore

if isinstance(body, (ChainMap, types.MappingProxyType, UserDict)):
body = dict(body)
if is_list_like(body):
body = list(body)

if not isinstance(body, str):
try:
Expand Down
13 changes: 10 additions & 3 deletions lib/streamlit/elements/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import types
from collections import ChainMap, UserDict, UserList
from collections.abc import ItemsView, KeysView, ValuesView
from io import StringIO
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -258,13 +259,13 @@ def write(self, *args: Any, unsafe_allow_html: bool = False, **kwargs) -> None:
- write(string) : Prints the formatted Markdown string, with
support for LaTeX expression, emoji shortcodes, and colored text.
See docs for st.markdown for more.
- write(data_frame) : Displays any dataframe-compatible value
as read-only table.
- write(dataframe) : Displays any dataframe-like object in a table.
- write(dict) : Displays dict-like in an interactive viewer.
- write(list) : Displays list-like in an interactive viewer.
- write(error) : Prints an exception specially.
- write(func) : Displays information about a function.
- write(module) : Displays information about the module.
- write(class) : Displays information about a class.
- write(dict) : Displays dict in an interactive widget.
- write(mpl_fig) : Displays a Matplotlib figure.
- write(generator) : Streams the output of a generator.
- write(openai.Stream) : Streams the output of an OpenAI stream.
Expand All @@ -276,8 +277,10 @@ def write(self, *args: Any, unsafe_allow_html: bool = False, **kwargs) -> None:
- write(bokeh_fig) : Displays a Bokeh figure.
- write(sympy_expr) : Prints SymPy expression using LaTeX.
- write(htmlable) : Prints _repr_html_() for the object if available.
- write(db_cursor) : Displays DB API 2.0 cursor results in a table.
- write(obj) : Prints str(obj) if otherwise unknown.
unsafe_allow_html : bool
Whether to render HTML within ``*args``. This only applies to
strings or objects falling back on ``_repr_html_()``. If this is
Expand Down Expand Up @@ -457,10 +460,14 @@ def flush_buffer():
UserDict,
ChainMap,
UserList,
ItemsView,
KeysView,
ValuesView,
),
)
or type_util.is_custom_dict(arg)
or type_util.is_namedtuple(arg)
or type_util.is_pydantic_model(arg)
):
flush_buffer()
self.dg.json(arg)
Expand Down
17 changes: 17 additions & 0 deletions lib/streamlit/type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def is_type(obj: object, fqn_type_pattern: str | re.Pattern[str]) -> bool:
return fqn_type_pattern.match(fqn_type) is not None


def _is_type_instance(obj: object, type_to_check: str) -> bool:
"""Check if instance of type without importing expensive modules."""
return type_to_check in [get_fqn(t) for t in type(obj).__mro__]


def get_fqn(the_type: type) -> str:
"""Get module.type_name for a given type."""
return f"{the_type.__module__}.{the_type.__qualname__}"
Expand Down Expand Up @@ -287,6 +292,18 @@ def is_pydeck(obj: object) -> TypeGuard[Deck]:
return is_type(obj, "pydeck.bindings.deck.Deck")


def is_pydantic_model(obj) -> bool:
"""True if input looks like a Pydantic model instance."""

if isinstance(obj, type):
# The obj is a class, but we
# only want to check for instances
# of Pydantic models, so we return False.
return False

return _is_type_instance(obj, "pydantic.main.BaseModel")


def is_custom_dict(obj: object) -> TypeGuard[CustomDict]:
"""True if input looks like one of the Streamlit custom dictionaries."""
from streamlit.runtime.context import StreamlitCookies, StreamlitHeaders
Expand Down
39 changes: 36 additions & 3 deletions lib/tests/streamlit/data_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def data_generator():
1,
DataFormat.LIST_OF_VALUES,
["st.number_input", "st.text_area", "st.text_input"],
"markdown",
"json",
False,
list,
),
Expand All @@ -277,7 +277,7 @@ def data_generator():
1,
DataFormat.LIST_OF_VALUES,
["number", "text", "text"],
"markdown",
"json",
False,
list,
),
Expand All @@ -298,7 +298,7 @@ def data_generator():
("st.text_area", "text"),
("st.text_input", "text"),
],
"markdown",
"json",
False,
list,
),
Expand Down Expand Up @@ -1149,3 +1149,36 @@ def data_generator():
)
except ModuleNotFoundError:
print("Xarray not installed. Skipping Xarray dataframe integration tests.") # noqa: T201

###################################
########## Pydantic Types #########
###################################
try:
from pydantic import BaseModel

class ElementPydanticModel(BaseModel):
name: str
is_widget: bool
usage: float

SHARED_TEST_CASES.extend(
[
(
"Pydantic Model",
ElementPydanticModel(
name="st.number_input", is_widget=True, usage=0.32
),
CaseMetadata(
3,
1,
DataFormat.KEY_VALUE_DICT,
["st.number_input", True, 0.32],
"json",
False,
dict,
),
),
]
)
except ModuleNotFoundError:
print("Pydantic not installed. Skipping Pydantic dataframe tests.") # noqa: T201
16 changes: 16 additions & 0 deletions lib/tests/streamlit/type_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def test_is_namedtuple(self):
res = type_util.is_namedtuple(John)
self.assertTrue(res)

@pytest.mark.require_integration
def test_is_pydantic_model(self):
from pydantic import BaseModel

class OtherObject:
foo: int
bar: str

class BarModel(BaseModel):
foo: int
bar: str

self.assertTrue(type_util.is_pydantic_model(BarModel(foo=1, bar="test")))
self.assertFalse(type_util.is_pydantic_model(BarModel))
self.assertFalse(type_util.is_pydantic_model(OtherObject))

def test_to_bytes(self):
bytes_obj = b"some bytes"
self.assertTrue(type_util.is_bytes_like(bytes_obj))
Expand Down

0 comments on commit e31343f

Please sign in to comment.