Skip to content

Commit

Permalink
add gradient API placeholder for SF solver
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenjia-xu committed Dec 27, 2024
1 parent 2fcdab1 commit 42d82e5
Showing 1 changed file with 51 additions and 26 deletions.
77 changes: 51 additions & 26 deletions genesis/engine/solvers/sf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class SFSolver(Solver):
Stable Fluid solver for eulerian-based gaseous simulation.
"""

# ------------------------------------------------------------------------------------
# --------------------------------- Initialization -----------------------------------
# ------------------------------------------------------------------------------------

def __init__(self, scene, sim, options):
super().__init__(scene, sim, options)

Expand All @@ -37,7 +41,6 @@ def set_jets(self, jets):
self.jets = jets

def build(self):

if self.is_active():
self.t = 0.0
self.setup_fields()
Expand Down Expand Up @@ -69,25 +72,6 @@ def init_fields(self):
for q in ti.static(range(self.grid.q.n)):
self.grid.q[i, j, k][q] = 0.0

def substep_pre_coupling(self, f):
self.advect_and_impulse(f, self.t)
self.divergence()

# projection
self.reset_swap()
self.pressure_to_swap()
for _ in range(self.solver_iters):
self.pressure_jacobi(self.p_swap.cur, self.p_swap.nxt)
self.p_swap.swap()
self.pressure_from_swap()
self.reset_swap()

self.subtract_gradient()
self.t += self.dt

def substep_post_coupling(self, f):
return

def reset_swap(self):
self.p_swap.cur.fill(0)
self.p_swap.nxt.fill(0)
Expand Down Expand Up @@ -166,7 +150,6 @@ def pressure_from_swap(self):
@ti.kernel
def subtract_gradient(self):
for i, j, k in ti.ndrange(*self.res):

pl = self.grid.p[self.compute_location(i, j, k, -1, 0, 0)]
pr = self.grid.p[self.compute_location(i, j, k, 1, 0, 0)]
pb = self.grid.p[self.compute_location(i, j, k, 0, -1, 0)]
Expand Down Expand Up @@ -252,21 +235,63 @@ def backtrace(self, vf, p, dt):
p -= dt * ((2 / 9) * v1 + (1 / 3) * v2 + (4 / 9) * v3)
return p

def get_state(self, f):
return None
# ------------------------------------------------------------------------------------
# ------------------------------------ stepping --------------------------------------
# ------------------------------------------------------------------------------------

def process_input(self, in_backward):
return None

def save_ckpt(self, f):
def substep_pre_coupling(self, f):
self.advect_and_impulse(f, self.t)
self.divergence()

# projection
self.reset_swap()
self.pressure_to_swap()
for _ in range(self.solver_iters):
self.pressure_jacobi(self.p_swap.cur, self.p_swap.nxt)
self.p_swap.swap()
self.pressure_from_swap()
self.reset_swap()

self.subtract_gradient()
self.t += self.dt

def substep_post_coupling(self, f):
return

def reset_grad(self):
return None

def set_state(self, f, state):
# ------------------------------------------------------------------------------------
# --------------------------------------- io -----------------------------------------
# ------------------------------------------------------------------------------------

def get_state(self, f):
return None

def reset_grad(self):
def set_state(self, f, state):
return None

# ------------------------------------------------------------------------------------
# ------------------------------------ gradient --------------------------------------
# ------------------------------------------------------------------------------------
def collect_output_grads(self):
"""
Collect gradients from downstream queried states.
"""
pass

def add_grad_from_state(self, state):
pass

def save_ckpt(self, ckpt_name):
pass

def load_ckpt(self, ckpt_name):
pass


class TexPair:
def __init__(self, cur, nxt):
Expand Down

0 comments on commit 42d82e5

Please sign in to comment.