Skip to content

Commit

Permalink
Refactor enum logic
Browse files Browse the repository at this point in the history
Add some type ignores for bugs in mypy. These will be removed when moved
to pyright.
  • Loading branch information
GDYendell committed Sep 24, 2024
1 parent d254dd3 commit dda9e45
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 58 deletions.
72 changes: 35 additions & 37 deletions src/fastcs/backends/epics/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

from fastcs.attributes import AttrR, AttrRW, AttrW
from fastcs.backends.epics.util import (
MBB_MAX_CHOICES,
MBB_STATE_FIELDS,
convert_if_enum,
attr_is_enum,
enum_index_to_value,
enum_value_to_index,
)
from fastcs.controller import BaseController
from fastcs.datatypes import Bool, Float, Int, String, T
Expand Down Expand Up @@ -155,23 +156,25 @@ def _create_and_link_attribute_pvs(pv_prefix: str, mapping: Mapping) -> None:
def _create_and_link_read_pv(
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T]
) -> None:
record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
if attr_is_enum(attribute):

_add_attr_pvi_info(record, pv_prefix, attr_name, "r")
async def async_record_set(value: T):
record.set(enum_value_to_index(attribute, value))
else:

async def async_record_set(value: T):
record.set(convert_if_enum(attribute, value))
async def async_record_set(value: T): # type: ignore
record.set(value)

record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
_add_attr_pvi_info(record, pv_prefix, attr_name, "r")

attribute.set_update_callback(async_record_set)


def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
if (
isinstance(attribute.datatype, String)
and attribute.allowed_values is not None
and len(attribute.allowed_values) <= MBB_MAX_CHOICES
):
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False))
if attr_is_enum(attribute):
# https://github.com/python/mypy/issues/16789
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) # type: ignore
return builder.mbbIn(pv, **state_keys)

match attribute.datatype:
Expand All @@ -192,40 +195,35 @@ def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
def _create_and_link_write_pv(
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T]
) -> None:
async def on_update(value):
if (
isinstance(attribute.datatype, String)
and isinstance(value, int)
and attribute.allowed_values is not None
):
try:
value = attribute.allowed_values[value]
except IndexError:
raise IndexError(
f"Invalid index {value}, allowed values: {attribute.allowed_values}"
) from None

await attribute.process_without_display_update(value)
if attr_is_enum(attribute):

async def on_update(value):
await attribute.process_without_display_update(
enum_index_to_value(attribute, value)
)

async def async_write_display(value: T):
record.set(enum_value_to_index(attribute, value), process=False)

else:

async def on_update(value):
await attribute.process_without_display_update(value)

async def async_write_display(value: T): # type: ignore
record.set(value, process=False)

record = _get_output_record(
f"{pv_prefix}:{pv_name}", attribute, on_update=on_update
)

_add_attr_pvi_info(record, pv_prefix, attr_name, "w")

async def async_record_set(value: T):
record.set(convert_if_enum(attribute, value), process=False)

attribute.set_write_display_callback(async_record_set)
attribute.set_write_display_callback(async_write_display)


def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any:
if (
isinstance(attribute.datatype, String)
and attribute.allowed_values is not None
and len(attribute.allowed_values) <= MBB_MAX_CHOICES
):
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False))
if attr_is_enum(attribute):
state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) # type: ignore
return builder.mbbOut(pv, always_update=True, on_update=on_update, **state_keys)

match attribute.datatype:
Expand Down
74 changes: 61 additions & 13 deletions src/fastcs/backends/epics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,75 @@
MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES)


def convert_if_enum(attribute: Attribute[T], value: T) -> T | int:
"""Check if `attribute` is a string enum and if so convert `value` to index of enum.
def attr_is_enum(attribute: Attribute) -> bool:
"""Check if the `Attribute` has a `String` datatype and has `allowed_values` set.
Args:
`attribute`: The attribute to be set
`value`: The value
attribute: The `Attribute` to check
Returns:
The index of the `value` if the `attribute` is an enum, else `value`
Raises:
ValueError: If `attribute` is an enum and `value` is not in its allowed values
`True` if `Attribute` is an enum, else `False`
"""
match attribute:
case Attribute(
datatype=String(), allowed_values=allowed_values
) if allowed_values is not None and len(allowed_values) <= MBB_MAX_CHOICES:
if value in allowed_values:
return allowed_values.index(value)
else:
raise ValueError(f"'{value}' not in allowed values {allowed_values}")
return True
case _:
return value
return False


def enum_value_to_index(attribute: Attribute[T], value: T) -> int:
"""Convert the given value to the index within the allowed_values of the Attribute
Args:
`attribute`: The attribute
`value`: The value to convert
Returns:
The index of the `value`
Raises:
ValueError: If `attribute` has no allowed values or `value` is not a valid
option
"""
if attribute.allowed_values is None:
raise ValueError(
"Cannot convert value to index for Attribute without allowed values"
)

try:
return attribute.allowed_values.index(value)
except ValueError:
raise ValueError(
f"{value} not in allowed values of {attribute}: {attribute.allowed_values}"
) from None


def enum_index_to_value(attribute: Attribute[T], index: int) -> T:
"""Lookup the value from the allowed_values of an attribute at the given index.
Parameters:
attribute: The `Attribute` to lookup the index from
index: The index of the value to retrieve
Returns:
The value at the specified index in the allowed values list.
Raises:
IndexError: If the index is out of bounds
"""
if attribute.allowed_values is None:
raise ValueError(
"Cannot lookup value by index for Attribute without allowed values"
)

try:
return attribute.allowed_values[index]
except IndexError:
raise IndexError(
f"Invalid index {index} into allowed values: {attribute.allowed_values}"
) from None
119 changes: 119 additions & 0 deletions tests/backends/epics/test_ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
_add_attr_pvi_info,
_add_pvi_info,
_add_sub_controller_pvi_info,
_create_and_link_read_pv,
_create_and_link_write_pv,
_get_input_record,
_get_output_record,
)
Expand All @@ -25,6 +27,54 @@
ONOFF_STATES = {"ZRST": "disabled", "ONST": "enabled"}


@pytest.mark.asyncio
async def test_create_and_link_read_pv(mocker: MockerFixture):
get_input_record = mocker.patch("fastcs.backends.epics.ioc._get_input_record")
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
record = get_input_record.return_value

attribute = mocker.MagicMock()

attr_is_enum.return_value = False
_create_and_link_read_pv("PREFIX", "PV", "attr", attribute)

get_input_record.assert_called_once_with("PREFIX:PV", attribute)
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r")

# Extract the callback generated and set in the function and call it
attribute.set_update_callback.assert_called_once_with(mocker.ANY)
record_set_callback = attribute.set_update_callback.call_args[0][0]
await record_set_callback(1)

record.set.assert_called_once_with(1)


@pytest.mark.asyncio
async def test_create_and_link_read_pv_enum(mocker: MockerFixture):
get_input_record = mocker.patch("fastcs.backends.epics.ioc._get_input_record")
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
record = get_input_record.return_value
enum_value_to_index = mocker.patch("fastcs.backends.epics.ioc.enum_value_to_index")

attribute = mocker.MagicMock()

attr_is_enum.return_value = True
_create_and_link_read_pv("PREFIX", "PV", "attr", attribute)

get_input_record.assert_called_once_with("PREFIX:PV", attribute)
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r")

# Extract the callback generated and set in the function and call it
attribute.set_update_callback.assert_called_once_with(mocker.ANY)
record_set_callback = attribute.set_update_callback.call_args[0][0]
await record_set_callback(1)

enum_value_to_index.assert_called_once_with(attribute, 1)
record.set.assert_called_once_with(enum_value_to_index.return_value)


@pytest.mark.parametrize(
"attribute,record_type,kwargs",
(
Expand Down Expand Up @@ -57,6 +107,75 @@ def test_get_input_record_raises(mocker: MockerFixture):
_get_input_record("PV", mocker.MagicMock())


@pytest.mark.asyncio
async def test_create_and_link_write_pv(mocker: MockerFixture):
get_output_record = mocker.patch("fastcs.backends.epics.ioc._get_output_record")
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
record = get_output_record.return_value

attribute = mocker.MagicMock()
attribute.process_without_display_update = mocker.AsyncMock()

attr_is_enum.return_value = False
_create_and_link_write_pv("PREFIX", "PV", "attr", attribute)

get_output_record.assert_called_once_with(
"PREFIX:PV", attribute, on_update=mocker.ANY
)
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w")

# Extract the write update callback generated and set in the function and call it
attribute.set_write_display_callback.assert_called_once_with(mocker.ANY)
write_display_callback = attribute.set_write_display_callback.call_args[0][0]
await write_display_callback(1)

record.set.assert_called_once_with(1, process=False)

# Extract the on update callback generated and set in the function and call it
on_update_callback = get_output_record.call_args[1]["on_update"]
await on_update_callback(1)

attribute.process_without_display_update.assert_called_once_with(1)


@pytest.mark.asyncio
async def test_create_and_link_write_pv_enum(mocker: MockerFixture):
get_output_record = mocker.patch("fastcs.backends.epics.ioc._get_output_record")
add_attr_pvi_info = mocker.patch("fastcs.backends.epics.ioc._add_attr_pvi_info")
attr_is_enum = mocker.patch("fastcs.backends.epics.ioc.attr_is_enum")
enum_value_to_index = mocker.patch("fastcs.backends.epics.ioc.enum_value_to_index")
enum_index_to_value = mocker.patch("fastcs.backends.epics.ioc.enum_index_to_value")
record = get_output_record.return_value

attribute = mocker.MagicMock()
attribute.process_without_display_update = mocker.AsyncMock()

attr_is_enum.return_value = True
_create_and_link_write_pv("PREFIX", "PV", "attr", attribute)

get_output_record.assert_called_once_with(
"PREFIX:PV", attribute, on_update=mocker.ANY
)
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w")

# Extract the write update callback generated and set in the function and call it
attribute.set_write_display_callback.assert_called_once_with(mocker.ANY)
write_display_callback = attribute.set_write_display_callback.call_args[0][0]
await write_display_callback(1)

enum_value_to_index.assert_called_once_with(attribute, 1)
record.set.assert_called_once_with(enum_value_to_index.return_value, process=False)

# Extract the on update callback generated and set in the function and call it
on_update_callback = get_output_record.call_args[1]["on_update"]
await on_update_callback(1)

attribute.process_without_display_update.assert_called_once_with(
enum_index_to_value.return_value
)


@pytest.mark.parametrize(
"attribute,record_type,kwargs",
(
Expand Down
38 changes: 30 additions & 8 deletions tests/backends/epics/test_util.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
import pytest

from fastcs.attributes import AttrR
from fastcs.backends.epics.util import convert_if_enum
from fastcs.backends.epics.util import (
attr_is_enum,
enum_index_to_value,
enum_value_to_index,
)
from fastcs.datatypes import String


def test_convert_if_enum():
string_attr = AttrR(String())
enum_attr = AttrR(String(), allowed_values=["disabled", "enabled"])
def test_attr_is_enum():
assert not attr_is_enum(AttrR(String()))
assert attr_is_enum(AttrR(String(), allowed_values=["disabled", "enabled"]))

assert convert_if_enum(string_attr, "enabled") == "enabled"

assert convert_if_enum(enum_attr, "enabled") == 1
def test_enum_index_to_value():
"""Test enum_index_to_value."""
attribute = AttrR(String(), allowed_values=["disabled", "enabled"])

with pytest.raises(ValueError):
convert_if_enum(enum_attr, "off")
assert enum_index_to_value(attribute, 0) == "disabled"
assert enum_index_to_value(attribute, 1) == "enabled"
with pytest.raises(IndexError, match="Invalid index"):
enum_index_to_value(attribute, 2)

with pytest.raises(ValueError, match="Cannot lookup value by index"):
enum_index_to_value(AttrR(String()), 0)


def test_enum_value_to_index():
attribute = AttrR(String(), allowed_values=["disabled", "enabled"])

assert enum_value_to_index(attribute, "disabled") == 0
assert enum_value_to_index(attribute, "enabled") == 1
with pytest.raises(ValueError, match="not in allowed values"):
enum_value_to_index(attribute, "off")

with pytest.raises(ValueError, match="Cannot convert value to index"):
enum_value_to_index(AttrR(String()), "disabled")

0 comments on commit dda9e45

Please sign in to comment.