82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
"""tests/test_io.py —— IO 读写单元测试。"""
|
|
|
|
import pytest
|
|
import geopandas as gpd
|
|
from pathlib import Path
|
|
|
|
from geo_tools.io.readers import read_vector, read_gpkg, list_gpkg_layers, read_csv_points
|
|
from geo_tools.io.writers import write_vector, write_gpkg, write_csv
|
|
|
|
|
|
class TestReadVector:
|
|
def test_read_geojson(self, tmp_geojson_path):
|
|
gdf = read_vector(tmp_geojson_path)
|
|
assert isinstance(gdf, gpd.GeoDataFrame)
|
|
assert len(gdf) == 3
|
|
assert gdf.crs is not None
|
|
|
|
def test_read_with_crs_reprojection(self, tmp_geojson_path):
|
|
gdf = read_vector(tmp_geojson_path, crs="EPSG:3857")
|
|
assert gdf.crs.to_epsg() == 3857
|
|
|
|
def test_read_nonexistent_raises(self, tmp_path):
|
|
with pytest.raises(FileNotFoundError):
|
|
read_vector(tmp_path / "nonexistent.geojson")
|
|
|
|
def test_read_unsupported_format_raises(self, tmp_path):
|
|
bad_file = tmp_path / "data.xyz"
|
|
bad_file.write_text("dummy")
|
|
with pytest.raises(ValueError, match="不支持"):
|
|
read_vector(bad_file)
|
|
|
|
|
|
class TestWriteReadRoundtrip:
|
|
def test_geojson_roundtrip(self, sample_points_gdf, tmp_output_dir):
|
|
out = tmp_output_dir / "out.geojson"
|
|
write_vector(sample_points_gdf, out)
|
|
loaded = read_vector(out)
|
|
assert len(loaded) == len(sample_points_gdf)
|
|
assert list(loaded.columns) == list(sample_points_gdf.columns)
|
|
|
|
def test_gpkg_roundtrip(self, sample_points_gdf, tmp_output_dir):
|
|
out = tmp_output_dir / "out.gpkg"
|
|
write_gpkg(sample_points_gdf, out, layer="points")
|
|
loaded = read_gpkg(out, layer="points")
|
|
assert len(loaded) == len(sample_points_gdf)
|
|
|
|
def test_gpkg_multiple_layers(self, sample_points_gdf, sample_polygon_gdf, tmp_output_dir):
|
|
out = tmp_output_dir / "multi.gpkg"
|
|
write_gpkg(sample_points_gdf, out, layer="points")
|
|
write_gpkg(sample_polygon_gdf, out, layer="polygons", mode="a")
|
|
layers = list_gpkg_layers(out)
|
|
assert "points" in layers
|
|
assert "polygons" in layers
|
|
|
|
def test_csv_roundtrip(self, sample_points_gdf, tmp_output_dir):
|
|
out = tmp_output_dir / "out.csv"
|
|
write_csv(sample_points_gdf, out)
|
|
# CSV 写出的是 WKT geometry 列,用 pandas 读回验证
|
|
import pandas as pd
|
|
df = pd.read_csv(out)
|
|
assert "geometry" in df.columns # 存在 WKT 几何列
|
|
assert len(df) == len(sample_points_gdf) # 行数一致
|
|
# 再用 read_csv_points 以 WKT 模式读回
|
|
from geo_tools.io.readers import _read_csv_vector
|
|
from pathlib import Path
|
|
gdf_back = _read_csv_vector(Path(out), wkt_col="geometry")
|
|
assert len(gdf_back) == len(sample_points_gdf)
|
|
|
|
|
|
class TestReadCsvPoints:
|
|
def test_read_csv_with_latlon(self, tmp_path):
|
|
import pandas as pd
|
|
csv_path = tmp_path / "points.csv"
|
|
pd.DataFrame({
|
|
"longitude": [116.4, 121.5],
|
|
"latitude": [39.9, 31.2],
|
|
"name": ["北京", "上海"],
|
|
}).to_csv(csv_path, index=False)
|
|
gdf = read_csv_points(csv_path)
|
|
assert len(gdf) == 2
|
|
assert gdf.crs.to_epsg() == 4326
|