Files
geo_tools/tests/test_vector.py
2026-03-04 17:07:07 +08:00

101 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""tests/test_vector.py —— 矢量操作单元测试。"""
import pytest
import geopandas as gpd
from shapely.geometry import Point
from geo_tools.core.vector import (
add_area_column,
clip_to_extent,
dissolve_by,
drop_invalid_geometries,
explode_multipart,
reproject,
set_crs,
spatial_join,
)
class TestReproject:
def test_basic_reproject(self, sample_points_gdf):
result = reproject(sample_points_gdf, "EPSG:3857")
assert result.crs.to_epsg() == 3857
assert len(result) == len(sample_points_gdf)
def test_reproject_preserves_count(self, sample_points_gdf):
result = reproject(sample_points_gdf, "EPSG:4490")
assert len(result) == 3
def test_reproject_no_crs_raises(self):
gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)]) # 没有 CRS
with pytest.raises(ValueError, match="CRS"):
reproject(gdf, "EPSG:4326")
class TestSetCrs:
def test_set_crs_on_new_gdf(self):
gdf = gpd.GeoDataFrame(geometry=[Point(116.4, 39.9)])
result = set_crs(gdf, "EPSG:4326")
assert result.crs.to_epsg() == 4326
def test_overwrite_blocked_by_default(self):
gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)], crs="EPSG:4326")
with pytest.raises(ValueError, match="overwrite"):
set_crs(gdf, "EPSG:3857")
def test_overwrite_allowed(self):
gdf = gpd.GeoDataFrame(geometry=[Point(0, 0)], crs="EPSG:4326")
result = set_crs(gdf, "EPSG:3857", overwrite=True)
assert result.crs.to_epsg() == 3857
class TestClipToExtent:
def test_clip_by_bbox(self, sample_points_gdf):
# 只包含北京116.4, 39.9)的 bbox
result = clip_to_extent(sample_points_gdf, (115.0, 38.0, 118.0, 41.0))
assert len(result) == 1
def test_clip_by_geodataframe(self, sample_points_gdf, sample_polygon_gdf):
# polygon 覆盖 115-122E38-41N应该包含北京
result = clip_to_extent(sample_points_gdf, sample_polygon_gdf)
assert len(result) >= 1
class TestDissolveBy:
def test_dissolve_by_field(self, sample_multi_polygon_gdf):
gdf = sample_multi_polygon_gdf.copy()
gdf["group"] = ["X", "X"] # 两条都归入同一组
result = dissolve_by(gdf, by="group")
assert len(result) == 1
def test_dissolve_preserves_crs(self, sample_multi_polygon_gdf):
gdf = sample_multi_polygon_gdf.copy()
gdf["group"] = ["A", "B"]
result = dissolve_by(gdf, by="group")
assert result.crs == gdf.crs
class TestAddAreaColumn:
def test_area_column_added(self, sample_polygon_gdf):
result = add_area_column(sample_polygon_gdf, col_name="area_m2")
assert "area_m2" in result.columns
assert result["area_m2"].iloc[0] > 0
class TestDropInvalidGeometries:
def test_drop_invalid(self):
from shapely.geometry import Polygon
valid = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
invalid = Polygon([(0, 0), (1, 1), (1, 0), (0, 1)]) # 蝴蝶形
gdf = gpd.GeoDataFrame(geometry=[valid, invalid], crs="EPSG:4326")
result = drop_invalid_geometries(gdf)
assert len(result) == 1
def test_fix_invalid(self):
from shapely.geometry import Polygon
invalid = Polygon([(0, 0), (1, 1), (1, 0), (0, 1)])
gdf = gpd.GeoDataFrame(geometry=[invalid], crs="EPSG:4326")
result = drop_invalid_geometries(gdf, fix=True)
assert len(result) == 1
assert result.geometry.is_valid.all()