-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
70 lines (51 loc) · 1.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch as th
import torch.distributed as dist
import argparse
import time
import sys
import os
from grid import condition
from multigrid import MultiGrid, FullMultiGrid
from input.func import origin_func, target_func
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--n', '-n', type=int, default=16)
parser.add_argument('--process', '-p', type=int, default=1)
parser.add_argument('--device', type=str, default='cpu')
args = parser.parse_args()
n, p = args.n // args.process, args.process
device = th.device(args.device)
rank = os.environ.get("LOCAL_RANK")
rank = 0 if rank is None else int(rank)
index = th.tensor([rank % p, rank // p], device=device) # x, y
if p > 1:
assert p ** 2 == int(os.environ.get("WORLD_SIZE"))
dist.init_process_group(backend='gloo')
MG_method = MultiGrid(index, p, device)
cond_method = condition(n, index, p, device)
b = cond_method(origin_func, target_func) # condition(origin_func)
w = n - 1 if p == 1 else n
u = th.zeros((w)**2, device=device)
# u_list = []
for _ in range(25):
u = MG_method(u, b, n=n)
# u_list.append(u)
if rank == 0:
print(u.view(w, w)[:8, :8])
# if rank == 3:
# print(u.view(w, w)[0:8, 0:8])
# e_list = [th.norm(u_list[i] - u_t) for i in range(len(u_list))]
# print(e_list)
# o_list = [th.log2(e_list[i]) - th.log2(e_list[i+1]) for i in range(len(e_list)-1)]
# print(o_list)
# print(u_c.view(n-1, n-1))
# method = FullMultiGrid(n)
# u = method(b, n=n)
# u = th.rand((n-1)**2, device=device)
# u_list = []
# G_method = smooth(A, w=2.0/3)
# for _ in range(200):
# u = G_method(u, b, m=1)
# u_list.append(u)
# # print(u_list[-1].view(n-1, n-1))
# print(*[th.norm(u_list[i] - u_t).item() for i in range(0, len(u_list), 10)])