Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Dec 26, 2023
1 parent 603e0b8 commit 936eab2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ pip install paddleslim

安装develop版本:
```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git & cd PaddleSlim
git clone https://github.com/PaddlePaddle/PaddleSlim.git && cd PaddleSlim
python setup.py install
```

Expand Down
10 changes: 5 additions & 5 deletions ce_tests/dygraph/quant/src/imagenet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ def __init__(self,
if self.mode == 'train':
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(), transforms.Transpose(),
normalize
transforms.RandomHorizontalFlip(),
transforms.Transpose(), normalize
])
else:
self.transform = transforms.Compose([
transforms.Resize(resize_short_size),
transforms.CenterCrop(image_size), transforms.Transpose(),
normalize
transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])

if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
if os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exists
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
Expand Down
17 changes: 9 additions & 8 deletions ce_tests/dygraph/quant/src/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():
}
dygraph_qat = QAT(quant_config)
else:
print("use navie api")
print("use naive api")
dygraph_qat = ImperativeQuantAware(
weight_quantize_type=FLAGS.weight_quantize_type, )
dygraph_qat.quantize(model)
Expand Down Expand Up @@ -112,12 +112,13 @@ def main():
if not os.path.exists(output_dir):
os.makedirs(output_dir)

model.fit(train_dataset,
val_dataset,
batch_size=FLAGS.batch_size,
epochs=FLAGS.epoch,
save_dir=output_dir,
num_workers=FLAGS.num_workers)
model.fit(
train_dataset,
val_dataset,
batch_size=FLAGS.batch_size,
epochs=FLAGS.epoch,
save_dir=output_dir,
num_workers=FLAGS.num_workers)

# save
if FLAGS.enable_quant:
Expand Down Expand Up @@ -183,7 +184,7 @@ def main():
parser.add_argument(
"--enable_quant", action='store_true', help="enable quant model")
parser.add_argument(
"--use_naive_api", action='store_true', help="use the navie api")
"--use_naive_api", action='store_true', help="use the naive api")
parser.add_argument(
"--weight_quantize_type", type=str, default='abs_max', help="")

Expand Down
2 changes: 1 addition & 1 deletion demo/deep_mutual_learning/cifar100_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def preprocess(sample, is_training):
img = Image.fromarray(rgb_array, 'RGB')

if is_training:
# pad, ramdom crop, random_flip_left_right, random_rotation
# pad, random crop, random_flip_left_right, random_rotation
img = ImageOps.expand(img, (4, 4, 4, 4), fill=0)
left_top = np.random.randint(8, size=2)
img = img.crop((left_top[1], left_top[0], left_top[1] + IMAGE_SIZE,
Expand Down

0 comments on commit 936eab2

Please sign in to comment.