Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert fix1 #87

Draft
wants to merge 80 commits into
base: inference
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
2436152
push changes to print dot graph
Apr 17, 2023
7030ae0
padded input to 512, added isExact to slice_tensor
Apr 21, 2023
280d29a
Add multiprecision support to replicate
lockshaw May 7, 2023
8e84401
Add explicit template instantiations for replicate kernels
lockshaw May 7, 2023
a2ddb91
Fix incorrect instantiations
lockshaw May 8, 2023
52cc8e8
Add nop init_task for replicate
lockshaw May 8, 2023
9eb530a
Fix replicate init_task registration
lockshaw May 8, 2023
d7a219c
Hopefully print hip errors
lockshaw May 10, 2023
f323a6e
Instantiate extra hip replicate kernels
lockshaw May 10, 2023
9a75f24
fix
jiazhihao May 10, 2023
fe61cbe
Merge branch 'BertMLM_fixes' of https://github.com/flexflow/FlexFlow …
jiazhihao May 10, 2023
c9277f3
debug changs
jiazhihao May 10, 2023
fe66561
Add slice_tensor fix
May 12, 2023
9e04302
Merge branch 'BertMLM_fixes' of github.com:flexflow/FlexFlow into Ber…
May 12, 2023
1177748
Add logging for metrics
lockshaw May 12, 2023
5b7cace
Add the cuda metrics hack to hip kernel as well
lockshaw May 12, 2023
e798a91
Add parallel dim pretty printing
lockshaw May 12, 2023
90541cf
[Embedding] bug fix
jiazhihao May 12, 2023
63fcde6
Merge branch 'BertMLM_fixes' of https://github.com/flexflow/FlexFlow …
jiazhihao May 12, 2023
7862143
Add replica dim to pretty print
lockshaw May 12, 2023
9663d96
Merge remote-tracking branch 'refs/remotes/origin/BertMLM_fixes' into…
lockshaw May 12, 2023
ef43c36
Fix replicate issue with python hack
tnoyola May 12, 2023
dd8090e
Use local json submodule
lockshaw May 20, 2023
0dc6187
ofi conduit-related fixes
May 20, 2023
0950ac7
Add mpi flags for hip
lockshaw May 22, 2023
4b06040
fix fusion bug
jiazhihao May 24, 2023
6796b1c
Merge branch 'BertMLM_fixes' of https://github.com/flexflow/FlexFlow …
jiazhihao May 24, 2023
99e9f95
increase the max number of regions in a ZeroInitMeta from 64 to 128
jiazhihao May 24, 2023
282c44a
support mixed precision
jiazhihao May 24, 2023
992dcb9
undo changes to Fused::Transpose
jiazhihao May 24, 2023
f528774
undo changes to config.linux
jiazhihao May 24, 2023
a68150d
try to fix layernorm
jiazhihao Jun 2, 2023
2bf9afc
fix typo
jiazhihao Jun 2, 2023
f6f7a32
Add possible layernorm fix
lockshaw Jun 3, 2023
5e03b0a
Fix additional layernorm bug due to get_piece_size return size in bytes
lockshaw Jun 3, 2023
53fb8bd
Bugfixes
tnoyola Jun 3, 2023
449a14c
Actually check elementwise_affine
lockshaw Jun 3, 2023
c737be6
Revert "Actually check elementwise_affine"
tnoyola Jun 5, 2023
a98e09d
Change optimizer to adam with correct hyperparams
lockshaw Jun 6, 2023
66b805e
Merge remote-tracking branch 'refs/remotes/origin/BertMLM_fixes' into…
tnoyola Jun 6, 2023
4bec811
fix training bert model.
xinhaoc Jul 4, 2023
2d28c15
revert changes
xinhaoc Jul 4, 2023
2025d56
fix bert training issue. (#832)
xinhaoc Jul 5, 2023
5f793c1
Improve machine_view hash
lockshaw Jul 25, 2023
2c09397
Fix bugs in improved hashing
lockshaw Jul 25, 2023
862e9d7
fix weight dimension in layernorm
xinhaoc Jul 25, 2023
d29bf1d
Merge branch 'BertMLM_fixes' of https://github.com/flexflow/FlexFlow …
xinhaoc Jul 25, 2023
88ad5fa
Merge remote-tracking branch 'origin/master' into BertMLM_fixes
lockshaw Aug 8, 2023
2eee875
fix `preregister_task_variant` issue, linting
goliaro Aug 11, 2023
b9d1332
try to run graph_optimize on each node
jiazhihao Aug 14, 2023
b5b0815
remove unnecessary file
jiazhihao Aug 14, 2023
94e35d9
fix hip build
xinhaoc Aug 15, 2023
ac185e3
Merge branch 'BertMLM_fixes' of https://github.com/flexflow/FlexFlow …
xinhaoc Aug 15, 2023
ded175c
bypass simulator creation when only_data_parallel is specified
jiazhihao Aug 18, 2023
1f7e8b7
add nccl prints
jiazhihao Aug 18, 2023
3fb70f6
.
jiazhihao Aug 21, 2023
d652b62
rccl
xinhaoc Aug 29, 2023
b39528b
fix fuse
xinhaoc Oct 6, 2023
0cf3c8e
fix hip
xinhaoc Oct 6, 2023
17a1c4e
more fix to hip
xinhaoc Oct 6, 2023
f65044d
customized kernel for broadcasting add.
xinhaoc Nov 3, 2023
bcab56a
dropout
xinhaoc Dec 20, 2023
fa1fffc
optimizer
xinhaoc Dec 20, 2023
40d830c
opt
xinhaoc Dec 21, 2023
e825526
fix
xinhaoc Dec 21, 2023
d2bdb15
fix
xinhaoc Jan 12, 2024
3b9e1c6
.
xinhaoc Jan 12, 2024
9f8bb9e
fix
xinhaoc Jan 12, 2024
fb91122
remove print
xinhaoc Jan 12, 2024
c162d4c
fix hip
xinhaoc Jan 19, 2024
ea79317
fix multinodes
xinhaoc Jan 19, 2024
58d84ed
fix
xinhaoc Jan 19, 2024
01c9d4c
fix
xinhaoc Jan 25, 2024
a31f8e9
fix
xinhaoc Jan 26, 2024
9141c46
tp
xinhaoc Feb 2, 2024
8185289
timer
xinhaoc Feb 15, 2024
d958805
rmv
xinhaoc Feb 16, 2024
38dfd87
fix tp
xinhaoc Feb 22, 2024
355d4b4
try a fix
xinhaoc Feb 29, 2024
8488ba0
fix hip
xinhaoc Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ endif()
# option for nccl
option(FF_USE_NCCL "Run FlexFlow with NCCL" OFF)

if (FF_GPU_BACKEND STREQUAL "hip_rocm" AND FF_USE_NCCL STREQUAL "ON")
message(FATAL_ERROR "NCCL: ON for FF_GPU_BACKEND: hip_rocm. hip_rocm backend must have NCCL disabled.")
endif()

# option for avx2
option(FF_USE_AVX2 "Run FlexFlow with AVX2" OFF)

Expand Down Expand Up @@ -224,7 +220,9 @@ endif()

# NCCL
if(FF_USE_NCCL)
include(nccl)
if(FF_GPU_BACKEND STREQUAL "hip_cuda" OR FF_GPU_BACKEND STREQUAL "cuda")
include(nccl)
endif()
list(APPEND FF_CC_FLAGS
-DFF_USE_NCCL)
list(APPEND FF_NVCC_FLAGS
Expand Down Expand Up @@ -369,11 +367,13 @@ elseif(FF_GPU_BACKEND STREQUAL "hip_cuda" OR FF_GPU_BACKEND STREQUAL "hip_rocm")
elseif(FF_GPU_BACKEND STREQUAL "hip_rocm")
find_package(hipblas REQUIRED)
find_package(miopen REQUIRED)
if(FF_USE_NCCL)
find_package(rccl REQUIRED)
endif()
# find_package(rocrand REQUIRED)
find_library(HIP_RAND_LIBRARY hiprand REQUIRED)

add_compile_definitions(FF_USE_HIP_ROCM)

# The hip cmake config module defines three targets,
# hip::amdhip64, hip::host, and hip::device.
#
Expand All @@ -387,12 +387,15 @@ elseif(FF_GPU_BACKEND STREQUAL "hip_cuda" OR FF_GPU_BACKEND STREQUAL "hip_rocm")
# Docs (outdated):
# https://rocmdocs.amd.com/en/latest/Installation_Guide/Using-CMake-with-AMD-ROCm.html
target_link_libraries(flexflow hip::device roc::hipblas MIOpen ${HIP_RAND_LIBRARY})
if(FF_USE_NCCL)
target_link_libraries(flexflow rccl)
endif()
endif()
else()
message(FATAL_ERROR "Unsupported FF_GPU_BACKEND for cmake: ${FF_GPU_BACKEND}")
endif()

if(FF_USE_NCCL)
if(FF_USE_NCCL AND (FF_GPU_BACKEND STREQUAL "hip_cuda" OR FF_GPU_BACKEND STREQUAL "cuda"))
add_dependencies(flexflow ${NCCL_NAME})
endif()

Expand Down
5 changes: 1 addition & 4 deletions cmake/json.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
include(FetchContent)

FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.10.5/json.tar.xz)
FetchContent_MakeAvailable(json)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/json)
4 changes: 3 additions & 1 deletion config/config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ if [ "$FF_LEGION_NETWORKS" = "gasnet" ]; then
elif [ "$FF_GASNET_CONDUIT" = "ucx" ]; then
SET_LEGION_NETWORKS+=" -DFF_GASNET_CONDUIT=ucx"
SET_LEGION_NETWORKS+=" -DFF_UCX_URL=$FF_UCX_URL"
elif [ "$FF_GASNET_CONDUIT" = "ofi" ]; then
SET_LEGION_NETWORKS+=" -DFF_GASNET_CONDUIT=ofi"
fi
elif [ "$FF_LEGION_NETWORKS" = "ucx" ]; then
SET_LEGION_NETWORKS+=" -DFF_LEGION_NETWORKS=ucx"
Expand Down Expand Up @@ -182,7 +184,7 @@ if [ -n "$FF_GPU_BACKEND" ]; then
chmod +x "$(pwd)/nvidia_hipcc"
SET_CXX="-DCMAKE_CXX_COMPILER=$(pwd)/nvidia_hipcc -DCMAKE_CXX_LINKER=$(pwd)/nvidia_hipcc"
else
SET_CXX="-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -DCMAKE_CXX_LINKER=/opt/rocm/bin/hipcc"
SET_CXX="-DCMAKE_CXX_COMPILER=$ROCM_PATH/bin/hipcc -DCMAKE_CXX_LINKER=$ROCM_PATH/bin/hipcc -DHIP_PATH=$ROCM_PATH/hip -DCMAKE_CXX_FLAGS='-I${MPICH_DIR}/include' -DCMAKE_EXE_LINKER_FLAGS='-L${MPICH_DIR}/lib -lmpi' -DCMAKE_SHARED_LINKER_FLAGS='-L${MPICH_DIR}/lib -lmpi'"
fi
fi
fi
Expand Down
6 changes: 2 additions & 4 deletions config/config.linux
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ FF_USE_PYTHON=${FF_USE_PYTHON:-ON}
FF_LEGION_NETWORKS=${FF_LEGION_NETWORKS:-}

# select GASNET conduit
FF_GASNET_CONDUIT=${FF_GASNET_CONDUIT:-ibv}
FF_GASNET_CONDUIT=${FF_GASNET_CONDUIT:-ofi}

# set UCX URL
FF_UCX_URL=${FF_UCX_URL:-""}
Expand Down Expand Up @@ -70,11 +70,9 @@ FF_GPU_BACKEND=${FF_GPU_BACKEND:-cuda}
if [[ "${FF_GPU_BACKEND}" != @(cuda|hip_cuda|hip_rocm|intel) ]]; then
echo "Error, value of FF_GPU_BACKEND (${FF_GPU_BACKEND}) is invalid."
exit 1
elif [[ "$FF_GPU_BACKEND" == "cuda" || "$FF_GPU_BACKEND" = "hip_cuda" ]]; then
elif [["$FF_GPU_BACKEND" == "cuda" || "$FF_GPU_BACKEND" = "hip_cuda" || "$FF_GPU_BACKEND" == "hip_rocm"]]; then
# enable NCCL
FF_USE_NCCL=${FF_USE_NCCL:-ON}
else
FF_USE_NCCL=OFF
fi

function get_build_configs() {
Expand Down
118 changes: 74 additions & 44 deletions examples/python/pytorch/mt5/mt5_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import sys

import numpy as np
import torch
from flexflow.core import *
from flexflow.torch.model import PyTorchModel
from transformers import MT5ForConditionalGeneration, T5Tokenizer

#from transformers import MT5ForConditionalGeneration, T5Tokenizer
from transformers import BertForMaskedLM, BertTokenizer
sys.path.append("./examples/python/pytorch/mt5")
from mt5_torch import DataPreparer, get_dataloaders, set_seed

BASE_DIR = "examples/python/pytorch/mt5"
DATA_DIR = os.path.join(BASE_DIR, "data")
NUMPY_DIR = os.path.join(DATA_DIR, "numpy")
NUMPY_DIR = os.path.join(DATA_DIR, "numpy_candle")


def data_to_numpy() -> None:
Expand All @@ -28,15 +29,17 @@ def data_to_numpy() -> None:
"""
model_params = {
"SEED": 42,
"MODEL": "google/mt5-small",
#"MODEL": "google/mt5-small",
"MODEL": "bert-base-uncased",
"TRAIN_BATCH_SIZE": None, # use the full dataset as one batch
"EVAL_BATCH_SIZE": None, # use the full dataset as one batch
"TRAIN_EPOCHS": 1, # unused
"MAX_SOURCE_TEXT_LENGTH": 48,
"MAX_TARGET_TEXT_LENGTH": 48,
}
set_seed(model_params)
tokenizer = T5Tokenizer.from_pretrained(model_params["MODEL"])
#tokenizer = T5Tokenizer.from_pretrained(model_params["MODEL"])
tokenizer = BertTokenizer.from_pretrained(model_params["MODEL"])
print("Getting dataloaders...")
train_loader, eval_loader = get_dataloaders(tokenizer, model_params)
assert len(train_loader) == 1
Expand All @@ -61,8 +64,8 @@ def preprocess_train() -> None:
y_shape = y.shape
assert len(y.shape) == 2, \
"`y` should have shape (num examples, sequence length)"
y_ids = np.empty((y_shape[0], y_shape[1] - 1), dtype=np.long)
lm_labels = np.empty((y_shape[0], y_shape[1] - 1), dtype=np.long)
y_ids = np.empty((y_shape[0], y_shape[1] - 1), dtype=np.int32)
lm_labels = np.empty((y_shape[0], y_shape[1] - 1), dtype=np.int32)
y_ids[:, :] = y[:, :-1]
lm_labels[:, :] = y[:, 1:]

Expand All @@ -81,36 +84,54 @@ def preprocess_train() -> None:
def top_level_task():
ffconfig = FFConfig()
ffmodel = FFModel(ffconfig)
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")

#model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
#model = BertModel.from_pretrained("bert-base-uncased")
# Load train data as numpy arrays
print("Loading data...")
ids = np.load(os.path.join(NUMPY_DIR, "train_source_ids.npy"))
mask = np.load(os.path.join(NUMPY_DIR, "train_source_mask.npy"))
y_ids = np.load(os.path.join(NUMPY_DIR, "train_y_ids.npy"))
lm_labels = np.load(os.path.join(NUMPY_DIR, "train_lm_labels.npy"))
ids = np.load(os.path.join(NUMPY_DIR, "train_input_ids.npy")).astype('int32')
ids = np.pad(ids, ((0,0), (0,17)), 'constant')
#ids = np.random.randint(0, 5, (1000, 512))
#print('ids_shape', ids.shape)
#print('ids', ids)
mask = np.load(os.path.join(NUMPY_DIR, "train_attention_mask.npy")).astype('int32')
mask = np.pad(mask, ((0,0), (0,17)), 'constant')
#mask = np.random.randint(0, 2, (1000, 512))
#y_ids = np.load(os.path.join(NUMPY_DIR, "train_y_ids.npy"))
lm_labels = np.load(os.path.join(NUMPY_DIR, "train_labels.npy")).astype('int32')
lm_labels = np.pad(lm_labels, ((0,0), (0,17)), 'constant')
#lm_labels = np.random.randint(-1, 5, (1000, 512))
position_id = torch.arange(ids.shape[1], dtype=torch.int32).expand((1, -1)).numpy()
token_type_ids = torch.zeros(ids.shape[1], dtype=torch.int32).expand((1, -1)).numpy()


batch_size = ffconfig.batch_size
input_ids_shape = (batch_size, ids.shape[1])
attention_mask_shape = (batch_size, mask.shape[1])
decoder_input_ids_shape = (batch_size, y_ids.shape[1])
#decoder_input_ids_shape = (batch_size, y_ids.shape[1])
input_tensors = [
ffmodel.create_tensor(input_ids_shape, DataType.DT_INT64), # input_ids
ffmodel.create_tensor(attention_mask_shape, DataType.DT_INT64), # attention_mask
ffmodel.create_tensor(decoder_input_ids_shape, DataType.DT_INT64), # decoder_input_ids
ffmodel.create_tensor(input_ids_shape, DataType.DT_INT32), # input_ids
ffmodel.create_tensor(attention_mask_shape, DataType.DT_INT32), # attention_mask
#ffmodel.create_tensor(decoder_input_ids_shape, DataType.DT_INT64), # decoder_input_ids
]
encoder_seq_length = ids.shape[1]
decoder_seq_length = y_ids.shape[1]
seq_length = (encoder_seq_length, decoder_seq_length)
input_names = ["input_ids", "attention_mask", "decoder_input_ids"]
#decoder_seq_length = y_ids.shape[1]
#seq_length = (encoder_seq_length, decoder_seq_length)
seq_length = encoder_seq_length
#input_names = ["input_ids", "attention_mask", "decoder_input_ids"]
input_names = ["input_ids", "attention_mask"]

print("Tracing the model...")
print(batch_size)
hf_model = PyTorchModel(
model, is_hf_model=True, input_names=input_names,
batch_size=batch_size, seq_length=seq_length,
)
output_tensors = hf_model.torch_to_ff(ffmodel, input_tensors, verbose=True)
ffoptimizer = SGDOptimizer(ffmodel, lr=0.01)
#from flexflow.torch.model import file_to_ff
#file_to_ff("mt5.ff", ffmodel, input_tensors)
ffoptimizer = AdamOptimizer(ffmodel, alpha=1e-4, beta1=0.9, beta2=0.98, weight_decay=0.0, epsilon=2e-8)
# ffoptimizer = SGDOptimizer(ffmodel, lr=0.01)

print("Compiling the model...")
ffmodel.compile(
Expand All @@ -121,13 +142,21 @@ def top_level_task():
MetricsType.METRICS_SPARSE_CATEGORICAL_CROSSENTROPY,
],
)

# load weights here
ffmodel.load_bert_pretrained(checkpoint=model)

print("Creating data loaders...")
print('id_dtype', ids.dtype)
print('mask_dtype', mask.dtype)
print('labels_dtype', lm_labels.dtype)
input_ids_dl = ffmodel.create_data_loader(input_tensors[0], ids)
attention_mask_dl = ffmodel.create_data_loader(input_tensors[1], mask)
decoder_input_ids_dl = ffmodel.create_data_loader(input_tensors[2], y_ids)
#decoder_input_ids_dl = ffmodel.create_data_loader(input_tensors[2], y_ids)
# NOTE: We cast down the label tensor data to 32-bit to accommodate the
# label tensor's required dtype
token_type_ids_dl = ffmodel.create_data_loader(input_tensors[2], token_type_ids)
position_id_dl = ffmodel.create_data_loader(input_tensors[3], position_id)
labels_dl = ffmodel.create_data_loader(
ffmodel.label_tensor, lm_labels.astype("int32")
)
Expand All @@ -138,31 +167,32 @@ def top_level_task():
print("Training...")
epochs = ffconfig.epochs
ffmodel.fit(
x=[input_ids_dl, attention_mask_dl, decoder_input_ids_dl],
#x=[input_ids_dl, attention_mask_dl, decoder_input_ids_dl],
x=[input_ids_dl, attention_mask_dl, position_id_dl, token_type_ids_dl],
y=labels_dl, batch_size=batch_size, epochs=epochs,
)


if __name__ == "__main__":
# Generate the .tsv files if needed
if not os.path.exists(os.path.join(DATA_DIR, "train.tsv")) or \
not os.path.exists(os.path.join(DATA_DIR, "eval.tsv")):
DataPreparer.data_to_tsv()
# Convert the .tsv files to .npy if needed
if not os.path.exists(NUMPY_DIR):
os.mkdir(NUMPY_DIR)
prefixes = ["train_", "eval_"]
suffixes = ["source_ids.npy", "source_mask.npy", "target_ids.npy"]
npy_filenames = [
pre + suf for pre, suf in itertools.product(prefixes, suffixes)
]
if any(
not os.path.exists(os.path.join(NUMPY_DIR, filename))
for filename in npy_filenames
):
data_to_numpy()
# Preprocess the training data if needed
if not os.path.exists(os.path.join(NUMPY_DIR, "train_y_ids.npy")) or \
not os.path.exists(os.path.join(NUMPY_DIR, "train_lm_labels.npy")):
preprocess_train()
## Generate the .tsv files if needed
#if not os.path.exists(os.path.join(DATA_DIR, "train.tsv")) or \
# not os.path.exists(os.path.join(DATA_DIR, "eval.tsv")):
# DataPreparer.data_to_tsv()
## Convert the .tsv files to .npy if needed
#if not os.path.exists(NUMPY_DIR):
# os.mkdir(NUMPY_DIR)
#prefixes = ["train_", "eval_"]
#suffixes = ["source_ids.npy", "source_mask.npy", "target_ids.npy"]
#npy_filenames = [
# pre + suf for pre, suf in itertools.product(prefixes, suffixes)
#]
#if any(
# not os.path.exists(os.path.join(NUMPY_DIR, filename))
# for filename in npy_filenames
#):
# data_to_numpy()
## Preprocess the training data if needed
#if not os.path.exists(os.path.join(NUMPY_DIR, "train_y_ids.npy")) or \
# not os.path.exists(os.path.join(NUMPY_DIR, "train_lm_labels.npy")):
# preprocess_train()
top_level_task()
4 changes: 2 additions & 2 deletions examples/python/pytorch/mt5/mt5_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

import numpy as np
import pandas as pd
#import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import MT5ForConditionalGeneration, T5Tokenizer
Expand Down Expand Up @@ -311,5 +311,5 @@ def TorchMT5Trainer(
"MAX_TARGET_TEXT_LENGTH": 48,
"LEARNING_RATE": 1e-4,
}
device = torch.device(0)
device = torch.device('cpu')
TorchMT5Trainer(model_params, device)
31 changes: 29 additions & 2 deletions gdb/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def to_string(self):
size = dim['size']
degree = dim['degree']
parallel_idx = dim['parallel_idx']
toks.append(f'{i}=[s={size} d={degree} pi={parallel_idx}]')
if dim['is_replica_dim']:
is_replica = 'r=t'
else:
is_replica = 'r=f'
toks.append(f'{i}=[s={size} d={degree} pi={parallel_idx} {is_replica}]')
return f'TensorShape<{" ".join(toks)}>'

class ParallelTensorBasePrinter:
Expand All @@ -77,9 +81,31 @@ def to_string(self):
size = dim['size']
degree = dim['degree']
parallel_idx = dim['parallel_idx']
toks.append(f'{i}=[s={size} d={degree} pi={parallel_idx}]')
tok = f'{i}=[s={size} d={degree} pi={parallel_idx} '
if dim['is_replica_dim']:
tok += 'r=t'
else:
tok += 'r=f'
tok += ']'
toks.append(tok)
return f'ParallelTensorBase<{" ".join(toks)}>'

class ParallelDimPrinter:
def __init__(self, val):
self.val = val

def to_string(self):
size = self.val['size']
degree = self.val['degree']
parallel_idx = self.val['parallel_idx']
tok = f's={size} d={degree} pi={parallel_idx} '
if dim['is_replica_dim']:
tok += 'r=t'
else:
tok += 'r=f'
return f'ParallelDim<{tok}>'


def build_pretty_printer():
pp = gdb.printing.RegexpCollectionPrettyPrinter(
"flexflow")
Expand All @@ -89,6 +115,7 @@ def build_pretty_printer():
pp.add_printer('Domain', '^Legion::Domain$', DomainPrinter)
pp.add_printer('ParallelTensorShape', '^FlexFlow::ParallelTensorShape$', TensorShapePrinter)
pp.add_printer('ParallelTensorBase', '^FlexFlow::ParallelTensorBase$', ParallelTensorBasePrinter)
pp.add_printer('ParallelDim', '^FlexFlow::ParallelDim$', ParallelDimPrinter)
return pp

gdb.printing.register_pretty_printer(
Expand Down
7 changes: 6 additions & 1 deletion include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
#error "Unknown device"
#endif
#include "tl/optional.hpp"
#ifdef FF_USE_NCCL
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
#include <nccl.h>
#else
#include <rccl.h>
#endif

namespace FlexFlow {
Expand Down Expand Up @@ -122,6 +124,7 @@ class FFConfig {
size_t workSpaceSize;
Legion::Context lg_ctx;
Legion::Runtime *lg_hlr;
Legion::IndexSpaceT<1> all_gpu_task_is;
Legion::FieldSpace field_space;
bool syntheticInput, profiling, perform_fusion;
size_t simulator_work_space_size;
Expand All @@ -135,6 +138,8 @@ class FFConfig {
bool enable_parameter_parallel;
bool enable_attribute_parallel;
bool enable_inplace_optimizations;
int data_parallelism_degree;
int tensor_parallelism_degree;
// Control Tensor Op Math Conversion
bool allow_tensor_op_math_conversion;
std::string dataset_path;
Expand Down
Loading
Loading