104 lines
2.6 KiB
Python
104 lines
2.6 KiB
Python
"""
|
||
geo_tools.core.raster
|
||
~~~~~~~~~~~~~~~~~~~~~
|
||
栅格数据处理预留接口。
|
||
|
||
当前提供基于 rasterio 的核心读写骨架;
|
||
如需完整栅格分析功能,请安装可选依赖:
|
||
``pip install geo-tools[raster]``
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
if TYPE_CHECKING:
|
||
import numpy as np
|
||
|
||
|
||
def _require_rasterio() -> Any:
|
||
"""检查 rasterio 是否可用,不可用时给出明确提示。"""
|
||
try:
|
||
import rasterio
|
||
return rasterio
|
||
except ImportError as exc:
|
||
raise ImportError(
|
||
"栅格处理功能需要 rasterio。\n"
|
||
"请执行:pip install geo-tools[raster] 或 pip install rasterio"
|
||
) from exc
|
||
|
||
|
||
def read_raster(
|
||
path: str | Path,
|
||
band: int = 1,
|
||
) -> tuple["np.ndarray", dict[str, Any]]:
|
||
"""读取栅格文件(单波段)。
|
||
|
||
Parameters
|
||
----------
|
||
path:
|
||
GeoTIFF 或其他 GDAL 支持格式的路径。
|
||
band:
|
||
波段号,1-indexed。
|
||
|
||
Returns
|
||
-------
|
||
(np.ndarray, dict)
|
||
栅格数组 和 rasterio 元数据字典(``meta``)。
|
||
"""
|
||
rasterio = _require_rasterio()
|
||
with rasterio.open(str(path)) as src:
|
||
data = src.read(band)
|
||
meta = src.meta.copy()
|
||
return data, meta
|
||
|
||
|
||
def write_raster(
|
||
data: "np.ndarray",
|
||
path: str | Path,
|
||
meta: dict[str, Any],
|
||
band: int = 1,
|
||
) -> Path:
|
||
"""将 numpy 数组写出为 GeoTIFF。
|
||
|
||
Parameters
|
||
----------
|
||
data:
|
||
2D numpy 数组(单波段)。
|
||
path:
|
||
输出路径(.tif)。
|
||
meta:
|
||
rasterio 元数据字典(从 ``read_raster`` 获取或自行构造)。
|
||
band:
|
||
写入的波段号,1-indexed。
|
||
|
||
Returns
|
||
-------
|
||
Path
|
||
"""
|
||
rasterio = _require_rasterio()
|
||
path = Path(path)
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
meta.update({"count": 1, "dtype": str(data.dtype)})
|
||
with rasterio.open(str(path), "w", **meta) as dst:
|
||
dst.write(data, band)
|
||
return path
|
||
|
||
|
||
def get_raster_info(path: str | Path) -> dict[str, Any]:
|
||
"""获取栅格文件的基本元信息(行列数、波段数、CRS、分辨率等)。"""
|
||
rasterio = _require_rasterio()
|
||
with rasterio.open(str(path)) as src:
|
||
return {
|
||
"width": src.width,
|
||
"height": src.height,
|
||
"count": src.count,
|
||
"dtype": src.dtypes[0],
|
||
"crs": str(src.crs),
|
||
"transform": src.transform,
|
||
"bounds": src.bounds,
|
||
"nodata": src.nodata,
|
||
"res": src.res,
|
||
}
|