Skip to content

Commit

Permalink
[dicp][ascend] Fix hf llama-inference precision. (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 authored Nov 16, 2023
1 parent 5fcb3da commit 2429afb
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 94 deletions.
10 changes: 10 additions & 0 deletions dicp/dicp/dynamo_bridge/op_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.fx.proxy import Proxy
from typing import Any, Dict, Tuple
from dicp.dynamo_bridge.compile_fx import is_torch_210
from dicp.vendor.AscendGraph.codegen.utils import symint_in_shape


class OpSetTransformer:
Expand All @@ -24,13 +25,22 @@ def __init__(self, module, conversions):
super().__init__(module)
self._conversions = conversions
self.sym_to_inputs = {}
self.sym_in_args = {}

def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Proxy:
proxy = super().placeholder(target, args, kwargs)
proxy.node.meta = fx_traceback.get_current_meta()
fake_tensor = proxy.node.meta['val']
if isinstance(fake_tensor, torch.SymInt):
self.sym_to_inputs[fake_tensor.node.str()] = proxy
elif symint_in_shape(fake_tensor.shape):
# mention symint position in args
# dynamic shape feature
for idx, dim in enumerate(fake_tensor.shape):
if isinstance(dim, torch.SymInt):
st = dim.node.str()
if not st in self.sym_in_args:
self.sym_in_args[st] = (proxy, idx)
return proxy

def get_proxy(self, target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] = {}):
Expand Down
14 changes: 12 additions & 2 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self):
super().__init__("Range")


class CumSum(Operator):
class Cumsum(Operator):
def __init__(self):
super().__init__("Cumsum")

Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(self):
super().__init__("TopK")


class ScatterElement(Operator):
class ScatterElements(Operator):
def __init__(self):
super().__init__("ScatterElements")

Expand Down Expand Up @@ -245,11 +245,21 @@ def __init__(self):
super().__init__("Cast")


class CastToCpu(Operator):
def __init__(self):
super().__init__("CastToCpu")


class Identity(Operator):
def __init__(self):
super().__init__("Identity")


class IdentityInp(Operator):
def __init__(self):
super().__init__("IdentityInp")


class IdentityN(Operator):
def __init__(self):
super().__init__("IdentityN")
Expand Down
Loading

0 comments on commit 2429afb

Please sign in to comment.