Skip to content

Commit

Permalink
efficient memory useage for fermion operators
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Sep 28, 2024
1 parent ac6c743 commit c2592db
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 45 deletions.
1 change: 1 addition & 0 deletions lib/cgpt/lib/exports.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ EXPORT_FUNCTION(munge_byte_order)
EXPORT_FUNCTION(munge_reconstruct_third_row)
EXPORT_FUNCTION(create_fermion_operator)
EXPORT_FUNCTION(update_fermion_operator)
EXPORT_FUNCTION(set_mass_fermion_operator)
EXPORT_FUNCTION(delete_fermion_operator)
EXPORT_FUNCTION(apply_fermion_operator)
EXPORT_FUNCTION(apply_fermion_operator_dirdisp)
Expand Down
13 changes: 13 additions & 0 deletions lib/cgpt/lib/operators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ EXPORT(update_fermion_operator,{
return PyLong_FromLong(0);
});

EXPORT(set_mass_fermion_operator,{

void* p;
PyObject* _args;
if (!PyArg_ParseTuple(args, "lO", &p, &_args)) {
return NULL;
}

((cgpt_fermion_operator_base*)p)->set_mass(_args);

return PyLong_FromLong(0);
});

EXPORT(delete_fermion_operator,{

void* p;
Expand Down
1 change: 1 addition & 0 deletions lib/cgpt/lib/operators/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ class cgpt_fermion_operator_base {
virtual RealD dirdisp(int opcode, PyObject* in, PyObject* out, int dir, int disp) = 0;
virtual RealD deriv(int opcode, PyObject* mat, PyObject* in, PyObject* out) = 0;
virtual void update(PyObject* args) = 0;
virtual void set_mass(PyObject* args) = 0;
};
24 changes: 24 additions & 0 deletions lib/cgpt/lib/operators/implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

template<typename WI> void cgpt_fermion_set_mass(WilsonFermion<WI>& op, PyObject* args) {
RealD mass = get_float(args,"mass");
op.mass = mass;
if (op.anisotropyCoeff.isAnisotropic){
op.diag_mass = op.mass + 1.0 + (Nd-1)*(op.anisotropyCoeff.nu / op.anisotropyCoeff.xi_0);
} else {
op.diag_mass = 4.0 + op.mass;
}
}

template<typename WI> void cgpt_fermion_set_mass(CayleyFermion5D<WI>& op, PyObject* args) {
RealD mass_plus = get_float(args,"mass_plus");
RealD mass_minus = get_float(args,"mass_minus");
op.SetMass(mass_plus, mass_minus);
}

template<typename T>
class cgpt_fermion_operator : public cgpt_fermion_operator_base {
public:
Expand Down Expand Up @@ -53,6 +70,10 @@ class cgpt_fermion_operator : public cgpt_fermion_operator_base {
op->ImportGauge(U);
}

virtual void set_mass(PyObject* args) {
cgpt_fermion_set_mass(*op, args);
}

};

template<typename T>
Expand Down Expand Up @@ -96,4 +117,7 @@ class cgpt_coarse_operator : public cgpt_fermion_operator_base {

op->ImportGauge(A, ASelfInv);
}

virtual void set_mass(PyObject* args) {
}
};
8 changes: 6 additions & 2 deletions lib/gpt/qcd/fermion/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class base(gpt.matrix_operator):
def __init__(self, name, U, params, otype, with_even_odd, daggered):
# keep constructor parameters
self.name = name
self.U = gpt.copy(U)
self.U = U
self.otype = otype
self.params_constructor = params
self.daggered = daggered
Expand Down Expand Up @@ -259,9 +259,13 @@ def adj(self):
)

def update(self, U):
gpt.copy(self.U, U)
self.U = U
self.params["U"] = [u.v_obj[0] for u in self.U]
self.interface.update(self.params)

def suspend(self):
self.interface.suspend()

def split(self, mpi_split):
split_grid = self.U_grid.split(mpi_split, self.U_grid.fdimensions)
U_split = [gpt.lattice(split_grid, x.otype) for x in self.U]
Expand Down
40 changes: 37 additions & 3 deletions lib/gpt/qcd/fermion/operator/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,80 @@
operator_tag = {}
operator_limbo = {}

verbose = g.default.is_verbose("fermion-operator")


class interface:
def __init__(self):
self.obj = None

def setup(self, name, grid, params):
def _setup(self, name, grid, params):
assert self.obj is None

tag_params = {x: params[x] for x in params if x not in ["U"]}
tag_params = {
x: params[x] for x in params if x not in ["U", "mass", "mass_plus", "mass_minus"]
}
tag = f"{name}_{grid.precision.cgpt_dtype}_{tag_params}"

if tag in operator_limbo and len(operator_limbo[tag]) > 0:
self.obj = operator_limbo[tag].pop()
# set mass needs to precede update for clover-type fermions
cgpt.set_mass_fermion_operator(self.obj, params)
cgpt.update_fermion_operator(self.obj, params)

if verbose:
g.message(f"Re-used fermion operator {tag}")
else:
# create new operator
self.obj = cgpt.create_fermion_operator(name, grid.precision.cgpt_dtype, params)
operator_tag[self.obj] = tag

if verbose:
g.message("Status of allocated fermion operators:")
statistics = {}
for tag in operator_tag:
if operator_tag[tag] not in statistics:
statistics[operator_tag[tag]] = 1
else:
statistics[operator_tag[tag]] += 1
for tag in statistics:
g.message(f" {statistics[tag]} of type {tag}")

def setup(self, name, grid, params):
self.setup_arguments = (name, grid, params)
self._setup(*self.setup_arguments)

def __del__(self):
self.suspend()

def suspend(self):
if self.obj is not None:
tag = operator_tag[self.obj]
if tag not in operator_limbo:
operator_limbo[tag] = [self.obj]
else:
operator_limbo[tag].append(self.obj)

# cgpt.delete_fermion_operator(self.obj)
self.obj = None

def update(self, params):
if self.obj is None:
# wake from suspended state
self._setup(*self.setup_arguments)
cgpt.update_fermion_operator(self.obj, params)

def apply_unary_operator(self, opcode, o, i):
assert self.obj is not None
# Grid has different calling conventions which we adopt in cgpt:
return cgpt.apply_fermion_operator(self.obj, opcode, i.v_obj, o.v_obj)

def apply_dirdisp_operator(self, opcode, o, i, dir, disp):
assert self.obj is not None
# Grid has different calling conventions which we adopt in cgpt:
return cgpt.apply_fermion_operator_dirdisp(self.obj, opcode, i.v_obj, o.v_obj, dir, disp)

def apply_deriv_operator(self, opcode, m, u, v):
assert self.obj is not None
# Grid has different calling conventions which we adopt in cgpt:
return cgpt.apply_fermion_operator_deriv(
self.obj, opcode, [y for x in m for y in x.v_obj], u.v_obj, v.v_obj
Expand Down
6 changes: 6 additions & 0 deletions lib/gpt/qcd/pseudofermion/action/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, M, inverter, operator):
self.M = g.core.util.to_list(M)
self.inverter = inverter
self.operator = operator
self._suspend()

def _updated(self, fields):
U = fields[0:-1]
Expand All @@ -35,6 +36,11 @@ def _updated(self, fields):
m.update(U)
return self.M + [U, psi]

def _suspend(self):
for m in self.M:
# suspend operator until it is reactivated by update
m.suspend()

def _allocate_force(self, U):
frc = g.group.cartesian(U)
for f in frc:
Expand Down
78 changes: 46 additions & 32 deletions lib/gpt/qcd/pseudofermion/action/exact_one_flavor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def _mat(dst, src):

def matrix_spectral_range(self, fields, algorithm):
evec, evals = algorithm(self.matrix(fields), fields[-1])
self._suspend()
return (min(evals.real), max(evals.real))

def __call__(self, fields):
phi = fields[-1]
M = self.matrix(fields)
psi = g(M * phi)
self._suspend()
return g.inner_product(phi, psi).real

def inv_sqrt_matrix(self, fields, rational_function):
Expand All @@ -85,18 +87,30 @@ def inv_sqrt_matrix(self, fields, rational_function):
m1 = self.m1
m2 = self.m2

P11 = [
self.operator(m1, m1 + d * (m2 - m1)).updated(U).propagator(self.inverter) for d in dm
]
P12 = [
self.operator(m1 + d * (m2 - m1), m2).updated(U).propagator(self.inverter) for d in dm
]
operator = self.operator
inverter = self.inverter

def _mat(dst, src):
dst @= diag * src
for i in range(len(cs)):
dst += cs[i] * (m2 - m1) * Pplus * g.gamma[5] * P12[i] * Pplus * src
dst -= cs[i] * (m2 - m1) * Pminus * g.gamma[5] * P11[i] * Pminus * src

op = operator(m1 + dm[i] * (m2 - m1), m2)
op.update(U)
P12_i = op.propagator(inverter)

dst += cs[i] * (m2 - m1) * Pplus * g.gamma[5] * P12_i * Pplus * src

P12_i = None

op = operator(m1, m1 + dm[i] * (m2 - m1))
op.update(U)
P11_i = op.propagator(inverter)

dst -= cs[i] * (m2 - m1) * Pminus * g.gamma[5] * P11_i * Pminus * src

P11_i = None
op = None

dst *= rational_function.norm

return g.matrix_operator(_mat)
Expand All @@ -118,48 +132,47 @@ def gradient(self, fields, dfields):

frc = self._allocate_force(U)

g.barrier()
g.message("checkmark 0")
g.barrier()
# g.barrier()
# g.message("checkmark 0")
# g.barrier()

inv_M12_adj = self.inverter(M12_adj)

g.barrier()
g.message("checkmark 1")
g.barrier()
# g.barrier()
# g.message("checkmark 1")
# g.barrier()

inv_M11_adj = self.inverter(M11_adj)

g.barrier()
g.message("checkmark 2")
g.barrier()

# g.barrier()
# g.message("checkmark 2")
# g.barrier()

m1 = self.m1
m2 = self.m2

w_plus = g(inv_M12_adj * M12.R * g.gamma[5] * M12.ImportUnphysicalFermion * Pplus * phi)

g.barrier()
g.message("checkmark 3")
g.barrier()
# g.barrier()
# g.message("checkmark 3")
# g.barrier()

w_minus = g(inv_M11_adj * M11.R * g.gamma[5] * M11.ImportUnphysicalFermion * Pminus * phi)

g.barrier()
g.message("checkmark 4")
g.barrier()
# g.barrier()
# g.message("checkmark 4")
# g.barrier()

w2_plus = g(g.gamma[5] * M12.R * M12.Dminus.adj() * w_plus)
w3_plus = g(Pplus * phi)

w2_minus = g(g.gamma[5] * M11.R * M11.Dminus.adj() * w_minus)
w3_minus = g(Pminus * phi)

g.barrier()
g.message("checkmark 5")
g.barrier()
# g.barrier()
# g.message("checkmark 5")
# g.barrier()

self._accumulate(frc, M12.M_projected_gradient(w_plus, w2_plus), m1 - m2)
self._accumulate(
frc,
Expand All @@ -182,4 +195,5 @@ def gradient(self, fields, dfields):
# not yet implemented
dS.append(None)

self._suspend()
return dS
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def MderivDag(op):
# M = ( OE OO ) = ( 0 1 ) ( 0 1 ) ( OE OO )
# Mhat = EE - EO OO^-1 OE
class MMdag_evenodd:
def M(self, op):
def M(op):
tmp = op.Mooee.vector_space[0].lattice()

def operator(dst, src):
Expand All @@ -57,7 +57,7 @@ def operator(dst, src):

return g.matrix_operator(mat=operator, vector_space=op.Mooee.vector_space)

def Mdag(self, op):
def Mdag(op):
tmp = op.Mooee.vector_space[0].lattice()

def operator(dst, src):
Expand All @@ -69,7 +69,7 @@ def operator(dst, src):

return g.matrix_operator(mat=operator, vector_space=op.Mooee.vector_space)

def MMdag(self, op):
def MMdag(op):
def spawn(op):
tmp = [op.Mooee.vector_space[0].lattice() for _ in [0, 1]]

Expand All @@ -94,7 +94,7 @@ def operator(dst, src):

return spawn(op)

def Mderiv(self, op):
def Mderiv(op):
tmp = [op.Mooee.vector_space[0].lattice() for _ in [0, 1]]

def operator(left, right):
Expand All @@ -115,7 +115,7 @@ def operator(left, right):

return operator

def MderivDag(self, op):
def MderivDag(op):
tmp = [op.Mooee.vector_space[0].lattice() for _ in [0, 1]]

def operator(left, right):
Expand Down
Loading

0 comments on commit c2592db

Please sign in to comment.