Skip to content

Commit

Permalink
Merge pull request #21 from kb1dds/kernel_proj
Browse files Browse the repository at this point in the history
Kernel projection for Sheaf.minimalExtend()
  • Loading branch information
kb1dds authored Jun 19, 2018
2 parents 93cf55e + f4fc242 commit d4771f3
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion pysheaf/pysheaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,45 @@ def minimalExtend(self,assignment, activeCells=None, testSupport=None, method='n
ac=[idx for idx in range(len(self.cells)) if idx not in support]
else:
ac=activeCells

if method == 'KernelProj':
if not self.isLinear():
raise NotImplementedError('KernelProj only works for sheaves of vector spaces')

if ord != 2:
warn('Kernel projection requires order 2 in minimalExtend')

# Compile dictionary of rows
rowstarts=dict()
rowidx=0
for i in support:
rowstarts[i]=rowidx
rowidx+=self.cells[i].stalkDim

newassignment = Section([sc for sc in assignment.sectionCells])

# Optimize each active cell independently
for i in ac:
if self.cells[i].stalkDim > 0:
# Matrix of all restrictions out of this cell into the support
mat=np.zeros((sum([self.cells[j].stalkDim for j in support]),
self.cells[i].stalkDim))

for cf in self.cofaces(i): # Iterate over all cofaces of this activeCell
try:
supportidx=support.index(cf.index)
mat[rowstarts[supportidx]:rowstarts[supportidx]+self.cells[supportidx].stalkDim,:]=cf.restriction.matrix
except ValueError:
pass

if self.isNumeric():
# Use least squares to solve for assignment rooted at this cell given the existing assignment
asg,bnds=self.serializeAssignment(assignment,activeCells=support) # Confusingly, activeSupport here refers *only* to the support of the assignment
result=np.linalg.lstsq(mat,asg)

newassignment.sectionCells.append(SectionCell(i,result[0]))

return newassignment
elif self.isNumeric():
initial_guess, bounds = self.serializeAssignment(assignment,ac)
res=scipy.optimize.minimize( fun = lambda sec: self.consistencyRadius(self.deserializeAssignment(sec,ac,assignment), testSupport=testSupport, ord=ord),
x0 = initial_guess,
Expand Down

0 comments on commit d4771f3

Please sign in to comment.