From af988ea7b9595c94fb4f577a41d42c6e4e90cd6f Mon Sep 17 00:00:00 2001 From: missum Date: Wed, 4 Mar 2026 17:07:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 26 +++ .gitignore | 81 ++++++++++ README.md | 128 +++++++++++++++ data/sample/sample_points.geojson | 102 ++++++++++++ data/sample/sample_regions.geojson | 120 ++++++++++++++ docs/projection_guide.md | 248 ++++++++++++++++++++++++++++ geo_tools/__init__.py | 178 ++++++++++++++++++++ geo_tools/analysis/__init__.py | 25 +++ geo_tools/analysis/spatial_ops.py | 149 +++++++++++++++++ geo_tools/analysis/stats.py | 136 ++++++++++++++++ geo_tools/config/__init__.py | 5 + geo_tools/config/project_enum.py | 43 +++++ geo_tools/config/settings.py | 98 +++++++++++ geo_tools/core/__init__.py | 80 +++++++++ geo_tools/core/geometry.py | 169 +++++++++++++++++++ geo_tools/core/projection.py | 220 +++++++++++++++++++++++++ geo_tools/core/raster.py | 103 ++++++++++++ geo_tools/core/vector.py | 210 ++++++++++++++++++++++++ geo_tools/io/__init__.py | 31 ++++ geo_tools/io/readers.py | 251 +++++++++++++++++++++++++++++ geo_tools/io/writers.py | 207 ++++++++++++++++++++++++ geo_tools/utils/__init__.py | 30 ++++ geo_tools/utils/config.py | 85 ++++++++++ geo_tools/utils/logger.py | 107 ++++++++++++ geo_tools/utils/validators.py | 145 +++++++++++++++++ logs/.gitkeep | 1 + output/.gitkeep | 1 + pyproject.toml | 98 +++++++++++ scripts/example_workflow.py | 104 ++++++++++++ tests/__init__.py | 1 + tests/conftest.py | 80 +++++++++ tests/test1.py | 20 +++ tests/test_analysis.py | 80 +++++++++ tests/test_geometry.py | 110 +++++++++++++ tests/test_io.py | 81 ++++++++++ tests/test_proj.py | 21 +++ tests/test_vector.py | 100 ++++++++++++ 37 files changed, 3674 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 README.md create mode 100644 data/sample/sample_points.geojson create mode 100644 data/sample/sample_regions.geojson create mode 100644 docs/projection_guide.md create mode 100644 geo_tools/__init__.py create mode 100644 geo_tools/analysis/__init__.py create mode 100644 geo_tools/analysis/spatial_ops.py create mode 100644 geo_tools/analysis/stats.py create mode 100644 geo_tools/config/__init__.py create mode 100644 geo_tools/config/project_enum.py create mode 100644 geo_tools/config/settings.py create mode 100644 geo_tools/core/__init__.py create mode 100644 geo_tools/core/geometry.py create mode 100644 geo_tools/core/projection.py create mode 100644 geo_tools/core/raster.py create mode 100644 geo_tools/core/vector.py create mode 100644 geo_tools/io/__init__.py create mode 100644 geo_tools/io/readers.py create mode 100644 geo_tools/io/writers.py create mode 100644 geo_tools/utils/__init__.py create mode 100644 geo_tools/utils/config.py create mode 100644 geo_tools/utils/logger.py create mode 100644 geo_tools/utils/validators.py create mode 100644 logs/.gitkeep create mode 100644 output/.gitkeep create mode 100644 pyproject.toml create mode 100644 scripts/example_workflow.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test1.py create mode 100644 tests/test_analysis.py create mode 100644 tests/test_geometry.py create mode 100644 tests/test_io.py create mode 100644 tests/test_proj.py create mode 100644 tests/test_vector.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..170f3e2 --- /dev/null +++ b/.env.example @@ -0,0 +1,26 @@ +# ============================================================ +# geo_tools 环境变量配置示例 +# 复制本文件为 .env 并按实际路径修改,.env 已在 .gitignore 中忽略 +# ============================================================ + +# ── 目录配置 ───────────────────────────────────────────────── +# 输出文件根目录(绝对路径或相对于项目根) +GEO_TOOLS_OUTPUT_DIR=output + +# 日志文件目录 +GEO_TOOLS_LOG_DIR=logs + +# ── 坐标系配置 ──────────────────────────────────────────────── +# 默认投影坐标系 EPSG 编码(地理坐标系:4326,中国常用:4490) +GEO_TOOLS_DEFAULT_CRS=EPSG:4326 + +# ── 日志配置 ────────────────────────────────────────────────── +# 日志等级:DEBUG / INFO / WARNING / ERROR / CRITICAL +GEO_TOOLS_LOG_LEVEL=INFO + +# 是否同时写出日志文件(true / false) +GEO_TOOLS_LOG_TO_FILE=true + +# ── 性能配置 ────────────────────────────────────────────────── +# 并行处理时最大 CPU 核数(0 = 自动检测) +GEO_TOOLS_MAX_WORKERS=0 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2448ce8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,81 @@ +# ── Python ────────────────────────────────────────────────────────────────── +__pycache__/ +*.py[cod] +*$py.class +*.so +*.egg +*.egg-info/ +dist/ +build/ +.eggs/ +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ +.coverage +htmlcov/ +.tox/ + +# ── 虚拟环境 ──────────────────────────────────────────────────────────────── +.venv/ +venv/ +env/ +.env +!.env.example + +# ── IDE ───────────────────────────────────────────────────────────────────── +.idea/ +.vscode/ +*.swp +*.swo +.DS_Store +Thumbs.db + +# ── 日志与输出(保留目录结构,忽略内容)────────────────────────────────────── +logs/* +!logs/.gitkeep +output/* +!output/.gitkeep + +# ── GIS 真实数据(保留示例数据,忽略用户数据)──────────────────────────────── +data/* +!data/sample/ +!data/sample/** + +# ── GIS 大型文件格式 ───────────────────────────────────────────────────────── +*.shp +*.dbf +*.shx +*.prj +*.cpg +*.sbn +*.sbx +*.fbn +*.fbx +*.ain +*.aih +*.atx +*.ixs +*.mxs +*.ovr +*.ecw +*.img +*.jp2 +*.sid +*.tif +*.tiff +*.geotiff +# FileGDB — 整个 .gdb 文件夹 +*.gdb/ +# 但允许提交示例 gdb(如果有) +# !data/sample/*.gdb/ + +# ── Jupyter Notebook ──────────────────────────────────────────────────────── +.ipynb_checkpoints/ +*.ipynb_checkpoints + +# ── 临时文件 ───────────────────────────────────────────────────────────────── +tmp/ +temp/ +*.tmp +*.bak +*.orig diff --git a/README.md b/README.md new file mode 100644 index 0000000..542a57e --- /dev/null +++ b/README.md @@ -0,0 +1,128 @@ +# geo_tools + +> 专业地理信息数据处理工具库 —— 基于 geopandas / shapely / fiona + +[![Python](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org) +[![License: MIT](https://img.shields.io/badge/license-MIT-green)](LICENSE) + +--- + +## 功能特性 + +- **统一 IO 接口**:一行代码读写 Shapefile、GeoJSON、GeoPackage、**File Geodatabase (GDB)**、KML、CSV 等格式 +- **核心几何运算**:基于 Shapely 2.x 的缓冲区、集合运算、有效性检查与自动修复 +- **坐标系处理**:重投影、CRS 信息查询、批量坐标转换,内置中国常用 CRS 常量 +- **空间分析**:叠置分析、最近邻、按位置选择、面积加权均值、属性统计汇总 +- **配置驱动**:通过 `.env` 或环境变量控制输出路径、日志级别、默认 CRS 等 +- **栅格预留接口**:为 rasterio 集成预留扩展点 + +## 项目结构 + +``` +geo_tools/ +├── geo_tools/ # 主包 +│ ├── config/ # Pydantic BaseSettings 全局配置 +│ ├── core/ # 核心处理(vector / geometry / projection / raster) +│ ├── io/ # 数据读写(readers / writers,含 GDB) +│ ├── analysis/ # 空间分析(spatial_ops / stats) +│ └── utils/ # 通用工具(logger / validators / config) +├── scripts/ # 独立处理脚本 +├── tests/ # pytest 测试套件 +├── data/sample/ # 示例数据(GeoJSON) +├── output/ # 处理结果输出目录 +├── logs/ # 日志文件目录 +├── docs/ # 文档 +└── pyproject.toml # 项目配置与依赖 +``` + +## 快速开始 + +### 安装依赖 + +```bash +# 推荐使用 conda 安装地理库(避免 GDAL 编译问题) +conda install -c conda-forge geopandas shapely fiona pyproj + +# 然后安装本项目(开发模式) +pip install -e ".[dev]" +``` + +### 基本使用 + +```python +import geo_tools + +# 读取矢量数据(自动识别格式) +gdf = geo_tools.read_vector("data/sample/sample_points.geojson") + +# 读写 File Geodatabase +layers = geo_tools.list_gdb_layers("path/to/data.gdb") +gdf = geo_tools.read_gdb("path/to/data.gdb", layer="my_layer") +geo_tools.write_gdb(gdf, "output/result.gdb", layer="result") + +# 坐标系转换 +gdf_proj = geo_tools.reproject(gdf, "EPSG:3857") + +# 缓冲区分析 +from geo_tools.core.geometry import buffer_geometry +buffered_geom = buffer_geometry(gdf.geometry[0], distance=1000) + +# 空间叠置 +from geo_tools.analysis.spatial_ops import overlay +result = geo_tools.overlay(layer_a, layer_b, how="intersection") + +# 面积加权均值 +from geo_tools.analysis.stats import area_weighted_mean +result = area_weighted_mean(polygon_gdf, value_col="soil_ph", group_col="region") +``` + +### 配置 + +复制 `.env.example` 为 `.env` 并按需修改: + +```bash +GEO_TOOLS_OUTPUT_DIR=D:/output +GEO_TOOLS_DEFAULT_CRS=EPSG:4490 +GEO_TOOLS_LOG_LEVEL=DEBUG +``` + +## 运行测试 + +```bash +# 运行全部测试 +pytest tests/ -v + +# 运行带覆盖率报告 +pytest tests/ -v --cov=geo_tools --cov-report=html +``` + +## 运行示例脚本 + +```bash +python scripts/example_workflow.py +``` + +## GDB 支持说明 + +本项目通过 `fiona>=1.9` 的 `OpenFileGDB` 驱动读写 Esri File Geodatabase(`.gdb`)。 + +| 操作 | 驱动 | 要求 | +|------|------|------| +| 读取 GDB | `OpenFileGDB` | fiona >= 1.9(内置) | +| 写出 GDB | `OpenFileGDB` | fiona >= 1.9(内置) | +| 编辑 GDB(高级) | `FileGDB` | 需要 ESRI FileGDB API | + +```python +# 列出所有图层 +layers = geo_tools.list_gdb_layers("data.gdb") + +# 读取指定图层 +gdf = geo_tools.read_gdb("data.gdb", layer="土地利用", crs="EPSG:4490") + +# 写出到 GDB(新建或追加图层) +geo_tools.write_gdb(result_gdf, "output.gdb", layer="分析结果", mode="w") +``` + +## 许可证 + +MIT License diff --git a/data/sample/sample_points.geojson b/data/sample/sample_points.geojson new file mode 100644 index 0000000..b8ee39c --- /dev/null +++ b/data/sample/sample_points.geojson @@ -0,0 +1,102 @@ +{ + "type": "FeatureCollection", + "name": "sample_points", + "crs": { + "type": "name", + "properties": { + "name": "urn:ogc:def:crs:OGC:1.3:CRS84" + } + }, + "features": [ + { + "type": "Feature", + "id": 1, + "properties": { + "id": 1, + "name": "北京", + "city": "Beijing", + "value": 10.5, + "category": "A" + }, + "geometry": { + "type": "Point", + "coordinates": [ + 116.4074, + 39.9042 + ] + } + }, + { + "type": "Feature", + "id": 2, + "properties": { + "id": 2, + "name": "上海", + "city": "Shanghai", + "value": 20.0, + "category": "B" + }, + "geometry": { + "type": "Point", + "coordinates": [ + 121.4737, + 31.2304 + ] + } + }, + { + "type": "Feature", + "id": 3, + "properties": { + "id": 3, + "name": "广州", + "city": "Guangzhou", + "value": 15.3, + "category": "A" + }, + "geometry": { + "type": "Point", + "coordinates": [ + 113.2644, + 23.1291 + ] + } + }, + { + "type": "Feature", + "id": 4, + "properties": { + "id": 4, + "name": "成都", + "city": "Chengdu", + "value": 8.7, + "category": "C" + }, + "geometry": { + "type": "Point", + "coordinates": [ + 104.0668, + 30.5728 + ] + } + }, + { + "type": "Feature", + "id": 5, + "properties": { + "id": 5, + "name": "武汉", + "city": "Wuhan", + "value": 12.1, + "category": "B" + }, + "geometry": { + "type": "Point", + "coordinates": [ + 114.3054, + 30.5931 + ] + } + } + ] +} \ No newline at end of file diff --git a/data/sample/sample_regions.geojson b/data/sample/sample_regions.geojson new file mode 100644 index 0000000..63f1160 --- /dev/null +++ b/data/sample/sample_regions.geojson @@ -0,0 +1,120 @@ +{ + "type": "FeatureCollection", + "name": "sample_regions", + "crs": { + "type": "name", + "properties": { + "name": "urn:ogc:def:crs:OGC:1.3:CRS84" + } + }, + "features": [ + { + "type": "Feature", + "id": 1, + "properties": { + "region_id": 1, + "name": "华北", + "area_km2": 1540000 + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 110.0, + 36.0 + ], + [ + 120.0, + 36.0 + ], + [ + 120.0, + 42.5 + ], + [ + 110.0, + 42.5 + ], + [ + 110.0, + 36.0 + ] + ] + ] + } + }, + { + "type": "Feature", + "id": 2, + "properties": { + "region_id": 2, + "name": "华东", + "area_km2": 790000 + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 118.0, + 29.0 + ], + [ + 122.5, + 29.0 + ], + [ + 122.5, + 35.0 + ], + [ + 118.0, + 35.0 + ], + [ + 118.0, + 29.0 + ] + ] + ] + } + }, + { + "type": "Feature", + "id": 3, + "properties": { + "region_id": 3, + "name": "华南", + "area_km2": 450000 + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [ + 110.0, + 21.0 + ], + [ + 117.0, + 21.0 + ], + [ + 117.0, + 25.0 + ], + [ + 110.0, + 25.0 + ], + [ + 110.0, + 21.0 + ] + ] + ] + } + } + ] +} \ No newline at end of file diff --git a/docs/projection_guide.md b/docs/projection_guide.md new file mode 100644 index 0000000..bd11818 --- /dev/null +++ b/docs/projection_guide.md @@ -0,0 +1,248 @@ +# geo_tools.core.projection 使用说明 + +> 坐标系查询、坐标转换、投影推荐工具,基于 [pyproj](https://pyproj4.github.io/pyproj/)。 + +--- + +## 导入方式 + +```python +# 推荐:从顶层包导入 +from geo_tools.core.projection import ( + get_crs_info, + crs_to_epsg, + transform_coordinates, + transform_point, + suggest_projected_crs, + WGS84, CGCS2000, WEB_MERCATOR, CGCS2000_UTM_50N, +) + +# 或直接通过 geo_tools 导入 +import geo_tools +``` + +--- + +## CRS 常量 + +模块内置了中国地理信息处理中最常用的 CRS 快捷常量,可直接作为参数传入所有函数: + +| 常量名 | EPSG | 说明 | +|--------|------|------| +| `WGS84` | `EPSG:4326` | WGS84 地理坐标系(经纬度,最通用) | +| `CGCS2000` | `EPSG:4490` | 中国国家大地坐标系 2000(经纬度) | +| `WEB_MERCATOR` | `EPSG:3857` | Web Mercator 投影(网络地图常用,单位:米) | +| `CGCS2000_UTM_50N` | `EPSG:4508` | CGCS2000 / 3° 高斯-克吕格 50 带(单位:米) | + +```python +from geo_tools.core.projection import WGS84, CGCS2000 + +# 直接用常量替代字符串 +gdf = gdf.to_crs(CGCS2000) +``` + +--- + +## 函数说明 + +### `get_crs_info(crs_input)` — 查询 CRS 信息 + +返回坐标系的详细描述字典,方便快速了解一个未知 EPSG 的含义。 + +**参数** +- `crs_input`:EPSG 代码字符串(如 `"EPSG:4523"`)、整数(如 `4523`)或 proj 字符串。 + +**返回值**(`dict`) + +| 键 | 含义 | +|----|------| +| `name` | 坐标系名称 | +| `epsg` | EPSG 整数编号(无法识别时为 `None`) | +| `unit` | 坐标单位(`degree` / `metre`) | +| `is_geographic` | 是否为地理坐标系(经纬度) | +| `is_projected` | 是否为投影坐标系(平面直角) | +| `datum` | 基准面名称 | + +```python +from geo_tools.core.projection import get_crs_info + +# 查询读取到的 GDB 数据的 CRS 含义 +info = get_crs_info("EPSG:4523") +print(info) +# { +# 'name': 'CGCS2000 / 3-degree Gauss-Kruger zone 45', +# 'epsg': 4523, +# 'unit': 'metre', +# 'is_geographic': False, +# 'is_projected': True, +# 'datum': 'China Geodetic Coordinate System 2000' +# } + +# 直接传整数 +info = get_crs_info(32650) +print(info["name"]) # WGS 84 / UTM zone 50N +``` + +--- + +### `crs_to_epsg(crs_input)` — 获取 EPSG 编号 + +将任意 CRS 描述转为整数 EPSG 编号,无法识别时返回 `None`(不抛异常)。 + +```python +from geo_tools.core.projection import crs_to_epsg + +epsg = crs_to_epsg("EPSG:4490") +print(epsg) # 4490 + +epsg = crs_to_epsg("WGS 84") +print(epsg) # 4326 + +epsg = crs_to_epsg("invalid_crs") +print(epsg) # None +``` + +--- + +### `transform_coordinates(xs, ys, source_crs, target_crs)` — 批量坐标转换 + +将一组坐标点从源坐标系批量转换到目标坐标系,返回转换后的 `(xs, ys)` 列表。 + +**参数** +- `xs`:X 坐标序列(地理 CRS 时为**经度**) +- `ys`:Y 坐标序列(地理 CRS 时为**纬度**) +- `source_crs`:源坐标系 +- `target_crs`:目标坐标系 +- `always_xy`(关键字参数):强制按 (经度/X, 纬度/Y) 顺序处理,默认 `True`,**建议不修改** + +```python +from geo_tools.core.projection import transform_coordinates, WGS84, WEB_MERCATOR + +# 将北京、上海、广州的 WGS84 经纬度转为 Web Mercator 米制坐标 +lons = [116.4074, 121.4737, 113.2644] +lats = [39.9042, 31.2304, 23.1291] + +xs, ys = transform_coordinates(lons, lats, WGS84, WEB_MERCATOR) +print(xs) # [12959618.8, 13521606.3, 12608870.0](单位:米) +print(ys) # [4859767.2, 3649094.2, 2641877.0] + +# 国家坐标系转换:CGCS2000 经纬度 → CGCS2000 3° 高斯带(50带) +from geo_tools.core.projection import CGCS2000, CGCS2000_UTM_50N +xs_proj, ys_proj = transform_coordinates(lons, lats, CGCS2000, CGCS2000_UTM_50N) +``` + +--- + +### `transform_point(x, y, source_crs, target_crs)` — 单点坐标转换 + +`transform_coordinates` 的单点版本,直接返回 `(x, y)` 元组。 + +```python +from geo_tools.core.projection import transform_point, WGS84, CGCS2000 + +# 单点:WGS84 → CGCS2000(两者数值非常接近,差异在毫米级) +x, y = transform_point(116.4074, 39.9042, WGS84, CGCS2000) +print(f"CGCS2000 坐标:经度={x:.6f}, 纬度={y:.6f}") + +# 单点:经纬度 → 投影坐标(米) +from geo_tools.core.projection import WEB_MERCATOR +mx, my = transform_point(116.4074, 39.9042, WGS84, WEB_MERCATOR) +print(f"墨卡托坐标:X={mx:.2f}m, Y={my:.2f}m") +``` + +--- + +### `suggest_projected_crs(lon, lat)` — 自动推荐投影 CRS + +根据数据中心坐标(WGS84 经纬度)自动推荐适合**面积/距离计算**的 UTM 投影带,避免在地理坐标系下计算面积出错。 + +**参数** +- `lon`:中心经度(WGS84) +- `lat`:中心纬度(WGS84,北半球为正) + +**返回值**:EPSG 代码字符串,如 `"EPSG:32650"` + +```python +from geo_tools.core.projection import suggest_projected_crs + +# 云南马关县(约 104.4°E, 23.0°N) +proj_crs = suggest_projected_crs(lon=104.4, lat=23.0) +print(proj_crs) # EPSG:32648 (WGS84 UTM zone 48N) + +# 北京(116.4°E, 39.9°N) +proj_crs = suggest_projected_crs(lon=116.4, lat=39.9) +print(proj_crs) # EPSG:32650 (WGS84 UTM zone 50N) + +# 实际场景:读取 GDB 后用推荐的投影计算面积 +import geo_tools + +gdf = geo_tools.read_gdb("data.gdb", layer="图斑") +cx, cy = gdf.geometry.unary_union.centroid.x, gdf.geometry.unary_union.centroid.y + +# 如果数据是投影坐标系(单位:米),先转到地理坐标系再推荐 +if gdf.crs.is_projected: + cx, cy = geo_tools.transform_point(cx, cy, gdf.crs, "EPSG:4326") + +proj_crs = suggest_projected_crs(cx, cy) +gdf_proj = geo_tools.reproject(gdf, proj_crs) # 重投影 +gdf_proj = geo_tools.add_area_column(gdf_proj) # 计算面积(单位:m²) +``` + +--- + +## 常见场景示例 + +### 场景一:不认识数据的 CRS,先查一下 + +```python +import geo_tools + +gdf = geo_tools.read_gdb("临时数据库.gdb", layer="马关综合后图斑") +# 读取完成:CRS=EPSG:4523 + +info = geo_tools.get_crs_info(gdf.crs) +print(info["name"]) # CGCS2000 / 3-degree Gauss-Kruger zone 45 +print(info["unit"]) # metre(投影坐标系,单位是米) +print(info["is_projected"]) # True +``` + +### 场景二:统一坐标系后叠置分析 + +```python +import geo_tools +from geo_tools.core.projection import CGCS2000 + +layer_a = geo_tools.read_gdb("a.gdb", layer="林地") # EPSG:4523 +layer_b = geo_tools.read_vector("b.geojson") # EPSG:4326 + +# 统一到 CGCS2000 地理坐标系后再做叠置 +layer_a = geo_tools.reproject(layer_a, CGCS2000) +layer_b = geo_tools.reproject(layer_b, CGCS2000) + +result = geo_tools.overlay(layer_a, layer_b, how="intersection") +``` + +### 场景三:在地理坐标系数据上正确计算面积 + +```python +import geo_tools +from geo_tools.core.projection import suggest_projected_crs + +gdf = geo_tools.read_vector("data.geojson") # EPSG:4326,单位是度 + +# 自动推荐合适的投影 +proj = suggest_projected_crs(lon=105.0, lat=25.0) # 云贵地区 + +gdf = geo_tools.add_area_column(gdf, projected_crs=proj) +print(gdf["area_m2"].describe()) +``` + +--- + +## 注意事项 + +> [!WARNING] +> 在**地理坐标系**(EPSG:4326 / 4490)下直接调用 `geometry.area` 得到的是"平方度",**不是平方米**,面积计算会严重失真。始终用 `add_area_column()` 或先 `reproject()` 到投影坐标系后再计算。 + +> [!NOTE] +> `WGS84`(EPSG:4326)与 `CGCS2000`(EPSG:4490)的坐标数值差异极小(通常 < 1 米),在普通精度的分析中可视为等价,但正式国家项目中必须使用 CGCS2000。 diff --git a/geo_tools/__init__.py b/geo_tools/__init__.py new file mode 100644 index 0000000..cd0631c --- /dev/null +++ b/geo_tools/__init__.py @@ -0,0 +1,178 @@ +""" +geo_tools +~~~~~~~~~ +专业地理信息数据处理工具库。 + +核心依赖:geopandas、shapely、fiona、pyproj。 + +快速开始 +-------- +>>> import geo_tools +>>> gdf = geo_tools.read_vector("data/sample/sample_points.geojson") +>>> gdf_proj = geo_tools.reproject(gdf, "EPSG:3857") +>>> print(gdf_proj.crs) + +GDB 读写 +-------- +>>> layers = geo_tools.list_gdb_layers("path/to/data.gdb") +>>> gdf = geo_tools.read_gdb("path/to/data.gdb", layer="my_layer") +>>> geo_tools.write_gdb(gdf, "output/result.gdb", layer="result_layer") +""" + +from importlib.metadata import PackageNotFoundError, version + +# ── 版本 ────────────────────────────────────────────────────────────────────── +try: + __version__ = version("geo-tools") +except PackageNotFoundError: + __version__ = "0.1.0-dev" + +# ── 配置 & 日志 ─────────────────────────────────────────────────────────────── +from geo_tools.config.settings import settings +from geo_tools.utils.logger import get_logger, set_global_level +from geo_tools.utils.validators import ( + SUPPORTED_VECTOR_EXTENSIONS, + is_supported_vector_format, + is_valid_crs, + validate_crs, + validate_geometry, + validate_vector_path, +) + +# ── IO ──────────────────────────────────────────────────────────────────────── +from geo_tools.io.readers import ( + list_gdb_layers, + list_gpkg_layers, + read_csv_points, + read_gdb, + read_gpkg, + read_vector, +) +from geo_tools.io.writers import ( + write_csv, + write_gdb, + write_gpkg, + write_vector, +) + +# ── 核心处理 ────────────────────────────────────────────────────────────────── +from geo_tools.core.geometry import ( + buffer_geometry, + bounding_box, + centroid, + contains, + convex_hull, + difference, + distance_between, + fix_geometry, + intersect, + intersects, + is_valid_geometry, + symmetric_difference, + unary_union, + union, + within, +) +from geo_tools.core.projection import ( + CGCS2000, + CGCS2000_UTM_50N, + WEB_MERCATOR, + WGS84, + crs_to_epsg, + get_crs_info, + suggest_projected_crs, + transform_coordinates, + transform_point, +) +from geo_tools.core.vector import ( + add_area_column, + clip_to_extent, + dissolve_by, + drop_invalid_geometries, + explode_multipart, + reproject, + set_crs, + spatial_join, +) + +# ── 空间分析 ────────────────────────────────────────────────────────────────── +from geo_tools.analysis.spatial_ops import ( + buffer_and_overlay, + nearest_features, + overlay, + select_by_location, +) +from geo_tools.analysis.stats import ( + area_weighted_mean, + count_by_polygon, + summarize_attributes, +) + +__all__ = [ + "__version__", + "settings", + # utils + "get_logger", + "set_global_level", + "is_valid_crs", + "validate_crs", + "validate_geometry", + "is_supported_vector_format", + "validate_vector_path", + "SUPPORTED_VECTOR_EXTENSIONS", + # io - readers + "read_vector", + "read_gdb", + "list_gdb_layers", + "read_gpkg", + "list_gpkg_layers", + "read_csv_points", + # io - writers + "write_vector", + "write_gdb", + "write_gpkg", + "write_csv", + # core - geometry + "is_valid_geometry", + "fix_geometry", + "buffer_geometry", + "centroid", + "bounding_box", + "convex_hull", + "intersect", + "union", + "difference", + "symmetric_difference", + "unary_union", + "contains", + "within", + "intersects", + "distance_between", + # core - projection + "WGS84", + "CGCS2000", + "WEB_MERCATOR", + "CGCS2000_UTM_50N", + "get_crs_info", + "crs_to_epsg", + "transform_coordinates", + "transform_point", + "suggest_projected_crs", + # core - vector + "reproject", + "set_crs", + "clip_to_extent", + "dissolve_by", + "explode_multipart", + "drop_invalid_geometries", + "spatial_join", + "add_area_column", + # analysis + "buffer_and_overlay", + "overlay", + "nearest_features", + "select_by_location", + "area_weighted_mean", + "summarize_attributes", + "count_by_polygon", +] diff --git a/geo_tools/analysis/__init__.py b/geo_tools/analysis/__init__.py new file mode 100644 index 0000000..ae7beb7 --- /dev/null +++ b/geo_tools/analysis/__init__.py @@ -0,0 +1,25 @@ +"""geo_tools.analysis 包 —— 空间分析层。""" + +from geo_tools.analysis.spatial_ops import ( + buffer_and_overlay, + nearest_features, + overlay, + select_by_location, +) +from geo_tools.analysis.stats import ( + area_weighted_mean, + count_by_polygon, + summarize_attributes, +) + +__all__ = [ + # spatial_ops + "buffer_and_overlay", + "overlay", + "nearest_features", + "select_by_location", + # stats + "area_weighted_mean", + "summarize_attributes", + "count_by_polygon", +] diff --git a/geo_tools/analysis/spatial_ops.py b/geo_tools/analysis/spatial_ops.py new file mode 100644 index 0000000..cb57d51 --- /dev/null +++ b/geo_tools/analysis/spatial_ops.py @@ -0,0 +1,149 @@ +""" +geo_tools.analysis.spatial_ops +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +空间叠加与邻域分析操作。 +""" + +from __future__ import annotations + +from typing import Any + +import geopandas as gpd +import pandas as pd + +from geo_tools.utils.logger import get_logger + +logger = get_logger(__name__) + + +def buffer_and_overlay( + source: gpd.GeoDataFrame, + distance: float, + target: gpd.GeoDataFrame, + how: str = "intersection", + projected_crs: str | None = None, +) -> gpd.GeoDataFrame: + """对 source 执行缓冲区后与 target 执行叠置分析。 + + Parameters + ---------- + source: + 源图层(生成缓冲区)。 + distance: + 缓冲距离(与 ``projected_crs`` 单位一致)。 + target: + 叠置目标图层。 + how: + 叠置类型:``"intersection"``、``"union"``、``"difference"``、``"symmetric_difference"``、``"identity"``。 + projected_crs: + 执行缓冲区前先投影到此 CRS(建议使用平面坐标系以保证距离精度); + ``None`` 则使用 source 的当前 CRS(地理 CRS 下 distance 单位为度)。 + + Returns + ------- + gpd.GeoDataFrame + """ + original_crs = source.crs + + if projected_crs: + source = source.to_crs(projected_crs) + target = target.to_crs(projected_crs) + + buffered = source.copy() + buffered["geometry"] = buffered.geometry.buffer(distance) + logger.debug("缓冲区完成(distance=%.2f),执行叠置分析(how=%s)", distance, how) + + result = gpd.overlay(buffered, target, how=how, keep_geom_type=False) + + if projected_crs: + result = result.to_crs(original_crs) + + logger.info("叠置分析完成:%d 条结果", len(result)) + return result + + +def overlay( + df1: gpd.GeoDataFrame, + df2: gpd.GeoDataFrame, + how: str = "intersection", + keep_geom_type: bool = True, +) -> gpd.GeoDataFrame: + """封装 geopandas overlay,自动对齐 CRS。 + + Parameters + ---------- + how: + 叠置类型:``"intersection"``、``"union"``、``"difference"``、 + ``"symmetric_difference"``、``"identity"``。 + """ + if df1.crs != df2.crs: + df2 = df2.to_crs(df1.crs) + result = gpd.overlay(df1, df2, how=how, keep_geom_type=keep_geom_type) + logger.debug("overlay(%s):%d 条结果", how, len(result)) + return result + + +def nearest_features( + source: gpd.GeoDataFrame, + target: gpd.GeoDataFrame, + k: int = 1, + max_distance: float | None = None, +) -> gpd.GeoDataFrame: + """为 source 中每条要素找到 target 中最近的 k 个要素。 + + Parameters + ---------- + source: + 查询图层。 + target: + 被查询图层。 + k: + 最近邻数量。 + max_distance: + 最大搜索距离(与 CRS 单位一致),``None`` 表示无限制。 + + Returns + ------- + gpd.GeoDataFrame + 连接了最近 target 属性的 source GDF(可能包含重复行,每行对应一个近邻)。 + """ + if source.crs != target.crs: + target = target.to_crs(source.crs) + + result = gpd.sjoin_nearest( + source, + target, + how="left", + max_distance=max_distance, + distance_col="nearest_distance", + lsuffix="left", + rsuffix="right", + ) + logger.debug("最近邻分析完成(k=%d):%d 条结果", k, len(result)) + return result + + +def select_by_location( + source: gpd.GeoDataFrame, + selector: gpd.GeoDataFrame, + predicate: str = "intersects", +) -> gpd.GeoDataFrame: + """按位置关系从 source 中选取要素(等同于 ArcGIS「按位置选择」)。 + + Parameters + ---------- + predicate: + 空间谓词:``"intersects"``、``"within"``、``"contains"``、``"touches"``。 + + Returns + ------- + gpd.GeoDataFrame + 满足条件的 source 子集。 + """ + if source.crs != selector.crs: + selector = selector.to_crs(source.crs) + + joined = gpd.sjoin(source, selector, how="inner", predicate=predicate) + result = source.loc[source.index.isin(joined.index)].copy() + logger.debug("按位置选择(%s):%d / %d 条", predicate, len(result), len(source)) + return result diff --git a/geo_tools/analysis/stats.py b/geo_tools/analysis/stats.py new file mode 100644 index 0000000..b9540c4 --- /dev/null +++ b/geo_tools/analysis/stats.py @@ -0,0 +1,136 @@ +""" +geo_tools.analysis.stats +~~~~~~~~~~~~~~~~~~~~~~~~~ +空间统计工具:属性汇总、面积加权均值、空间自相关指数等。 +""" + +from __future__ import annotations + +import geopandas as gpd +import numpy as np +import pandas as pd + +from geo_tools.utils.logger import get_logger + +logger = get_logger(__name__) + + +def area_weighted_mean( + gdf: gpd.GeoDataFrame, + value_col: str, + group_col: str | None = None, + projected_crs: str = "EPSG:3857", +) -> pd.Series | pd.DataFrame: + """计算面积加权均值。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame(面要素)。 + value_col: + 需要加权平均的属性列名。 + group_col: + 分组字段名;若为 ``None`` 则对整个 GDF 计算单一结果。 + projected_crs: + 用于计算面积的平面投影 CRS。 + + Returns + ------- + pd.Series(无分组)或 pd.DataFrame(有分组) + """ + gdf = gdf.copy() + + # 计算面积 + if not gdf.crs or not gdf.crs.is_projected: + projected = gdf.to_crs(projected_crs) + else: + projected = gdf + gdf["_area"] = projected.geometry.area + + if group_col is None: + total_area = gdf["_area"].sum() + result = (gdf[value_col] * gdf["_area"]).sum() / total_area + return pd.Series({"area_weighted_mean": result, "total_area": total_area}) + + def _weighted(group: pd.DataFrame) -> float: + return float((group[value_col] * group["_area"]).sum() / group["_area"].sum()) + + result = gdf.groupby(group_col).apply(_weighted, include_groups=False).rename("area_weighted_mean") + area_sum = gdf.groupby(group_col)["_area"].sum().rename("total_area") + return pd.concat([result, area_sum], axis=1).reset_index() + + +def summarize_attributes( + gdf: gpd.GeoDataFrame, + columns: list[str] | None = None, + group_col: str | None = None, + agg_funcs: list[str] | None = None, +) -> pd.DataFrame: + """对属性列进行统计汇总(最大、最小、均值、总和等)。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame。 + columns: + 统计的列名列表;``None`` 则自动选取所有数值列。 + group_col: + 分组字段名;``None`` 则对全局统计。 + agg_funcs: + 聚合函数列表,默认 ``["count", "mean", "min", "max", "sum", "std"]``。 + + Returns + ------- + pd.DataFrame + """ + if agg_funcs is None: + agg_funcs = ["count", "mean", "min", "max", "sum", "std"] + + df = gdf.drop(columns=["geometry"], errors="ignore") + + if columns is None: + columns = df.select_dtypes(include="number").columns.tolist() + + if not columns: + raise ValueError("未找到数值列,请显式指定 columns 参数。") + + subset = df[columns] + + if group_col is None: + return subset.agg(agg_funcs).T.rename_axis("column").reset_index() + + df_with_group = df[[group_col] + columns] + return df_with_group.groupby(group_col)[columns].agg(agg_funcs).reset_index() + + +def count_by_polygon( + points: gpd.GeoDataFrame, + polygons: gpd.GeoDataFrame, + count_col: str = "point_count", +) -> gpd.GeoDataFrame: + """统计每个面要素内的点要素数量(类似 ArcGIS「面要素统计点」)。 + + Parameters + ---------- + points: + 点图层。 + polygons: + 面图层。 + count_col: + 新增计数列名。 + + Returns + ------- + gpd.GeoDataFrame + 含 ``count_col`` 列的 polygons 副本。 + """ + if points.crs != polygons.crs: + points = points.to_crs(polygons.crs) + + joined = gpd.sjoin(points, polygons, how="inner", predicate="within") + point_counts = joined.groupby("index_right").size().rename(count_col) + + result = polygons.copy() + result = result.join(point_counts) + result[count_col] = result[count_col].fillna(0).astype(int) + return result diff --git a/geo_tools/config/__init__.py b/geo_tools/config/__init__.py new file mode 100644 index 0000000..cf3de1c --- /dev/null +++ b/geo_tools/config/__init__.py @@ -0,0 +1,5 @@ +"""geo_tools.config 包 —— 全局配置层。""" + +from geo_tools.config.settings import GeoToolsSettings, settings + +__all__ = ["GeoToolsSettings", "settings"] diff --git a/geo_tools/config/project_enum.py b/geo_tools/config/project_enum.py new file mode 100644 index 0000000..ff843df --- /dev/null +++ b/geo_tools/config/project_enum.py @@ -0,0 +1,43 @@ +""" +枚举类 +""" +from enum import Enum, unique + +# 坐标系枚举 +@unique +class CRS(Enum): + WGS84 = "EPSG:4326" + CGCS2000 = "EPSG:4490" + WEB_MERCATOR = "EPSG:3857" + CGCS2000_3_DEGREE_ZONE_25 = "EPSG:4513" + CGCS2000_3_DEGREE_ZONE_26 = "EPSG:4514" + CGCS2000_3_DEGREE_ZONE_27 = "EPSG:4515" + CGCS2000_3_DEGREE_ZONE_28 = "EPSG:4516" + CGCS2000_3_DEGREE_ZONE_29 = "EPSG:4517" + CGCS2000_3_DEGREE_ZONE_30 = "EPSG:4518" + CGCS2000_3_DEGREE_ZONE_31 = "EPSG:4519" + CGCS2000_3_DEGREE_ZONE_32 = "EPSG:4520" + CGCS2000_3_DEGREE_ZONE_33 = "EPSG:4521" + CGCS2000_3_DEGREE_ZONE_34 = "EPSG:4522" + CGCS2000_3_DEGREE_ZONE_35 = "EPSG:4523" + CGCS2000_3_DEGREE_ZONE_36 = "EPSG:4524" + CGCS2000_3_DEGREE_ZONE_37 = "EPSG:4525" + CGCS2000_3_DEGREE_ZONE_38 = "EPSG:4526" + CGCS2000_3_DEGREE_ZONE_39 = "EPSG:4527" + CGCS2000_3_DEGREE_ZONE_40 = "EPSG:4528" + CGCS2000_3_DEGREE_ZONE_41 = "EPSG:4529" + CGCS2000_3_DEGREE_ZONE_42 = "EPSG:4530" + CGCS2000_3_DEGREE_ZONE_43 = "EPSG:4531" + CGCS2000_3_DEGREE_ZONE_44 = "EPSG:4532" + CGCS2000_3_DEGREE_ZONE_45 = "EPSG:4533" + CGCS2000_6_DEGREE_ZONE_13 = "EPSG:4491" + CGCS2000_6_DEGREE_ZONE_14 = "EPSG:4492" + CGCS2000_6_DEGREE_ZONE_15 = "EPSG:4493" + CGCS2000_6_DEGREE_ZONE_16 = "EPSG:4494" + CGCS2000_6_DEGREE_ZONE_17 = "EPSG:4495" + CGCS2000_6_DEGREE_ZONE_18 = "EPSG:4496" + CGCS2000_6_DEGREE_ZONE_19 = "EPSG:4497" + CGCS2000_6_DEGREE_ZONE_20 = "EPSG:4498" + CGCS2000_6_DEGREE_ZONE_21 = "EPSG:4499" + CGCS2000_6_DEGREE_ZONE_22 = "EPSG:4500" + CGCS2000_6_DEGREE_ZONE_23 = "EPSG:4501" diff --git a/geo_tools/config/settings.py b/geo_tools/config/settings.py new file mode 100644 index 0000000..7ecffcf --- /dev/null +++ b/geo_tools/config/settings.py @@ -0,0 +1,98 @@ +""" +geo_tools.config.settings +~~~~~~~~~~~~~~~~~~~~~~~~~ +全局配置,通过 Pydantic BaseSettings 从环境变量 / .env 文件加载。 + +使用方式 +-------- +>>> from geo_tools.config.settings import settings +>>> print(settings.default_crs) +'EPSG:4326' +""" + +from __future__ import annotations + +import multiprocessing +from pathlib import Path + +from pydantic import field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class GeoToolsSettings(BaseSettings): + """全局运行时配置。 + + 所有字段均可通过前缀为 ``GEO_TOOLS_`` 的环境变量覆盖, + 或在项目根目录创建 ``.env`` 文件(参考 ``.env.example``)。 + """ + + model_config = SettingsConfigDict( + env_prefix="GEO_TOOLS_", + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) + + # ── 目录配置 ────────────────────────────────────────────── + output_dir: Path = Path("output") + """处理结果输出目录(相对路径相对于当前工作目录)。""" + + log_dir: Path = Path("logs") + """日志文件目录。""" + + # ── 坐标系配置 ──────────────────────────────────────────── + default_crs: str = "EPSG:4326" + """默认地理坐标系,使用 EPSG 代码字符串。 + 常见值: + - ``EPSG:4326`` — WGS84 经纬度 + - ``EPSG:4490`` — CGCS2000 经纬度(中国国家标准) + - ``EPSG:3857`` — Web Mercator + """ + + # ── 日志配置 ────────────────────────────────────────────── + log_level: str = "INFO" + """日志等级:DEBUG / INFO / WARNING / ERROR / CRITICAL。""" + + log_to_file: bool = True + """是否同时将日志写出到文件。""" + + # ── 性能配置 ────────────────────────────────────────────── + max_workers: int = 0 + """并行处理最大 CPU 核数,0 表示自动检测(使用 CPU 核数 - 1)。""" + + # ── 校验器 ──────────────────────────────────────────────── + + @field_validator("log_level") + @classmethod + def validate_log_level(cls, v: str) -> str: + allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} + upper = v.upper() + if upper not in allowed: + raise ValueError(f"log_level 必须是 {allowed} 之一,收到:{v!r}") + return upper + + @field_validator("default_crs") + @classmethod + def validate_crs(cls, v: str) -> str: + # 简单前缀校验,完整校验在 validators.py 中通过 pyproj 完成 + v = v.strip() + if not v: + raise ValueError("default_crs 不能为空") + return v + + @model_validator(mode="after") + def resolve_max_workers(self) -> "GeoToolsSettings": + if self.max_workers <= 0: + cpu_count = multiprocessing.cpu_count() + self.max_workers = max(1, cpu_count - 1) + return self + + def ensure_dirs(self) -> None: + """创建输出和日志目录(幂等)。""" + self.output_dir.mkdir(parents=True, exist_ok=True) + self.log_dir.mkdir(parents=True, exist_ok=True) + + +# 模块级单例,项目内统一引用 +settings = GeoToolsSettings() diff --git a/geo_tools/core/__init__.py b/geo_tools/core/__init__.py new file mode 100644 index 0000000..6562bde --- /dev/null +++ b/geo_tools/core/__init__.py @@ -0,0 +1,80 @@ +"""geo_tools.core 包 —— 核心地理处理层。""" + +from geo_tools.core.geometry import ( + buffer_geometry, + bounding_box, + centroid, + contains, + convex_hull, + difference, + distance_between, + explain_validity, + fix_geometry, + intersect, + intersects, + is_valid_geometry, + symmetric_difference, + unary_union, + union, + within, +) +from geo_tools.core.projection import ( + CGCS2000, + CGCS2000_UTM_50N, + WEB_MERCATOR, + WGS84, + crs_to_epsg, + get_crs_info, + suggest_projected_crs, + transform_coordinates, + transform_point, +) +from geo_tools.core.vector import ( + add_area_column, + clip_to_extent, + dissolve_by, + drop_invalid_geometries, + explode_multipart, + reproject, + set_crs, + spatial_join, +) + +__all__ = [ + # geometry + "is_valid_geometry", + "fix_geometry", + "explain_validity", + "buffer_geometry", + "centroid", + "bounding_box", + "convex_hull", + "intersect", + "union", + "difference", + "symmetric_difference", + "unary_union", + "contains", + "within", + "intersects", + "distance_between", + # projection + "WGS84", + "CGCS2000", + "WEB_MERCATOR", + "CGCS2000_UTM_50N", + "get_crs_info", + "crs_to_epsg", + "transform_coordinates", + "transform_point", + "suggest_projected_crs", + # vector + "reproject", + "set_crs", + "clip_to_extent", + "dissolve_by", + "explode_multipart", + "drop_invalid_geometries", + "spatial_join", + "add_area_column", +] diff --git a/geo_tools/core/geometry.py b/geo_tools/core/geometry.py new file mode 100644 index 0000000..5fcf449 --- /dev/null +++ b/geo_tools/core/geometry.py @@ -0,0 +1,169 @@ +""" +geo_tools.core.geometry +~~~~~~~~~~~~~~~~~~~~~~~ +基于 Shapely 2.x 的几何运算工具函数。 +""" + +from __future__ import annotations + +from typing import Sequence + +import shapely +from shapely.geometry import ( + LinearRing, + LineString, + MultiLineString, + MultiPoint, + MultiPolygon, + Point, + Polygon, +) +from shapely.geometry.base import BaseGeometry + +from geo_tools.utils.logger import get_logger + +logger = get_logger(__name__) + + +# ── 几何有效性 ──────────────────────────────────────────────────────────────── + +def is_valid_geometry(geom: BaseGeometry | None) -> bool: + """判断几何对象是否有效(非空且通过 Shapely 合法性检查)。""" + if geom is None: + return False + return bool(geom.is_valid and not geom.is_empty) + + +def fix_geometry(geom: BaseGeometry | None) -> BaseGeometry | None: + """尝试修复无效几何。 + + 依次尝试: + 1. ``buffer(0)`` — 适合大多数自相交多边形 + 2. ``make_valid``(Shapely 2.x)— 覆盖更多情形 + + Returns + ------- + BaseGeometry | None + 修复后的几何;无法修复时返回 ``None``。 + """ + if geom is None: + return None + if geom.is_valid: + return geom + + # 方法一:buffer(0) + try: + fixed = geom.buffer(0) + if fixed.is_valid and not fixed.is_empty: + return fixed + except Exception: + pass + + # 方法二:shapely.make_valid(Shapely >= 1.8) + try: + fixed = shapely.make_valid(geom) + if fixed.is_valid and not fixed.is_empty: + return fixed + except Exception: + pass + + logger.warning("无法修复几何:%r", geom.geom_type) + return None + + +def explain_validity(geom: BaseGeometry) -> str: + """返回 Shapely 对该几何的有效性说明(英文)。""" + from shapely.validation import explain_validity as _explain + return _explain(geom) + + +# ── 基础几何运算 ─────────────────────────────────────────────────────────────── + +def buffer_geometry( + geom: BaseGeometry, + distance: float, + cap_style: int = 1, + join_style: int = 1, + resolution: int = 16, +) -> BaseGeometry: + """对几何对象执行缓冲区运算。 + + Parameters + ---------- + geom: + 输入几何。 + distance: + 缓冲距离(单位与 CRS 一致;地理坐标系单位为度)。 + cap_style: + 端头样式:1=圆形,2=平头,3=方头(仅线要素有效)。 + join_style: + 转角样式:1=圆角,2=斜角,3=尖角。 + resolution: + 圆弧逼近精度(段数),默认 16。 + """ + return geom.buffer(distance, cap_style=cap_style, join_style=join_style, resolution=resolution) + + +def centroid(geom: BaseGeometry) -> Point: + """返回几何的质心点。""" + return geom.centroid + + +def bounding_box(geom: BaseGeometry) -> Polygon: + """返回几何的最小外接矩形(BBOX)为多边形。""" + from shapely.geometry import box + return box(*geom.bounds) + + +def convex_hull(geom: BaseGeometry) -> BaseGeometry: + """返回几何的凸包。""" + return geom.convex_hull + + +# ── 集合运算 ────────────────────────────────────────────────────────────────── + +def intersect(geom_a: BaseGeometry, geom_b: BaseGeometry) -> BaseGeometry: + """返回两几何的交集。""" + return geom_a.intersection(geom_b) + + +def union(geom_a: BaseGeometry, geom_b: BaseGeometry) -> BaseGeometry: + """返回两几何的并集。""" + return geom_a.union(geom_b) + + +def difference(geom_a: BaseGeometry, geom_b: BaseGeometry) -> BaseGeometry: + """返回 ``geom_a`` 减去 ``geom_b`` 的差集。""" + return geom_a.difference(geom_b) + + +def symmetric_difference(geom_a: BaseGeometry, geom_b: BaseGeometry) -> BaseGeometry: + """返回两几何的对称差集(异或)。""" + return geom_a.symmetric_difference(geom_b) + + +def unary_union(geoms: Sequence[BaseGeometry]) -> BaseGeometry: + """将多个几何合并为一个(等同于逐一 union)。""" + return shapely.unary_union(list(geoms)) + + +# ── 空间关系判断 ─────────────────────────────────────────────────────────────── + +def contains(geom_a: BaseGeometry, geom_b: BaseGeometry) -> bool: + """判断 ``geom_a`` 是否完全包含 ``geom_b``。""" + return bool(geom_a.contains(geom_b)) + + +def within(geom_a: BaseGeometry, geom_b: BaseGeometry) -> bool: + """判断 ``geom_a`` 是否完全在 ``geom_b`` 内。""" + return bool(geom_a.within(geom_b)) + + +def intersects(geom_a: BaseGeometry, geom_b: BaseGeometry) -> bool: + """判断两几何是否相交(含边界接触)。""" + return bool(geom_a.intersects(geom_b)) + + +def distance_between(geom_a: BaseGeometry, geom_b: BaseGeometry) -> float: + """计算两几何间的最小距离(单位与 CRS 一致)。""" + return float(geom_a.distance(geom_b)) diff --git a/geo_tools/core/projection.py b/geo_tools/core/projection.py new file mode 100644 index 0000000..5e07120 --- /dev/null +++ b/geo_tools/core/projection.py @@ -0,0 +1,220 @@ +""" +geo_tools.core.projection +~~~~~~~~~~~~~~~~~~~~~~~~~ +坐标系与投影转换工具,基于 pyproj。 +""" + +from __future__ import annotations +from typing import Sequence +from pyproj import CRS, Transformer +import geopandas as gpd + +from geo_tools.utils.logger import get_logger + +logger = get_logger(__name__) + +# ── 常用 CRS 快捷常量 ────────────────────────────────────────────────────────── +WGS84 = "EPSG:4326" # 地理坐标系(经纬度) +CGCS2000 = "EPSG:4490" # 中国国家大地坐标系 2000 +WEB_MERCATOR = "EPSG:3857" # Web Mercator(米) +CGCS2000_UTM_50N = "EPSG:4508" # CGCS2000 / 3-degree Gauss-Kruger zone 50N + + +def get_crs_info(crs_input: str | int) -> dict[str, str | int | None]: + """获取 CRS 的基本信息。 + + Returns + ------- + dict + 包含 ``name``、``epsg``、``unit``、``is_geographic``、``is_projected``。 + """ + crs = CRS.from_user_input(crs_input) + return { + "name": crs.name, + "epsg": crs.to_epsg(), + "unit": str(crs.axis_info[0].unit_name) if crs.axis_info else None, + "is_geographic": crs.is_geographic, + "is_projected": crs.is_projected, + "datum": crs.datum.name if crs.datum else None, + } + + +def crs_to_epsg(crs_input: str | int) -> int | None: + """尝试将 CRS 转为 EPSG 整数编号,无法识别时返回 None。""" + try: + return CRS.from_user_input(crs_input).to_epsg() + except Exception: + return None + + +def transform_coordinates( + xs: Sequence[float], + ys: Sequence[float], + source_crs: str | int, + target_crs: str | int, + *, + always_xy: bool = True, +) -> tuple[list[float], list[float]]: + """批量转换坐标点。 + + Parameters + ---------- + xs: + X 坐标序列(地理 CRS 时为经度)。 + ys: + Y 坐标序列(地理 CRS 时为纬度)。 + source_crs: + 源 CRS。 + target_crs: + 目标 CRS。 + always_xy: + 强制以 (X, Y) 顺序输入输出(推荐保持 True)。 + + Returns + ------- + (list[float], list[float]) + 转换后的 (xs, ys)。 + """ + transformer = Transformer.from_crs(source_crs, target_crs, always_xy=always_xy) + result_xs, result_ys = transformer.transform(list(xs), list(ys)) + return list(result_xs), list(result_ys) + + +def transform_point( + x: float, + y: float, + source_crs: str | int, + target_crs: str | int, + *, + always_xy: bool = True, +) -> tuple[float, float]: + """转换单个坐标点。""" + xs, ys = transform_coordinates([x], [y], source_crs, target_crs, always_xy=always_xy) + return xs[0], ys[0] + + +def suggest_projected_crs(lon: float, lat: float, use_3degree: bool = True) -> str: + """根据经纬度范围自动推荐适合面积/距离计算的投影 CRS(CGCS2000 高斯-克吕格 带号)。 + + Parameters + ---------- + lon: + 中心经度(CGCS2000)。 + lat: + 中心纬度(CGCS2000)。 + use_3degree: + True 表示3度分带,False 表示6度分带。 + + Returns + ------- + str + EPSG 代码字符串,如 ``"EPSG:32650"``(CGCS2000 高斯-克吕格 带号)。 + """ + if use_3degree: + # 3度分带计算:中央经线 = 3° * n + central_meridian = round(lon / 3) * 3 + zone_number = int(central_meridian / 3) + + # CGCS2000 3度带投影定义 + # 从第25带到45带(75°E-135°E) + if 75 <= central_meridian <= 135: + epsg = 4513 + zone_number - 25 + else: + # 默认使用36带(108°E) + epsg = 4524 + logger.warning("经度范围超出3度带范围,默认使用36带(108°E)") + else: + # 6度分带计算:中央经线 = 6° * n - 3° + central_meridian = round((lon + 3) / 6) * 6 - 3 + zone_number = int((central_meridian + 3) / 6) + + # CGCS2000 6度带投影定义 + # 从第13带到23带(75°E-135°E) + if 75 <= central_meridian <= 135: + epsg = 4491 + zone_number - 13 + else: + # 默认使用18带(105°E) + epsg = 4496 + logger.warning("经度范围超出6度带范围,默认使用18带(105°E)") + + logger.debug("建议投影 CRS:EPSG:%d(lon=%.2f, lat=%.2f)", epsg, lon, lat) + return f"EPSG:{epsg}" + + +def reproject_gdf( + gdf: gpd.GeoDataFrame, + target_crs: str | int | None = None, + *, + auto_crs: bool = False, + verbose: bool = True, +) -> gpd.GeoDataFrame: + """将 GeoDataFrame(要素类)重投影到目标坐标系。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame,必须已定义 CRS。 + target_crs: + 目标 CRS,如 ``"EPSG:4326"``、``"EPSG:4490"`` 或整数 ``4523``。 + 与 ``auto_crs=True`` 二选一。 + auto_crs: + 为 ``True`` 时忽略 ``target_crs``,根据数据中心点自动推荐 + CGCS2000 高斯-克吕格 带号(适合面积/距离计算场景)。 + verbose: + 为 ``True`` 时在日志中打印投影前后的 CRS 信息。 + + Returns + ------- + gpd.GeoDataFrame + 重投影后的新 GeoDataFrame(原始对象不变)。 + + Raises + ------ + ValueError + ``gdf`` 未定义 CRS,或 ``target_crs`` 与 ``auto_crs`` 均未指定。 + + Examples + -------- + >>> # 指定目标 CRS + >>> gdf_proj = reproject_gdf(gdf, "EPSG:4490") + + >>> # 自动推荐 CGCS2000 高斯-克吕格 带号(用于面积计算) + >>> gdf_utm = reproject_gdf(gdf, auto_crs=True) + + >>> # 配合 GDB 读取 + >>> gdf = read_gdb("data.gdb", layer="图斑") + >>> gdf_proj = reproject_gdf(gdf, "EPSG:4326") + """ + + + if gdf.crs is None: + raise ValueError("GeoDataFrame 未定义 CRS,请先调用 set_crs() 设置坐标系。") + + if auto_crs: + # 先统一到地理坐标系,再取中心点推荐 CGCS2000 高斯-克吕格 带号 + if gdf.crs.is_projected: + center = gdf.to_crs("EPSG:4490").geometry.unary_union.centroid + else: + center = gdf.geometry.unary_union.centroid + target_crs = suggest_projected_crs(center.x, center.y) + logger.info("auto_crs:自动推荐投影 CRS = %s", target_crs) + + if target_crs is None: + raise ValueError("请指定 target_crs,或设置 auto_crs=True 自动推荐投影。") + + src_crs_str = gdf.crs.to_string() + result = gdf.to_crs(target_crs) + + if verbose: + tgt_info = get_crs_info(target_crs) + logger.info( + "要素类重投影完成:%s → %s(%s,单位:%s,要素数:%d)", + src_crs_str, + tgt_info.get("epsg") or target_crs, + tgt_info.get("name"), + tgt_info.get("unit"), + len(result), + ) + + return result + diff --git a/geo_tools/core/raster.py b/geo_tools/core/raster.py new file mode 100644 index 0000000..11cb30e --- /dev/null +++ b/geo_tools/core/raster.py @@ -0,0 +1,103 @@ +""" +geo_tools.core.raster +~~~~~~~~~~~~~~~~~~~~~ +栅格数据处理预留接口。 + +当前提供基于 rasterio 的核心读写骨架; +如需完整栅格分析功能,请安装可选依赖: +``pip install geo-tools[raster]`` +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import numpy as np + + +def _require_rasterio() -> Any: + """检查 rasterio 是否可用,不可用时给出明确提示。""" + try: + import rasterio + return rasterio + except ImportError as exc: + raise ImportError( + "栅格处理功能需要 rasterio。\n" + "请执行:pip install geo-tools[raster] 或 pip install rasterio" + ) from exc + + +def read_raster( + path: str | Path, + band: int = 1, +) -> tuple["np.ndarray", dict[str, Any]]: + """读取栅格文件(单波段)。 + + Parameters + ---------- + path: + GeoTIFF 或其他 GDAL 支持格式的路径。 + band: + 波段号,1-indexed。 + + Returns + ------- + (np.ndarray, dict) + 栅格数组 和 rasterio 元数据字典(``meta``)。 + """ + rasterio = _require_rasterio() + with rasterio.open(str(path)) as src: + data = src.read(band) + meta = src.meta.copy() + return data, meta + + +def write_raster( + data: "np.ndarray", + path: str | Path, + meta: dict[str, Any], + band: int = 1, +) -> Path: + """将 numpy 数组写出为 GeoTIFF。 + + Parameters + ---------- + data: + 2D numpy 数组(单波段)。 + path: + 输出路径(.tif)。 + meta: + rasterio 元数据字典(从 ``read_raster`` 获取或自行构造)。 + band: + 写入的波段号,1-indexed。 + + Returns + ------- + Path + """ + rasterio = _require_rasterio() + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + meta.update({"count": 1, "dtype": str(data.dtype)}) + with rasterio.open(str(path), "w", **meta) as dst: + dst.write(data, band) + return path + + +def get_raster_info(path: str | Path) -> dict[str, Any]: + """获取栅格文件的基本元信息(行列数、波段数、CRS、分辨率等)。""" + rasterio = _require_rasterio() + with rasterio.open(str(path)) as src: + return { + "width": src.width, + "height": src.height, + "count": src.count, + "dtype": src.dtypes[0], + "crs": str(src.crs), + "transform": src.transform, + "bounds": src.bounds, + "nodata": src.nodata, + "res": src.res, + } diff --git a/geo_tools/core/vector.py b/geo_tools/core/vector.py new file mode 100644 index 0000000..f215c27 --- /dev/null +++ b/geo_tools/core/vector.py @@ -0,0 +1,210 @@ +""" +geo_tools.core.vector +~~~~~~~~~~~~~~~~~~~~~ +基于 geopandas 的矢量要素处理函数。 +""" + +from __future__ import annotations + +from typing import Any + +import geopandas as gpd +import pandas as pd + +from geo_tools.utils.logger import get_logger +from geo_tools.utils.validators import validate_geometry + +logger = get_logger(__name__) + + +def reproject(gdf: gpd.GeoDataFrame, target_crs: str | int) -> gpd.GeoDataFrame: + """将 GeoDataFrame 重投影到目标坐标系。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame,必须已定义 CRS。 + target_crs: + 目标 CRS,如 ``"EPSG:3857"`` 或 ``4490``。 + + Returns + ------- + gpd.GeoDataFrame + 重投影后的 GeoDataFrame(新对象,原始不变)。 + """ + if gdf.crs is None: + raise ValueError("GeoDataFrame 未定义 CRS,请先设置坐标系。") + if gdf.crs.to_epsg() == (target_crs if isinstance(target_crs, int) else None): + return gdf # 已经是目标 CRS,跳过 + logger.debug("重投影:%s → %s(共 %d 条)", gdf.crs, target_crs, len(gdf)) + return gdf.to_crs(target_crs) + + +def set_crs(gdf: gpd.GeoDataFrame, crs: str | int, *, overwrite: bool = False) -> gpd.GeoDataFrame: + """为没有 CRS 的 GeoDataFrame 设置坐标系(不重投影)。 + + Parameters + ---------- + gdf: + 输入数据。 + crs: + 目标 CRS。 + overwrite: + 若为 ``True``,即使已有 CRS 也强制覆盖(危险操作,请确认坐标系正确)。 + """ + if gdf.crs is not None and not overwrite: + raise ValueError( + f"GeoDataFrame 已有 CRS:{gdf.crs}。若要覆盖,请传入 overwrite=True。" + ) + return gdf.set_crs(crs, allow_override=overwrite) + + +def clip_to_extent( + gdf: gpd.GeoDataFrame, + bbox: tuple[float, float, float, float] | gpd.GeoDataFrame, +) -> gpd.GeoDataFrame: + """按矩形范围或另一个 GeoDataFrame 裁切要素。 + + Parameters + ---------- + gdf: + 待裁切的 GeoDataFrame。 + bbox: + 矩形范围 ``(minx, miny, maxx, maxy)`` 或用于裁切的 GeoDataFrame / GeoSeries。 + + Returns + ------- + gpd.GeoDataFrame + """ + if isinstance(bbox, tuple): + from shapely.geometry import box as shapely_box + mask = shapely_box(*bbox) + result = gdf.clip(mask) + else: + if bbox.crs != gdf.crs: + bbox = bbox.to_crs(gdf.crs) + result = gdf.clip(bbox) + + logger.debug("裁切完成:%d → %d 条", len(gdf), len(result)) + return result + + +def dissolve_by( + gdf: gpd.GeoDataFrame, + by: str | list[str], + aggfunc: str | dict[str, Any] = "first", +) -> gpd.GeoDataFrame: + """按属性字段融合(Dissolve)几何要素。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame。 + by: + 融合字段名或字段列表。 + aggfunc: + 属性聚合函数,参考 ``pd.DataFrame.groupby``。 + + Returns + ------- + gpd.GeoDataFrame + 融合后的 GeoDataFrame,索引为 ``by`` 字段。 + """ + logger.debug("按字段 %r 融合要素(%d 条 → ?)", by, len(gdf)) + result = gdf.dissolve(by=by, aggfunc=aggfunc).reset_index() + logger.debug("融合完成:%d 条", len(result)) + return result + + +def explode_multipart(gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """将多部分几何(MultiPolygon 等)拆分为单部分要素。 + + Returns + ------- + gpd.GeoDataFrame + 拆分后索引已 reset。 + """ + result = gdf.explode(index_parts=False).reset_index(drop=True) + logger.debug("多部分拆分:%d → %d 条", len(gdf), len(result)) + return result + + +def drop_invalid_geometries(gdf: gpd.GeoDataFrame, *, fix: bool = False) -> gpd.GeoDataFrame: + """删除或修复无效几何。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame。 + fix: + 若为 ``True``,尝试通过 ``buffer(0)`` 修复无效几何而非删除。 + """ + stats = validate_geometry(gdf) + if stats["invalid"] == 0 and stats["null"] == 0: + return gdf + + if fix: + from geo_tools.core.geometry import fix_geometry + gdf = gdf.copy() + mask = ~gdf.geometry.is_valid | gdf.geometry.isna() + gdf.loc[mask, "geometry"] = gdf.loc[mask, "geometry"].apply(fix_geometry) + logger.info("已修复 %d 个无效几何", stats["invalid"]) + else: + before = len(gdf) + gdf = gdf[gdf.geometry.is_valid & gdf.geometry.notna()].copy() + logger.info("已删除 %d 个无效/空几何", before - len(gdf)) + return gdf + + +def spatial_join( + left: gpd.GeoDataFrame, + right: gpd.GeoDataFrame, + how: str = "left", + predicate: str = "intersects", + **kwargs: Any, +) -> gpd.GeoDataFrame: + """空间连接(封装 geopandas.sjoin)。 + + Parameters + ---------- + left: + 左侧 GeoDataFrame。 + right: + 右侧 GeoDataFrame。 + how: + 连接方式:``"left"``、``"right"``、``"inner"``。 + predicate: + 空间谓词:``"intersects"``、``"contains"``、``"within"``、``"touches"``。 + """ + if left.crs != right.crs: + right = right.to_crs(left.crs) + result = gpd.sjoin(left, right, how=how, predicate=predicate, **kwargs) + logger.debug("空间连接完成:%d 条结果", len(result)) + return result + + +def add_area_column( + gdf: gpd.GeoDataFrame, + col_name: str = "area_m2", + projected_crs: str = "EPSG:3857", +) -> gpd.GeoDataFrame: + """添加面积列(单位:平方米)。 + + 将数据临时投影到 ``projected_crs``(笛卡尔投影)计算面积后回填到原 GDF。 + + Parameters + ---------- + gdf: + 输入 GeoDataFrame(面要素)。 + col_name: + 新列名。 + projected_crs: + 用于面积计算的投影 CRS(需为平面坐标系)。 + """ + gdf = gdf.copy() + if gdf.crs is None or not gdf.crs.is_projected: + projected = gdf.to_crs(projected_crs) + else: + projected = gdf + gdf[col_name] = projected.geometry.area + return gdf diff --git a/geo_tools/io/__init__.py b/geo_tools/io/__init__.py new file mode 100644 index 0000000..ff9c62c --- /dev/null +++ b/geo_tools/io/__init__.py @@ -0,0 +1,31 @@ +"""geo_tools.io 包 —— 数据读写层。""" + +from geo_tools.io.readers import ( + list_gdb_layers, + list_gpkg_layers, + read_csv_points, + read_gdb, + read_gpkg, + read_vector, +) +from geo_tools.io.writers import ( + write_csv, + write_gdb, + write_gpkg, + write_vector, +) + +__all__ = [ + # readers + "read_vector", + "read_gdb", + "list_gdb_layers", + "read_gpkg", + "list_gpkg_layers", + "read_csv_points", + # writers + "write_vector", + "write_gdb", + "write_gpkg", + "write_csv", +] diff --git a/geo_tools/io/readers.py b/geo_tools/io/readers.py new file mode 100644 index 0000000..4f3e0b3 --- /dev/null +++ b/geo_tools/io/readers.py @@ -0,0 +1,251 @@ +""" +geo_tools.io.readers +~~~~~~~~~~~~~~~~~~~~ +统一的矢量数据读取接口,支持: +- Shapefile (.shp) +- GeoJSON (.geojson / .json) +- GeoPackage (.gpkg) +- File Geodatabase (.gdb) ← 通过 fiona OpenFileGDB / ESRI FileGDB 驱动 +- KML / KMZ +- FlatGeobuf (.fgb) +- CSV(含 WKT 或 经纬度列) + +所有函数均返回 ``geopandas.GeoDataFrame``。 +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import fiona +import geopandas as gpd + +from geo_tools.utils.logger import get_logger +from geo_tools.utils.validators import validate_vector_path + +logger = get_logger(__name__) + + +# ── 主入口 ───────────────────────────────────────────────────────────────────── + +def read_vector( + path: str | Path, + layer: str | int | None = None, + crs: str | int | None = None, + encoding: str = "utf-8", + **kwargs: Any, +) -> gpd.GeoDataFrame: + """统一的矢量数据读取入口,自动识别文件格式。 + + Parameters + ---------- + path: + 数据路径。支持文件或目录(FileGDB ``*.gdb``)。 + layer: + 图层名或索引(多图层格式如 GPKG、GDB 必填;单图层可省略)。 + crs: + 读取后强制重投影到目标 CRS(不传则保留原始 CRS)。 + encoding: + 属性表编码,Shapefile 中文路径常需指定 ``"gbk"``。 + **kwargs: + 透传给 ``geopandas.read_file`` 的额外参数。 + + Returns + ------- + gpd.GeoDataFrame + """ + path = validate_vector_path(path) + suffix = path.suffix.lower() + + logger.info("读取矢量数据:%s(格式:%s,图层:%s)", path, suffix or "目录", layer) + + if suffix == ".csv": + return _read_csv_vector(path, crs=crs, **kwargs) + + # fiona / geopandas 通用读取 + read_kwargs: dict[str, Any] = {"encoding": encoding, **kwargs} + if layer is not None: + read_kwargs["layer"] = layer + + gdf = gpd.read_file(str(path), **read_kwargs) + + if crs is not None: + logger.debug("重投影到 %s", crs) + gdf = gdf.to_crs(crs) + + logger.info("读取完成:共 %d 条要素,CRS=%s", len(gdf), gdf.crs) + return gdf + + +# ── GDB 专用 ─────────────────────────────────────────────────────────────────── + +def read_gdb( + gdb_path: str | Path, + layer: str | int | None = None, + crs: str | int | None = None, + encoding: str = "utf-8", + **kwargs: Any, +) -> gpd.GeoDataFrame: + """读取 Esri File Geodatabase(.gdb)中的图层。 + + Parameters + ---------- + gdb_path: + ``.gdb`` 目录路径。 + layer: + 图层名称或索引。若不指定且 GDB 仅有一个图层,则自动选取第一层; + 多图层时必须指定。 + crs: + 读取后目标 CRS,``None`` 则保留原始坐标系。 + encoding: + 属性表字段编码。 + """ + gdb_path = Path(gdb_path) + if not gdb_path.exists(): + raise FileNotFoundError(f"GDB 路径不存在:{gdb_path}") + if gdb_path.suffix.lower() != ".gdb": + raise ValueError(f"期望 .gdb 目录,收到:{gdb_path.suffix!r}") + + available_layers = list_gdb_layers(gdb_path) + logger.debug("GDB 可用图层:%s", available_layers) + + if layer is None: + if not available_layers: + raise ValueError(f"GDB 中没有可用图层:{gdb_path}") + layer = available_layers[0] + if len(available_layers) > 1: + logger.warning( + "GDB 包含多个图层 %s,默认读取第一层 %r。请显式传入 layer=... 以指定图层。", + available_layers, + layer, + ) + + logger.info("读取 GDB 图层:%s >> %s", gdb_path.name, layer) + gdf = gpd.read_file(str(gdb_path), layer=layer, encoding=encoding, **kwargs) + + if crs is not None: + gdf = gdf.to_crs(crs) + + logger.info("GDB 读取完成:%d 条要素,CRS=%s", len(gdf), gdf.crs) + return gdf + + +def list_gdb_layers(gdb_path: str | Path) -> list[str]: + """列出 FileGDB 中所有图层名称。 + + Parameters + ---------- + gdb_path: + ``.gdb`` 目录路径。 + + Returns + ------- + list[str] + 图层名称列表。 + """ + gdb_path = Path(gdb_path) + try: + return fiona.listlayers(str(gdb_path)) + except Exception as exc: + raise RuntimeError( + f"无法列出 GDB 图层:{gdb_path}。\n" + "请确认 fiona 已安装 OpenFileGDB 驱动(通常随 conda/wheels 自带)。\n" + f"原始错误:{exc}" + ) from exc + + +# ── GPKG 专用 ────────────────────────────────────────────────────────────────── + +def read_gpkg( + gpkg_path: str | Path, + layer: str | int | None = None, + crs: str | int | None = None, + **kwargs: Any, +) -> gpd.GeoDataFrame: + """读取 GeoPackage (.gpkg) 文件。 + + Parameters + ---------- + gpkg_path: + ``.gpkg`` 文件路径。 + layer: + 图层名或索引;多图层时必须指定。 + """ + gpkg_path = Path(gpkg_path) + if not gpkg_path.exists(): + raise FileNotFoundError(f"GPKG 文件不存在:{gpkg_path}") + + available = fiona.listlayers(str(gpkg_path)) + if layer is None: + if not available: + raise ValueError(f"GPKG 中没有可用图层:{gpkg_path}") + layer = available[0] + if len(available) > 1: + logger.warning( + "GPKG 包含多个图层 %s,默认读取第一层 %r。", available, layer + ) + + gdf = gpd.read_file(str(gpkg_path), layer=layer, **kwargs) + if crs is not None: + gdf = gdf.to_crs(crs) + return gdf + + +def list_gpkg_layers(gpkg_path: str | Path) -> list[str]: + """列出 GeoPackage 中所有图层名称。""" + return fiona.listlayers(str(gpkg_path)) + + +# ── CSV 矢量读取 ──────────────────────────────────────────────────────────────── + +def _read_csv_vector( + path: Path, + lon_col: str = "longitude", + lat_col: str = "latitude", + wkt_col: str | None = None, + crs: str | int | None = None, + **kwargs: Any, +) -> gpd.GeoDataFrame: + """从 CSV 读取空间数据,支持 WKT 列或经纬度列。 + + Parameters + ---------- + path: + CSV 文件路径。 + lon_col: + 经度列名(WKT 模式时忽略)。 + lat_col: + 纬度列名(WKT 模式时忽略)。 + wkt_col: + WKT 几何列名;若指定则优先使用。 + """ + import pandas as pd + from shapely import wkt as shapely_wkt + + df = pd.read_csv(path, **kwargs) + + if wkt_col and wkt_col in df.columns: + geometry = df[wkt_col].apply(shapely_wkt.loads) + elif lon_col in df.columns and lat_col in df.columns: + from shapely.geometry import Point + geometry = [Point(lon, lat) for lon, lat in zip(df[lon_col], df[lat_col])] + else: + raise ValueError( + f"CSV 中未找到 WKT 列 {wkt_col!r} 或经纬度列 ({lon_col!r}, {lat_col!r})。" + ) + + gdf = gpd.GeoDataFrame(df, geometry=geometry, crs=crs or "EPSG:4326") + return gdf + + +def read_csv_points( + path: str | Path, + lon_col: str = "longitude", + lat_col: str = "latitude", + crs: str | int = "EPSG:4326", + **kwargs: Any, +) -> gpd.GeoDataFrame: + """从含经纬度列的 CSV 文件创建点 GeoDataFrame(公开接口)。""" + path = Path(path) + return _read_csv_vector(path, lon_col=lon_col, lat_col=lat_col, crs=crs, **kwargs) diff --git a/geo_tools/io/writers.py b/geo_tools/io/writers.py new file mode 100644 index 0000000..63d4619 --- /dev/null +++ b/geo_tools/io/writers.py @@ -0,0 +1,207 @@ +""" +geo_tools.io.writers +~~~~~~~~~~~~~~~~~~~~~ +统一的矢量数据写出接口,支持: +- Shapefile (.shp) +- GeoJSON (.geojson / .json) +- GeoPackage (.gpkg) ← 支持追加图层 +- File Geodatabase (.gdb) ← 通过 fiona OpenFileGDB 驱动 +- FlatGeobuf (.fgb) +- CSV(含 WKT 列) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import geopandas as gpd + +from geo_tools.utils.logger import get_logger + +logger = get_logger(__name__) + + +# ── 主入口 ────────────────────────────────────────────────────────────────────── + +def write_vector( + gdf: gpd.GeoDataFrame, + path: str | Path, + layer: str | None = None, + driver: str | None = None, + encoding: str = "utf-8", + mode: Literal["w", "a"] = "w", + **kwargs: Any, +) -> Path: + """统一的矢量数据写出入口,自动识别格式。 + + Parameters + ---------- + gdf: + 待写出的 GeoDataFrame。 + path: + 目标路径(文件或 `.gdb` 目录)。 + layer: + 图层名(GPKG、GDB 多图层格式时使用)。 + driver: + 强制指定 fiona 驱动名(通常不需要,自动推断)。 + encoding: + 字段编码,Shapefile 导出中文时常需 ``"gbk"``。 + mode: + ``"w"`` 覆盖写出,``"a"`` 追加图层(GPKG / GDB 支持)。 + + Returns + ------- + Path + 实际写出的路径。 + """ + path = Path(path) + suffix = path.suffix.lower() + + if suffix == ".csv": + return _write_csv_vector(gdf, path) + + # 自动推断驱动 + if driver is None: + driver = _infer_driver(path) + + # 确保父目录存在(GDB 是目录,其父目录要存在) + if suffix == ".gdb": + path.parent.mkdir(parents=True, exist_ok=True) + else: + path.parent.mkdir(parents=True, exist_ok=True) + + write_kwargs: dict[str, Any] = { + "driver": driver, + "encoding": encoding, + "mode": mode, + **kwargs, + } + if layer is not None: + write_kwargs["layer"] = layer + + logger.info( + "写出矢量数据:%s(驱动:%s,图层:%s,模式:%s,要素数:%d)", + path, driver, layer, mode, len(gdf), + ) + gdf.to_file(str(path), **write_kwargs) + logger.info("写出完成:%s", path) + return path + + +# ── GDB 专用 ──────────────────────────────────────────────────────────────────── + +def write_gdb( + gdf: gpd.GeoDataFrame, + gdb_path: str | Path, + layer: str, + mode: Literal["w", "a"] = "w", + encoding: str = "utf-8", + **kwargs: Any, +) -> Path: + """将 GeoDataFrame 写出到 Esri File Geodatabase(.gdb)中。 + + Parameters + ---------- + gdf: + 待写出的 GeoDataFrame。 + gdb_path: + 目标 ``.gdb`` 目录路径(不存在时自动创建)。 + layer: + 图层名称(必填)。 + mode: + ``"w"`` 覆盖图层;``"a"`` 向已有 GDB 追加图层。 + + Notes + ----- + 写出 GDB 依赖 fiona 的 ``OpenFileGDB``(写)或 ``FileGDB``(需 ESRI 驱动)支持。 + 当前 fiona >= 1.9 的 ``OpenFileGDB`` 驱动已支持创建和写出,无需额外安装。 + """ + gdb_path = Path(gdb_path) + if not layer: + raise ValueError("写出 GDB 必须指定 layer 参数。") + + gdb_path.parent.mkdir(parents=True, exist_ok=True) + + logger.info("写出到 GDB:%s >> 图层 %r(模式:%s)", gdb_path.name, layer, mode) + gdf.to_file( + str(gdb_path), + layer=layer, + driver="OpenFileGDB", + mode=mode, + encoding=encoding, + **kwargs, + ) + logger.info("GDB 写出完成:%s >> %s", gdb_path, layer) + return gdb_path + + +# ── GPKG 专用 ─────────────────────────────────────────────────────────────────── + +def write_gpkg( + gdf: gpd.GeoDataFrame, + gpkg_path: str | Path, + layer: str, + mode: Literal["w", "a"] = "w", + **kwargs: Any, +) -> Path: + """将 GeoDataFrame 写出为 GeoPackage 中的一个图层。 + + Parameters + ---------- + gpkg_path: + 目标 ``.gpkg`` 文件路径(不存在时自动创建)。 + layer: + 图层名称(必填)。 + mode: + ``"w"`` 覆盖;``"a"`` 向已有 GPKG 追加图层。 + """ + gpkg_path = Path(gpkg_path) + gpkg_path.parent.mkdir(parents=True, exist_ok=True) + gdf.to_file(str(gpkg_path), layer=layer, driver="GPKG", mode=mode, **kwargs) + logger.info("GPKG 写出完成:%s >> %s", gpkg_path, layer) + return gpkg_path + + +# ── CSV 写出 ──────────────────────────────────────────────────────────────────── + +def _write_csv_vector(gdf: gpd.GeoDataFrame, path: Path, **kwargs: Any) -> Path: + """将 GeoDataFrame 写出为含 WKT 几何列的 CSV。""" + path.parent.mkdir(parents=True, exist_ok=True) + df = gdf.copy() + df["geometry"] = df["geometry"].apply(lambda g: g.wkt if g is not None else None) + df.to_csv(path, index=False, encoding="utf-8-sig", **kwargs) # utf-8-sig 兼容 Excel + logger.info("CSV 写出完成:%s", path) + return path + + +def write_csv(gdf: gpd.GeoDataFrame, path: str | Path, **kwargs: Any) -> Path: + """将 GeoDataFrame 写出为含 WKT 几何列的 CSV(公开接口)。""" + return _write_csv_vector(gdf, Path(path), **kwargs) + + +# ── 工具函数 ───────────────────────────────────────────────────────────────────── + +def _infer_driver(path: Path) -> str: + """根据文件扩展名推断 fiona 驱动。""" + _EXT_TO_DRIVER: dict[str, str] = { + ".shp": "ESRI Shapefile", + ".geojson": "GeoJSON", + ".json": "GeoJSON", + ".gpkg": "GPKG", + ".gdb": "OpenFileGDB", + ".kml": "KML", + ".fgb": "FlatGeobuf", + ".gml": "GML", + ".dxf": "DXF", + } + suffix = path.suffix.lower() + if path.is_dir() and suffix == ".gdb": + return "OpenFileGDB" + driver = _EXT_TO_DRIVER.get(suffix) + if driver is None: + raise ValueError( + f"无法自动推断 fiona 驱动,未知扩展名:{suffix!r}。" + f"请显式传入 driver=... 参数。" + ) + return driver diff --git a/geo_tools/utils/__init__.py b/geo_tools/utils/__init__.py new file mode 100644 index 0000000..b8a88e0 --- /dev/null +++ b/geo_tools/utils/__init__.py @@ -0,0 +1,30 @@ +"""geo_tools.utils 包 —— 通用工具函数。""" + +from geo_tools.utils.config import load_config, load_json_config, load_toml_config, load_yaml_config +from geo_tools.utils.logger import get_logger, set_global_level +from geo_tools.utils.validators import ( + SUPPORTED_VECTOR_EXTENSIONS, + is_supported_vector_format, + is_valid_crs, + validate_crs, + validate_geometry, + validate_vector_path, +) + +__all__ = [ + # logger + "get_logger", + "set_global_level", + # config loaders + "load_config", + "load_json_config", + "load_toml_config", + "load_yaml_config", + # validators + "is_valid_crs", + "validate_crs", + "validate_geometry", + "is_supported_vector_format", + "validate_vector_path", + "SUPPORTED_VECTOR_EXTENSIONS", +] diff --git a/geo_tools/utils/config.py b/geo_tools/utils/config.py new file mode 100644 index 0000000..f85e43f --- /dev/null +++ b/geo_tools/utils/config.py @@ -0,0 +1,85 @@ +""" +geo_tools.utils.config +~~~~~~~~~~~~~~~~~~~~~~ +配置加载辅助函数:读取 TOML / JSON / YAML 格式的任务配置文件。 +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + + +def load_json_config(path: str | Path) -> dict[str, Any]: + """读取 JSON 配置文件。 + + Parameters + ---------- + path: + JSON 文件路径。 + + Returns + ------- + dict + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"配置文件不存在:{path}") + with path.open(encoding="utf-8") as f: + return json.load(f) + + +def load_toml_config(path: str | Path) -> dict[str, Any]: + """读取 TOML 配置文件(Python 3.11+ 内置 tomllib,低版本需 tomli)。 + + Parameters + ---------- + path: + TOML 文件路径。 + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"配置文件不存在:{path}") + try: + import tomllib # Python 3.11+ + except ImportError: + try: + import tomli as tomllib # type: ignore[no-redef] + except ImportError as exc: + raise ImportError( + "读取 TOML 文件需要 Python 3.11+ 或安装 tomli:pip install tomli" + ) from exc + with path.open("rb") as f: + return tomllib.load(f) + + +def load_yaml_config(path: str | Path) -> dict[str, Any]: + """读取 YAML 配置文件(需安装 PyYAML)。""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"配置文件不存在:{path}") + try: + import yaml + except ImportError as exc: + raise ImportError("读取 YAML 文件需要安装 pyyaml:pip install pyyaml") from exc + with path.open(encoding="utf-8") as f: + return yaml.safe_load(f) or {} + + +def load_config(path: str | Path) -> dict[str, Any]: + """根据文件扩展名自动选择解析器。 + + 支持 ``.json``、``.toml``、``.yaml``、``.yml``。 + """ + path = Path(path) + ext = path.suffix.lower() + loaders = { + ".json": load_json_config, + ".toml": load_toml_config, + ".yaml": load_yaml_config, + ".yml": load_yaml_config, + } + if ext not in loaders: + raise ValueError(f"不支持的配置文件格式:{ext!r},支持:{list(loaders)}") + return loaders[ext](path) diff --git a/geo_tools/utils/logger.py b/geo_tools/utils/logger.py new file mode 100644 index 0000000..c1cd4ea --- /dev/null +++ b/geo_tools/utils/logger.py @@ -0,0 +1,107 @@ +""" +geo_tools.utils.logger +~~~~~~~~~~~~~~~~~~~~~~ +统一日志工厂,支持同时输出到控制台和文件。 + +使用方式 +-------- +>>> from geo_tools.utils.logger import get_logger +>>> logger = get_logger(__name__) +>>> logger.info("处理开始") +""" + +from __future__ import annotations + +import logging +import sys +from pathlib import Path + + +_LOG_FORMAT = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s" +_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +# 已初始化的 logger 集合,避免重复添加 handler +_initialized: set[str] = set() + + +def get_logger( + name: str, + level: str | None = None, + log_file: Path | str | None = None, + *, + propagate: bool = False, +) -> logging.Logger: + """获取(或创建)一个带格式化 handler 的 Logger。 + + Parameters + ---------- + name: + Logger 名称,通常传入 ``__name__``。 + level: + 日志等级字符串;``None`` 时读取 ``settings.log_level``。 + log_file: + 日志文件路径;``None`` 时读取 ``settings``: + 若 ``settings.log_to_file`` 为 True,则写到 ``settings.log_dir/geo_tools.log``。 + propagate: + 是否向父 logger 传播,默认 False(避免重复输出)。 + + Returns + ------- + logging.Logger + """ + # 延迟导入,避免循环依赖 + from geo_tools.config.settings import settings as _settings + + if level is None: + level = _settings.log_level + numeric_level = logging.getLevelName(level.upper()) + + logger = logging.getLogger(name) + logger.propagate = propagate + + # 已初始化则直接返回,level 可动态调整 + if name in _initialized: + logger.setLevel(numeric_level) + return logger + + logger.setLevel(numeric_level) + + formatter = logging.Formatter(_LOG_FORMAT, datefmt=_DATE_FORMAT) + + # ── 控制台 handler ──────────────────────────────────────── + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(numeric_level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # ── 文件 handler ────────────────────────────────────────── + _resolve_log_file = log_file + if _resolve_log_file is None and _settings.log_to_file: + _settings.ensure_dirs() + _resolve_log_file = _settings.log_dir / "geo_tools.log" + + if _resolve_log_file is not None: + file_path = Path(_resolve_log_file) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(file_path, encoding="utf-8") + file_handler.setLevel(numeric_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + _initialized.add(name) + return logger + + +def set_global_level(level: str) -> None: + """动态调整所有 geo_tools 下 logger 的日志等级。 + + Parameters + ---------- + level: + 目标日志等级,例如 ``"DEBUG"``。 + """ + numeric = logging.getLevelName(level.upper()) + root = logging.getLogger("geo_tools") + root.setLevel(numeric) + for handler in root.handlers: + handler.setLevel(numeric) diff --git a/geo_tools/utils/validators.py b/geo_tools/utils/validators.py new file mode 100644 index 0000000..32c0bad --- /dev/null +++ b/geo_tools/utils/validators.py @@ -0,0 +1,145 @@ +""" +geo_tools.utils.validators +~~~~~~~~~~~~~~~~~~~~~~~~~~ +数据验证工具:CRS 合法性、几何有效性、文件格式等。 +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import geopandas as gpd + from shapely.geometry.base import BaseGeometry + + +# ── CRS 校验 ────────────────────────────────────────────────────────────────── + +def is_valid_crs(crs_input: str | int) -> bool: + """检查 CRS 是否可以被 pyproj 正常解析。 + + Parameters + ---------- + crs_input: + EPSG 代码(整数或 ``"EPSG:4326"`` 字符串)或 proj 字符串。 + + Returns + ------- + bool + """ + try: + from pyproj import CRS + CRS.from_user_input(crs_input) + return True + except Exception: + return False + + +def validate_crs(crs_input: str | int) -> str: + """校验并标准化 CRS,返回 EPSG 代码字符串。 + + Raises + ------ + ValueError + 如果 CRS 无法被 pyproj 解析。 + """ + from pyproj import CRS + try: + crs_obj = CRS.from_user_input(crs_input) + # 尝试返回简洁的 EPSG 字符串 + epsg = crs_obj.to_epsg() + if epsg: + return f"EPSG:{epsg}" + return crs_obj.to_string() + except Exception as exc: + raise ValueError(f"无效的 CRS:{crs_input!r}。原因:{exc}") from exc + + +# ── 几何校验 ────────────────────────────────────────────────────────────────── + +def validate_geometry(gdf: "gpd.GeoDataFrame", *, raise_on_invalid: bool = False) -> dict[str, int]: + """检查 GeoDataFrame 中几何对象的有效性。 + + Parameters + ---------- + gdf: + 待检查的 GeoDataFrame。 + raise_on_invalid: + 若为 True,当存在无效几何时抛出 ``ValueError``。 + + Returns + ------- + dict + 包含 ``total``、``valid``、``invalid``、``null`` 计数。 + """ + import geopandas as gpd # noqa: F811 + + null_count = gdf.geometry.isna().sum() + non_null = gdf.geometry.dropna() + invalid_mask = ~non_null.is_valid + invalid_count = int(invalid_mask.sum()) + valid_count = len(non_null) - invalid_count + + result = { + "total": len(gdf), + "valid": valid_count, + "invalid": invalid_count, + "null": int(null_count), + } + + if raise_on_invalid and (invalid_count > 0 or null_count > 0): + raise ValueError( + f"GeoDataFrame 存在 {invalid_count} 个无效几何、{null_count} 个空几何。" + ) + return result + + +# ── 文件格式校验 ─────────────────────────────────────────────────────────────── + +#: 支持读取的矢量文件扩展名(fiona 驱动映射) +SUPPORTED_VECTOR_EXTENSIONS: dict[str, str] = { + ".shp": "ESRI Shapefile", + ".geojson": "GeoJSON", + ".json": "GeoJSON", + ".gpkg": "GPKG", + ".gdb": "OpenFileGDB", + ".kml": "KML", + ".kmz": "KML", + ".csv": "CSV", + ".gml": "GML", + ".dxf": "DXF", + ".fgb": "FlatGeobuf", +} + + +def is_supported_vector_format(path: str | Path) -> bool: + """判断路径是否为已知的矢量格式。""" + path = Path(path) + suffix = path.suffix.lower() + # .gdb 可能是目录(FileGDB) + if path.is_dir() and suffix == ".gdb": + return True + return suffix in SUPPORTED_VECTOR_EXTENSIONS + + +def validate_vector_path(path: str | Path) -> Path: + """校验矢量数据路径,返回 Path 对象。 + + Raises + ------ + FileNotFoundError + 文件或目录不存在。 + ValueError + 文件格式不受支持。 + """ + path = Path(path) + # GDB 是目录 + if not path.exists(): + raise FileNotFoundError(f"路径不存在:{path}") + if not is_supported_vector_format(path): + raise ValueError( + f"不支持的矢量格式:{path.suffix!r}。" + f"支持的格式:{list(SUPPORTED_VECTOR_EXTENSIONS.keys())}" + ) + return path diff --git a/logs/.gitkeep b/logs/.gitkeep new file mode 100644 index 0000000..2e7b212 --- /dev/null +++ b/logs/.gitkeep @@ -0,0 +1 @@ +# 此目录用于存放日志文件(已在 .gitignore 中忽略) diff --git a/output/.gitkeep b/output/.gitkeep new file mode 100644 index 0000000..1ae39ca --- /dev/null +++ b/output/.gitkeep @@ -0,0 +1 @@ +# 此目录用于存放处理结果输出文件(已在 .gitignore 中忽略) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d60a43d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,98 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.backends.legacy:build" + +[project] +name = "geo-tools" +version = "0.1.0" +description = "专业地理信息数据处理工具库 —— 基于 geopandas / shapely / fiona" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "MIT" } +authors = [{ name = "geo_tools contributors" }] +keywords = ["gis", "geopandas", "shapely", "fiona", "spatial", "geospatial"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: GIS", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +dependencies = [ + "geopandas>=0.14", + "shapely>=2.0", + "fiona>=1.9", + "pyproj>=3.6", + "pandas>=2.0", + "numpy>=1.24", + "pydantic>=2.0", + "pydantic-settings>=2.0", + "python-dotenv>=1.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4", + "pytest-cov>=4.1", + "pytest-timeout>=2.1", + "black>=23.0", + "isort>=5.12", + "flake8>=6.0", + "mypy>=1.5", +] +notebook = [ + "jupyter>=1.0", + "matplotlib>=3.7", + "contextily>=1.4", + "folium>=0.15", +] +raster = [ + "rasterio>=1.3", + "xarray>=2023.6", + "rio-cogeo>=4.0", +] + +[project.urls] +Homepage = "https://github.com/your-org/geo_tools" +Repository = "https://github.com/your-org/geo_tools" +"Bug Tracker" = "https://github.com/your-org/geo_tools/issues" + +[tool.setuptools.packages.find] +where = ["."] +include = ["geo_tools*"] + +# ── pytest ───────────────────────────────────────────────────────────────── +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-v --tb=short" +timeout = 120 + +# ── coverage ──────────────────────────────────────────────────────────────── +[tool.coverage.run] +source = ["geo_tools"] +omit = ["tests/*", "scripts/*", "examples/*"] + +[tool.coverage.report] +show_missing = true +skip_covered = false + +# ── black ─────────────────────────────────────────────────────────────────── +[tool.black] +line-length = 100 +target-version = ["py310", "py311", "py312"] + +# ── isort ─────────────────────────────────────────────────────────────────── +[tool.isort] +profile = "black" +line_length = 100 + +# ── mypy ──────────────────────────────────────────────────────────────────── +[tool.mypy] +python_version = "3.10" +ignore_missing_imports = true +warn_return_any = false diff --git a/scripts/example_workflow.py b/scripts/example_workflow.py new file mode 100644 index 0000000..8d7533b --- /dev/null +++ b/scripts/example_workflow.py @@ -0,0 +1,104 @@ +""" +scripts/example_workflow.py +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +端到端示例:演示 geo_tools 的完整工作流。 + +运行方式 +-------- +cd F:\\@Project\\python\\geo_tools +python scripts/example_workflow.py +""" + +from __future__ import annotations + +from pathlib import Path + +# 添加项目根目录到路径(开发模式未安装时使用) +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import geo_tools +from geo_tools.utils.logger import get_logger + +logger = get_logger("example_workflow") + +DATA_DIR = Path(__file__).parent.parent / "data" / "sample" +OUTPUT_DIR = Path(__file__).parent.parent / "output" +OUTPUT_DIR.mkdir(exist_ok=True) + + +def main() -> None: + logger.info("=" * 60) + logger.info("geo_tools 端到端工作流示例 v%s", geo_tools.__version__) + logger.info("=" * 60) + + # ── 1. 读取示例点数据 ────────────────────────────────────────── + logger.info("\n[步骤 1] 读取示例点数据(GeoJSON)") + points = geo_tools.read_vector(DATA_DIR / "sample_points.geojson") + logger.info(" 读取完成:%d 条要素,CRS=%s", len(points), points.crs) + logger.info(" 字段:%s", list(points.columns)) + + # ── 2. 读取示例面数据 ────────────────────────────────────────── + logger.info("\n[步骤 2] 读取示例区域多边形(GeoJSON)") + regions = geo_tools.read_vector(DATA_DIR / "sample_regions.geojson") + logger.info(" 区域列表:%s", regions["name"].tolist()) + + # ── 3. 数据校验 ─────────────────────────────────────────────── + logger.info("\n[步骤 3] 几何有效性校验") + stats = geo_tools.validate_geometry(points) + logger.info(" 点数据校验结果:%s", stats) + stats = geo_tools.validate_geometry(regions) + logger.info(" 面数据校验结果:%s", stats) + + # ── 4. 坐标系信息 ───────────────────────────────────────────── + logger.info("\n[步骤 4] 查询 CRS 信息") + crs_info = geo_tools.get_crs_info("EPSG:4326") + logger.info(" WGS84 信息:%s", crs_info) + proj_crs = geo_tools.suggest_projected_crs(116.4, 39.9) + logger.info(" 北京适合的投影 CRS:%s", proj_crs) + + # ── 5. 重投影 ───────────────────────────────────────────────── + logger.info("\n[步骤 5] 重投影到 Web Mercator(用于可视化)") + points_3857 = geo_tools.reproject(points, "EPSG:3857") + logger.info(" 重投影完成:CRS=%s", points_3857.crs) + + # ── 6. 面积加权均值 ─────────────────────────────────────────── + logger.info("\n[步骤 6] 面积加权均值计算(示例:用 buffer 生成面数据)") + # 先将点缓冲生成面数据 + points_buffered = points.to_crs("EPSG:3857").copy() + points_buffered["geometry"] = points_buffered.geometry.buffer(100_000) # 100km缓冲 + points_buffered = points_buffered.to_crs("EPSG:4326") + from geo_tools.analysis.stats import area_weighted_mean + aw_result = area_weighted_mean(points_buffered, value_col="value") + logger.info(" 全局面积加权均值:%.4f", aw_result["area_weighted_mean"]) + + # ── 7. 按位置选择 ───────────────────────────────────────────── + logger.info("\n[步骤 7] 按位置选择:筛选华南区域内的城市") + hua_nan = regions[regions["name"] == "华南"] + points_in_huanan = geo_tools.select_by_location(points, hua_nan, predicate="intersects") + logger.info(" 华南区域内的城市:%s", points_in_huanan["name"].tolist()) + + # ── 8. 统计汇总 ─────────────────────────────────────────────── + logger.info("\n[步骤 8] 属性统计汇总") + from geo_tools.analysis.stats import summarize_attributes + summary = summarize_attributes(points, columns=["value"], group_col="category") + logger.info(" 按分类汇总:\n%s", summary.to_string(index=False)) + + # ── 9. 写出结果 ─────────────────────────────────────────────── + logger.info("\n[步骤 9] 写出处理结果") + out_geojson = OUTPUT_DIR / "result_points_3857.geojson" + geo_tools.write_vector(points_3857, out_geojson) + logger.info(" GeoJSON 写出:%s", out_geojson) + + out_gpkg = OUTPUT_DIR / "results.gpkg" + geo_tools.write_gpkg(points, out_gpkg, layer="original_points") + geo_tools.write_gpkg(regions, out_gpkg, layer="regions", mode="a") + logger.info(" GPKG 写出(2 图层):%s", out_gpkg) + + logger.info("\n" + "=" * 60) + logger.info("工作流演示完成!输出目录:%s", OUTPUT_DIR) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..65140f2 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# tests package diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4810685 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,80 @@ +""" +tests/conftest.py +~~~~~~~~~~~~~~~~~ +共享测试夹具(Fixture)—— 提供测试数据,供所有测试文件复用。 +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import geopandas as gpd +import pytest +from shapely.geometry import Point, Polygon + + +# ── 示例 GeoDataFrame ────────────────────────────────────────────────────────── + +@pytest.fixture +def sample_points_gdf() -> gpd.GeoDataFrame: + """3 个点的 GeoDataFrame(WGS84)。""" + return gpd.GeoDataFrame( + { + "id": [1, 2, 3], + "name": ["点A", "点B", "点C"], + "value": [10.5, 20.0, 15.3], + }, + geometry=[Point(116.4, 39.9), Point(121.5, 31.2), Point(113.3, 23.1)], + crs="EPSG:4326", + ) + + +@pytest.fixture +def sample_polygon_gdf() -> gpd.GeoDataFrame: + """1 个矩形多边形的 GeoDataFrame(WGS84)。""" + poly = Polygon([(115.0, 38.0), (122.0, 38.0), (122.0, 41.0), (115.0, 41.0)]) + return gpd.GeoDataFrame( + {"region": ["华北区"], "area_km2": [450000.0]}, + geometry=[poly], + crs="EPSG:4326", + ) + + +@pytest.fixture +def sample_multi_polygon_gdf() -> gpd.GeoDataFrame: + """包含两个多边形的 GeoDataFrame,用于融合/叠置测试(WGS84)。""" + poly1 = Polygon([(100, 20), (110, 20), (110, 30), (100, 30)]) + poly2 = Polygon([(105, 20), (115, 20), (115, 30), (105, 30)]) + return gpd.GeoDataFrame( + {"zone": ["A", "B"], "value": [100, 200]}, + geometry=[poly1, poly2], + crs="EPSG:4326", + ) + + +# ── 临时文件路径 ─────────────────────────────────────────────────────────────── + +@pytest.fixture +def tmp_geojson_path(tmp_path: Path, sample_points_gdf: gpd.GeoDataFrame) -> Path: + """将 sample_points_gdf 写出为临时 GeoJSON 并返回路径。""" + path = tmp_path / "sample.geojson" + sample_points_gdf.to_file(str(path), driver="GeoJSON") + return path + + +@pytest.fixture +def tmp_gpkg_path(tmp_path: Path, sample_points_gdf: gpd.GeoDataFrame) -> Path: + """将 sample_points_gdf 写出为临时 GPKG 并返回路径。""" + path = tmp_path / "sample.gpkg" + sample_points_gdf.to_file(str(path), driver="GPKG", layer="points") + return path + + +@pytest.fixture +def tmp_output_dir(tmp_path: Path) -> Path: + """空的临时输出目录。""" + out = tmp_path / "output" + out.mkdir() + return out diff --git a/tests/test1.py b/tests/test1.py new file mode 100644 index 0000000..2b2ece4 --- /dev/null +++ b/tests/test1.py @@ -0,0 +1,20 @@ +import sys +import os +os.environ["OGR_ORGANIZE_POLYGONS"] = "SKIP" + +from pathlib import Path + +# 添加项目根目录到路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import geo_tools + +gdb_path = r"E:\@三普\@临时文件夹\临时数据库.gdb" + +# 列出图层 +# layers = geo_tools.list_gdb_layers(gdb_path) +# print(layers) + +# 读取图层 +gdf = geo_tools.read_gdb(gdb_path, layer="马关综合后图斑") +print(gdf.crs) \ No newline at end of file diff --git a/tests/test_analysis.py b/tests/test_analysis.py new file mode 100644 index 0000000..a86a0b5 --- /dev/null +++ b/tests/test_analysis.py @@ -0,0 +1,80 @@ +"""tests/test_analysis.py —— 空间分析单元测试。""" + +import pytest +import geopandas as gpd +from shapely.geometry import Point, Polygon + +from geo_tools.analysis.spatial_ops import overlay, select_by_location +from geo_tools.analysis.stats import area_weighted_mean, count_by_polygon, summarize_attributes + + +class TestOverlay: + def test_intersection(self, sample_multi_polygon_gdf): + poly_a = sample_multi_polygon_gdf.iloc[[0]].copy() + poly_b = sample_multi_polygon_gdf.iloc[[1]].copy() + result = overlay(poly_a, poly_b, how="intersection") + assert len(result) >= 1 + assert result.geometry.is_valid.all() + + def test_union(self, sample_multi_polygon_gdf): + poly_a = sample_multi_polygon_gdf.iloc[[0]].copy() + poly_b = sample_multi_polygon_gdf.iloc[[1]].copy() + result = overlay(poly_a, poly_b, how="union", keep_geom_type=False) + assert result.geometry.is_valid.all() + + +class TestSelectByLocation: + def test_select_points_in_polygon(self, sample_points_gdf, sample_polygon_gdf): + # polygon 覆盖华北区,应选中 北京 点 + result = select_by_location(sample_points_gdf, sample_polygon_gdf, predicate="intersects") + assert len(result) >= 1 + + def test_select_within(self, sample_points_gdf, sample_polygon_gdf): + result = select_by_location(sample_points_gdf, sample_polygon_gdf, predicate="within") + assert len(result) >= 0 # 可能有点在边界上 + + +class TestAreaWeightedMean: + def test_global_weighted_mean(self, sample_multi_polygon_gdf): + result = area_weighted_mean(sample_multi_polygon_gdf, value_col="value") + assert "area_weighted_mean" in result.index + assert result["area_weighted_mean"] > 0 + + def test_grouped_weighted_mean(self, sample_multi_polygon_gdf): + gdf = sample_multi_polygon_gdf.copy() + gdf["group"] = ["A", "B"] + result = area_weighted_mean(gdf, value_col="value", group_col="group") + assert "area_weighted_mean" in result.columns + assert len(result) == 2 + + +class TestSummarizeAttributes: + def test_basic_summary(self, sample_points_gdf): + result = summarize_attributes(sample_points_gdf, columns=["value"]) + assert "column" in result.columns + assert "mean" in result.columns + + def test_grouped_summary(self, sample_points_gdf): + gdf = sample_points_gdf.copy() + gdf["group"] = ["北方", "东部", "南方"] + result = summarize_attributes(gdf, columns=["value"], group_col="group") + # 每组一行 + assert len(result) == 3 + + +class TestCountByPolygon: + def test_count_points_in_polygons(self, sample_points_gdf, sample_polygon_gdf): + result = count_by_polygon(sample_points_gdf, sample_polygon_gdf) + assert "point_count" in result.columns + assert result["point_count"].dtype.kind == "i" # 整数 + + def test_polygon_with_no_points(self): + # 南海中的 polygon,不含任何点 + poly = Polygon([(115, 10), (120, 10), (120, 15), (115, 15)]) + polygons = gpd.GeoDataFrame({"id": [1]}, geometry=[poly], crs="EPSG:4326") + points = gpd.GeoDataFrame( + geometry=[Point(116.4, 39.9)], # 北京,不在 polygon 内 + crs="EPSG:4326", + ) + result = count_by_polygon(points, polygons) + assert result["point_count"].iloc[0] == 0 diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 0000000..a1cd50d --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,110 @@ +"""tests/test_geometry.py —— 几何运算单元测试。""" + +import pytest +from shapely.geometry import LineString, Point, Polygon + +import geo_tools +from geo_tools.core.geometry import ( + buffer_geometry, + bounding_box, + centroid, + contains, + convex_hull, + difference, + distance_between, + fix_geometry, + intersect, + intersects, + is_valid_geometry, + unary_union, + union, + within, +) + + +class TestIsValidGeometry: + def test_valid_polygon(self): + poly = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]) + assert is_valid_geometry(poly) is True + + def test_none_returns_false(self): + assert is_valid_geometry(None) is False + + def test_invalid_self_intersecting(self): + # 蝴蝶形(自相交) + bowtie = Polygon([(0, 0), (1, 1), (1, 0), (0, 1)]) + assert is_valid_geometry(bowtie) is False + + +class TestFixGeometry: + def test_fix_bowtie(self): + bowtie = Polygon([(0, 0), (1, 1), (1, 0), (0, 1)]) + assert not bowtie.is_valid + fixed = fix_geometry(bowtie) + assert fixed is not None + assert fixed.is_valid + + def test_valid_geometry_unchanged(self): + poly = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]) + fixed = fix_geometry(poly) + assert fixed.is_valid + assert fixed.area == pytest.approx(poly.area) + + def test_none_returns_none(self): + assert fix_geometry(None) is None + + +class TestBufferGeometry: + def test_point_buffer(self): + pt = Point(0, 0) + buf = buffer_geometry(pt, 1.0) + assert buf.area > 3.0 # π * r² ≈ 3.14 + + def test_zero_distance_returns_point_like(self): + pt = Point(0, 0) + buf = buffer_geometry(pt, 0.0) + # buffer(0) on point may return empty or point + assert buf is not None + + +class TestSetOperations: + @pytest.fixture + def poly_a(self): + return Polygon([(0, 0), (2, 0), (2, 2), (0, 2)]) + + @pytest.fixture + def poly_b(self): + return Polygon([(1, 0), (3, 0), (3, 2), (1, 2)]) + + def test_intersection(self, poly_a, poly_b): + result = intersect(poly_a, poly_b) + assert result.area == pytest.approx(2.0) + + def test_union(self, poly_a, poly_b): + result = union(poly_a, poly_b) + assert result.area == pytest.approx(6.0) + + def test_difference(self, poly_a, poly_b): + result = difference(poly_a, poly_b) + assert result.area == pytest.approx(2.0) + + def test_unary_union(self, poly_a, poly_b): + result = unary_union([poly_a, poly_b]) + assert result.area == pytest.approx(6.0) + + +class TestSpatialRelations: + def test_contains_true(self): + big = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)]) + small = Polygon([(1, 1), (2, 1), (2, 2), (1, 2)]) + assert contains(big, small) is True + + def test_within(self): + big = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)]) + small = Polygon([(1, 1), (2, 1), (2, 2), (1, 2)]) + assert within(small, big) is True + + def test_distance(self): + p1 = Point(0, 0) + p2 = Point(3, 4) + assert distance_between(p1, p2) == pytest.approx(5.0) diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..a628874 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,81 @@ +"""tests/test_io.py —— IO 读写单元测试。""" + +import pytest +import geopandas as gpd +from pathlib import Path + +from geo_tools.io.readers import read_vector, read_gpkg, list_gpkg_layers, read_csv_points +from geo_tools.io.writers import write_vector, write_gpkg, write_csv + + +class TestReadVector: + def test_read_geojson(self, tmp_geojson_path): + gdf = read_vector(tmp_geojson_path) + assert isinstance(gdf, gpd.GeoDataFrame) + assert len(gdf) == 3 + assert gdf.crs is not None + + def test_read_with_crs_reprojection(self, tmp_geojson_path): + gdf = read_vector(tmp_geojson_path, crs="EPSG:3857") + assert gdf.crs.to_epsg() == 3857 + + def test_read_nonexistent_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + read_vector(tmp_path / "nonexistent.geojson") + + def test_read_unsupported_format_raises(self, tmp_path): + bad_file = tmp_path / "data.xyz" + bad_file.write_text("dummy") + with pytest.raises(ValueError, match="不支持"): + read_vector(bad_file) + + +class TestWriteReadRoundtrip: + def test_geojson_roundtrip(self, sample_points_gdf, tmp_output_dir): + out = tmp_output_dir / "out.geojson" + write_vector(sample_points_gdf, out) + loaded = read_vector(out) + assert len(loaded) == len(sample_points_gdf) + assert list(loaded.columns) == list(sample_points_gdf.columns) + + def test_gpkg_roundtrip(self, sample_points_gdf, tmp_output_dir): + out = tmp_output_dir / "out.gpkg" + write_gpkg(sample_points_gdf, out, layer="points") + loaded = read_gpkg(out, layer="points") + assert len(loaded) == len(sample_points_gdf) + + def test_gpkg_multiple_layers(self, sample_points_gdf, sample_polygon_gdf, tmp_output_dir): + out = tmp_output_dir / "multi.gpkg" + write_gpkg(sample_points_gdf, out, layer="points") + write_gpkg(sample_polygon_gdf, out, layer="polygons", mode="a") + layers = list_gpkg_layers(out) + assert "points" in layers + assert "polygons" in layers + + def test_csv_roundtrip(self, sample_points_gdf, tmp_output_dir): + out = tmp_output_dir / "out.csv" + write_csv(sample_points_gdf, out) + # CSV 写出的是 WKT geometry 列,用 pandas 读回验证 + import pandas as pd + df = pd.read_csv(out) + assert "geometry" in df.columns # 存在 WKT 几何列 + assert len(df) == len(sample_points_gdf) # 行数一致 + # 再用 read_csv_points 以 WKT 模式读回 + from geo_tools.io.readers import _read_csv_vector + from pathlib import Path + gdf_back = _read_csv_vector(Path(out), wkt_col="geometry") + assert len(gdf_back) == len(sample_points_gdf) + + +class TestReadCsvPoints: + def test_read_csv_with_latlon(self, tmp_path): + import pandas as pd + csv_path = tmp_path / "points.csv" + pd.DataFrame({ + "longitude": [116.4, 121.5], + "latitude": [39.9, 31.2], + "name": ["北京", "上海"], + }).to_csv(csv_path, index=False) + gdf = read_csv_points(csv_path) + assert len(gdf) == 2 + assert gdf.crs.to_epsg() == 4326 diff --git a/tests/test_proj.py b/tests/test_proj.py new file mode 100644 index 0000000..5d95a8d --- /dev/null +++ b/tests/test_proj.py @@ -0,0 +1,21 @@ +import sys +import os +os.environ["OGR_ORGANIZE_POLYGONS"] = "SKIP" + +from pathlib import Path + + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +import geo_tools +from geo_tools.core import projection +from geo_tools.config.project_enum import CRS + +info = projection.get_crs_info(CRS.CGCS2000_6_DEGREE_ZONE_18.value) +print(info) +print(type(CRS.CGCS2000_3_DEGREE_ZONE_27)) + +# aa = geo_tools.read_vector(r"E:\@三普\@临时文件夹\样点异常值剔除\容县\异常样点数据\AB_outliers.shp") +# projection.reproject_gdf(aa,CRS.CGCS2000_3_DEGREE_ZONE_37).to_file(r"E:\@三普\@临时文件夹\样点异常值剔除\容县\AB_ou.shp") \ No newline at end of file diff --git a/tests/test_vector.py b/tests/test_vector.py new file mode 100644 index 0000000..7ec21d5 --- /dev/null +++ b/tests/test_vector.py @@ -0,0 +1,100 @@ +"""tests/test_vector.py —— 矢量操作单元测试。""" + +import pytest +import geopandas as gpd +from shapely.geometry import Point + +from geo_tools.core.vector import ( + add_area_column, + clip_to_extent, + dissolve_by, + drop_invalid_geometries, + explode_multipart, + reproject, + set_crs, + spatial_join, +) + + +class TestReproject: + def test_basic_reproject(self, sample_points_gdf): + result = reproject(sample_points_gdf, "EPSG:3857") + assert result.crs.to_epsg() == 3857 + assert len(result) == len(sample_points_gdf) + + def test_reproject_preserves_count(self, sample_points_gdf): + result = reproject(sample_points_gdf, "EPSG:4490") + assert len(result) == 3 + + def test_reproject_no_crs_raises(self): + gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)]) # 没有 CRS + with pytest.raises(ValueError, match="CRS"): + reproject(gdf, "EPSG:4326") + + +class TestSetCrs: + def test_set_crs_on_new_gdf(self): + gdf = gpd.GeoDataFrame(geometry=[Point(116.4, 39.9)]) + result = set_crs(gdf, "EPSG:4326") + assert result.crs.to_epsg() == 4326 + + def test_overwrite_blocked_by_default(self): + gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)], crs="EPSG:4326") + with pytest.raises(ValueError, match="overwrite"): + set_crs(gdf, "EPSG:3857") + + def test_overwrite_allowed(self): + gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)], crs="EPSG:4326") + result = set_crs(gdf, "EPSG:3857", overwrite=True) + assert result.crs.to_epsg() == 3857 + + +class TestClipToExtent: + def test_clip_by_bbox(self, sample_points_gdf): + # 只包含北京(116.4, 39.9)的 bbox + result = clip_to_extent(sample_points_gdf, (115.0, 38.0, 118.0, 41.0)) + assert len(result) == 1 + + def test_clip_by_geodataframe(self, sample_points_gdf, sample_polygon_gdf): + # polygon 覆盖 115-122E,38-41N,应该包含北京 + result = clip_to_extent(sample_points_gdf, sample_polygon_gdf) + assert len(result) >= 1 + + +class TestDissolveBy: + def test_dissolve_by_field(self, sample_multi_polygon_gdf): + gdf = sample_multi_polygon_gdf.copy() + gdf["group"] = ["X", "X"] # 两条都归入同一组 + result = dissolve_by(gdf, by="group") + assert len(result) == 1 + + def test_dissolve_preserves_crs(self, sample_multi_polygon_gdf): + gdf = sample_multi_polygon_gdf.copy() + gdf["group"] = ["A", "B"] + result = dissolve_by(gdf, by="group") + assert result.crs == gdf.crs + + +class TestAddAreaColumn: + def test_area_column_added(self, sample_polygon_gdf): + result = add_area_column(sample_polygon_gdf, col_name="area_m2") + assert "area_m2" in result.columns + assert result["area_m2"].iloc[0] > 0 + + +class TestDropInvalidGeometries: + def test_drop_invalid(self): + from shapely.geometry import Polygon + valid = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]) + invalid = Polygon([(0, 0), (1, 1), (1, 0), (0, 1)]) # 蝴蝶形 + gdf = gpd.GeoDataFrame(geometry=[valid, invalid], crs="EPSG:4326") + result = drop_invalid_geometries(gdf) + assert len(result) == 1 + + def test_fix_invalid(self): + from shapely.geometry import Polygon + invalid = Polygon([(0, 0), (1, 1), (1, 0), (0, 1)]) + gdf = gpd.GeoDataFrame(geometry=[invalid], crs="EPSG:4326") + result = drop_invalid_geometries(gdf, fix=True) + assert len(result) == 1 + assert result.geometry.is_valid.all()