Skip to content

Commit

Permalink
Python bindings: accept numpy.int64/float64 arguments for xoff, yoff,…
Browse files Browse the repository at this point in the history
… win_xsize, win_ysize, buf_xsize, buf_ysize arguments of ReadAsArray() (fixes OSGeo#8026)
  • Loading branch information
rouault committed Jul 4, 2023
1 parent a4aa6e5 commit b42eeea
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
39 changes: 39 additions & 0 deletions autotest/gcore/rasterio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,3 +3009,42 @@ def test_rasterio_gdal_rasterio_resampling():
)

assert data_avg1 == data_avg3


###############################################################################
# Test passing numpy.int64 values to ReadAsArray() arguments
# cf https://github.com/OSGeo/gdal/issues/8026


def test_rasterio_numpy_datatypes_for_xoff():
np = pytest.importorskip("numpy")

ds = gdal.Open("data/byte.tif")
assert np.array_equal(
ds.ReadAsArray(np.int64(1), np.int64(2), np.int64(3), np.int64(4)),
ds.ReadAsArray(1, 2, 3, 4),
)
assert np.array_equal(
ds.ReadAsArray(np.float64(1), np.float64(2), np.float64(3), np.float64(4)),
ds.ReadAsArray(1, 2, 3, 4),
)
assert np.array_equal(
ds.ReadAsArray(
np.float64(1.5),
np.float64(2.5),
np.float64(3.5),
np.float64(4.5),
buf_xsize=np.float64(1),
buf_ysize=np.float64(1),
resample_alg=gdal.GRIORA_Cubic,
),
ds.ReadAsArray(
1.5, 2.5, 3.5, 4.5, buf_xsize=1, buf_ysize=1, resample_alg=gdal.GRIORA_Cubic
),
)
assert np.array_equal(
ds.GetRasterBand(1).ReadAsArray(
np.int64(1), np.int64(2), np.int64(3), np.int64(4)
),
ds.GetRasterBand(1).ReadAsArray(1, 2, 3, 4),
)
33 changes: 33 additions & 0 deletions swig/include/gdal_array.i
Original file line number Diff line number Diff line change
Expand Up @@ -2279,6 +2279,19 @@ def SaveArray(src_array, filename, format="GTiff", prototype=None, interleave='b

return driver.CreateCopy(filename, OpenArray(src_array, prototype, interleave))

def _to_primitive_type(x):
"""Converts an object with a __int__ or __float__ method to the
corresponding primitive type, or return x."""
if x is None:
return x
if hasattr(x, "__int__"):
if hasattr(x, "is_integer") and x.is_integer():
return int(x)
elif not hasattr(x, "__float__"):
return int(x)
elif hasattr(x, "__float__"):
return float(x)
return x

def DatasetReadAsArray(ds, xoff=0, yoff=0, win_xsize=None, win_ysize=None, buf_obj=None,
buf_xsize=None, buf_ysize=None, buf_type=None,
Expand All @@ -2293,6 +2306,13 @@ def DatasetReadAsArray(ds, xoff=0, yoff=0, win_xsize=None, win_ysize=None, buf_o
if win_ysize is None:
win_ysize = ds.RasterYSize

xoff = _to_primitive_type(xoff)
yoff = _to_primitive_type(yoff)
win_xsize = _to_primitive_type(win_xsize)
win_ysize = _to_primitive_type(win_ysize)
buf_xsize = _to_primitive_type(buf_xsize)
buf_ysize = _to_primitive_type(buf_ysize)

if band_list is None:
band_list = list(range(1, ds.RasterCount + 1))

Expand Down Expand Up @@ -2387,6 +2407,9 @@ def DatasetWriteArray(ds, array, xoff=0, yoff=0,
"""Pure python implementation of writing a chunk of a GDAL file
from a numpy array. Used by the gdal.Dataset.WriteArray method."""

xoff = _to_primitive_type(xoff)
yoff = _to_primitive_type(yoff)

if band_list is None:
band_list = list(range(1, ds.RasterCount + 1))

Expand Down Expand Up @@ -2460,6 +2483,13 @@ def BandReadAsArray(band, xoff=0, yoff=0, win_xsize=None, win_ysize=None,
if win_ysize is None:
win_ysize = band.YSize

xoff = _to_primitive_type(xoff)
yoff = _to_primitive_type(yoff)
win_xsize = _to_primitive_type(win_xsize)
win_ysize = _to_primitive_type(win_ysize)
buf_xsize = _to_primitive_type(buf_xsize)
buf_ysize = _to_primitive_type(buf_ysize)

if buf_obj is None:
if buf_xsize is None:
buf_xsize = win_xsize
Expand Down Expand Up @@ -2522,6 +2552,9 @@ def BandWriteArray(band, array, xoff=0, yoff=0,
if array is None or len(array.shape) != 2:
raise ValueError("expected array of dim 2")

xoff = _to_primitive_type(xoff)
yoff = _to_primitive_type(yoff)

xsize = array.shape[1]
ysize = array.shape[0]

Expand Down

0 comments on commit b42eeea

Please sign in to comment.