forked from ktorch/ktorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathktorch.h
1352 lines (1234 loc) · 57.4 KB
/
ktorch.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#pragma once
#ifdef __clang__
# pragma clang diagnostic push
# pragma GCC diagnostic ignored "-Wgnu-anonymous-struct" // k.h warning
# pragma GCC diagnostic ignored "-Wnested-anon-types" // k.h warning
# pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" // ATen.h VA_ARG warning, FORWARD_HAS_DEFAULT_ARGS
# pragma clang diagnostic ignored "-Wunused-function" // private.h generates 'unused function' warnings
# pragma clang diagnostic ignored "-Wc++1z-extensions" // nodiscard & fallthrough warnings
#elif defined __GNUC__
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wpedantic"
# pragma GCC diagnostic ignored "-Wunused-function"
#endif
#define KXVER 3
#include "k.h"
#undef P
#undef R
#undef U
#undef Z
#undef xs
#define KSHORT 5
#undef KH // conflict introduced with v1.10.0, include/ATen/ops/avg_pool2d_meta.h: template <bool KH..
#include <stack>
#include "torch/torch.h"
#include "private.h"
// access private name_ & buffers_ of Module
using Tensor = torch::Tensor;
using TensorDict = torch::OrderedDict<std::string, Tensor>;
using Module = torch::nn::Module;
using Moduleptr = std::shared_ptr<Module>;
using Modulemap = torch::OrderedDict<std::string, Moduleptr>;
using Generator = torch::Generator;
ACCESS_PRIVATE_FIELD(Module, c10::optional<std::string>, name_)
ACCESS_PRIVATE_FIELD(Module, TensorDict, buffers_)
ACCESS_PRIVATE_FIELD(Module, TensorDict, parameters_)
ACCESS_PRIVATE_FIELD(Module, Modulemap, children_)
#ifdef __clang__
# pragma clang diagnostic pop
#elif defined __GNUC__
# pragma GCC diagnostic pop
#endif
#define TORCH_ERROR(...) \
do { \
C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
} while (false)
#define KFN(f) reinterpret_cast<void *>(f)
#define KERR(e) krr((S)e)
#define KDICT xD(ktn(KS,0),ktn(0,0))
#define KTRY \
try {
#define KCATCH(x) \
} catch (const c10::Error &e) { \
return KERR(krrbuf(env().frame ? e.what() : e.what_without_backtrace())); \
} catch (const std::exception &e) { \
return KERR(krrbuf(e.what())); \
} catch (...) { \
return KERR(x); \
}
#ifdef __cplusplus
# define KEXT extern "C"
#else
# define KEXT
#endif
#ifdef _WIN32
# define KAPI KEXT __declspec(dllexport) K
#else
# define KAPI KEXT K
#endif
#define Ksize torch::SmallVector<int64_t,8>
#define cs(x) ss((S)x)
#define KEX(x) kexpand(x.size(),(*x).data()) // k list from ExpandingArray
#define ESYM(v) std::visit(esym(), v)
using Ktype=signed char;
using Device=torch::Device;
using DeviceType=torch::DeviceType;
using Storage=torch::Storage;
using Scalar=torch::Scalar;
using TensorVector=std::vector<Tensor>;
using TensorDeque=std::deque<Tensor>;
using LongVector=std::vector<int64_t>;
using DoubleVector=std::vector<double>;
using IntArrayRef=torch::IntArrayRef;
using SymArrayRef=torch::ArrayRef<S>;
using DoubleArrayRef=torch::ArrayRef<double>;
template<size_t D,typename T=int64_t> using ExpandingArray=torch::ExpandingArray<D,T>;
template<size_t D,typename T=double> using Exdouble=torch::ExpandingArray<D,T>;
template<size_t D,typename T=int64_t> using Exoptional=torch::ExpandingArrayWithOptionalElem<D,T>;
using Dtype=torch::Dtype;
using TypeMeta=caffe2::TypeMeta;
using ScalarType=torch::ScalarType;
using TensorOptions=torch::TensorOptions;
using TensorList=torch::TensorList; // ArrayRef<Tensor>
// shorter names for commonly used module structures
using Modules = std::stack<Moduleptr>;
using Modulepairs = std::vector<std::pair<std::string, std::shared_ptr<Module>>>;
using AnyModule = torch::nn::AnyModule;
struct Empty{};
using Tuple = std::tuple<Tensor,Tensor>;
using Tuple3 = std::tuple<Tensor,Tensor,Tensor>;
using Tuple4 = std::tuple<Tensor,Tensor,Tensor,Tensor>;
using Nested = std::tuple<Tensor,Tuple>;
using Input = std::variant<Tensor,TensorVector,TensorDict,Empty>;
using Output = std::variant<Tensor,Tuple,Nested,TensorVector>;
using MetricData = std::vector<TensorVector>;
using Optimizer = torch::optim::Optimizer;
using Optptr = std::shared_ptr<Optimizer>;
typedef struct Pairs {
Ktype a = 0; // type: 1-dict, 2-list of pairs, 3-general list, 4-sym list
Ktype t = 0; // type of value in last pair processed
H i = 0; // next pair to process
H n = 0; // count of pairs
S k = 0; // name of current name,value pair
K x = 0; // k value with dict/pairs/list
union {
bool b; // boolean value from current pair
J j; // long value
float e; // float value
double f; // double value
S s; // symbol value
K v; // value (isn't sym or numeric scalar)
};
} Pairs;
enum class Class:short {
undefined=0,
tensor,
vector,
dict,
module,
loss,
optimizer,
model,
train,
test
};
enum class Arg:short { // type of input(s) & output for callback modules
undefined=0,
boolean,
tensor,
tuple,
nested,
vector,
dict
};
enum class Cast:short {
undefined=0,
tensor, parameter, buffer, model, // basic structures
callback, moduledict, modulelist, parmdict, // container modules
sequential, seqdict, seqlist, seqnest, seqjoin,
adaptavg1d, adaptavg2d, adaptavg3d, adaptmax1d, adaptmax2d, // modules
adaptmax3d, adrop, attention, avgpool1d, avgpool2d,
avgpool3d, batchnorm1d, batchnorm2d, batchnorm3d, bilinear,
cat, celu, conv1d, conv2d, conv3d,
convtranspose1d, convtranspose2d, convtranspose3d, crossmap2d, decoder,
decoderlayer, drop, drop2d, drop3d, droppath,
elu, embed, embedbag, embedpos, embedseq,
encoder, encoderlayer, expand, fadrop, flatten,
fmaxpool2d, fmaxpool3d, fold, fork, gelu,
glu, groupnorm, gru, hardshrink, hardtanh,
identity, indexselect, instancenorm1d, instancenorm2d, instancenorm3d,
interpolate, layernorm, leakyrelu, linear, localnorm,
logsigmoid, logsoftmax, lppool1d, lppool2d, lstm,
matmul, maxpool1d, maxpool2d, maxpool3d, mish,
mul, nbeats, normalize, onehot, pad,
pad1d, pad2d, pad3d, permute, prelu,
randomcrop, randomflip, recur, reflect1d, reflect2d,
relu, relu6, replicate1d, replicate2d, replicate3d,
reshape, residual, rnn, rrelu, select,
selfattention, selu, sigmoid, silu, softmax,
softmax2d, softmin, softplus, softshrink, softsign,
squeeze, tanh, tanhshrink, threshold, transform,
transformer, transpose, unfold, unsqueeze, upsample,
zeropad2d, zscore,
pairwise, similar, // distance functions
bce, bcelogits, ce, cosineloss, ctc, hinge, // loss fns
huber, kl, l1, margin, mse, multilabel,
multimargin, multisoft, nll, poissonloss, smoothl1, softmargin, triplet,
adagrad, adam, adamw, lamb, lbfgs, rmsprop, sgd // optimizers
};
using Args=std::vector<Arg>; // module's forward arg type(s)
using Attrs=std::tuple<S, // 0 module symbol
Cast, // 1 module enumeration
size_t, // 2 typeid hash
const char*, // 3 description
bool, // 4 true if has non-templatized forward
Arg, // 5 result type
size_t, // 6 min number of arguments
size_t, // 7 max number of arguments
Args>; // 8 argument type(s)
using AttrRef=torch::ArrayRef<Attrs>; // array reference (for modules/losses)
using ModuleAttrs=std::array<Attrs,128>; // global list of module attributes
using LossAttrs=std::array<Attrs,21>; // global list of loss modules
enum class Tensormode:char { // tensor creation modes
undefined,
arange, complex, empty, eye, full, linspace, logspace,
ones, rand, randint, randn, randperm, range, sparse,
zeros
};
enum class Setting:uint8_t {
undefined,
addbias, addzero, affine, align, alloptions,
alpha, amsgrad, batchfirst, batchsize, benchmark,
beta, beta1, beta2, bi, bias,
blank, buffers, ceiling, centered, changetol,
channels, classes, clipgroup, clipnorm, clipvalue,
cols, complexfirst, countpad, cuda, cudadevices,
cudnn, cudnndeterministic, cudnnversion, dampening, decay,
decoder, decoderlayer, delta, detach, deterministic,
dictionary, dilate, dim, dim0, dim1,
divisor, dlayers, droplast, dropout, dtype,
elayers, encoder, encoderlayer, end, eps,
eval, fn, freeze, full, globalnorm,
gradtol, groups, heads, hidden, history,
ignore, in, in1, in2, ind,
indices, init, inplace, interopthreads, iter,
k, kdim, keepdim, kvbias, kvzeros,
lambda, lastoffset, layernorm, layers, length,
log, lower, lr, lrdecay, magma,
margin, max, maxnorm, mean, metrics,
min, mkl, mode, mps, momentum, nesterov,
norm, openmp, out, outpad, outsize,
p, pad, padflag, padindex, padmode,
parms, ratio, reduce, rescale, rows,
scale, search, shape, shuffle, shufflecuda, shuffleseed, size,
slope, smoothing, sparse, stackframe, start,
std, stride, swap, sync, task,
tasks, tensor, threads, threshold, track,
train, transpose, trustclip, trustmax, trustmin,
unbiased, upper, value, vdim, weight,
zeroinf
};
enum class State:char {
buffers, depth, loss, module, name, options, optimizer,
parms, parmgroup, pointer, size, train, test
};
enum class Attr:char {
undefined = 0,
ktype, // char
bytes, densedim, dim, elements, itemsize, nnz, numel, offset, // long scalars
ptr, ref, sparsedim, sptr, sref, tensorcount, weakref,
device, dtype, gradfn, gradient, inputmodule, outputmodule, // symbol
layout, memory, result,
coalesced, contiguous, contiguous2d, contiguous3d, defined, // boolean
gradflag, leaf, pinned, sparseflag,
size, stride, // long list
data, storage // other: list,dict,..
};
enum class Metric:char {
batchloss, loss, accuracy, matches, predict, output, hidden, hiddencell
};
using Metrics = std::vector<Metric>;
enum class Help:char {
undefined=0, backward, device, dtype, ktype
};
enum class Enum { // enums to match pytorch variants
undefined=-1,
area, batchmean, bicubic, bilinear, border,
circular, constant, conv1d, conv2d, conv3d,
convtranspose1d, convtranspose2d, convtranspose3d, fanin, fanout,
gelu, leakyrelu, linear, max, mean,
mish, nearest, nearestexact, none, reflect,
reflection, relu, replicate, same, sigmoid,
silu, sum, tanh, trilinear, valid,
zeros
};
struct TORCH_API Kmodule;
struct TORCH_API Ktag {
Class a = Class::undefined;
Cast c = Cast::undefined;
virtual ~Ktag() = default;
virtual void set(const Tensor& t) {TORCH_ERROR("unable to set tensor");}
virtual void set(const TensorVector& v) {TORCH_ERROR("unable to set dict");}
virtual void set(const TensorDict& d) {TORCH_ERROR("unable to set dictionary");}
virtual Tensor& tensor() {TORCH_ERROR("unable to retrieve tensor");}
virtual const Tensor& tensor() const {TORCH_ERROR("unable to retrieve tensor");}
virtual TensorVector& vector() {TORCH_ERROR("unable to retrieve vector");}
virtual const TensorVector& vector() const {TORCH_ERROR("unable to retrieve vector");}
virtual TensorDict& dict() {TORCH_ERROR("unable to retrieve dictionary");}
virtual const TensorDict& dict() const {TORCH_ERROR("unable to retrieve dictionary");}
virtual Kmodule* kmodule() {TORCH_ERROR("unable to retrieve module");}
virtual Module& module() {TORCH_ERROR("unable to retrieve module");}
virtual const Module& module() const {TORCH_ERROR("unable to retrieve module");}
virtual Moduleptr& moduleptr() {TORCH_ERROR("unable to retrieve module pointer");}
virtual const Moduleptr& moduleptr() const {TORCH_ERROR("unable to retrieve module pointer");}
virtual Optimizer& opt() {TORCH_ERROR("unable to retrieve optimizer");}
virtual const Optimizer& opt() const {TORCH_ERROR("unable to retrieve optimizer");}
virtual Optptr& optptr() {TORCH_ERROR("unable to retrieve optimizer pointer");}
virtual const Optptr& optptr() const {TORCH_ERROR("unable to retrieve optimizer pointer");}
};
struct TORCH_API Kten : public Ktag {
Tensor t;
Kten(const Tensor& x) : t(std::move(x)) {a=Class::tensor; c=Cast::tensor;}
Tensor& tensor() {return t;}
const Tensor& tensor() const {return t;}
void set(const Tensor& x) {t=std::move(x);}
};
struct TORCH_API Kvec : public Ktag {
TensorVector v;
Kvec(const TensorVector& x) : v(std::move(x)) {a=Class::vector; c=Cast::tensor;}
TensorVector& vector() {return v;}
const TensorVector& vector() const {return v;}
void set(const TensorVector& x) {v=std::move(x);}
};
struct TORCH_API Kdict : public Ktag {
TensorDict d;
Kdict(const TensorDict& x,Cast y=Cast::tensor) : d(std::move(x)) {a=Class::dict; c=y;}
TensorDict& dict() {return d;}
const TensorDict& dict() const {return d;}
void set(const TensorDict& x) {d=std::move(x);}
};
struct TORCH_API TrainOptions {
using Doubles = std::array<double,2>;
TORCH_ARG(int64_t, batchsize) = 32;
TORCH_ARG(int64_t, task) = 0;
TORCH_ARG(int64_t, tasks) = 1;
TORCH_ARG(int64_t, shuffleseed) = 0;
TORCH_ARG(bool, droplast) = false;
TORCH_ARG(bool, hidden) = false;
TORCH_ARG(bool, shuffle) = false;
TORCH_ARG(bool, shufflecuda) = false;
TORCH_ARG(bool, tensor) = false;
TORCH_ARG(bool, dictionary) = false;
TORCH_ARG(bool, sync) = false;
TORCH_ARG(bool, clipgroup) = false;
TORCH_ARG(c10::optional<Doubles>, clipnorm);
TORCH_ARG(c10::optional<double>, clipvalue);
TORCH_ARG(Metrics, metrics) = {Metric::loss};
};
struct TORCH_API TestOptions {
TORCH_ARG(int64_t, batchsize) = 100;
TORCH_ARG(int64_t, task) = 0;
TORCH_ARG(int64_t, tasks) = 1;
TORCH_ARG(bool, droplast) = false;
TORCH_ARG(bool, hidden) = false;
TORCH_ARG(bool, tensor) = false;
TORCH_ARG(bool, dictionary) = false;
TORCH_ARG(Metrics, metrics) = {Metric::loss};
};
struct TORCH_API ForwardOptions {
TORCH_ARG(Cast, in) = Cast::undefined; // type of module accepting input
TORCH_ARG(Cast, out) = Cast::undefined; // type of module returning output
TORCH_ARG(bool, f) = false; // true if non-templated forward exists
TORCH_ARG(Arg, r) = Arg::undefined; // type of result
TORCH_ARG(size_t, n) = 0; // number of required arguments
TORCH_ARG(size_t, m) = 0; // maximum number of arguments
TORCH_ARG(Args, a); // vector of argument type(s)
};
void forwardoptions(Cast,ForwardOptions&,const Module&); // initializes options during module construction
struct TORCH_API Kmodule : public Ktag {
Kmodule(Class x,Cast y,const Moduleptr& p) : m(std::move(p)) {a=x; c=y; forwardoptions(c,f,*m);}
Kmodule* kmodule() {return this;};
Module& module() {return *m;}
const Module& module() const {return *m;}
Moduleptr& moduleptr() {return m;}
const Moduleptr& moduleptr() const {return m;}
Moduleptr m; // generic module pointer, with specific run-tyme type
ForwardOptions f; // options describing forward calculation
c10::optional<Device> d; // initialized if m.to() called or if forward call uses k array
};
struct TORCH_API Kopt : public Ktag {
Optptr o; // shared ptr with optimizer
Moduleptr m; // single module or container holding all modules/tensors managed by optimizer
Kopt(Cast x,const Optptr& y,const Moduleptr& m) : o(std::move(y)),m(std::move(m)) {a=Class::optimizer; c=x;}
Optimizer& opt() {return *o;}
const Optimizer& opt() const {return *o;}
Optptr& optptr() {return o;}
const Optptr& optptr() const {return o;}
Module& module() {return *m;}
const Module& module() const {return *m;}
Moduleptr& moduleptr() {return m;}
const Moduleptr& moduleptr() const {return m;}
};
struct TORCH_API Data {
TORCH_ARG(int64_t, size) = -1; // size of tensors (along batching dimension)
TORCH_ARG(int64_t, batchsize) = -1; // size of batches
TORCH_ARG(int64_t, batch) = -1; // current batch
TORCH_ARG(int64_t, batches) = -1; // overall number of batches
public:
Input x = Empty(); // model input(s)
Input y = Empty(); // model target(s)
Output z; // model output for latest batch
Tensor l; // tensor loss for latest batch (if required)
Tensor p; // permutation index if shuffled
Generator g; // generator (used for permutation index across tasks)
MetricData m; // metrics stored in vector of vectors for each batch
};
struct TORCH_API Kmodel : public Ktag {
Kmodel(Kmodule *x,Kmodule *y,Kopt *z) : q(*x), l(*y), o(*z) {
a=Class::model; c=Cast::model;
}
Kmodule* kmodule() {return &q;}
Kmodule* kloss() {return &l;}
Kopt* kopt() {return &o;}
Module& module() {return *q.m;}
const Module& module() const {return *q.m;}
Moduleptr& moduleptr() {return q.m;}
const Moduleptr& moduleptr() const {return q.m;}
Optimizer& opt() {return *o.o;}
const Optimizer& opt() const {return *o.o;}
Optptr& optptr() {return o.o;}
const Optptr& optptr() const {return o.o;}
Kmodule q;
Kmodule l;
Kopt o;
TrainOptions train;
TestOptions test;
Data data;
Data testdata;
};
S krrbuf(const char*);
void dictadd(K,S,K);
void dictadd(K,const char*,K);
bool xind(K,J);
bool xind(K,J,Ktype);
K kptr(void*);
bool ptrtype(K);
bool ptrflag(K);
bool mapped(K);
bool xptr(K);
bool xptr(K,J);
Ktag* xtag(K);
Ktag* xtag(K,J);
bool null(const char*);
bool null(const J);
bool match(const Scalar&,const Scalar&);
K kscalar(const Scalar&);
K resolvedict(K);
K resolve(K);
J xlen(K);
J xlen(K,J);
const char* kname(Ktype);
const char* kname(K);
const char* kname(K,J);
J ksizeof(Ktype);
Ktype maptype(TypeMeta);
TypeMeta maptype(Ktype);
S mapclass(Class);
S mapattr(Attr);
void print_tensor(std::ostream&,int64_t,const Tensor& t);
Enum emap(S);
S emap(Enum);
S inputname(const Input&);
S outputname(const Output&);
S statekey(State);
J statefind(State,K,bool r=false);
S statesym(State,bool,K,J j=-1);
K statedict(State,K,J j=-1);
K statetable(State,K);
J statedepth(K x,J j=-1);
S statemodule(K x,J j=-1);
S statename(K x,J j=-1);
K stateoptions(K x,J j=-1);
K stateparms(K x,J j=-1);
K statebuffers(K x,J j=-1);
J stategroup(K x,J j=-1);
K statesize(K x,J j=-1);
K statecol(State,K,short t=nh);
void stateparms(S,Module&,K,bool);
S nullsym();
K knull();
bool nullsym(S);
bool nullsym(K);
bool xnull(K);
bool xnull(K,J);
bool xempty(K);
bool xempty(K,J);
bool xarray(K,J);
bool xsym(K);
bool xsym(K,J);
bool xsym(K,S&);
bool xsym(K,J,S&);
bool xsyms(K,S&);
bool xsyms(K,SymArrayRef&);
bool xsyms(K,J,SymArrayRef&);
bool xdev(K,Device&);
bool xdev(K,J,Device&);
bool xint64(K,int64_t&);
bool xint64(K,J,int64_t&);
bool xint64(K,c10::optional<int64_t>&);
bool xint64(K,J,c10::optional<int64_t>&);
bool xlong(K,J&);
bool xlong(K,J,J&);
bool xlong(K,J&,J*&);
bool xlong(K,J,J&,J*&);
bool xdouble(K,double&);
bool xdouble(K,J,double&);
bool xdouble(K,J&,double *&);
bool xdouble(K,J,J&,double *&);
bool xdict(K);
bool xdict(K,J);
bool xstate(K);
bool xstate(K,J);
bool xsize(K,IntArrayRef&);
bool xsize(K,J,IntArrayRef&);
bool xsize(K,J,int64_t*);
bool xsize(K,J,double*);
bool xsize(K,J,J,int64_t*);
bool xsize(K,J,J,double*);
bool xten(K,Tensor&);
bool xten(K,J,Tensor&);
Tensor* xten(K);
Tensor* xten(K,J);
Tensor* xout(K);
bool xtenarg(K,J,Tensor&,Tensor&);
bool xtenarg(K,J,Tensor&,Tensor&,Tensor&);
bool xtenarg(K,Tensor&,Tensor&);
bool xtenarg(K,Tensor&,Tensor&,Tensor&);
TensorVector xtensors(K x,bool& p,const char* c);
Kmodule* xmodule(K);
Kmodule* xmodule(K,J);
Kmodule* xloss(K);
Kmodule* xloss(K,J);
Kopt* xoptim(K);
Kopt* xoptim(K,J);
Kmodel* xmodel(K);
Kmodel* xmodel(K,J);
bool xparm(K,Cast,S&,Tensor&);
TensorVector* xvec(K);
TensorVector* xvec(K,J);
TensorDict* xtensordict(K);
TensorDict* xtensordict(K,J);
bool xnum(K,double&);
bool xnum(K,J,double&);
bool xnum(K,Scalar&);
bool xnum(K,J,Scalar&);
bool xnumn(K,c10::optional<Scalar>&);
bool xnumn(K,J,c10::optional<Scalar>&);
bool xnumt(K,Scalar&);
bool xnumt(K,J,Scalar&);
bool xnumlist(K,J,Scalar&);
bool xbyte(K,Scalar&);
bool xbyte(K,J,Scalar&);
bool xscalar(K,Scalar&);
bool xscalar(K,J,Scalar&);
bool xbool(K,bool&);
bool xbool(K,J,bool&);
TypeMeta mtype(S);
S mtype(TypeMeta);
Dtype stype(S);
S stype(Dtype);
S stype(c10::optional<Dtype>);
bool xtype(K,Dtype&);
bool xtype(K,J,Dtype&);
bool xtype(K,c10::optional<Dtype>&);
bool xtype(K,J,c10::optional<Dtype>&);
bool xtype(K,TypeMeta&);
bool xtype(K,J,TypeMeta&);
bool xopt(S,TensorOptions&);
bool xopt(K,TensorOptions&);
bool xopt(K,J,TensorOptions&);
bool xmode(K,S&,Tensormode&);
bool xmode(K,J,S&,Tensormode&);
S modesym(Tensormode&);
bool xbacksym(K,bool&,bool&);
bool xbacksym(K,J,bool&,bool&);
bool xpairs(K,Pairs&);
bool xpairs(K,J,Pairs&);
bool xpair(Pairs&);
J xargc(K,J,Pairs&);
bool xnone(K,J);
S psym(const Pairs&);
Dtype ptype(const Pairs&);
void perr(const Pairs&,const char*);
bool pempty(const Pairs&);
bool pbool(const Pairs&);
J plong(const Pairs&);
double pdouble(const Pairs&);
void pnum(const Pairs&,Scalar&);
void psize(const Pairs&,IntArrayRef&,J n=-1);
void psize(const Pairs&,J,int64_t*);
void psize(const Pairs&,J,double*);
void pdoubles(const Pairs&,DoubleArrayRef&,J n=-1);
void pten(const Pairs&,Tensor&);
S& optdev(const Device&);
S& optdtype(const TypeMeta&);
S& optdtype(ScalarType);
S& optlayout(const torch::Layout&);
S& optmemory(const c10::optional<torch::MemoryFormat>&);
torch::MemoryFormat optmemory(S);
S& optgrad(const bool&);
S& optpin(const bool&);
K optkey();
K optval(const Tensor &t,K x,J i=-1);
K optval(const TensorOptions &o,K x,J i=-1);
K optmap(const Tensor&);
K optmap(const TensorOptions&);
S argname(Arg);
Arg argtype(S s,const char *c=nullptr);
K arglist(const Args&);
std::string kstring(K);
std::string kstring(K,J);
K kshow(K);
K kcast(Ktype,K);
K kbool(K);
J kfind(K,const std::string&);
K klist(J,const int64_t*);
K klist(J,const double*);
K klist(J,const c10::optional<int64_t>*);
K kexpand(J,const int64_t*);
K kexpand(J,const double*);
K kexpand(J,const c10::optional<int64_t>*e);
J xdv(K);
J xdv(K,J);
J dvd(K,J);
K dvv(K,J);
c10::optional<Device> firstdevice(const Tensor&);
c10::optional<Device> firstdevice(const TensorVector&);
c10::optional<Device> firstdevice(const TensorDict&);
c10::optional<Device> firstdevice(const Input&);
void sync(int64_t);
void sync(const Device&);
Device defaultdevice(const c10::optional<Device>);
S objdevice(const Tensor&);
S objdevice(const TensorVector&,S);
J objnum(int64_t);
J objnum(double);
J objnum(const Tensor&);
J objnum(const TensorVector&);
J objnum(const c10::optional<TensorVector>&);
J objnum(const TensorDeque&);
J objnum(const Module&);
J objnum(Cast,const Optimizer&);
J objbytes(int64_t);
J objbytes(double);
J objbytes(const Tensor&);
J objbytes(const TensorVector&);
J objbytes(const c10::optional<TensorVector>&);
J objbytes(const TensorDeque&);
J objbytes(const Module&);
J objbytes(Cast,const Optimizer&);
bool kfree(K);
bool kfree(K,J);
bool xfree(K);
void kfree(const std::vector<K>&);
void fn(K,const char*,void*,I);
void randomfn(K);
void mathfn(K);
K attr(K,Ktype,Attr);
S tensortype(Cast);
void castsym(S,Class&,Cast&);
Cast castsym(S);
// tensor & vector routines:
K kget(const Tensor&);
K kget(const LongVector&);
K kget(const DoubleVector&);
K kget(const TensorVector& v,K x=nullptr);
K kget(const TensorDict& d,K x=nullptr);
K kget(const TensorDeque&);
K kget(const Tuple&);
K kget(const Nested&);
K kget(const Input&);
K kget(const Output&);
K kin(const Input&);
K kout(const Output&);
bool broadcast(const Tensor&,const Tensor&);
Tensor kput(K);
Tensor kput(K,J);
TensorDict kputd(K);
TensorVector vec(K,bool b=false);
K kten(const Tensor&);
K kvec(const TensorVector&);
K kdict(const TensorDict&,Cast c=Cast::tensor);
inline K kresult(bool p,const Tensor& t) {return p ? kten(t) : kget(t);}
inline K kresult(bool p,const TensorVector& v) {return p ? kvec(v) : kget(v);}
inline K kresult(bool p,const Tuple& t) {return kresult(p, TensorVector{std::get<0>(t),std::get<1>(t)});}
inline K kresult(bool p,const Tuple3& t) {return kresult(p, TensorVector{std::get<0>(t),std::get<1>(t),std::get<2>(t)});}
inline K kresult(bool p,const Tuple4& t) {return kresult(p, TensorVector{std::get<0>(t),std::get<1>(t),std::get<2>(t),std::get<3>(t)});}
inline K koutput(TensorVector& v,const Tuple& t) {
using std::get;
switch(v.size()) {
case 0: v.push_back(get<0>(t)); v.push_back(get<1>(t)); break;
case 1: v[0]=get<0>(t); v.push_back(get<1>(t)); break;
default: v[0]=get<0>(t); v[1]=get<1>(t); break;
}
return (K)0;
}
inline K koutput(TensorVector& v,const Tuple3& t) {
using std::get;
switch(v.size()) {
case 0: v.push_back(get<0>(t)); v.push_back(get<1>(t)); v.push_back(get<2>(t)); break;
case 1: v[0]=get<0>(t); v.push_back(get<1>(t)); v.push_back(get<2>(t)); break;
case 2: v[0]=get<0>(t); v[1]=get<1>(t); v.push_back(get<2>(t)); break;
default: v[0]=get<0>(t); v[1]=get<1>(t); v[2]=get<2>(t); break;
}
return (K)0;
}
inline K koutput(TensorVector& v,const Tuple4& t) {
using std::get;
switch(v.size()) {
case 0: v.push_back(get<0>(t)); v.push_back(get<1>(t)); v.push_back(get<2>(t)); v.push_back(get<3>(t)); break;
case 1: v[0]=get<0>(t); v.push_back(get<1>(t)); v.push_back(get<2>(t)); v.push_back(get<3>(t)); break;
case 2: v[0]=get<0>(t); v[1]=get<1>(t); v.push_back(get<2>(t)); v.push_back(get<3>(t)); break;
case 3: v[0]=get<0>(t); v[1]=get<1>(t); v[2]=get<2>(t); v.push_back(get<3>(t)); break;
default: v[0]=get<0>(t); v[1]=get<1>(t); v[2]=get<2>(t); v[3]=get<3>(t); break;
}
return (K)0;
}
K to(Kten*, const TensorOptions&,bool,bool);
void to(TensorVector&,const TensorOptions&,bool);
void to(TensorDict&,const TensorOptions&,bool);
J tensorlong(const Tensor&,Attr);
S tensorsym(const Tensor&,Attr);
K tensorsize(const Tensor&,Attr);
K tensorattr(const Tensor&,Ktype,Attr);
K vectorattr(const TensorVector&,Ktype,Attr);
K dictattr(const TensorDict&,Ktype,Attr);
K tensorinfo(const Tensor&,bool);
K vectorinfo(const TensorVector&,bool);
void tensorcopy(Tensor&,const Tensor&,bool async=false);
std::vector<int64_t> newsize(const Tensor&,int64_t,int64_t);
int64_t maxsize(const Tensor&, int64_t d=0);
int64_t maxsize(const TensorVector&, int64_t d=0);
int64_t maxsize(const TensorDict&, int64_t d=0);
int64_t checksize(const Input&,const Input&);
int64_t fullsize(const Tensor&, int64_t d=0,int64_t n=-1);
int64_t fullsize(const TensorVector&, int64_t d=0,int64_t n=-1);
int64_t fullsize(const TensorDict&, int64_t d=0,int64_t n=-1);
int64_t batches(int64_t w,int64_t n,bool b=false);
void batch(const Input& x,int64_t i,int64_t w,int64_t d=0,int64_t n=-1);
bool nextbatch(K,int64_t,int64_t);
void batchindex(K,int64_t,int64_t,int64_t);
void setsafe(Tensor& t,int64_t,const IntArrayRef&,const IntArrayRef&);
TensorVector tensorpick(Ktag *,K,bool,Cast,const char*);
void tensorfn(K);
// module routines
using Callbacks=std::array<Moduleptr,3>; // list of callbacks
ModuleAttrs moduleattrs();
Callbacks callbacks();
K kmodule(Cast,const Moduleptr&,Class a=Class::module);
K kmodule(Kmodule*);
void to(Kmodule*,const TensorOptions&,bool);
Moduleptr mcreate(K,J,Cast);
AnyModule anymodule(Cast,const Moduleptr&);
const
c10::optional<std::string>& mname_(const Module&);
c10::optional<std::string>& mname_(Module&);
S mname(const Module&);
std::string mlabel(const char*);
std::string mlabel(const Module&);
std::string mlabel(const Moduleptr&);
std::string mlabel(Kmodule*);
J argstart(K,S);
Cast mcast(const Module&);
Cast mcast(const Moduleptr&);
Cast mcast(const Moduleptr&,bool);
Cast msym(S);
S msym(Cast);
S msym(const Module&);
void msyms(K x,S&,S&);
const Tensor *findtensor(const Module&,const std::string&,Cast);
K modexample(Cast);
K moduleoptions(bool,bool,Cast,const Module&);
K moduleget(bool,bool,const Module&);
Output mforward(Kmodule*,const Input&);
void nnfn(K);
// loss functions:
LossAttrs lossattrs();
Cast lmap(S);
S lmap(Cast);
S lmap(Kmodule*);
Tensor losscalc(Kmodule*,const Tensor&,const Tensor&);
Tensor losscalc(Kmodule*,const Tensor&,const Tensor&,const Tensor&);
Tensor losscalc(Kmodule*,const Tensor&,const Tensor&,const Tensor&,const Tensor&);
Tensor losscalc(Kmodel*,const Output&,const Input&);
K lossexample(Cast);
K lossoptions(bool,Cast,const Module&);
K lossget(bool,bool,Cast,const Module&);
void lossfn(K);
// optimization functions:
S omap(Cast);
size_t osize(const Optimizer&);
J buffersize(Attr,Cast,const Optimizer&);
K kopt(Cast,const Optptr&,const Moduleptr&);
K optget(bool,bool,Cast,const Optimizer&,const Module&);
K optsettings(bool,Cast,const Optimizer&);
K optdefaults(Cast);
void optfn(K);
// model functions:
K modelget(bool,bool,Kmodel*);
Input modelarg(K,J,const char*);
std::tuple<Input,Input> modelargs(K,const char* c);
void modelfn(K);
#ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" // FORWARD_HAS_DEFAULT_ARGS
#endif
//#include "knn.h"
#ifdef __clang__
# pragma clang diagnostic pop
#endif
// global environment
typedef struct Env {
I cuda; // number of CUDA devices
S nullsym=cs(""); // internal representation of null symbol
bool frame=false; // if true, error message returns stack frame
bool alloptions=true; // if true, return all option settings, else only non-defaults
bool complexfirst=true; // if true, return complex tensor as (real;imag) instead of (real,'imag)
std::vector<std::tuple<S,Device>> device;
std::array<std::tuple<Ktype,TypeMeta,char>,8> ktype = {{ //k type -> torch type
std::make_tuple(KE, torch::scalarTypeToTypeMeta(torch::kFloat), 'e'),
std::make_tuple(KF, torch::scalarTypeToTypeMeta(torch::kDouble), 'f'),
std::make_tuple(KJ, torch::scalarTypeToTypeMeta(torch::kLong), 'j'),
std::make_tuple(KI, torch::scalarTypeToTypeMeta(torch::kInt), 'i'),
std::make_tuple(KSHORT, torch::scalarTypeToTypeMeta(torch::kShort), 'h'),
std::make_tuple(KB, torch::scalarTypeToTypeMeta(torch::kBool), 'b'),
std::make_tuple(KG, torch::scalarTypeToTypeMeta(torch::kByte), 'x'),
std::make_tuple(KC, torch::scalarTypeToTypeMeta(torch::kChar), 'c')
}};
std::array<std::tuple<S,TypeMeta,Ktype,char>,12> dtype = {{ //sym -> torch type -> k type
std::make_tuple(cs("float"), torch::scalarTypeToTypeMeta(torch::kFloat), KE, 'e'),
std::make_tuple(cs("double"), torch::scalarTypeToTypeMeta(torch::kDouble), KF, 'f'),
std::make_tuple(cs("half"), torch::scalarTypeToTypeMeta(torch::kHalf), KE, 'e'),
std::make_tuple(cs("bool"), torch::scalarTypeToTypeMeta(torch::kBool), KB, 'b'),
std::make_tuple(cs("byte"), torch::scalarTypeToTypeMeta(torch::kByte), KG, 'x'),
std::make_tuple(cs("char"), torch::scalarTypeToTypeMeta(torch::kChar), KC, 'c'),
std::make_tuple(cs("long"), torch::scalarTypeToTypeMeta(torch::kLong), KJ, 'j'),
std::make_tuple(cs("int"), torch::scalarTypeToTypeMeta(torch::kInt), KI, 'i'),
std::make_tuple(cs("short"), torch::scalarTypeToTypeMeta(torch::kShort), KSHORT, 'h'),
std::make_tuple(cs("chalf"), torch::scalarTypeToTypeMeta(torch::kComplexHalf), KE, 'e'),
std::make_tuple(cs("cfloat"), torch::scalarTypeToTypeMeta(torch::kComplexFloat), KE, 'e'),
std::make_tuple(cs("cdouble"), torch::scalarTypeToTypeMeta(torch::kComplexDouble), KF, 'f')
}};
std::array<std::tuple<S,torch::Layout>,2> layout = {{
std::make_tuple(cs("strided"),torch::kStrided),
std::make_tuple(cs("sparse"), torch::kSparse)
}};
std::array<std::tuple<S,bool>,2> gradient = {{
std::make_tuple(cs("grad"), true),
std::make_tuple(cs("nograd"), false)
}};
std::array<std::tuple<S,bool>,2> pin = {{
std::make_tuple(cs("pinned"), true),
std::make_tuple(cs("unpinned"), false)
}};
std::array<std::tuple<S,torch::MemoryFormat>,4> memory = {{
std::make_tuple(cs("contiguous"), torch::MemoryFormat::Contiguous),
std::make_tuple(cs("preserve"), torch::MemoryFormat::Preserve),
std::make_tuple(cs("channel2d"), torch::MemoryFormat::ChannelsLast),
std::make_tuple(cs("channel3d"), torch::MemoryFormat::ChannelsLast3d)
}};
std::array<std::tuple<S,Class>,9> kclass = {{ //higher level object names
std::make_tuple(cs("tensor"), Class::tensor),
std::make_tuple(cs("vector"), Class::vector),
std::make_tuple(cs("dictionary"), Class::dict),
std::make_tuple(cs("module"), Class::module),
std::make_tuple(cs("loss"), Class::loss),
std::make_tuple(cs("optimizer"), Class::optimizer),
std::make_tuple(cs("model"), Class::model),
std::make_tuple(cs("train"), Class::train),
std::make_tuple(cs("test"), Class::test)
}};
std::array<std::tuple<S,Arg>,6> arg = {{ // type of inputs & output for modules
std::make_tuple(cs("bool"), Arg::boolean), // accomodate bool arg of MultiHeadAttention
std::make_tuple(cs("tensor"), Arg::tensor),
std::make_tuple(cs("tuple"), Arg::tuple),
std::make_tuple(cs("nested"), Arg::nested),
std::make_tuple(cs("vector"), Arg::vector),
std::make_tuple(cs("dictionary"), Arg::dict)
}};
std::array<std::tuple<S,Cast>,3> tensortype = {{ //distiguish tensor from parameter & buffer
std::make_tuple(cs("tensor"), Cast::tensor),
std::make_tuple(cs("parameter"), Cast::parameter),
std::make_tuple(cs("buffer"), Cast::buffer)
}};
std::array<S,std::variant_size_v<Input>> in = {{
cs("tensor"),
cs("vector"),
cs("dictionary"),
cs("empty")
}};
std::array<S,std::variant_size_v<Output>> out = {{
cs("tensor"),
cs("tuple"),
cs("nested"),
cs("vector")
}};
std::array<std::tuple<S,Tensormode>,15> tensormode = {{ //tensor creation mode: map symbol -> enum
std::make_tuple(cs("arange"), Tensormode::arange),
std::make_tuple(cs("complex"), Tensormode::complex),
std::make_tuple(cs("empty"), Tensormode::empty),
std::make_tuple(cs("eye"), Tensormode::eye),
std::make_tuple(cs("full"), Tensormode::full),
std::make_tuple(cs("linspace"), Tensormode::linspace),
std::make_tuple(cs("logspace"), Tensormode::logspace),
std::make_tuple(cs("ones"), Tensormode::ones),
std::make_tuple(cs("randint"), Tensormode::randint),
std::make_tuple(cs("randn"), Tensormode::randn),
std::make_tuple(cs("randperm"), Tensormode::randperm),
std::make_tuple(cs("rand"), Tensormode::rand),
std::make_tuple(cs("range"), Tensormode::range),
std::make_tuple(cs("sparse"), Tensormode::sparse),
std::make_tuple(cs("zeros"), Tensormode::zeros)
}};
ModuleAttrs modules=moduleattrs();
LossAttrs loss=lossattrs();
Callbacks cb=callbacks();
std::array<std::tuple<S,Setting>,16> cset = {{ // configuration settings
std::make_tuple(cs("mkl"), Setting::mkl),