Skip to content

Commit

Permalink
Merge pull request #424 from DeepLink-org/tzy/add_dropout
Browse files Browse the repository at this point in the history
feat: add dropout for ascend
  • Loading branch information
jinminxi104 authored Nov 15, 2023
2 parents 5a4a549 + 9cc9bc4 commit 5fcb3da
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
10 changes: 10 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,16 @@ def __init__(self):
super().__init__("Reciprocal")


class DropOutGenMaskV4(Operator):
def __init__(self):
super().__init__("DropOutGenMaskV4")


class DropOutDoMaskV3(Operator):
def __init__(self):
super().__init__("DropOutDoMaskV3")


def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return a, b, c

Expand Down
15 changes: 15 additions & 0 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,3 +1307,18 @@ def Reciprocal(name, x):
op = OP(name, "Reciprocal")
op.set_input("x", x)
return op.to_node()

@staticmethod
def DropOutGenMaskV4(name, shape, prob):
op = OP(name, "DropOutGenMaskV4")
op.set_input("shape", shape)
op.set_input("prob", prob)
return op.to_node()

@staticmethod
def DropOutDoMaskV3(name, x, mask, keep_prob):
op = OP(name, "DropOutDoMaskV3")
op.set_input("x", x)
op.set_input("mask", mask)
op.set_input("keep_prob", keep_prob)
return op.to_node()
3 changes: 1 addition & 2 deletions dicp/dicp/vendor/AscendGraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@


aten = torch.ops.aten
decomp_keys = [aten.native_dropout.default,
aten.native_dropout_backward.default]
decomp_keys = []


def get_decomp():
Expand Down
26 changes: 26 additions & 0 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,29 @@ def AddCMul(self, a, b, c, value=1):
@register_conversion(torch.ops.aten.reciprocal.default)
def Reciprocal(self, x):
return self.get_proxy(ascend_op.Reciprocal, (x,))

@register_conversion(torch.ops.aten.native_dropout.default)
def NativeDropout(self, x, p, train):
assert train is True
dtype = x.node.meta['val'].dtype
p = 1. - p
shape = self.get_proxy(ascend_op.Shape, (x,))
prob = self.get_proxy(ascend_op.Const, ([float(p)], torch.float, []))
mask = self.get_proxy(ascend_op.DropOutGenMaskV4, (shape, prob))
prob_op = prob
if dtype == torch.float16:
cast = self.get_proxy(ascend_op.Cast, (prob, "FLOAT16"))
prob_op = cast
do_mask = self.get_proxy(ascend_op.DropOutDoMaskV3, (x, mask, prob_op))
return self.get_proxy(ascend_op.IdentityN, (do_mask, mask))

@register_conversion(torch.ops.aten.native_dropout_backward.default)
def NativeDropoutBackward(self, grad_output, mask, scale):
dtype = grad_output.node.meta['val'].dtype
p = 1. - scale
prob = self.get_proxy(ascend_op.Const, ([float(p)], torch.float, []))
prob_op = prob
if dtype == torch.float16:
cast = self.get_proxy(ascend_op.Cast, (prob, "FLOAT16"))
prob_op = cast
return self.get_proxy(ascend_op.DropOutDoMaskV3, (grad_output, mask, prob_op))

0 comments on commit 5fcb3da

Please sign in to comment.