Skip to content

Commit

Permalink
Closes Bears-R-Us#3664: streamline get_max_array_rank checks in unit …
Browse files Browse the repository at this point in the history
…testing
  • Loading branch information
ajpotts committed Aug 21, 2024
1 parent 3965193 commit 9023f70
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 217 deletions.
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ env =
D:ARKOUDA_VERBOSE=True
D:ARKOUDA_CLIENT_TIMEOUT=0
D:ARKOUDA_LOG_LEVEL=DEBUG
markers =
skip_if_max_rank_less_than
skip_if_max_rank_greater_than
11 changes: 3 additions & 8 deletions tests/array_api/array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,9 @@
SIZES = [1, 0, 0, 1, 5, 4, 50]
DIMS = [0, 1, 2, 1, 1, 2, 2]

def get_server_max_array_dims():
try:
return json.load(open('serverConfig.json', 'r'))['max_array_dims']
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1

class TestArrayCreation:
@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="test_zeros requires server with 'max_array_dims' >= 2")
@pytest.mark.skip_if_max_rank_less_than(2)
def test_zeros(self):
for shape, size, dim in zip(SHAPES, SIZES, DIMS):
for dtype in ak.ScalarDTypes:
Expand All @@ -29,7 +24,7 @@ def test_zeros(self):
assert a.dtype == dtype
assert a.tolist() == np.zeros(shape, dtype=dtype).tolist()

@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="test_ones requires server with 'max_array_dims' >= 2")
@pytest.mark.skip_if_max_rank_less_than(2)
def test_ones(self):
for shape, size, dim in zip(SHAPES, SIZES, DIMS):
for dtype in ak.ScalarDTypes:
Expand All @@ -40,7 +35,7 @@ def test_ones(self):
assert a.dtype == dtype
assert a.tolist() == np.ones(shape, dtype=dtype).tolist()

@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="test_from_numpy requires server with 'max_array_dims' >= 2")
@pytest.mark.skip_if_max_rank_less_than(2)
def test_from_numpy(self):
# TODO: support 0D (scalar) arrays
# (need changes to the create0D command from #2967)
Expand Down
62 changes: 11 additions & 51 deletions tests/array_api/array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
s = SEED


def get_server_max_array_dims():
try:
return json.load(open("serverConfig.json", "r"))["max_array_dims"]
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1


def randArr(shape):
global s
s += 2
Expand All @@ -25,10 +18,7 @@ def randArr(shape):

class TestManipulation:

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_broadcast requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_broadcast(self):
a = xp.ones((1, 6, 1))
b = xp.ones((5, 1, 10))
Expand All @@ -47,10 +37,7 @@ def test_broadcast(self):
assert (abcd[2] == 1).all()
assert (abcd[3] == 1).all()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_concat requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_concat(self):
a = randArr((5, 3, 10))
b = randArr((5, 3, 2))
Expand Down Expand Up @@ -83,10 +70,7 @@ def test_concat(self):
assert hijConcat.shape == (18,)
assert hijConcat.tolist() == hijNP.tolist()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_expand_dims requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_expand_dims(self):
a = randArr((5, 3))
alist = a.tolist()
Expand Down Expand Up @@ -121,10 +105,7 @@ def test_expand_dims(self):
with pytest.raises(IndexError):
xp.expand_dims(a, axis=-4)

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_flip requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_flip(self):
# 1D case
a = xp.arange(10)
Expand Down Expand Up @@ -164,10 +145,7 @@ def test_flip(self):
with pytest.raises(IndexError):
xp.flip(r, axis=-4)

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_permute_dims requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_permute_dims(self):
r = randArr((7, 8, 9))

Expand All @@ -194,10 +172,7 @@ def test_permute_dims(self):
with pytest.raises(IndexError):
xp.permute_dims(r, (0, 1, -4))

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_reshape requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_reshape(self):
r = randArr((2, 6, 12))
nr = np.asarray(r.tolist())
Expand Down Expand Up @@ -226,10 +201,7 @@ def test_reshape(self):
# more than one dimension can't be inferred
xp.reshape(r, (2, -1, -1))

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_roll requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_roll(self):
# 1D case
a = xp.arange(10)
Expand Down Expand Up @@ -272,10 +244,7 @@ def test_roll(self):
with pytest.raises(IndexError):
xp.roll(r, 3, axis=-4)

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_squeeze requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_squeeze(self):
r1 = randArr((1, 2, 3))
r2 = randArr((2, 1, 3))
Expand Down Expand Up @@ -308,10 +277,7 @@ def test_squeeze(self):
with pytest.raises(ValueError):
xp.squeeze(r4, axis=1)

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_stack_unstack requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_stack_unstack(self):
a = randArr((5, 4))
b = randArr((5, 4))
Expand All @@ -337,10 +303,7 @@ def test_stack_unstack(self):
assert bp.tolist() == b.tolist()
assert cp.tolist() == c.tolist()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_tile requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(2)
def test_tile(self):
a = randArr((2, 3))

Expand All @@ -350,10 +313,7 @@ def test_tile(self):
assert at.shape == npat.shape
assert at.tolist() == npat.tolist()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_repeat requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_repeat(self):
a = randArr((5, 10))
r = randArr((50,))
Expand Down
12 changes: 1 addition & 11 deletions tests/array_api/binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,8 @@
SCALAR_TYPES.remove("bool_")


def get_server_max_array_dims():
try:
return json.load(open("serverConfig.json", "r"))["max_array_dims"]
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1


class TestArrayCreation:
@pytest.mark.skipif(
get_server_max_array_dims() < 2,
reason="test_binops requires server with 'max_array_dims' >= 2",
)
@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.parametrize("op", ["+", "-", "*", "/"])
@pytest.mark.parametrize("dtype", SCALAR_TYPES)
def test_binops(self, op, dtype):
Expand Down
27 changes: 4 additions & 23 deletions tests/array_api/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,14 @@
s = SEED


def get_server_max_array_dims():
try:
return json.load(open("serverConfig.json", "r"))["max_array_dims"]
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1


def randArr(shape):
global s
s += 2
return xp.asarray(ak.randint(0, 100, shape, dtype=ak.int64, seed=s))


class TestIndexing:
@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_rank_changing_assignment requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_rank_changing_assignment(self):
a = randArr((5, 6, 7))
b = randArr((5, 6))
Expand All @@ -47,10 +37,7 @@ def test_rank_changing_assignment(self):
a[:, :, :] = e
assert a.tolist() == e.tolist()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_nd_assignment requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_nd_assignment(self):
a = randArr((5, 6, 7))
bnp = randArr((5, 6, 7)).to_ndarray()
Expand All @@ -64,10 +51,7 @@ def test_nd_assignment(self):
a[:] = 5
assert (a == 5).all()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_pdarray_index requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_pdarray_index(self):
a = randArr((5, 6, 7))
anp = np.asarray(a.tolist())
Expand Down Expand Up @@ -98,10 +82,7 @@ def test_pdarray_index(self):
xnp = anp[:]
assert x.tolist() == xnp.tolist()

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_none_index requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_none_index(self):
a = randArr((10, 10))
anp = np.asarray(a.tolist())
Expand Down
32 changes: 5 additions & 27 deletions tests/array_api/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,8 @@
SEED = 314159


def get_server_max_array_dims():
try:
return json.load(open("serverConfig.json", "r"))["max_array_dims"]
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1


class TestSearchingFunctions:
@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_argmax requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_argmax(self):
a = xp.asarray(ak.randint(0, 100, (4, 5, 6), dtype=ak.int64, seed=SEED))
a[3, 2, 1] = 101
Expand All @@ -37,10 +27,7 @@ def test_argmax(self):
assert aArgmax1Keepdims.shape == (4, 1, 6)
assert aArgmax1Keepdims[3, 0, 1] == 2

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_argmin requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_argmin(self):
a = xp.asarray(ak.randint(0, 100, (4, 5, 6), dtype=ak.int64, seed=SEED))
a[3, 2, 1] = -1
Expand All @@ -55,10 +42,7 @@ def test_argmin(self):
assert aArgmin1Keepdims.shape == (4, 1, 6)
assert aArgmin1Keepdims[3, 0, 1] == 2

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_nonzero requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_nonzero(self):
a = xp.zeros((4, 5, 6), dtype=ak.int64)
a[0, 1, 0] = 1
Expand All @@ -74,10 +58,7 @@ def test_nonzero(self):
assert sorted(nz[1].tolist()) == sorted([1, 2, 2, 2])
assert sorted(nz[2].tolist()) == sorted([0, 3, 2, 1])

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_where requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_where(self):
a = xp.zeros((4, 5, 6), dtype=ak.int64)
a[1, 2, 3] = 1
Expand All @@ -96,10 +77,7 @@ def test_where(self):
assert d[0, 0, 0] == c[0, 0, 0]
assert d[3, 3, 3] == c[3, 3, 3]

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_search_sorted requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_search_sorted(self):
a = xp.asarray(ak.randint(0, 100, 1000, dtype=ak.float64))
b = xp.asarray(ak.randint(0, 100, (10, 10), dtype=ak.float64))
Expand Down
12 changes: 1 addition & 11 deletions tests/array_api/set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
s = SEED


def get_server_max_array_dims():
try:
return json.load(open("serverConfig.json", "r"))["max_array_dims"]
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1


def randArr(shape):
global s
s += 2
Expand All @@ -25,10 +18,7 @@ def randArr(shape):

class TestSetFunction:

@pytest.mark.skipif(
get_server_max_array_dims() < 3,
reason="test_set_functions requires server with 'max_array_dims' >= 3",
)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_set_functions(self):

for shape in [(1000), (20, 50), (2, 10, 50)]:
Expand Down
12 changes: 4 additions & 8 deletions tests/array_api/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@
SHAPES = [(1,), (25,), (5, 10), (10, 5)]
SEED = 12345
SCALAR_TYPES = list(ak.ScalarDTypes)
SCALAR_TYPES.remove('bool_')
SCALAR_TYPES.remove("bool_")


def get_server_max_array_dims():
try:
return json.load(open('serverConfig.json', 'r'))['max_array_dims']
except (ValueError, FileNotFoundError, TypeError, KeyError):
return 1
class TestArrayCreation:

@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="test_argsort requires server with 'max_array_dims' >= 2")
@pytest.mark.skip_if_max_rank_less_than(2)
def test_argsort(self):
for shape in SHAPES:
for dtype in ak.ScalarDTypes:
Expand Down Expand Up @@ -50,7 +46,7 @@ def test_argsort(self):
for j in range(shape[1] - 1):
assert a[i, b[i, j]] <= a[i, b[i, j + 1]]

@pytest.mark.skipif(get_server_max_array_dims() < 2, reason="test_sort requires server with 'max_array_dims' >= 2")
@pytest.mark.skip_if_max_rank_less_than(2)
@pytest.mark.parametrize("dtype", SCALAR_TYPES)
@pytest.mark.parametrize("shape", SHAPES)
def test_sort(self, dtype, shape):
Expand Down
Loading

0 comments on commit 9023f70

Please sign in to comment.