Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

零散问题修复 #511

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion paconvert/api_alias_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
"torch.bilinear": "torch.nn.functional.bilinear",
"torch.celu_": "torch.nn.functional.celu_",
"torch.channel_shuffle": "torch.nn.functional.channel_shuffle",
"torch.concatenate": "torch.cat",
"torch.clip": "torch.clamp",
"torch.concatenate": "torch.cat",
"torch.conv1d": "torch.nn.functional.conv1d",
"torch.conv2d": "torch.nn.functional.conv2d",
"torch.conv3d": "torch.nn.functional.conv3d",
Expand Down
63 changes: 33 additions & 30 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"num_embeddings",
"embedding_dim",
"padding_idx",
"max_norm",
"norm_type",
"scale_grad_by_freq",
"sparse",
Expand All @@ -72,7 +73,8 @@
"padding_idx",
"norm_type",
"scale_grad_by_freq",
"sparse"
"sparse",
"max_norm"
]
},
"fairscale.nn.model_parallel.layers.RowParallelLinear": {
Expand Down Expand Up @@ -125,14 +127,14 @@
"unsupport_args": [
"window_size",
"alibi_slopes",
"deterministic",
"return_attn_probs"
"deterministic"
],
"kwargs_change": {
"q": "query",
"k": "key",
"v": "value",
"dropout_p": "dropout"
"dropout_p": "dropout",
"return_attn_probs": "return_softmax"
}
},
"flash_attn.flash_attn_interface.flash_attn_unpadded_func": {
Expand All @@ -157,15 +159,15 @@
"unsupport_args": [
"window_size",
"alibi_slopes",
"deterministic",
"return_attn_probs"
"deterministic"
],
"kwargs_change": {
"q": "query",
"k": "key",
"v": "value",
"dropout_p": "dropout",
"softmax_scale": "scale"
"softmax_scale": "scale",
"return_attn_probs": "return_softmax"
}
},
"flash_attn.layers.rotary.apply_rotary_emb_func": {
Expand All @@ -189,7 +191,10 @@
"x",
"weight",
"epsilon"
]
],
"kwargs_change": {
"weight": "norm_weight"
}
},
"os.environ.get": {
"Matcher": "OsEnvironGetMatcher",
Expand Down Expand Up @@ -2172,6 +2177,7 @@
},
"torch.Tensor.is_inference": {
"Matcher": "Is_InferenceMatcher",
"paddle_api": "paddle.Tensor.stop_gradient",
"min_input_args": 0
},
"torch.Tensor.is_pinned": {
Expand Down Expand Up @@ -3192,7 +3198,7 @@
]
},
"torch.Tensor.positive": {
"Matcher": "PositiveMatcher"
"Matcher": "PositiveMatcher"
},
"torch.Tensor.pow": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -3571,7 +3577,7 @@
}
},
"torch.Tensor.scatter_reduce": {
"Matcher": "ScatterReduceMatcher",
"Matcher": "ScatterReduceMatcher",
"paddle_api": "paddle.Tensor.put_along_axis",
"min_input_args": 3,
"args_list": [
Expand Down Expand Up @@ -4024,7 +4030,8 @@
]
},
"torch.Tensor.to_sparse_coo": {
"Matcher": "TensorToSparseCooMatcher"
"Matcher": "TensorToSparseCooMatcher",
"paddle_api": "paddle.Tensor.to_sparse_coo"
},
"torch.Tensor.tolist": {
"Matcher": "UnchangeMatcher",
Expand Down Expand Up @@ -7836,6 +7843,16 @@
"axis": 0
}
},
"torch.float_power": {
"Matcher": "FloatPowerMatcher",
"min_input_args": 2,
"args_list": [
"input",
"exponent",
"*",
"out"
]
},
"torch.floor": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.floor",
Expand All @@ -7849,16 +7866,6 @@
"input": "x"
}
},
"torch.float_power": {
"Matcher": "FloatPowerMatcher",
"min_input_args": 2,
"args_list": [
"input",
"exponent",
"*",
"out"
]
},
"torch.floor_divide": {
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.floor_divide",
Expand Down Expand Up @@ -8518,6 +8525,7 @@
},
"torch.is_inference": {
"Matcher": "Is_InferenceMatcher",
"paddle_api": "paddle.Tensor.stop_gradient",
"min_input_args": 1,
"args_list": [
"input"
Expand Down Expand Up @@ -10723,7 +10731,7 @@
"min_input_args": 0
},
"torch.nn.GRUCell": {
"Matcher": "GenericMatcher",
"Matcher": "GRUCellMatcher",
"paddle_api": "paddle.nn.GRUCell",
"args_list": [
"input_size",
Expand Down Expand Up @@ -10967,7 +10975,7 @@
"min_input_args": 0
},
"torch.nn.LSTMCell": {
"Matcher": "GenericMatcher",
"Matcher": "LSTMCellMatcher",
"paddle_api": "paddle.nn.LSTMCell",
"args_list": [
"input_size",
Expand Down Expand Up @@ -11731,8 +11739,8 @@
"min_input_args": 3
},
"torch.nn.RNNCell": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.SimpleRNNCell",
"Matcher": "RNNCellMatcher",
"paddle_api": "paddle.SimpleRNNCell",
"args_list": [
"input_size",
"hidden_size",
Expand Down Expand Up @@ -14867,11 +14875,6 @@
"dims": "perm"
}
},
"torch.pi": {
"Matcher": "GenericMatcher",
"paddle_api": "numpy.pi",
"min_input_args": 0
},
"torch.pinverse": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.linalg.pinv",
Expand Down
91 changes: 80 additions & 11 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3238,19 +3238,26 @@ def generate_code(self, kwargs):

class AvgPoolMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" in kwargs:
kwargs["x"] = kwargs.pop("input")
new_kwargs = {}
for k in list(kwargs.keys()):
if k == "input":
new_kwargs["x"] = kwargs.pop(k)
else:
new_kwargs[k] = kwargs.pop(k)

if "count_include_pad" in kwargs:
kwargs["exclusive"] = "not " + kwargs.pop("count_include_pad")
if "count_include_pad" in new_kwargs:
new_kwargs["exclusive"] = "not " + new_kwargs.pop("count_include_pad")
else:
kwargs["exclusive"] = "False"
new_kwargs["exclusive"] = "False"

API_TEMPLATE = textwrap.dedent(
"""
{}({})
"""
)
code = API_TEMPLATE.format(self.get_paddle_api(), self.kwargs_to_str(kwargs))
code = API_TEMPLATE.format(
self.get_paddle_api(), self.kwargs_to_str(new_kwargs)
)

return code

Expand Down Expand Up @@ -4358,13 +4365,18 @@ def generate_code(self, kwargs):
class ConstantLRMatcher(LRSchedulerMatcher):
def generate_code(self, kwargs):
optim = kwargs["optimizer"]
factor = 0.3333333333333333
total_iters = 5
if "factor" in kwargs:
factor = kwargs.pop("factor")
kwargs["values"] = "[{}*{}.get_lr(), {}.get_lr()]".format(
factor, optim, optim
)
else:
kwargs["values"] = "[{}.get_lr()/3, {}.get_lr()]".format(optim, optim)

if "total_iters" in kwargs:
total_iters = kwargs.pop("total_iters")
kwargs["values"] = "[{}*{}.get_lr(), {}.get_lr()]".format(factor, optim, optim)

kwargs["boundaries"] = "[{}]".format(total_iters)
return super().generate_code(kwargs)

Expand Down Expand Up @@ -4847,7 +4859,7 @@ def generate_code(self, kwargs):
class PositiveMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
"""
def positive(x):
if x.dtype != paddle.bool:
return x
Expand Down Expand Up @@ -5204,7 +5216,7 @@ def get_scalable_var(self):
if not (arg_name.startswith("*") and len(arg_name) > 1):
return None
return arg_name[1:]

def get_paddle_nodes(self, args, kwargs):
var_arg_name = self.get_scalable_var()
dest_var_arg_name = self.api_mapping.get("kwargs_change", {}).get(
Expand All @@ -5222,7 +5234,7 @@ def get_paddle_nodes(self, args, kwargs):
return ast.parse(code).body


class ScalableVarMatcher(BaseMatcher):
class ScalableVarMatcher(BaseMatcher):
def get_scalable_var(self):
args_list = self.api_mapping.get("args_list", [])
if len(args_list) != 1:
Expand Down Expand Up @@ -5779,3 +5791,60 @@ def generate_code(self, kwargs):
code = API_TEMPLATE.format(kwargs["device"])

return code


class GRUCellMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
class GRUCell(paddle.nn.GRUCell):
def forward(self, inputs, states = None):
return super().forward(inputs, states)[0]
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
return GenericMatcher.generate_code(self, kwargs)

def get_paddle_api(self):
return "paddle_aux.GRUCell"


class LSTMCellMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
class LSTMCell(paddle.nn.LSTMCell):
def forward(self, inputs, states = None):
return super().forward(inputs, states)[1]
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
return GenericMatcher.generate_code(self, kwargs)

def get_paddle_api(self):
return "paddle_aux.LSTMCell"


class RNNCellMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
class SimpleRNNCell(paddle.nn.SimpleRNNCell):
def forward(self, inputs, states = None):
return super().forward(inputs, states)[0]
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
return GenericMatcher.generate_code(self, kwargs)

def get_paddle_api(self):
return "paddle_aux.SimpleRNNCell"
4 changes: 4 additions & 0 deletions paconvert/attribute_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@
"torch.optim.Optimizer.state_dict": {
"Matcher": "UnchangeMatcher"
},
"torch.pi": {
"Matcher": "GenericMatcher",
"paddle_api": "numpy.pi"
},
"torch.short": {
"Matcher": "GenericMatcher",
"paddle_api": "'int16'"
Expand Down
1 change: 1 addition & 0 deletions tests/test_Tensor_scatter_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_case_7():
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_block_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_case_2():
A = torch.tensor([[4], [3], [2]])
B = torch.tensor([7, 6, 5])
C = torch.tensor(1)
result = torch.block_diag(torch.tensor([[4], [3], [2]]),
torch.tensor([7, 6, 5]),
result = torch.block_diag(torch.tensor([[4], [3], [2]]),
torch.tensor([7, 6, 5]),
torch.tensor(1))
"""
)
Expand Down
Loading