-
Notifications
You must be signed in to change notification settings - Fork 846
/
dlrm_s_pytorch.py
1907 lines (1711 loc) · 73.3 KB
/
dlrm_s_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Description: an implementation of a deep learning recommendation model (DLRM)
# The model input consists of dense and sparse features. The former is a vector
# of floating point values. The latter is a list of sparse indices into
# embedding tables, which consist of vectors of floating point values.
# The selected vectors are passed to mlp networks denoted by triangles,
# in some cases the vectors are interacted through operators (Ops).
#
# output:
# vector of values
# model: |
# /\
# /__\
# |
# _____________________> Op <___________________
# / | \
# /\ /\ /\
# /__\ /__\ ... /__\
# | | |
# | Op Op
# | ____/__\_____ ____/__\____
# | |_Emb_|____|__| ... |_Emb_|__|___|
# input:
# [ dense features ] [sparse indices] , ..., [sparse indices]
#
# More precise definition of model layers:
# 1) fully connected layers of an mlp
# z = f(y)
# y = Wx + b
#
# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk])
# z = Op(e1,...,ek)
# obtain vectors e1=E[:,p1], ..., ek=E[:,pk]
#
# 3) Operator Op can be one of the following
# Sum(e1,...,ek) = e1 + ... + ek
# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek]
# Cat(e1,...,ek) = [e1', ..., ek']'
# where ' denotes transpose operation
#
# References:
# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang,
# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu,
# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii,
# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko,
# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong,
# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and
# Recommendation Systems", CoRR, arXiv:1906.00091, 2019
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
# miscellaneous
import builtins
import datetime
import json
import sys
import time
# onnx
# The onnx import causes deprecation warnings every time workers
# are spawned during testing. So, we filter out those warnings.
import warnings
# data generation
import dlrm_data_pytorch as dp
# For distributed run
import extend_distributed as ext_dist
import mlperf_logger
# numpy
import numpy as np
import optim.rwsadagrad as RowWiseSparseAdagrad
import sklearn.metrics
# pytorch
import torch
import torch.nn as nn
# dataloader
try:
from internals import fbDataLoader, fbInputBatchFormatter
has_internal_libs = True
except ImportError:
has_internal_libs = False
from torch._ops import ops
from torch.autograd.profiler import record_function
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather, scatter
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
# mixed-dimension trick
from tricks.md_embedding_bag import md_solver, PrEmbeddingBag
# quotient-remainder trick
from tricks.qr_embedding_bag import QREmbeddingBag
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
try:
import onnx
except ImportError as error:
print("Unable to import onnx. ", error)
# from torchviz import make_dot
# import torch.nn.functional as Functional
# from torch.nn.parameter import Parameter
exc = getattr(builtins, "IOError", "FileNotFoundError")
def time_wrap(use_gpu):
if use_gpu:
torch.cuda.synchronize()
return time.time()
def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1):
with record_function("DLRM forward"):
if use_gpu: # .cuda()
# lS_i can be either a list of tensors or a stacked tensor.
# Handle each case below:
if ndevices == 1:
lS_i = (
[S_i.to(device) for S_i in lS_i]
if isinstance(lS_i, list)
else lS_i.to(device)
)
lS_o = (
[S_o.to(device) for S_o in lS_o]
if isinstance(lS_o, list)
else lS_o.to(device)
)
return dlrm(X.to(device), lS_o, lS_i)
def loss_fn_wrap(Z, T, use_gpu, device):
with record_function("DLRM loss compute"):
if args.loss_function == "mse" or args.loss_function == "bce":
return dlrm.loss_fn(Z, T.to(device))
elif args.loss_function == "wbce":
loss_ws_ = dlrm.loss_ws[T.data.view(-1).long()].view_as(T).to(device)
loss_fn_ = dlrm.loss_fn(Z, T.to(device))
loss_sc_ = loss_ws_ * loss_fn_
return loss_sc_.mean()
# The following function is a wrapper to avoid checking this multiple times in th
# loop below.
def unpack_batch(b):
if args.data_generation == "internal":
return fbInputBatchFormatter(b, args.data_size)
else:
# Experiment with unweighted samples
return b[0], b[1], b[2], b[3], torch.ones(b[3].size()), None
class LRPolicyScheduler(_LRScheduler):
def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
self.num_warmup_steps = num_warmup_steps
self.decay_start_step = decay_start_step
self.decay_end_step = decay_start_step + num_decay_steps
self.num_decay_steps = num_decay_steps
if self.decay_start_step < self.num_warmup_steps:
sys.exit("Learning rate warmup must finish before the decay starts")
super(LRPolicyScheduler, self).__init__(optimizer)
def get_lr(self):
step_count = self._step_count
if step_count < self.num_warmup_steps:
# warmup
scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps
lr = [base_lr * scale for base_lr in self.base_lrs]
self.last_lr = lr
elif self.decay_start_step <= step_count and step_count < self.decay_end_step:
# decay
decayed_steps = step_count - self.decay_start_step
scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2
min_lr = 0.0000001
lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs]
self.last_lr = lr
else:
if self.num_decay_steps > 0:
# freeze at last, either because we're after decay
# or because we're between warmup and decay
lr = self.last_lr
else:
# do not adjust
lr = self.base_lrs
return lr
### define dlrm in PyTorch ###
class DLRM_Net(nn.Module):
def create_mlp(self, ln, sigmoid_layer):
# build MLP layer by layer
layers = nn.ModuleList()
for i in range(0, ln.size - 1):
n = ln[i]
m = ln[i + 1]
# construct fully connected operator
LL = nn.Linear(int(n), int(m), bias=True)
# initialize the weights
# with torch.no_grad():
# custom Xavier input, output or two-sided fill
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
# approach 1
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
# approach 2
# LL.weight.data.copy_(torch.tensor(W))
# LL.bias.data.copy_(torch.tensor(bt))
# approach 3
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
layers.append(LL)
# construct sigmoid or relu operator
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
# approach 1: use ModuleList
# return layers
# approach 2: use Sequential container to wrap all layers
return torch.nn.Sequential(*layers)
def create_emb(self, m, ln, weighted_pooling=None):
emb_l = nn.ModuleList()
v_W_l = []
for i in range(0, ln.size):
if ext_dist.my_size > 1:
if i not in self.local_emb_indices:
continue
n = ln[i]
# construct embedding operator
if self.qr_flag and n > self.qr_threshold:
EE = QREmbeddingBag(
n,
m,
self.qr_collisions,
operation=self.qr_operation,
mode="sum",
sparse=True,
)
elif self.md_flag and n > self.md_threshold:
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
# use np initialization as below for consistency...
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, _m)
).astype(np.float32)
EE.embs.weight.data = torch.tensor(W, requires_grad=True)
else:
EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)
# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
).astype(np.float32)
# approach 1
EE.weight.data = torch.tensor(W, requires_grad=True)
# approach 2
# EE.weight.data.copy_(torch.tensor(W))
# approach 3
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
if weighted_pooling is None:
v_W_l.append(None)
else:
v_W_l.append(torch.ones(n, dtype=torch.float32))
emb_l.append(EE)
return emb_l, v_W_l
def __init__(
self,
m_spa=None,
ln_emb=None,
ln_bot=None,
ln_top=None,
arch_interaction_op=None,
arch_interaction_itself=False,
sigmoid_bot=-1,
sigmoid_top=-1,
sync_dense_params=True,
loss_threshold=0.0,
ndevices=-1,
qr_flag=False,
qr_operation="mult",
qr_collisions=0,
qr_threshold=200,
md_flag=False,
md_threshold=200,
weighted_pooling=None,
loss_function="bce",
):
super(DLRM_Net, self).__init__()
if (
(m_spa is not None)
and (ln_emb is not None)
and (ln_bot is not None)
and (ln_top is not None)
and (arch_interaction_op is not None)
):
# save arguments
self.ndevices = ndevices
self.output_d = 0
self.parallel_model_batch_size = -1
self.parallel_model_is_not_prepared = True
self.arch_interaction_op = arch_interaction_op
self.arch_interaction_itself = arch_interaction_itself
self.sync_dense_params = sync_dense_params
self.loss_threshold = loss_threshold
self.loss_function = loss_function
if weighted_pooling is not None and weighted_pooling != "fixed":
self.weighted_pooling = "learned"
else:
self.weighted_pooling = weighted_pooling
# create variables for QR embedding if applicable
self.qr_flag = qr_flag
if self.qr_flag:
self.qr_collisions = qr_collisions
self.qr_operation = qr_operation
self.qr_threshold = qr_threshold
# create variables for MD embedding if applicable
self.md_flag = md_flag
if self.md_flag:
self.md_threshold = md_threshold
# If running distributed, get local slice of embedding tables
if ext_dist.my_size > 1:
n_emb = len(ln_emb)
if n_emb < ext_dist.my_size:
sys.exit(
"only (%d) sparse features for (%d) devices, table partitions will fail"
% (n_emb, ext_dist.my_size)
)
self.n_global_emb = n_emb
self.n_local_emb, self.n_emb_per_rank = ext_dist.get_split_lengths(
n_emb
)
self.local_emb_slice = ext_dist.get_my_slice(n_emb)
self.local_emb_indices = list(range(n_emb))[self.local_emb_slice]
# create operators
if ndevices <= 1:
self.emb_l, w_list = self.create_emb(m_spa, ln_emb, weighted_pooling)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList()
for w in w_list:
self.v_W_l.append(Parameter(w))
else:
self.v_W_l = w_list
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
self.top_l = self.create_mlp(ln_top, sigmoid_top)
# quantization
self.quantize_emb = False
self.emb_l_q = []
self.quantize_bits = 32
# specify the loss function
if self.loss_function == "mse":
self.loss_fn = torch.nn.MSELoss(reduction="mean")
elif self.loss_function == "bce":
self.loss_fn = torch.nn.BCELoss(reduction="mean")
elif self.loss_function == "wbce":
self.loss_ws = torch.tensor(
np.fromstring(args.loss_weights, dtype=float, sep="-")
)
self.loss_fn = torch.nn.BCELoss(reduction="none")
else:
sys.exit(
"ERROR: --loss-function=" + self.loss_function + " is not supported"
)
def apply_mlp(self, x, layers):
# approach 1: use ModuleList
# for layer in layers:
# x = layer(x)
# return x
# approach 2: use Sequential container to wrap all layers
return layers(x)
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
# WARNING: notice that we are processing the batch at once. We implicitly
# assume that the data is laid out such that:
# 1. each embedding is indexed with a group of sparse indices,
# corresponding to a single lookup
# 2. for each embedding the lookups are further organized into a batch
# 3. for a list of embedding tables there is a list of batched lookups
ly = []
for k, sparse_index_group_batch in enumerate(lS_i):
sparse_offset_group_batch = lS_o[k]
# embedding lookup
# We are using EmbeddingBag, which implicitly uses sum operator.
# The embeddings are represented as tall matrices, with sum
# happening vertically across 0 axis, resulting in a row vector
# E = emb_l[k]
if v_W_l[k] is not None:
per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
else:
per_sample_weights = None
if self.quantize_emb:
s1 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement()
s2 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement()
print("quantized emb sizes:", s1, s2)
if self.quantize_bits == 4:
QV = ops.quantized.embedding_bag_4bit_rowwise_offsets(
self.emb_l_q[k],
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
elif self.quantize_bits == 8:
QV = ops.quantized.embedding_bag_byte_rowwise_offsets(
self.emb_l_q[k],
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(QV)
else:
E = emb_l[k]
V = E(
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(V)
# print(ly)
return ly
# using quantizing functions from caffe2/aten/src/ATen/native/quantized/cpu
def quantize_embedding(self, bits):
n = len(self.emb_l)
self.emb_l_q = [None] * n
for k in range(n):
if bits == 4:
self.emb_l_q[k] = ops.quantized.embedding_bag_4bit_prepack(
self.emb_l[k].weight
)
elif bits == 8:
self.emb_l_q[k] = ops.quantized.embedding_bag_byte_prepack(
self.emb_l[k].weight
)
else:
return
self.emb_l = None
self.quantize_emb = True
self.quantize_bits = bits
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
# concatenate dense and sparse features
(batch_size, d) = x.shape
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
# perform a dot product
Z = torch.bmm(T, torch.transpose(T, 1, 2))
# append dense feature with the interactions (into a row vector)
# approach 1: all
# Zflat = Z.view((batch_size, -1))
# approach 2: unique
_, ni, nj = Z.shape
# approach 1: tril_indices
# offset = 0 if self.arch_interaction_itself else -1
# li, lj = torch.tril_indices(ni, nj, offset=offset)
# approach 2: custom
offset = 1 if self.arch_interaction_itself else 0
li = torch.tensor([i for i in range(ni) for j in range(i + offset)])
lj = torch.tensor([j for i in range(nj) for j in range(i + offset)])
Zflat = Z[:, li, lj]
# concatenate dense features and interactions
R = torch.cat([x] + [Zflat], dim=1)
elif self.arch_interaction_op == "cat":
# concatenation features (into a row vector)
R = torch.cat([x] + ly, dim=1)
else:
sys.exit(
"ERROR: --arch-interaction-op="
+ self.arch_interaction_op
+ " is not supported"
)
return R
def forward(self, dense_x, lS_o, lS_i):
if ext_dist.my_size > 1:
# multi-node multi-device run
return self.distributed_forward(dense_x, lS_o, lS_i)
elif self.ndevices <= 1:
# single device run
return self.sequential_forward(dense_x, lS_o, lS_i)
else:
# single-node multi-device run
return self.parallel_forward(dense_x, lS_o, lS_i)
def distributed_forward(self, dense_x, lS_o, lS_i):
batch_size = dense_x.size()[0]
# WARNING: # of ranks must be <= batch size in distributed_forward call
if batch_size < ext_dist.my_size:
sys.exit(
"ERROR: batch_size (%d) must be larger than number of ranks (%d)"
% (batch_size, ext_dist.my_size)
)
if batch_size % ext_dist.my_size != 0:
sys.exit(
"ERROR: batch_size %d can not split across %d ranks evenly"
% (batch_size, ext_dist.my_size)
)
dense_x = dense_x[ext_dist.get_my_slice(batch_size)]
lS_o = lS_o[self.local_emb_slice]
lS_i = lS_i[self.local_emb_slice]
if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
sys.exit(
"ERROR: corrupted model input detected in distributed_forward call"
)
# embeddings
with record_function("DLRM embedding forward"):
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# WARNING: Note that at this point we have the result of the embedding lookup
# for the entire batch on each rank. We would like to obtain partial results
# corresponding to all embedding lookups, but part of the batch on each rank.
# Therefore, matching the distribution of output of bottom mlp, so that both
# could be used for subsequent interactions on each device.
if len(self.emb_l) != len(ly):
sys.exit("ERROR: corrupted intermediate result in distributed_forward call")
a2a_req = ext_dist.alltoall(ly, self.n_emb_per_rank)
with record_function("DLRM bottom nlp forward"):
x = self.apply_mlp(dense_x, self.bot_l)
ly = a2a_req.wait()
ly = list(ly)
# interactions
with record_function("DLRM interaction forward"):
z = self.interact_features(x, ly)
# top mlp
with record_function("DLRM top nlp forward"):
p = self.apply_mlp(z, self.top_l)
# clamp output if needed
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
else:
z = p
return z
def sequential_forward(self, dense_x, lS_o, lS_i):
# process dense features (using bottom mlp), resulting in a row vector
x = self.apply_mlp(dense_x, self.bot_l)
# debug prints
# print("intermediate")
# print(x.detach().cpu().numpy())
# process sparse features(using embeddings), resulting in a list of row vectors
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# for y in ly:
# print(y.detach().cpu().numpy())
# interact features (dense and sparse)
z = self.interact_features(x, ly)
# print(z.detach().cpu().numpy())
# obtain probability of a click (using top mlp)
p = self.apply_mlp(z, self.top_l)
# clamp output if needed
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
else:
z = p
return z
def parallel_forward(self, dense_x, lS_o, lS_i):
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
batch_size = dense_x.size()[0]
ndevices = min(self.ndevices, batch_size, len(self.emb_l))
device_ids = range(ndevices)
# WARNING: must redistribute the model if mini-batch size changes(this is common
# for last mini-batch, when # of elements in the dataset/batch size is not even
if self.parallel_model_batch_size != batch_size:
self.parallel_model_is_not_prepared = True
if self.parallel_model_is_not_prepared or self.sync_dense_params:
# replicate mlp (data parallelism)
self.bot_l_replicas = replicate(self.bot_l, device_ids)
self.top_l_replicas = replicate(self.top_l, device_ids)
self.parallel_model_batch_size = batch_size
if self.parallel_model_is_not_prepared:
# distribute embeddings (model parallelism)
t_list = []
w_list = []
for k, emb in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
t_list.append(emb.to(d))
if self.weighted_pooling == "learned":
w_list.append(Parameter(self.v_W_l[k].to(d)))
elif self.weighted_pooling == "fixed":
w_list.append(self.v_W_l[k].to(d))
else:
w_list.append(None)
self.emb_l = nn.ModuleList(t_list)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList(w_list)
else:
self.v_W_l = w_list
self.parallel_model_is_not_prepared = False
### prepare input (overwrite) ###
# scatter dense features (data parallelism)
# print(dense_x.device)
dense_x = scatter(dense_x, device_ids, dim=0)
# distribute sparse features (model parallelism)
if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
sys.exit("ERROR: corrupted model input detected in parallel_forward call")
t_list = []
i_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
t_list.append(lS_o[k].to(d))
i_list.append(lS_i[k].to(d))
lS_o = t_list
lS_i = i_list
### compute results in parallel ###
# bottom mlp
# WARNING: Note that the self.bot_l is a list of bottom mlp modules
# that have been replicated across devices, while dense_x is a tuple of dense
# inputs that has been scattered across devices on the first (batch) dimension.
# The output is a list of tensors scattered across devices according to the
# distribution of dense_x.
x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
# debug prints
# print(x)
# embeddings
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# debug prints
# print(ly)
# butterfly shuffle (implemented inefficiently for now)
# WARNING: Note that at this point we have the result of the embedding lookup
# for the entire batch on each device. We would like to obtain partial results
# corresponding to all embedding lookups, but part of the batch on each device.
# Therefore, matching the distribution of output of bottom mlp, so that both
# could be used for subsequent interactions on each device.
if len(self.emb_l) != len(ly):
sys.exit("ERROR: corrupted intermediate result in parallel_forward call")
t_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
y = scatter(ly[k], device_ids, dim=0)
t_list.append(y)
# adjust the list to be ordered per device
ly = list(map(lambda y: list(y), zip(*t_list)))
# debug prints
# print(ly)
# interactions
z = []
for k in range(ndevices):
zk = self.interact_features(x[k], ly[k])
z.append(zk)
# debug prints
# print(z)
# top mlp
# WARNING: Note that the self.top_l is a list of top mlp modules that
# have been replicated across devices, while z is a list of interaction results
# that by construction are scattered across devices on the first (batch) dim.
# The output is a list of tensors scattered across devices according to the
# distribution of z.
p = parallel_apply(self.top_l_replicas, z, None, device_ids)
### gather the distributed results ###
p0 = gather(p, self.output_d, dim=0)
# clamp output if needed
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z0 = torch.clamp(
p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
)
else:
z0 = p0
return z0
def dash_separated_ints(value):
vals = value.split("-")
for val in vals:
try:
int(val)
except ValueError:
raise argparse.ArgumentTypeError(
"%s is not a valid dash separated list of ints" % value
)
return value
def dash_separated_floats(value):
vals = value.split("-")
for val in vals:
try:
float(val)
except ValueError:
raise argparse.ArgumentTypeError(
"%s is not a valid dash separated list of floats" % value
)
return value
def inference(
args,
dlrm,
best_acc_test,
best_auc_test,
test_ld,
device,
use_gpu,
log_iter=-1,
):
test_accu = 0
test_samp = 0
if args.mlperf_logging:
scores = []
targets = []
for i, testBatch in enumerate(test_ld):
# early exit if nbatches was set by the user and was exceeded
if nbatches > 0 and i >= nbatches:
break
X_test, lS_o_test, lS_i_test, T_test, W_test, CBPP_test = unpack_batch(
testBatch
)
# Skip the batch if batch size not multiple of total ranks
if ext_dist.my_size > 1 and X_test.size(0) % ext_dist.my_size != 0:
print("Warning: Skiping the batch %d with size %d" % (i, X_test.size(0)))
continue
# forward pass
Z_test = dlrm_wrap(
X_test,
lS_o_test,
lS_i_test,
use_gpu,
device,
ndevices=ndevices,
)
### gather the distributed results on each rank ###
# For some reason it requires explicit sync before all_gather call if
# tensor is on GPU memory
if Z_test.is_cuda:
torch.cuda.synchronize()
(_, batch_split_lengths) = ext_dist.get_split_lengths(X_test.size(0))
if ext_dist.my_size > 1:
Z_test = ext_dist.all_gather(Z_test, batch_split_lengths)
if args.mlperf_logging:
S_test = Z_test.detach().cpu().numpy() # numpy array
T_test = T_test.detach().cpu().numpy() # numpy array
scores.append(S_test)
targets.append(T_test)
else:
with record_function("DLRM accuracy compute"):
# compute loss and accuracy
S_test = Z_test.detach().cpu().numpy() # numpy array
T_test = T_test.detach().cpu().numpy() # numpy array
mbs_test = T_test.shape[0] # = mini_batch_size except last
A_test = np.sum((np.round(S_test, 0) == T_test).astype(np.uint8))
test_accu += A_test
test_samp += mbs_test
if args.mlperf_logging:
with record_function("DLRM mlperf sklearn metrics compute"):
scores = np.concatenate(scores, axis=0)
targets = np.concatenate(targets, axis=0)
metrics = {
"recall": lambda y_true, y_score: sklearn.metrics.recall_score(
y_true=y_true, y_pred=np.round(y_score)
),
"precision": lambda y_true, y_score: sklearn.metrics.precision_score(
y_true=y_true, y_pred=np.round(y_score)
),
"f1": lambda y_true, y_score: sklearn.metrics.f1_score(
y_true=y_true, y_pred=np.round(y_score)
),
"ap": sklearn.metrics.average_precision_score,
"roc_auc": sklearn.metrics.roc_auc_score,
"accuracy": lambda y_true, y_score: sklearn.metrics.accuracy_score(
y_true=y_true, y_pred=np.round(y_score)
),
}
validation_results = {}
for metric_name, metric_function in metrics.items():
validation_results[metric_name] = metric_function(targets, scores)
writer.add_scalar(
"mlperf-metrics-test/" + metric_name,
validation_results[metric_name],
log_iter,
)
acc_test = validation_results["accuracy"]
else:
acc_test = test_accu / test_samp
writer.add_scalar("Test/Acc", acc_test, log_iter)
model_metrics_dict = {
"nepochs": args.nepochs,
"nbatches": nbatches,
"nbatches_test": nbatches_test,
"state_dict": dlrm.state_dict(),
"test_acc": acc_test,
}
if args.mlperf_logging:
is_best = validation_results["roc_auc"] > best_auc_test
if is_best:
best_auc_test = validation_results["roc_auc"]
model_metrics_dict["test_auc"] = best_auc_test
print(
"recall {:.4f}, precision {:.4f},".format(
validation_results["recall"],
validation_results["precision"],
)
+ " f1 {:.4f}, ap {:.4f},".format(
validation_results["f1"], validation_results["ap"]
)
+ " auc {:.4f}, best auc {:.4f},".format(
validation_results["roc_auc"], best_auc_test
)
+ " accuracy {:3.3f} %, best accuracy {:3.3f} %".format(
validation_results["accuracy"] * 100, best_acc_test * 100
),
flush=True,
)
else:
is_best = acc_test > best_acc_test
if is_best:
best_acc_test = acc_test
print(
" accuracy {:3.3f} %, best {:3.3f} %".format(
acc_test * 100, best_acc_test * 100
),
flush=True,
)
return model_metrics_dict, is_best
def run():
### parse arguments ###
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
# model related parameters
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
parser.add_argument(
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
)
# j will be replaced with the table number
parser.add_argument("--arch-mlp-bot", type=dash_separated_ints, default="4-3-2")
parser.add_argument("--arch-mlp-top", type=dash_separated_ints, default="4-2-1")
parser.add_argument(
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
)
parser.add_argument("--arch-interaction-itself", action="store_true", default=False)
parser.add_argument("--weighted-pooling", type=str, default=None)
# embedding table options
parser.add_argument("--md-flag", action="store_true", default=False)
parser.add_argument("--md-threshold", type=int, default=200)
parser.add_argument("--md-temperature", type=float, default=0.3)
parser.add_argument("--md-round-dims", action="store_true", default=False)
parser.add_argument("--qr-flag", action="store_true", default=False)
parser.add_argument("--qr-threshold", type=int, default=200)
parser.add_argument("--qr-operation", type=str, default="mult")
parser.add_argument("--qr-collisions", type=int, default=4)
# activations and loss
parser.add_argument("--activation-function", type=str, default="relu")
parser.add_argument("--loss-function", type=str, default="mse") # or bce or wbce
parser.add_argument(
"--loss-weights", type=dash_separated_floats, default="1.0-1.0"
) # for wbce
parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7
parser.add_argument("--round-targets", type=bool, default=False)
# data
parser.add_argument("--data-size", type=int, default=1)
parser.add_argument("--num-batches", type=int, default=0)
parser.add_argument(
"--data-generation",
type=str,
choices=["random", "dataset", "internal"],
default="random",
) # synthetic, dataset or internal
parser.add_argument(
"--rand-data-dist", type=str, default="uniform"
) # uniform or gaussian
parser.add_argument("--rand-data-min", type=float, default=0)
parser.add_argument("--rand-data-max", type=float, default=1)
parser.add_argument("--rand-data-mu", type=float, default=-1)
parser.add_argument("--rand-data-sigma", type=float, default=1)
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log")
parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte
parser.add_argument("--raw-data-file", type=str, default="")
parser.add_argument("--processed-data-file", type=str, default="")
parser.add_argument("--data-randomize", type=str, default="total") # or day or none
parser.add_argument("--data-trace-enable-padding", type=bool, default=False)
parser.add_argument("--max-ind-range", type=int, default=-1)
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1]
parser.add_argument("--num-indices-per-lookup", type=int, default=10)
parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--memory-map", action="store_true", default=False)
# training
parser.add_argument("--mini-batch-size", type=int, default=1)
parser.add_argument("--nepochs", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=0.01)
parser.add_argument("--print-precision", type=int, default=5)
parser.add_argument("--numpy-rand-seed", type=int, default=123)
parser.add_argument("--sync-dense-params", type=bool, default=True)
parser.add_argument("--optimizer", type=str, default="sgd")
parser.add_argument(
"--dataset-multiprocessing",
action="store_true",
default=False,
help="The Kaggle dataset can be multiprocessed in an environment \
with more than 7 CPU cores and more than 20 GB of memory. \n \
The Terabyte dataset can be multiprocessed in an environment \
with more than 24 CPU cores and at least 1 TB of memory.",
)
# inference
parser.add_argument("--inference-only", action="store_true", default=False)
# quantize
parser.add_argument("--quantize-mlp-with-bit", type=int, default=32)
parser.add_argument("--quantize-emb-with-bit", type=int, default=32)
# onnx
parser.add_argument("--save-onnx", action="store_true", default=False)
# gpu
parser.add_argument("--use-gpu", action="store_true", default=False)
# distributed
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dist-backend", type=str, default="")
# debugging and profiling
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
parser.add_argument("--test-mini-batch-size", type=int, default=-1)
parser.add_argument("--test-num-workers", type=int, default=-1)
parser.add_argument("--print-time", action="store_true", default=False)
parser.add_argument("--print-wall-time", action="store_true", default=False)