Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

零散问题修复 #511

Merged
merged 20 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paconvert/api_alias_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
"torch.bilinear": "torch.nn.functional.bilinear",
"torch.celu_": "torch.nn.functional.celu_",
"torch.channel_shuffle": "torch.nn.functional.channel_shuffle",
"torch.concatenate": "torch.cat",
"torch.clip": "torch.clamp",
"torch.concatenate": "torch.cat",
"torch.conv1d": "torch.nn.functional.conv1d",
"torch.conv2d": "torch.nn.functional.conv2d",
"torch.conv3d": "torch.nn.functional.conv3d",
Expand Down
67 changes: 35 additions & 32 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"num_embeddings",
"embedding_dim",
"padding_idx",
"max_norm",
"norm_type",
"scale_grad_by_freq",
"sparse",
Expand All @@ -72,7 +73,8 @@
"padding_idx",
"norm_type",
"scale_grad_by_freq",
"sparse"
"sparse",
"max_norm"
]
},
"fairscale.nn.model_parallel.layers.RowParallelLinear": {
Expand Down Expand Up @@ -125,14 +127,14 @@
"unsupport_args": [
"window_size",
"alibi_slopes",
"deterministic",
"return_attn_probs"
"deterministic"
],
"kwargs_change": {
"q": "query",
"k": "key",
"v": "value",
"dropout_p": "dropout"
"dropout_p": "dropout",
"return_attn_probs": "return_softmax"
}
},
"flash_attn.flash_attn_interface.flash_attn_unpadded_func": {
Expand All @@ -157,15 +159,15 @@
"unsupport_args": [
"window_size",
"alibi_slopes",
"deterministic",
"return_attn_probs"
"deterministic"
],
"kwargs_change": {
"q": "query",
"k": "key",
"v": "value",
"dropout_p": "dropout",
"softmax_scale": "scale"
"softmax_scale": "scale",
"return_attn_probs": "return_softmax"
}
},
"flash_attn.layers.rotary.apply_rotary_emb_func": {
Expand All @@ -189,7 +191,10 @@
"x",
"weight",
"epsilon"
]
],
"kwargs_change": {
"weight": "norm_weight"
}
},
"os.environ.get": {
"Matcher": "OsEnvironGetMatcher",
Expand Down Expand Up @@ -2172,6 +2177,7 @@
},
"torch.Tensor.is_inference": {
"Matcher": "Is_InferenceMatcher",
"paddle_api": "paddle.Tensor.stop_gradient",
"min_input_args": 0
},
"torch.Tensor.is_pinned": {
Expand Down Expand Up @@ -3192,7 +3198,7 @@
]
},
"torch.Tensor.positive": {
"Matcher": "PositiveMatcher"
"Matcher": "PositiveMatcher"
},
"torch.Tensor.pow": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -3571,7 +3577,7 @@
}
},
"torch.Tensor.scatter_reduce": {
"Matcher": "ScatterReduceMatcher",
"Matcher": "ScatterReduceMatcher",
"paddle_api": "paddle.Tensor.put_along_axis",
"min_input_args": 3,
"args_list": [
Expand Down Expand Up @@ -4024,7 +4030,8 @@
]
},
"torch.Tensor.to_sparse_coo": {
"Matcher": "TensorToSparseCooMatcher"
"Matcher": "TensorToSparseCooMatcher",
"paddle_api": "paddle.Tensor.to_sparse_coo"
},
"torch.Tensor.tolist": {
"Matcher": "UnchangeMatcher",
Expand Down Expand Up @@ -7836,6 +7843,16 @@
"axis": 0
}
},
"torch.float_power": {
"Matcher": "FloatPowerMatcher",
"min_input_args": 2,
"args_list": [
"input",
"exponent",
"*",
"out"
]
},
"torch.floor": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.floor",
Expand All @@ -7849,16 +7866,6 @@
"input": "x"
}
},
"torch.float_power": {
"Matcher": "FloatPowerMatcher",
"min_input_args": 2,
"args_list": [
"input",
"exponent",
"*",
"out"
]
},
"torch.floor_divide": {
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.floor_divide",
Expand Down Expand Up @@ -8518,6 +8525,7 @@
},
"torch.is_inference": {
"Matcher": "Is_InferenceMatcher",
"paddle_api": "paddle.Tensor.stop_gradient",
"min_input_args": 1,
"args_list": [
"input"
Expand Down Expand Up @@ -10723,8 +10731,8 @@
"min_input_args": 0
},
"torch.nn.GRUCell": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.GRUCell",
"Matcher": "CellMatcher",
"paddle_api": "paddle_aux.GRUCell_1",
"args_list": [
"input_size",
"hidden_size",
Expand Down Expand Up @@ -10967,8 +10975,8 @@
"min_input_args": 0
},
"torch.nn.LSTMCell": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.LSTMCell",
"Matcher": "CellMatcher",
"paddle_api": "paddle_aux.LSTMCell_1",
Xuxuanang marked this conversation as resolved.
Show resolved Hide resolved
"args_list": [
"input_size",
"hidden_size",
Expand Down Expand Up @@ -11731,8 +11739,8 @@
"min_input_args": 3
},
"torch.nn.RNNCell": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.SimpleRNNCell",
"Matcher": "CellMatcher",
"paddle_api": "paddle_aux.SimpleRNNCell_1",
"args_list": [
"input_size",
"hidden_size",
Expand Down Expand Up @@ -14867,11 +14875,6 @@
"dims": "perm"
}
},
"torch.pi": {
"Matcher": "GenericMatcher",
"paddle_api": "numpy.pi",
"min_input_args": 0
},
"torch.pinverse": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.linalg.pinv",
Expand Down
58 changes: 47 additions & 11 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3238,19 +3238,26 @@ def generate_code(self, kwargs):

class AvgPoolMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" in kwargs:
kwargs["x"] = kwargs.pop("input")
new_kwargs = {}
for k in list(kwargs.keys()):
if k == "input":
new_kwargs["x"] = kwargs.pop(k)
else:
new_kwargs[k] = kwargs.pop(k)

if "count_include_pad" in kwargs:
kwargs["exclusive"] = "not " + kwargs.pop("count_include_pad")
if "count_include_pad" in new_kwargs:
new_kwargs["exclusive"] = "not " + new_kwargs.pop("count_include_pad")
else:
kwargs["exclusive"] = "False"
new_kwargs["exclusive"] = "False"

API_TEMPLATE = textwrap.dedent(
"""
{}({})
"""
)
code = API_TEMPLATE.format(self.get_paddle_api(), self.kwargs_to_str(kwargs))
code = API_TEMPLATE.format(
self.get_paddle_api(), self.kwargs_to_str(new_kwargs)
)

return code

Expand Down Expand Up @@ -4358,13 +4365,18 @@ def generate_code(self, kwargs):
class ConstantLRMatcher(LRSchedulerMatcher):
def generate_code(self, kwargs):
optim = kwargs["optimizer"]
factor = 0.3333333333333333
total_iters = 5
if "factor" in kwargs:
factor = kwargs.pop("factor")
kwargs["values"] = "[{}*{}.get_lr(), {}.get_lr()]".format(
factor, optim, optim
)
else:
kwargs["values"] = "[{}.get_lr()/3, {}.get_lr()]".format(optim, optim)

if "total_iters" in kwargs:
total_iters = kwargs.pop("total_iters")
kwargs["values"] = "[{}*{}.get_lr(), {}.get_lr()]".format(factor, optim, optim)

kwargs["boundaries"] = "[{}]".format(total_iters)
return super().generate_code(kwargs)

Expand Down Expand Up @@ -4847,7 +4859,7 @@ def generate_code(self, kwargs):
class PositiveMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
"""
def positive(x):
if x.dtype != paddle.bool:
return x
Expand Down Expand Up @@ -5204,7 +5216,7 @@ def get_scalable_var(self):
if not (arg_name.startswith("*") and len(arg_name) > 1):
return None
return arg_name[1:]

def get_paddle_nodes(self, args, kwargs):
var_arg_name = self.get_scalable_var()
dest_var_arg_name = self.api_mapping.get("kwargs_change", {}).get(
Expand All @@ -5222,7 +5234,7 @@ def get_paddle_nodes(self, args, kwargs):
return ast.parse(code).body


class ScalableVarMatcher(BaseMatcher):
class ScalableVarMatcher(BaseMatcher):
def get_scalable_var(self):
args_list = self.api_mapping.get("args_list", [])
if len(args_list) != 1:
Expand Down Expand Up @@ -5779,3 +5791,27 @@ def generate_code(self, kwargs):
code = API_TEMPLATE.format(kwargs["device"])

return code


class CellMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
class LSTMCell_1(paddle.nn.LSTMCell):
Xuxuanang marked this conversation as resolved.
Show resolved Hide resolved
def forward(self, inputs, states = None):
return super().forward(inputs, states)[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他都取的[0],这个取[1]是确定的吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确定的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看这里返回的是一个值,为何单测里测试的返回值有两个

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回的是包含两个元素的元组
{5C169DC4-AFCD-40C2-97E8-98B9EA5379DE}


class GRUCell_1(paddle.nn.GRUCell):
def forward(self, inputs, states = None):
return super().forward(inputs, states)[0]

class SimpleRNNCell_1(paddle.nn.SimpleRNNCell):
def forward(self, inputs, states = None):
return super().forward(inputs, states)[0]
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
return GenericMatcher.generate_code(self, kwargs)
4 changes: 4 additions & 0 deletions paconvert/attribute_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@
"torch.optim.Optimizer.state_dict": {
"Matcher": "UnchangeMatcher"
},
"torch.pi": {
"Matcher": "GenericMatcher",
"paddle_api": "numpy.pi"
},
"torch.short": {
"Matcher": "GenericMatcher",
"paddle_api": "'int16'"
Expand Down
1 change: 1 addition & 0 deletions tests/test_Tensor_scatter_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_case_7():
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_block_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_case_2():
A = torch.tensor([[4], [3], [2]])
B = torch.tensor([7, 6, 5])
C = torch.tensor(1)
result = torch.block_diag(torch.tensor([[4], [3], [2]]),
torch.tensor([7, 6, 5]),
result = torch.block_diag(torch.tensor([[4], [3], [2]]),
torch.tensor([7, 6, 5]),
torch.tensor(1))
"""
)
Expand Down
Loading