Skip to content

Commit

Permalink
Merge pull request #426 from ironjr/master
Browse files Browse the repository at this point in the history
Fixed device mismatch error of Symbolic_KANLayer
  • Loading branch information
KindXiaoming authored Aug 28, 2024
2 parents c0d9981 + e29c384 commit 173dadd
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions kan/Symbolic_KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-1
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
if random == False:
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
else:
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
return None
else:
#initialize from x & y and fun
Expand All @@ -237,9 +237,9 @@ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-1
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun
if random == False:
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
else:
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
return None

def swap(self, i1, i2, mode='in'):
Expand Down

0 comments on commit 173dadd

Please sign in to comment.