-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
/
Copy pathtest_csv.py
85 lines (67 loc) · 2.64 KB
/
test_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import tempfile
import numpy as np
import pandas as pd
def test_preserve_numpy_arrays_in_csv():
print("\nRunning: test_preserve_numpy_arrays_in_csv")
df = pd.DataFrame({
"id": [1, 2],
"embedding": [
np.array([0.1, 0.2, 0.3]),
np.array([0.4, 0.5, 0.6]),
],
})
with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
path = tmp.name
df.to_csv(path, index=False, preserve_complex=True)
df_loaded = pd.read_csv(path, preserve_complex=True)
assert isinstance(
df_loaded["embedding"][0], np.ndarray
), "Test Failed: The CSV did not preserve embeddings as NumPy arrays!"
print("PASS: test_preserve_numpy_arrays_in_csv")
def test_preserve_numpy_arrays_in_csv_empty_dataframe():
print("\nRunning: test_preserve_numpy_arrays_in_csv_empty_dataframe")
df = pd.DataFrame({"embedding": []})
expected = "embedding\n"
with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
path = tmp.name
df.to_csv(path, index=False, preserve_complex=True)
with open(path, encoding="utf-8") as f:
result = f.read()
msg = (
f"CSV output mismatch for empty DataFrame.\n"
f"Got:\n{result}\nExpected:\n{expected}"
)
assert result == expected, msg
print("PASS: test_preserve_numpy_arrays_in_csv_empty_dataframe")
def test_preserve_numpy_arrays_in_csv_mixed_dtypes():
print("\nRunning: test_preserve_numpy_arrays_in_csv_mixed_dtypes")
df = pd.DataFrame({
"id": [101, 102],
"name": ["alice", "bob"],
"scores": [
np.array([95.5, 88.0]),
np.array([76.0, 90.5]),
],
"age": [25, 30],
})
with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
path = tmp.name
df.to_csv(path, index=False, preserve_complex=True)
df_loaded = pd.read_csv(path, preserve_complex=True)
err_scores = "Failed: 'scores' column not deserialized as np.ndarray."
assert isinstance(df_loaded["scores"][0], np.ndarray), err_scores
assert df_loaded["id"].dtype == np.int64, (
"Failed: 'id' should still be int."
)
assert df_loaded["name"].dtype == object, (
"Failed: 'name' should still be object/string."
)
assert df_loaded["age"].dtype == np.int64, (
"Failed: 'age' should still be int."
)
print("PASS: test_preserve_numpy_arrays_in_csv_mixed_dtypes")
if __name__ == "__main__":
test_preserve_numpy_arrays_in_csv()
test_preserve_numpy_arrays_in_csv_empty_dataframe()
test_preserve_numpy_arrays_in_csv_mixed_dtypes()
print("\nDone.")