From 05bc4fbf09d1b59e1d645ae9a9152630fdef27ad Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 27 Apr 2024 17:46:28 -0500 Subject: [PATCH 01/49] examples(webassembly): use MNIST as test model --- examples/webassembly/Cargo.toml | 1 + examples/webassembly/src/lib.rs | 28 +++++++++++++++++++------- examples/webassembly/src/mnist.ort | Bin 0 -> 34480 bytes examples/webassembly/src/upsample.ort | Bin 16040 -> 0 bytes 4 files changed, 22 insertions(+), 7 deletions(-) create mode 100644 examples/webassembly/src/mnist.ort delete mode 100644 examples/webassembly/src/upsample.ort diff --git a/examples/webassembly/Cargo.toml b/examples/webassembly/Cargo.toml index 4c235e54..cc3b5f16 100644 --- a/examples/webassembly/Cargo.toml +++ b/examples/webassembly/Cargo.toml @@ -16,6 +16,7 @@ web-sys = "0.3" tracing = "0.1" tracing-subscriber = "0.3" tracing-subscriber-wasm = "0.1" +image = { version = "0.25", default-features = false, features = [ "jpeg" ]} [dev-dependencies] wasm-bindgen-test = "0.3" diff --git a/examples/webassembly/src/lib.rs b/examples/webassembly/src/lib.rs index 589bc11d..17edf461 100644 --- a/examples/webassembly/src/lib.rs +++ b/examples/webassembly/src/lib.rs @@ -1,22 +1,36 @@ -use ndarray::{Array4, ArrayViewD}; -use ort::Session; +use image::{ImageBuffer, Luma, Pixel}; +use ort::{ArrayExtensions, Session}; use wasm_bindgen::prelude::*; -static MODEL_BYTES: &[u8] = include_bytes!("upsample.ort"); +static IMAGE_BYTES: &[u8] = include_bytes!("../../../tests/data/mnist_5.jpg"); +static MODEL_BYTES: &[u8] = include_bytes!("mnist.ort"); pub fn upsample_inner() -> ort::Result<()> { let session = Session::builder()? .commit_from_memory_directly(MODEL_BYTES) .expect("Could not read model from memory"); - let array = Array4::::zeros((1, 224, 224, 3)); + let image_buffer: ImageBuffer, Vec> = image::load_from_memory(IMAGE_BYTES).unwrap().to_luma8(); + + let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); + (channels[c] as f32) / 255.0 + }); let outputs = session.run(ort::inputs![array]?)?; - assert_eq!(outputs.len(), 1); - let output: ArrayViewD = outputs[0].try_extract_tensor()?; + let mut probabilities: Vec<(usize, f32)> = outputs[0] + .try_extract_tensor()? + .softmax(ndarray::Axis(1)) + .iter() + .copied() + .enumerate() + .collect::>(); + + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(probabilities[0].0, 5, "Expecting class '5' (got {})", probabilities[0].0); Ok(()) } diff --git a/examples/webassembly/src/mnist.ort b/examples/webassembly/src/mnist.ort new file mode 100644 index 0000000000000000000000000000000000000000..184589976a1d9e70d593395924514f6978e7376a GIT binary patch literal 34480 zcmbSy30zHG_wa3wNKzVALKR`|PvNaJ5Auk;h2y;S7n|h}0M;GfJLe zp)C>_$%;e?Ng~mnJI&4Vzj{P$lt>i9V4{pj^nl@d)+g|Ue#R1MFc7pj4hot%)6vF; zk)k$to0~IPr&Nb2KaZe~rHeGD^`q9Y6ox)w0(#$*MZf~?TTF%a4+3>=3C z#CV2{=pI~w3QiGiTw)w?2r-q%h`CT z9--ZpfewSoVKXgf22YL%iwv6*<>)#)GAM9BSZK7P5Ns2Z$zfRVSAjuq2LFtAQuFV4 zH#029`m-1oeBQ^fU?+-U=X-yT+l}#sxLY%D92pcjEg(F|v5!bJjg1q07W@?K{pw&? z(DOSkp;!J_Tmo;sWV|#7-i#;M7V07JL?T5tR*C)oI%MRTtPqO^i^7PF(q@#}FKzO_ z{IO?!f?nb1oxtkDz=c6doQP@peQ2=E{gD-B*>Blt#v)MJ|W)|nwpzc8P;aN*uVUf`ofyO z+Ctu?$4E9lgasjhU_XZ81q^&!vHY@uf-ctlCPxKF2Sf#jg)&`|v4Vd>4Oj@jY%%l~7HH;*C@?tWmsBL9 zO(f@t&?m?Uc7^&+jR>19__Bb-B=teatuSxA`VAlGKg!c{;K<+SXu#~KF#qs?K-R^a z{ykpsL&)2|`62X0Mnwb%21SajTa1_bE!Ydp3W^F~h}jgbv4Wq!udyx13Aqdi37__BAftu+2s&DV^1|PI4R#`=u*4h` z&Ji|7xYRQA>l~OOT#bx|2Sm-B9byy^8fern@D~XE!q};jv2j0sUpKTRe8G<3oAeOc zLd~T|G#f9>e*uJi3Uv_j)SD@FW?0C*5Q8bJslaz;AU&jOi8K~(=DWaCV{zCsA;FI1 zk|q@qwP%!f@XVmdUN*h_BYt1+*y%2EWb2-EE(-AoItAO(Lud=@g!HgvIwL6;DBILoak<9jm@>im`g&me|@8v{XBuvhTp&wLQ(W@J=YM8MP_|FDR_ zpa|B!+@fEY)02aOL;Pok1qLx%jrk(j5^5`~8-g8o4{xu5-oImLxgM`z`4+|s{z+{! z@;@Gee{4AyE)Sxb7V<)kg?tNkg+q!3-O}}3B=Jp}+kYD``15aTIZN4J^OR$Qg!Wru zu;dV~J;FXMICQFwolrz}HBw{Q6!u&~pWsi+dDyakf-WJ?e+~mS2dzatSXO@TDg;?! zPx5AWV#q+v31az0nZf6GjQu565V9`2qOju*+SBr!FrZ(Z3a1PKVi=x ziuqB*eEQCucV@7FT?~%1$>_;o1v7Y6neiDMVM(%PP{byQ3xk_%y$TT)ItI(w38&2< zixt$G0nZZa#^5RobRvUSY*-8f;k*g(?G zVOU^G>&9piyKIWWBBI!yDk3P*zpuZIe=mP)|ERDTL81OOeMKU_7k}*SWb!N&;gDI+ zXsP~168+Nla!Fg8$qHj7GQR}>ZqJ6Me&Qd-yxmFSn+UnFTOGui)*J!>|&q=h|M z2`%VoVmhVyy#Eh-ZzSy)iM^Kj6!v@nJRfWr&$!`FT+>_FnaF5q9IlLRIZu2f?FS4C zKDEr_-|a}w12!nFg`LYxPC7r%GFlo>zC^FoE=t;NOjfY-KgA>Dt+9n2M+q&+S~FUz zw_=6 zfe^1M>ld&x!wks?dqd$ck?^Jd2s(sAsO4Yx2p+Se!e>YM3pw_A_dCu&MsH=Z(wH+D zEww*SqO+y_B!%Xxq`+FpOYo=G=|GGdkLO9sAL-bedHV(GZdlg~s-f!WDiG*fs z8GS}e{g7p}5OYgEI!btI3=89>_ADj*A_-rb2R7!{v!g$IY=3Dsr0xi55rsDX&c%90 z7cw2vyzFDN)Sp!n8-h>13|ig`=S%u>7-n7U+2HS^Lc%}(555q;Fy^L&FP%r9B>cA$ zzBCTDZIkZv*nPrZF^Ejw|Mqn=>zBsBNSQzukJRrWj23MDo8MC;eG?fLNlfB>~^E&b13%g@lOa9 z690wsN>xHLQbwN9QvY8_bpM#IPVI_1{3km--{++e!G=lJSx{i`*rAHwjz8uf7})=`t&F7(ijB$Jd>CDy;Gv|-~2{NUy;O~J`)y>>k@v=KlpDX{CWvr z>bD~EL(sp1VTS&ecj57`a9-KAm@;6Y`nmwk)*lF!BqI%21Z zM4y-ri9VTEjF$R;S)#k;_3WvH_kdx+p7gr@Q^HUA7e92_b811q3xi1#yMnAQqowO6 zWwekdscm;j+l^s?EnPSNe(&>-b(0NDlGvKa=xQb_jjNo|Qr&qH-7Vun39pD@!Jag( z>k@vAgfGo6d(6;sufTR^?7GjMuM75K84Qv5!AKcrMoVp*Fj|OHcy7coFk&FIsf2IO zFl+u*M;{5_^B??B3BO6gm&D0Lf4?94YahdIm^MoMa$|jUY`iol!n+S?{!yY&>Swj2 z?aO4D*kAb<^cyqyPx}+;+adLDf9mch@j=k#$!MuvJ4UyxyQ`$_%&;(4TKE5Wz3ju* z4v8&c9jRcl(inM0OJmw8(Jk2f=Y9w!eMJllc7?}6!f{=~ulWc6jf7t>;Y;(c$ovrW zuV9#=zwNh$cNBls#lEko7kjTE$hk0>C-Gm931_s_XCH|^X?u#KJ&|Eytkmw`^Yiy- z0>ZnstrA;8ex5Q}X@2f9TB`e`MEAet=f0%xCc}bV>G|9oR!Hng z>zXg&=lp|TA>kjF@C6&fabLop#{6N;zw$2h3Ay;oZ#EZOv*+*9x{j3iFUYttTI#bU zqXmDZZ7ONoGc2&Bbrqh0|HpeYc7D&3*mGlJPcnIFZHpK!J(|5nnF}E$A?1wA7v|qg%#jAZcqeER2=f`%mjROQpz%J^L2o zX$oNVXR^}RG8rw6MTlRp^{@4p^sQi6u&d90gd<(2C!Km{pLoP4}za53`R=)XQYf9qoqDuGFtHWpVt9b3E!DvK}KqSf`sq; z4}LUj|MpAcZ+^8rul~O7*fX`Q*?Tp?L-}8Gmpx|x^{!Ol3%XP!e2ES=7uzt-@0g|j zv&SWW)mM-a-hc`3LWTD)zx`4Bn?I7e{rcwS&$FSwuT#QBRG2?8n%Nlo5ly)4SnO z!}~ZVzm(**P2<<}NQ9wJ74c~QY%WW=2?RS*-x&E1JAaSShMlj`E$p;m&g%;SFi^-V zC6f;~k&#vT#9ikbNpd;|Njsjv#&PPP-t9D0glVDESygf`%oW!3Y(vLqXwof%KSHMR zD7tgaG+a=2lg5n>C-=N9(BU57_}eb{!u$kT_^^z;_+rin8imjWInVIW<1S(!7|P!? zlO-Nio6vUeYOd*d1=nD>ni^?aQJZn=VSecf4o}~JqGj3eZNO}Dm^L}s*>u9FHlNA4 z1$W7l$P?s5+YI;+beT7ettAbdnxHdiqw!D8$+dTQq zm6&zGVErlA@ldjG(uWjWONbRSy1{W1*+|4x2brK2vs!FO7Swcq#h`_<*p_?au= zgHO>5OJ}0~^&0wmOEG60y^Zu*X;F4*^*s2w+=+B}p$OA++tW9@WW(;0?pMb+?z0%z(?xk8W zm!a3UP25LZ4e%-!Dw+?G$Co0B#xZ-E-Efx#7rEh*(Q@3bT}#M-#0V~yTMagKx5?7i z^HB77E+@NXFRmY^2zif;QFVo;Q_a$~Fjq<5X^(6-nqxm4@B1Ajk5;DO$2OnQrnnn% zXnmFU8MA>7ygvb*9rln7^HQkis!Uw7IE5a6<$-51e!|MRr^MsSg6Q;1Dl~h;6MpU0 z4D^0h!d)MoNM0<}#DON4h{@76c;$Q$d1|x=yxTk%H}Cesylc%t=3u0(%t>N_K8d)?BIgS%f z#N%pjOZwq;CYTh}7OQS+2Olk#&rm*78nH28O+${+bEk zov9?rcqMiknMO7qScX=;*Ws*;UDRS!6_lJfM)QN(;LbxakdRYIT9^HRF7rQO?3vFn zbInFLP%YzRw6+^k1(;shiW15DPl>rVLeE zeilsK$I^=MWcnm)97zf5Ulv{AfpudF=&|xPIJ$Nfb!`|+W6o8QLhERJKA{WEwVQ@L z*DT|@mOFAyfwoj`(J(T6%~le)LJU18oaHJCM}ivpibupO7fH5M@D*hOA8tq+zZ4k8zaJ{K#0Z3h>$Kf{KX z`ou0|Fb3Y&qs?w9WX^!w+^S8Ecsye`Hm+VsBTdw4$qRSBwTcYgb0m|$RY3SHjrGLn zP#$IxMci(vffr;Z(&iqSw3SLM&Fk8Ws$Q=l(cR*4U*r?Wv5kP#K5FEKU3aKbISSYO zGWdOq+L83DSD=gKav3OaaL3Q1UOj*C)~$@vAQ*v0e`z0@m~q`Jj{uir}SpDdH`RQyrSv{S;Q$KE(5k%l zIRDCVBDcLgUfntX&d0W;NwUw0X;%Bv>F-^+{#tT)`q^meS64}D-@PrZ=~4>MmQyUh zRYsh;1$ohZ3nV6O21#!&Af`tXP<7h~(lbGZ#NF9QWDbtQpm2ioFQw9p z_h#by+SxRz){kn&DVHe^-2v5ieCh1nmNdt$l~d60VW{F$TxdCmUP~T9qQYaY~N}{-&+oJ+Ilz*HQi^CR`1nuv-=6&(6x~7zc&nDTc=}BuWC|b zp^R$pTH|R$UFu_|LbrJA14#^5^{Whms7vDhF z7zLU&=`}Ic+D9U7+o7A)^3veL8{w?&MxtJ*N#n1r0fWb;xM9RVn9%f`=&d!w>S^cU z(@-1KxOSF!hWZeXm+E9{-(-i4igloLWG5LFAIxVB?MvG1EeF-`L~y>pp2IyiN%Mmz zg-HHW_OX^R~`$HIl~*3`Xx z1`ID$Mb*^hcz5H@;nCEs ztShhL@-IZ=jBXtvx=FOGy<(G-?+Ovo=TXTit25?~BksN1hn(JV~k-#gTO* zDrx2PD)h1MkM1Mp;)U}_Z@xH!$Aj|4ceL%HLs}R);*^B9#Sx(FQiwHgylBU>!* zII1W*hHpEl^BKdd(57%A=koX_r?$R5of+VP>vNY=+0J*0RSgEBUzss283ug!HBY&_ z*>A}EjjQ;+)24#wvt2MkZynThzX~O1EQ$QfgLJHHA?n#AkfZGpQ3 zb6!Pg@^U-wv^zrP?XSk>+QFoUZz4Tf z(`zzlOAL2l;zMpy&^x~H&S4Tc+LGR~x()S-ufbwVTXfX(;75jDBMTm`2ENvkygZys zT-z<>dVGneRt|O0tgDV&6OWL)N+I;PemAJX4WPPz2Rc*~^4V1#a7%I~_?sK!Qoqiq z@FE%74LeGk)E|mR)aa6kIY5(`4-4cQ7dC6RDq>&S$Ok zAj`69iT@Z&(0cupKRbCc9pa@xTdDg(xpOl~I(d)`{yBsgl)BO$v$Am%8Hvu{BeC_F zbkxe3OQxP2M0*6^nB2xHa5x1(>j%T zG;^y3PPtx40t3{^tyhI$<;j(n3~q(%=7pf*cSE8#zyd$j+tDYsC!p}s20q`b4c*wB zg(24_P~ES&bbrEoJk#|F_1m!pPmk5de(8?%%J-x2NYRY0>TXW@DT!%*wOCyCeG?AO z(4|jTrNOESb3FRB4LaHskS%+unMuvCtfpH0jY%lFXlk!{BVN0C(VXjL^ z$DoJ2iTe)Dc*thxIjlRmS=|bZqcb4gbr>{!?1&rBx|6d{=J1-!?~&{t=Sg6#3i>p5 zgFJBwbQ&{@Zv2pdFKhaMiPbBvqntN6@Mw|PbgvOzH0vVq5AsIE%;PXPF^*3bS z6Zo>T4n31|nY{M(NAdljlC!FBpkAYd=$R-x`SmAs(s%psok#5|1$bO(hp1zqXNR)x-tstRwhh&L zb&4)f>O?ON>w+`iYY`32<$Uzyws20#7OhAvF@$#qE)cI@?7*r+ES^P@Oc;3yae!2>J3{N1pmVM$j-cG0O zM^DDC^<8P$n%UU-d=LiZ2lJgfJg0_Q)A{=qgVAunGVD0N4{h~)H93CZ0~GXU_corj zWava&tQu1YwTq_OKU z*gV4q`z)ORR?EfY;-GAD!Xuw#x@M6j>2pwaJyJ!!8b-}B!n`z362EE~eOEpYYgJoQ zuCkrDk9RX4nASw<-^gHE=l-~IpE=#N?F(m7Za~k<`oh4zH@FQirlH=3R`?@fGwm@0 zLDPodvei}e`;tjGXQ(x;INiuw**_%n-W9?3ih$B!Ic>Ujgcp(5x1wWIvbp;1%kUF7 z6~D)5VBC%2Fyll#4ei^P4!6FB>%&iSswb_{e4A`peARa{cE@I%_F0=?|AACNwvmoF zav$z~IwP(dLSfSQFBogqjU;J@yM7Iw!p=xbz;6zGvpq_x)Srkx|## zeV4t{D_dDyX_1Ui{l1aerc=Zp1{#)Cor@#xm)>&iRrk@xw0FFTTP-O%$L`B)J#f;M z7#bT~OE%8ijimtaXEE58a8Ek}smi$%D-ULVfPzl3g6kvnDE7u7aq(dI*OIBe)6a(}=D@@jz`rO`vg z&e3wzTlNYxy3U~TueYJ*b4@C;@g}!sIMDF!n_;;|JXdTw7wZ>nqE0<;qgujJy3$ym z>~USf70lQ{CoRe6)vmY^#dAKq!}xS^vl2_^5JUJVgA4Yago2*PR#BvGN|Ad|f2? z-+m@mcI`&(vb~5?t`f}i=!ikq7En6w3>VY28(a@r4eG&*VD{s4p!jwxEN^Zm<3+&*eF$$68C?+4xm)iFcpxR_D&W?g}I(7gswjpM4fGx*j9d*3q;!+Z^wQc1OE`_Ec+Md*1%~GrmWu90t|ull`sa z__y-cXxA}bIP;(!UflY%WCkUQXl%P|8S_$c!ZaFE7LUJ;ia?owlBMI z`v{Gw)5nLItI5or`E=TZSbq2WP2A#155cpd9jXk>A%;;q=$AJ;>9IB%rNyn4h+V~6 zSU7JcE`?kIeRHXNQ3I^}))vpjyd_2t8bRmITzW3Tg4_DR1e2DXhcR-OVD_Q*V0?8b zE*(dx>_rR5(e=Ig_g)kpXt6o0Hwu5MTJd|L9uk#@S*V;40M(6|)OzS}ng{DJ^TTR* zGS!7Xz|VNmbv3Pd)sD{WmO5}_~sl~svd^b8>6X}k~e9-pF+npo(HAD&EPcf!v%$rGd%`Lvfb2yodIb5FEOtSBKJPe$(nD_0WP2W$x4tM6rpov~L z?3(t3+huBrE}J@2%=$#eJTS$8vSO+qn}Vy_w?`S*Q$(D%gzR~J3q#)^bT3XL9d}

z$*gH3Y43w_uwlIton;zH4k(&t-*xuqkNt&?@ zAD`+;jn8a_wFhOL)^~eHriP2@fE`xYG;A~-Sic^-Y&%JZ%=6|COwS}fmZ?~A$emsu z+Yfda{Uk$vBm>1XGI(qfgnbDjYR8lzt0#fTn#bTis;qQ%Wq6M>Z={M zz8k-g+U!pFz~BRUA>Wngeet8y_SoUJZUw~t{bE!cFasLY2NFx$dj8hP4n#b3JbakY zg~X0JU&Wj$7mdCU<10mS-J$b08 z&M!NgO6s!1aK@V?@y*Bre3-liFD;Tm)r{r*t{@ATZdOf4dH2Cl(aF^Fd>^*9>5|Qp ztYFUQnP?IbK@OcW#D{8boL|H-@~W~oK0O^kD)LMaBMuSMv>Fs*dCm%vOVM@{zI@>3Ycn>$hPfo?qPq#n0R{xX0oWp$1{bG2?t}%nE&cRo8 zj+(K(pg4UI=A2^BV&BPuag;JWb!ZNF;U?0l(gt-qIVZyOx>upqsZ+Qu zNW-Z^c^qD9?Bld=Pb?navrfn}cNDUo7}JVSEpkgSm*lRrBXx6c@MB>G@Ui!ZmZ2gy zSHl`ha=X$G=XQdzlMk75AedH!Y(}r@yQI&viR7N+WvW>@g6kYe&}m6EG5!3W+Yudz zS(|Ue(fzCNSyeCbrTrWDT>r;hh|V*f8_;Dv>-`EikHl@{^aOK^_X(mAzRq;vfF_u$HJe;gkipg8Low#l z8tCg4=UA*Vm;c;xzbs#IAOd#ZYi>)_Y^je=(6_^XL}x|CTJ3^0oy3IxE!`9 zUxH-yr}WS`S3DjYMW+Bb11uGlY!GdjqCZT>Cr5$QwUp>o7w#$vGP zxEu5XlxaY8E?ME7Lo#ms)<(_x#) zuClgNtt^ZV8^4QRU6)E`U0V*T=DCv}cl1af+poYcFM!A;aqyz1Cv~=}gU2^!5na(s znmnTeUMkn4mF60>V)SMR?6wg_i+q7^A3|#~gK)UZeNf5m2PcNMrS83!V9dHsbhA!2 zys5OK=QhlulN0Pw``Q#TWs4)7dNc;|%hky7BYQzh)`-_Bz6lRaszJkc9Z?A~p|!J2 zs6urnnejFSpZ85AsiWNJ_%6NZ38g%&pFak_s2t>~^Sgm{elF=0F%+KJ^#uQQW8thv zIDIUaP8Yf9kn_7Esmzve?#AXCxW8mSS(ZEklCRWoji;;7N%J68H(L+8`c#s>xt(Zr zzZv-SBB2$VCy~b*b-dS}YW2x9 z^@I|Xbt;4V7OI$Fx{_RKlSuE(vZmRq;_$HJH0WEshIH|t4Pz5{xOQU-ZZ|c+tL-+F znCGO>WGxxWeb1tb^X1X6rzx>mrjJ*rG|=>+vAAdEZs6N$W32l+ak<-$l3T1A3+MI2%HLsglPTS>o z#u3BsaNB}o(f_qFHT0dz4Okn@-npHmA4_w{8|n@{yJcW#rVjPhlgA^KH6;FEA*k2w zf{Dft#UICO;6bht9*GW-D_P&6)r&KO0fMRg*UR_fCYI<4r;`LD?Q<#KRR z_Xd1A=0gKaU-5Z6rs7lU3ex*T0OsHRMl`}zah#Gh8I*etRv%0xj}+c=Lv9@)VeN+C zj*PS9-urv}!ft2>mDZ+9+&NooZze`Po^b}-<3$=o9Cv$BYxkq##1 zt|o0J<&d5m5wgD`dD=PyH(D6+Gv}MoL>)&mP+7KAwtYHW2&X9CxmNuAyBEGHm`0qR zn$VJIo$;ak1@MwtM!xRb3nHaOL@g~5)%uF@j-LVEE1pDMW~`^4bJ|mLlY3nEcKy)9 z=NddWdO}9q4WvJ7Ka;eeMPeCh0am@_aPIMo;`${&LC0Yn+pk5!UA{Hx)p;51mVB@D zwM-@W%Y5S}L|!De3)9JviTQlbYA-A^?Mh!dvG-^fo5}q5FJaB+DQNI&2Jf?{gfuo8 z;Apmg(in0IV)pL9L7O{Mw+C%7NU4r2v>EJ}6m3hxO;=M*r*3q7oD*ug=0h-hPgZu} z0*StU6bcU}b20vj&?>wQH4`tv`fGFO(?oAt+0GUwOirNH`}Rz* zP@}t{sH8X%V{(V!Q=>2BgZxr5xU&YW^$S45Nfa(+#~hqA2z`HXLrCbVLfcysSAr+#gkP| z^<>xOHfR}TPM`WJkmAt0#NV|$OlWUhTBo;#YEABe6(K9}>BmC)x`!c^{pm=H4r%sa>)HbIZ{v$SXO zHuRbG64VW6(|SYptmipT7EH*3&mYsEX?+&{NcW{rwkYG{qv1F%xi4PxolBM+u!7Zo z6KHN09Us3D`VnCh7h&iKcDLBC@+)o4w?+RL3J*N^tWfOU{QX;34GaDWY6?gKzVo zHyEV%!aGY=(p^iMu=lpLICw@9RZTkr`=k6Y?{pm<*}57u8!D-L$F^mYrj$gr+p6H0x3b)-oh5Ye zsZ6xHrAy5$e{hzU<{$ zkU6=N9FS{=Hn$g$oRCzq&?AknntG0mcoYaxE9Z0fpS7l{IwG_^Ef2#omJ*pxUgYVl zV_>K0iuzN#k~ZCC=+u#W`H~0vr2X_%j4!r8!$4&^eVE@UXCuAa2H-5(}lcs?D^lj>3rLEPoe8^ zq^p07!l-OjdiT+NI2i0jZ=edcI-12>jTuKCl9yygt4Mfwd=;K*Jp>I(5~!cOJ;ZmA zr|q1TU}l?ZWXkL7FuQFrW?ij8_vy+`V>;F0B!@_|{Lvz$|f5SD{+SZI{zic3j>pw!Ib8osQ-WH8N#nZE9 zJ?I*TE&R7HE~w+Kh2534=&7j@cxCN9veDI-^!vC37u>I;^CRDvid;T}m%J)|YCHoX z`9Z{SxHUewJ(jdJO(pIdk8*uKZY%EVA_ww{Q7|?xg>3DnN7NHm!?%K&{HKHGz+g@& zG|XELlbCPQB%|xwH$+YC- z0K6%VCp)gz(y0l@*?XgRu>WHk{hG0z9Gr6t64p#7Ushe+OscU!?FwN$d< z-4<@h3VEuxRS~t6YDh!ZbM$q`0Q4A;i@R2 zmy5|T*A8f0!$BgmWde>pmW@&UJ!x8oIVr!Y4f=!PXTp`(DAb?!N%`TfqcFIzVel1 zdgtNT-Jnox>s~^$!=94jaj&ua=J}-is4`r9Mc(P_135ZJqY--es^Ye4Z|XmMDK4>0 zqd7e$V;qLl4yQC>>$5N9*SCt2j~J!jSO2H~Bx2Q{?-hr!PYAskI5SX>$bt2r?!$o) zHFT&Pfj(h<@!_U%*h{_@DulS9iEkayk*%x<>GWhj9 zqvYF9ks;IfAHP3l{}FY)@W&6~8%`Gn!Y7R%O8SD$h-UGun587_=tnMjmJBXg(~gEj zOo9WhL!kAcedKhIKL~M4!v6I=Dr^4whV-v*s%;zuwW9nMcG|E%N?QwmmS)iUkt>V~ zDuDxiZUIA8`xqf$=ZXaQK}~e9EnKviXuR(S760+XOEGrTMz}&Fv{T z>12@9of9TU!4)5vTLf4EmTy{MHJAd9y zUYi{u86rcDYnYB*^x`3=w1C?^XejJS=}F~XX3=kbw(&Zyt{?-qtzY4(fr^cgp@*&82oj~@VOu(ggu0pH%dARLiE40`q12F~9 zxYsj|5#{XrAbRVBGL}JjbHo)g^=E(3(YVEP8V7mzPA|CtB^R=I#Q73>^eS;3b_DSI zL|CS`h4h+wlau=vk5BimMB^Q!@W6^yoWZ>`!WFfFuOBwy>wPjft=gPVnD>^beeOw3 zN2h|$q5$$!*_O)pNP?LoRPloBR{X4R9a_%?Vw|~wUY|0LWDE$xTc0N&x9$+c$FwEe zBFMaGP-j_$#UtST&xja_>O> z7wq7|cJ9OGxInad*$vk(wxG4)&uEijD#iWH%ol&W0NPkrx(PgQ#M4s;FO$U{iDXsW z4GfOjh8Gh$fYM$kGC1)CacJiUJ~LJD$M_WVv-HD;vLF1NjT6Ylt(IstHWRmHn?aw9 z1=!aAJQ?|B5ZhB!!m5R7s5M>(OZpuEt2Hv>l$7yg-aIFA;$2s~eFV8ND-+4mtMl0! z8$-e!M{+~IoB*>UH=*+8blT8+1ihkH1ZMLW;+oP9)Z?-@(Qz*4{V!CJl%`-n-%T)b z*bLxSj36IIvhPAQ78CtG2Ds@_JB-@8o6`-hDq*!?vWNd$=-f zuG>bwUHuL^k2FwS)ds7AH<11h#?tcJ>v>&Y%6p}E#;c7+p!R7xpVPS=DA%;1vQJ{^ zfbO;=Nu?4Wq;G?u)jD)gXD=G!sfNqW48-`2HV}EXhzxR@K-;z22eld=xV2ajK{sQF(DuJIZehGttEKE zl;ESGQ{dT#Q>5DWD%jrXiJQ7~Msn+oqm%1dapUUp(o3p7bmqM{qA|{ly6sG-Qx?vE zAiDyx+Et5^5*6CTO`ktsCrfW#T*22RZ-dca6=0u>9>TVn`g<4U-mi!?`EQ;F`-tQm-9CyQuA@ z`%;ZD*!2m2?TsESy{bf1KYP-cF?Y#kds&*0c9V1eLeOE{I80X8Lo?+;*pI6L^V=zK z=e+~otDS;-?=OPzpwZ|ss{rQRRK#;{--Ew}4*uLvF{P^=b(*YdBiZ-1M4vsZDYNu!2AGt2LT$O+VE;c&S0)&hpVTml)L)8SHs3a1b? zms&Ztr7yvbR~e^{=O;cU7tgFAUC$kcg*^uoy>usd*UA~cPd3L#YM$&qdOxI}yu+Dy zIRwi-P@LBHoj5F81r1+BBXmf|W%{=u*KBkNN#6?DTb`2J&(4x$S7R{lEJyFt!(?4l z8lAVyg9gTW(+~1p$lbfzD4$aXkq_^}v`K?-=AJb43^&5dnz#93MHOVNpDG_Pat?pF zu`lYk3x^Q7osjY+025jM<6UISqV{dY$m{L#rcxj`$xRs&R?LHM5y5Etu?G$cdJBF1 zO5m+xBH8CU3#;r>`F#@`xT!~dXpcxMG(NJOX0z`tujxdh@_kEO=IDAf*{iUFvs`eC9QV#8Djf{yruOS_ONk45=C1+6(*uases?5`S%e}r-ZH_Iqb3aOs`vP|s zw!oJ+Y(HO2L9}ci_>E2gj}#GpGTsNyMbo$is_n_+b*6Ybb{>km0G_#KkMB!WY2t~? zgp^gm0xd%$OXArIA^mBmR`yb zgRn19-iW$ts;F-_3H^I@!-z3U@lxPQEUq4hr1d(e-s6U%Cri<0S`zw{Q^Y${XxD`O z+>lq^js^q!a?7fdh}V}^pb_nb8cRo_{48(mAH9=5`@o&3rfn4GX|V5)>SodH@8dyE zZY-IEsr1RGgY-+_2s#t4V&}FCP^I}AXe9TAliTldg`NduPE`;pE}D*823!*>Yi05u z?f22nVlQ%d*CkkXJp_L0bj8*62K)*qALL&ZkbMKrf!9S9d@|S*bUnIYz?412v@wk6 z9+(PU*|*oKyog;VJJIY{YV2>J#ZBvj&`fm=w{3DT@-oZG`en`Jyx&i< zZ|iUv)d?M^9X~_PKTM=kKlLQ~uS=niT?%#f9?01!#gQ-L*U%X^ebDsIZuZ@u9$qa? z<5HDwaGg9?qrwD3d^_?r&{t-7NpleUcCHkjk2rv_Z*t(&b4}b(#_r))#lzl1H;7q+ z0?Ci=3b&_Q!133b2y-amtzFPW{S0iNclfeGRX*XS7;md7kmTDFNuSwMNrYE1Nj6^t zh0P|UNogf_;xzj^!>fEUO>He1?y-&>O&&}e+AGub`g>qjJP-Qm`}sKKL}E5Eg$pej zhT99z;)z{;WZXk7dLTBAzkJ3C?0Pz5ubT$Yy}TTPd(Xl9KOTbWu4d4?;fgnHeX&U9 z7<4Lr0uAv)@X^udFj+SW5__H_BM%NLz3z09Y@3pf3E}Cu<=q{!_qjP$Q8r}X(OrOG z&8cu`c25)~){&Z1a;R?*j0?Ea5Cb)E>BA+k2rGd-p6B3njtqXz&VY-_tzhMCPu#y- z6~oM@V7bv=h=@AH%`fRnn{qm1QHKbe*1-wS+|MQD55J3*=JdycyB)!PReQ{PDQ5qd z(IS<#D{$OGGhFnEp!dodxOcf7oH6r89fQSizWf%y>r*itezk&^f9{7%lq}Kig*zlg zb)}Q%U570^hhsvsVePil+^y$3Xs1^>{LJ%OFq(_wv+ph^OP<$I*Gn$+bcr4oG-ZLp zqW*O6oj7_cr&yfi5C=xvkMld>Iah0P59HFeL-DgZv3K-Gepsw0Z7p+!JoBGVxR*Df zT4M|8`8>t3?`0S2Ffxp6?dgO2D(k?_U61nO%NV)@dTSI4Yigd14Pcjkdr7xp#awOk;oX@f`V3>J^_?I}+bF-IHyK7qX zyj3mSG^{M0=oHE+X*O`8eLZ2TttM{blTkL(15XCcL}OJKs#vT{yByZTq|E_XH*zf| z+*%FKLVaL>=L!<4Ye5|vOlWkW103)%L#|I}6mL2RH`cb}=blJ};Wru}S}qgCr`lt- zI0KHjuZ8Lj&LpUO1zCBVa=VwFbv$mBN>`7)N=skm(oWH1NW=3INLg3{vtDL_&!w%f z^_m`<+b@T&>Fk?7y%pSq8L3=R{53c~T^B|BTjNcQ0AAIimcRYEiJbj81mBLm#wRIU z=6+n51>2jah%dXJ0KaPrn2>Nfb#cv#1Cmij+z!6-lJfOhnW@Yn3QtA!R6;$4p5G;XA$m$M=2z_dL&j z?sM}onM&>mzxIg?(tLwhygq?!ehYg=-Mj8bfcdL$jlV;zk21+PWLve=mmV+w!2| zwIsE5XF{%3E*q?%%Q6e}vFe!wq2*3i*_p#ehI|5P+Q;G>8`$JKk+_W2z_jXHxJNk! zhn-B~6o37|g0Jy#`PWg3Re34ohde1O!VD)yxRL+3X>_G#33)9)PKoybWq)ds%@cEK zcPT^ZzsKqB$NTV1GmTxo!^0%0hdAB*6Ry#E3GQ3;m}#LslfDznC-<+wGFe3~%Ha%- zUcQJm-pgj~U;2^z$@Q@BpdOdJd@=lx4q(NHM?!`34)NgU&Q`KvdszRFPHx}SK6q-i z8;Qqx)8_gHTF<7?)HH9B?juY3oqjYeFoFVR?ZS_rq{-^aOU^;ad-W~L=flEJ;xrvY zrY_NkRS%gBLdKkG>dmN+&P-+@;K$Fr7(!=fMo?ne1h_k66x*J5n5pit;*Qyd!qK&Z zX`0j%uG7`xDJu7o&^AHJr2Y6~FIBEzU}MCzhXJA>f*}5Ay99MlIt67YfXcv~eImR=kRDPh}r z+p=LOxqk?aI{FvR9q?fF_6q!I%P7vTQ;rQEx(hM{zQZ-;$)MIcoIRd1nt2ZUjB@Lg zDSGX0yb;+Ad&25aeuf;b*GfhY=b1RxXFj=X(ILMppFnhOlx2R4I{R>FB|rFbCJdC4 z<#n5?G3tFG|Lj*cjGiyUws{H=QENIo-Byiul6rJ%%u31+i^I~p8L0js10Hpm!t>xzj2Y!kF`C*` zXzl_N>wUTLxB5}tN(tt)Pz<=C7%f85al;k0I0W&n4nTZ6A^uBQO43Nn+r%DEl?ghh$Y6nJO@?UUR`wow-``1gMLW#mB( zCOa@`S0$QXlA*mChj6G?I6q{OJh^|ZgPl1XZMPEM7YjvZJW7`QHx_UoW9_-3i!ofo z=5CxDJ&yDno8bn`q6fAk#g48=ace*%?`Pw~zUNIL`-P*aR(Cu0GM6Ir`Esmb!XTRd z`aV_*80YIXm*A(}QDs1I2MNUVR>SR;?8;?`w!__-9=Hj9u6% z@Su9Dr^30`OP24cFPm+bgDGcS;n$h{n4L8M_Sfu!c0P-1KK)9(NLxqzQA3?&rk%$> zpW}J2M|R>@sUtufD$jb^Os1Dr+7NIc19G(LQ0Us;k2($qM#hda?N9pE8VX8BMb18(DAV zA{h6kRJ`Yn2^D=ar?&Bt%&AZ72ntplB$R-sfpzs*0QFqT<+Ay6btSO0(So~sX&he9YT`Oy z@?dzYH*b+5=1XLs;Ee=($hlqu$8=6(e&-$h&=5=^`Wv}zcQe@24{|X4(r6m?d#bkhO;EqTUa+_2|Mg@5Zq>{fchd;KBhK|>FzXzLC@uwd*Da>=4AtTRD$}Z$wHaSs-Wxd)$wRaAo*2(*jbh||0O%7(IWfs(AbrAMl z*T;7$M_AyLAetx4(W0>hyndB5`y2RIs=nybHDACxxRT!{Qzjj$ z!}HeTsN-7%w{N#1-qT2f^I^<=43>DR;TMfi`gB!gy}!$~PEbCBvpl zc5tt}bD?R=c*7iG2pUWu@5|B3-=}Et{WG+qJ2&^~cF&5L=CE~{cGWmf^-Z=68K zTjGWM|2VQOU4ib;BI(Bb5|WlLpwNACWbdg-hwh2+`92HE9FfI;?nkgK_z;)XwibGV`(oMR&ewh~yNq0i|d|PH7CSCEoXaiO1sFCq}d3%-~>D!5k=bdR&?*eW8mZUDbe2?jtMxSVLF#!hpr1h zuKO8oat*`t2^n1Hj!Ot@iAdu+%=gh@h5H#(xHpS@GNw}ers+&Zy9E-3oVS%ImJ%~X zB!4;sG3E~PHR&`rdn5TqYtxTY^Jz}n5Kg1i9*S;j@%rBeQdXuZJ*u6F^tFw`UKKgvrPZpXY@Z}BN*dvcQgEV?W`8;-2CVd*(ltgn$QE7x7gf|KIe za^*=Z|MNlo^CypweA@s|*SCld2>KH@`tD`1O6r&?y_k)zwqkG2>r&jHo$%i7E~@Wf zG(9dA^SzAmXjUP}oH$11ku`8CR)!XSGJ@P2%5SnEEOgB4gO14gMDa@F-hum98|ZW9G>X->&}*3^6!O-SJnzKty)HG8uT%j^^y{Ex ziKn#mU1x0+n(We-T}R7NdmW zDbSZ!CTE>SyzMiXF^v&y;lw+bWmbZ3WnAcv>os2H%T9c0?g}A&??A)Eq3pcKo`yEu z!R{mK?BLI9@O5SrH{50|Z4Exde_x$Iv2!$O*Xoh1saH9S(C~$g>7Jx{!whz>^kMTJ zE3t)YHX=jM`^dZWVjbJYbN%0~V4uZf)8 zKVqU?A-`bS7x-|-j2*Wd1{bs}sSjT(^0g3h-|-z#7JP>dUGo@OO9S-%q(l5N6{dP% zKeIf2o;#+bMXx%t#Od}jtm*M9_U8Lu$loeM5TQU`EBr+H_U>%eH#P2w>Mf8@N#c)3 zHQ?D<(HQ4`3Zs7f5Nhs|nd-ns{>QB}-uaw>AsSqSdG4QiUCxnIhW|#_E9La*W*g?r zyN!bs2=|O0#3Zb=NY`U2eb-l~&)S2?JM0FuZ%*ap($}!Y39S$x+80l(osDynccaAb z-YmM%fJKHnkW9Tan3{O z9d2rvh33w)ss7GX+%s_(nuvCic&{D=C{3ga=}Q>#wiYK|+kr-sGoa+)F(^KwOkr#w z^@a!F`0fLqyjYD1J7X#D)+c_;2ovV{r3_zx`G}8#)M@#^Eqp~Yu^g8Ns70xGy3U)I zalFVa*rUkoeak?!X^!}L`XO>YH;tZoNK(k+HkcXT!F8K^;o`mtbZnIq{?I>zT8Hmp z%i~mxIOaiB*@MWr&YC-7Do6QZC3@VUM1L$lfXb;ZzVnm?Z5?_W|K@CmGE>G(ooA5h zZ$FBEaE5&K&FPj(7@PE_7QH^q<9~L{#?Lz<5NuP~i*JK*sirQovN7aZG!tRC@lm!$ z;x?!^r?B1$YAjFKcd;JUw_0670`(CaA<`Y z40g50us#dehpn447c+v) zSgWQ5sYb|PYS}Tm;Ot2jriLUZw^L-7v;m?Fvfz>2Ky=Sq!bhsN;tKzVFzEXrCRBdu zxY;Fa3EzXuWSsE6wiE@AxyI$yKH|K?zw^bpbEu`{J~!aH0-Jp4J=_=atxh*JpjCet zGq=43@3n@o{MZ7ve8fTSamsG=ZYaY!2Zc3vkqbV?Sa21vJ3qa0#V=k&aIrJTqmzFl zB)?rrTfbb!bK3jy(79m9w6tI&+;aJY;j*xUvP6C2a(k;;ova?_AGRo z_}i?L=$K$kzOO8)`odIdIQEp!(!7h|#yZrtF$=dz%*VqPiJakb3;b4oiyyOk6UMJP z!4-w?g!wws6lZ)0AaE2L^ko2eom@em-Bj3tqGhalYz`ZnXie8_uh6c``^$5bUD0{e zJKXiX3!Pv72FaJ1B0t&TnD*QV3#=01LZ}&kmRk){9W{LFkW#QU{>VvRsRV^z$1%mB zmMcCr4;RWO;a}qrw2NO)OO5+c@yiH$wzm;?EGh^2F;B!zzpZJ$=_0cIBgXD0f#Rfo zuJn9tFZ$hb0rBieQQIJG?s_k0kRB4nY{DhrW>zfL8Lq%Z8|~P{`7^o0<_dKGi!1F) zeTr{AbV+f5D;|m1Ld|9)sg(u8Ws6BzV|AF94H-+eVdbJz9}b|O{uuJ#b_RdT8PcW~ zPcd-RU>s&QhK>9?mmL_o3g!kVqJ3@y4!_q70l$R0x!fe&cix^3dAFm|7km0)CPVM+ z^(lI3JPizei)YWNi)*eZ;FaxTcqwmv+VDyNy$?Ba9^0fqA^kP%_uh}Y^_8)x_8DK_ zt1o+ZBvM>@@+-F8o+i>B=mOqcUo_dc9lg4*<8jEw&Rx#^8Vez;iX6CVoJzP`@TW>{Bl5GwaEm%jvzZF@5&vI52 zVots0Y!`H|Oh~fY0cFZ1NO?#mD*f7p#b!USA$L74S)C5DUWI7VTaTi-T5z+y%XJ%k z;D5)xg1GU;oJ(*JtBLN%B5X`(zUy|@{&oHw{r`B z+`v|ck?`yF2{^tag1!bG;0|f65x1r1^RqNG@i2c5^0dF<+O9aPPsqS_nPyRy*I&3f zQxa#VCQ#$&Gf*c{33@5BpuA=#+y5+?El4teO}S&q^U_)FQ*}RlGXF578h_)QzO>^; zEg3rXRSC5Ahr^Bz3ASWjBR)t8A)UXnoYLO4umPOeeYdq-cC{OR{wBw5&A-HLAEr(3 z8pTCVI z;kRIwu`V?HHDG3mi*UHN4j$k19PS6KhNok$u~jN+Ff8&G*MF8Q^R2zj=DlcVOS-;_ zyVM?Y7d&gZV>&H-=~gMqxY#Ok%()0j-!?#X<7qsl`jYb<^cXXIj)COwd7P_|ufHX^ znQhIjJU8KO5q|vci8iak@r*?U#I$K}wTWXPewiL!SUrZHFwp}$PK;y48(kq>wH~Ky zPv!fJcZQKKhtLUsKd@MP6$T8GqX^x*+&0rZZew>EY%=YJXCFu4?9ttD`&c5U>r{mw zluY2|QcZAkXvX{J*6<3sZQQT4LlvgUeMm_AQ?-W@Qw&VTGOJ0hJ7 zH7miC0r8yY{-vDdbOTtO7mRWTy!o?|H?Xr=_yHHq!-5-gpjgnSiro7HJ9VUJ!*@w? zc=8C=o-Kzr1J|HlrjRFcX~JVNo~&2>T=uLW9&`sq(v|mm^g^_oi}`H~ZRc7+>E;Of z9=MehHy*W0ekaCN3zINA;wJtOd(!nOD`4DtKkh=uoQi_GvXqhKP3dmUT-GXGbXwy< zIkCU+uk0=Eq3d8+aH<`wOTXeQK})S+^JTpJdj^~S-jl)^I6cp@5k@fLRJT z_L?SI|4x(Dm&bs}IDyk&6Ha~h?1uUW`JjGLO(+y=uz+%7id=93m)~?19eQX3 zqtk(wziD*LZT4 z?FO(+HEtapd}EAf2ttd?uFS3n>>` zZR>);3u-D}ZPdo`Z-!yMeEW5PVXw^AS;3iF|Y40B8v+y{$;%xRZRj40@11hyDogyDVH z;((n3R=-Ue-Q)J~n}_&tMyu`M`MinPFHM0mm5kB8K@QIc8KCc+cJ5_TDBB(|jRJ#q zh})aAnctxaRG?}jRvEpH-?P|(S*9t%R-K(VLEi)Jyxk`73WqY;VK2~XrhqdM^4kT4 zvAFBYBKqjLA1hUUauYkI)09tsbnf7L{9OJRgC{U1CAYYe z0mtFvsvYb`bZ>SsQi`9G?!?J%xCwnN*D#UwWo&NUNj9kh4kBbT44Z!u$NaRRxraXB zg~fxZ`jYUQwCEhB?oHwJ9$w-9l-A> zlfsGArSC^`7B7d7l#I_?8bxZ-Gg(rSvh`}WsVq81j;5!iuu&g(u{rr!Z2FmxQ1z=3 zw2ERI7U;9HrA~O{u^4qnElqP}C~;m*b|lup#)Dn|#im_;sh zYR)dWx?UO{b{}K8{WMwU(nMZv-BhOj+>8b+8^#7c{eYxoOhu!7L<;_I@Q$G*w5^n< zudN<*;E)Vg+B=P(eMrCpT~a5V(f|Py{sZo>tb!qxgCR6~FLtVC&;^|;uov_SaMKTc1c_oZd=hoCliDLf3D3~Rm_Ql;k@D$m}7uxuxeJ zZC@=!$dyCydA&&ORII2~eE=G6jpanz#-yepO?~F*h?Yr9lisUn`qO(Ebq7zw6>}Zo z_z@=zX^>}wmghqBhP`;_7Q)!pxiGi>HoW~Xj*I{D2@{^psDpwmQRAW&r9*HrarYRsbG4=6VdRJnY6&T7{h%p@WZxM;?mo@QN}=tf|c*0!IY2C zs9Oxx9wFS?0rGTihXE>_Q(#y6J>c^7M{`jrKOorY2S}fY#qL=S@WQVRTI$n4z1SH% ztP0R_Ya4oW&fqJ)|3I332JN4ZV&4rU$#IxAyODGYZJ#93i*{$SIy8VSNdAO2vxf_H zogfl9^ryO}Jg^&kUDW^WZ!THjNjtkv_WQ|oJw8lR z3=2rjw~V5mq{7Mi87yby7+Nd*AN0E9O@8eutZC>CHsog)#ttxp&QrNe)2y6nXhl$d zOEYL1KH-akoavd?H0tdXNs|mOL-Wa%Y|OsB(5vD%%s=u4%ji5MAneI(jW6${`SUiGo;ho?7gf{|y^?)a9Q3{3iZP_qj zsslD1e$RKR48hc>Yka2EdwgW43-mw=PkaxdlSlhur}||a@gxSE^y2vFvGuU4<1UKy zg*kGE3jGl8rmmbU=HNLR6Ha*1@(Ke~dESdu5<+qJfXUEjQW>19YliZuK=#^rBwbWr zNryFukZ$?nCi=*>k9s{ioj0sa6Y21C7n-o zB3_T9IcuJv-G>aW-!~QZ$=;AHNl?WpDYwztB?K9oV}lfmVnZd6AEZd+;6a}p;%I=j z71hbs;Ak^L@Lr!_>8enU<+Wv4eY_Z6NR{#XSuOsyY8N@y7{mBvWx5kPk?y^ehkFt-NdE^lPdOhjyQeLDYj_%?E^3*_A^+u>%X4wI z7&gd&CT)sA$DOILbtl8F@sBZCD^m0+_XIo-9|OI0Y?yK7RPt5PAd8t-`IIYj=&tY{ z&tL8XI%l8btasyCSIhx)emWV>Jeo&sqp#uh+rT8Hfv%e8__~KVR?qi-X z;<66ueTc;F&a2@3rIULw{Ib=Qbx}-p#A3KUIFKrW4vF;tn$qEK&dhXq92{5{W!1-h zD}0&q8%>U=l3s}h-CNF)hM>9E6ktsY_L!2KZ?x#mIYo*$TYwVYVy?Q@2lwn6&utX_ zMT1Ixj8pewW1X+U!uCR_Q`#o1aqi)R-&R%wUZyb5WtDLAFNZ1)Mcjg_9DzSy$rMI= zfT{l&(u>H2{7-M7?x+?EYFWef?zLts6Z`sdz{zI4?MdT z^MW+lDlc<(#bpVq8aUF*xL2tC`W6hEvmGet9H-&oh@FRb;bUt(vXDw=ETjtVMY{`l z2RG{1Cl>Q|R58n)Nw9gwd2By_7S#nUq!x$6Y{Ot(93*hZlk(5t2**6Ao?^~^+a>b{ zen(Na#tiBpU7FE%x}=>HR2Uf5G%)f)D>zAUl&C{W%x{Ceo~;3^k4WM>H$<=`i!D~VZhNRBvJj7_rLn?nMlpaXmQo^7j=DnyHJ;_vgF1|66gCp%TK*O^P zJu434?dblbEND6!c?}24v(jX;=@5U{;|xZb$$(MLOWeENm)kP|$Wv=MeH(IwOjcym z+)-CCwRIBrbXEzLJ=jX4C*I*-RO*mR|B+nhfo)u>wJG!~Siw)Py2B+1{0gmprj$Kb ziSHlq0yohcxG$^?w~GWW`<|WXKIkPXKe~xe5A>&^Q{V8*fn}s@wVsY-mZD#w0*x=L z7OhCgBrSI@{%n&BbY@@R!}Kr1_eV3S3XH*gY=0J|$ipIgYbuzg&dx*|GGCNsTID+2 zx&nPFUsS=~ZlA%d23&-2&R?uJPMstxzKKQ%b=ruHPnD?0M~f>!9~I7ydtg?jG&(9#5YCtN$$svriS; zFrlfxRpDFr1(U9wRONSz9XM>F3ag46SeY^R3pwdb5%`&ttltXSCo8V}B9EJX5V$JR zT0>IN`mWA1D)82&-eN^-m%Rh6=e|iJuOC{KT_xI(5wwI^*7?vaw@&i%N~NNCx_IQi z4;$(Emqot2&-ykyzygb#I6%-xJgc85uI`kzE{&^aHtRI$^`u-fs2fhx>d&!(*NEB8 zQmM?^Ttkyu$I;=)Wo*EL0@n1kno7Hu)5*MOv`Y-+|2|G)@p9GtE7`lSswILAwJK-t zo|#x1g%z^yZB2}muVHKQ^4Qz^8<>J&JIzkaziYv2Z$S84Esi0gi7XrkJ~st%{h)!%j<aewA`> zDOszh%%gnwHhS@9F1_Dd!6q;K!92vbXjxSYNp9Xk@j;DjZbmwj3~HlW2QtWW;W->H zaP(Te+F)Ra3F#}Hp`hZy)V1R@i?L;Fz}O_%8=$~AL%1Q zyC5}+1yxtEk7pg(PwyjaxQUAOoTASdcG8I@@p9~Ur9Iq!AkTiRd&4DcQKW9q@AzW! z6nc~@Q5huek1mVsNb^-Q#av6GLGSFSXZ-yecmH@?|Hj?F$N%=o=;`-w-1c}Ynf{m8 z+Q0ohege*cf*YXV2GZjl+~Wb=?tWq!e{LsllH@ lKL#PLU*Osm|CxVHjm(XDyc2u+*iUnJ>FL*VpFO@F{|ADG#i#%P literal 0 HcmV?d00001 diff --git a/examples/webassembly/src/upsample.ort b/examples/webassembly/src/upsample.ort deleted file mode 100644 index b3e43d0099baebc9f777844560edea37aecf3dac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16040 zcmb`O4Qy0bcE_Ko%es_Omt`q+x71--)=iP)PvTe+qQs?c#8Qi;krhRZjX&1Tni+RI z5HE^Sgd&7cg-{eB6s-_KD}+#lA`~ICil`Musf$ud5lY>ZvXpL9H%%#}DN9+F>F@vE zIi5S?Ng1SGbhzidd++(5bMCqKzW3&NSrkQov~}CJ7e}?x;;06xnmUSU@)x3L@A@dJ z`-^Zzf9ZL?V8%3g(UBVwm~u8)}e{f zJwq{GG@Z|HE78^cFwymO+MX?A%fZWuEq(ux&u^fQA|pG05t8BwYPai6bP=j(B6d8h{La|ufSG-sejk|K8s~9Z^EZiY53uc8fT>t)8%+;V z#qnyOR$P4Cv-xBmyQpRFW@_o(MlBl~{<7FE&7H<--aDV`pU8DS6K$rCVxb&~$F>yw zyI^*W(`S};+2g*Sx}MJ;;JVwQ>*5us&TsD;>wG$jPSJi9S>t5iYDhUGx$a8hbB*Pw zY>9L$z*i)TR#MG$# z=?wfe=qFI^Gs&FLg;9m^$qKHHZ9d)}MQ<~KQ;>XE0>MVz1m-2$6rT+w+#m+M#02M< zh~p&pWemGV@E>68GN_4u>N`l^5&HJgcapw8r0-Sw)<9p-Zw1swzm-4D=jF!=NIr5o zQeLZQP9-Y_UY63o2SL;TDNemmAA_EyvvmD3G=tBZsmCDcknV#YCHX5})IT|@Wf>RE zvHCe({qg0&6#4axT?q|$tv>iB7G*gDAK&WPJIA6MSi?Vt?}}!9={GF>zXj? zsqaxwLpg}%7~>R4F8$K2eL?%|@K`3JiO}`*P=9u4BA42k85}~4W=i^NkFq|oKJQnH zHjO9Tl@M=3jC8RcL5(U~JdmP%5)4H!+H-$+;^0sWo%rh063!XIWWzxj>Cr(}z#(TaRZuHh3{XMUE?%g)fJBo_*?Q*6AP zp#LC(h|tCL2(|V_-LpuSa=V*7RDNj<%QiwxSwgMsRn)5E{ljDA{;AGo_>ZD-P9LmF z_;J>;7ul&}7f(QoqCF4tdD$BJ;`-GtU0OF@()1?{;(b}ZbU->+s_|t5wRCCxt90@< z#JI>#w@vj-oz49>S?5fo)+xt!+*kjtg?jq4xlyf+p~2_+$HtzT82Xc4L+RXT|JYNz z##0mhJIBXH)8B0wd}?Ph@Okd%WN+wCI2Xdj)VZ>r{-Hlzv~)vl7_%9YF6G$U?9<+c z7#FpL{)ksmy>@kivw=yx`XqZ(JvKF}H7TS~<-SeivaI{T)P(NiqNtI3KKT&(C|kS; zU$OMEivG$sk>*UbqTIuTsFp6}-!yx?w@C-4Qo#sm7}=mxCMFaHQ*< znxvm{4{@&O=BMI0j_;RJDdklQd$oKd%yam%!;9uDENkfB#s!gLpjvu1W1MQ~lARx5 zkhdYmMc18XN$$AcQ%P>bCRsmKn&QXvN23-IUEQTEj_->U%W!T9-CPlNv2QD}w~4jV zV0UdAs~tm6Ue)n=6Z+l7_-<$$q`4QC^nKJ+n0v&yl;*czmkoVOlb9`IemDFsXZCL; z`SM)PAOHU3(7-M&*`91>H+ie$HeN^ZV~8uER?fxO*o@(2x8f=Gdd=vYjdnCZ?P%ZR>Z!oveyQ{DXntFn{fNs=NL)U&AAE5^v!ndK< zp?c_TZ10Ayas$@~9c904fZk<)+5mm@8|=%_=S)f;^nkNM=f&5dQ`}g0K~rDDF6hQG zYUoYH1-eX}BIx+9VLvp(G8%+tIM@fFS?;ZJP%ZQW4(}%DgT>ejU1MU#p$8m52cTup zae}rMI!{t`K~v=8W@v^*ISAcm(q4sbS5rf}mwuDvu7}=cvO6F>|MfwiW7;G{muTPL zCg*fF{GE>z9=2*09`$jbOYrK;&{)21@+4z*S#4ejTa?M>CPui$8caQ%YEHF5$t`b7 z2N(g~4&xb8)eG-<>)}JZEhf)WtFP?vG9Rxw)#7+hN6`*$?g&?$`&Sq*o>Dap?|5fS zzx#K=@+`IbR+^8^__yQl3-P?dc+Zz4$L3}Pqm z>)h|tjCa3}!#m!7)9-#CF}aDckAt_4IG+yi>Wr6;T6mB1-4eWow8`J&)av`Fyq%j@ z&WSqS@wtBo1H8Qf-mvL+|8|>P%h*cr{G4cR<^4~9cR#?p9pHUjg7>M(d&}T)+fZ^2 zHqY&Iync%A8{P=~%f`#*K6v+M6TIRl(%!!sT4wSNYkT*<^7*v!Q}BC?{|J7+@iXw; zLKHop@A+{dzMs+0^XCS<&*z6W-tGC=t8xN4iEA z1H4Y-rK1Vn*W+^d(7!b%Pf@FHrR!f$@SXoy@b5I^-M{1T?%#gX@BSSzxrwpDlPzB6 z$Fp$X)~fX%;MEz=*s5B1_wQ~A-a^`hahzIxL;s57S$rSBF%s3?EckaYz}pM&{tcUc z_iwk!wTu;Bm@mb6Nxt~{=a`HR26*=ayxRfZ$0c~5n!MNW!gzAaU$Xv*C)e%-I6I7& zt-bK>*LwIcPqvsmORc`4U-Rd!unyV)-0@F^I9_4A=gBm@)tN_cB&^pR&Z8TmwI^Z2|C4A^#qsi0MrFiAyr~|k2Ck4MwGtT`w4(~YoO|ScP#N;N% zmg4ZpR`PycuU(s>69LXr<2AMh-f?b~;5;<>Xc?Sx@teCpBz{Z(Iq{_nkbME(WPrCF zK8#=5o9VHNuDftT%avT7A{#9~aT% zeonzV?n%ROKhK%m!&p=l;cEZ*Qe4_D26&yuOGgvD?~lvjLqFG;JVjlK$1TZY-iPvj zj9byI|5EV#ee}BD=i%M&V}|Q~pE0?cv9d$gy_B0@-#43|iK6!dycXjbTh##XcuU|z zzgL=ko?3mwb3yN<=H3nA&#KviKX1_A{W$>d`qQS*{dvvg6^xbsFn@T364~p-7VRVZ zOn(>rok0I=p#M^d{_7^sP>1?s+x62f@3oq+{krLXhfZ6JSN|?}xAjSQ*(Vx-HbV_2 zZ?d-ScVnCd>LZ_brN5U-wYEmX*8W=A_LgYBweO{7SQW2=;&RNcpDR1AIsWZwoBxc- zuCWeZPXB8Xd&Q~Fc-X31c%O&6rq_L2NSkbqsMXi=>i^2AHganDX9<>mU)UdDy$bKT zcbIOMvnJpD&jd?0Wiz8#-x3X*zS}I~j{|)l1^Uj~c-cD!bwjHm$(O8c1U-mzafFXP zv$h5I;T?=WGO~O~N}E;O!0YhE2cY?KZiVu?#E1n|~j{ng)P7u-w}9=b(1{xf0-QG+sJ7 z;2m!zeCSW3$&0+3VCvt@$TO#c*i?r`rW@{CU-El6t7&K027@y9L-am z@fuqT?>Ki$a2C=gKO<`O^*9#hiH}3{Vwb~y(SH>D*c0d{wvDbl(HZ{k&dc$1Rf&m>uq?>?jvc-P_y`a9$2@`rsXB z6MUE_FPMCrS~kV86$oX1UIx!4JW#%TyAa<`81H#<8Q$^EnSPJ&hbH$i7Kpj|E00&W z?`li1F^w`Yb5}SET6wv$=8pNc zG>VQe)_wakc=si1x`j6a$*y`x@*!(GK#e$eexJ1K7dQjCUUGJ~Mju)`uf58&w?x_d z$v24l9d~hjXI2ycjyqphbM%?}K;w(X%a$H^kJmc*FkYKXzCo?N9xwPZ`Cf?GI`Z#D z-uWp`-OxqGd(5Wb9q**+_ZXftxrea~dnBIk`|HTRiveDz@zT)*?|94ML;u#8toN4c z8}46vU#V{f^u3C9rYE^kaev;Tzx#6--uK7drqBJ^Z*o0jOZ6A$T=Dy-b`tz8!&4qD zGG00&c=zYi61=-6-=REVyfb}Z8`?14xOV;|*7if#v?KKB3*RUr`rc~r7(9s))-3_m2I?tzj;JKgI z18n(s%j5&eGu)h?G?n$$^QjfU1^<&^2rnDpjletJ7I@*f@7qnjN3F|ojA2yf$F^Gv56>5AS%#Ouzeg#^i3sR)XjEL9N_Zz8~PV7%v?S@Q$|x zKJ;&;$>*unx8nI@ziatf!M_uXcmIyUJKk%i-~Bseasy*4@z3Y4jfJmoSUsLgjhBuZ zc;D}CmEb)z`KaOfe#f|y{473q+sK!t4+?nu0=&rpZ##S#&$P)kjFoK_<8jXK3-Im) zc(d>x&r2nE*G--@JjL0|{PkCOE@&e^?*w=|jAv|BFTCTehY$VRV)87t`c|4RZTkM> z=Y@D)VZ7(dG`!=TG5ww|7fkMDtne!Jubuq765wq#UOGD99d9Lk=wGAB)6|vV`Tp6) z0lM*F!N0SNcmGboJKiDF@BSS#xr4Ef6VG<`&$9uZzOQw>74VMtumo=jZSwaN_2c;0 z&U#qIIqmE3jR5aJfR{G??%!)Budp-zQS;~Py`B7g!|=K=bb;}*`Cfo`y#();$+Ohz z>-jlTBSgtPKVYywr{+P3?aPLvL7pdW^ z+|WLskNNQ_e7m@gT+?qAvutLsL;CwjeZM&lRYUL4_7Y@1M_b4_*(lwzVFjcjIV?{g zRC2$kH6nfTQ8B59gynmy^oF|2`QqzT&!e~gAMMk9+In)$?OThisjP>*_-`FTJrlEu(6O!$py7^`fzw_t>V1#?f?cD4XsdYfQ^{%J~u;#lvwz{d^)6|NVx> zNT=q=`|ZRp?Ul(6>(6n?@1cCIuYGIxP&UWkSq^72JL_{h^@l*a`OI8@pvPD*d Date: Sat, 27 Apr 2024 17:47:02 -0500 Subject: [PATCH 02/49] fix: stub c++ exception handling for wasm --- src/wasm.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/wasm.rs b/src/wasm.rs index fe8bdf6d..51cf82f3 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -212,6 +212,12 @@ mod emscripten_shims { let c = str::from_utf8_unchecked(slice::from_raw_parts(str, len)); tracing::error!("Emscripten error: {c}"); } + + // despite disabling exceptions literally everywhere when compiling, we still have to stub this... + #[no_mangle] + pub unsafe extern "C" fn __cxa_throw(_ptr: *const (), _type: *const (), _destructor: *const ()) -> ! { + std::process::abort(); + } } #[no_mangle] From 80be2068296b04e8ae0a442d4b9c2d635362d07e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 27 Apr 2024 17:51:56 -0500 Subject: [PATCH 03/49] chore(sys): update WASM build --- ort-sys/dist.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index e283a36e..0b27077a 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -13,7 +13,7 @@ cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/ rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_rocm-v1.17.3-x86_64-unknown-linux-gnu.tgz 50E39B38484A0676B3D24365149CE9E7760F658B552EFE5A9382AB503D73D7E7 # todo: update WASM build to 1.17.3 -none wasm32-wasi https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 -none wasm32-wasi-preview1 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 -none wasm32-wasi-preview2 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 -none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b2-v1.17.1-wasm32-unknown-unknown.tgz 41A5713B37EEE40A0D7608B9E77AEB3E1A5DCE6845496A5F5E65F89A13E45089 +none wasm32-wasi https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 +none wasm32-wasi-preview1 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 +none wasm32-wasi-preview2 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 +none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 From 9b8bdf9065570f2a94d23e6954eeaa1d88357f02 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 6 May 2024 13:43:18 -0500 Subject: [PATCH 04/49] docs: update readme --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 262f3dbd..1193b6ef 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ - **[Twitter](https://twitter.com/)** uses `ort` to serve homepage recommendations to hundreds of millions of users. - **[Bloop](https://bloop.ai/)** uses `ort` to power their semantic code search feature. -- **[pyke Diffusers](https://github.com/pykeio/diffusers)** uses `ort` for efficient Stable Diffusion image generation on both CPUs & GPUs. - **[edge-transformers](https://github.com/npc-engine/edge-transformers)** uses `ort` for accelerated transformer model inference at the edge. - **[Ortex](https://github.com/relaypro-open/ortex)** uses `ort` for safe ONNX Runtime bindings in Elixir. - **[Supabase](https://supabase.com/)** uses `ort` to remove cold starts for their edge functions. From 1c0a5e4145ce68780b0a3459f4d7b7bd2cc96f82 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 6 May 2024 14:08:16 -0500 Subject: [PATCH 05/49] feat: direct downcasts of `ValueRef` & `ValueRefMut` --- src/value/mod.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/value/mod.rs b/src/value/mod.rs index 61b49d18..868403f1 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -178,6 +178,14 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { ValueRef { inner, lifetime: PhantomData } } + /// Attempts to downcast a temporary dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed + /// variant, like [`TensorRef`]. + #[inline] + pub fn downcast(self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_downcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } + } + pub fn into_dyn(self) -> ValueRef<'v, DynValueTypeMarker> { unsafe { std::mem::transmute(self) } } @@ -203,6 +211,14 @@ impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { ValueRefMut { inner, lifetime: PhantomData } } + /// Attempts to downcast a temporary mutable dynamic value (like [`DynValue`] or [`DynTensor`]) to a more + /// strongly typed variant, like [`TensorRefMut`]. + #[inline] + pub fn downcast(self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_downcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } + } + pub fn into_dyn(self) -> ValueRefMut<'v, DynValueTypeMarker> { unsafe { std::mem::transmute(self) } } From 1d8b815f4fa8f08b3fb268b2efc59fc16b5773aa Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 6 May 2024 14:09:18 -0500 Subject: [PATCH 06/49] feat: implement more kernel attribute types still need constant inputs and functions for getting I/O lengths & names from indexes --- src/operator/kernel.rs | 77 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 733d6158..61e98107 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -1,9 +1,9 @@ use std::{ - ffi::CString, + ffi::{c_char, CString}, ptr::{self, NonNull} }; -use crate::{error::status_to_result, ortsys, value::ValueRefMut, Error, Result, Value, ValueRef}; +use crate::{error::status_to_result, ortsys, value::ValueRefMut, Allocator, DowncastableTarget, DynValue, Error, Result, Value, ValueRef}; pub trait Kernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()>; @@ -25,19 +25,19 @@ impl KernelAttributes { } #[allow(private_bounds)] - pub fn get(&self, name: impl AsRef) -> Option { + pub fn get<'s, T: GetKernelAttribute<'s>>(&'s self, name: impl AsRef) -> Option { let name = CString::new(name.as_ref()).ok()?; T::get_from(self.0.as_ptr(), name.as_ptr()) } } -pub trait GetKernelAttribute { +pub trait GetKernelAttribute<'s> { fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option where Self: Sized; } -impl GetKernelAttribute for f32 { +impl<'s> GetKernelAttribute<'s> for f32 { fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option where Self: Sized @@ -48,6 +48,73 @@ impl GetKernelAttribute for f32 { } } +impl<'s> GetKernelAttribute<'s> for i64 { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut value = Self::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttribute_int64(info, name, &mut value)]).ok()?; + Some(value) + } +} + +impl<'s> GetKernelAttribute<'s> for String { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut size = ort_sys::size_t::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, ptr::null_mut(), &mut size)]).ok()?; + let mut out = vec![0u8; size]; + status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, out.as_mut_ptr().cast::(), &mut size)]).ok()?; + String::from_utf8(out).ok() + } +} + +impl<'s> GetKernelAttribute<'s> for Vec { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut size = ort_sys::size_t::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, ptr::null_mut(), &mut size)]).ok()?; + let mut out = vec![0f32; size]; + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, out.as_mut_ptr(), &mut size)]).ok()?; + Some(out) + } +} + +impl<'s> GetKernelAttribute<'s> for Vec { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + let mut size = ort_sys::size_t::default(); + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, ptr::null_mut(), &mut size)]).ok()?; + let mut out = vec![0i64; size]; + status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, out.as_mut_ptr(), &mut size)]).ok()?; + Some(out) + } +} + +impl<'s, T: DowncastableTarget> GetKernelAttribute<'s> for ValueRef<'s, T> { + fn get_from(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Option + where + Self: Sized + { + // TODO: This should probably be customizable - docs say the allocator is required for "internal tensor state", but it's + // not clear if this also includes tensor data (and thus it should instead be allocated on an appropriate device). + let allocator = Allocator::default(); + + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + status_to_result(ortsys![unsafe KernelInfoGetAttribute_tensor(info, name, allocator.ptr.as_ptr(), &mut value_ptr)]).ok()?; + unsafe { ValueRef::new(DynValue::from_ptr(NonNull::new(value_ptr)?, None)) } + .downcast() + .ok() + } +} + pub struct KernelContext { ptr: NonNull } From 076ddf4861c6b9b61ce7ed9e0ba026f7bd3c6465 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 6 May 2024 14:13:17 -0500 Subject: [PATCH 07/49] fix: kernel array attribute size type --- src/operator/kernel.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 61e98107..1801f92a 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -66,7 +66,7 @@ impl<'s> GetKernelAttribute<'s> for String { { let mut size = ort_sys::size_t::default(); status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, ptr::null_mut(), &mut size)]).ok()?; - let mut out = vec![0u8; size]; + let mut out = vec![0u8; size as _]; status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, out.as_mut_ptr().cast::(), &mut size)]).ok()?; String::from_utf8(out).ok() } @@ -79,7 +79,7 @@ impl<'s> GetKernelAttribute<'s> for Vec { { let mut size = ort_sys::size_t::default(); status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, ptr::null_mut(), &mut size)]).ok()?; - let mut out = vec![0f32; size]; + let mut out = vec![0f32; size as _]; status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, out.as_mut_ptr(), &mut size)]).ok()?; Some(out) } @@ -92,7 +92,7 @@ impl<'s> GetKernelAttribute<'s> for Vec { { let mut size = ort_sys::size_t::default(); status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, ptr::null_mut(), &mut size)]).ok()?; - let mut out = vec![0i64; size]; + let mut out = vec![0i64; size as _]; status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, out.as_mut_ptr(), &mut size)]).ok()?; Some(out) } From ba0867dec6d2e8243c7237d33b435b085ec4548e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 6 May 2024 17:21:43 -0500 Subject: [PATCH 08/49] fix: exclude null byte in string kernel attribute extraction --- src/operator/kernel.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 1801f92a..61758f6e 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -68,7 +68,7 @@ impl<'s> GetKernelAttribute<'s> for String { status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, ptr::null_mut(), &mut size)]).ok()?; let mut out = vec![0u8; size as _]; status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, out.as_mut_ptr().cast::(), &mut size)]).ok()?; - String::from_utf8(out).ok() + CString::from_vec_with_nul(out).ok().and_then(|c| c.into_string().ok()) } } From 9c80410f6f19cab618b2eb22fd6f21e2f65e8130 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 19 May 2024 23:13:13 -0500 Subject: [PATCH 09/49] fix: bad cfg values --- src/execution_providers/cuda.rs | 2 +- src/execution_providers/tensorrt.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index 6881068b..77200e3c 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -171,7 +171,7 @@ impl ExecutionProvider for CUDAExecutionProvider { } fn supported_by_platform(&self) -> bool { - cfg!(any(all(target_os = "linux", any(target_os = "aarch64", target_os = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) + cfg!(any(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) } #[allow(unused, unreachable_code)] diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index e1df3ad0..fe581c34 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -220,7 +220,7 @@ impl ExecutionProvider for TensorRTExecutionProvider { } fn supported_by_platform(&self) -> bool { - cfg!(any(all(target_os = "linux", any(target_os = "aarch64", target_os = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) + cfg!(any(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")), all(target_os = "windows", target_arch = "x86_64"))) } #[allow(unused, unreachable_code)] From 58ba9912caf9b08ef13c720760fe59eb087f2761 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 19 May 2024 23:14:31 -0500 Subject: [PATCH 10/49] chore: remove old widestring error --- src/error.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/error.rs b/src/error.rs index 07375207..7bdb2ba7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -180,10 +180,6 @@ pub enum Error { /// Attempt to build a Rust `CString` when the original string contains a null character. #[error("Failed to build CString when original contains null: {0}")] FfiStringNull(#[from] std::ffi::NulError), - /// Attempt to build a `WideCString` when the original string contains a null character. - #[cfg(all(windows, feature = "profiling"))] - #[error("Failed to build CString when original contains null: {0}")] - WideFfiStringNull(#[from] widestring::error::ContainsNul), #[error("`{0}` should be a null pointer")] /// ORT pointer should have been null PointerShouldBeNull(&'static str), From ce5aaba2ca319bffb9dbdceb3fbc29d5dd013a7e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 21 May 2024 09:40:05 -0500 Subject: [PATCH 11/49] feat: add telemetry option to environment builder, closes #203 --- src/environment.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/environment.rs b/src/environment.rs index 812c4c44..343cdc95 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -75,6 +75,7 @@ pub struct EnvironmentGlobalThreadPoolOptions { /// Struct used to build an `Environment`. pub struct EnvironmentBuilder { name: String, + telemetry: bool, execution_providers: Vec, global_thread_pool_options: Option } @@ -83,6 +84,7 @@ impl Default for EnvironmentBuilder { fn default() -> Self { EnvironmentBuilder { name: "default".to_string(), + telemetry: true, execution_providers: vec![], global_thread_pool_options: None } @@ -100,6 +102,12 @@ impl EnvironmentBuilder { self } + #[must_use] + pub fn with_telemetry(mut self, enable: bool) -> Self { + self.telemetry = enable; + self + } + /// Sets a list of execution providers which all sessions created in this environment will register. /// /// If a session is created in this environment with [`crate::SessionBuilder::with_execution_providers`], those EPs @@ -177,6 +185,12 @@ impl EnvironmentBuilder { }; debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); + if self.telemetry { + ortsys![unsafe EnableTelemetryEvents(env_ptr) -> Error::CreateEnvironment]; + } else { + ortsys![unsafe DisableTelemetryEvents(env_ptr) -> Error::CreateEnvironment]; + } + unsafe { *G_ENV.cell.get() = Some(Arc::new(Environment { execution_providers: self.execution_providers, From 17b4170cb081751f2fa1f1425ed6fb3ec6a6b81c Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 25 May 2024 23:17:43 -0500 Subject: [PATCH 12/49] chore: update ONNX Runtime to v1.18.0 --- README.md | 4 ++-- docs/migrating/version-mapping.mdx | 4 ++-- ort-sys/dist.txt | 28 ++++++++++++---------------- ort-sys/src/lib.rs | 2 +- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 1193b6ef..55cc4b8e 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,10 @@


Coverage Results Crates.io Open Collective backers and sponsors
- Crates.io ONNX Runtime + Crates.io ONNX Runtime -`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.17 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. +`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.18 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. ## 📖 Documentation - [Guide](https://ort.pyke.io/) diff --git a/docs/migrating/version-mapping.mdx b/docs/migrating/version-mapping.mdx index d2861ce0..e238b5a8 100644 --- a/docs/migrating/version-mapping.mdx +++ b/docs/migrating/version-mapping.mdx @@ -6,7 +6,7 @@ description: Information about `ort`'s versioning and relation to ONNX Runtime v ## A note on SemVer `ort` versions pre-2.0 were not SemVer compatible. From v2.0 onwards, breaking API changes are accompanied by a **major version update**. -Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.17.3, but 2.1 may ship with 1.18.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): +Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.18.0, but 2.1 may ship with 1.19.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): ```toml [dependencies] ort = { version = "~2.0", ... } @@ -16,7 +16,7 @@ ort = { version = "~2.0", ... } | **ort** | **ONNX Runtime** | | -------- | ----------------:| -| v2.0.0+ | v1.17.3 | +| v2.0.0+ | v1.18.0 | | v1.16.0-v1.16.2 | v1.16.0 | | v1.15.0-v1.15.5 | v1.15.1 | | v1.14.2-v1.14.8 | v1.14.1 | diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 0b27077a..6a622d45 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -1,19 +1,15 @@ -none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-aarch64-apple-darwin.tgz 4D3EFABA9B329900B400570FBC1A1F72899149EB3756F9540772701944DC5E49 -none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-aarch64-pc-windows-msvc.tgz 7BCECBBC15F64C631051C894C5044FB01658F171D7B5D4E9635D7D7F464B8B3E -none aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-aarch64-unknown-linux-gnu.tgz CA36FB040F127C5CAFA081BAC713240EE8C3F4D9F1BD7B789B789FFDD4885F0F +none aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-unknown-linux-gnu.tgz 5337059CE144C2ACBEE4744E0E59644ED03196AF6423062C82567240DE7BE235 +cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-unknown-linux-gnu.tgz D37C85BB1CE639135B4C168DEC12120FDDC223D4F33193C11B7CFDAF755D4C92 +cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-unknown-linux-gnu.tgz 168478A99F4C514B1BD8A8C142ED502501AEEDA038497389BDDAC37B9F12ED77 +rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_rocm-v1.18.0-x86_64-unknown-linux-gnu.tgz D6113A895DEB0BCBC28FD7E23A201DE4C5FBA6BADEB49F3190A084A36C24B43D +none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-unknown-linux-gnu.tgz F486F4B9F040FF533DCD6B26E074BEB5F9092E8E4C67F72D08696D9EB4C9C082 -none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-x86_64-apple-darwin.tgz 938EFA25B53283CA4768E6A779E8BFCA3423468C617FBA458A6E2CA04D2B2E3F -none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-x86_64-pc-windows-msvc.tgz F28C68199A3CDE9AD0D748994894AE2920BDB0A258D3F390E715975FBAB5B4DA -none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_static-v1.17.3-x86_64-unknown-linux-gnu.tgz 6FAF334246A635808FCDAC6D5550C4D56814B1E92CCA5FC0642AF41437BF071F +none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-pc-windows-msvc.tgz 9A1BF23A73D680290B52C22AAD039B490AC5AAA66FC21C06343A41369747B514 +cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-pc-windows-msvc.tgz A9457AC9AC5D6BE1F98B3BEE3D6AF5C074C9984F7CC7D1E660EA8082EBF65D48 +cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-pc-windows-msvc.tgz C5C62263BDD82B58ED15A6467D0729B21F26E78EA0E49E1E5197ECBA80783903 +none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-pc-windows-msvc.tgz 08A22E94EBA56BF30ECBB2DC9DD9F90A4583C8372BAFC7FE3DAB6C28A06544CE -cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_cuda11-v1.17.3-x86_64-pc-windows-msvc.tgz 9AE21DECB9BE1270CD850276AAC1AB9C2E2AE978B563B733804F5DF7DACC3BE5 -cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_cuda12-v1.17.3-x86_64-pc-windows-msvc.tgz 8748005ED6F11A3DA741601FA41D0C6DCCD9FABF704A0853BD3D2CE82FBD49F0 -cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_cuda11-v1.17.3-x86_64-unknown-linux-gnu.tgz D4C264EA6805790D4C8B51D166EF6BD407FB3ECC641B33AEFE77FCD5BF0C6ECA -cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_cuda12-v1.17.3-x86_64-unknown-linux-gnu.tgz C23A69F709DF91D2BB185D76A29901BD4BE81CD66D85DBCFC5E8BCD851DE9891 -rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.3/ortrs-msort_dylib_rocm-v1.17.3-x86_64-unknown-linux-gnu.tgz 50E39B38484A0676B3D24365149CE9E7760F658B552EFE5A9382AB503D73D7E7 +none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-apple-darwin.tgz F8DB068DFACFE3B00B9F0181B79780C6971CD1A6EAEB9D9A7FC2129CEB8413A5 +none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-apple-darwin.tgz E6E0457CB9C727DBA818D10245D3A2A29203CB037546B39C217E4CC9FB61ABE8 -# todo: update WASM build to 1.17.3 -none wasm32-wasi https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 -none wasm32-wasi-preview1 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 -none wasm32-wasi-preview2 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 -none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.17.1/ortrs-pkort_static_b3-v1.17.1-wasm32-unknown-unknown.tgz 5D240CF7A0E92B6B43E94014B399027110449CCBE43E30AA66F7A1B5BF7425A6 +none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8CDB6A4809652CAAB9F1AC704D8E549472836FC4A95187A63E5FE4E50EB9A15E diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index ee08e634..835da814 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -10,7 +10,7 @@ #[doc(hidden)] pub mod internal; -pub const ORT_API_VERSION: u32 = 17; +pub const ORT_API_VERSION: u32 = 18; pub use std::ffi::{c_char, c_int, c_ulong, c_ulonglong, c_ushort, c_void}; From 943a797d5560fdf8ccfec368739b0df72297f701 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 25 May 2024 23:35:11 -0500 Subject: [PATCH 13/49] fix(sys): update API to 1.18 --- ort-sys/src/lib.rs | 45 +++++++++++++++++++++++++++++++-- src/execution_providers/rocm.rs | 9 +++++++ src/operator/bound.rs | 4 +++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index 835da814..f5d3b130 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -298,7 +298,8 @@ pub struct OrtAllocator { #[doc = "< Free a block of memory previously allocated with OrtAllocator::Alloc"] pub Free: ::std::option::Option<_system!(unsafe fn(this_: *mut OrtAllocator, p: *mut ::std::os::raw::c_void))>, #[doc = "< Return a pointer to an ::OrtMemoryInfo that describes this allocator"] - pub Info: ::std::option::Option<_system!(unsafe fn(this_: *const OrtAllocator) -> *const OrtMemoryInfo)> + pub Info: ::std::option::Option<_system!(unsafe fn(this_: *const OrtAllocator) -> *const OrtMemoryInfo)>, + pub Reserve: ::std::option::Option<_system!(unsafe fn(this_: *const OrtAllocator, size: size_t) -> *mut ::std::os::raw::c_void)> } #[test] fn bindgen_test_layout_OrtAllocator() { @@ -522,6 +523,7 @@ pub struct OrtROCMProviderOptions { pub user_compute_stream: *mut ::std::os::raw::c_void, #[doc = " \\brief ROCM memory arena configuration parameters"] pub default_memory_arena_cfg: *mut OrtArenaCfg, + pub enable_hip_graph: ::std::os::raw::c_int, #[doc = " \\brief Enable TunableOp for using.\n Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.\n This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE."] pub tunable_op_enable: ::std::os::raw::c_int, #[doc = " \\brief Enable TunableOp for tuning.\n Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default.\n This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE."] @@ -1831,6 +1833,39 @@ pub struct OrtApi { num_keys: size_t ) -> OrtStatusPtr ) + >, + pub SessionOptionsAppendExecutionProvider_VitisAI: ::std::option::Option< + _system!( + unsafe fn( + options: *mut OrtSessionOptions, + provider_options_keys: *const *const ::std::os::raw::c_char, + provider_options_values: *const *const ::std::os::raw::c_char, + num_keys: size_t + ) -> OrtStatusPtr + ) + >, + pub KernelContext_GetScratchBuffer: ::std::option::Option< + _system!( + unsafe fn( + context: *const OrtKernelContext, + mem_info: *const OrtMemoryInfo, + count_or_bytes: size_t, + out: *mut *mut ::std::os::raw::c_void + ) -> OrtStatusPtr + ) + >, + pub KernelInfoGetAllocator: + ::std::option::Option<_system!(unsafe fn(info: *const OrtKernelInfo, mem_type: OrtMemType, out: *mut *mut OrtAllocator) -> OrtStatusPtr)>, + pub AddExternalInitializersFromMemory: ::std::option::Option< + _system!( + unsafe fn( + options: *mut OrtSessionOptions, + external_initializer_file_names: *const *const ortchar, + external_initializer_file_buffer_array: *const *mut ::std::os::raw::c_char, + external_initializer_file_lengths: *const size_t, + num_external_initializer_files: size_t + ) -> OrtStatusPtr + ) > } #[test] @@ -3254,7 +3289,13 @@ pub struct OrtCustomOp { pub KernelComputeV2: ::std::option::Option<_system!(unsafe fn(op_kernel: *mut ::std::os::raw::c_void, context: *mut OrtKernelContext) -> OrtStatusPtr)>, pub InferOutputShapeFn: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp, arg1: *mut OrtShapeInferContext) -> OrtStatusPtr)>, pub GetStartVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)>, - pub GetEndVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)> + pub GetEndVersion: ::std::option::Option<_system!(unsafe fn(op: *const OrtCustomOp) -> ::std::os::raw::c_int)>, + pub GetMayInplace: + ::std::option::Option<_system!(unsafe fn(input_index: *mut *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int) -> size_t)>, + pub ReleaseMayInplace: ::std::option::Option<_system!(unsafe fn(input_index: *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int))>, + pub GetAliasMap: + ::std::option::Option<_system!(unsafe fn(input_index: *mut *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int) -> size_t)>, + pub ReleaseAliasMap: ::std::option::Option<_system!(unsafe fn(input_index: *mut ::std::os::raw::c_int, output_index: *mut *mut ::std::os::raw::c_int))> } #[test] fn bindgen_test_layout_OrtCustomOp() { diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index 50e7de54..c2c28857 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -12,6 +12,7 @@ pub struct ROCmExecutionProvider { do_copy_in_default_stream: bool, user_compute_stream: Option<*mut c_void>, default_memory_arena_cfg: Option<*mut ort_sys::OrtArenaCfg>, + enable_hip_graph: bool, tunable_op_enable: bool, tunable_op_tuning_enable: bool, tunable_op_max_tuning_duration_ms: i32 @@ -30,6 +31,7 @@ impl Default for ROCmExecutionProvider { do_copy_in_default_stream: true, user_compute_stream: None, default_memory_arena_cfg: None, + enable_hip_graph: false, tunable_op_enable: false, tunable_op_tuning_enable: false, tunable_op_max_tuning_duration_ms: 0 @@ -80,6 +82,12 @@ impl ROCmExecutionProvider { self } + #[must_use] + pub fn with_hip_graph(mut self, enable: bool) -> Self { + self.enable_hip_graph = enable; + self + } + #[must_use] pub fn with_tunable_op(mut self, enable: bool) -> Self { self.tunable_op_enable = enable; @@ -135,6 +143,7 @@ impl ExecutionProvider for ROCmExecutionProvider { has_user_compute_stream: self.user_compute_stream.is_some().into(), user_compute_stream: self.user_compute_stream.unwrap_or_else(std::ptr::null_mut), default_memory_arena_cfg: self.default_memory_arena_cfg.unwrap_or_else(std::ptr::null_mut), + enable_hip_graph: self.enable_hip_graph.into(), tunable_op_enable: self.tunable_op_enable.into(), tunable_op_tuning_enable: self.tunable_op_tuning_enable.into(), tunable_op_max_tuning_duration_ms: self.tunable_op_max_tuning_duration_ms diff --git a/src/operator/bound.rs b/src/operator/bound.rs index 241ab6a8..f46c6c15 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -43,6 +43,10 @@ impl BoundOperator { GetVariadicInputMinArity: Some(BoundOperator::::GetVariadicInputMinArity), GetVariadicOutputHomogeneity: Some(BoundOperator::::GetVariadicOutputHomogeneity), GetVariadicOutputMinArity: Some(BoundOperator::::GetVariadicOutputMinArity), + GetAliasMap: None, + ReleaseAliasMap: None, + GetMayInplace: None, + ReleaseMayInplace: None, InferOutputShapeFn: if O::get_infer_shape_function().is_some() { Some(BoundOperator::::InferOutputShapeFn) } else { From fd41c8ffc750995786a97d9094f885303fd102cd Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 26 May 2024 00:36:17 -0500 Subject: [PATCH 14/49] ci(test): run wasm test with `--debug` --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 79ab7bf6..1b9fb83f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,7 +61,7 @@ jobs: - name: Run tests working-directory: examples/webassembly run: | - wasm-pack test --node + wasm-pack test --node --debug # Disable cross-compile until cross updates aarch64-unknown-linux-gnu to Ubuntu 22.04 # ref https://github.com/cross-rs/cross/pull/973 #cross-compile: From 9a7f0465d004eebe26aaaa2ae207b8678853cd2e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 26 May 2024 00:39:15 -0500 Subject: [PATCH 15/49] Revert "ci(test): run wasm test with `--debug`" This reverts commit fd41c8ffc750995786a97d9094f885303fd102cd. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b9fb83f..79ab7bf6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,7 +61,7 @@ jobs: - name: Run tests working-directory: examples/webassembly run: | - wasm-pack test --node --debug + wasm-pack test --node # Disable cross-compile until cross updates aarch64-unknown-linux-gnu to Ubuntu 22.04 # ref https://github.com/cross-rs/cross/pull/973 #cross-compile: From 9453d040b2085149234949699329b7e43ba9887b Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 29 May 2024 14:00:13 -0500 Subject: [PATCH 16/49] fix(wasm): disable `image::load_from_memory` --- examples/webassembly/src/lib.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/webassembly/src/lib.rs b/examples/webassembly/src/lib.rs index 17edf461..dbc0d73a 100644 --- a/examples/webassembly/src/lib.rs +++ b/examples/webassembly/src/lib.rs @@ -10,13 +10,19 @@ pub fn upsample_inner() -> ort::Result<()> { .commit_from_memory_directly(MODEL_BYTES) .expect("Could not read model from memory"); - let image_buffer: ImageBuffer, Vec> = image::load_from_memory(IMAGE_BYTES).unwrap().to_luma8(); - - let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { - let pixel = image_buffer.get_pixel(i as u32, j as u32); - let channels = pixel.channels(); - (channels[c] as f32) / 255.0 - }); + // NOTE: An earlier nightly version of Rust 1.78 includes a patch required for ONNX Runtime to link properly, but a + // later version enables debug assertions in `dlmalloc`, which surfaces an allocation bug in the `image` crate: + // https://github.com/rustwasm/wasm-pack/issues/1389 Because of this, using `image::load_from_memory` crashes. + // For demonstration purposes, we're replacing the image loading code shown below with zeros(). In a real application, + // you can get the image from another source, like an HTML canvas. + // + // let image_buffer: ImageBuffer, Vec> = image::load_from_memory(IMAGE_BYTES).unwrap().to_luma8(); + // let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + // let pixel = image_buffer.get_pixel(i as u32, j as u32); + // let channels = pixel.channels(); + // (channels[c] as f32) / 255.0 + // }); + let array = ndarray::Array4::::zeros((1, 1, 28, 28)); let outputs = session.run(ort::inputs![array]?)?; From eb0f4029ceb7ff1eae2acb26ba2193022dc6cde6 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 29 May 2024 15:07:48 -0500 Subject: [PATCH 17/49] fix(wasm): update artifact --- ort-sys/dist.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 6a622d45..066a2b52 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -12,4 +12,4 @@ none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/ms none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-apple-darwin.tgz F8DB068DFACFE3B00B9F0181B79780C6971CD1A6EAEB9D9A7FC2129CEB8413A5 none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-apple-darwin.tgz E6E0457CB9C727DBA818D10245D3A2A29203CB037546B39C217E4CC9FB61ABE8 -none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8CDB6A4809652CAAB9F1AC704D8E549472836FC4A95187A63E5FE4E50EB9A15E +none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8AB76874E977961A1CFA9714973521AA1B85F0F40D31EF38492CCA659BE58BF5 From 721bf0056c29062868b4677c27662d5728e66ba6 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 29 May 2024 15:11:02 -0500 Subject: [PATCH 18/49] ci(test): run if examples or dist.txt changed --- .github/workflows/test.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 79ab7bf6..2780d33f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -7,7 +7,9 @@ on: paths: - '.github/workflows/test.yml' - 'src/**/*.rs' + - 'examples/**/*' - 'ort-sys/**/*.rs' + - 'ort-sys/**/dist.txt' - 'build.rs' - 'Cargo.toml' - '.cargo/**/*' @@ -16,7 +18,9 @@ on: paths: - '.github/workflows/test.yml' - 'src/**/*.rs' + - 'examples/**/*' - 'ort-sys/**/*.rs' + - 'ort-sys/**/dist.txt' - 'build.rs' - 'Cargo.toml' - '.cargo/**/*' From 5dce2c272572abd980a7091b956fd05365c70441 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 3 Jun 2024 18:53:50 -0500 Subject: [PATCH 19/49] fix(sys): update windows artifacts to include DirectML again --- ort-sys/dist.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 066a2b52..b50dc69e 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -4,10 +4,10 @@ cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/ rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_rocm-v1.18.0-x86_64-unknown-linux-gnu.tgz D6113A895DEB0BCBC28FD7E23A201DE4C5FBA6BADEB49F3190A084A36C24B43D none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-unknown-linux-gnu.tgz F486F4B9F040FF533DCD6B26E074BEB5F9092E8E4C67F72D08696D9EB4C9C082 -none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-pc-windows-msvc.tgz 9A1BF23A73D680290B52C22AAD039B490AC5AAA66FC21C06343A41369747B514 -cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-pc-windows-msvc.tgz A9457AC9AC5D6BE1F98B3BEE3D6AF5C074C9984F7CC7D1E660EA8082EBF65D48 -cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-pc-windows-msvc.tgz C5C62263BDD82B58ED15A6467D0729B21F26E78EA0E49E1E5197ECBA80783903 -none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-pc-windows-msvc.tgz 08A22E94EBA56BF30ECBB2DC9DD9F90A4583C8372BAFC7FE3DAB6C28A06544CE +none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-pc-windows-msvc.tgz D9807ACE93E87CC45286A9B1892138FE4F28D1C764E30E9FC0B20DBE300063BA +cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-pc-windows-msvc.tgz D40FBAA7C4348A4CB5E38F59BB172A93829C6D457B1F97D68472D551A6E961E4 +cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-pc-windows-msvc.tgz 3BDA7C8BCB97DFB58114A391ECA7A1C1395A47CB88490103153B36CFF8C1CD48 +none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-pc-windows-msvc.tgz 6D9CFE125807CB9EC4C37D903457B82D505BB77CE7DEF4F750467613BBC2702A none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-apple-darwin.tgz F8DB068DFACFE3B00B9F0181B79780C6971CD1A6EAEB9D9A7FC2129CEB8413A5 none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-apple-darwin.tgz E6E0457CB9C727DBA818D10245D3A2A29203CB037546B39C217E4CC9FB61ABE8 From 1d89f822d841060925bf54e52e434b6c98fc2d9b Mon Sep 17 00:00:00 2001 From: Florian Kasischke Date: Wed, 5 Jun 2024 19:08:46 +0200 Subject: [PATCH 20/49] feat(sys): support `pkg-config` Co-authored-by: Florian Kasischke --- ort-sys/Cargo.toml | 1 + ort-sys/build.rs | 43 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 9a6a3cf2..e638ae2f 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -44,3 +44,4 @@ ureq = { version = "2.1", optional = true, default-features = false, features = tar = { version = "0.4", optional = true } flate2 = { version = "1.0", optional = true } sha2 = { version = "0.10", optional = true } +pkg-config = "0.3.30" \ No newline at end of file diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 440023cb..ccbd5684 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -376,6 +376,29 @@ fn prepare_libort_dir() -> (PathBuf, bool) { } } +fn try_setup_with_pkg_config() -> bool { + match pkg_config::Config::new().probe("libonnxruntime") { + Ok(lib) => { + // Setting the link paths + for path in lib.link_paths { + println!("cargo:rustc-link-search=native={}", path.display()); + } + + // Setting the libraries to link against + for lib in lib.libs { + println!("cargo:rustc-link-lib={}", lib); + } + + println!("Using onnxruntime found by pkg-config."); + true + } + Err(_) => { + println!("onnxruntime not found using pkg-config, falling back to manual setup."); + false + } + } +} + fn real_main(link: bool) { println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION); println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_PROFILE); @@ -400,14 +423,20 @@ fn main() { } if cfg!(feature = "load-dynamic") { - // we only need to execute the real main step if we are using the download strategy... - if cfg!(feature = "download-binaries") && env::var(ORT_ENV_SYSTEM_LIB_LOCATION).is_err() { - // but we don't need to link to the binaries we download (so all we are doing is downloading them and placing them in - // the output directory) - real_main(false); + if !try_setup_with_pkg_config() { + // Only execute the real main step if pkg-config fails and if we are using the download + // strategy + if cfg!(feature = "download-binaries") && env::var(ORT_ENV_SYSTEM_LIB_LOCATION).is_err() { + // but we don't need to link to the binaries we download (so all we are doing is + // downloading them and placing them in the output directory) + real_main(false); // but we don't need to link to the binaries we download + } } } else { - // if we are not using the load-dynamic feature then we need to link to dylibs. - real_main(true); + // If pkg-config setup was successful, we don't need further action + // Otherwise, if we are not using the load-dynamic feature, we need to link to the dylibs. + if !try_setup_with_pkg_config() { + real_main(true); + } } } From b00c9341bc9e25c87a6634de2fe330b54fc3f8a3 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 8 Jun 2024 13:19:34 -0500 Subject: [PATCH 21/49] docs: nuke mintlify --- docs/mint.json | 102 - docs/next-env.d.ts | 5 + docs/next.config.mjs | 11 + docs/package.json | 23 + docs/pages/_app.mdx | 5 + docs/pages/_meta.json | 37 + docs/{introduction.mdx => pages/index.mdx} | 97 +- docs/{ => pages}/migrating/opsets.mdx | 0 docs/{ => pages}/migrating/v2.mdx | 7 - .../{ => pages}/migrating/version-mapping.mdx | 0 docs/{ => pages}/perf/execution-providers.mdx | 34 +- docs/{ => pages}/perf/io-binding.mdx | 0 docs/{ => pages}/setup/cargo-features.mdx | 0 docs/pages/setup/linking.mdx | 109 + docs/{ => pages}/setup/platforms.mdx | 18 +- docs/{ => pages}/setup/webassembly.mdx | 8 +- .../{ => pages}/troubleshooting/compiling.mdx | 0 .../troubleshooting/performance.mdx | 81 +- docs/pnpm-lock.yaml | 3200 +++++++++++++++++ docs/{ => public}/assets/banner.png | Bin docs/{ => public}/assets/icon.png | Bin .../{ => public}/assets/sample-onnx-graph.png | Bin docs/{ => public}/assets/trend-banner.png | Bin docs/setup/linking.mdx | 106 - docs/theme.config.jsx | 33 + docs/tsconfig.json | 28 + 26 files changed, 3575 insertions(+), 329 deletions(-) delete mode 100644 docs/mint.json create mode 100644 docs/next-env.d.ts create mode 100644 docs/next.config.mjs create mode 100644 docs/package.json create mode 100644 docs/pages/_app.mdx create mode 100644 docs/pages/_meta.json rename docs/{introduction.mdx => pages/index.mdx} (56%) rename docs/{ => pages}/migrating/opsets.mdx (100%) rename docs/{ => pages}/migrating/v2.mdx (98%) rename docs/{ => pages}/migrating/version-mapping.mdx (100%) rename docs/{ => pages}/perf/execution-providers.mdx (94%) rename docs/{ => pages}/perf/io-binding.mdx (100%) rename docs/{ => pages}/setup/cargo-features.mdx (100%) create mode 100644 docs/pages/setup/linking.mdx rename docs/{ => pages}/setup/platforms.mdx (63%) rename docs/{ => pages}/setup/webassembly.mdx (83%) rename docs/{ => pages}/troubleshooting/compiling.mdx (100%) rename docs/{ => pages}/troubleshooting/performance.mdx (55%) create mode 100644 docs/pnpm-lock.yaml rename docs/{ => public}/assets/banner.png (100%) rename docs/{ => public}/assets/icon.png (100%) rename docs/{ => public}/assets/sample-onnx-graph.png (100%) rename docs/{ => public}/assets/trend-banner.png (100%) delete mode 100644 docs/setup/linking.mdx create mode 100644 docs/theme.config.jsx create mode 100644 docs/tsconfig.json diff --git a/docs/mint.json b/docs/mint.json deleted file mode 100644 index 8a6237ba..00000000 --- a/docs/mint.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "$schema": "https://mintlify.com/schema.json", - "name": "ort", - "logo": { - "dark": "/assets/banner.png", - "light": "/assets/banner.png" - }, - "favicon": "/assets/icon.png", - "colors": { - "primary": "#F74C00", - "light": "#F74C00", - "background": { - "light": "#FFFFFF", - "dark": "#000000" - }, - "dark": "#F74C00", - "anchors": { - "from": "#F74C00", - "to": "#eb8e65" - } - }, - "tabs": [ - { - "name": "API Reference", - "url": "https://docs.rs/ort/2.0.0-rc.2/ort/" - } - ], - "anchors": [ - { - "name": "Sponsor", - "icon": "hand-holding-heart", - "url": "https://opencollective.com/pyke-osai" - }, - { - "name": "Crates.io", - "icon": "rust", - "url": "https://crates.io/crates/ort" - }, - { - "name": "GitHub", - "icon": "github", - "url": "https://github.com/pykeio/ort" - }, - { - "name": "Discord", - "icon": "discord", - "url": "https://discord.gg/uQtsNu2xMa" - } - ], - "navigation": [ - { - "group": "Get Started", - "pages": [ - "introduction" - ] - }, - { - "group": "Setup", - "pages": [ - "setup/platforms", - "setup/webassembly", - "setup/linking", - "setup/cargo-features" - ] - }, - { - "group": "Fundamentals", - "pages": [ - "fundamentals/environment", - "fundamentals/session", - "fundamentals/value" - ] - }, - { - "group": "Performance", - "pages": [ - "perf/execution-providers", - "perf/io-binding" - ] - }, - { - "group": "Troubleshooting", - "pages": [ - "troubleshooting/precision", - "troubleshooting/performance", - "troubleshooting/compiling" - ] - }, - { - "group": "Migration & versioning", - "pages": [ - "migrating/version-mapping", - "migrating/v2" - ] - } - ], - "footerSocials": { - "website": "https://pyke.io/", - "github": "https://github.com/pykeio/ort", - "discord": "https://discord.gg/uQtsNu2xMa" - } -} diff --git a/docs/next-env.d.ts b/docs/next-env.d.ts new file mode 100644 index 00000000..4f11a03d --- /dev/null +++ b/docs/next-env.d.ts @@ -0,0 +1,5 @@ +/// +/// + +// NOTE: This file should not be edited +// see https://nextjs.org/docs/basic-features/typescript for more information. diff --git a/docs/next.config.mjs b/docs/next.config.mjs new file mode 100644 index 00000000..47e0f5ea --- /dev/null +++ b/docs/next.config.mjs @@ -0,0 +1,11 @@ +import nextra from 'nextra'; + +export default nextra({ + theme: 'nextra-theme-docs', + themeConfig: './theme.config.jsx' +})({ + output: 'export', + images: { + unoptimized: true + } +}); diff --git a/docs/package.json b/docs/package.json new file mode 100644 index 00000000..36fdb98c --- /dev/null +++ b/docs/package.json @@ -0,0 +1,23 @@ +{ + "private": true, + "name": "ort-docs", + "version": "0.0.0", + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start" + }, + "dependencies": { + "next": "^14.2.3", + "nextra": "^2.13.4", + "nextra-theme-docs": "^2.13.4", + "react": "^18.3.1", + "react-dom": "^18.3.1" + }, + "devDependencies": { + "@types/node": "20.14.2", + "@types/react": "^18.3.3", + "@types/react-dom": "^18.3.0", + "typescript": "^5.4.5" + } +} diff --git a/docs/pages/_app.mdx b/docs/pages/_app.mdx new file mode 100644 index 00000000..c466f982 --- /dev/null +++ b/docs/pages/_app.mdx @@ -0,0 +1,5 @@ +import font from 'next/font/google'; + +export default function App({ Component, pageProps }) { + return ; +} diff --git a/docs/pages/_meta.json b/docs/pages/_meta.json new file mode 100644 index 00000000..14840b87 --- /dev/null +++ b/docs/pages/_meta.json @@ -0,0 +1,37 @@ +{ + "-- Links": { + "type": "separator", + "title": "Links" + }, + "link-oc": { + "title": "Sponsor ↗", + "href": "https://opencollective.com/pyke-osai", + "newWindow": true + }, + "link-api": { + "title": "API Reference ↗", + "href": "https://docs.rs/ort/2.0.0-rc.2/ort" + }, + "link-crates": { + "title": "Crates.io ↗", + "href": "https://crates.io/crates/ort", + "newWindow": true + }, + "-- Docs": { + "type": "separator", + "title": "Docs" + }, + "index": "Introduction", + "setup": { + "title": "Setup" + }, + "perf": { + "title": "Performance" + }, + "troubleshooting": { + "title": "Troubleshooting" + }, + "migrating": { + "title": "Migration & versioning" + } +} diff --git a/docs/introduction.mdx b/docs/pages/index.mdx similarity index 56% rename from docs/introduction.mdx rename to docs/pages/index.mdx index 0ff676ba..97ab9f5e 100644 --- a/docs/introduction.mdx +++ b/docs/pages/index.mdx @@ -2,14 +2,17 @@ title: Introduction --- +import Image from 'next/image'; +import { Callout, Card, Cards, Steps } from 'nextra/components'; +

ort is an open-source Rust binding for ONNX Runtime.

- + These docs are for the latest alpha version of `ort`, `2.0.0-rc.2`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. - + `ort` makes it easy to deploy your machine learning models to production via [ONNX Runtime](https://onnxruntime.ai/), a hardware-accelerated inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA) on almost any hardware, often far faster than PyTorch, and with the added bonus of Rust's efficiency. @@ -29,52 +32,54 @@ Converting a neural network to a graph representation like ONNX opens the door t # Getting started - - If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: - ```toml - [dependencies] - ort = "2.0.0-rc.2" - ``` - - - Your model will need to be converted to an ONNX graph before you can use it. - - The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export 🤗 Transformers models to ONNX with 🤗 Optimum. - - For any PyTorch model: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) - - For `scikit-learn` models: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) - - For TensorFlow, Keras, TFlite, & TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) - - For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) - - - Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): - - ```rust - use ort::{GraphOptimizationLevel, Session}; - - let model = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(4)? - .commit_from_file("yolov8m.onnx")?; - ``` - - - Preprocess your inputs, then `run()` the session to perform inference. - - ```rust - let outputs = model.run(ort::inputs!["image" => image]?)?; - let predictions = outputs["output0"].try_extract_tensor::()?; - ... - ``` - - There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! - + +### Add ort to your Cargo.toml +If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: +```toml +[dependencies] +ort = "2.0.0-rc.2" +``` + +### Convert your model +Your model will need to be converted to an ONNX graph before you can use it. +- The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export 🤗 Transformers models to ONNX with 🤗 Optimum. +- For any PyTorch model: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) +- For `scikit-learn` models: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) +- For TensorFlow, Keras, TFlite, & TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) +- For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) + +### Load your model +Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): + +```rust +use ort::{GraphOptimizationLevel, Session}; + +let model = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)? + .commit_from_file("yolov8m.onnx")?; +``` + +### Perform inference +Preprocess your inputs, then `run()` the session to perform inference. + +```rust +let outputs = model.run(ort::inputs!["image" => image]?)?; +let predictions = outputs["output0"].try_extract_tensor::()?; +... +``` + +There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! + # Next steps - - Use [execution providers](/perf/execution-providers) to enable hardware acceleration in your app and unlock the full power of your GPU or NPU. - - - We'd love to see what you've made with `ort`! Show off your project in [GitHub Discussions](https://github.com/pykeio/ort/discussions/categories/show-and-tell) or on our [Discord](https://discord.gg/uQtsNu2xMa). - + +### Unlock more performance with EPs +Use [execution providers](/perf/execution-providers) to enable hardware acceleration in your app and unlock the full power of your GPU or NPU. + +### Show off your project! +We'd love to see what you've made with `ort`! Show off your project in [GitHub Discussions](https://github.com/pykeio/ort/discussions/categories/show-and-tell) or on our [Discord](https://discord.gg/uQtsNu2xMa). + diff --git a/docs/migrating/opsets.mdx b/docs/pages/migrating/opsets.mdx similarity index 100% rename from docs/migrating/opsets.mdx rename to docs/pages/migrating/opsets.mdx diff --git a/docs/migrating/v2.mdx b/docs/pages/migrating/v2.mdx similarity index 98% rename from docs/migrating/v2.mdx rename to docs/pages/migrating/v2.mdx index a2f20201..a3c1afdf 100644 --- a/docs/migrating/v2.mdx +++ b/docs/pages/migrating/v2.mdx @@ -141,13 +141,6 @@ let noise_pred = unet.run(ort::inputs![ ]?)?; ``` -You can also supply `ort::inputs!` your `IoBinding` by specifying `bind =`: -```rust -let binding = model.create_binding()?; -... -let outputs = model.run(ort::inputs![bind = binding]?)?; -``` - ### Tensor creation no longer requires the session's allocator In previous versions, `Value::from_array` took an allocator parameter. The allocator was only used because the string data contained in string tensors had to be cloned into ONNX Runtime-managed memory. However, 99% of users only ever use primitive tensors, so the extra parameter served little purpose. The new `Tensor::from_array` function now takes only an array, and the logic for converting string arrays has been moved to a new function, `DynTensor::from_string_array`. diff --git a/docs/migrating/version-mapping.mdx b/docs/pages/migrating/version-mapping.mdx similarity index 100% rename from docs/migrating/version-mapping.mdx rename to docs/pages/migrating/version-mapping.mdx diff --git a/docs/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx similarity index 94% rename from docs/perf/execution-providers.mdx rename to docs/pages/perf/execution-providers.mdx index 03447b1e..f2e7b9e0 100644 --- a/docs/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -3,6 +3,8 @@ title: Execution providers description: Learn how to enable execution providers to leverage hardware acceleration. --- +import { Callout, Tabs } from 'nextra/components'; + Execution providers (EPs) enable ONNX Runtime to execute ONNX graphs with hardware acceleration. If you have specialized hardware like a GPU or NPU, execution providers can provide a massive performance boost to your `ort` applications. For more information on the intricacies of execution providers, see the [ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/). ONNX Runtime must be compiled with support for each execution provider. pyke provides precompiled binaries for some of the most common EPs, so you won't need to compile ONNX Runtime from source. Below is a table showing available EPs, their support in `ort`, and their binary availability status. @@ -28,12 +30,12 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro | Microsoft Azure | ❌ | ❌ | ❓ | | Rockchip RKNPU | ❌ | ❌ | ❓ | - + Some EPs supported by ONNX Runtime are not supported by `ort` due to a lack of hardware for testing. If your preferred EP is missing support and you've got the hardware, please [open an issue](https://github.com/pykeio/ort/issues/new)! - + ## Registering execution providers - + To use an execution provider with `ort`, you'll need to enable its respective Cargo feature, e.g. the `cuda` feature to use CUDA, or the `coreml` feature to use CoreML. ```toml Cargo.toml @@ -42,7 +44,7 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro ``` See [Cargo features](/setup/cargo-features) for the full list of features. - + In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session: @@ -167,9 +169,9 @@ fn main() -> anyhow::Result<()> { } ``` - + `ort::init` must come before you create any sessions, otherwise the configuration will not take effect! - + Sessions configured with their own execution providers will *extend* the execution provider defaults, rather than overriding them. @@ -181,32 +183,32 @@ If it seems like the execution provider is not registering properly, or you are ### CoreML Statically linking to CoreML (the default behavior when using downloaded binaries + the `coreml` Cargo feature) requires an additional Rust flag in order to link properly. You'll need to provide the flag `-C link-arg=-fapple-link-rtlib` to `rustc`. You can do this via an entry in [`.cargo/config.toml`](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure), in a build script, or in an environment variable. - - + + See [Configuration: Hierarchical structure](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure) for more information on where the configuration file can be placed. - ```toml .cargo/config.toml + ```toml filename=".cargo/config.toml" copy [target.aarch64-apple-darwin] rustflags = ["-Clink-arg=-fapple-link-rtlib"] [target.x86_64-apple-darwin] rustflags = ["-Clink-arg=-fapple-link-rtlib"] ``` - - + + Add the following to the `build.rs` script of any **binary** crate that uses `ort`. - ```rust build.rs + ```rust filename="build.rs" copy fn main() { println!("cargo:rustc-link-arg=-fapple-link-rtlib"); } ``` Library crates do not need this flag, and the usage of it in a library crate will not transitively apply to any binary crates dependent on it. - - - ```shell + + + ```shell copy $ RUSTFLAGS="-Clink-arg=-fapple-link-rtlib" cargo build ``` - + diff --git a/docs/perf/io-binding.mdx b/docs/pages/perf/io-binding.mdx similarity index 100% rename from docs/perf/io-binding.mdx rename to docs/pages/perf/io-binding.mdx diff --git a/docs/setup/cargo-features.mdx b/docs/pages/setup/cargo-features.mdx similarity index 100% rename from docs/setup/cargo-features.mdx rename to docs/pages/setup/cargo-features.mdx diff --git a/docs/pages/setup/linking.mdx b/docs/pages/setup/linking.mdx new file mode 100644 index 00000000..bdb49234 --- /dev/null +++ b/docs/pages/setup/linking.mdx @@ -0,0 +1,109 @@ +--- +title: Linking +description: Here's how `ort` links to ONNX Runtime, and how to configure its behavior. +--- + +import { Callout, Tabs, Steps } from 'nextra/components'; + +`ort` provides its own builds of ONNX Runtime to make your experience as painless as possible, but in some cases, you'll want to use a custom build of ONNX Runtime with `ort`. Luckily, we make this very easy by handling all of the linking configuration automagically. Just point `ort` to the output of ONNX Runtime's build pipeline and it'll Just Work™. + +## Static linking +Most ONNX Runtime compile configurations will support static linking - just run `build.sh` without the `--build_shared_lib` argument. You should prefer static linking if your execution providers support it, as it avoids many issues and follows de facto Rust practices. If you compile both static libraries and dynamic libraries, `ort` will prefer linking to the static libraries. + +To direct `ort` to your statically built binaries, use the `ORT_LIB_LOCATION` environment variable when running `cargo build`. Point it to the location where the static libraries (`.a`/`.lib` files) are compiled to. This will typically be `onnxruntime/build/`. For example: +```shell +$ ORT_LIB_LOCATION=~/onnxruntime/build/Linux cargo build +``` + +For iOS (or for other platforms if you are compiling multiple profiles at once), you'll need to manually specify the profile with the `ORT_LIB_PROFILE` environment variable. If not specified, `ort` will prefer `Release` over `RelWithDebInfo` over `MinSizeRel` over `Debug`. + +## Dynamic linking +Some execution providers unfortunately only support dynamic linking. Dynamic linking doesn't play well with the Rust ecosystem, though `ort` tries to alleviate the pain as much as possible. + +When it comes to dynamic linking, there are two options: `load-dynamic`, or standard compile-time dynamic linking. We recommend `load-dynamic` as it gives more control and is often far less troublesome to work with. + +### Runtime loading with `load-dynamic` +The `load-dynamic` Cargo feature solves a few of the issues with dynamic linking by **loading the library at runtime** rather than **linking at compile time**. This means that the path to the ONNX Runtime library can be configured at runtime, and the executable will not just completely fail to start if the binary couldn't be found. + +To use `load-dynamic`: + + +#### Enable the feature in Cargo.toml +```toml filename="Cargo.toml" +[dependencies] +ort = { version = "2", features = [ "load-dynamic" ] } +``` + +### Point ort to the dylib + + + ```rust main.rs + fn main() -> anyhow::Result<()> { + // Find our custom ONNX Runtime dylib path somehow + // (i.e. resolving it from the root of our program's install folder) + let dylib_path = crate::internal::find_onnxruntime_dylib()?; + // The path should point to the `libonnxruntime` binary, which looks like: + // - on Unix: /etc/.../libonnxruntime.so + // - on Windows: C:\Program Files\...\onnxruntime.dll + + // Initialize ort with the path to the dylib. This **must** be called before any usage of `ort`! + // `init_from` returns an `EnvironmentBuilder` which you can use to further configure the environment + // before `.commit()`ing; see the Environment docs for more information on what you can configure. + ort::init_from(dylib_path).commit()?; + + Ok(()) + } + ``` + + + Set the `ORT_DYLIB_PATH` environment variable to the path to `libonnxruntime.so`/`onnxruntime.dll`. + + ```shell + $ ORT_DYLIB_PATH=../onnxruntime-build/linux-x64/libonnxruntime.so ./mirai + ``` + + + + + +`ORT_DYLIB_PATH` is relative to the executable. Cargo examples and tests are compiled to a different directory than binary crates: `target//examples` and `target//deps` respectively. Keep this in mind if you're going to use relative paths. + +### Compile-time dynamic linking +For compile-time dynamic linking, you'll need to configure your environment in the exact same way as if you were [statically linking](#static-linking). + +Note that the dylibs then have to be placed in a certain location for them to be found by the executable. For Windows, this is either somewhere on the `PATH`, or in the same folder as the executable. On macOS and Linux, they have to be placed somewhere in the `LD_LIBRARY_PATH`, or you can use rpath to configure the executable to search for dylibs in its parent folder. We've had the least issues with rpath, but YMMV. + +To configure rpath, you'll need to: + +#### Enable rpath in Cargo.toml +```toml filename="Cargo.toml" copy +[profile.dev] +rpath = true + +[profile.release] +rpath = true + +# do this for any other profiles +``` + +### Configure the path in the linker args in .cargo/config.toml to be relative to the executable + + + ```toml filename="~/.cargo/config.toml" copy + [target.x86_64-unknown-linux-gnu] + rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ] + + # do this for any other Linux targets as well + ``` + + + ```toml filename="~/.cargo/config.toml" copy + [target.x86_64-apple-darwin] + rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ] + + # do this for any other macOS targets as well + ``` + + + + diff --git a/docs/setup/platforms.mdx b/docs/pages/setup/platforms.mdx similarity index 63% rename from docs/setup/platforms.mdx rename to docs/pages/setup/platforms.mdx index 02443452..f83d131b 100644 --- a/docs/setup/platforms.mdx +++ b/docs/pages/setup/platforms.mdx @@ -3,6 +3,8 @@ title: Platform support description: ONNX Runtime, and by extension `ort`, supports a wide variety of platforms. For most desktop users, pre-built binaries are available, so setting up `ort` is as simple as adding it to your `Cargo.toml`! --- +import { Callout } from 'nextra/components'; + Here are the supported platforms and binary availability status, as of v2.0.0-rc.2. * 🟢 - Supported. Dynamic & static binaries provided by pyke. @@ -19,14 +21,18 @@ Here are the supported platforms and binary availability status, as of v2.0.0-rc | **Android** | ❌ | ❌ | ⭕ | ⭕ | ❌ | | **Web** | ❌ | ❌ | ❌ | ❌ | 🔷¶ | -\* Recent version of Windows 10/11 required for pyke binaries.
-† glibc ≥ 2.31 (Ubuntu ≥ 20.04) required for pyke binaries.
-‡ glibc ≥ 2.35 (Ubuntu ≥ 22.04) required for pyke binaries.
-§ macOS ≥ 10.15 required.
-¶ WASM supports a limited subset of ONNX Runtime features. For more info, see [the docs on WebAssembly support](/setup/webassembly). +
+

\* Recent version of Windows 10/11 required for pyke binaries.

+

† glibc ≥ 2.31 (Ubuntu ≥ 20.04) required for pyke binaries.

+

‡ glibc ≥ 2.35 (Ubuntu ≥ 22.04) required for pyke binaries.

+

§ macOS ≥ 10.15 required.

+

¶ WASM supports a limited subset of ONNX Runtime features. For more info, see [the docs on WebAssembly support](/setup/webassembly).

+
If your platform is marked as 🟢 or 🔷, you're in luck! Almost no setup will be required to get `ort` up and running. For platforms marked as ⭕, you'll need to [compile ONNX Runtime from source](https://onnxruntime.ai/docs/build/) and then [link `ort` to your custom binaries](/setup/linking) (but don't worry, we made this setup as simple as possible!) -Certain execution providers may not have binaries available. You can check EP binary support in the [Execution providers](/perf/execution-providers) documentation. + + Certain execution providers may not have binaries available. You can check EP binary support in the [Execution providers](/perf/execution-providers) documentation. + diff --git a/docs/setup/webassembly.mdx b/docs/pages/setup/webassembly.mdx similarity index 83% rename from docs/setup/webassembly.mdx rename to docs/pages/setup/webassembly.mdx index 87b1ccf9..3e5e5cc0 100644 --- a/docs/setup/webassembly.mdx +++ b/docs/pages/setup/webassembly.mdx @@ -5,19 +5,13 @@ description: Deploy ONNX models to the web WebAssembly support in `ort` is currently experimental. If you experience any issues using `ort` in WebAssembly, please [open an issue](https://github.com/pykeio/ort/issues/new). -Development of WASM support is done in a separate branch for now, so you'll have to add `ort` as a Git dependency: -```toml Cargo.toml -[dependencies] -ort = { git = "https://github.com/pykeio/ort.git", branch = "wasm32-unknown-unknown" } -``` - By nature, some features of ONNX Runtime are not available in the web. These include: - **Support for `.onnx` models.** You instead need to [convert `.onnx` models to the `.ort` format](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html). - **Runtime graph optimizations**, aka `SessionBuilder::with_optimization_level`. You can statically optimize the graph using the `.ort` conversion tool, though. - **Loading models with `commit_from_file`/`commit_from_url`.** You can create models from a slice of bytes in memory with `SessionBuilder::commit_from_memory` or `SessionBuilder::commit_from_memory_directly`. Additionally, you'll need to call `ort::wasm::initialize()` at the earliest possible point in your code, before you use any `ort` APIs: -```rust main.rs +```rust filename="main.rs" copy use ort::Session; static MODEL_BYTES: &[u8] = include_bytes!("../model.ort"); diff --git a/docs/troubleshooting/compiling.mdx b/docs/pages/troubleshooting/compiling.mdx similarity index 100% rename from docs/troubleshooting/compiling.mdx rename to docs/pages/troubleshooting/compiling.mdx diff --git a/docs/troubleshooting/performance.mdx b/docs/pages/troubleshooting/performance.mdx similarity index 55% rename from docs/troubleshooting/performance.mdx rename to docs/pages/troubleshooting/performance.mdx index 6bf41128..e407895e 100644 --- a/docs/troubleshooting/performance.mdx +++ b/docs/pages/troubleshooting/performance.mdx @@ -2,53 +2,56 @@ title: 'Troubleshoot: Performance' --- +import { Callout, Tabs, Steps } from 'nextra/components'; + ## Execution providers don't seem to register `ort` is designed to fail gracefully when an execution provider is not available. It logs failure events through [`tracing`](https://crates.io/crates/tracing), thus you'll need a library that subscribes to `tracing` events to see the logs. The simplest way to do this is to use [`tracing-subscriber`](https://crates.io/crates/tracing-subscriber). - - ```toml Cargo.toml - [dependencies] - tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } + +### Add tracing-subscriber to your dependencies +```toml Cargo.toml +[dependencies] +tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } +``` + +### Initialize the subscriber in the main function +```rust main.rs +fn main() { + tracing_subscriber::fmt::init(); +} +``` + +### Show debug messages from ort +Set the environment variable `RUST_LOG` to `ort=debug` to see all debug messages from `ort`. + + + ```powershell + $env:RUST_LOG = 'ort=debug'; + cargo run ``` - - - ```rust main.rs - fn main() { - tracing_subscriber::fmt::init(); - } + + + ```cmd + set RUST_LOG=ort=debug + cargo run ``` - - - Set the environment variable `RUST_LOG` to `ort=debug` to see all debug messages from `ort`. - - - ```powershell - $env:RUST_LOG = 'ort=debug'; - cargo run - ``` - - - ```cmd - set RUST_LOG=ort=debug - cargo run - ``` - - - ```shell - RUST_LOG="ort=debug" cargo run - ``` - - - ```shell - RUST_LOG="ort=debug" cargo run - ``` - - - + + + ```shell + RUST_LOG="ort=debug" cargo run + ``` + + + ```shell + RUST_LOG="ort=debug" cargo run + ``` + + + -You can also detect EP regsitration failures programmatically. See [Execution providers: Fallback behavior](/perf/execution-providers#fallback-behavior) for more info. +You can also detect EP regsitration failures programmatically. See [Execution providers: Fallback behavior](/perf/execution-providers#fallback-behavior) for more info. ## Inference is slower than expected There are a few things you could try to improve performance: diff --git a/docs/pnpm-lock.yaml b/docs/pnpm-lock.yaml new file mode 100644 index 00000000..4bebc4f8 --- /dev/null +++ b/docs/pnpm-lock.yaml @@ -0,0 +1,3200 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +importers: + + .: + dependencies: + next: + specifier: ^14.2.3 + version: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra: + specifier: ^2.13.4 + version: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra-theme-docs: + specifier: ^2.13.4 + version: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: + specifier: ^18.3.1 + version: 18.3.1 + react-dom: + specifier: ^18.3.1 + version: 18.3.1(react@18.3.1) + devDependencies: + '@types/node': + specifier: 20.14.2 + version: 20.14.2 + '@types/react': + specifier: ^18.3.3 + version: 18.3.3 + '@types/react-dom': + specifier: ^18.3.0 + version: 18.3.0 + typescript: + specifier: ^5.4.5 + version: 5.4.5 + +packages: + + '@babel/runtime@7.24.7': + resolution: {integrity: sha512-UwgBRMjJP+xv857DCngvqXI3Iq6J4v0wXmwc6sapg+zyhbwmQX67LUEFrkK5tbyJ30jGuG3ZvWpBiB9LCy1kWw==} + engines: {node: '>=6.9.0'} + + '@braintree/sanitize-url@6.0.4': + resolution: {integrity: sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A==} + + '@headlessui/react@1.7.19': + resolution: {integrity: sha512-Ll+8q3OlMJfJbAKM/+/Y2q6PPYbryqNTXDbryx7SXLIDamkF6iQFbriYHga0dY44PvDhvvBWCx1Xj4U5+G4hOw==} + engines: {node: '>=10'} + peerDependencies: + react: ^16 || ^17 || ^18 + react-dom: ^16 || ^17 || ^18 + + '@mdx-js/mdx@2.3.0': + resolution: {integrity: sha512-jLuwRlz8DQfQNiUCJR50Y09CGPq3fLtmtUQfVrj79E0JWu3dvsVcxVIcfhR5h0iXu+/z++zDrYeiJqifRynJkA==} + + '@mdx-js/react@2.3.0': + resolution: {integrity: sha512-zQH//gdOmuu7nt2oJR29vFhDv88oGPmVw6BggmrHeMI+xgEkp1B2dX9/bMBSYtK0dyLX/aOmesKS09g222K1/g==} + peerDependencies: + react: '>=16' + + '@napi-rs/simple-git-android-arm-eabi@0.1.16': + resolution: {integrity: sha512-dbrCL0Pl5KZG7x7tXdtVsA5CO6At5ohDX3myf5xIYn9kN4jDFxsocl8bNt6Vb/hZQoJd8fI+k5VlJt+rFhbdVw==} + engines: {node: '>= 10'} + cpu: [arm] + os: [android] + + '@napi-rs/simple-git-android-arm64@0.1.16': + resolution: {integrity: sha512-xYz+TW5J09iK8SuTAKK2D5MMIsBUXVSs8nYp7HcMi8q6FCRO7yJj96YfP9PvKsc/k64hOyqGmL5DhCzY9Cu1FQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [android] + + '@napi-rs/simple-git-darwin-arm64@0.1.16': + resolution: {integrity: sha512-XfgsYqxhUE022MJobeiX563TJqyQyX4FmYCnqrtJwAfivESVeAJiH6bQIum8dDEYMHXCsG7nL8Ok0Dp8k2m42g==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@napi-rs/simple-git-darwin-x64@0.1.16': + resolution: {integrity: sha512-tkEVBhD6vgRCbeWsaAQqM3bTfpIVGeitamPPRVSbsq8qgzJ5Dx6ZedH27R7KSsA/uao7mZ3dsrNLXbu1Wy5MzA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@napi-rs/simple-git-linux-arm-gnueabihf@0.1.16': + resolution: {integrity: sha512-R6VAyNnp/yRaT7DV1Ao3r67SqTWDa+fNq2LrNy0Z8gXk2wB9ZKlrxFtLPE1WSpWknWtyRDLpRlsorh7Evk7+7w==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@napi-rs/simple-git-linux-arm64-gnu@0.1.16': + resolution: {integrity: sha512-LAGI0opFKw/HBMCV2qIBK3uWSEW9h4xd2ireZKLJy8DBPymX6NrWIamuxYNyCuACnFdPRxR4LaRFy4J5ZwuMdw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@napi-rs/simple-git-linux-arm64-musl@0.1.16': + resolution: {integrity: sha512-I57Ph0F0Yn2KW93ep+V1EzKhACqX0x49vvSiapqIsdDA2PifdEWLc1LJarBolmK7NKoPqKmf6lAKKO9lhiZzkg==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@napi-rs/simple-git-linux-x64-gnu@0.1.16': + resolution: {integrity: sha512-AZYYFY2V7hlcQASPEOWyOa3e1skzTct9QPzz0LiDM3f/hCFY/wBaU2M6NC5iG3d2Kr38heuyFS/+JqxLm5WaKA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@napi-rs/simple-git-linux-x64-musl@0.1.16': + resolution: {integrity: sha512-9TyMcYSBJwjT8jwjY9m24BZbu7ozyWTjsmYBYNtK3B0Um1Ov6jthSNneLVvouQ6x+k3Ow+00TiFh6bvmT00r8g==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@napi-rs/simple-git-win32-arm64-msvc@0.1.16': + resolution: {integrity: sha512-uslJ1WuAHCYJWui6xjsyT47SjX6KOHDtClmNO8hqKz1pmDSNY7AjyUY8HxvD1lK9bDnWwc4JYhikS9cxCqHybw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@napi-rs/simple-git-win32-x64-msvc@0.1.16': + resolution: {integrity: sha512-SoEaVeCZCDF1MP+M9bMSXsZWgEjk4On9GWADO5JOulvzR1bKjk0s9PMHwe/YztR9F0sJzrCxwtvBZowhSJsQPg==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@napi-rs/simple-git@0.1.16': + resolution: {integrity: sha512-C5wRPw9waqL2jk3jEDeJv+f7ScuO3N0a39HVdyFLkwKxHH4Sya4ZbzZsu2JLi6eEqe7RuHipHL6mC7B2OfYZZw==} + engines: {node: '>= 10'} + + '@next/env@14.2.3': + resolution: {integrity: sha512-W7fd7IbkfmeeY2gXrzJYDx8D2lWKbVoTIj1o1ScPHNzvp30s1AuoEFSdr39bC5sjxJaxTtq3OTCZboNp0lNWHA==} + + '@next/swc-darwin-arm64@14.2.3': + resolution: {integrity: sha512-3pEYo/RaGqPP0YzwnlmPN2puaF2WMLM3apt5jLW2fFdXD9+pqcoTzRk+iZsf8ta7+quAe4Q6Ms0nR0SFGFdS1A==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@next/swc-darwin-x64@14.2.3': + resolution: {integrity: sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@next/swc-linux-arm64-gnu@14.2.3': + resolution: {integrity: sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@next/swc-linux-arm64-musl@14.2.3': + resolution: {integrity: sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@next/swc-linux-x64-gnu@14.2.3': + resolution: {integrity: sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@next/swc-linux-x64-musl@14.2.3': + resolution: {integrity: sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@next/swc-win32-arm64-msvc@14.2.3': + resolution: {integrity: sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@next/swc-win32-ia32-msvc@14.2.3': + resolution: {integrity: sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==} + engines: {node: '>= 10'} + cpu: [ia32] + os: [win32] + + '@next/swc-win32-x64-msvc@14.2.3': + resolution: {integrity: sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@popperjs/core@2.11.8': + resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==} + + '@swc/counter@0.1.3': + resolution: {integrity: sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ==} + + '@swc/helpers@0.5.5': + resolution: {integrity: sha512-KGYxvIOXcceOAbEk4bi/dVLEK9z8sZ0uBB3Il5b1rhfClSpcX0yfRO0KmTkqR2cnQDymwLB+25ZyMzICg/cm/A==} + + '@tanstack/react-virtual@3.5.1': + resolution: {integrity: sha512-jIsuhfgy8GqA67PdWqg73ZB2LFE+HD9hjWL1L6ifEIZVyZVAKpYmgUG4WsKQ005aEyImJmbuimPiEvc57IY0Aw==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 + + '@tanstack/virtual-core@3.5.1': + resolution: {integrity: sha512-046+AUSiDru/V9pajE1du8WayvBKeCvJ2NmKPy/mR8/SbKKrqmSbj7LJBfXE+nSq4f5TBXvnCzu0kcYebI9WdQ==} + + '@theguild/remark-mermaid@0.0.5': + resolution: {integrity: sha512-e+ZIyJkEv9jabI4m7q29wZtZv+2iwPGsXJ2d46Zi7e+QcFudiyuqhLhHG/3gX3ZEB+hxTch+fpItyMS8jwbIcw==} + peerDependencies: + react: ^18.2.0 + + '@theguild/remark-npm2yarn@0.2.1': + resolution: {integrity: sha512-jUTFWwDxtLEFtGZh/TW/w30ySaDJ8atKWH8dq2/IiQF61dPrGfETpl0WxD0VdBfuLOeU14/kop466oBSRO/5CA==} + + '@types/acorn@4.0.6': + resolution: {integrity: sha512-veQTnWP+1D/xbxVrPC3zHnCZRjSrKfhbMUlEA43iMZLu7EsnTtkJklIuwrCPbOi8YkvDQAiW05VQQFvvz9oieQ==} + + '@types/d3-scale-chromatic@3.0.3': + resolution: {integrity: sha512-laXM4+1o5ImZv3RpFAsTRn3TEkzqkytiOY0Dz0sq5cnd1dtNlk6sHLon4OvqaiJb28T0S/TdsBI3Sjsy+keJrw==} + + '@types/d3-scale@4.0.8': + resolution: {integrity: sha512-gkK1VVTr5iNiYJ7vWDI+yUFFlszhNMtVeneJ6lUTKPjprsvLLI9/tgEGiXJOnlINJA8FyA88gfnQsHbybVZrYQ==} + + '@types/d3-time@3.0.3': + resolution: {integrity: sha512-2p6olUZ4w3s+07q3Tm2dbiMZy5pCDfYwtLXXHUnVzXgQlZ/OyPtUz6OL382BkOuGlLXqfT+wqv8Fw2v8/0geBw==} + + '@types/debug@4.1.12': + resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + + '@types/estree-jsx@1.0.5': + resolution: {integrity: sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==} + + '@types/estree@1.0.5': + resolution: {integrity: sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==} + + '@types/hast@2.3.10': + resolution: {integrity: sha512-McWspRw8xx8J9HurkVBfYj0xKoE25tOFlHGdx4MJ5xORQrMGZNqJhVQWaIbm6Oyla5kYOXtDiopzKRJzEOkwJw==} + + '@types/hast@3.0.4': + resolution: {integrity: sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==} + + '@types/js-yaml@4.0.9': + resolution: {integrity: sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==} + + '@types/katex@0.16.7': + resolution: {integrity: sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==} + + '@types/mdast@3.0.15': + resolution: {integrity: sha512-LnwD+mUEfxWMa1QpDraczIn6k0Ee3SMicuYSSzS6ZYl2gKS09EClnJYGd8Du6rfc5r/GZEk5o1mRb8TaTj03sQ==} + + '@types/mdast@4.0.4': + resolution: {integrity: sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==} + + '@types/mdx@2.0.13': + resolution: {integrity: sha512-+OWZQfAYyio6YkJb3HLxDrvnx6SWWDbC0zVPfBRzUk0/nqoDyf6dNxQi3eArPe8rJ473nobTMQ/8Zk+LxJ+Yuw==} + + '@types/ms@0.7.34': + resolution: {integrity: sha512-nG96G3Wp6acyAgJqGasjODb+acrI7KltPiRxzHPXnP3NgI28bpQDRv53olbqGXbfcgF5aiiHmO3xpwEpS5Ld9g==} + + '@types/node@20.14.2': + resolution: {integrity: sha512-xyu6WAMVwv6AKFLB+e/7ySZVr/0zLCzOa7rSpq6jNwpqOrUbcACDWC+53d4n2QHOnDou0fbIsg8wZu/sxrnI4Q==} + + '@types/prop-types@15.7.12': + resolution: {integrity: sha512-5zvhXYtRNRluoE/jAp4GVsSduVUzNWKkOZrCDBWYtE7biZywwdC2AcEzg+cSMLFRfVgeAFqpfNabiPjxFddV1Q==} + + '@types/react-dom@18.3.0': + resolution: {integrity: sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==} + + '@types/react@18.3.3': + resolution: {integrity: sha512-hti/R0pS0q1/xx+TsI73XIqk26eBsISZ2R0wUijXIngRK9R/e7Xw/cXVxQK7R5JjW+SV4zGcn5hXjudkN/pLIw==} + + '@types/unist@2.0.10': + resolution: {integrity: sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==} + + '@types/unist@3.0.2': + resolution: {integrity: sha512-dqId9J8K/vGi5Zr7oo212BGii5m3q5Hxlkwy3WpYuKPklmBEvsbMYYyLxAQpSffdLl/gdW0XUpKWFvYmyoWCoQ==} + + '@ungap/structured-clone@1.2.0': + resolution: {integrity: sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==} + + acorn-jsx@5.3.2: + resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} + peerDependencies: + acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 + + acorn@8.11.3: + resolution: {integrity: sha512-Y9rRfJG5jcKOE0CLisYbojUjIrIEE7AGMzA/Sm4BslANhbS+cDMpgBdcPT91oJ7OuJ9hYJBx59RjbhxVnrF8Xg==} + engines: {node: '>=0.4.0'} + hasBin: true + + ansi-sequence-parser@1.1.1: + resolution: {integrity: sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==} + + ansi-styles@3.2.1: + resolution: {integrity: sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==} + engines: {node: '>=4'} + + arch@2.2.0: + resolution: {integrity: sha512-Of/R0wqp83cgHozfIYLbBMnej79U/SVGOOyuB3VVFv1NRM/PSFMK12x9KVtiYzJqmnU5WR2qp0Z5rHb7sWGnFQ==} + + arg@1.0.0: + resolution: {integrity: sha512-Wk7TEzl1KqvTGs/uyhmHO/3XLd3t1UeU4IstvPXVzGPM522cTjqjNZ99esCkcL52sjqjo8e8CTBcWhkxvGzoAw==} + + argparse@1.0.10: + resolution: {integrity: sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==} + + argparse@2.0.1: + resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} + + astring@1.8.6: + resolution: {integrity: sha512-ISvCdHdlTDlH5IpxQJIex7BWBywFWgjJSVdwst+/iQCoEYnyOaQ95+X1JGshuBjGp6nxKUy1jMgE3zPqN7fQdg==} + hasBin: true + + bail@2.0.2: + resolution: {integrity: sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==} + + busboy@1.6.0: + resolution: {integrity: sha512-8SFQbg/0hQ9xy3UNTB0YEnsNBbWfhf7RtnzpL7TkBiTBRfrQ9Fxcnz7VJsleJpyp6rVLvXiuORqjlHi5q+PYuA==} + engines: {node: '>=10.16.0'} + + caniuse-lite@1.0.30001629: + resolution: {integrity: sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==} + + ccount@2.0.1: + resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} + + chalk@2.3.0: + resolution: {integrity: sha512-Az5zJR2CBujap2rqXGaJKaPHyJ0IrUimvYNX+ncCy8PJP4ltOGTrHUIo097ZaL2zMeKYpiCdqDvS6zdrTFok3Q==} + engines: {node: '>=4'} + + character-entities-html4@2.1.0: + resolution: {integrity: sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==} + + character-entities-legacy@3.0.0: + resolution: {integrity: sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==} + + character-entities@2.0.2: + resolution: {integrity: sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==} + + character-reference-invalid@2.0.1: + resolution: {integrity: sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==} + + client-only@0.0.1: + resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==} + + clipboardy@1.2.2: + resolution: {integrity: sha512-16KrBOV7bHmHdxcQiCvfUFYVFyEah4FI8vYT1Fr7CGSA4G+xBWMEfUEQJS1hxeHGtI9ju1Bzs9uXSbj5HZKArw==} + engines: {node: '>=4'} + + clsx@2.1.1: + resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==} + engines: {node: '>=6'} + + color-convert@1.9.3: + resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==} + + color-name@1.1.3: + resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==} + + comma-separated-tokens@2.0.3: + resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==} + + commander@7.2.0: + resolution: {integrity: sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==} + engines: {node: '>= 10'} + + commander@8.3.0: + resolution: {integrity: sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==} + engines: {node: '>= 12'} + + compute-scroll-into-view@3.1.0: + resolution: {integrity: sha512-rj8l8pD4bJ1nx+dAkMhV1xB5RuZEyVysfxJqB1pRchh1KVvwOv9b7CGB8ZfjTImVv2oF+sYMUkMZq6Na5Ftmbg==} + + cose-base@1.0.3: + resolution: {integrity: sha512-s9whTXInMSgAp/NVXVNuVxVKzGH2qck3aQlVHxDCdAEPgtMKwc4Wq6/QKhgdEdgbLSi9rBTAcPoRa6JpiG4ksg==} + + cross-spawn@5.1.0: + resolution: {integrity: sha512-pTgQJ5KC0d2hcY8eyL1IzlBPYjTkyH72XRZPnLyKus2mBfNjQs3klqbJU2VILqZryAZUt9JOb3h/mWMy23/f5A==} + + csstype@3.1.3: + resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==} + + cytoscape-cose-bilkent@4.1.0: + resolution: {integrity: sha512-wgQlVIUJF13Quxiv5e1gstZ08rnZj2XaLHGoFMYXz7SkNfCDOOteKBE6SYRfA9WxxI/iBc3ajfDoc6hb/MRAHQ==} + peerDependencies: + cytoscape: ^3.2.0 + + cytoscape@3.29.2: + resolution: {integrity: sha512-2G1ycU28Nh7OHT9rkXRLpCDP30MKH1dXJORZuBhtEhEW7pKwgPi77ImqlCWinouyE1PNepIOGZBOrE84DG7LyQ==} + engines: {node: '>=0.10'} + + d3-array@2.12.1: + resolution: {integrity: sha512-B0ErZK/66mHtEsR1TkPEEkwdy+WDesimkM5gpZr5Dsg54BiTA5RXtYW5qTLIAcekaS9xfZrzBLF/OAkB3Qn1YQ==} + + d3-array@3.2.4: + resolution: {integrity: sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==} + engines: {node: '>=12'} + + d3-axis@3.0.0: + resolution: {integrity: sha512-IH5tgjV4jE/GhHkRV0HiVYPDtvfjHQlQfJHs0usq7M30XcSBvOotpmH1IgkcXsO/5gEQZD43B//fc7SRT5S+xw==} + engines: {node: '>=12'} + + d3-brush@3.0.0: + resolution: {integrity: sha512-ALnjWlVYkXsVIGlOsuWH1+3udkYFI48Ljihfnh8FZPF2QS9o+PzGLBslO0PjzVoHLZ2KCVgAM8NVkXPJB2aNnQ==} + engines: {node: '>=12'} + + d3-chord@3.0.1: + resolution: {integrity: sha512-VE5S6TNa+j8msksl7HwjxMHDM2yNK3XCkusIlpX5kwauBfXuyLAtNg9jCp/iHH61tgI4sb6R/EIMWCqEIdjT/g==} + engines: {node: '>=12'} + + d3-color@3.1.0: + resolution: {integrity: sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==} + engines: {node: '>=12'} + + d3-contour@4.0.2: + resolution: {integrity: sha512-4EzFTRIikzs47RGmdxbeUvLWtGedDUNkTcmzoeyg4sP/dvCexO47AaQL7VKy/gul85TOxw+IBgA8US2xwbToNA==} + engines: {node: '>=12'} + + d3-delaunay@6.0.4: + resolution: {integrity: sha512-mdjtIZ1XLAM8bm/hx3WwjfHt6Sggek7qH043O8KEjDXN40xi3vx/6pYSVTwLjEgiXQTbvaouWKynLBiUZ6SK6A==} + engines: {node: '>=12'} + + d3-dispatch@3.0.1: + resolution: {integrity: sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==} + engines: {node: '>=12'} + + d3-drag@3.0.0: + resolution: {integrity: sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==} + engines: {node: '>=12'} + + d3-dsv@3.0.1: + resolution: {integrity: sha512-UG6OvdI5afDIFP9w4G0mNq50dSOsXHJaRE8arAS5o9ApWnIElp8GZw1Dun8vP8OyHOZ/QJUKUJwxiiCCnUwm+Q==} + engines: {node: '>=12'} + hasBin: true + + d3-ease@3.0.1: + resolution: {integrity: sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==} + engines: {node: '>=12'} + + d3-fetch@3.0.1: + resolution: {integrity: sha512-kpkQIM20n3oLVBKGg6oHrUchHM3xODkTzjMoj7aWQFq5QEM+R6E4WkzT5+tojDY7yjez8KgCBRoj4aEr99Fdqw==} + engines: {node: '>=12'} + + d3-force@3.0.0: + resolution: {integrity: sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==} + engines: {node: '>=12'} + + d3-format@3.1.0: + resolution: {integrity: sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA==} + engines: {node: '>=12'} + + d3-geo@3.1.1: + resolution: {integrity: sha512-637ln3gXKXOwhalDzinUgY83KzNWZRKbYubaG+fGVuc/dxO64RRljtCTnf5ecMyE1RIdtqpkVcq0IbtU2S8j2Q==} + engines: {node: '>=12'} + + d3-hierarchy@3.1.2: + resolution: {integrity: sha512-FX/9frcub54beBdugHjDCdikxThEqjnR93Qt7PvQTOHxyiNCAlvMrHhclk3cD5VeAaq9fxmfRp+CnWw9rEMBuA==} + engines: {node: '>=12'} + + d3-interpolate@3.0.1: + resolution: {integrity: sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==} + engines: {node: '>=12'} + + d3-path@1.0.9: + resolution: {integrity: sha512-VLaYcn81dtHVTjEHd8B+pbe9yHWpXKZUC87PzoFmsFrJqgFwDe/qxfp5MlfsfM1V5E/iVt0MmEbWQ7FVIXh/bg==} + + d3-path@3.1.0: + resolution: {integrity: sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==} + engines: {node: '>=12'} + + d3-polygon@3.0.1: + resolution: {integrity: sha512-3vbA7vXYwfe1SYhED++fPUQlWSYTTGmFmQiany/gdbiWgU/iEyQzyymwL9SkJjFFuCS4902BSzewVGsHHmHtXg==} + engines: {node: '>=12'} + + d3-quadtree@3.0.1: + resolution: {integrity: sha512-04xDrxQTDTCFwP5H6hRhsRcb9xxv2RzkcsygFzmkSIOJy3PeRJP7sNk3VRIbKXcog561P9oU0/rVH6vDROAgUw==} + engines: {node: '>=12'} + + d3-random@3.0.1: + resolution: {integrity: sha512-FXMe9GfxTxqd5D6jFsQ+DJ8BJS4E/fT5mqqdjovykEB2oFbTMDVdg1MGFxfQW+FBOGoB++k8swBrgwSHT1cUXQ==} + engines: {node: '>=12'} + + d3-sankey@0.12.3: + resolution: {integrity: sha512-nQhsBRmM19Ax5xEIPLMY9ZmJ/cDvd1BG3UVvt5h3WRxKg5zGRbvnteTyWAbzeSvlh3tW7ZEmq4VwR5mB3tutmQ==} + + d3-scale-chromatic@3.1.0: + resolution: {integrity: sha512-A3s5PWiZ9YCXFye1o246KoscMWqf8BsD9eRiJ3He7C9OBaxKhAd5TFCdEx/7VbKtxxTsu//1mMJFrEt572cEyQ==} + engines: {node: '>=12'} + + d3-scale@4.0.2: + resolution: {integrity: sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==} + engines: {node: '>=12'} + + d3-selection@3.0.0: + resolution: {integrity: sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==} + engines: {node: '>=12'} + + d3-shape@1.3.7: + resolution: {integrity: sha512-EUkvKjqPFUAZyOlhY5gzCxCeI0Aep04LwIRpsZ/mLFelJiUfnK56jo5JMDSE7yyP2kLSb6LtF+S5chMk7uqPqw==} + + d3-shape@3.2.0: + resolution: {integrity: sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==} + engines: {node: '>=12'} + + d3-time-format@4.1.0: + resolution: {integrity: sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==} + engines: {node: '>=12'} + + d3-time@3.1.0: + resolution: {integrity: sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==} + engines: {node: '>=12'} + + d3-timer@3.0.1: + resolution: {integrity: sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==} + engines: {node: '>=12'} + + d3-transition@3.0.1: + resolution: {integrity: sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==} + engines: {node: '>=12'} + peerDependencies: + d3-selection: 2 - 3 + + d3-zoom@3.0.0: + resolution: {integrity: sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==} + engines: {node: '>=12'} + + d3@7.9.0: + resolution: {integrity: sha512-e1U46jVP+w7Iut8Jt8ri1YsPOvFpg46k+K8TpCb0P+zjCkjkPnV7WzfDJzMHy1LnA+wj5pLT1wjO901gLXeEhA==} + engines: {node: '>=12'} + + dagre-d3-es@7.0.10: + resolution: {integrity: sha512-qTCQmEhcynucuaZgY5/+ti3X/rnszKZhEQH/ZdWdtP1tA/y3VoHJzcVrO9pjjJCNpigfscAtoUB5ONcd2wNn0A==} + + dayjs@1.11.11: + resolution: {integrity: sha512-okzr3f11N6WuqYtZSvm+F776mB41wRZMhKP+hc34YdW+KmtYYK9iqvHSwo2k9FEH3fhGXvOPV6yz2IcSrfRUDg==} + + debug@4.3.5: + resolution: {integrity: sha512-pt0bNEmneDIvdL1Xsd9oDQ/wrQRkXDT4AUWlNZNPKvW5x/jyO9VFXkJUP07vQ2upmw5PlaITaPKc31jK13V+jg==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + decode-named-character-reference@1.0.2: + resolution: {integrity: sha512-O8x12RzrUF8xyVcY0KJowWsmaJxQbmy0/EtnNtHRpsOcT7dFk5W598coHqBVpmWo1oQQfsCqfCmkZN5DJrZVdg==} + + delaunator@5.0.1: + resolution: {integrity: sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==} + + dequal@2.0.3: + resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} + engines: {node: '>=6'} + + devlop@1.1.0: + resolution: {integrity: sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==} + + diff@5.2.0: + resolution: {integrity: sha512-uIFDxqpRZGZ6ThOk84hEfqWoHx2devRFvpTZcTHur85vImfaxUbTW9Ryh4CpCuDnToOP1CEtXKIgytHBPVff5A==} + engines: {node: '>=0.3.1'} + + dompurify@3.1.5: + resolution: {integrity: sha512-lwG+n5h8QNpxtyrJW/gJWckL+1/DQiYMX8f7t8Z2AZTPw1esVrqjI63i7Zc2Gz0aKzLVMYC1V1PL/ky+aY/NgA==} + + elkjs@0.9.3: + resolution: {integrity: sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==} + + entities@4.5.0: + resolution: {integrity: sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==} + engines: {node: '>=0.12'} + + escape-string-regexp@1.0.5: + resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==} + engines: {node: '>=0.8.0'} + + escape-string-regexp@5.0.0: + resolution: {integrity: sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw==} + engines: {node: '>=12'} + + esprima@4.0.1: + resolution: {integrity: sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==} + engines: {node: '>=4'} + hasBin: true + + estree-util-attach-comments@2.1.1: + resolution: {integrity: sha512-+5Ba/xGGS6mnwFbXIuQiDPTbuTxuMCooq3arVv7gPZtYpjp+VXH/NkHAP35OOefPhNG/UGqU3vt/LTABwcHX0w==} + + estree-util-build-jsx@2.2.2: + resolution: {integrity: sha512-m56vOXcOBuaF+Igpb9OPAy7f9w9OIkb5yhjsZuaPm7HoGi4oTOQi0h2+yZ+AtKklYFZ+rPC4n0wYCJCEU1ONqg==} + + estree-util-is-identifier-name@2.1.0: + resolution: {integrity: sha512-bEN9VHRyXAUOjkKVQVvArFym08BTWB0aJPppZZr0UNyAqWsLaVfAqP7hbaTJjzHifmB5ebnR8Wm7r7yGN/HonQ==} + + estree-util-to-js@1.2.0: + resolution: {integrity: sha512-IzU74r1PK5IMMGZXUVZbmiu4A1uhiPgW5hm1GjcOfr4ZzHaMPpLNJjR7HjXiIOzi25nZDrgFTobHTkV5Q6ITjA==} + + estree-util-value-to-estree@1.3.0: + resolution: {integrity: sha512-Y+ughcF9jSUJvncXwqRageavjrNPAI+1M/L3BI3PyLp1nmgYTGUXU6t5z1Y7OWuThoDdhPME07bQU+d5LxdJqw==} + engines: {node: '>=12.0.0'} + + estree-util-visit@1.2.1: + resolution: {integrity: sha512-xbgqcrkIVbIG+lI/gzbvd9SGTJL4zqJKBFttUl5pP27KhAjtMKbX/mQXJ7qgyXpMgVy/zvpm0xoQQaGL8OloOw==} + + estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + + execa@0.8.0: + resolution: {integrity: sha512-zDWS+Rb1E8BlqqhALSt9kUhss8Qq4nN3iof3gsOdyINksElaPyNBtKUMTR62qhvgVWR0CqCX7sdnKe4MnUbFEA==} + engines: {node: '>=4'} + + extend-shallow@2.0.1: + resolution: {integrity: sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==} + engines: {node: '>=0.10.0'} + + extend@3.0.2: + resolution: {integrity: sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==} + + flexsearch@0.7.43: + resolution: {integrity: sha512-c5o/+Um8aqCSOXGcZoqZOm+NqtVwNsvVpWv6lfmSclU954O3wvQKxxK8zj74fPaSJbXpSLTs4PRhh+wnoCXnKg==} + + focus-visible@5.2.0: + resolution: {integrity: sha512-Rwix9pBtC1Nuy5wysTmKy+UjbDJpIfg8eHjw0rjZ1mX4GNLz1Bmd16uDpI3Gk1i70Fgcs8Csg2lPm8HULFg9DQ==} + + get-stream@3.0.0: + resolution: {integrity: sha512-GlhdIUuVakc8SJ6kK0zAFbiGzRFzNnY4jUuEbV9UROo4Y+0Ny4fjvcZFVTeDA4odpFyOQzaw6hXukJSq/f28sQ==} + engines: {node: '>=4'} + + git-up@7.0.0: + resolution: {integrity: sha512-ONdIrbBCFusq1Oy0sC71F5azx8bVkvtZtMJAsv+a6lz5YAmbNnLD6HAB4gptHZVLPR8S2/kVN6Gab7lryq5+lQ==} + + git-url-parse@13.1.1: + resolution: {integrity: sha512-PCFJyeSSdtnbfhSNRw9Wk96dDCNx+sogTe4YNXeXSJxt7xz5hvXekuRn9JX7m+Mf4OscCu8h+mtAl3+h5Fo8lQ==} + + github-slugger@2.0.0: + resolution: {integrity: sha512-IaOQ9puYtjrkq7Y0Ygl9KDZnrf/aiUJYUpVf89y8kyaxbRG7Y1SrX/jaumrv81vc61+kiMempujsM3Yw7w5qcw==} + + graceful-fs@4.2.11: + resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} + + gray-matter@4.0.3: + resolution: {integrity: sha512-5v6yZd4JK3eMI3FqqCouswVqwugaA9r4dNZB1wwcmrD02QkV5H0y7XBQW8QwQqEaZY1pM9aqORSORhJRdNK44Q==} + engines: {node: '>=6.0'} + + has-flag@2.0.0: + resolution: {integrity: sha512-P+1n3MnwjR/Epg9BBo1KT8qbye2g2Ou4sFumihwt6I4tsUX7jnLcX4BTOSKg/B1ZrIYMN9FcEnG4x5a7NB8Eng==} + engines: {node: '>=0.10.0'} + + hash-obj@4.0.0: + resolution: {integrity: sha512-FwO1BUVWkyHasWDW4S8o0ssQXjvyghLV2rfVhnN36b2bbcj45eGiuzdn9XOvOpjV3TKQD7Gm2BWNXdE9V4KKYg==} + engines: {node: '>=12'} + + hast-util-from-dom@5.0.0: + resolution: {integrity: sha512-d6235voAp/XR3Hh5uy7aGLbM3S4KamdW0WEgOaU1YoewnuYw4HXb5eRtv9g65m/RFGEfUY1Mw4UqCc5Y8L4Stg==} + + hast-util-from-html-isomorphic@2.0.0: + resolution: {integrity: sha512-zJfpXq44yff2hmE0XmwEOzdWin5xwH+QIhMLOScpX91e/NSGPsAzNCvLQDIEPyO2TXi+lBmU6hjLIhV8MwP2kw==} + + hast-util-from-html@2.0.1: + resolution: {integrity: sha512-RXQBLMl9kjKVNkJTIO6bZyb2n+cUH8LFaSSzo82jiLT6Tfc+Pt7VQCS+/h3YwG4jaNE2TA2sdJisGWR+aJrp0g==} + + hast-util-from-parse5@8.0.1: + resolution: {integrity: sha512-Er/Iixbc7IEa7r/XLtuG52zoqn/b3Xng/w6aZQ0xGVxzhw5xUFxcRqdPzP6yFi/4HBYRaifaI5fQ1RH8n0ZeOQ==} + + hast-util-is-element@3.0.0: + resolution: {integrity: sha512-Val9mnv2IWpLbNPqc/pUem+a7Ipj2aHacCwgNfTiK0vJKl0LF+4Ba4+v1oPHFpf3bLYmreq0/l3Gud9S5OH42g==} + + hast-util-parse-selector@4.0.0: + resolution: {integrity: sha512-wkQCkSYoOGCRKERFWcxMVMOcYE2K1AaNLU8DXS9arxnLOUEWbOXKXiJUNzEpqZ3JOKpnha3jkFrumEjVliDe7A==} + + hast-util-raw@9.0.3: + resolution: {integrity: sha512-ICWvVOF2fq4+7CMmtCPD5CM4QKjPbHpPotE6+8tDooV0ZuyJVUzHsrNX+O5NaRbieTf0F7FfeBOMAwi6Td0+yQ==} + + hast-util-to-estree@2.3.3: + resolution: {integrity: sha512-ihhPIUPxN0v0w6M5+IiAZZrn0LH2uZomeWwhn7uP7avZC6TE7lIiEh2yBMPr5+zi1aUCXq6VoYRgs2Bw9xmycQ==} + + hast-util-to-parse5@8.0.0: + resolution: {integrity: sha512-3KKrV5ZVI8if87DVSi1vDeByYrkGzg4mEfeu4alwgmmIeARiBLKCZS2uw5Gb6nU9x9Yufyj3iudm6i7nl52PFw==} + + hast-util-to-text@4.0.2: + resolution: {integrity: sha512-KK6y/BN8lbaq654j7JgBydev7wuNMcID54lkRav1P0CaE1e47P72AWWPiGKXTJU271ooYzcvTAn/Zt0REnvc7A==} + + hast-util-whitespace@2.0.1: + resolution: {integrity: sha512-nAxA0v8+vXSBDt3AnRUNjyRIQ0rD+ntpbAp4LnPkumc5M9yUbSMa4XDU9Q6etY4f1Wp4bNgvc1yjiZtsTTrSng==} + + hastscript@8.0.0: + resolution: {integrity: sha512-dMOtzCEd3ABUeSIISmrETiKuyydk1w0pa+gE/uormcTpSYuaNJPbX1NU3JLyscSLjwAQM8bWMhhIlnCqnRvDTw==} + + html-void-elements@3.0.0: + resolution: {integrity: sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==} + + iconv-lite@0.6.3: + resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} + engines: {node: '>=0.10.0'} + + inline-style-parser@0.1.1: + resolution: {integrity: sha512-7NXolsK4CAS5+xvdj5OMMbI962hU/wvwoxk+LWR9Ek9bVtyuuYScDN6eS0rUm6TxApFpw7CX1o4uJzcd4AyD3Q==} + + internmap@1.0.1: + resolution: {integrity: sha512-lDB5YccMydFBtasVtxnZ3MRBHuaoE8GKsppq+EchKL2U4nK/DmEpPHNH8MZe5HkMtpSiTSOZwfN0tzYjO/lJEw==} + + internmap@2.0.3: + resolution: {integrity: sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==} + engines: {node: '>=12'} + + intersection-observer@0.12.2: + resolution: {integrity: sha512-7m1vEcPCxXYI8HqnL8CKI6siDyD+eIWSwgB3DZA+ZTogxk9I4CDnj4wilt9x/+/QbHI4YG5YZNmC6458/e9Ktg==} + + is-alphabetical@2.0.1: + resolution: {integrity: sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==} + + is-alphanumerical@2.0.1: + resolution: {integrity: sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==} + + is-buffer@2.0.5: + resolution: {integrity: sha512-i2R6zNFDwgEHJyQUtJEk0XFi1i0dPFn/oqjK3/vPCcDeJvW5NQ83V8QbicfF1SupOaB0h8ntgBC2YiE7dfyctQ==} + engines: {node: '>=4'} + + is-decimal@2.0.1: + resolution: {integrity: sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==} + + is-extendable@0.1.1: + resolution: {integrity: sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==} + engines: {node: '>=0.10.0'} + + is-hexadecimal@2.0.1: + resolution: {integrity: sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==} + + is-obj@3.0.0: + resolution: {integrity: sha512-IlsXEHOjtKhpN8r/tRFj2nDyTmHvcfNeu/nrRIcXE17ROeatXchkojffa1SpdqW4cr/Fj6QkEf/Gn4zf6KKvEQ==} + engines: {node: '>=12'} + + is-plain-obj@3.0.0: + resolution: {integrity: sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA==} + engines: {node: '>=10'} + + is-plain-obj@4.1.0: + resolution: {integrity: sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==} + engines: {node: '>=12'} + + is-reference@3.0.2: + resolution: {integrity: sha512-v3rht/LgVcsdZa3O2Nqs+NMowLOxeOm7Ay9+/ARQ2F+qEoANRcqrjAZKGN0v8ymUetZGgkp26LTnGT7H0Qo9Pg==} + + is-ssh@1.4.0: + resolution: {integrity: sha512-x7+VxdxOdlV3CYpjvRLBv5Lo9OJerlYanjwFrPR9fuGPjCiNiCzFgAWpiLAohSbsnH4ZAys3SBh+hq5rJosxUQ==} + + is-stream@1.1.0: + resolution: {integrity: sha512-uQPm8kcs47jx38atAcWTVxyltQYoPT68y9aWYdV6yWXSyW8mzSat0TL6CiWdZeCdF3KrAvpVtnHbTv4RN+rqdQ==} + engines: {node: '>=0.10.0'} + + isexe@2.0.0: + resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} + + js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + + js-yaml@3.14.1: + resolution: {integrity: sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==} + hasBin: true + + js-yaml@4.1.0: + resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} + hasBin: true + + jsonc-parser@3.2.1: + resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==} + + katex@0.16.10: + resolution: {integrity: sha512-ZiqaC04tp2O5utMsl2TEZTXxa6WSC4yo0fv5ML++D3QZv/vx2Mct0mTlRx3O+uUkjfuAgOkzsCmq5MiUEsDDdA==} + hasBin: true + + khroma@2.1.0: + resolution: {integrity: sha512-Ls993zuzfayK269Svk9hzpeGUKob/sIgZzyHYdjQoAdQetRKpOLj+k/QQQ/6Qi0Yz65mlROrfd+Ev+1+7dz9Kw==} + + kind-of@6.0.3: + resolution: {integrity: sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==} + engines: {node: '>=0.10.0'} + + kleur@4.1.5: + resolution: {integrity: sha512-o+NO+8WrRiQEE4/7nwRJhN1HWpVmJm511pBHUxPLtp0BUISzlBplORYSmTclCnJvQq2tKu/sgl3xVpkc7ZWuQQ==} + engines: {node: '>=6'} + + layout-base@1.0.2: + resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==} + + lodash-es@4.17.21: + resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==} + + lodash.get@4.4.2: + resolution: {integrity: sha512-z+Uw/vLuy6gQe8cfaFWD7p0wVv8fJl3mbzXh33RS+0oW2wvUqiRXiQ69gLWSLpgB5/6sU+r6BlQR0MBILadqTQ==} + + longest-streak@3.1.0: + resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} + + loose-envify@1.4.0: + resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} + hasBin: true + + lru-cache@4.1.5: + resolution: {integrity: sha512-sWZlbEP2OsHNkXrMl5GYk/jKk70MBng6UU4YI/qGDYbgf6YbP4EvmqISbXCoJiRKs+1bSpFHVgQxvJ17F2li5g==} + + markdown-extensions@1.1.1: + resolution: {integrity: sha512-WWC0ZuMzCyDHYCasEGs4IPvLyTGftYwh6wIEOULOF0HXcqZlhwRzrK0w2VUlxWA98xnvb/jszw4ZSkJ6ADpM6Q==} + engines: {node: '>=0.10.0'} + + markdown-table@3.0.3: + resolution: {integrity: sha512-Z1NL3Tb1M9wH4XESsCDEksWoKTdlUafKc4pt0GRwjUyXaCFZ+dc3g2erqB6zm3szA2IUSi7VnPI+o/9jnxh9hw==} + + match-sorter@6.3.4: + resolution: {integrity: sha512-jfZW7cWS5y/1xswZo8VBOdudUiSd9nifYRWphc9M5D/ee4w4AoXLgBEdRbgVaxbMuagBPeUC5y2Hi8DO6o9aDg==} + + mdast-util-definitions@5.1.2: + resolution: {integrity: sha512-8SVPMuHqlPME/z3gqVwWY4zVXn8lqKv/pAhC57FuJ40ImXyBpmO5ukh98zB2v7Blql2FiHjHv9LVztSIqjY+MA==} + + mdast-util-find-and-replace@2.2.2: + resolution: {integrity: sha512-MTtdFRz/eMDHXzeK6W3dO7mXUlF82Gom4y0oOgvHhh/HXZAGvIQDUvQ0SuUx+j2tv44b8xTHOm8K/9OoRFnXKw==} + + mdast-util-from-markdown@1.3.1: + resolution: {integrity: sha512-4xTO/M8c82qBcnQc1tgpNtubGUW/Y1tBQ1B0i5CtSoelOLKFYlElIr3bvgREYYO5iRqbMY1YuqZng0GVOI8Qww==} + + mdast-util-gfm-autolink-literal@1.0.3: + resolution: {integrity: sha512-My8KJ57FYEy2W2LyNom4n3E7hKTuQk/0SES0u16tjA9Z3oFkF4RrC/hPAPgjlSpezsOvI8ObcXcElo92wn5IGA==} + + mdast-util-gfm-footnote@1.0.2: + resolution: {integrity: sha512-56D19KOGbE00uKVj3sgIykpwKL179QsVFwx/DCW0u/0+URsryacI4MAdNJl0dh+u2PSsD9FtxPFbHCzJ78qJFQ==} + + mdast-util-gfm-strikethrough@1.0.3: + resolution: {integrity: sha512-DAPhYzTYrRcXdMjUtUjKvW9z/FNAMTdU0ORyMcbmkwYNbKocDpdk+PX1L1dQgOID/+vVs1uBQ7ElrBQfZ0cuiQ==} + + mdast-util-gfm-table@1.0.7: + resolution: {integrity: sha512-jjcpmNnQvrmN5Vx7y7lEc2iIOEytYv7rTvu+MeyAsSHTASGCCRA79Igg2uKssgOs1i1po8s3plW0sTu1wkkLGg==} + + mdast-util-gfm-task-list-item@1.0.2: + resolution: {integrity: sha512-PFTA1gzfp1B1UaiJVyhJZA1rm0+Tzn690frc/L8vNX1Jop4STZgOE6bxUhnzdVSB+vm2GU1tIsuQcA9bxTQpMQ==} + + mdast-util-gfm@2.0.2: + resolution: {integrity: sha512-qvZ608nBppZ4icQlhQQIAdc6S3Ffj9RGmzwUKUWuEICFnd1LVkN3EktF7ZHAgfcEdvZB5owU9tQgt99e2TlLjg==} + + mdast-util-math@2.0.2: + resolution: {integrity: sha512-8gmkKVp9v6+Tgjtq6SYx9kGPpTf6FVYRa53/DLh479aldR9AyP48qeVOgNZ5X7QUK7nOy4yw7vg6mbiGcs9jWQ==} + + mdast-util-mdx-expression@1.3.2: + resolution: {integrity: sha512-xIPmR5ReJDu/DHH1OoIT1HkuybIfRGYRywC+gJtI7qHjCJp/M9jrmBEJW22O8lskDWm562BX2W8TiAwRTb0rKA==} + + mdast-util-mdx-jsx@2.1.4: + resolution: {integrity: sha512-DtMn9CmVhVzZx3f+optVDF8yFgQVt7FghCRNdlIaS3X5Bnym3hZwPbg/XW86vdpKjlc1PVj26SpnLGeJBXD3JA==} + + mdast-util-mdx@2.0.1: + resolution: {integrity: sha512-38w5y+r8nyKlGvNjSEqWrhG0w5PmnRA+wnBvm+ulYCct7nsGYhFVb0lljS9bQav4psDAS1eGkP2LMVcZBi/aqw==} + + mdast-util-mdxjs-esm@1.3.1: + resolution: {integrity: sha512-SXqglS0HrEvSdUEfoXFtcg7DRl7S2cwOXc7jkuusG472Mmjag34DUDeOJUZtl+BVnyeO1frIgVpHlNRWc2gk/w==} + + mdast-util-phrasing@3.0.1: + resolution: {integrity: sha512-WmI1gTXUBJo4/ZmSk79Wcb2HcjPJBzM1nlI/OUWA8yk2X9ik3ffNbBGsU+09BFmXaL1IBb9fiuvq6/KMiNycSg==} + + mdast-util-to-hast@12.3.0: + resolution: {integrity: sha512-pits93r8PhnIoU4Vy9bjW39M2jJ6/tdHyja9rrot9uujkN7UTU9SDnE6WNJz/IGyQk3XHX6yNNtrBH6cQzm8Hw==} + + mdast-util-to-hast@13.1.0: + resolution: {integrity: sha512-/e2l/6+OdGp/FB+ctrJ9Avz71AN/GRH3oi/3KAx/kMnoUsD6q0woXlDT8lLEeViVKE7oZxE7RXzvO3T8kF2/sA==} + + mdast-util-to-markdown@1.5.0: + resolution: {integrity: sha512-bbv7TPv/WC49thZPg3jXuqzuvI45IL2EVAr/KxF0BSdHsU0ceFHOmwQn6evxAh1GaoK/6GQ1wp4R4oW2+LFL/A==} + + mdast-util-to-string@3.2.0: + resolution: {integrity: sha512-V4Zn/ncyN1QNSqSBxTrMOLpjr+IKdHl2v3KVLoWmDPscP4r9GcCi71gjgvUV1SFSKh92AjAG4peFuBl2/YgCJg==} + + mermaid@10.9.1: + resolution: {integrity: sha512-Mx45Obds5W1UkW1nv/7dHRsbfMM1aOKA2+Pxs/IGHNonygDHwmng8xTHyS9z4KWVi0rbko8gjiBmuwwXQ7tiNA==} + + micromark-core-commonmark@1.1.0: + resolution: {integrity: sha512-BgHO1aRbolh2hcrzL2d1La37V0Aoz73ymF8rAcKnohLy93titmv62E0gP8Hrx9PKcKrqCZ1BbLGbP3bEhoXYlw==} + + micromark-extension-gfm-autolink-literal@1.0.5: + resolution: {integrity: sha512-z3wJSLrDf8kRDOh2qBtoTRD53vJ+CWIyo7uyZuxf/JAbNJjiHsOpG1y5wxk8drtv3ETAHutCu6N3thkOOgueWg==} + + micromark-extension-gfm-footnote@1.1.2: + resolution: {integrity: sha512-Yxn7z7SxgyGWRNa4wzf8AhYYWNrwl5q1Z8ii+CSTTIqVkmGZF1CElX2JI8g5yGoM3GAman9/PVCUFUSJ0kB/8Q==} + + micromark-extension-gfm-strikethrough@1.0.7: + resolution: {integrity: sha512-sX0FawVE1o3abGk3vRjOH50L5TTLr3b5XMqnP9YDRb34M0v5OoZhG+OHFz1OffZ9dlwgpTBKaT4XW/AsUVnSDw==} + + micromark-extension-gfm-table@1.0.7: + resolution: {integrity: sha512-3ZORTHtcSnMQEKtAOsBQ9/oHp9096pI/UvdPtN7ehKvrmZZ2+bbWhi0ln+I9drmwXMt5boocn6OlwQzNXeVeqw==} + + micromark-extension-gfm-tagfilter@1.0.2: + resolution: {integrity: sha512-5XWB9GbAUSHTn8VPU8/1DBXMuKYT5uOgEjJb8gN3mW0PNW5OPHpSdojoqf+iq1xo7vWzw/P8bAHY0n6ijpXF7g==} + + micromark-extension-gfm-task-list-item@1.0.5: + resolution: {integrity: sha512-RMFXl2uQ0pNQy6Lun2YBYT9g9INXtWJULgbt01D/x8/6yJ2qpKyzdZD3pi6UIkzF++Da49xAelVKUeUMqd5eIQ==} + + micromark-extension-gfm@2.0.3: + resolution: {integrity: sha512-vb9OoHqrhCmbRidQv/2+Bc6pkP0FrtlhurxZofvOEy5o8RtuuvTq+RQ1Vw5ZDNrVraQZu3HixESqbG+0iKk/MQ==} + + micromark-extension-math@2.1.2: + resolution: {integrity: sha512-es0CcOV89VNS9wFmyn+wyFTKweXGW4CEvdaAca6SWRWPyYCbBisnjaHLjWO4Nszuiud84jCpkHsqAJoa768Pvg==} + + micromark-extension-mdx-expression@1.0.8: + resolution: {integrity: sha512-zZpeQtc5wfWKdzDsHRBY003H2Smg+PUi2REhqgIhdzAa5xonhP03FcXxqFSerFiNUr5AWmHpaNPQTBVOS4lrXw==} + + micromark-extension-mdx-jsx@1.0.5: + resolution: {integrity: sha512-gPH+9ZdmDflbu19Xkb8+gheqEDqkSpdCEubQyxuz/Hn8DOXiXvrXeikOoBA71+e8Pfi0/UYmU3wW3H58kr7akA==} + + micromark-extension-mdx-md@1.0.1: + resolution: {integrity: sha512-7MSuj2S7xjOQXAjjkbjBsHkMtb+mDGVW6uI2dBL9snOBCbZmoNgDAeZ0nSn9j3T42UE/g2xVNMn18PJxZvkBEA==} + + micromark-extension-mdxjs-esm@1.0.5: + resolution: {integrity: sha512-xNRBw4aoURcyz/S69B19WnZAkWJMxHMT5hE36GtDAyhoyn/8TuAeqjFJQlwk+MKQsUD7b3l7kFX+vlfVWgcX1w==} + + micromark-extension-mdxjs@1.0.1: + resolution: {integrity: sha512-7YA7hF6i5eKOfFUzZ+0z6avRG52GpWR8DL+kN47y3f2KhxbBZMhmxe7auOeaTBrW2DenbbZTf1ea9tA2hDpC2Q==} + + micromark-factory-destination@1.1.0: + resolution: {integrity: sha512-XaNDROBgx9SgSChd69pjiGKbV+nfHGDPVYFs5dOoDd7ZnMAE+Cuu91BCpsY8RT2NP9vo/B8pds2VQNCLiu0zhg==} + + micromark-factory-label@1.1.0: + resolution: {integrity: sha512-OLtyez4vZo/1NjxGhcpDSbHQ+m0IIGnT8BoPamh+7jVlzLJBH98zzuCoUeMxvM6WsNeh8wx8cKvqLiPHEACn0w==} + + micromark-factory-mdx-expression@1.0.9: + resolution: {integrity: sha512-jGIWzSmNfdnkJq05c7b0+Wv0Kfz3NJ3N4cBjnbO4zjXIlxJr+f8lk+5ZmwFvqdAbUy2q6B5rCY//g0QAAaXDWA==} + + micromark-factory-space@1.1.0: + resolution: {integrity: sha512-cRzEj7c0OL4Mw2v6nwzttyOZe8XY/Z8G0rzmWQZTBi/jjwyw/U4uqKtUORXQrR5bAZZnbTI/feRV/R7hc4jQYQ==} + + micromark-factory-title@1.1.0: + resolution: {integrity: sha512-J7n9R3vMmgjDOCY8NPw55jiyaQnH5kBdV2/UXCtZIpnHH3P6nHUKaH7XXEYuWwx/xUJcawa8plLBEjMPU24HzQ==} + + micromark-factory-whitespace@1.1.0: + resolution: {integrity: sha512-v2WlmiymVSp5oMg+1Q0N1Lxmt6pMhIHD457whWM7/GUlEks1hI9xj5w3zbc4uuMKXGisksZk8DzP2UyGbGqNsQ==} + + micromark-util-character@1.2.0: + resolution: {integrity: sha512-lXraTwcX3yH/vMDaFWCQJP1uIszLVebzUa3ZHdrgxr7KEU/9mL4mVgCpGbyhvNLNlauROiNUq7WN5u7ndbY6xg==} + + micromark-util-character@2.1.0: + resolution: {integrity: sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==} + + micromark-util-chunked@1.1.0: + resolution: {integrity: sha512-Ye01HXpkZPNcV6FiyoW2fGZDUw4Yc7vT0E9Sad83+bEDiCJ1uXu0S3mr8WLpsz3HaG3x2q0HM6CTuPdcZcluFQ==} + + micromark-util-classify-character@1.1.0: + resolution: {integrity: sha512-SL0wLxtKSnklKSUplok1WQFoGhUdWYKggKUiqhX+Swala+BtptGCu5iPRc+xvzJ4PXE/hwM3FNXsfEVgoZsWbw==} + + micromark-util-combine-extensions@1.1.0: + resolution: {integrity: sha512-Q20sp4mfNf9yEqDL50WwuWZHUrCO4fEyeDCnMGmG5Pr0Cz15Uo7KBs6jq+dq0EgX4DPwwrh9m0X+zPV1ypFvUA==} + + micromark-util-decode-numeric-character-reference@1.1.0: + resolution: {integrity: sha512-m9V0ExGv0jB1OT21mrWcuf4QhP46pH1KkfWy9ZEezqHKAxkj4mPCy3nIH1rkbdMlChLHX531eOrymlwyZIf2iw==} + + micromark-util-decode-string@1.1.0: + resolution: {integrity: sha512-YphLGCK8gM1tG1bd54azwyrQRjCFcmgj2S2GoJDNnh4vYtnL38JS8M4gpxzOPNyHdNEpheyWXCTnnTDY3N+NVQ==} + + micromark-util-encode@1.1.0: + resolution: {integrity: sha512-EuEzTWSTAj9PA5GOAs992GzNh2dGQO52UvAbtSOMvXTxv3Criqb6IOzJUBCmEqrrXSblJIJBbFFv6zPxpreiJw==} + + micromark-util-encode@2.0.0: + resolution: {integrity: sha512-pS+ROfCXAGLWCOc8egcBvT0kf27GoWMqtdarNfDcjb6YLuV5cM3ioG45Ys2qOVqeqSbjaKg72vU+Wby3eddPsA==} + + micromark-util-events-to-acorn@1.2.3: + resolution: {integrity: sha512-ij4X7Wuc4fED6UoLWkmo0xJQhsktfNh1J0m8g4PbIMPlx+ek/4YdW5mvbye8z/aZvAPUoxgXHrwVlXAPKMRp1w==} + + micromark-util-html-tag-name@1.2.0: + resolution: {integrity: sha512-VTQzcuQgFUD7yYztuQFKXT49KghjtETQ+Wv/zUjGSGBioZnkA4P1XXZPT1FHeJA6RwRXSF47yvJ1tsJdoxwO+Q==} + + micromark-util-normalize-identifier@1.1.0: + resolution: {integrity: sha512-N+w5vhqrBihhjdpM8+5Xsxy71QWqGn7HYNUvch71iV2PM7+E3uWGox1Qp90loa1ephtCxG2ftRV/Conitc6P2Q==} + + micromark-util-resolve-all@1.1.0: + resolution: {integrity: sha512-b/G6BTMSg+bX+xVCshPTPyAu2tmA0E4X98NSR7eIbeC6ycCqCeE7wjfDIgzEbkzdEVJXRtOG4FbEm/uGbCRouA==} + + micromark-util-sanitize-uri@1.2.0: + resolution: {integrity: sha512-QO4GXv0XZfWey4pYFndLUKEAktKkG5kZTdUNaTAkzbuJxn2tNBOr+QtxR2XpWaMhbImT2dPzyLrPXLlPhph34A==} + + micromark-util-sanitize-uri@2.0.0: + resolution: {integrity: sha512-WhYv5UEcZrbAtlsnPuChHUAsu/iBPOVaEVsntLBIdpibO0ddy8OzavZz3iL2xVvBZOpolujSliP65Kq0/7KIYw==} + + micromark-util-subtokenize@1.1.0: + resolution: {integrity: sha512-kUQHyzRoxvZO2PuLzMt2P/dwVsTiivCK8icYTeR+3WgbuPqfHgPPy7nFKbeqRivBvn/3N3GBiNC+JRTMSxEC7A==} + + micromark-util-symbol@1.1.0: + resolution: {integrity: sha512-uEjpEYY6KMs1g7QfJ2eX1SQEV+ZT4rUD3UcF6l57acZvLNK7PBZL+ty82Z1qhK1/yXIY4bdx04FKMgR0g4IAag==} + + micromark-util-symbol@2.0.0: + resolution: {integrity: sha512-8JZt9ElZ5kyTnO94muPxIGS8oyElRJaiJO8EzV6ZSyGQ1Is8xwl4Q45qU5UOg+bGH4AikWziz0iN4sFLWs8PGw==} + + micromark-util-types@1.1.0: + resolution: {integrity: sha512-ukRBgie8TIAcacscVHSiddHjO4k/q3pnedmzMQ4iwDcK0FtFCohKOlFbaOL/mPgfnPsL3C1ZyxJa4sbWrBl3jg==} + + micromark-util-types@2.0.0: + resolution: {integrity: sha512-oNh6S2WMHWRZrmutsRmDDfkzKtxF+bc2VxLC9dvtrDIRFln627VsFP6fLMgTryGDljgLPjkrzQSDcPrjPyDJ5w==} + + micromark@3.2.0: + resolution: {integrity: sha512-uD66tJj54JLYq0De10AhWycZWGQNUvDI55xPgk2sQM5kn1JYlhbCMTtEeT27+vAhW2FBQxLlOmS3pmA7/2z4aA==} + + mri@1.2.0: + resolution: {integrity: sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==} + engines: {node: '>=4'} + + ms@2.1.2: + resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} + + nanoid@3.3.7: + resolution: {integrity: sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + + next-mdx-remote@4.4.1: + resolution: {integrity: sha512-1BvyXaIou6xy3XoNF4yaMZUCb6vD2GTAa5ciOa6WoO+gAUTYsb1K4rI/HSC2ogAWLrb/7VSV52skz07vOzmqIQ==} + engines: {node: '>=14', npm: '>=7'} + peerDependencies: + react: '>=16.x <=18.x' + react-dom: '>=16.x <=18.x' + + next-seo@6.5.0: + resolution: {integrity: sha512-MfzUeWTN/x/rsKp/1n0213eojO97lIl0unxqbeCY+6pAucViHDA8GSLRRcXpgjsSmBxfCFdfpu7LXbt4ANQoNQ==} + peerDependencies: + next: ^8.1.1-canary.54 || >=9.0.0 + react: '>=16.0.0' + react-dom: '>=16.0.0' + + next-themes@0.2.1: + resolution: {integrity: sha512-B+AKNfYNIzh0vqQQKqQItTS8evEouKD7H5Hj3kmuPERwddR2TxvDSFZuTj6T7Jfn1oyeUyJMydPl1Bkxkh0W7A==} + peerDependencies: + next: '*' + react: '*' + react-dom: '*' + + next@14.2.3: + resolution: {integrity: sha512-dowFkFTR8v79NPJO4QsBUtxv0g9BrS/phluVpMAt2ku7H+cbcBJlopXjkWlwxrk/xGqMemr7JkGPGemPrLLX7A==} + engines: {node: '>=18.17.0'} + hasBin: true + peerDependencies: + '@opentelemetry/api': ^1.1.0 + '@playwright/test': ^1.41.2 + react: ^18.2.0 + react-dom: ^18.2.0 + sass: ^1.3.0 + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + '@playwright/test': + optional: true + sass: + optional: true + + nextra-theme-docs@2.13.4: + resolution: {integrity: sha512-2XOoMfwBCTYBt8ds4ZHftt9Wyf2XsykiNo02eir/XEYB+sGeUoE77kzqfidjEOKCSzOHYbK9BDMcg2+B/2vYRw==} + peerDependencies: + next: '>=9.5.3' + nextra: 2.13.4 + react: '>=16.13.1' + react-dom: '>=16.13.1' + + nextra@2.13.4: + resolution: {integrity: sha512-7of2rSBxuUa3+lbMmZwG9cqgftcoNOVQLTT6Rxf3EhBR9t1EI7b43dted8YoqSNaigdE3j1CoyNkX8N/ZzlEpw==} + engines: {node: '>=16'} + peerDependencies: + next: '>=9.5.3' + react: '>=16.13.1' + react-dom: '>=16.13.1' + + non-layered-tidy-tree-layout@2.0.2: + resolution: {integrity: sha512-gkXMxRzUH+PB0ax9dUN0yYF0S25BqeAYqhgMaLUFmpXLEk7Fcu8f4emJuOAY0V8kjDICxROIKsTAKsV/v355xw==} + + npm-run-path@2.0.2: + resolution: {integrity: sha512-lJxZYlT4DW/bRUtFh1MQIWqmLwQfAxnqWG4HhEdjMlkrJYnJn0Jrr2u3mgxqaWsdiBc76TYkTG/mhrnYTuzfHw==} + engines: {node: '>=4'} + + npm-to-yarn@2.2.1: + resolution: {integrity: sha512-O/j/ROyX0KGLG7O6Ieut/seQ0oiTpHF2tXAcFbpdTLQFiaNtkyTXXocM1fwpaa60dg1qpWj0nHlbNhx6qwuENQ==} + engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} + + p-finally@1.0.0: + resolution: {integrity: sha512-LICb2p9CB7FS+0eR1oqWnHhp0FljGLZCWBE9aix0Uye9W8LTQPwMTYVGWQWIw9RdQiDg4+epXQODwIYJtSJaow==} + engines: {node: '>=4'} + + p-limit@3.1.0: + resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} + engines: {node: '>=10'} + + parse-entities@4.0.1: + resolution: {integrity: sha512-SWzvYcSJh4d/SGLIOQfZ/CoNv6BTlI6YEQ7Nj82oDVnRpwe/Z/F1EMx42x3JAOwGBlCjeCH0BRJQbQ/opHL17w==} + + parse-numeric-range@1.3.0: + resolution: {integrity: sha512-twN+njEipszzlMJd4ONUYgSfZPDxgHhT9Ahed5uTigpQn90FggW4SA/AIPq/6a149fTbE9qBEcSwE3FAEp6wQQ==} + + parse-path@7.0.0: + resolution: {integrity: sha512-Euf9GG8WT9CdqwuWJGdf3RkUcTBArppHABkO7Lm8IzRQp0e2r/kkFnmhu4TSK30Wcu5rVAZLmfPKSBBi9tWFog==} + + parse-url@8.1.0: + resolution: {integrity: sha512-xDvOoLU5XRrcOZvnI6b8zA6n9O9ejNk/GExuz1yBuWUGn9KA97GI6HTs6u02wKara1CeVmZhH+0TZFdWScR89w==} + + parse5@7.1.2: + resolution: {integrity: sha512-Czj1WaSVpaoj0wbhMzLmWD69anp2WH7FXMB9n1Sy8/ZFF9jolSQVMu1Ij5WIyGmcBmhk7EOndpO4mIpihVqAXw==} + + path-key@2.0.1: + resolution: {integrity: sha512-fEHGKCSmUSDPv4uoj8AlD+joPlq3peND+HRYyxFz4KPw4z926S/b8rIuFs2FYJg3BwsxJf6A9/3eIdLaYC+9Dw==} + engines: {node: '>=4'} + + periscopic@3.1.0: + resolution: {integrity: sha512-vKiQ8RRtkl9P+r/+oefh25C3fhybptkHKCZSPlcXiJux2tJF55GnEj3BVn4A5gKfq9NWWXXrxkHBwVPUfH0opw==} + + picocolors@1.0.1: + resolution: {integrity: sha512-anP1Z8qwhkbmu7MFP5iTt+wQKXgwzf7zTyGlcdzabySa9vd0Xt392U0rVmz9poOaBj0uHJKyyo9/upk0HrEQew==} + + postcss@8.4.31: + resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} + engines: {node: ^10 || ^12 || >=14} + + property-information@6.5.0: + resolution: {integrity: sha512-PgTgs/BlvHxOu8QuEN7wi5A0OmXaBcHpmCSTehcs6Uuu9IkDIEo13Hy7n898RHfrQ49vKCoGeWZSaAK01nwVig==} + + protocols@2.0.1: + resolution: {integrity: sha512-/XJ368cyBJ7fzLMwLKv1e4vLxOju2MNAIokcr7meSaNcVbWz/CPcW22cP04mwxOErdA5mwjA8Q6w/cdAQxVn7Q==} + + pseudomap@1.0.2: + resolution: {integrity: sha512-b/YwNhb8lk1Zz2+bXXpS/LK9OisiZZ1SNsSLxN1x2OXVEhW2Ckr/7mWE5vrC1ZTiJlD9g19jWszTmJsB+oEpFQ==} + + react-dom@18.3.1: + resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} + peerDependencies: + react: ^18.3.1 + + react@18.3.1: + resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} + engines: {node: '>=0.10.0'} + + reading-time@1.5.0: + resolution: {integrity: sha512-onYyVhBNr4CmAxFsKS7bz+uTLRakypIe4R+5A824vBSkQy/hB3fZepoVEf8OVAxzLvK+H/jm9TzpI3ETSm64Kg==} + + regenerator-runtime@0.14.1: + resolution: {integrity: sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==} + + rehype-katex@7.0.0: + resolution: {integrity: sha512-h8FPkGE00r2XKU+/acgqwWUlyzve1IiOKwsEkg4pDL3k48PiE0Pt+/uLtVHDVkN1yA4iurZN6UES8ivHVEQV6Q==} + + rehype-pretty-code@0.9.11: + resolution: {integrity: sha512-Eq90eCYXQJISktfRZ8PPtwc5SUyH6fJcxS8XOMnHPUQZBtC6RYo67gGlley9X2nR8vlniPj0/7oCDEYHKQa/oA==} + engines: {node: '>=16'} + peerDependencies: + shiki: '*' + + rehype-raw@7.0.0: + resolution: {integrity: sha512-/aE8hCfKlQeA8LmyeyQvQF3eBiLRGNlfBJEvWH7ivp9sBqs7TNqBL5X3v157rM4IFETqDnIOO+z5M/biZbo9Ww==} + + remark-gfm@3.0.1: + resolution: {integrity: sha512-lEFDoi2PICJyNrACFOfDD3JlLkuSbOa5Wd8EPt06HUdptv8Gn0bxYTdbU/XXQ3swAPkEaGxxPN9cbnMHvVu1Ig==} + + remark-math@5.1.1: + resolution: {integrity: sha512-cE5T2R/xLVtfFI4cCePtiRn+e6jKMtFDR3P8V3qpv8wpKjwvHoBA4eJzvX+nVrnlNy0911bdGmuspCSwetfYHw==} + + remark-mdx@2.3.0: + resolution: {integrity: sha512-g53hMkpM0I98MU266IzDFMrTD980gNF3BJnkyFcmN+dD873mQeD5rdMO3Y2X+x8umQfbSE0PcoEDl7ledSA+2g==} + + remark-parse@10.0.2: + resolution: {integrity: sha512-3ydxgHa/ZQzG8LvC7jTXccARYDcRld3VfcgIIFs7bI6vbRSxJJmzgLEIIoYKyrfhaY+ujuWaf/PJiMZXoiCXgw==} + + remark-reading-time@2.0.1: + resolution: {integrity: sha512-fy4BKy9SRhtYbEHvp6AItbRTnrhiDGbqLQTSYVbQPGuRCncU1ubSsh9p/W5QZSxtYcUXv8KGL0xBgPLyNJA1xw==} + + remark-rehype@10.1.0: + resolution: {integrity: sha512-EFmR5zppdBp0WQeDVZ/b66CWJipB2q2VLNFMabzDSGR66Z2fQii83G5gTBbgGEnEEA0QRussvrFHxk1HWGJskw==} + + remove-accents@0.5.0: + resolution: {integrity: sha512-8g3/Otx1eJaVD12e31UbJj1YzdtVvzH85HV7t+9MJYk/u3XmkOUJ5Ys9wQrf9PCPK8+xn4ymzqYCiZl6QWKn+A==} + + robust-predicates@3.0.2: + resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==} + + rw@1.3.3: + resolution: {integrity: sha512-PdhdWy89SiZogBLaw42zdeqtRJ//zFd2PgQavcICDUgJT5oW10QCRKbJ6bg4r0/UY2M6BWd5tkxuGFRvCkgfHQ==} + + sade@1.8.1: + resolution: {integrity: sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==} + engines: {node: '>=6'} + + safer-buffer@2.1.2: + resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} + + scheduler@0.23.2: + resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==} + + scroll-into-view-if-needed@3.1.0: + resolution: {integrity: sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ==} + + section-matter@1.0.0: + resolution: {integrity: sha512-vfD3pmTzGpufjScBh50YHKzEu2lxBWhVEHsNGoEXmCmn2hKGfeNLYMzCJpe8cD7gqX7TJluOVpBkAequ6dgMmA==} + engines: {node: '>=4'} + + shebang-command@1.2.0: + resolution: {integrity: sha512-EV3L1+UQWGor21OmnvojK36mhg+TyIKDh3iFBKBohr5xeXIhNBcx8oWdgkTEEQ+BEFFYdLRuqMfd5L84N1V5Vg==} + engines: {node: '>=0.10.0'} + + shebang-regex@1.0.0: + resolution: {integrity: sha512-wpoSFAxys6b2a2wHZ1XpDSgD7N9iVjg29Ph9uV/uaP9Ex/KXlkTZTeddxDPSYQpgvzKLGJke2UU0AzoGCjNIvQ==} + engines: {node: '>=0.10.0'} + + shiki@0.14.7: + resolution: {integrity: sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==} + + signal-exit@3.0.7: + resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} + + slash@3.0.0: + resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==} + engines: {node: '>=8'} + + sort-keys@5.0.0: + resolution: {integrity: sha512-Pdz01AvCAottHTPQGzndktFNdbRA75BgOfeT1hH+AMnJFv8lynkPi42rfeEhpx1saTEI3YNMWxfqu0sFD1G8pw==} + engines: {node: '>=12'} + + source-map-js@1.2.0: + resolution: {integrity: sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==} + engines: {node: '>=0.10.0'} + + source-map@0.7.4: + resolution: {integrity: sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==} + engines: {node: '>= 8'} + + space-separated-tokens@2.0.2: + resolution: {integrity: sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==} + + sprintf-js@1.0.3: + resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} + + streamsearch@1.1.0: + resolution: {integrity: sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==} + engines: {node: '>=10.0.0'} + + stringify-entities@4.0.4: + resolution: {integrity: sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==} + + strip-bom-string@1.0.0: + resolution: {integrity: sha512-uCC2VHvQRYu+lMh4My/sFNmF2klFymLX1wHJeXnbEJERpV/ZsVuonzerjfrGpIGF7LBVa1O7i9kjiWvJiFck8g==} + engines: {node: '>=0.10.0'} + + strip-eof@1.0.0: + resolution: {integrity: sha512-7FCwGGmx8mD5xQd3RPUvnSpUXHM3BWuzjtpD4TXsfcZ9EL4azvVVUscFYwD9nx8Kh+uCBC00XBtAykoMHwTh8Q==} + engines: {node: '>=0.10.0'} + + style-to-object@0.4.4: + resolution: {integrity: sha512-HYNoHZa2GorYNyqiCaBgsxvcJIn7OHq6inEga+E6Ke3m5JkoqpQbnFssk4jwe+K7AhGa2fcha4wSOf1Kn01dMg==} + + styled-jsx@5.1.1: + resolution: {integrity: sha512-pW7uC1l4mBZ8ugbiZrcIsiIvVx1UmTfw7UkC3Um2tmfUq9Bhk8IiyEIPl6F8agHgjzku6j0xQEZbfA5uSgSaCw==} + engines: {node: '>= 12.0.0'} + peerDependencies: + '@babel/core': '*' + babel-plugin-macros: '*' + react: '>= 16.8.0 || 17.x.x || ^18.0.0-0' + peerDependenciesMeta: + '@babel/core': + optional: true + babel-plugin-macros: + optional: true + + stylis@4.3.2: + resolution: {integrity: sha512-bhtUjWd/z6ltJiQwg0dUfxEJ+W+jdqQd8TbWLWyeIJHlnsqmGLRFFd8e5mA0AZi/zx90smXRlN66YMTcaSFifg==} + + supports-color@4.5.0: + resolution: {integrity: sha512-ycQR/UbvI9xIlEdQT1TQqwoXtEldExbCEAJgRo5YXlmSKjv6ThHnP9/vwGa1gr19Gfw+LkFd7KqYMhzrRC5JYw==} + engines: {node: '>=4'} + + title@3.5.3: + resolution: {integrity: sha512-20JyowYglSEeCvZv3EZ0nZ046vLarO37prvV0mbtQV7C8DJPGgN967r8SJkqd3XK3K3lD3/Iyfp3avjfil8Q2Q==} + hasBin: true + + titleize@1.0.0: + resolution: {integrity: sha512-TARUb7z1pGvlLxgPk++7wJ6aycXF3GJ0sNSBTAsTuJrQG5QuZlkUQP+zl+nbjAh4gMX9yDw9ZYklMd7vAfJKEw==} + engines: {node: '>=0.10.0'} + + trim-lines@3.0.1: + resolution: {integrity: sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==} + + trough@2.2.0: + resolution: {integrity: sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==} + + ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + + tslib@2.6.3: + resolution: {integrity: sha512-xNvxJEOUiWPGhUuUdQgAJPKOOJfGnIyKySOc09XkKsgdUV/3E2zvwZYdejjmRgPCgcym1juLH3226yA7sEFJKQ==} + + type-fest@1.4.0: + resolution: {integrity: sha512-yGSza74xk0UG8k+pLh5oeoYirvIiWo5t0/o3zHHAO2tRDiZcxWP7fywNlXhqb6/r6sWvwi+RsyQMWhVLe4BVuA==} + engines: {node: '>=10'} + + typescript@5.4.5: + resolution: {integrity: sha512-vcI4UpRgg81oIRUFwR0WSIHKt11nJ7SAVlYNIu+QpqeyXP+gpQJy/Z4+F0aGxSE4MqwjyXvW/TzgkLAx2AGHwQ==} + engines: {node: '>=14.17'} + hasBin: true + + undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + + unified@10.1.2: + resolution: {integrity: sha512-pUSWAi/RAnVy1Pif2kAoeWNBa3JVrx0MId2LASj8G+7AiHWoKZNTomq6LG326T68U7/e263X6fTdcXIy7XnF7Q==} + + unist-util-find-after@5.0.0: + resolution: {integrity: sha512-amQa0Ep2m6hE2g72AugUItjbuM8X8cGQnFoHk0pGfrFeT9GZhzN5SW8nRsiGKK7Aif4CrACPENkA6P/Lw6fHGQ==} + + unist-util-generated@2.0.1: + resolution: {integrity: sha512-qF72kLmPxAw0oN2fwpWIqbXAVyEqUzDHMsbtPvOudIlUzXYFIeQIuxXQCRCFh22B7cixvU0MG7m3MW8FTq/S+A==} + + unist-util-is@5.2.1: + resolution: {integrity: sha512-u9njyyfEh43npf1M+yGKDGVPbY/JWEemg5nH05ncKPfi+kBbKBJoTdsogMu33uhytuLlv9y0O7GH7fEdwLdLQw==} + + unist-util-is@6.0.0: + resolution: {integrity: sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw==} + + unist-util-position-from-estree@1.1.2: + resolution: {integrity: sha512-poZa0eXpS+/XpoQwGwl79UUdea4ol2ZuCYguVaJS4qzIOMDzbqz8a3erUCOmubSZkaOuGamb3tX790iwOIROww==} + + unist-util-position@4.0.4: + resolution: {integrity: sha512-kUBE91efOWfIVBo8xzh/uZQ7p9ffYRtUbMRZBNFYwf0RK8koUMx6dGUfwylLOKmaT2cs4wSW96QoYUSXAyEtpg==} + + unist-util-position@5.0.0: + resolution: {integrity: sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==} + + unist-util-remove-position@4.0.2: + resolution: {integrity: sha512-TkBb0HABNmxzAcfLf4qsIbFbaPDvMO6wa3b3j4VcEzFVaw1LBKwnW4/sRJ/atSLSzoIg41JWEdnE7N6DIhGDGQ==} + + unist-util-remove-position@5.0.0: + resolution: {integrity: sha512-Hp5Kh3wLxv0PHj9m2yZhhLt58KzPtEYKQQ4yxfYFEO7EvHwzyDYnduhHnY1mDxoqr7VUwVuHXk9RXKIiYS1N8Q==} + + unist-util-remove@4.0.0: + resolution: {integrity: sha512-b4gokeGId57UVRX/eVKej5gXqGlc9+trkORhFJpu9raqZkZhU0zm8Doi05+HaiBsMEIJowL+2WtQ5ItjsngPXg==} + + unist-util-stringify-position@3.0.3: + resolution: {integrity: sha512-k5GzIBZ/QatR8N5X2y+drfpWG8IDBzdnVj6OInRNWm1oXrzydiaAT2OQiA8DPRRZyAKb9b6I2a6PxYklZD0gKg==} + + unist-util-stringify-position@4.0.0: + resolution: {integrity: sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==} + + unist-util-visit-parents@4.1.1: + resolution: {integrity: sha512-1xAFJXAKpnnJl8G7K5KgU7FY55y3GcLIXqkzUj5QF/QVP7biUm0K0O2oqVkYsdjzJKifYeWn9+o6piAK2hGSHw==} + + unist-util-visit-parents@5.1.3: + resolution: {integrity: sha512-x6+y8g7wWMyQhL1iZfhIPhDAs7Xwbn9nRosDXl7qoPTSCy0yNxnKc+hWokFifWQIDGi154rdUqKvbCa4+1kLhg==} + + unist-util-visit-parents@6.0.1: + resolution: {integrity: sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw==} + + unist-util-visit@3.1.0: + resolution: {integrity: sha512-Szoh+R/Ll68QWAyQyZZpQzZQm2UPbxibDvaY8Xc9SUtYgPsDzx5AWSk++UUt2hJuow8mvwR+rG+LQLw+KsuAKA==} + + unist-util-visit@4.1.2: + resolution: {integrity: sha512-MSd8OUGISqHdVvfY9TPhyK2VdUrPgxkUtWSuMHF6XAAFuL4LokseigBnZtPnJMu+FbynTkFNnFlyjxpVKujMRg==} + + unist-util-visit@5.0.0: + resolution: {integrity: sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==} + + uuid@9.0.1: + resolution: {integrity: sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==} + hasBin: true + + uvu@0.5.6: + resolution: {integrity: sha512-+g8ENReyr8YsOc6fv/NVJs2vFdHBnBNdfE49rshrTzDWOlUx4Gq7KOS2GD8eqhy2j+Ejq29+SbKH8yjkAqXqoA==} + engines: {node: '>=8'} + hasBin: true + + vfile-location@5.0.2: + resolution: {integrity: sha512-NXPYyxyBSH7zB5U6+3uDdd6Nybz6o6/od9rk8bp9H8GR3L+cm/fC0uUTbqBmUTnMCUDslAGBOIKNfvvb+gGlDg==} + + vfile-matter@3.0.1: + resolution: {integrity: sha512-CAAIDwnh6ZdtrqAuxdElUqQRQDQgbbIrYtDYI8gCjXS1qQ+1XdLoK8FIZWxJwn0/I+BkSSZpar3SOgjemQz4fg==} + + vfile-message@3.1.4: + resolution: {integrity: sha512-fa0Z6P8HUrQN4BZaX05SIVXic+7kE3b05PWAtPuYP9QLHsLKYR7/AlLW3NtOrpXRLeawpDLMsVkmk5DG0NXgWw==} + + vfile-message@4.0.2: + resolution: {integrity: sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw==} + + vfile@5.3.7: + resolution: {integrity: sha512-r7qlzkgErKjobAmyNIkkSpizsFPYiUPuJb5pNW1RB4JcYVZhs4lIbVqk8XPk033CV/1z8ss5pkax8SuhGpcG8g==} + + vfile@6.0.1: + resolution: {integrity: sha512-1bYqc7pt6NIADBJ98UiG0Bn/CHIVOoZ/IyEkqIruLg0mE1BKzkOXY2D6CSqQIcKqgadppE5lrxgWXJmXd7zZJw==} + + vscode-oniguruma@1.7.0: + resolution: {integrity: sha512-L9WMGRfrjOhgHSdOYgCt/yRMsXzLDJSL7BPrOZt73gU0iWO4mpqzqQzOz5srxqTvMBaR0XZTSrVWo4j55Rc6cA==} + + vscode-textmate@8.0.0: + resolution: {integrity: sha512-AFbieoL7a5LMqcnOF04ji+rpXadgOXnZsxQr//r83kLPr7biP7am3g9zbaZIaBGwBRWeSvoMD4mgPdX3e4NWBg==} + + web-namespaces@2.0.1: + resolution: {integrity: sha512-bKr1DkiNa2krS7qxNtdrtHAmzuYGFQLiQ13TsorsdT6ULTkPLKuu5+GsFpDlg6JFjUTwX2DyhMPG2be8uPrqsQ==} + + web-worker@1.3.0: + resolution: {integrity: sha512-BSR9wyRsy/KOValMgd5kMyr3JzpdeoR9KVId8u5GVlTTAtNChlsE4yTxeY7zMdNSyOmoKBv8NH2qeRY9Tg+IaA==} + + which@1.3.1: + resolution: {integrity: sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==} + hasBin: true + + yallist@2.1.2: + resolution: {integrity: sha512-ncTzHV7NvsQZkYe1DW7cbDLm0YpzHmZF5r/iyP3ZnQtMiJ+pjzisCiMNI+Sj+xQF5pXhSHxSB3uDbsBTzY/c2A==} + + yocto-queue@0.1.0: + resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} + engines: {node: '>=10'} + + zod@3.23.8: + resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==} + + zwitch@2.0.4: + resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} + +snapshots: + + '@babel/runtime@7.24.7': + dependencies: + regenerator-runtime: 0.14.1 + + '@braintree/sanitize-url@6.0.4': {} + + '@headlessui/react@1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/react-virtual': 3.5.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + client-only: 0.0.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@mdx-js/mdx@2.3.0': + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/mdx': 2.0.13 + estree-util-build-jsx: 2.2.2 + estree-util-is-identifier-name: 2.1.0 + estree-util-to-js: 1.2.0 + estree-walker: 3.0.3 + hast-util-to-estree: 2.3.3 + markdown-extensions: 1.1.1 + periscopic: 3.1.0 + remark-mdx: 2.3.0 + remark-parse: 10.0.2 + remark-rehype: 10.1.0 + unified: 10.1.2 + unist-util-position-from-estree: 1.1.2 + unist-util-stringify-position: 3.0.3 + unist-util-visit: 4.1.2 + vfile: 5.3.7 + transitivePeerDependencies: + - supports-color + + '@mdx-js/react@2.3.0(react@18.3.1)': + dependencies: + '@types/mdx': 2.0.13 + '@types/react': 18.3.3 + react: 18.3.1 + + '@napi-rs/simple-git-android-arm-eabi@0.1.16': + optional: true + + '@napi-rs/simple-git-android-arm64@0.1.16': + optional: true + + '@napi-rs/simple-git-darwin-arm64@0.1.16': + optional: true + + '@napi-rs/simple-git-darwin-x64@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm-gnueabihf@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm64-gnu@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm64-musl@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-x64-gnu@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-x64-musl@0.1.16': + optional: true + + '@napi-rs/simple-git-win32-arm64-msvc@0.1.16': + optional: true + + '@napi-rs/simple-git-win32-x64-msvc@0.1.16': + optional: true + + '@napi-rs/simple-git@0.1.16': + optionalDependencies: + '@napi-rs/simple-git-android-arm-eabi': 0.1.16 + '@napi-rs/simple-git-android-arm64': 0.1.16 + '@napi-rs/simple-git-darwin-arm64': 0.1.16 + '@napi-rs/simple-git-darwin-x64': 0.1.16 + '@napi-rs/simple-git-linux-arm-gnueabihf': 0.1.16 + '@napi-rs/simple-git-linux-arm64-gnu': 0.1.16 + '@napi-rs/simple-git-linux-arm64-musl': 0.1.16 + '@napi-rs/simple-git-linux-x64-gnu': 0.1.16 + '@napi-rs/simple-git-linux-x64-musl': 0.1.16 + '@napi-rs/simple-git-win32-arm64-msvc': 0.1.16 + '@napi-rs/simple-git-win32-x64-msvc': 0.1.16 + + '@next/env@14.2.3': {} + + '@next/swc-darwin-arm64@14.2.3': + optional: true + + '@next/swc-darwin-x64@14.2.3': + optional: true + + '@next/swc-linux-arm64-gnu@14.2.3': + optional: true + + '@next/swc-linux-arm64-musl@14.2.3': + optional: true + + '@next/swc-linux-x64-gnu@14.2.3': + optional: true + + '@next/swc-linux-x64-musl@14.2.3': + optional: true + + '@next/swc-win32-arm64-msvc@14.2.3': + optional: true + + '@next/swc-win32-ia32-msvc@14.2.3': + optional: true + + '@next/swc-win32-x64-msvc@14.2.3': + optional: true + + '@popperjs/core@2.11.8': {} + + '@swc/counter@0.1.3': {} + + '@swc/helpers@0.5.5': + dependencies: + '@swc/counter': 0.1.3 + tslib: 2.6.3 + + '@tanstack/react-virtual@3.5.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/virtual-core': 3.5.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@tanstack/virtual-core@3.5.1': {} + + '@theguild/remark-mermaid@0.0.5(react@18.3.1)': + dependencies: + mermaid: 10.9.1 + react: 18.3.1 + unist-util-visit: 5.0.0 + transitivePeerDependencies: + - supports-color + + '@theguild/remark-npm2yarn@0.2.1': + dependencies: + npm-to-yarn: 2.2.1 + unist-util-visit: 5.0.0 + + '@types/acorn@4.0.6': + dependencies: + '@types/estree': 1.0.5 + + '@types/d3-scale-chromatic@3.0.3': {} + + '@types/d3-scale@4.0.8': + dependencies: + '@types/d3-time': 3.0.3 + + '@types/d3-time@3.0.3': {} + + '@types/debug@4.1.12': + dependencies: + '@types/ms': 0.7.34 + + '@types/estree-jsx@1.0.5': + dependencies: + '@types/estree': 1.0.5 + + '@types/estree@1.0.5': {} + + '@types/hast@2.3.10': + dependencies: + '@types/unist': 2.0.10 + + '@types/hast@3.0.4': + dependencies: + '@types/unist': 3.0.2 + + '@types/js-yaml@4.0.9': {} + + '@types/katex@0.16.7': {} + + '@types/mdast@3.0.15': + dependencies: + '@types/unist': 2.0.10 + + '@types/mdast@4.0.4': + dependencies: + '@types/unist': 3.0.2 + + '@types/mdx@2.0.13': {} + + '@types/ms@0.7.34': {} + + '@types/node@20.14.2': + dependencies: + undici-types: 5.26.5 + + '@types/prop-types@15.7.12': {} + + '@types/react-dom@18.3.0': + dependencies: + '@types/react': 18.3.3 + + '@types/react@18.3.3': + dependencies: + '@types/prop-types': 15.7.12 + csstype: 3.1.3 + + '@types/unist@2.0.10': {} + + '@types/unist@3.0.2': {} + + '@ungap/structured-clone@1.2.0': {} + + acorn-jsx@5.3.2(acorn@8.11.3): + dependencies: + acorn: 8.11.3 + + acorn@8.11.3: {} + + ansi-sequence-parser@1.1.1: {} + + ansi-styles@3.2.1: + dependencies: + color-convert: 1.9.3 + + arch@2.2.0: {} + + arg@1.0.0: {} + + argparse@1.0.10: + dependencies: + sprintf-js: 1.0.3 + + argparse@2.0.1: {} + + astring@1.8.6: {} + + bail@2.0.2: {} + + busboy@1.6.0: + dependencies: + streamsearch: 1.1.0 + + caniuse-lite@1.0.30001629: {} + + ccount@2.0.1: {} + + chalk@2.3.0: + dependencies: + ansi-styles: 3.2.1 + escape-string-regexp: 1.0.5 + supports-color: 4.5.0 + + character-entities-html4@2.1.0: {} + + character-entities-legacy@3.0.0: {} + + character-entities@2.0.2: {} + + character-reference-invalid@2.0.1: {} + + client-only@0.0.1: {} + + clipboardy@1.2.2: + dependencies: + arch: 2.2.0 + execa: 0.8.0 + + clsx@2.1.1: {} + + color-convert@1.9.3: + dependencies: + color-name: 1.1.3 + + color-name@1.1.3: {} + + comma-separated-tokens@2.0.3: {} + + commander@7.2.0: {} + + commander@8.3.0: {} + + compute-scroll-into-view@3.1.0: {} + + cose-base@1.0.3: + dependencies: + layout-base: 1.0.2 + + cross-spawn@5.1.0: + dependencies: + lru-cache: 4.1.5 + shebang-command: 1.2.0 + which: 1.3.1 + + csstype@3.1.3: {} + + cytoscape-cose-bilkent@4.1.0(cytoscape@3.29.2): + dependencies: + cose-base: 1.0.3 + cytoscape: 3.29.2 + + cytoscape@3.29.2: {} + + d3-array@2.12.1: + dependencies: + internmap: 1.0.1 + + d3-array@3.2.4: + dependencies: + internmap: 2.0.3 + + d3-axis@3.0.0: {} + + d3-brush@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + d3-chord@3.0.1: + dependencies: + d3-path: 3.1.0 + + d3-color@3.1.0: {} + + d3-contour@4.0.2: + dependencies: + d3-array: 3.2.4 + + d3-delaunay@6.0.4: + dependencies: + delaunator: 5.0.1 + + d3-dispatch@3.0.1: {} + + d3-drag@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-selection: 3.0.0 + + d3-dsv@3.0.1: + dependencies: + commander: 7.2.0 + iconv-lite: 0.6.3 + rw: 1.3.3 + + d3-ease@3.0.1: {} + + d3-fetch@3.0.1: + dependencies: + d3-dsv: 3.0.1 + + d3-force@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-quadtree: 3.0.1 + d3-timer: 3.0.1 + + d3-format@3.1.0: {} + + d3-geo@3.1.1: + dependencies: + d3-array: 3.2.4 + + d3-hierarchy@3.1.2: {} + + d3-interpolate@3.0.1: + dependencies: + d3-color: 3.1.0 + + d3-path@1.0.9: {} + + d3-path@3.1.0: {} + + d3-polygon@3.0.1: {} + + d3-quadtree@3.0.1: {} + + d3-random@3.0.1: {} + + d3-sankey@0.12.3: + dependencies: + d3-array: 2.12.1 + d3-shape: 1.3.7 + + d3-scale-chromatic@3.1.0: + dependencies: + d3-color: 3.1.0 + d3-interpolate: 3.0.1 + + d3-scale@4.0.2: + dependencies: + d3-array: 3.2.4 + d3-format: 3.1.0 + d3-interpolate: 3.0.1 + d3-time: 3.1.0 + d3-time-format: 4.1.0 + + d3-selection@3.0.0: {} + + d3-shape@1.3.7: + dependencies: + d3-path: 1.0.9 + + d3-shape@3.2.0: + dependencies: + d3-path: 3.1.0 + + d3-time-format@4.1.0: + dependencies: + d3-time: 3.1.0 + + d3-time@3.1.0: + dependencies: + d3-array: 3.2.4 + + d3-timer@3.0.1: {} + + d3-transition@3.0.1(d3-selection@3.0.0): + dependencies: + d3-color: 3.1.0 + d3-dispatch: 3.0.1 + d3-ease: 3.0.1 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-timer: 3.0.1 + + d3-zoom@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + d3@7.9.0: + dependencies: + d3-array: 3.2.4 + d3-axis: 3.0.0 + d3-brush: 3.0.0 + d3-chord: 3.0.1 + d3-color: 3.1.0 + d3-contour: 4.0.2 + d3-delaunay: 6.0.4 + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-dsv: 3.0.1 + d3-ease: 3.0.1 + d3-fetch: 3.0.1 + d3-force: 3.0.0 + d3-format: 3.1.0 + d3-geo: 3.1.1 + d3-hierarchy: 3.1.2 + d3-interpolate: 3.0.1 + d3-path: 3.1.0 + d3-polygon: 3.0.1 + d3-quadtree: 3.0.1 + d3-random: 3.0.1 + d3-scale: 4.0.2 + d3-scale-chromatic: 3.1.0 + d3-selection: 3.0.0 + d3-shape: 3.2.0 + d3-time: 3.1.0 + d3-time-format: 4.1.0 + d3-timer: 3.0.1 + d3-transition: 3.0.1(d3-selection@3.0.0) + d3-zoom: 3.0.0 + + dagre-d3-es@7.0.10: + dependencies: + d3: 7.9.0 + lodash-es: 4.17.21 + + dayjs@1.11.11: {} + + debug@4.3.5: + dependencies: + ms: 2.1.2 + + decode-named-character-reference@1.0.2: + dependencies: + character-entities: 2.0.2 + + delaunator@5.0.1: + dependencies: + robust-predicates: 3.0.2 + + dequal@2.0.3: {} + + devlop@1.1.0: + dependencies: + dequal: 2.0.3 + + diff@5.2.0: {} + + dompurify@3.1.5: {} + + elkjs@0.9.3: {} + + entities@4.5.0: {} + + escape-string-regexp@1.0.5: {} + + escape-string-regexp@5.0.0: {} + + esprima@4.0.1: {} + + estree-util-attach-comments@2.1.1: + dependencies: + '@types/estree': 1.0.5 + + estree-util-build-jsx@2.2.2: + dependencies: + '@types/estree-jsx': 1.0.5 + estree-util-is-identifier-name: 2.1.0 + estree-walker: 3.0.3 + + estree-util-is-identifier-name@2.1.0: {} + + estree-util-to-js@1.2.0: + dependencies: + '@types/estree-jsx': 1.0.5 + astring: 1.8.6 + source-map: 0.7.4 + + estree-util-value-to-estree@1.3.0: + dependencies: + is-plain-obj: 3.0.0 + + estree-util-visit@1.2.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/unist': 2.0.10 + + estree-walker@3.0.3: + dependencies: + '@types/estree': 1.0.5 + + execa@0.8.0: + dependencies: + cross-spawn: 5.1.0 + get-stream: 3.0.0 + is-stream: 1.1.0 + npm-run-path: 2.0.2 + p-finally: 1.0.0 + signal-exit: 3.0.7 + strip-eof: 1.0.0 + + extend-shallow@2.0.1: + dependencies: + is-extendable: 0.1.1 + + extend@3.0.2: {} + + flexsearch@0.7.43: {} + + focus-visible@5.2.0: {} + + get-stream@3.0.0: {} + + git-up@7.0.0: + dependencies: + is-ssh: 1.4.0 + parse-url: 8.1.0 + + git-url-parse@13.1.1: + dependencies: + git-up: 7.0.0 + + github-slugger@2.0.0: {} + + graceful-fs@4.2.11: {} + + gray-matter@4.0.3: + dependencies: + js-yaml: 3.14.1 + kind-of: 6.0.3 + section-matter: 1.0.0 + strip-bom-string: 1.0.0 + + has-flag@2.0.0: {} + + hash-obj@4.0.0: + dependencies: + is-obj: 3.0.0 + sort-keys: 5.0.0 + type-fest: 1.4.0 + + hast-util-from-dom@5.0.0: + dependencies: + '@types/hast': 3.0.4 + hastscript: 8.0.0 + web-namespaces: 2.0.1 + + hast-util-from-html-isomorphic@2.0.0: + dependencies: + '@types/hast': 3.0.4 + hast-util-from-dom: 5.0.0 + hast-util-from-html: 2.0.1 + unist-util-remove-position: 5.0.0 + + hast-util-from-html@2.0.1: + dependencies: + '@types/hast': 3.0.4 + devlop: 1.1.0 + hast-util-from-parse5: 8.0.1 + parse5: 7.1.2 + vfile: 6.0.1 + vfile-message: 4.0.2 + + hast-util-from-parse5@8.0.1: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + devlop: 1.1.0 + hastscript: 8.0.0 + property-information: 6.5.0 + vfile: 6.0.1 + vfile-location: 5.0.2 + web-namespaces: 2.0.1 + + hast-util-is-element@3.0.0: + dependencies: + '@types/hast': 3.0.4 + + hast-util-parse-selector@4.0.0: + dependencies: + '@types/hast': 3.0.4 + + hast-util-raw@9.0.3: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + '@ungap/structured-clone': 1.2.0 + hast-util-from-parse5: 8.0.1 + hast-util-to-parse5: 8.0.0 + html-void-elements: 3.0.0 + mdast-util-to-hast: 13.1.0 + parse5: 7.1.2 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.1 + web-namespaces: 2.0.1 + zwitch: 2.0.4 + + hast-util-to-estree@2.3.3: + dependencies: + '@types/estree': 1.0.5 + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/unist': 2.0.10 + comma-separated-tokens: 2.0.3 + estree-util-attach-comments: 2.1.1 + estree-util-is-identifier-name: 2.1.0 + hast-util-whitespace: 2.0.1 + mdast-util-mdx-expression: 1.3.2 + mdast-util-mdxjs-esm: 1.3.1 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + style-to-object: 0.4.4 + unist-util-position: 4.0.4 + zwitch: 2.0.4 + transitivePeerDependencies: + - supports-color + + hast-util-to-parse5@8.0.0: + dependencies: + '@types/hast': 3.0.4 + comma-separated-tokens: 2.0.3 + devlop: 1.1.0 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + web-namespaces: 2.0.1 + zwitch: 2.0.4 + + hast-util-to-text@4.0.2: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + hast-util-is-element: 3.0.0 + unist-util-find-after: 5.0.0 + + hast-util-whitespace@2.0.1: {} + + hastscript@8.0.0: + dependencies: + '@types/hast': 3.0.4 + comma-separated-tokens: 2.0.3 + hast-util-parse-selector: 4.0.0 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + + html-void-elements@3.0.0: {} + + iconv-lite@0.6.3: + dependencies: + safer-buffer: 2.1.2 + + inline-style-parser@0.1.1: {} + + internmap@1.0.1: {} + + internmap@2.0.3: {} + + intersection-observer@0.12.2: {} + + is-alphabetical@2.0.1: {} + + is-alphanumerical@2.0.1: + dependencies: + is-alphabetical: 2.0.1 + is-decimal: 2.0.1 + + is-buffer@2.0.5: {} + + is-decimal@2.0.1: {} + + is-extendable@0.1.1: {} + + is-hexadecimal@2.0.1: {} + + is-obj@3.0.0: {} + + is-plain-obj@3.0.0: {} + + is-plain-obj@4.1.0: {} + + is-reference@3.0.2: + dependencies: + '@types/estree': 1.0.5 + + is-ssh@1.4.0: + dependencies: + protocols: 2.0.1 + + is-stream@1.1.0: {} + + isexe@2.0.0: {} + + js-tokens@4.0.0: {} + + js-yaml@3.14.1: + dependencies: + argparse: 1.0.10 + esprima: 4.0.1 + + js-yaml@4.1.0: + dependencies: + argparse: 2.0.1 + + jsonc-parser@3.2.1: {} + + katex@0.16.10: + dependencies: + commander: 8.3.0 + + khroma@2.1.0: {} + + kind-of@6.0.3: {} + + kleur@4.1.5: {} + + layout-base@1.0.2: {} + + lodash-es@4.17.21: {} + + lodash.get@4.4.2: {} + + longest-streak@3.1.0: {} + + loose-envify@1.4.0: + dependencies: + js-tokens: 4.0.0 + + lru-cache@4.1.5: + dependencies: + pseudomap: 1.0.2 + yallist: 2.1.2 + + markdown-extensions@1.1.1: {} + + markdown-table@3.0.3: {} + + match-sorter@6.3.4: + dependencies: + '@babel/runtime': 7.24.7 + remove-accents: 0.5.0 + + mdast-util-definitions@5.1.2: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + unist-util-visit: 4.1.2 + + mdast-util-find-and-replace@2.2.2: + dependencies: + '@types/mdast': 3.0.15 + escape-string-regexp: 5.0.0 + unist-util-is: 5.2.1 + unist-util-visit-parents: 5.1.3 + + mdast-util-from-markdown@1.3.1: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + decode-named-character-reference: 1.0.2 + mdast-util-to-string: 3.2.0 + micromark: 3.2.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-decode-string: 1.1.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-stringify-position: 3.0.3 + uvu: 0.5.6 + transitivePeerDependencies: + - supports-color + + mdast-util-gfm-autolink-literal@1.0.3: + dependencies: + '@types/mdast': 3.0.15 + ccount: 2.0.1 + mdast-util-find-and-replace: 2.2.2 + micromark-util-character: 1.2.0 + + mdast-util-gfm-footnote@1.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + micromark-util-normalize-identifier: 1.1.0 + + mdast-util-gfm-strikethrough@1.0.3: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + + mdast-util-gfm-table@1.0.7: + dependencies: + '@types/mdast': 3.0.15 + markdown-table: 3.0.3 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-gfm-task-list-item@1.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + + mdast-util-gfm@2.0.2: + dependencies: + mdast-util-from-markdown: 1.3.1 + mdast-util-gfm-autolink-literal: 1.0.3 + mdast-util-gfm-footnote: 1.0.2 + mdast-util-gfm-strikethrough: 1.0.3 + mdast-util-gfm-table: 1.0.7 + mdast-util-gfm-task-list-item: 1.0.2 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-math@2.0.2: + dependencies: + '@types/mdast': 3.0.15 + longest-streak: 3.1.0 + mdast-util-to-markdown: 1.5.0 + + mdast-util-mdx-expression@1.3.2: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx-jsx@2.1.4: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + ccount: 2.0.1 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + parse-entities: 4.0.1 + stringify-entities: 4.0.4 + unist-util-remove-position: 4.0.2 + unist-util-stringify-position: 3.0.3 + vfile-message: 3.1.4 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx@2.0.1: + dependencies: + mdast-util-from-markdown: 1.3.1 + mdast-util-mdx-expression: 1.3.2 + mdast-util-mdx-jsx: 2.1.4 + mdast-util-mdxjs-esm: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdxjs-esm@1.3.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-phrasing@3.0.1: + dependencies: + '@types/mdast': 3.0.15 + unist-util-is: 5.2.1 + + mdast-util-to-hast@12.3.0: + dependencies: + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-definitions: 5.1.2 + micromark-util-sanitize-uri: 1.2.0 + trim-lines: 3.0.1 + unist-util-generated: 2.0.1 + unist-util-position: 4.0.4 + unist-util-visit: 4.1.2 + + mdast-util-to-hast@13.1.0: + dependencies: + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + '@ungap/structured-clone': 1.2.0 + devlop: 1.1.0 + micromark-util-sanitize-uri: 2.0.0 + trim-lines: 3.0.1 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.1 + + mdast-util-to-markdown@1.5.0: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + longest-streak: 3.1.0 + mdast-util-phrasing: 3.0.1 + mdast-util-to-string: 3.2.0 + micromark-util-decode-string: 1.1.0 + unist-util-visit: 4.1.2 + zwitch: 2.0.4 + + mdast-util-to-string@3.2.0: + dependencies: + '@types/mdast': 3.0.15 + + mermaid@10.9.1: + dependencies: + '@braintree/sanitize-url': 6.0.4 + '@types/d3-scale': 4.0.8 + '@types/d3-scale-chromatic': 3.0.3 + cytoscape: 3.29.2 + cytoscape-cose-bilkent: 4.1.0(cytoscape@3.29.2) + d3: 7.9.0 + d3-sankey: 0.12.3 + dagre-d3-es: 7.0.10 + dayjs: 1.11.11 + dompurify: 3.1.5 + elkjs: 0.9.3 + katex: 0.16.10 + khroma: 2.1.0 + lodash-es: 4.17.21 + mdast-util-from-markdown: 1.3.1 + non-layered-tidy-tree-layout: 2.0.2 + stylis: 4.3.2 + ts-dedent: 2.2.0 + uuid: 9.0.1 + web-worker: 1.3.0 + transitivePeerDependencies: + - supports-color + + micromark-core-commonmark@1.1.0: + dependencies: + decode-named-character-reference: 1.0.2 + micromark-factory-destination: 1.1.0 + micromark-factory-label: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-factory-title: 1.1.0 + micromark-factory-whitespace: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-chunked: 1.1.0 + micromark-util-classify-character: 1.1.0 + micromark-util-html-tag-name: 1.2.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-subtokenize: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-autolink-literal@1.0.5: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-extension-gfm-footnote@1.1.2: + dependencies: + micromark-core-commonmark: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-strikethrough@1.0.7: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-classify-character: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-table@1.0.7: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-tagfilter@1.0.2: + dependencies: + micromark-util-types: 1.1.0 + + micromark-extension-gfm-task-list-item@1.0.5: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm@2.0.3: + dependencies: + micromark-extension-gfm-autolink-literal: 1.0.5 + micromark-extension-gfm-footnote: 1.1.2 + micromark-extension-gfm-strikethrough: 1.0.7 + micromark-extension-gfm-table: 1.0.7 + micromark-extension-gfm-tagfilter: 1.0.2 + micromark-extension-gfm-task-list-item: 1.0.5 + micromark-util-combine-extensions: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-extension-math@2.1.2: + dependencies: + '@types/katex': 0.16.7 + katex: 0.16.10 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-mdx-expression@1.0.8: + dependencies: + '@types/estree': 1.0.5 + micromark-factory-mdx-expression: 1.0.9 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-mdx-jsx@1.0.5: + dependencies: + '@types/acorn': 4.0.6 + '@types/estree': 1.0.5 + estree-util-is-identifier-name: 2.1.0 + micromark-factory-mdx-expression: 1.0.9 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-extension-mdx-md@1.0.1: + dependencies: + micromark-util-types: 1.1.0 + + micromark-extension-mdxjs-esm@1.0.5: + dependencies: + '@types/estree': 1.0.5 + micromark-core-commonmark: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-position-from-estree: 1.1.2 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-extension-mdxjs@1.0.1: + dependencies: + acorn: 8.11.3 + acorn-jsx: 5.3.2(acorn@8.11.3) + micromark-extension-mdx-expression: 1.0.8 + micromark-extension-mdx-jsx: 1.0.5 + micromark-extension-mdx-md: 1.0.1 + micromark-extension-mdxjs-esm: 1.0.5 + micromark-util-combine-extensions: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-destination@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-label@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-factory-mdx-expression@1.0.9: + dependencies: + '@types/estree': 1.0.5 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-position-from-estree: 1.1.2 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-factory-space@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-types: 1.1.0 + + micromark-factory-title@1.1.0: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-whitespace@1.1.0: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-character@1.2.0: + dependencies: + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-character@2.1.0: + dependencies: + micromark-util-symbol: 2.0.0 + micromark-util-types: 2.0.0 + + micromark-util-chunked@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-classify-character@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-combine-extensions@1.1.0: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-decode-numeric-character-reference@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-decode-string@1.1.0: + dependencies: + decode-named-character-reference: 1.0.2 + micromark-util-character: 1.2.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-symbol: 1.1.0 + + micromark-util-encode@1.1.0: {} + + micromark-util-encode@2.0.0: {} + + micromark-util-events-to-acorn@1.2.3: + dependencies: + '@types/acorn': 4.0.6 + '@types/estree': 1.0.5 + '@types/unist': 2.0.10 + estree-util-visit: 1.2.1 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-util-html-tag-name@1.2.0: {} + + micromark-util-normalize-identifier@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-resolve-all@1.1.0: + dependencies: + micromark-util-types: 1.1.0 + + micromark-util-sanitize-uri@1.2.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-encode: 1.1.0 + micromark-util-symbol: 1.1.0 + + micromark-util-sanitize-uri@2.0.0: + dependencies: + micromark-util-character: 2.1.0 + micromark-util-encode: 2.0.0 + micromark-util-symbol: 2.0.0 + + micromark-util-subtokenize@1.1.0: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-util-symbol@1.1.0: {} + + micromark-util-symbol@2.0.0: {} + + micromark-util-types@1.1.0: {} + + micromark-util-types@2.0.0: {} + + micromark@3.2.0: + dependencies: + '@types/debug': 4.1.12 + debug: 4.3.5 + decode-named-character-reference: 1.0.2 + micromark-core-commonmark: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-chunked: 1.1.0 + micromark-util-combine-extensions: 1.1.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-encode: 1.1.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-subtokenize: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + transitivePeerDependencies: + - supports-color + + mri@1.2.0: {} + + ms@2.1.2: {} + + nanoid@3.3.7: {} + + next-mdx-remote@4.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@mdx-js/mdx': 2.3.0 + '@mdx-js/react': 2.3.0(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + vfile: 5.3.7 + vfile-matter: 3.0.1 + transitivePeerDependencies: + - supports-color + + next-seo@6.5.0(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + next-themes@0.2.1(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@next/env': 14.2.3 + '@swc/helpers': 0.5.5 + busboy: 1.6.0 + caniuse-lite: 1.0.30001629 + graceful-fs: 4.2.11 + postcss: 8.4.31 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + styled-jsx: 5.1.1(react@18.3.1) + optionalDependencies: + '@next/swc-darwin-arm64': 14.2.3 + '@next/swc-darwin-x64': 14.2.3 + '@next/swc-linux-arm64-gnu': 14.2.3 + '@next/swc-linux-arm64-musl': 14.2.3 + '@next/swc-linux-x64-gnu': 14.2.3 + '@next/swc-linux-x64-musl': 14.2.3 + '@next/swc-win32-arm64-msvc': 14.2.3 + '@next/swc-win32-ia32-msvc': 14.2.3 + '@next/swc-win32-x64-msvc': 14.2.3 + transitivePeerDependencies: + - '@babel/core' + - babel-plugin-macros + + nextra-theme-docs@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@headlessui/react': 1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@popperjs/core': 2.11.8 + clsx: 2.1.1 + escape-string-regexp: 5.0.0 + flexsearch: 0.7.43 + focus-visible: 5.2.0 + git-url-parse: 13.1.1 + intersection-observer: 0.12.2 + match-sorter: 6.3.4 + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-seo: 6.5.0(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-themes: 0.2.1(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + scroll-into-view-if-needed: 3.1.0 + zod: 3.23.8 + + nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@headlessui/react': 1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@mdx-js/mdx': 2.3.0 + '@mdx-js/react': 2.3.0(react@18.3.1) + '@napi-rs/simple-git': 0.1.16 + '@theguild/remark-mermaid': 0.0.5(react@18.3.1) + '@theguild/remark-npm2yarn': 0.2.1 + clsx: 2.1.1 + github-slugger: 2.0.0 + graceful-fs: 4.2.11 + gray-matter: 4.0.3 + katex: 0.16.10 + lodash.get: 4.4.2 + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-mdx-remote: 4.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + p-limit: 3.1.0 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + rehype-katex: 7.0.0 + rehype-pretty-code: 0.9.11(shiki@0.14.7) + rehype-raw: 7.0.0 + remark-gfm: 3.0.1 + remark-math: 5.1.1 + remark-reading-time: 2.0.1 + shiki: 0.14.7 + slash: 3.0.0 + title: 3.5.3 + unist-util-remove: 4.0.0 + unist-util-visit: 5.0.0 + zod: 3.23.8 + transitivePeerDependencies: + - supports-color + + non-layered-tidy-tree-layout@2.0.2: {} + + npm-run-path@2.0.2: + dependencies: + path-key: 2.0.1 + + npm-to-yarn@2.2.1: {} + + p-finally@1.0.0: {} + + p-limit@3.1.0: + dependencies: + yocto-queue: 0.1.0 + + parse-entities@4.0.1: + dependencies: + '@types/unist': 2.0.10 + character-entities: 2.0.2 + character-entities-legacy: 3.0.0 + character-reference-invalid: 2.0.1 + decode-named-character-reference: 1.0.2 + is-alphanumerical: 2.0.1 + is-decimal: 2.0.1 + is-hexadecimal: 2.0.1 + + parse-numeric-range@1.3.0: {} + + parse-path@7.0.0: + dependencies: + protocols: 2.0.1 + + parse-url@8.1.0: + dependencies: + parse-path: 7.0.0 + + parse5@7.1.2: + dependencies: + entities: 4.5.0 + + path-key@2.0.1: {} + + periscopic@3.1.0: + dependencies: + '@types/estree': 1.0.5 + estree-walker: 3.0.3 + is-reference: 3.0.2 + + picocolors@1.0.1: {} + + postcss@8.4.31: + dependencies: + nanoid: 3.3.7 + picocolors: 1.0.1 + source-map-js: 1.2.0 + + property-information@6.5.0: {} + + protocols@2.0.1: {} + + pseudomap@1.0.2: {} + + react-dom@18.3.1(react@18.3.1): + dependencies: + loose-envify: 1.4.0 + react: 18.3.1 + scheduler: 0.23.2 + + react@18.3.1: + dependencies: + loose-envify: 1.4.0 + + reading-time@1.5.0: {} + + regenerator-runtime@0.14.1: {} + + rehype-katex@7.0.0: + dependencies: + '@types/hast': 3.0.4 + '@types/katex': 0.16.7 + hast-util-from-html-isomorphic: 2.0.0 + hast-util-to-text: 4.0.2 + katex: 0.16.10 + unist-util-visit-parents: 6.0.1 + vfile: 6.0.1 + + rehype-pretty-code@0.9.11(shiki@0.14.7): + dependencies: + '@types/hast': 2.3.10 + hash-obj: 4.0.0 + parse-numeric-range: 1.3.0 + shiki: 0.14.7 + + rehype-raw@7.0.0: + dependencies: + '@types/hast': 3.0.4 + hast-util-raw: 9.0.3 + vfile: 6.0.1 + + remark-gfm@3.0.1: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-gfm: 2.0.2 + micromark-extension-gfm: 2.0.3 + unified: 10.1.2 + transitivePeerDependencies: + - supports-color + + remark-math@5.1.1: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-math: 2.0.2 + micromark-extension-math: 2.1.2 + unified: 10.1.2 + + remark-mdx@2.3.0: + dependencies: + mdast-util-mdx: 2.0.1 + micromark-extension-mdxjs: 1.0.1 + transitivePeerDependencies: + - supports-color + + remark-parse@10.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + unified: 10.1.2 + transitivePeerDependencies: + - supports-color + + remark-reading-time@2.0.1: + dependencies: + estree-util-is-identifier-name: 2.1.0 + estree-util-value-to-estree: 1.3.0 + reading-time: 1.5.0 + unist-util-visit: 3.1.0 + + remark-rehype@10.1.0: + dependencies: + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-to-hast: 12.3.0 + unified: 10.1.2 + + remove-accents@0.5.0: {} + + robust-predicates@3.0.2: {} + + rw@1.3.3: {} + + sade@1.8.1: + dependencies: + mri: 1.2.0 + + safer-buffer@2.1.2: {} + + scheduler@0.23.2: + dependencies: + loose-envify: 1.4.0 + + scroll-into-view-if-needed@3.1.0: + dependencies: + compute-scroll-into-view: 3.1.0 + + section-matter@1.0.0: + dependencies: + extend-shallow: 2.0.1 + kind-of: 6.0.3 + + shebang-command@1.2.0: + dependencies: + shebang-regex: 1.0.0 + + shebang-regex@1.0.0: {} + + shiki@0.14.7: + dependencies: + ansi-sequence-parser: 1.1.1 + jsonc-parser: 3.2.1 + vscode-oniguruma: 1.7.0 + vscode-textmate: 8.0.0 + + signal-exit@3.0.7: {} + + slash@3.0.0: {} + + sort-keys@5.0.0: + dependencies: + is-plain-obj: 4.1.0 + + source-map-js@1.2.0: {} + + source-map@0.7.4: {} + + space-separated-tokens@2.0.2: {} + + sprintf-js@1.0.3: {} + + streamsearch@1.1.0: {} + + stringify-entities@4.0.4: + dependencies: + character-entities-html4: 2.1.0 + character-entities-legacy: 3.0.0 + + strip-bom-string@1.0.0: {} + + strip-eof@1.0.0: {} + + style-to-object@0.4.4: + dependencies: + inline-style-parser: 0.1.1 + + styled-jsx@5.1.1(react@18.3.1): + dependencies: + client-only: 0.0.1 + react: 18.3.1 + + stylis@4.3.2: {} + + supports-color@4.5.0: + dependencies: + has-flag: 2.0.0 + + title@3.5.3: + dependencies: + arg: 1.0.0 + chalk: 2.3.0 + clipboardy: 1.2.2 + titleize: 1.0.0 + + titleize@1.0.0: {} + + trim-lines@3.0.1: {} + + trough@2.2.0: {} + + ts-dedent@2.2.0: {} + + tslib@2.6.3: {} + + type-fest@1.4.0: {} + + typescript@5.4.5: {} + + undici-types@5.26.5: {} + + unified@10.1.2: + dependencies: + '@types/unist': 2.0.10 + bail: 2.0.2 + extend: 3.0.2 + is-buffer: 2.0.5 + is-plain-obj: 4.1.0 + trough: 2.2.0 + vfile: 5.3.7 + + unist-util-find-after@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + + unist-util-generated@2.0.1: {} + + unist-util-is@5.2.1: + dependencies: + '@types/unist': 2.0.10 + + unist-util-is@6.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-position-from-estree@1.1.2: + dependencies: + '@types/unist': 2.0.10 + + unist-util-position@4.0.4: + dependencies: + '@types/unist': 2.0.10 + + unist-util-position@5.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-remove-position@4.0.2: + dependencies: + '@types/unist': 2.0.10 + unist-util-visit: 4.1.2 + + unist-util-remove-position@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-visit: 5.0.0 + + unist-util-remove@4.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + unist-util-stringify-position@3.0.3: + dependencies: + '@types/unist': 2.0.10 + + unist-util-stringify-position@4.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-visit-parents@4.1.1: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + + unist-util-visit-parents@5.1.3: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + + unist-util-visit-parents@6.0.1: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + + unist-util-visit@3.1.0: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + unist-util-visit-parents: 4.1.1 + + unist-util-visit@4.1.2: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + unist-util-visit-parents: 5.1.3 + + unist-util-visit@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + uuid@9.0.1: {} + + uvu@0.5.6: + dependencies: + dequal: 2.0.3 + diff: 5.2.0 + kleur: 4.1.5 + sade: 1.8.1 + + vfile-location@5.0.2: + dependencies: + '@types/unist': 3.0.2 + vfile: 6.0.1 + + vfile-matter@3.0.1: + dependencies: + '@types/js-yaml': 4.0.9 + is-buffer: 2.0.5 + js-yaml: 4.1.0 + + vfile-message@3.1.4: + dependencies: + '@types/unist': 2.0.10 + unist-util-stringify-position: 3.0.3 + + vfile-message@4.0.2: + dependencies: + '@types/unist': 3.0.2 + unist-util-stringify-position: 4.0.0 + + vfile@5.3.7: + dependencies: + '@types/unist': 2.0.10 + is-buffer: 2.0.5 + unist-util-stringify-position: 3.0.3 + vfile-message: 3.1.4 + + vfile@6.0.1: + dependencies: + '@types/unist': 3.0.2 + unist-util-stringify-position: 4.0.0 + vfile-message: 4.0.2 + + vscode-oniguruma@1.7.0: {} + + vscode-textmate@8.0.0: {} + + web-namespaces@2.0.1: {} + + web-worker@1.3.0: {} + + which@1.3.1: + dependencies: + isexe: 2.0.0 + + yallist@2.1.2: {} + + yocto-queue@0.1.0: {} + + zod@3.23.8: {} + + zwitch@2.0.4: {} diff --git a/docs/assets/banner.png b/docs/public/assets/banner.png similarity index 100% rename from docs/assets/banner.png rename to docs/public/assets/banner.png diff --git a/docs/assets/icon.png b/docs/public/assets/icon.png similarity index 100% rename from docs/assets/icon.png rename to docs/public/assets/icon.png diff --git a/docs/assets/sample-onnx-graph.png b/docs/public/assets/sample-onnx-graph.png similarity index 100% rename from docs/assets/sample-onnx-graph.png rename to docs/public/assets/sample-onnx-graph.png diff --git a/docs/assets/trend-banner.png b/docs/public/assets/trend-banner.png similarity index 100% rename from docs/assets/trend-banner.png rename to docs/public/assets/trend-banner.png diff --git a/docs/setup/linking.mdx b/docs/setup/linking.mdx deleted file mode 100644 index ecba449f..00000000 --- a/docs/setup/linking.mdx +++ /dev/null @@ -1,106 +0,0 @@ ---- -title: Linking -description: Here's how `ort` links to ONNX Runtime, and how to configure its behavior. ---- - -In some cases, you'll want to use a custom build of ONNX Runtime with `ort`. Luckily, we make this very easy by handling all of the linking configuration automagically. Just point `ort` to the output of ONNX Runtime's build pipeline and it'll Just Work™. - -## Static linking -Most ONNX Runtime compile configurations will support static linking - just run `build.sh` without the `--build_shared_lib` argument. You should prefer static linking if your execution providers support it, as it avoids many issues and follows de facto Rust practices. If you compile both static libraries and dynamic libraries, `ort` will prefer linking to the static libraries. - -To direct `ort` to your statically built binaries, use the `ORT_LIB_LOCATION` environment variable when running `cargo build`. Point it to the location where the static libraries (`.a`/`.lib` files) are compiled to. This will typically be `onnxruntime/build/`. For example: -```shell -$ ORT_LIB_LOCATION=~/onnxruntime/build/Linux cargo build -``` - -For iOS (or for other platforms if you are compiling multiple profiles at once), you'll need to manually specify the profile with the `ORT_LIB_PROFILE` environment variable. If not specified, `ort` will prefer `Release` over `RelWithDebInfo` over `MinSizeRel` over `Debug`. - -## Dynamic linking -Some execution providers unfortunately only support dynamic linking. Dynamic linking doesn't play well with the Rust ecosystem, though `ort` tries to alleviate the pain as much as possible. - -When it comes to dynamic linking, there are two options: `load-dynamic`, or standard compile-time dynamic linking. We recommend `load-dynamic` as it gives more control and is often far less troublesome to work with. - -### Runtime loading with `load-dynamic` -The `load-dynamic` Cargo feature solves a few of the issues with dynamic linking by **loading the library at runtime** rather than **linking at compile time**. This means that the path to the ONNX Runtime library can be configured at runtime, and the executable will not just completely fail to start if the binary couldn't be found. - -To use `load-dynamic`: - - - ```toml Cargo.toml - [dependencies] - ort = { version = "2", features = [ "load-dynamic" ] } - ``` - - - - - ```rust main.rs - fn main() -> anyhow::Result<()> { - // Find our custom ONNX Runtime dylib path somehow - // (i.e. resolving it from the root of our program's install folder) - let dylib_path = crate::internal::find_onnxruntime_dylib()?; - // The path should point to the `libonnxruntime` binary, which looks like: - // - on Unix: /etc/.../libonnxruntime.so - // - on Windows: C:\Program Files\...\onnxruntime.dll - - // Initialize ort with the path to the dylib. This **must** be called before any usage of `ort`! - // `init_from` returns an `EnvironmentBuilder` which you can use to further configure the environment - // before `.commit()`ing; see the Environment docs for more information on what you can configure. - ort::init_from(dylib_path).commit()?; - - Ok(()) - } - ``` - - - Set the `ORT_DYLIB_PATH` environment variable to the path to `libonnxruntime.so`/`onnxruntime.dll`. - - ```shell - $ ORT_DYLIB_PATH=../onnxruntime-build/linux-x64/libonnxruntime.so ./mirai - ``` - - - - - -`ORT_DYLIB_PATH` is relative to the executable. Cargo examples and tests are compiled to a different directory than binary crates: `target//examples` and `target//deps` respectively. Keep this in mind if you're going to use relative paths. - -### Compile-time dynamic linking -For compile-time dynamic linking, you'll need to configure your environment in the exact same way as if you were [statically linking](#static-linking). - -Note that the dylibs then have to be placed in a certain location for them to be found by the executable. For Windows, this is either somewhere on the `PATH`, or in the same folder as the executable. On macOS and Linux, they have to be placed somewhere in the `LD_LIBRARY_PATH`, or you can use rpath to configure the executable to search for dylibs in its parent folder. We've had the least issues with rpath, but YMMV. - -To configure rpath, you'll need to: - - - ```toml - [profile.dev] - rpath = true - - [profile.release] - rpath = true - - # do this for any other profiles - ``` - - - - - ```toml - [target.x86_64-unknown-linux-gnu] - rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ] - - # do this for any other Linux targets as well - ``` - - - ```toml - [target.x86_64-apple-darwin] - rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ] - - # do this for any other macOS targets as well - ``` - - - - diff --git a/docs/theme.config.jsx b/docs/theme.config.jsx new file mode 100644 index 00000000..ef71c4eb --- /dev/null +++ b/docs/theme.config.jsx @@ -0,0 +1,33 @@ +import Image from 'next/image'; + +/** @type {import('nextra-theme-docs').DocsThemeConfig} */ +const config = { + project: { + link: 'https://github.com/pykeio/ort' + }, + chat: { + link: 'https://discord.gg/uQtsNu2xMa' + }, + docsRepositoryBase: 'https://github.com/pykeio/ort/blob/main/docs', + useNextSeoProps() { + return { + titleTemplate: '%s | ort' + } + }, + logo: , + darkMode: true, + nextThemes: { + defaultTheme: 'system' + }, + footer: { + text:
+

made with 💜 by pykesponsor

+
+ }, + primaryHue: 20, + primarySaturation: 100, + toc: { + float: true + } +}; +export default config; diff --git a/docs/tsconfig.json b/docs/tsconfig.json new file mode 100644 index 00000000..19deeffc --- /dev/null +++ b/docs/tsconfig.json @@ -0,0 +1,28 @@ +{ + "compilerOptions": { + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": true, + "skipLibCheck": true, + "strict": false, + "noEmit": true, + "incremental": true, + "module": "esnext", + "esModuleInterop": true, + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve" + }, + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx" +, "pages/_app.mdx" ], + "exclude": [ + "node_modules" + ] +} From 812fdb056c8ed2e81cf20b8b60739b2b16ca1f2b Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 8 Jun 2024 13:20:03 -0500 Subject: [PATCH 22/49] config: add python venv to gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index bf1af90f..b00a624d 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,5 @@ WixTools/ # Glassbench results /glassbench*.db + +.venv* From 23fce788b6ddb7edf2675cfcda6cf04c352fee60 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 14 Jun 2024 18:07:13 -0500 Subject: [PATCH 23/49] feat: option to error out session builder if EP registration fails --- docs/pages/perf/execution-providers.mdx | 17 +++- examples/cudarc/src/main.rs | 2 +- src/execution_providers/acl.rs | 2 +- src/execution_providers/armnn.rs | 2 +- src/execution_providers/cann.rs | 2 +- src/execution_providers/coreml.rs | 2 +- src/execution_providers/cpu.rs | 2 +- src/execution_providers/cuda.rs | 2 +- src/execution_providers/directml.rs | 2 +- src/execution_providers/mod.rs | 108 ++++++++++++------------ src/execution_providers/nnapi.rs | 2 +- src/execution_providers/onednn.rs | 2 +- src/execution_providers/openvino.rs | 2 +- src/execution_providers/qnn.rs | 2 +- src/execution_providers/rocm.rs | 2 +- src/execution_providers/tensorrt.rs | 2 +- src/execution_providers/tvm.rs | 2 +- src/execution_providers/xnnpack.rs | 2 +- src/session/builder.rs | 6 +- 19 files changed, 89 insertions(+), 74 deletions(-) diff --git a/docs/pages/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx index f2e7b9e0..923590b2 100644 --- a/docs/pages/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -107,7 +107,22 @@ fn main() -> anyhow::Result<()> { ## Fallback behavior `ort` will silently fail and fall back to executing on the CPU if all execution providers fail to register. In many cases, though, you'll want to show the user an error message when an EP fails to register, or outright abort the process. -To receive these registration errors, instead use `ExecutionProvider::register` to register an execution provider: +You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`: +```rust +use ort::{CoreMLExecutionProvider, Session}; + +fn main() -> anyhow::Result<()> { + let session = Session::builder()? + .with_execution_providers([ + CUDAExecutionProvider::default().build().error_on_failure() + ])? + .commit_from_file("model.onnx")?; + + Ok(()) +} +``` + +If you require more complex error handling, you can also manually register execution providers via the `ExecutionProvider::register` method: ```rust use ort::{CUDAExecutionProvider, ExecutionProvider, Session}; diff --git a/examples/cudarc/src/main.rs b/examples/cudarc/src/main.rs index 1ffc01f0..20013a9f 100644 --- a/examples/cudarc/src/main.rs +++ b/examples/cudarc/src/main.rs @@ -11,7 +11,7 @@ fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); ort::init() - .with_execution_providers([CUDAExecutionProvider::default().build()]) + .with_execution_providers([CUDAExecutionProvider::default().build().error_on_failure()]) .commit()?; let model = diff --git a/src/execution_providers/acl.rs b/src/execution_providers/acl.rs index a8e3bdb3..1f15ac70 100644 --- a/src/execution_providers/acl.rs +++ b/src/execution_providers/acl.rs @@ -26,7 +26,7 @@ impl ACLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ACLExecutionProvider) -> Self { - ExecutionProviderDispatch::ACL(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/armnn.rs b/src/execution_providers/armnn.rs index 53a38795..86332f01 100644 --- a/src/execution_providers/armnn.rs +++ b/src/execution_providers/armnn.rs @@ -26,7 +26,7 @@ impl ArmNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ArmNNExecutionProvider) -> Self { - ExecutionProviderDispatch::ArmNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cann.rs b/src/execution_providers/cann.rs index c43e8e06..f37a2f1b 100644 --- a/src/execution_providers/cann.rs +++ b/src/execution_providers/cann.rs @@ -109,7 +109,7 @@ impl CANNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CANNExecutionProvider) -> Self { - ExecutionProviderDispatch::CANN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/coreml.rs b/src/execution_providers/coreml.rs index 94971e8a..256de1e5 100644 --- a/src/execution_providers/coreml.rs +++ b/src/execution_providers/coreml.rs @@ -46,7 +46,7 @@ impl CoreMLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CoreMLExecutionProvider) -> Self { - ExecutionProviderDispatch::CoreML(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cpu.rs b/src/execution_providers/cpu.rs index 2f98095d..eb4be919 100644 --- a/src/execution_providers/cpu.rs +++ b/src/execution_providers/cpu.rs @@ -21,7 +21,7 @@ impl CPUExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CPUExecutionProvider) -> Self { - ExecutionProviderDispatch::CPU(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index 77200e3c..17fbe825 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -161,7 +161,7 @@ impl CUDAExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: CUDAExecutionProvider) -> Self { - ExecutionProviderDispatch::CUDA(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/directml.rs b/src/execution_providers/directml.rs index 71802553..38556f11 100644 --- a/src/execution_providers/directml.rs +++ b/src/execution_providers/directml.rs @@ -26,7 +26,7 @@ impl DirectMLExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: DirectMLExecutionProvider) -> Self { - ExecutionProviderDispatch::DirectML(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index 24ec6acf..8c237041 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, os::raw::c_char}; +use std::{fmt::Debug, os::raw::c_char, sync::Arc}; use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder}; @@ -60,16 +60,17 @@ pub trait ExecutionProvider { true } - /// Returns `Ok(true)` if ONNX Runtime was compiled with support for this execution provider, and `Ok(false)` + /// Returns `Ok(true)` if ONNX Runtime was *compiled with support* for this execution provider, and `Ok(false)` /// otherwise. /// /// An `Err` may be returned if a serious internal error occurs, in which case your application should probably /// just abort. /// - /// Note that this does not always mean the execution provider is *usable* for a specific model. A model may use - /// operators not supported by an execution provider, or the EP may encounter an error while attempting to load a - /// dynamic library during registration. In most cases (i.e. showing the user an error message if CUDA could not be - /// enabled), you'll instead want to detect and handle errors from [`ExecutionProvider::register`]. + /// **Note that this does not always mean the execution provider is *usable* for a specific session.** A model may + /// use operators not supported by an execution provider, or the EP may encounter an error while attempting to load + /// dependencies during session creation. In most cases (i.e. showing the user an error message if CUDA could not be + /// enabled), you'll instead want to manually register this EP via [`ExecutionProvider::register`] and detect + /// and handle any errors returned by that function. fn is_available(&self) -> Result { let mut providers: *mut *mut c_char = std::ptr::null_mut(); let mut num_providers = 0; @@ -110,56 +111,50 @@ pub enum ArenaExtendStrategy { SameAsRequested } -/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more -/// info on execution providers. Execution providers are actually registered via the functions [`crate::SessionBuilder`] -/// (per-session) or [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder) (default for all sessions in an -/// environment). -#[derive(Debug, Clone)] +/// Dynamic execution provider container, used to provide a list of multiple types of execution providers when +/// configuring execution providers for a [`SessionBuilder`](crate::SessionBuilder) or +/// [`EnvironmentBuilder`](crate::environment::EnvironmentBuilder). +/// +/// See [`ExecutionProvider`] for more info on execution providers. +#[derive(Clone)] #[allow(missing_docs)] #[non_exhaustive] -pub enum ExecutionProviderDispatch { - CPU(CPUExecutionProvider), - CUDA(CUDAExecutionProvider), - TensorRT(TensorRTExecutionProvider), - OpenVINO(OpenVINOExecutionProvider), - ACL(ACLExecutionProvider), - OneDNN(OneDNNExecutionProvider), - CoreML(CoreMLExecutionProvider), - DirectML(DirectMLExecutionProvider), - ROCm(ROCmExecutionProvider), - NNAPI(NNAPIExecutionProvider), - QNN(QNNExecutionProvider), - TVM(TVMExecutionProvider), - CANN(CANNExecutionProvider), - XNNPACK(XNNPACKExecutionProvider), - ArmNN(ArmNNExecutionProvider) +pub struct ExecutionProviderDispatch { + pub(crate) inner: Arc, + error_on_failure: bool } -macro_rules! impl_dispatch { - ($($variant:ident),*) => { - impl ExecutionProvider for ExecutionProviderDispatch { - fn as_str(&self) -> &'static str { - match self { - $(Self::$variant(inner) => inner.as_str(),)* - } - } +impl ExecutionProviderDispatch { + pub(crate) fn new(ep: E) -> Self { + ExecutionProviderDispatch { + inner: Arc::new(ep) as Arc, + error_on_failure: false + } + } - fn is_available(&self) -> $crate::Result { - match self { - $(Self::$variant(inner) => inner.is_available(),)* - } - } + /// Configures this execution provider to silently log an error if registration of the EP fails. + /// This is the default behavior; it can be overridden with [`ExecutionProviderDispatch::error_on_failure`]. + pub fn fail_silently(mut self) -> Self { + self.error_on_failure = false; + self + } - fn register(&self, session_builder: &$crate::SessionBuilder) -> $crate::Result<()> { - match self { - $(Self::$variant(inner) => inner.register(session_builder),)* - } - } - } - }; + /// Configures this execution provider to return an error upon EP registration if registration of this EP fails. + /// The default behavior is to silently fail and fall back to the next execution provider, or the CPU provider if no + /// registrations succeed. + pub fn error_on_failure(mut self) -> Self { + self.error_on_failure = true; + self + } } -impl_dispatch!(CPU, CUDA, TensorRT, ACL, OneDNN, OpenVINO, CoreML, CANN, ROCm, DirectML, TVM, NNAPI, QNN, XNNPACK, ArmNN); +impl Debug for ExecutionProviderDispatch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(self.inner.as_str()) + .field("error_on_failure", &self.error_on_failure) + .finish() + } +} #[allow(unused)] macro_rules! map_keys { @@ -207,26 +202,31 @@ macro_rules! get_ep_register { pub(crate) use get_ep_register; #[tracing::instrument(skip_all)] -pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator) { +pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator) -> Result<()> { let execution_providers: Vec<_> = execution_providers.collect(); let mut fallback_to_cpu = !execution_providers.is_empty(); for ex in execution_providers { - if let Err(e) = ex.register(session_builder) { + if let Err(e) = ex.inner.register(session_builder) { + if ex.error_on_failure { + return Err(e); + } + if let &Error::ExecutionProviderNotRegistered(ep_name) = &e { - if ex.supported_by_platform() { + if ex.inner.supported_by_platform() { tracing::warn!("{e}"); } else { - tracing::debug!("{e} (additionally, `{ep_name}` is not supported on this platform)"); + tracing::debug!("{e} (note: additionally, `{ep_name}` is not supported on this platform)"); } } else { - tracing::warn!("An error occurred when attempting to register `{}`: {e}", ex.as_str()); + tracing::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str()); } } else { - tracing::info!("Successfully registered `{}`", ex.as_str()); + tracing::info!("Successfully registered `{}`", ex.inner.as_str()); fallback_to_cpu = false; } } if fallback_to_cpu { tracing::warn!("No execution providers registered successfully. Falling back to CPU."); } + Ok(()) } diff --git a/src/execution_providers/nnapi.rs b/src/execution_providers/nnapi.rs index 472db339..9f1951ef 100644 --- a/src/execution_providers/nnapi.rs +++ b/src/execution_providers/nnapi.rs @@ -59,7 +59,7 @@ impl NNAPIExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: NNAPIExecutionProvider) -> Self { - ExecutionProviderDispatch::NNAPI(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/onednn.rs b/src/execution_providers/onednn.rs index 04166757..795d0e66 100644 --- a/src/execution_providers/onednn.rs +++ b/src/execution_providers/onednn.rs @@ -29,7 +29,7 @@ impl OneDNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: OneDNNExecutionProvider) -> Self { - ExecutionProviderDispatch::OneDNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index fb8f932b..95dc8e26 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -103,7 +103,7 @@ impl OpenVINOExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: OpenVINOExecutionProvider) -> Self { - ExecutionProviderDispatch::OpenVINO(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/qnn.rs b/src/execution_providers/qnn.rs index 6262aac3..eb7075d5 100644 --- a/src/execution_providers/qnn.rs +++ b/src/execution_providers/qnn.rs @@ -110,7 +110,7 @@ impl QNNExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: QNNExecutionProvider) -> Self { - ExecutionProviderDispatch::QNN(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index c2c28857..be4cfdea 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -114,7 +114,7 @@ impl ROCmExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: ROCmExecutionProvider) -> Self { - ExecutionProviderDispatch::ROCm(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index fe581c34..e60e16f0 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -210,7 +210,7 @@ impl TensorRTExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: TensorRTExecutionProvider) -> Self { - ExecutionProviderDispatch::TensorRT(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/tvm.rs b/src/execution_providers/tvm.rs index a054a704..19c8ea7a 100644 --- a/src/execution_providers/tvm.rs +++ b/src/execution_providers/tvm.rs @@ -54,7 +54,7 @@ impl TVMExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: TVMExecutionProvider) -> Self { - ExecutionProviderDispatch::TVM(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/execution_providers/xnnpack.rs b/src/execution_providers/xnnpack.rs index b344cc3b..87933260 100644 --- a/src/execution_providers/xnnpack.rs +++ b/src/execution_providers/xnnpack.rs @@ -23,7 +23,7 @@ impl XNNPACKExecutionProvider { impl From for ExecutionProviderDispatch { fn from(value: XNNPACKExecutionProvider) -> Self { - ExecutionProviderDispatch::XNNPACK(value) + ExecutionProviderDispatch::new(value) } } diff --git a/src/session/builder.rs b/src/session/builder.rs index 60f716e0..7d654c2a 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -112,7 +112,7 @@ impl SessionBuilder { /// `CUDAExecutionProvider`) **is discouraged** unless you allow the user to configure the execution providers by /// providing a `Vec` of [`ExecutionProviderDispatch`]es. pub fn with_execution_providers(self, execution_providers: impl IntoIterator) -> Result { - apply_execution_providers(&self, execution_providers.into_iter()); + apply_execution_providers(&self, execution_providers.into_iter())?; Ok(self) } @@ -329,7 +329,7 @@ impl SessionBuilder { .collect(); let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned()); + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; if env.has_global_threadpool { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; @@ -406,7 +406,7 @@ impl SessionBuilder { let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut(); let env = get_environment()?; - apply_execution_providers(&self, env.execution_providers.iter().cloned()); + apply_execution_providers(&self, env.execution_providers.iter().cloned())?; if env.has_global_threadpool { ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions]; From a92dd3022eb169280ce1d37c2a92b94124ce515a Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 17 Jun 2024 07:30:54 -0500 Subject: [PATCH 24/49] fix(sys): enable SOCKS proxy support, closes #210 --- ort-sys/Cargo.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index e638ae2f..26e2aa81 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -38,10 +38,9 @@ vitis = [] cann = [] qnn = [] - [build-dependencies] -ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } +ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls", "socks-proxy" ] } tar = { version = "0.4", optional = true } flate2 = { version = "1.0", optional = true } sha2 = { version = "0.10", optional = true } -pkg-config = "0.3.30" \ No newline at end of file +pkg-config = "0.3.30" From 19d66de302fa7bfc065555fec3815e81c6d7e28e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Thu, 20 Jun 2024 09:45:00 -0500 Subject: [PATCH 25/49] feat: MIGraphX execution provider, ref #212 --- src/execution_providers/migraphx.rs | 79 +++++++++++++++++++++++++++++ src/execution_providers/mod.rs | 2 + 2 files changed, 81 insertions(+) create mode 100644 src/execution_providers/migraphx.rs diff --git a/src/execution_providers/migraphx.rs b/src/execution_providers/migraphx.rs new file mode 100644 index 00000000..ebcb50db --- /dev/null +++ b/src/execution_providers/migraphx.rs @@ -0,0 +1,79 @@ +use std::{ffi::CString, ptr}; + +use super::ExecutionProvider; +use crate::{ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder}; + +#[derive(Debug, Default, Clone)] +pub struct MIGraphXExecutionProvider { + device_id: i32, + enable_fp16: bool, + enable_int8: bool, + use_native_calibration_table: bool, + int8_calibration_table_name: Option +} + +impl MIGraphXExecutionProvider { + #[must_use] + pub fn with_device_id(mut self, device_id: i32) -> Self { + self.device_id = device_id; + self + } + + #[must_use] + pub fn with_fp16(mut self) -> Self { + self.enable_fp16 = true; + self + } + + #[must_use] + pub fn with_int8(mut self) -> Self { + self.enable_int8 = true; + self + } + + #[must_use] + pub fn with_native_calibration_table(mut self, table_name: Option>) -> Self { + self.use_native_calibration_table = true; + self.int8_calibration_table_name = table_name.map(|c| CString::new(c.as_ref()).expect("invalid string")); + self + } + + #[must_use] + pub fn build(self) -> ExecutionProviderDispatch { + self.into() + } +} + +impl From for ExecutionProviderDispatch { + fn from(value: MIGraphXExecutionProvider) -> Self { + ExecutionProviderDispatch::new(value) + } +} + +impl ExecutionProvider for MIGraphXExecutionProvider { + fn as_str(&self) -> &'static str { + "MIGraphXExecutionProvider" + } + + fn supported_by_platform(&self) -> bool { + cfg!(any(all(target_os = "linux", target_arch = "x86_64"), all(target_os = "windows", target_arch = "x86_64"))) + } + + #[allow(unused, unreachable_code)] + fn register(&self, session_builder: &SessionBuilder) -> Result<()> { + #[cfg(any(feature = "load-dynamic", feature = "migraphx"))] + { + let options = ort_sys::OrtMIGraphXProviderOptions { + device_id: self.device_id, + migraphx_fp16_enable: self.enable_fp16.into(), + migraphx_int8_enable: self.enable_int8.into(), + migraphx_use_native_calibration_table: self.use_native_calibration_table.into(), + migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) + }; + ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider]; + return Ok(()); + } + + Err(Error::ExecutionProviderNotRegistered(self.as_str())) + } +} diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index 8c237041..c7f49370 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -32,6 +32,8 @@ mod xnnpack; pub use self::xnnpack::XNNPACKExecutionProvider; mod armnn; pub use self::armnn::ArmNNExecutionProvider; +mod migraphx; +pub use self::migraphx::MIGraphXExecutionProvider; /// ONNX Runtime works with different hardware acceleration libraries through its extensible **Execution Providers** /// (EP) framework to optimally execute the ONNX models on the hardware platform. This interface enables flexibility for From c64b8ea990a0c1a25cde6e5e0ff41e396bced631 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 21 Jun 2024 15:37:39 -0500 Subject: [PATCH 26/49] refactor: usability aka The Cleanening, part 2 - Add clearer documentation and examples for more things. - Rework string tensors by introducing `PrimitiveTensorElementType` for primitive (i.e. f32) types, and again re-implementing `IntoTensorElementType` for `String`. This allows string tensors to be used via `Tensor` instead of exclusively via `DynTensor`. Additionally, string tensors no longer require an `Allocator` to be created (which didn't make sense, since string data in Rust can only ever be stored on the CPU anyway). This also now applies to `Map`s, since their data also needed to be on the CPU anyway. (`Sequence`s are currently unaffected because I think a custom allocator could be useful for them?) - Rework the `IoBinding` interface, and add an example clarifying the intended usage of it (ref #209). Thanks to AAce from the pyke Discord for pointing out the mutability issue in the old interface, which should be addressed now. - Refactor `OperatorDomain::add` from the slightly-nicer-looking-but-more-confusing `fn(t: T)` to just `fn()` to further enforce the fact that `Operator`s are zero-sized. - Maps can now have `String` keys. - Remove some unused errors. --- examples/custom-ops/examples/custom-ops.rs | 2 +- src/environment.rs | 80 +++- src/error.rs | 61 +-- src/io_binding.rs | 159 ++++++- src/lib.rs | 39 +- src/memory.rs | 206 ++++++--- src/operator/bound.rs | 8 +- src/operator/kernel.rs | 9 + src/operator/mod.rs | 18 +- src/session/input.rs | 7 +- src/tensor/mod.rs | 2 +- src/tensor/types.rs | 20 + src/value/impl_map.rs | 168 +++++-- src/value/impl_sequence.rs | 31 +- src/value/impl_tensor/create.rs | 97 ++-- src/value/impl_tensor/extract.rs | 499 ++++++++++----------- src/value/impl_tensor/mod.rs | 192 +++++++- src/value/mod.rs | 62 ++- src/wasm.rs | 14 +- tests/vectorizer.rs | 4 +- 20 files changed, 1112 insertions(+), 566 deletions(-) diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 2d590f0c..1206860c 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -82,7 +82,7 @@ impl Kernel for CustomOpTwoKernel { fn main() -> ort::Result<()> { let session = Session::builder()? - .with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)? + .with_operators(OperatorDomain::new("test.customop")?.add::()?.add::()?)? .commit_from_file("tests/data/custom_op_test.onnx")?; let values = session.run(ort::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; diff --git a/src/environment.rs b/src/environment.rs index 343cdc95..810f9e9e 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -34,7 +34,7 @@ pub struct Environment { } impl Environment { - /// Loads the underlying [`ort_sys::OrtEnv`] pointer. + /// Returns the underlying [`ort_sys::OrtEnv`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtEnv { self.env_ptr.load(Ordering::Relaxed) } @@ -52,13 +52,14 @@ impl Drop for Environment { } } -/// Gets a reference to the global environment, creating one if an environment has been committed yet. +/// Gets a reference to the global environment, creating one if an environment has not been +/// [`commit`](EnvironmentBuilder::commit)ted yet. pub fn get_environment() -> Result<&'static Arc> { if let Some(c) = unsafe { &*G_ENV.cell.get() } { Ok(c) } else { debug!("Environment not yet initialized, creating a new one"); - EnvironmentBuilder::default().commit()?; + EnvironmentBuilder::new().commit()?; Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() }) } @@ -72,7 +73,7 @@ pub struct EnvironmentGlobalThreadPoolOptions { pub intra_op_thread_affinity: Option } -/// Struct used to build an `Environment`. +/// Struct used to build an [`Environment`]; see [`crate::init`]. pub struct EnvironmentBuilder { name: String, telemetry: bool, @@ -80,8 +81,8 @@ pub struct EnvironmentBuilder { global_thread_pool_options: Option } -impl Default for EnvironmentBuilder { - fn default() -> Self { +impl EnvironmentBuilder { + pub(crate) fn new() -> Self { EnvironmentBuilder { name: "default".to_string(), telemetry: true, @@ -89,11 +90,9 @@ impl Default for EnvironmentBuilder { global_thread_pool_options: None } } -} -impl EnvironmentBuilder { /// Configure the environment with a given name for logging purposes. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_name(mut self, name: S) -> Self where S: Into @@ -102,7 +101,17 @@ impl EnvironmentBuilder { self } - #[must_use] + /// Enable or disable sending telemetry events to Microsoft. + /// + /// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled. + /// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled. + /// + /// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.0/onnxruntime/core/platform/windows/telemetry.cc). + /// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or + /// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names, + /// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to + /// better understand how customers use ONNX Runtime and where performance can be improved. + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_telemetry(mut self, enable: bool) -> Self { self.telemetry = enable; self @@ -116,14 +125,14 @@ impl EnvironmentBuilder { /// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built /// with support for the corresponding execution provider. Execution providers that do not have their corresponding /// feature enabled will emit a warning. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self { self.execution_providers = execution_providers.as_ref().to_vec(); self } /// Enables the global thread pool for this environment. - #[must_use] + #[must_use = "commit() must be called in order for the environment to take effect"] pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self { self.global_thread_pool_options = Some(options); self @@ -158,14 +167,17 @@ impl EnvironmentBuilder { ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment]; } - ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( + ortsys![ + unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), thread_options, &mut env_ptr - ) -> Error::CreateEnvironment; nonNull(env_ptr)]; + ) -> Error::CreateEnvironment; + nonNull(env_ptr) + ]; ortsys![unsafe ReleaseThreadingOptions(thread_options)]; (env_ptr, true) } else { @@ -174,13 +186,16 @@ impl EnvironmentBuilder { // FIXME: What should go here? let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!()); - ortsys![unsafe CreateEnvWithCustomLogger( + ortsys![ + unsafe CreateEnvWithCustomLogger( logging_function, logger_param, ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), &mut env_ptr - ) -> Error::CreateEnvironment; nonNull(env_ptr)]; + ) -> Error::CreateEnvironment; + nonNull(env_ptr) + ]; (env_ptr, false) }; debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); @@ -205,15 +220,25 @@ impl EnvironmentBuilder { /// Creates an ONNX Runtime environment. /// +/// ``` +/// # use ort::CUDAExecutionProvider; +/// # fn main() -> ort::Result<()> { +/// ort::init() +/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .commit()?; +/// # Ok(()) +/// # } +/// ``` +/// /// # Notes /// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a /// default environment will be created. -/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it. +/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it. /// - In order for environment settings to apply, this must be called **before** you use other APIs like /// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function. -#[must_use] +#[must_use = "commit() must be called in order for the environment to take effect"] pub fn init() -> EnvironmentBuilder { - EnvironmentBuilder::default() + EnvironmentBuilder::new() } /// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`) @@ -221,15 +246,26 @@ pub fn init() -> EnvironmentBuilder { /// /// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded. /// +/// ```no_run +/// # use ort::CUDAExecutionProvider; +/// # fn main() -> ort::Result<()> { +/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib"); +/// ort::init_from(lib_path.join("onnxruntime.dll")) +/// .with_execution_providers([CUDAExecutionProvider::default().build()]) +/// .commit()?; +/// # Ok(()) +/// # } +/// ``` +/// /// # Notes /// - In order for environment settings to apply, this must be called **before** you use other APIs like /// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function. #[cfg(feature = "load-dynamic")] #[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))] -#[must_use] +#[must_use = "commit() must be called in order for the environment to take effect"] pub fn init_from(path: impl ToString) -> EnvironmentBuilder { let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string())); - EnvironmentBuilder::default() + EnvironmentBuilder::new() } /// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct. @@ -325,7 +361,7 @@ mod tests { assert!(!is_env_initialized()); assert_eq!(env_ptr(), None); - EnvironmentBuilder::default().with_name("env_is_initialized").commit()?; + EnvironmentBuilder::new().with_name("env_is_initialized").commit()?; assert!(is_env_initialized()); assert_ne!(env_ptr(), None); Ok(()) diff --git a/src/error.rs b/src/error.rs index 7bdb2ba7..fb25f204 100644 --- a/src/error.rs +++ b/src/error.rs @@ -121,9 +121,6 @@ pub enum Error { /// Error occurred when filling a tensor with string data #[error("Failed to fill string tensor: {0}")] FillStringTensor(ErrorInternal), - /// Error occurred when checking if a value is a tensor - #[error("Failed to check if value is a tensor: {0}")] - FailedTensorCheck(ErrorInternal), /// Error occurred when getting tensor type and shape #[error("Failed to get tensor type and shape: {0}")] GetTensorTypeAndShape(ErrorInternal), @@ -159,12 +156,6 @@ pub enum Error { /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models). #[error("Failed to download ONNX model: {0}")] DownloadError(#[from] FetchModelError), - /// Type of input data and the ONNX model do not match. - #[error("Data types do not match: expected {model:?}, got {input:?}")] - NonMatchingDataTypes { input: TensorElementType, model: TensorElementType }, - /// Dimensions of input data and the ONNX model do not match. - #[error("Dimensions do not match: {0:?}")] - NonMatchingDimensions(NonMatchingDimensionsError), /// File does not exist #[error("File `{filename:?}` does not exist")] FileDoesNotExist { @@ -186,9 +177,6 @@ pub enum Error { /// ORT pointer should not have been null #[error("`{0}` should not be a null pointer")] PointerShouldNotBeNull(&'static str), - /// The runtime type was undefined. - #[error("Undefined tensor element type")] - UndefinedTensorElementType, /// Could not retrieve model metadata. #[error("Failed to retrieve model metadata: {0}")] GetModelMetadata(ErrorInternal), @@ -208,8 +196,8 @@ pub enum Error { ExecutionProviderNotRegistered(&'static str), #[error("Expected tensor to be on CPU in order to get data, but had allocation device `{0}`.")] TensorNotOnCpu(&'static str), - #[error("String tensors require the session's allocator to be provided through `Value::from_array`.")] - StringTensorRequiresAllocator, + #[error("Cannot extract scalar value from a {0}-dimensional tensor")] + TensorNot0Dimensional(usize), #[error("Failed to create memory info: {0}")] CreateMemoryInfo(ErrorInternal), #[error("Could not get allocation device from `MemoryInfo`: {0}")] @@ -222,10 +210,10 @@ pub enum Error { BindInput(ErrorInternal), #[error("Error when binding output: {0}")] BindOutput(ErrorInternal), - #[error("Failed to clear IO binding: {0}")] - ClearBinding(ErrorInternal), #[error("Error when retrieving session outputs from `IoBinding`: {0}")] GetBoundOutputs(ErrorInternal), + #[error("Cannot use `extract_tensor` on a value that is {0:?}")] + NotTensor(ValueType), #[error("Cannot use `extract_sequence` on a value that is {0:?}")] NotSequence(ValueType), #[error("Cannot use `extract_map` on a value that is {0:?}")] @@ -252,6 +240,8 @@ pub enum Error { GetOperatorInput(ErrorInternal), #[error("Failed to get operator output: {0}")] GetOperatorOutput(ErrorInternal), + #[error("Failed to retrieve GPU compute stream from kernel context: {0}")] + GetOperatorGPUComputeStream(ErrorInternal), #[error("{0}")] CustomError(#[from] Box), #[error("String tensors cannot be borrowed as mutable")] @@ -266,37 +256,20 @@ pub enum Error { GetDeviceId(ErrorInternal) } -impl From for Error { - fn from(_: Infallible) -> Self { - Error::Infallible +impl Error { + /// Wrap a custom, user-provided error in an [`ort::Error`](Error). The resulting error will be the + /// [`Error::CustomError`] variant. + /// + /// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort` + /// related operation fails. + pub fn wrap(err: T) -> Self { + Error::CustomError(Box::new(err) as Box) } } -/// Error used when the input dimensions defined in the model and passed from an inference call do not match. -#[non_exhaustive] -#[derive(Error, Debug)] -pub enum NonMatchingDimensionsError { - /// Number of inputs from model does not match the number of inputs from inference call. - #[error( - "Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})" - )] - InputsCount { - /// Number of input dimensions used by inference call - inference_input_count: usize, - /// Number of input dimensions defined in model - model_input_count: usize, - /// Input dimensions used by inference call - inference_input: Vec>, - /// Input dimensions defined in model - model_input: Vec>> - }, - /// Inputs length from model does not match the expected input from inference call - #[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")] - InputsLength { - /// Input dimensions used by inference call - inference_input: Vec>, - /// Input dimensions defined in model - model_input: Vec>> +impl From for Error { + fn from(_: Infallible) -> Self { + Error::Infallible } } diff --git a/src/io_binding.rs b/src/io_binding.rs index a94f2913..36b11e61 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -1,6 +1,8 @@ use std::{ + collections::HashMap, ffi::CString, fmt::Debug, + marker::PhantomData, ptr::{self, NonNull}, sync::Arc }; @@ -9,24 +11,87 @@ use crate::{ memory::MemoryInfo, ortsys, session::{output::SessionOutputs, RunOptions}, - value::{Value, ValueRefMut}, - Error, Result, Session, ValueTypeMarker + value::{Value, ValueInner}, + DynValue, Error, Result, Session, ValueTypeMarker }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. /// -/// Note that this arrangement is designed to minimize data copies, and to that effect, your memory allocations must -/// match what is expected by the model, whether you run on CPU or GPU. Data will still be copied if the -/// pre-allocated memory location does not match the one expected by the model. However, copies with `IoBinding`s are -/// only done once, at the time of the binding, not at run time. This means, that if your input data required a copy, -/// your further input modifications would not be seen by ONNX Runtime unless you rebind it, even if it is the same -/// buffer. If your scenario requires that the data is copied, `IoBinding` may not be the best match for your use case. -/// The fact that data copy is not made during runtime may also have performance implications. +/// [`IoBinding`] minimizes copies between a device (like a GPU) and the host (CPU) by allowing the user to bind a +/// certain input/output to a pre-allocated value on a specific device. +/// +/// [`IoBinding`] is most suitable for: +/// - An ensemble of models in which the output from one model is the input of another and does not need to pass through +/// the CPU to perform additional processing. +/// - Situations where the output should stay on a device (e.g. to perform additional processing with CUDA). +/// - Diffusion models, for instance, that accept an unchanging embedding for conditioning. +/// +/// [`IoBinding`] will not provide any meaningful benefit for: +/// - Models where every input changes with each invocation, such as a causal language model or object recognition +/// model. +/// - Pipelines that go straight from CPU -> GPU -> CPU. +/// +/// # Example +/// A diffusion model which takes a text condition input. +/// +/// ```no_run +/// # use ort::{Allocator, AllocatorType, AllocationDevice, CUDAExecutionProvider, MemoryInfo, MemoryType, Session, Tensor, IoBinding}; +/// # fn main() -> ort::Result<()> { +/// let text_encoder = Session::builder()? +/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .commit_from_file("text_encoder.onnx")?; +/// let unet = Session::builder()? +/// .with_execution_providers([CUDAExecutionProvider::default().build()])? +/// .commit_from_file("unet.onnx")?; +/// +/// let text_condition = text_encoder +/// .run(ort::inputs![Tensor::::from_array(( +/// vec![27], +/// vec![ +/// 23763, 15460, 473, 68, 312, 265, 17463, 4098, 304, 1077, 283, 198, 7676, 5976, 272, 285, 3609, 435, +/// 21680, 321, 265, 300, 1689, 64, 285, 4763, 64 +/// ] +/// ))?]?)? +/// .remove("output0") +/// .unwrap(); +/// +/// let input_allocator = Allocator::new( +/// &unet, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? +/// )?; +/// let mut latents = Tensor::::new(&input_allocator, [1, 4, 64, 64])?; +/// +/// let mut io_binding = unet.create_binding()?; +/// io_binding.bind_input("condition", &text_condition)?; +/// +/// let output_allocator = Allocator::new( +/// &unet, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUOutput)? +/// )?; +/// io_binding.bind_output("noise_pred", Tensor::::new(&output_allocator, [1, 4, 64, 64])?)?; +/// +/// for _ in 0..20 { +/// io_binding.bind_input("latents", &latents)?; +/// let noise_pred = io_binding.run()?.remove("noise_pred").unwrap(); +/// +/// let mut latents = latents.extract_tensor_mut(); +/// latents += &noise_pred.try_extract_tensor::()?; +/// } +/// # Ok(()) +/// # } +/// ``` +/// +/// [`IoBinding`] may provide a decent speedup in this example since the `condition` tensor is unchanging between runs. +/// If we were to use normal session inference, the `condition` tensor would be needlessly copied with each invocation +/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition` +/// tensor is only copied to the device once instead of 20 times. #[derive(Debug)] pub struct IoBinding<'s> { pub(crate) ptr: NonNull, session: &'s Session, - output_names: Vec + held_inputs: HashMap>, + output_names: Vec, + output_values: HashMap } impl<'s> IoBinding<'s> { @@ -36,25 +101,47 @@ impl<'s> IoBinding<'s> { Ok(Self { ptr: unsafe { NonNull::new_unchecked(ptr) }, session, - output_names: Vec::new() + held_inputs: HashMap::new(), + output_names: Vec::new(), + output_values: HashMap::new() }) } /// Bind a [`Value`] to a session input. - pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'i mut Value) -> Result> { + /// + /// Upon invocation, the value's data will be copied to the device the session is allocated on. The copied data will + /// be used as an input (specified by `name`) in all future invocations of [`IoBinding::run`] until the input is + /// overridden (by calling [`IoBinding::bind_input`] again) or until all inputs are cleared (via + /// [`IoBinding::clear_inputs`] or [`IoBinding::clear`]). + /// + /// The data is only copied **once**, immediately upon invocation of this function. Any changes to the given + /// value afterwards will not affect the data seen by the session until the value is re-bound. Subsequent re-binds + /// will still copy data, hence why [`IoBinding`] is really only suitable when one or more inputs do not change + /// between runs. + pub fn bind_input>(&mut self, name: S, ort_value: &Value) -> Result<()> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput]; - Ok(ort_value.view_mut()) + self.held_inputs.insert(name.to_string(), Arc::clone(&ort_value.inner)); + Ok(()) } /// Bind a session output to a pre-allocated [`Value`]. - pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'o mut Value) -> Result> { + /// + /// This allows for the pre-allocation and reuse of memory in the session output (see [`crate::Tensor::new`]). Any + /// subsequent runs via [`IoBinding::run`] will reuse the same tensor to store the output instead of creating a new + /// one each time. + /// + /// The output will be accessible in the value returned by [`IoBinding::run`], under the name specified by `name`. + pub fn bind_output>(&mut self, name: S, ort_value: Value) -> Result<()> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput]; self.output_names.push(name.to_string()); - Ok(ort_value.view_mut()) + // Clear the old bound output if we have any. + drop(self.output_values.remove(name)); + self.output_values.insert(name.to_string(), ort_value.into_dyn()); + Ok(()) } /// Bind a session output to a device which is specified by `mem_info`. @@ -66,15 +153,35 @@ impl<'s> IoBinding<'s> { Ok(()) } - pub fn run<'i: 's>(&'i self) -> Result> { + /// Clears all bound inputs specified by [`IoBinding::bind_input`]. + pub fn clear_inputs(&mut self) { + ortsys![unsafe ClearBoundInputs(self.ptr.as_ptr())]; + drop(self.held_inputs.drain()); + } + /// Clears all bound outputs specified by [`IoBinding::bind_output`] or [`IoBinding::bind_output_to_device`]. + pub fn clear_outputs(&mut self) { + ortsys![unsafe ClearBoundOutputs(self.ptr.as_ptr())]; + drop(self.output_names.drain(..)); + drop(self.output_values.drain()); + } + /// Clears both the bound inputs & outputs; equivalent to [`IoBinding::clear_inputs`] followed by + /// [`IoBinding::clear_outputs`]. + pub fn clear(&mut self) { + self.clear_inputs(); + self.clear_outputs(); + } + + /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. + pub fn run(&mut self) -> Result> { self.run_inner(None) } - pub fn run_with_options<'i: 's>(&'i self, run_options: Arc) -> Result> { + /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. + pub fn run_with_options(&mut self, run_options: Arc) -> Result> { self.run_inner(Some(run_options)) } - fn run_inner<'i: 's>(&'i self, run_options: Option>) -> Result> { + fn run_inner(&mut self, run_options: Option>) -> Result> { let run_options_ptr = if let Some(run_options) = run_options { run_options.run_options_ptr.as_ptr() } else { @@ -82,6 +189,7 @@ impl<'s> IoBinding<'s> { }; ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr()) -> Error::SessionRunWithIoBinding]; + let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect(); let mut count = self.output_names.len() as ort_sys::size_t; if count > 0 { let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut(); @@ -91,10 +199,17 @@ impl<'s> IoBinding<'s> { let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() } .into_iter() .map(|v| unsafe { - Value::from_ptr( - NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), - Some(Arc::clone(&self.session.inner)) - ) + if let Some(inner) = owned_ptrs.get(&v) { + DynValue { + inner: Arc::clone(*inner), + _markers: PhantomData + } + } else { + DynValue::from_ptr( + NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), + Some(Arc::clone(&self.session.inner)) + ) + } }); // output values will be freed when the `Value`s in `SessionOutputs` drop diff --git a/src/lib.rs b/src/lib.rs index ed5a2adf..a071b867 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ pub use self::session::{ #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; -pub use self::tensor::{IntoTensorElementType, TensorElementType}; +pub use self::tensor::{IntoTensorElementType, Utf8Data, PrimitiveTensorElementType, TensorElementType}; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, @@ -143,6 +143,23 @@ pub(crate) static G_ORT_API: OnceLock> = OnceLock::ne /// May panic if: /// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. /// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +/// +/// # Examples +/// The primary (public-facing) use case for this function is accessing APIs that do not have a corresponding safe +/// implementation in `ort`. For example, [`GetBuildInfoString`](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0a7dba37b0017c0ef3a0ab4e266a967d): +/// +/// ``` +/// # use std::ffi::CStr; +/// # fn main() -> ort::Result<()> { +/// let api = ort::api().as_ptr(); +/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) }; +/// println!("{}", build_info.to_string_lossy()); +/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED +/// # Ok(()) +/// # } +/// ``` +/// +/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html). pub fn api() -> NonNull { unsafe { NonNull::new_unchecked( @@ -252,6 +269,26 @@ pub(crate) fn char_p_to_string(raw: *const c_char) -> Result { .map_err(Error::FfiStringConversion) } +pub(crate) struct PrivateTraitMarker; + +macro_rules! private_trait { + () => { + #[doc(hidden)] + #[allow(private_interfaces)] + fn _private() -> crate::PrivateTraitMarker; + }; +} +macro_rules! private_impl { + () => { + #[allow(private_interfaces)] + fn _private() -> crate::PrivateTraitMarker { + crate::PrivateTraitMarker + } + }; +} +pub(crate) use private_impl; +pub(crate) use private_trait; + #[cfg(test)] mod test { use std::ffi::CString; diff --git a/src/memory.rs b/src/memory.rs index 00464f27..bc7644ca 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,20 +1,75 @@ use std::{ ffi::{c_char, c_int, CString}, - ptr::NonNull + ptr::NonNull, + sync::Arc }; use super::{ error::{Error, Result}, ortsys }; -use crate::{char_p_to_string, error::status_to_result, Session}; +use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner}; -/// An ONNX Runtime allocator, used to manage the allocation of [`crate::Value`]s. +/// A device allocator used to manage the allocation of [`crate::Value`]s. +/// +/// # Direct allocation +/// [`Allocator`] can be used to directly allocate device memory. This can be useful if you have a +/// postprocessing step that runs on the GPU. +/// ```no_run +/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let allocator = Allocator::new( +/// &session, +/// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? +/// )?; +/// +/// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; +/// // Here, `data_ptr` is a pointer to **device memory** inaccessible to the CPU; we'll need another crate, like +/// // `cudarc`, to access it. +/// let data_ptr = tensor.data_ptr_mut()?; +/// # Ok(()) +/// # } +/// ``` +/// +/// Note that `ort` does not facilitate the transfer of data between host & device outside of session inputs & +/// outputs; you'll need to use a separate crate for that, like [`cudarc`](https://crates.io/crates/cudarc) for CUDA. +/// +/// # Pinned allocation +/// Memory allocated on the host CPU is often *pageable* and may reside on the disk (swap memory). Transferring +/// pageable memory to another device is slow because the device has to go through the CPU to access the +/// memory. Many execution providers thus provide a "pinned" allocator type, which allocates *unpaged* CPU memory +/// that the device is able to access directly, bypassing the CPU and allowing for faster host-to-device data +/// transfer. +/// +/// If you create a session with a device allocator that supports pinned memory, like CUDA or ROCm, you can create +/// an allocator for that session, and use it to allocate tensors with faster pinned memory: +/// ```no_run +/// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; +/// # fn main() -> ort::Result<()> { +/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let allocator = Allocator::new( +/// &session, +/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? +/// )?; +/// +/// // Create a tensor with our pinned allocator. +/// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; +/// let data = tensor.extract_tensor_mut(); +/// // ...fill `data` with data... +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] pub struct Allocator { pub(crate) ptr: NonNull, + /// The 'default' CPU allocator, provided by `GetAllocatorWithDefaultOptions` and implemented by + /// [`Allocator::default`], should **not** be released, so this field marks whether or not we should call + /// `ReleaseAllocator` on drop. is_default: bool, - _info: Option + _info: Option, + /// Hold a reference to the session if this allocator is tied to one. + _session_inner: Option> } impl Allocator { @@ -22,47 +77,46 @@ impl Allocator { Allocator { ptr: NonNull::new_unchecked(ptr), is_default: false, + // currently, this function is only ever used in session creation, where we call `CreateAllocator` manually and store the allocator resulting from + // this function in the `SharedSessionInner` - we don't need to hold onto the session, because the session is holding onto us. + _session_inner: None, _info: None } } + /// Frees an object allocated by this allocator, given the object's C pointer. pub(crate) unsafe fn free(&self, ptr: *mut T) { self.ptr.as_ref().Free.unwrap_or_else(|| unreachable!("Allocator method `Free` is null"))(self.ptr.as_ptr(), ptr.cast()); } /// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the /// [`MemoryInfo`]. - /// - /// For example, to create an allocator to allocate pinned memory for CUDA: - /// ```no_run - /// # use ort::{Allocator, Session, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// let allocator = Allocator::new( - /// &session, - /// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? - /// )?; - /// # Ok(()) - /// # } - /// ``` pub fn new(session: &Session, memory_info: MemoryInfo) -> Result { let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)]; Ok(Self { ptr: unsafe { NonNull::new_unchecked(allocator_ptr) }, is_default: false, + _session_inner: Some(session.inner()), _info: Some(memory_info) }) } } impl Default for Allocator { + /// Returns the default CPU allocator; equivalent to `MemoryInfo::new(AllocationDevice::CPU, 0, + /// AllocatorType::Device, MemoryType::Default)`. + /// + /// The allocator returned by this function is actually shared across all invocations (though this behavior is + /// transparent to the user). fn default() -> Self { let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut(); status_to_result(ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]).expect("Failed to get default allocator"); Self { ptr: unsafe { NonNull::new_unchecked(allocator_ptr) }, is_default: true, + // The default allocator isn't tied to a session. + _session_inner: None, _info: None } } @@ -70,8 +124,6 @@ impl Default for Allocator { impl Drop for Allocator { fn drop(&mut self) { - // per GetAllocatorWithDefaultOptions docs: Returned value should NOT be freed - // https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a8dec797ae52ee1a681e4f88be1fb4bb3 if !self.is_default { ortsys![unsafe ReleaseAllocator(self.ptr.as_ptr())]; } @@ -81,7 +133,8 @@ impl Drop for Allocator { /// Represents possible devices that have their own device allocator. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocationDevice { - // https://github.com/microsoft/onnxruntime/blob/v1.17.0/include/onnxruntime/core/framework/allocator.h#L43-L53 + // https://github.com/microsoft/onnxruntime/blob/v1.18.0/include/onnxruntime/core/framework/allocator.h#L43-L53 + // ort will likely never support WebGPU, so I think it's best to leave `WebGPU_Buffer` out entirely to reduce confusion CPU, CUDA, CUDAPinned, @@ -91,12 +144,10 @@ pub enum AllocationDevice { HIP, HIPPinned, OpenVINOCPU, - OpenVINOGPU, - WebGPUBuffer + OpenVINOGPU } impl AllocationDevice { - #[must_use] pub fn as_str(&self) -> &'static str { match self { Self::CPU => "Cpu", @@ -108,10 +159,15 @@ impl AllocationDevice { Self::HIP => "Hip", Self::HIPPinned => "HipPinned", Self::OpenVINOCPU => "OpenVINO_CPU", - Self::OpenVINOGPU => "OpenVINO_GPU", - Self::WebGPUBuffer => "WebGPU_Buffer" + Self::OpenVINOGPU => "OpenVINO_GPU" } } + + /// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device, + /// it could be extracted to an `ndarray` or slice. + pub fn is_cpu_accessible(&self) -> bool { + matches!(self, Self::CPU | Self::CUDAPinned | Self::CANNPinned | Self::HIPPinned | Self::OpenVINOCPU) + } } impl TryFrom for AllocationDevice { @@ -129,14 +185,13 @@ impl TryFrom for AllocationDevice { "HipPinned" => Ok(AllocationDevice::HIPPinned), "OpenVINO_CPU" => Ok(AllocationDevice::OpenVINOCPU), "OpenVINO_GPU" => Ok(AllocationDevice::OpenVINOGPU), - "WebGPUBuffer" => Ok(AllocationDevice::WebGPUBuffer), _ => Err(value) } } } /// Execution provider allocator type. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocatorType { /// Default device-specific allocator. Device, @@ -154,11 +209,11 @@ impl From for ort_sys::OrtAllocatorType { } /// Memory types for allocated memory. -#[derive(Default, Debug, Copy, Clone)] +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] pub enum MemoryType { /// Any CPU memory used by non-CPU execution provider. CPUInput, - /// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED. + /// CPU-accessible memory output by a non-CPU execution provider, i.e. [`AllocatorDevice::CUDAPinned`]. CPUOutput, /// The default allocator for an execution provider. #[default] @@ -190,6 +245,12 @@ impl From for MemoryType { } } +/// Structure describing a memory location - the device on which the memory resides, the type of allocator (device +/// default, or arena) used, and the type of memory allocated (device-only, or CPU accessible). +/// +/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which +/// device value data should reside, and how that data should be accessible with regard to the CPU (if a non-CPU device +/// is requested). #[derive(Debug)] pub struct MemoryInfo { pub(crate) ptr: NonNull, @@ -197,24 +258,24 @@ pub struct MemoryInfo { } impl MemoryInfo { - pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { - MemoryInfo { ptr, should_release } - } - - #[tracing::instrument] - pub fn new_cpu(allocator: AllocatorType, memory_type: MemoryType) -> Result { - let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); - ortsys![ - unsafe CreateCpuMemoryInfo(allocator.into(), memory_type.into(), &mut memory_info_ptr) -> Error::CreateMemoryInfo; - nonNull(memory_info_ptr) - ]; - Ok(Self { - ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) }, - should_release: true - }) - } - - #[tracing::instrument] + /// Creates a [`MemoryInfo`], describing a memory location on a device allocator. + /// + /// # Examples + /// `MemoryInfo` can be used to specify the device & memory type used by an [`Allocator`] to allocate tensors. + /// See [`Allocator`] for more information & potential applications. + /// ```no_run + /// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// + /// let mut tensor = Tensor::::new(&allocator, [1, 3, 224, 224])?; + /// # Ok(()) + /// # } + /// ``` pub fn new(allocation_device: AllocationDevice, device_id: c_int, allocator_type: AllocatorType, memory_type: MemoryType) -> Result { let mut memory_info_ptr: *mut ort_sys::OrtMemoryInfo = std::ptr::null_mut(); let allocator_name = CString::new(allocation_device.as_str()).unwrap_or_else(|_| unreachable!()); @@ -229,7 +290,19 @@ impl MemoryInfo { }) } + pub(crate) fn from_raw(ptr: NonNull, should_release: bool) -> Self { + MemoryInfo { ptr, should_release } + } + /// Returns the [`MemoryType`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.memory_type()?, MemoryType::Default); + /// # Ok(()) + /// # } + /// ``` pub fn memory_type(&self) -> Result { let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault; ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetMemoryType]; @@ -237,6 +310,14 @@ impl MemoryInfo { } /// Returns the [`AllocatorType`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.allocator_type()?, AllocatorType::Device); + /// # Ok(()) + /// # } + /// ``` pub fn allocator_type(&self) -> Result { let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator; ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetAllocatorType]; @@ -248,6 +329,14 @@ impl MemoryInfo { } /// Returns the [`AllocationDevice`] this struct was created with. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.allocation_device()?, AllocationDevice::CPU); + /// # Ok(()) + /// # } + /// ``` pub fn allocation_device(&self) -> Result { let mut name_ptr: *const c_char = std::ptr::null_mut(); ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr) -> Error::GetAllocationDevice; nonNull(name_ptr)]; @@ -258,6 +347,14 @@ impl MemoryInfo { } /// Returns the ID of the [`AllocationDevice`] described by this struct. + /// ``` + /// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # fn main() -> ort::Result<()> { + /// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?; + /// assert_eq!(mem.device_id()?, 0); + /// # Ok(()) + /// # } + /// ``` pub fn device_id(&self) -> Result { let mut raw: ort_sys::c_int = 0; ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw) -> Error::GetDeviceId]; @@ -266,24 +363,9 @@ impl MemoryInfo { } impl Drop for MemoryInfo { - #[tracing::instrument] fn drop(&mut self) { if self.should_release { ortsys![unsafe ReleaseMemoryInfo(self.ptr.as_ptr())]; } } } - -#[cfg(test)] -mod tests { - use test_log::test; - - use super::*; - - #[test] - fn create_memory_info() -> crate::Result<()> { - let memory_info = MemoryInfo::new_cpu(AllocatorType::Device, MemoryType::Default)?; - std::mem::drop(memory_info); - Ok(()) - } -} diff --git a/src/operator/bound.rs b/src/operator/bound.rs index f46c6c15..58219c47 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -11,8 +11,7 @@ use super::{ }; use crate::error::IntoStatus; -#[repr(C)] -#[derive(Clone)] +#[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later pub(crate) struct BoundOperator { implementation: ort_sys::OrtCustomOp, name: CString, @@ -184,7 +183,10 @@ unsafe impl Send for ErasedBoundOperator {} impl ErasedBoundOperator { pub(crate) fn new(bound: BoundOperator) -> Self { - ErasedBoundOperator(NonNull::from(unsafe { &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) })) + ErasedBoundOperator(NonNull::from(unsafe { + // horrible horrible horrible horrible horrible horrible horrible horrible horrible + &mut *(Box::leak(Box::new(bound)) as *mut _ as *mut ()) + })) } pub(crate) fn op_ptr(&self) -> *mut ort_sys::OrtCustomOp { diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 61758f6e..db09aa28 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -138,4 +138,13 @@ impl KernelContext { ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx as ort_sys::size_t, shape.as_ptr(), shape.len() as _, &mut value_ptr) -> Error::GetOperatorOutput]; Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) }))) } + + /// Returns a pointer to the GPU compute stream (i.e. `cudaStream_t`) used by the execution provider, if this + /// kernel's operator was configured to use said execution provider (see + /// [`super::Operator::execution_provider_type`]). + pub fn compute_stream(&self) -> Result>> { + let mut stream_ptr: *mut ort_sys::c_void = ptr::null_mut(); + ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetOperatorGPUComputeStream]; + Ok(NonNull::new(stream_ptr)) + } } diff --git a/src/operator/mod.rs b/src/operator/mod.rs index 7e873516..ad361f29 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -16,11 +16,26 @@ use crate::{operator::bound::BoundOperator, ortsys, Error, Result}; pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>; +/// A custom operator descriptor, which describes the expected inputs & outputs of a graph operator. +/// +/// [`Operator`]s are bound to [`OperatorDomain`]s. Multiple operators can have the same name as long as they have +/// different input/output types, in which case the exact operator will be picked depending on the input/output +/// types. If you want to, for example, define a `Sort` operator that can accept either a single `f32` or `i64` tensor +/// input, you'll need to define 2 separate operators (which can be done via a macro); but both of these +/// [`Operator`] structs can return the same name in [`Operator::name`] so that they are usable as simply +/// `my.domain:Sort` in the graph. pub trait Operator: Send { type Kernel: Kernel; + /// Returns the name of the operator. fn name() -> &'static str; + /// Returns the execution provider this operator runs on, e.g. `CUDAExecutionProvider`. + /// + /// If the returned type is not `None`, and the execution provider used by the session matches this operator's + /// EP type, the value will not be copied to the CPU and you may use functions like [`crate::Tensor::data_ptr`] to + /// access the underlying device memory, and [`super::KernelContext::compute_stream`] to access the GPU compute + /// stream. fn execution_provider_type() -> Option<&'static str> { None } @@ -42,6 +57,7 @@ pub trait Operator: Send { } } +/// Dummy type implementing [`Operator`] used by [`ErasedBoundOperator`] to cheat the type system. struct DummyOperator; impl Operator for DummyOperator { @@ -84,7 +100,7 @@ impl OperatorDomain { } #[allow(clippy::should_implement_trait)] - pub fn add(mut self, _operator: O) -> Result { + pub fn add(mut self) -> Result { let name = O::name(); let bound = BoundOperator::::new(CString::new(name)?, O::execution_provider_type().map(CString::new).transpose()?); diff --git a/src/session/input.rs b/src/session/input.rs index ce33f006..61d55e5c 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -92,16 +92,15 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs< /// # } /// ``` /// -/// Note that string tensors must be created manually with [`Value::from_string_array`]. +/// Note that string tensors must be created manually with [`crate::Tensor::from_string_array`]. /// /// ```no_run /// # use std::{error::Error, sync::Arc}; /// # use ndarray::Array1; -/// # use ort::{GraphOptimizationLevel, Session, Value}; +/// # use ort::{GraphOptimizationLevel, Session, Tensor}; /// # fn main() -> Result<(), Box> { /// # let mut session = Session::builder()?.commit_from_file("model.onnx")?; -/// let _ = session -/// .run(ort::inputs![Value::from_string_array(session.allocator(), Array1::from_vec(vec!["hello", "world"]))?]?); +/// let _ = session.run(ort::inputs![Tensor::from_string_array(Array1::from_vec(vec!["hello", "world"]))?]?); /// # Ok(()) /// # } /// ``` diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 98cf2ae2..a1a5440e 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -22,4 +22,4 @@ mod types; pub use self::ndarray::ArrayExtensions; #[cfg(feature = "ndarray")] pub(crate) use self::types::{extract_primitive_array, extract_primitive_array_mut}; -pub use self::types::{IntoTensorElementType, TensorElementType, Utf8Data}; +pub use self::types::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; diff --git a/src/tensor/types.rs b/src/tensor/types.rs index 08a57d8f..aabe6839 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -91,6 +91,12 @@ impl From for TensorElementType { pub trait IntoTensorElementType { /// Returns the ONNX tensor element data type corresponding to the given Rust type. fn into_tensor_element_type() -> TensorElementType; + + crate::private_trait!(); +} + +pub trait PrimitiveTensorElementType: IntoTensorElementType { + crate::private_trait!(); } macro_rules! impl_type_trait { @@ -99,6 +105,12 @@ macro_rules! impl_type_trait { fn into_tensor_element_type() -> TensorElementType { TensorElementType::$variant } + + crate::private_impl!(); + } + + impl PrimitiveTensorElementType for $type_ { + crate::private_impl!(); } }; } @@ -121,6 +133,14 @@ impl_type_trait!(u64, Uint64); #[cfg_attr(docsrs, doc(cfg(feature = "half")))] impl_type_trait!(half::bf16, Bfloat16); +impl IntoTensorElementType for String { + fn into_tensor_element_type() -> TensorElementType { + TensorElementType::String + } + + crate::private_impl!(); +} + /// Adapter for common Rust string types to ONNX strings. pub trait Utf8Data { /// Returns the contents of this value as a slice of UTF-8 bytes. diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 1421e5a5..f6387876 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -3,25 +3,39 @@ use std::{ fmt::Debug, hash::Hash, marker::PhantomData, - ptr::{self, NonNull} + ptr::{self, NonNull}, + sync::Arc }; use super::{ValueInner, ValueTypeMarker}; use crate::{ - memory::Allocator, ortsys, value::impl_tensor::DynTensor, DynValue, Error, IntoTensorElementType, Result, Tensor, Value, ValueRef, ValueRefMut, ValueType + memory::Allocator, + ortsys, + value::impl_tensor::{calculate_tensor_size, DynTensor}, + DynValue, Error, IntoTensorElementType, PrimitiveTensorElementType, Result, Tensor, TensorElementType, Value, ValueRef, ValueRefMut, ValueType }; -pub trait MapValueTypeMarker: ValueTypeMarker {} +pub trait MapValueTypeMarker: ValueTypeMarker { + crate::private_trait!(); +} #[derive(Debug)] pub struct DynMapValueType; -impl ValueTypeMarker for DynMapValueType {} -impl MapValueTypeMarker for DynMapValueType {} +impl ValueTypeMarker for DynMapValueType { + crate::private_impl!(); +} +impl MapValueTypeMarker for DynMapValueType { + crate::private_impl!(); +} #[derive(Debug)] pub struct MapValueType(PhantomData<(K, V)>); -impl ValueTypeMarker for MapValueType {} -impl MapValueTypeMarker for MapValueType {} +impl ValueTypeMarker for MapValueType { + crate::private_impl!(); +} +impl MapValueTypeMarker for MapValueType { + crate::private_impl!(); +} pub type DynMap = Value; pub type Map = Value>; @@ -32,10 +46,7 @@ pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType>; pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType>; impl Value { - pub fn try_extract_map( - &self, - allocator: &Allocator - ) -> Result> { + pub fn try_extract_map(&self) -> Result> { match self.dtype()? { ValueType::Map { key, value } => { let k_type = K::into_tensor_element_type(); @@ -47,47 +58,95 @@ impl Value { return Err(Error::InvalidMapValueType { expected: v_type, actual: value }); } + let allocator = Allocator::default(); + let mut key_tensor_ptr = ptr::null_mut(); ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr) -> Error::ExtractMap; nonNull(key_tensor_ptr)]; let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) }; - let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_tensor::()?; - - let mut value_tensor_ptr = ptr::null_mut(); - ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; - let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; - let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; - - assert_eq!(key_tensor_shape.len(), 1); - assert_eq!(value_tensor_shape.len(), 1); - assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); - - let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); - for i in 0..key_tensor_shape[0] as usize { - vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + if K::into_tensor_element_type() != TensorElementType::String { + let dtype = key_value.dtype()?; + let (key_tensor_shape, key_tensor) = match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = key_value.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == K::into_tensor_element_type() { + let mut output_array_ptr: *mut K = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut K = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(key_tensor_ptr, output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + let len = calculate_tensor_size(&dimensions); + (dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) }) + } else { + return Err(Error::DataTypeMismatch { + actual: ty, + requested: K::into_tensor_element_type() + }); + } + } + _ => unreachable!() + }; + + let mut value_tensor_ptr = ptr::null_mut(); + ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + + assert_eq!(key_tensor_shape.len(), 1); + assert_eq!(value_tensor_shape.len(), 1); + assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + + let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); + for i in 0..key_tensor_shape[0] as usize { + vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + } + Ok(vec.into_iter().collect()) + } else { + let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_string_tensor()?; + // SAFETY: `IntoTensorElementType` is a private trait, and we only map the `String` type to `TensorElementType::String`, + // so at this point, `K` is **always** the `String` type, and this transmute really does nothing but please the type + // checker. + let key_tensor: Vec = unsafe { std::mem::transmute(key_tensor) }; + + let mut value_tensor_ptr = ptr::null_mut(); + ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; + + assert_eq!(key_tensor_shape.len(), 1); + assert_eq!(value_tensor_shape.len(), 1); + assert_eq!(key_tensor_shape[0], value_tensor_shape[0]); + + let mut vec = Vec::with_capacity(key_tensor_shape[0] as _); + for i in 0..key_tensor_shape[0] as usize { + vec.push((key_tensor[i].clone(), value_tensor[i].clone())); + } + Ok(vec.into_iter().collect()) } - Ok(vec.into_iter().collect()) } t => Err(Error::NotMap(t)) } } } -impl Value> { +impl Value> { /// Creates a [`Map`] from an iterable emitting `K` and `V`. /// /// ``` /// # use std::collections::HashMap; - /// # use ort::{Allocator, Map}; + /// # use ort::Map; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let mut map = HashMap::::new(); /// map.insert(0, 1.0); /// map.insert(1, 2.0); /// map.insert(2, 3.0); /// - /// let value = Map::new(map)?; + /// let value = Map::::new(map)?; /// - /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); /// # Ok(()) /// # } /// ``` @@ -95,20 +154,45 @@ impl, Vec) = data.into_iter().unzip(); Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) } +} + +impl Value> { + /// Creates a [`Map`] from an iterable emitting `K` and `V`. + /// + /// ``` + /// # use std::collections::HashMap; + /// # use ort::Map; + /// # fn main() -> ort::Result<()> { + /// let mut map = HashMap::::new(); + /// map.insert(0, 1.0); + /// map.insert(1, 2.0); + /// map.insert(2, 3.0); + /// + /// let value = Map::::new(map)?; + /// + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); + /// # Ok(()) + /// # } + /// ``` + pub fn new(data: impl IntoIterator) -> Result { + let (keys, values): (Vec, Vec) = data.into_iter().unzip(); + Self::new_kv(Tensor::from_string_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) + } +} +impl Value> { /// Creates a [`Map`] from two tensors of keys & values respectively. /// /// ``` /// # use std::collections::HashMap; - /// # use ort::{Allocator, Map, Tensor}; + /// # use ort::{Map, Tensor}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let keys = Tensor::::from_array(([4], vec![0, 1, 2, 3]))?; /// let values = Tensor::::from_array(([4], vec![1., 2., 3., 4.]))?; /// /// let value = Map::new_kv(keys, values)?; /// - /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0); /// # Ok(()) /// # } /// ``` @@ -122,21 +206,23 @@ impl Value> { - pub fn extract_map(&self, allocator: &Allocator) -> HashMap { - self.try_extract_map(allocator).expect("Failed to extract map") +impl Value> { + pub fn extract_map(&self) -> HashMap { + self.try_extract_map().expect("Failed to extract map") } +} +impl Value> { /// Converts from a strongly-typed [`Map`] to a type-erased [`DynMap`]. #[inline] pub fn upcast(self) -> DynMap { @@ -149,7 +235,7 @@ impl(PhantomData); -impl ValueTypeMarker for SequenceValueType {} -impl SequenceValueTypeMarker for SequenceValueType {} +impl ValueTypeMarker for SequenceValueType { + crate::private_impl!(); +} +impl SequenceValueTypeMarker for SequenceValueType { + crate::private_impl!(); +} pub type DynSequence = Value; pub type Sequence = Value>; @@ -89,11 +100,11 @@ impl Value Value Value { + /// Construct a [`DynTensor`] from an array of strings. /// - /// Just like numeric tensors, string tensor `Value`s can be created from: + /// Just like numeric tensors, string tensors can be created from: /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); /// - (with feature `ndarray`) an owned [`ndarray::Array`]; @@ -36,26 +36,19 @@ impl DynTensor { /// ``` /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?; - /// // You'll need to obtain an `Allocator` from a session in order to create string tensors. - /// let allocator = session.allocator(); - /// /// // Create a string tensor from a raw data vector /// let data = vec!["hello", "world"]; - /// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?; + /// let value = Value::from_string_array(([data.len()], data.into_boxed_slice()))?; /// /// // Create a string tensor from an `ndarray::Array` /// #[cfg(feature = "ndarray")] - /// let value = Value::from_string_array( - /// allocator, - /// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap() - /// )?; + /// let value = Value::from_string_array(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap())?; /// # Ok(()) /// # } /// ``` /// /// Note that string data will *always* be copied, no matter what form the data is provided in. - pub fn from_string_array(allocator: &Allocator, input: impl IntoValueTensor) -> Result { + pub fn from_string_array(input: impl IntoValueTensor) -> Result> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); let (shape, data) = input.ref_parts()?; @@ -64,7 +57,7 @@ impl DynTensor { // create tensor without data -- data is filled in later ortsys![ - unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) + unsafe CreateTensorAsOrtValue(Allocator::default().ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) -> Error::CreateTensor; nonNull(value_ptr) ]; @@ -84,18 +77,18 @@ impl DynTensor { ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - }, + }), _markers: PhantomData }) } } -impl Tensor { - /// Construct a tensor [`Value`] in a given allocator with a given shape and datatype. The data contained in the +impl Tensor { + /// Construct a tensor in a given allocator with a given shape and datatype. The data contained in the /// value will be zero-allocated on the allocation device. /// /// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned @@ -132,18 +125,18 @@ impl Tensor { ]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - }, + }), _markers: PhantomData }) } - /// Construct a tensor [`Value`] from an array of data. + /// Construct a tensor from an array of data. /// - /// Tensor `Value`s can be created from: + /// Tensors can be created from: /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); /// - (with feature `ndarray`) an owned [`ndarray::Array`]; @@ -154,19 +147,19 @@ impl Tensor { /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. /// /// ``` - /// # use ort::Value; + /// # use ort::Tensor; /// # fn main() -> ort::Result<()> { /// // Create a tensor from a raw data vector - /// let value = Value::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?; + /// let tensor = Tensor::from_array(([1usize, 2, 3], vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].into_boxed_slice()))?; /// /// // Create a tensor from an `ndarray::Array` /// #[cfg(feature = "ndarray")] - /// let value = Value::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; + /// let tensor = Tensor::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; /// # Ok(()) /// # } /// ``` /// - /// Creating string tensors requires a separate method; see [`Value::from_string_array`]. + /// Creating string tensors requires a separate method; see [`DynTensor::from_string_array`]. /// /// Note that data provided in an `ndarray` may be copied in some circumstances: /// - `&CowArray<'_, T, D>` will always be copied regardless of whether it is uniquely owned or borrowed. @@ -177,7 +170,7 @@ impl Tensor { /// Raw data provided as a `Arc>`, `Box<[T]>`, or `Vec` will never be copied. Raw data is expected to be /// in standard, contigous layout. pub fn from_array(input: impl IntoValueTensor) -> Result> { - let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?; + let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::Default)?; let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); @@ -203,17 +196,17 @@ impl Tensor { ]; Ok(Value { - inner: ValueInner::RustOwned { + inner: Arc::new(ValueInner::RustOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: guard, _memory_info: Some(memory_info) - }, + }), _markers: PhantomData }) } } -impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> { +impl<'a, T: PrimitiveTensorElementType + Debug> TensorRefMut<'a, T> { /// Create a mutable tensor view from a raw pointer and shape. /// /// The length of data is determined by `T` and the given shape, so the given buffer must be at least @@ -260,11 +253,11 @@ impl<'a, T: IntoTensorElementType + Debug> TensorRefMut<'a, T> { ]; Ok(TensorRefMut::new(Value { - inner: ValueInner::CppOwned { + inner: Arc::new(ValueInner::CppOwned { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, drop: true, _session: None - }, + }), _markers: PhantomData })) } @@ -290,7 +283,7 @@ macro_rules! impl_to_dimensions { .enumerate() .map(|(i, c)| if *c >= 1 { Ok(*c as i64) } else { Err(Error::InvalidDimension(i)) }) .collect::>()?; - let sum = v.iter().product::() as usize; + let sum = calculate_tensor_size(&v); if let Some(expected_size) = expected_size { if sum != expected_size { Err(Error::TensorShapeMismatch { @@ -318,6 +311,14 @@ macro_rules! impl_to_dimensions { }; } +impl ToDimensions for () { + fn to_dimensions(&self, expected_size: Option) -> Result> { + match expected_size { + Some(1) | None => Ok(vec![]), + Some(x) => Err(Error::TensorShapeMismatch { input: vec![], total: 1, expected: x }) + } + } +} impl_to_dimensions!(for &[usize], for &[i32], for &[i64], for Vec, for Vec, for Vec); impl_to_dimensions!( for [usize; N], for [i32; N], for [i64; N]); @@ -500,7 +501,7 @@ impl IntoValueTensor for (D, Arc TryFrom<&'i CowArray<'v, T, D>> for Tensor +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Tensor where 'i: 'v { @@ -512,7 +513,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr) @@ -521,7 +522,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor where 'i: 'v { @@ -533,7 +534,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr).map(|c| c.upcast()) @@ -542,7 +543,7 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue +impl<'i, 'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue where 'i: 'v { @@ -554,7 +555,7 @@ where #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { +impl<'v, T: PrimitiveTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { Tensor::from_array(arr).map(|c| c.into_dyn()) @@ -564,19 +565,19 @@ impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'sta macro_rules! impl_try_from { (@T,I $($t:ty),+) => { $( - impl TryFrom<$t> for Tensor { + impl TryFrom<$t> for Tensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value) } } - impl TryFrom<$t> for DynTensor { + impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.upcast()) } } - impl TryFrom<$t> for crate::DynValue { + impl TryFrom<$t> for crate::DynValue { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.into_dyn()) @@ -587,21 +588,21 @@ macro_rules! impl_try_from { (@T,D $($t:ty),+) => { $( #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for Tensor { + impl TryFrom<$t> for Tensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for DynTensor { + impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.upcast()) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for crate::DynValue { + impl TryFrom<$t> for crate::DynValue { type Error = Error; fn try_from(value: $t) -> Result { Tensor::from_array(value).map(|c| c.into_dyn()) diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index a4859e73..22a52edf 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, os::raw::c_char, ptr, string::FromUtf8Error}; +use std::{fmt::Debug, ptr, string::FromUtf8Error}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; @@ -7,9 +7,7 @@ use super::TensorValueTypeMarker; #[cfg(feature = "ndarray")] use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; use crate::{ - ortsys, - tensor::{IntoTensorElementType, TensorElementType}, - Error, Result, Tensor, Value + ortsys, tensor::TensorElementType, value::impl_tensor::calculate_tensor_size, Error, PrimitiveTensorElementType, Result, Tensor, Value, ValueType }; impl Value { @@ -38,38 +36,81 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_tensor`] instead)* /// - The provided type `T` does not match the tensor's element type. + /// - The tensor's data is not allocated in CPU memory. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn try_extract_tensor(&self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + pub fn try_extract_tensor(&self) -> Result> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + Ok(extract_primitive_array(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } + } + t => Err(Error::NotTensor(t)) + } + } - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok(extract_primitive_array(shape, self.ptr())?) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + /// Attempt to extract the scalar from a tensor of type `T`. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Value}; + /// # fn main() -> ort::Result<()> { + /// let value = Value::from_array(((), vec![3.14_f32]))?; + /// + /// let extracted = value.try_extract_scalar::()?; + /// assert_eq!(extracted, 3.14); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// May return an error if: + /// - The tensor is not 0-dimensional. + /// - The provided type `T` does not match the tensor's element type. + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_tensor`] instead)* + /// - The tensor's data is not allocated in CPU memory. + pub fn try_extract_scalar(&self) -> Result { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if !dimensions.is_empty() { + return Err(Error::TensorNot0Dimensional(dimensions.len())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + Ok(unsafe { *output_array_ptr }) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`]. @@ -101,36 +142,26 @@ impl Value { /// - The provided type `T` does not match the tensor's element type. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn try_extract_tensor_mut(&mut self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok(extract_primitive_array_mut(shape, self.ptr())?) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + pub fn try_extract_tensor_mut(&mut self) -> Result> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + Ok(extract_primitive_array_mut(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), self.ptr())?) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an @@ -159,40 +190,32 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_raw_tensor`] instead)* /// - The provided type `T` does not match the tensor's element type. - pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok((node_dims, unsafe { std::slice::from_raw_parts(output_array_ptr, len as _) })) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + let len = calculate_tensor_size(&dimensions); + Ok((dimensions, unsafe { std::slice::from_raw_parts(output_array_ptr, len) })) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a @@ -218,50 +241,41 @@ impl Value { /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the /// infallible [`Tensor::extract_raw_tensor_mut`] instead)* /// - The provided type `T` does not match the tensor's element type. - pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == T::into_tensor_element_type() { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut T = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - Ok((node_dims, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len as _) })) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: T::into_tensor_element_type() - }) + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == T::into_tensor_element_type() { + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); + ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; + + let len = calculate_tensor_size(&dimensions); + Ok((dimensions, unsafe { std::slice::from_raw_parts_mut(output_array_ptr, len) })) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: T::into_tensor_element_type() + }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying data into a Rust `ndarray`. /// /// ``` - /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; + /// # use ort::{Session, Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let array = ndarray::Array1::from_vec(vec!["hello", "world"]); - /// let tensor = DynTensor::from_string_array(&allocator, array.clone())?; + /// let tensor = Tensor::from_string_array(array.clone())?; /// /// let extracted = tensor.try_extract_string_tensor()?; /// assert_eq!(array.into_dyn(), extracted); @@ -271,78 +285,68 @@ impl Value { #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub fn try_extract_string_tensor(&self) -> Result> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == TensorElementType::String { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - let shape = IxDyn(&node_dims.iter().map(|&n| n as usize).collect::>()); - - let mut len: ort_sys::size_t = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - - // Total length of string data, not including \0 suffix - let mut total_length: ort_sys::size_t = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; - - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length as _]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; (len + 1) as _]; - - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; - - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len as usize]); - offsets[len as usize] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0] as _..w[1] as _]; - String::from_utf8(slice.into()) + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == TensorElementType::String { + let len = calculate_tensor_size(&dimensions); + + // Total length of string data, not including \0 suffix + let mut total_length: ort_sys::size_t = 0; + ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length as _]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; (len + 1) as _]; + + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; + + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0] as _..w[1] as _]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::StringFromUtf8Error)?; + + Ok(ndarray::Array::from_shape_vec(IxDyn(&dimensions.iter().map(|&n| n as usize).collect::>()), strings) + .expect("Shape extracted from tensor didn't match tensor contents")) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: TensorElementType::String }) - .collect::, FromUtf8Error>>() - .map_err(Error::StringFromUtf8Error)?; - - Ok(ndarray::Array::from_shape_vec(shape, strings) - .expect("Shape extracted from tensor didn't match tensor contents") - .into_dyn()) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: TensorElementType::String - }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Attempt to extract the underlying string data into a "raw" data tuple, consisting of the tensor's dimensions and /// an owned `Vec` of its data. /// /// ``` - /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; + /// # use ort::{Session, Tensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// # let allocator = Allocator::default(); /// let array = vec!["hello", "world"]; - /// let tensor = DynTensor::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?; + /// let tensor = Tensor::from_string_array(([array.len()], array.clone().into_boxed_slice()))?; /// /// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?; /// assert_eq!(extracted_data, array); @@ -351,68 +355,57 @@ impl Value { /// # } /// ``` pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { - let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; - - let res = { - let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ortsys![unsafe GetTensorElementType(tensor_info_ptr, &mut type_sys) -> Error::GetTensorElementType]; - assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); - let data_type: TensorElementType = type_sys.into(); - if data_type == TensorElementType::String { - let mut num_dims = 0; - ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; - - let mut node_dims: Vec = vec![0; num_dims as _]; - ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; - - let mut output_array_ptr: *mut c_char = ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut c_char = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr.cast(); - ortsys![unsafe GetTensorMutableData(self.ptr(), output_array_ptr_ptr_void) -> Error::GetTensorMutableData; nonNull(output_array_ptr)]; - - let mut len: ort_sys::size_t = 0; - ortsys![unsafe GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> Error::GetTensorShapeElementCount]; - // Total length of string data, not including \0 suffix - let mut total_length = 0; - ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; - - // In the JNI impl of this, tensor_element_len was included in addition to total_length, - // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes - // don't seem to be written to in practice either. - // If the string data actually did go farther, it would panic below when using the offset - // data to get slices for each string. - let mut string_contents = vec![0u8; total_length as _]; - // one extra slot so that the total length can go in the last one, making all per-string - // length calculations easy - let mut offsets = vec![0; len as usize + 1]; - - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length as _, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; - - // final offset = overall length so that per-string length calculations work for the last string - debug_assert_eq!(0, offsets[len as usize]); - offsets[len as usize] = total_length; - - let strings = offsets - // offsets has 1 extra offset past the end so that all windows work - .windows(2) - .map(|w| { - let slice = &string_contents[w[0] as _..w[1] as _]; - String::from_utf8(slice.into()) + let dtype = self.dtype()?; + match dtype { + ValueType::Tensor { ty, dimensions } => { + let device = self.memory_info()?.allocation_device()?; + if !device.is_cpu_accessible() { + return Err(Error::TensorNotOnCpu(device.as_str())); + } + + if ty == TensorElementType::String { + let len = calculate_tensor_size(&dimensions); + + // Total length of string data, not including \0 suffix + let mut total_length: ort_sys::size_t = 0; + ortsys![unsafe GetStringTensorDataLength(self.ptr(), &mut total_length) -> Error::GetStringTensorDataLength]; + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0u8; total_length as _]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0; (len + 1) as _]; + + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; + + // final offset = overall length so that per-string length calculations work for the last string + debug_assert_eq!(0, offsets[len]); + offsets[len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let slice = &string_contents[w[0] as _..w[1] as _]; + String::from_utf8(slice.into()) + }) + .collect::, FromUtf8Error>>() + .map_err(Error::StringFromUtf8Error)?; + + Ok((dimensions, strings)) + } else { + Err(Error::DataTypeMismatch { + actual: ty, + requested: TensorElementType::String }) - .collect::, FromUtf8Error>>() - .map_err(Error::StringFromUtf8Error)?; - - Ok((node_dims, strings)) - } else { - Err(Error::DataTypeMismatch { - actual: data_type, - requested: TensorElementType::String - }) + } } - }; - ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; - res + t => Err(Error::NotTensor(t)) + } } /// Returns the shape of the tensor. @@ -445,7 +438,7 @@ impl Value { } } -impl Tensor { +impl Tensor { /// Extracts the underlying data into a read-only [`ndarray::ArrayView`]. /// /// ``` diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index d7f1db3b..a4c7ba69 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -11,41 +11,96 @@ use std::{ use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker}; use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType}; -pub trait TensorValueTypeMarker: ValueTypeMarker {} +pub trait TensorValueTypeMarker: ValueTypeMarker { + crate::private_trait!(); +} #[derive(Debug)] pub struct DynTensorValueType; -impl ValueTypeMarker for DynTensorValueType {} -impl TensorValueTypeMarker for DynTensorValueType {} +impl ValueTypeMarker for DynTensorValueType { + crate::private_impl!(); +} +impl TensorValueTypeMarker for DynTensorValueType { + crate::private_impl!(); +} #[derive(Debug)] pub struct TensorValueType(PhantomData); -impl ValueTypeMarker for TensorValueType {} -impl TensorValueTypeMarker for TensorValueType {} +impl ValueTypeMarker for TensorValueType { + crate::private_impl!(); +} +impl TensorValueTypeMarker for TensorValueType { + crate::private_impl!(); +} +/// A tensor [`Value`] whose data type is unknown. pub type DynTensor = Value; +/// A strongly-typed tensor [`Value`]. pub type Tensor = Value>; +/// A reference to a tensor [`Value`] whose data type is unknown. pub type DynTensorRef<'v> = ValueRef<'v, DynTensorValueType>; +/// A mutable reference to a tensor [`Value`] whose data type is unknown. pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>; +/// A reference to a strongly-typed tensor [`Value`]. pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType>; +/// A mutable reference to a strongly-typed tensor [`Value`]. pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType>; impl DowncastableTarget for DynTensorValueType { fn can_downcast(dtype: &ValueType) -> bool { matches!(dtype, ValueType::Tensor { .. }) } + + crate::private_impl!(); } impl Value { /// Returns a mutable pointer to the tensor's data. + /// + /// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a + /// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be + /// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access. + /// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before + /// accessing it. + /// + /// ``` + /// # use ort::{Allocator, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let mut tensor = Tensor::::from_array((vec![5], vec![0, 1, 2, 3, 4]))?; + /// let ptr = tensor.data_ptr_mut()?.cast::(); + /// unsafe { + /// *ptr.add(3) = 42; + /// }; + /// + /// let (_, extracted) = tensor.extract_raw_tensor(); + /// assert_eq!(&extracted, &[0, 1, 2, 42, 4]); + /// # Ok(()) + /// # } + /// ``` pub fn data_ptr_mut(&mut self) -> Result<*mut ort_sys::c_void> { let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)]; Ok(buffer_ptr) } - /// Returns a pointer to the tensor's data. + /// Returns an immutable pointer to the tensor's underlying data. + /// + /// It's important to note that the resulting pointer may not point to CPU-accessible memory. In the case of a + /// tensor created on a different EP device, e.g. via [`Tensor::new`], the pointer returned by this function may be + /// a CUDA pointer, which would require a separate crate (like [`cudarc`](https://crates.io/crates/cudarc)) to access. + /// Use [`Tensor::memory_info`] & [`MemoryInfo::allocation_device`] to check which device the data resides on before + /// accessing it. + /// + /// ``` + /// # use ort::{Allocator, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::from_array((vec![5], vec![0, 1, 2, 3, 4]))?; + /// let ptr = tensor.data_ptr()?.cast::(); + /// assert_eq!(unsafe { *ptr.add(3) }, 3); + /// # Ok(()) + /// # } + /// ``` pub fn data_ptr(&self) -> Result<*const ort_sys::c_void> { let mut buffer_ptr: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe GetTensorMutableData(self.ptr(), &mut buffer_ptr) -> Error::GetTensorMutableData; nonNull(buffer_ptr)]; @@ -53,6 +108,26 @@ impl Value { } /// Returns information about the device this tensor is allocated on. + /// + /// ``` + /// # use ort::{Allocator, AllocatorType, AllocationDevice, MemoryInfo, MemoryType, Session, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// // Tensors are allocated on CPU by default. + /// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CPU); + /// + /// # if false { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let cuda_allocator = Allocator::new( + /// &session, + /// MemoryInfo::new(AllocationDevice::CUDA, 0, AllocatorType::Device, MemoryType::Default)? + /// )?; + /// let tensor = Tensor::::new(&cuda_allocator, [1, 3, 224, 224])?; + /// assert_eq!(tensor.memory_info()?.allocation_device()?, AllocationDevice::CUDA); + /// # } + /// # Ok(()) + /// # } + /// ``` pub fn memory_info(&self) -> Result { let mut memory_info_ptr: *const ort_sys::OrtMemoryInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorMemoryInfo(self.ptr(), &mut memory_info_ptr) -> Error::GetTensorMemoryInfo; nonNull(memory_info_ptr)]; @@ -62,29 +137,68 @@ impl Value { impl Tensor { /// Converts from a strongly-typed [`Tensor`] to a type-erased [`DynTensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// let tensor_dyn = tensor.upcast(); + /// assert!(tensor_dyn.try_extract_raw_tensor::().is_ok()); + /// assert!(tensor_dyn.try_extract_raw_tensor::().is_err()); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast(self) -> DynTensor { unsafe { std::mem::transmute(self) } } - /// Converts from a strongly-typed [`Tensor`] to a reference to a type-erased [`DynTensor`]. + /// Creates a type-erased [`DynTensorRef`] from a strongly-typed [`Tensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let tensor = Tensor::::new(&Allocator::default(), [1, 3, 224, 224])?; + /// let tensor_dyn = tensor.upcast_ref(); + /// + /// let (_, original_extract) = tensor.extract_raw_tensor(); + /// let (_, ref_extract) = tensor_dyn.try_extract_raw_tensor::()?; + /// assert_eq!(original_extract, ref_extract); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast_ref(&self) -> DynTensorRef { DynTensorRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } /// Converts from a strongly-typed [`Tensor`] to a mutable reference to a type-erased [`DynTensor`]. + /// + /// ``` + /// # use ort::{Allocator, DynTensor, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let mut tensor = Tensor::::from_array((vec![5], vec![1, 2, 3, 4, 5]))?; + /// let mut tensor_dyn = tensor.upcast_mut(); + /// + /// let (_, mut_view) = tensor_dyn.try_extract_raw_tensor_mut::()?; + /// mut_view[3] = 0; + /// + /// let (_, original_view) = tensor.extract_raw_tensor(); + /// assert_eq!(original_view, &[1, 2, 3, 0, 5]); + /// # Ok(()) + /// # } + /// ``` #[inline] pub fn upcast_mut(&mut self) -> DynTensorRefMut { DynTensorRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -97,6 +211,8 @@ impl DowncastableTarget for TensorValueType _ => false } } + + crate::private_impl!(); } impl From>> for DynValue { @@ -113,6 +229,17 @@ impl From> for DynValue { impl Index<[i64; N]> for Tensor { type Output = T; fn index(&self, index: [i64; N]) -> &Self::Output { + // Interestingly, the `TensorAt` API doesn't check if the tensor is on CPU, so we have to perform the check ourselves. + if !self + .memory_info() + .expect("could not retrieve tensor memory info") + .allocation_device() + .expect("could not retrieve tensor allocation device") + .is_cpu_accessible() + { + panic!("Cannot directly index a tensor which is not allocated on the CPU."); + } + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")]; unsafe { &*out.cast::() } @@ -120,25 +247,46 @@ impl Index<[i64; N]> f } impl IndexMut<[i64; N]> for Tensor { fn index_mut(&mut self, index: [i64; N]) -> &mut Self::Output { + if !self + .memory_info() + .expect("could not retrieve tensor memory info") + .allocation_device() + .expect("could not retrieve tensor allocation device") + .is_cpu_accessible() + { + panic!("Cannot directly index a tensor which is not allocated on the CPU."); + } + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).expect("Failed to index tensor")]; unsafe { &mut *out.cast::() } } } +pub(crate) fn calculate_tensor_size(shape: &[i64]) -> usize { + let mut size = 1usize; + for dim in shape { + if *dim < 0 { + return 0; + } + size *= *dim as usize; + } + size +} + #[cfg(test)] mod tests { use std::sync::Arc; use ndarray::{ArcArray1, Array1, CowArray}; - use crate::{Allocator, DynTensor, TensorElementType, Value, ValueType}; + use crate::{Tensor, TensorElementType, ValueType}; #[test] #[cfg(feature = "ndarray")] fn test_tensor_value() -> crate::Result<()> { let v: Vec = vec![1., 2., 3., 4., 5.]; - let value = Value::from_array(Array1::from_vec(v.clone()))?; + let value = Tensor::from_array(Array1::from_vec(v.clone()))?; assert!(value.is_tensor()?); assert_eq!(value.dtype()?.tensor_type(), Some(TensorElementType::Float32)); assert_eq!( @@ -163,17 +311,17 @@ mod tests { let arc1 = ArcArray1::from_vec(v.clone()); let mut arc2 = ArcArray1::clone(&arc1); - let value = Value::from_array(&mut arc2)?; + let value = Tensor::from_array(&mut arc2)?; drop((arc1, arc2)); assert_eq!(value.extract_raw_tensor().1, &v); let cow = CowArray::from(Array1::from_vec(v.clone())); - let value = Value::from_array(&cow)?; + let value = Tensor::from_array(&cow)?; assert_eq!(value.extract_raw_tensor().1, &v); let owned = Array1::from_vec(v.clone()); - let value = Value::from_array(owned.view())?; + let value = Tensor::from_array(owned.view())?; drop(owned); assert_eq!(value.extract_raw_tensor().1, &v); @@ -186,7 +334,7 @@ mod tests { let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; - let value = Value::from_array((shape, Arc::clone(&arc)))?; + let value = Tensor::from_array((shape, Arc::clone(&arc)))?; drop(arc); assert_eq!(value.try_extract_raw_tensor::()?.1, &v); @@ -196,10 +344,9 @@ mod tests { #[test] #[cfg(feature = "ndarray")] fn test_string_tensor_ndarray() -> crate::Result<()> { - let allocator = Allocator::default(); let v = Array1::from_vec(vec!["hello world".to_string(), "こんにちは世界".to_string()]); - let value = DynTensor::from_string_array(&allocator, v.view())?; + let value = Tensor::from_string_array(v.view())?; let extracted = value.try_extract_string_tensor()?; assert_eq!(extracted, v.into_dyn()); @@ -208,10 +355,9 @@ mod tests { #[test] fn test_string_tensor_raw() -> crate::Result<()> { - let allocator = Allocator::default(); let v = vec!["hello world".to_string(), "こんにちは世界".to_string()]; - let value = DynTensor::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?; + let value = Tensor::from_string_array((vec![v.len() as i64], v.clone().into_boxed_slice()))?; let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?; assert_eq!(extracted_shape, [v.len() as i64]); assert_eq!(extracted_view, v); @@ -224,10 +370,10 @@ mod tests { let v: Vec = vec![1., 2., 3., 4., 5.]; let shape = [v.len()]; - let value_arc_box = Value::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; - let value_box = Value::from_array((shape, v.clone().into_boxed_slice()))?; - let value_vec = Value::from_array((shape, v.clone()))?; - let value_slice = Value::from_array((shape, &v[..]))?; + let value_arc_box = Tensor::from_array((shape, Arc::new(v.clone().into_boxed_slice())))?; + let value_box = Tensor::from_array((shape, v.clone().into_boxed_slice()))?; + let value_vec = Tensor::from_array((shape, v.clone()))?; + let value_slice = Tensor::from_array((shape, &v[..]))?; assert_eq!(value_arc_box.extract_raw_tensor().1, &v); assert_eq!(value_box.extract_raw_tensor().1, &v); diff --git a/src/value/mod.rs b/src/value/mod.rs index 868403f1..cf506009 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -166,6 +166,14 @@ pub(crate) enum ValueInner { } } +impl ValueInner { + pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue { + match self { + ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() + } + } +} + /// A temporary version of a [`Value`] with a lifetime specifier. #[derive(Debug)] pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { @@ -277,8 +285,8 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> { /// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`] #[derive(Debug)] pub struct Value { - inner: ValueInner, - _markers: PhantomData + pub(crate) inner: Arc, + pub(crate) _markers: PhantomData } /// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`]. @@ -291,11 +299,15 @@ pub type DynValue = Value; /// /// For example, [`Tensor::try_extract_tensor`] can only be used on [`Value`]s with the [`TensorValueTypeMarker`] (which /// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s. -pub trait ValueTypeMarker: Debug {} +pub trait ValueTypeMarker: Debug { + crate::private_trait!(); +} /// Represents a type that a [`DynValue`] can be downcast to. pub trait DowncastableTarget: ValueTypeMarker { fn can_downcast(dtype: &ValueType) -> bool; + + crate::private_trait!(); } // this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence` @@ -303,15 +315,25 @@ impl DowncastableTarget for DynValueTypeMarker { fn can_downcast(_: &ValueType) -> bool { true } + + crate::private_impl!(); } /// The dynamic type marker, used for values which can be of any type. #[derive(Debug)] pub struct DynValueTypeMarker; -impl ValueTypeMarker for DynValueTypeMarker {} -impl MapValueTypeMarker for DynValueTypeMarker {} -impl SequenceValueTypeMarker for DynValueTypeMarker {} -impl TensorValueTypeMarker for DynValueTypeMarker {} +impl ValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl MapValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl SequenceValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} +impl TensorValueTypeMarker for DynValueTypeMarker { + crate::private_impl!(); +} unsafe impl Send for Value {} @@ -350,7 +372,7 @@ impl Value { /// /// If the value belongs to a session (i.e. if it is returned from [`crate::Session::run`] or /// [`crate::IoBinding::run`]), you must provide the [`SharedSessionInner`] (acquired from - /// [`crate::Session::inner`]). This ensures the session is not dropped until the value is. + /// [`crate::Session::inner`]). This ensures the session is not dropped until any values owned by it is. /// /// # Safety /// @@ -359,7 +381,7 @@ impl Value { #[must_use] pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: true, _session: session }, + inner: Arc::new(ValueInner::CppOwned { ptr, drop: true, _session: session }), _markers: PhantomData } } @@ -369,16 +391,14 @@ impl Value { #[must_use] pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: false, _session: session }, + inner: Arc::new(ValueInner::CppOwned { ptr, drop: false, _session: session }), _markers: PhantomData } } /// Returns the underlying [`ort_sys::OrtValue`] pointer. pub fn ptr(&self) -> *mut ort_sys::OrtValue { - match &self.inner { - ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr() - } + self.inner.ptr() } /// Create a view of this value's data. @@ -386,7 +406,7 @@ impl Value { ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -396,7 +416,7 @@ impl Value { ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) }) } @@ -442,7 +462,7 @@ impl Value { Ok(ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) })) } else { @@ -450,7 +470,7 @@ impl Value { } } - /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed + /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed /// mutable-reference variant, like [`TensorRefMut`]. #[inline] pub fn downcast_mut(&mut self) -> Result> { @@ -459,7 +479,7 @@ impl Value { Ok(ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), - if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None } ) })) } else { @@ -468,17 +488,17 @@ impl Value { } } -impl Drop for Value { +impl Drop for ValueInner { fn drop(&mut self) { let ptr = self.ptr(); tracing::trace!( "dropping {} value at {ptr:p}", - match &self.inner { + match self { ValueInner::RustOwned { .. } => "rust-owned", ValueInner::CppOwned { .. } => "cpp-owned" } ); - if !matches!(&self.inner, ValueInner::CppOwned { drop: false, .. }) { + if !matches!(self, ValueInner::CppOwned { drop: false, .. }) { ortsys![unsafe ReleaseValue(ptr)]; } } diff --git a/src/wasm.rs b/src/wasm.rs index 51cf82f3..235a2198 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -1,6 +1,6 @@ //! Utilities for using `ort` in WebAssembly. //! -//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs: +//! You **must** call `ort::wasm::initialize()` before using any `ort` APIs in WASM: //! ``` //! # use ort::Session; //! # static MODEL_BYTES: &[u8] = include_bytes!("../tests/data/upsample.ort"); @@ -223,12 +223,12 @@ mod emscripten_shims { #[no_mangle] #[export_name = "_initialize"] pub fn initialize() { - // No idea what the hell this does, but the presence of an `_initialize` function prevents the linker from calling - // `__wasm_call_ctors` at the top of every function - including the functions `wasm-bindgen` interprets to generate - // JS glue code. The `__wasm_call_ctors` call was calling complex functions that the interpreter isn't equipped to - // handle, which was preventing wbg from outputting anything. I don't know what specific constructors this is calling, - // and most basic ONNX Runtime APIs *do* work without calling this, but we encourage the user to perform this - // initialization at program start anyways to be safe. + // The presence of an `_initialize` function prevents the linker from calling `__wasm_call_ctors` at the top of every + // function - including the functions `wasm-bindgen` interprets to generate JS glue code. `__wasm_call_ctors` calls + // complex functions that wbg's interpreter isn't equipped to handle, which was preventing wbg from outputting + // anything. + // I'm not entirely sure what `__wasm_call_ctors` is initializing, but it seems to have something to do with C++ + // vtables, and it's crucial for proper operation. extern "C" { fn __wasm_call_ctors(); } diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index bef1a572..f3af20a7 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -3,7 +3,7 @@ use std::path::Path; use ndarray::{ArrayD, IxDyn}; -use ort::{inputs, DynTensor, GraphOptimizationLevel, Session}; +use ort::{inputs, GraphOptimizationLevel, Session, Tensor}; use test_log::test; #[test] @@ -22,7 +22,7 @@ fn vectorizer() -> ort::Result<()> { let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()); // Just one input - let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?; + let input_tensor_values = inputs![Tensor::from_string_array(&array)?]?; // Perform the inference let outputs = session.run(input_tensor_values)?; From 882f657599b7703ead97f56aacfd6f7ed2cab244 Mon Sep 17 00:00:00 2001 From: cagnolone Date: Fri, 21 Jun 2024 22:39:03 +0200 Subject: [PATCH 27/49] fix: bundle libonnxruntime in rlib/staticlib builds (#215) Fixes #214. --- ort-sys/build.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index ccbd5684..da15be08 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -409,7 +409,15 @@ fn real_main(link: bool) { if link { if needs_link { - println!("cargo:rustc-link-lib=onnxruntime"); + let target_os = env::var("CARGO_CFG_TARGET_OS").unwrap(); + let static_lib_file_name = if target_os.contains("windows") { "onnxruntime.lib" } else { "libonnxruntime.a" }; + + let static_lib_path = lib_dir.join(static_lib_file_name); + if static_lib_path.exists() { + println!("cargo:rustc-link-lib=static=onnxruntime"); + } else { + println!("cargo:rustc-link-lib=onnxruntime"); + } println!("cargo:rustc-link-search=native={}", lib_dir.display()); } From 668260c54f318e922dd512f4dbe44986db8a748d Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Fri, 21 Jun 2024 15:44:32 -0500 Subject: [PATCH 28/49] fix: ARM build This whole thing where usize != size_t specifically on aarch64 is so bad. --- src/lib.rs | 2 +- src/value/impl_tensor/extract.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a071b867..08e5321c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ pub use self::session::{ #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; -pub use self::tensor::{IntoTensorElementType, Utf8Data, PrimitiveTensorElementType, TensorElementType}; +pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 22a52edf..bec573f4 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -310,7 +310,7 @@ impl Value { // length calculations easy let mut offsets = vec![0; (len + 1) as _]; - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; // final offset = overall length so that per-string length calculations work for the last string debug_assert_eq!(0, offsets[len]); @@ -380,7 +380,7 @@ impl Value { // length calculations easy let mut offsets = vec![0; (len + 1) as _]; - ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len) -> Error::GetStringTensorContent]; + ortsys![unsafe GetStringTensorContent(self.ptr(), string_contents.as_mut_ptr().cast(), total_length, offsets.as_mut_ptr(), len as _) -> Error::GetStringTensorContent]; // final offset = overall length so that per-string length calculations work for the last string debug_assert_eq!(0, offsets[len]); From 860e4496ad99279aab3d5d2755a4d0984f28db76 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 25 Jun 2024 04:22:04 +0900 Subject: [PATCH 29/49] fix: `i686-pc-windows-msvc` build (#218) --- src/operator/bound.rs | 252 ++++++++++++++++++++++++------------------ 1 file changed, 142 insertions(+), 110 deletions(-) diff --git a/src/operator/bound.rs b/src/operator/bound.rs index 58219c47..452736fa 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -9,7 +9,7 @@ use super::{ kernel::{Kernel, KernelAttributes, KernelContext}, DummyOperator, Operator }; -use crate::error::IntoStatus; +use crate::{error::IntoStatus, extern_system_fn}; #[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later pub(crate) struct BoundOperator { @@ -65,115 +65,147 @@ impl BoundOperator { &*op.cast() } - pub(crate) unsafe extern "C" fn CreateKernelV2( - _: *const ort_sys::OrtCustomOp, - _: *const ort_sys::OrtApi, - info: *const ort_sys::OrtKernelInfo, - kernel_ptr: *mut *mut ort_sys::c_void - ) -> *mut ort_sys::OrtStatus { - let kernel = match O::create_kernel(&KernelAttributes::new(info)) { - Ok(kernel) => kernel, - e => return e.into_status() - }; - *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast(); - Ok(()).into_status() - } - - pub(crate) unsafe extern "C" fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { - let context = KernelContext::new(context); - O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::() }, &context).into_status() - } - - pub(crate) unsafe extern "C" fn KernelDestroy(op_kernel: *mut ort_sys::c_void) { - drop(Box::from_raw(op_kernel.cast::())); - } - - pub(crate) unsafe extern "C" fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { - let safe = Self::safe(op); - safe.name.as_ptr() - } - pub(crate) unsafe extern "C" fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { - let safe = Self::safe(op); - safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) - } - - pub(crate) unsafe extern "C" fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::min_version() - } - pub(crate) unsafe extern "C" fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::max_version() - } - - pub(crate) unsafe extern "C" fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType { - O::inputs()[index as usize].memory_type.into() - } - pub(crate) unsafe extern "C" fn GetInputCharacteristic( - _: *const ort_sys::OrtCustomOp, - index: ort_sys::size_t - ) -> ort_sys::OrtCustomOpInputOutputCharacteristic { - O::inputs()[index as usize].characteristic.into() - } - pub(crate) unsafe extern "C" fn GetOutputCharacteristic( - _: *const ort_sys::OrtCustomOp, - index: ort_sys::size_t - ) -> ort_sys::OrtCustomOpInputOutputCharacteristic { - O::outputs()[index as usize].characteristic.into() - } - pub(crate) unsafe extern "C" fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { - O::inputs().len() as _ - } - pub(crate) unsafe extern "C" fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { - O::outputs().len() as _ - } - pub(crate) unsafe extern "C" fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { - O::inputs()[index as usize] - .r#type - .map(|c| c.into()) - .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) - } - pub(crate) unsafe extern "C" fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { - O::outputs()[index as usize] - .r#type - .map(|c| c.into()) - .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) - } - pub(crate) unsafe extern "C" fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::inputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_min_arity) - .unwrap_or(1) - .try_into() - .expect("input minimum arity overflows i32") - } - pub(crate) unsafe extern "C" fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::inputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_homogeneity) - .unwrap_or(false) - .into() - } - pub(crate) unsafe extern "C" fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::outputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_min_arity) - .unwrap_or(1) - .try_into() - .expect("output minimum arity overflows i32") - } - pub(crate) unsafe extern "C" fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::outputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_homogeneity) - .unwrap_or(false) - .into() - } - - pub(crate) unsafe extern "C" fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { - O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status() + extern_system_fn! { + pub(crate) unsafe fn CreateKernelV2( + _: *const ort_sys::OrtCustomOp, + _: *const ort_sys::OrtApi, + info: *const ort_sys::OrtKernelInfo, + kernel_ptr: *mut *mut ort_sys::c_void + ) -> *mut ort_sys::OrtStatus { + let kernel = match O::create_kernel(&KernelAttributes::new(info)) { + Ok(kernel) => kernel, + e => return e.into_status() + }; + *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast(); + Ok(()).into_status() + } + } + + extern_system_fn! { + pub(crate) unsafe fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { + let context = KernelContext::new(context); + O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::() }, &context).into_status() + } + } + + extern_system_fn! { + pub(crate) unsafe fn KernelDestroy(op_kernel: *mut ort_sys::c_void) { + drop(Box::from_raw(op_kernel.cast::())); + } + } + + extern_system_fn! { + pub(crate) unsafe fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { + let safe = Self::safe(op); + safe.name.as_ptr() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { + let safe = Self::safe(op); + safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) + } + } + + extern_system_fn! { + pub(crate) unsafe fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::min_version() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::max_version() + } + } + + extern_system_fn! { + pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType { + O::inputs()[index as usize].memory_type.into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic { + O::inputs()[index as usize].characteristic.into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic { + O::outputs()[index as usize].characteristic.into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { + O::inputs().len() as _ + } + } + extern_system_fn! { + pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { + O::outputs().len() as _ + } + } + extern_system_fn! { + pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { + O::inputs()[index as usize] + .r#type + .map(|c| c.into()) + .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) + } + } + extern_system_fn! { + pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { + O::outputs()[index as usize] + .r#type + .map(|c| c.into()) + .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::inputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_min_arity) + .unwrap_or(1) + .try_into() + .expect("input minimum arity overflows i32") + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::inputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_homogeneity) + .unwrap_or(false) + .into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::outputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_min_arity) + .unwrap_or(1) + .try_into() + .expect("output minimum arity overflows i32") + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::outputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_homogeneity) + .unwrap_or(false) + .into() + } + } + + extern_system_fn! { + pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { + O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status() + } } } From d59ac432418d0dbc68078d5d8b988e87e002c0b6 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 11:49:08 -0500 Subject: [PATCH 30/49] refactor: take `RunOptions` by reference, add `run_async_with_options` --- src/session/async.rs | 41 ++++++++-- src/session/mod.rs | 152 +++++++++++-------------------------- src/session/run_options.rs | 98 ++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 112 deletions(-) create mode 100644 src/session/run_options.rs diff --git a/src/session/async.rs b/src/session/async.rs index 104618b1..ab162b12 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -3,6 +3,7 @@ use std::{ ffi::{c_char, CString}, future::Future, mem::MaybeUninit, + ops::Deref, pin::Pin, ptr::NonNull, sync::{ @@ -85,14 +86,42 @@ impl<'s> Drop for InferenceFutInner<'s> { unsafe impl<'s> Send for InferenceFutInner<'s> {} unsafe impl<'s> Sync for InferenceFutInner<'s> {} -pub struct InferenceFut<'s> { +pub enum RunOptionsRef<'r> { + Arc(Arc), + Ref(&'r RunOptions) +} + +impl<'r> From<&Arc> for RunOptionsRef<'r> { + fn from(value: &Arc) -> Self { + Self::Arc(Arc::clone(value)) + } +} + +impl<'r> From<&'r RunOptions> for RunOptionsRef<'r> { + fn from(value: &'r RunOptions) -> Self { + Self::Ref(value) + } +} + +impl<'r> Deref for RunOptionsRef<'r> { + type Target = RunOptions; + + fn deref(&self) -> &Self::Target { + match self { + Self::Arc(r) => r, + Self::Ref(r) => r + } + } +} + +pub struct InferenceFut<'s, 'r> { inner: Arc>, - run_options: Arc, + run_options: RunOptionsRef<'r>, did_receive: bool } -impl<'s> InferenceFut<'s> { - pub(crate) fn new(inner: Arc>, run_options: Arc) -> Self { +impl<'s, 'r> InferenceFut<'s, 'r> { + pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r>) -> Self { Self { inner, run_options, @@ -101,7 +130,7 @@ impl<'s> InferenceFut<'s> { } } -impl<'s> Future for InferenceFut<'s> { +impl<'s, 'r> Future for InferenceFut<'s, 'r> { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -122,7 +151,7 @@ impl<'s> Future for InferenceFut<'s> { } } -impl<'s> Drop for InferenceFut<'s> { +impl<'s, 'r> Drop for InferenceFut<'s, 'r> { fn drop(&mut self) { if !self.did_receive && self.inner.close() { let _ = self.run_options.terminate(); diff --git a/src/session/mod.rs b/src/session/mod.rs index 7e91a908..33c20d9f 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,6 +2,8 @@ use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; +use r#async::RunOptionsRef; + use super::{ char_p_to_string, environment::Environment, @@ -18,12 +20,14 @@ mod r#async; pub(crate) mod builder; pub(crate) mod input; pub(crate) mod output; +mod run_options; use self::r#async::{AsyncInferenceContext, InferenceFutInner}; pub use self::{ r#async::InferenceFut, builder::{GraphOptimizationLevel, SessionBuilder}, input::{SessionInputValue, SessionInputs}, - output::SessionOutputs + output::SessionOutputs, + run_options::RunOptions }; /// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator. @@ -112,101 +116,6 @@ pub struct Output { pub output_type: ValueType } -/// A structure which can be passed to [`Session::run_with_options`] to allow terminating/unterminating a session -/// inference run from a different thread. -#[derive(Debug)] -pub struct RunOptions { - pub(crate) run_options_ptr: NonNull -} - -// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 -unsafe impl Send for RunOptions {} -unsafe impl Sync for RunOptions {} - -impl RunOptions { - /// Creates a new [`RunOptions`] struct. - pub fn new() -> Result { - let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); - ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; - Ok(Self { - run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) } - }) - } - - /// Sets a tag to identify this run in logs. - pub fn set_tag(&mut self, tag: impl AsRef) -> Result<()> { - let tag = CString::new(tag.as_ref())?; - ortsys![unsafe RunOptionsSetRunTag(self.run_options_ptr.as_ptr(), tag.as_ptr()) -> Error::RunOptionsSetTag]; - Ok(()) - } - - /// Sets the termination flag for the runs associated with this [`RunOptions`]. - /// - /// This function returns immediately (it does not wait for the session run to terminate). The run will terminate as - /// soon as it is able to. - /// - /// ```no_run - /// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough - /// # use std::sync::Arc; - /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; - /// let run_options = Arc::new(RunOptions::new()?); - /// - /// let run_options_ = Arc::clone(&run_options); - /// std::thread::spawn(move || { - /// let _ = run_options_.terminate(); - /// }); - /// - /// let res = session.run_with_options(ort::inputs![input]?, run_options); - /// // upon termination, the session will return an `Error::SessionRun` error.` - /// assert_eq!( - /// &res.unwrap_err().to_string(), - /// "Failed to run inference on model: Exiting due to terminate flag being set to true." - /// ); - /// # Ok(()) - /// # } - /// ``` - pub fn terminate(&self) -> Result<()> { - ortsys![unsafe RunOptionsSetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsSetTerminate]; - Ok(()) - } - - /// Resets the termination flag for the runs associated with [`RunOptions`]. - /// - /// ```no_run - /// # use std::sync::Arc; - /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; - /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; - /// let run_options = Arc::new(RunOptions::new()?); - /// - /// let run_options_ = Arc::clone(&run_options); - /// std::thread::spawn(move || { - /// let _ = run_options_.terminate(); - /// // ...oops, didn't mean to do that - /// let _ = run_options_.unterminate(); - /// }); - /// - /// let res = session.run_with_options(ort::inputs![input]?, run_options); - /// assert!(res.is_ok()); - /// # Ok(()) - /// # } - /// ``` - pub fn unterminate(&self) -> Result<()> { - ortsys![unsafe RunOptionsUnsetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsUnsetTerminate]; - Ok(()) - } -} - -impl Drop for RunOptions { - fn drop(&mut self) { - ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; - } -} - impl Session { /// Creates a new [`SessionBuilder`]. pub fn builder() -> Result { @@ -283,7 +192,7 @@ impl Session { /// let _ = run_options_.terminate(); /// }); /// - /// let res = session.run_with_options(ort::inputs![input]?, run_options); + /// let res = session.run_with_options(ort::inputs![input]?, &*run_options); /// // upon termination, the session will return an `Error::SessionRun` error.` /// assert_eq!( /// &res.unwrap_err().to_string(), @@ -295,7 +204,7 @@ impl Session { pub fn run_with_options<'s, 'i, 'v: 'i, const N: usize>( &'s self, input_values: impl Into>, - run_options: Arc + run_options: &RunOptions ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { @@ -314,7 +223,7 @@ impl Session { &self, input_names: &[&str], input_values: impl Iterator>, - run_options: Option> + run_options: Option<&RunOptions> ) -> Result> { let input_names_ptr: Vec<*const c_char> = input_names .iter() @@ -393,22 +302,53 @@ impl Session { /// # Ok(()) /// # }) } /// ``` - pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(&'s self, input_values: impl Into> + 'static) -> Result> { + pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>( + &'s self, + input_values: impl Into> + 'static + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { - self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter()) + self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter(), None) } SessionInputs::ValueMap(input_values) => { - self.run_inner_async(&input_values.iter().map(|(k, _)| k.to_string()).collect::>(), input_values.into_iter().map(|(_, v)| v)) + self.run_inner_async(&input_values.iter().map(|(k, _)| k.to_string()).collect::>(), input_values.into_iter().map(|(_, v)| v), None) } } } - fn run_inner_async<'s, 'v: 's>(&'s self, input_names: &[String], input_values: impl Iterator>) -> Result> { - // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial - // (performance-wise) for routines involving `tokio::select!` or timeouts - let run_options = Arc::new(RunOptions::new()?); + /// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`]. + /// See [`Session::run_with_options`] and [`Session::run_async`] for more details. + pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, const N: usize>( + &'s self, + input_values: impl Into> + 'static, + run_options: &'r RunOptions + ) -> Result> { + match input_values.into() { + SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), + SessionInputs::ValueArray(input_values) => { + self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::>(), input_values.into_iter(), Some(run_options)) + } + SessionInputs::ValueMap(input_values) => self.run_inner_async( + &input_values.iter().map(|(k, _)| k.to_string()).collect::>(), + input_values.into_iter().map(|(_, v)| v), + Some(run_options) + ) + } + } + + fn run_inner_async<'s, 'v: 's, 'r>( + &'s self, + input_names: &[String], + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let run_options = match run_options { + Some(r) => RunOptionsRef::Ref(r), + // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial + // (performance-wise) for routines involving `tokio::select!` or timeouts + None => RunOptionsRef::Arc(Arc::new(RunOptions::new()?)) + }; let input_name_ptrs: Vec<*const c_char> = input_names .iter() diff --git a/src/session/run_options.rs b/src/session/run_options.rs new file mode 100644 index 00000000..92e85e8e --- /dev/null +++ b/src/session/run_options.rs @@ -0,0 +1,98 @@ +use std::{ffi::CString, ptr::NonNull}; + +use crate::{ortsys, Error, Result}; + +/// A structure which can be passed to [`crate::Session::run_with_options`] to allow terminating/unterminating a session +/// inference run from a different thread. +#[derive(Debug)] +pub struct RunOptions { + pub(crate) run_options_ptr: NonNull +} + +// https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 +unsafe impl Send for RunOptions {} +unsafe impl Sync for RunOptions {} + +impl RunOptions { + /// Creates a new [`RunOptions`] struct. + pub fn new() -> Result { + let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); + ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; + Ok(Self { + run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) } + }) + } + + /// Sets a tag to identify this run in logs. + pub fn set_tag(&mut self, tag: impl AsRef) -> Result<()> { + let tag = CString::new(tag.as_ref())?; + ortsys![unsafe RunOptionsSetRunTag(self.run_options_ptr.as_ptr(), tag.as_ptr()) -> Error::RunOptionsSetTag]; + Ok(()) + } + + /// Sets the termination flag for the runs associated with this [`RunOptions`]. + /// + /// This function returns immediately (it does not wait for the session run to terminate). The run will terminate as + /// soon as it is able to. + /// + /// ```no_run + /// # // no_run because upsample.onnx is too simple of a model for the termination signal to be reliable enough + /// # use std::sync::Arc; + /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; + /// let run_options = Arc::new(RunOptions::new()?); + /// + /// let run_options_ = Arc::clone(&run_options); + /// std::thread::spawn(move || { + /// let _ = run_options_.terminate(); + /// }); + /// + /// let res = session.run_with_options(ort::inputs![input]?, &*run_options); + /// // upon termination, the session will return an `Error::SessionRun` error.` + /// assert_eq!( + /// &res.unwrap_err().to_string(), + /// "Failed to run inference on model: Exiting due to terminate flag being set to true." + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn terminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsSetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsSetTerminate]; + Ok(()) + } + + /// Resets the termination flag for the runs associated with [`RunOptions`]. + /// + /// ```no_run + /// # use std::sync::Arc; + /// # use ort::{Session, RunOptions, Value, ValueType, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// # let input = Value::from_array(ndarray::Array4::::zeros((1, 64, 64, 3)))?; + /// let run_options = Arc::new(RunOptions::new()?); + /// + /// let run_options_ = Arc::clone(&run_options); + /// std::thread::spawn(move || { + /// let _ = run_options_.terminate(); + /// // ...oops, didn't mean to do that + /// let _ = run_options_.unterminate(); + /// }); + /// + /// let res = session.run_with_options(ort::inputs![input]?, &*run_options); + /// assert!(res.is_ok()); + /// # Ok(()) + /// # } + /// ``` + pub fn unterminate(&self) -> Result<()> { + ortsys![unsafe RunOptionsUnsetTerminate(self.run_options_ptr.as_ptr()) -> Error::RunOptionsUnsetTerminate]; + Ok(()) + } +} + +impl Drop for RunOptions { + fn drop(&mut self) { + ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; + } +} From 8d1f6b6566a67d33c82d33157dfbab75397f73ef Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 12:25:53 -0500 Subject: [PATCH 31/49] feat: `OutputSelector` --- src/io_binding.rs | 6 +-- src/lib.rs | 4 +- src/session/async.rs | 28 +++++++------- src/session/mod.rs | 52 +++++++++++++++++--------- src/session/output.rs | 28 +++++++------- src/session/run_options.rs | 75 ++++++++++++++++++++++++++++++++++++-- 6 files changed, 138 insertions(+), 55 deletions(-) diff --git a/src/io_binding.rs b/src/io_binding.rs index 36b11e61..0467cbb7 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -172,16 +172,16 @@ impl<'s> IoBinding<'s> { } /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. - pub fn run(&mut self) -> Result> { + pub fn run(&mut self) -> Result> { self.run_inner(None) } /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. - pub fn run_with_options(&mut self, run_options: Arc) -> Result> { + pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { self.run_inner(Some(run_options)) } - fn run_inner(&mut self, run_options: Option>) -> Result> { + fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { let run_options_ptr = if let Some(run_options) = run_options { run_options.run_options_ptr.as_ptr() } else { diff --git a/src/lib.rs b/src/lib.rs index 08e5321c..b8c9176f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,8 +59,8 @@ pub use self::operator::{ InferShapeFn, Operator, OperatorDomain }; pub use self::session::{ - GraphOptimizationLevel, InMemorySession, Input, Output, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, - SharedSessionInner + GraphOptimizationLevel, InMemorySession, Input, Output, OutputSelector, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs, + SessionOutputs, SharedSessionInner }; #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] diff --git a/src/session/async.rs b/src/session/async.rs index ab162b12..4bc338f7 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -27,13 +27,13 @@ const VALUE_PRESENT: usize = 1 << 0; const CHANNEL_CLOSED: usize = 1 << 1; #[derive(Debug)] -pub(crate) struct InferenceFutInner<'s> { +pub(crate) struct InferenceFutInner<'r, 's> { presence: AtomicUsize, - value: UnsafeCell>>>, + value: UnsafeCell>>>, waker: Mutex> } -impl<'s> InferenceFutInner<'s> { +impl<'r, 's> InferenceFutInner<'r, 's> { pub(crate) fn new() -> Self { InferenceFutInner { presence: AtomicUsize::new(0), @@ -42,7 +42,7 @@ impl<'s> InferenceFutInner<'s> { } } - pub(crate) fn try_take(&self) -> InnerValue>> { + pub(crate) fn try_take(&self) -> InnerValue>> { let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire); if state_snapshot & VALUE_PRESENT == 0 { if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 { @@ -55,7 +55,7 @@ impl<'s> InferenceFutInner<'s> { } } - pub(crate) fn emplace_value(&self, value: Result>) { + pub(crate) fn emplace_value(&self, value: Result>) { unsafe { (*self.value.get()).write(value) }; self.presence.fetch_or(VALUE_PRESENT, Ordering::Release); } @@ -75,7 +75,7 @@ impl<'s> InferenceFutInner<'s> { } } -impl<'s> Drop for InferenceFutInner<'s> { +impl<'r, 's> Drop for InferenceFutInner<'r, 's> { fn drop(&mut self) { if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 { unsafe { (*self.value.get()).assume_init_drop() }; @@ -83,8 +83,8 @@ impl<'s> Drop for InferenceFutInner<'s> { } } -unsafe impl<'s> Send for InferenceFutInner<'s> {} -unsafe impl<'s> Sync for InferenceFutInner<'s> {} +unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {} +unsafe impl<'r, 's> Sync for InferenceFutInner<'r, 's> {} pub enum RunOptionsRef<'r> { Arc(Arc), @@ -115,13 +115,13 @@ impl<'r> Deref for RunOptionsRef<'r> { } pub struct InferenceFut<'s, 'r> { - inner: Arc>, + inner: Arc>, run_options: RunOptionsRef<'r>, did_receive: bool } impl<'s, 'r> InferenceFut<'s, 'r> { - pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r>) -> Self { + pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r>) -> Self { Self { inner, run_options, @@ -131,7 +131,7 @@ impl<'s, 'r> InferenceFut<'s, 'r> { } impl<'s, 'r> Future for InferenceFut<'s, 'r> { - type Output = Result>; + type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = Pin::into_inner(self); @@ -160,8 +160,8 @@ impl<'s, 'r> Drop for InferenceFut<'s, 'r> { } } -pub(crate) struct AsyncInferenceContext<'s> { - pub(crate) inner: Arc>, +pub(crate) struct AsyncInferenceContext<'r, 's> { + pub(crate) inner: Arc>, pub(crate) _input_values: Vec>, pub(crate) input_ort_values: Vec<*const ort_sys::OrtValue>, pub(crate) input_name_ptrs: Vec<*const c_char>, @@ -173,7 +173,7 @@ pub(crate) struct AsyncInferenceContext<'s> { crate::extern_system_fn! { pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: ort_sys::size_t, status: *mut OrtStatus) { - let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; + let ctx = unsafe { Box::from_raw(user_data.cast::>()) }; // Reconvert name ptrs to CString so drop impl is called and memory is freed drop( diff --git a/src/session/mod.rs b/src/session/mod.rs index 33c20d9f..865e09ba 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -27,7 +27,7 @@ pub use self::{ builder::{GraphOptimizationLevel, SessionBuilder}, input::{SessionInputValue, SessionInputs}, output::SessionOutputs, - run_options::RunOptions + run_options::{OutputSelector, RunOptions} }; /// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator. @@ -161,7 +161,7 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { + pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) @@ -201,11 +201,11 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn run_with_options<'s, 'i, 'v: 'i, const N: usize>( + pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, const N: usize>( &'s self, input_values: impl Into>, - run_options: &RunOptions - ) -> Result> { + run_options: &'r RunOptions + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), Some(run_options)) @@ -219,25 +219,34 @@ impl Session { } } - fn run_inner<'i, 'v: 'i>( - &self, + fn run_inner<'i, 'r, 's: 'r, 'v: 'i>( + &'s self, input_names: &[&str], input_values: impl Iterator>, - run_options: Option<&RunOptions> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let input_names_ptr: Vec<*const c_char> = input_names .iter() .map(|n| CString::new(n.as_bytes()).unwrap_or_else(|_| unreachable!())) .map(|n| n.into_raw().cast_const()) .collect(); - let output_names_ptr: Vec<*const c_char> = self - .outputs + + let (output_names, output_tensors) = match run_options { + Some(r) => r.outputs.resolve_outputs(&self.outputs), + None => (self.outputs.iter().map(|o| o.name.as_str()).collect(), std::iter::repeat_with(|| None).take(self.outputs.len()).collect()) + }; + let output_names_ptr: Vec<*const c_char> = output_names .iter() - .map(|output| CString::new(output.name.as_str()).unwrap_or_else(|_| unreachable!())) + .map(|n| CString::new(*n).unwrap_or_else(|_| unreachable!())) .map(|n| n.into_raw().cast_const()) .collect(); - - let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = output_tensors + .iter() + .map(|c| match c { + Some(v) => v.ptr(), + None => std::ptr::null_mut() + }) + .collect(); // The C API expects pointers for the arrays (pointers to C-arrays) let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); @@ -261,10 +270,17 @@ impl Session { ) -> Error::SessionRun ]; - let outputs: Vec = output_tensor_ptrs + let outputs: Vec = output_tensors .into_iter() - .map(|tensor_ptr| unsafe { - Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), Some(Arc::clone(&self.inner))) + .enumerate() + .map(|(i, v)| match v { + Some(value) => value, + None => unsafe { + Value::from_ptr( + NonNull::new(output_tensor_ptrs[i]).expect("OrtValue ptr returned from session Run should not be null"), + Some(Arc::clone(&self.inner)) + ) + } }) .collect(); @@ -280,7 +296,7 @@ impl Session { .collect::>>()? ); - Ok(SessionOutputs::new(self.outputs.iter().map(|o| o.name.as_str()), outputs)) + Ok(SessionOutputs::new(output_names.into_iter(), outputs)) } /// Asynchronously run input data through the ONNX graph, performing inference. diff --git a/src/session/output.rs b/src/session/output.rs index c0fed437..74e13324 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -25,16 +25,16 @@ use crate::{Allocator, DynValue}; /// # } /// ``` #[derive(Debug)] -pub struct SessionOutputs<'s> { - map: BTreeMap<&'s str, DynValue>, - idxs: Vec<&'s str>, +pub struct SessionOutputs<'r, 's> { + map: BTreeMap<&'r str, DynValue>, + idxs: Vec<&'r str>, backing_ptr: Option<(&'s Allocator, *mut c_void)> } -unsafe impl<'s> Send for SessionOutputs<'s> {} +unsafe impl<'r, 's> Send for SessionOutputs<'r, 's> {} -impl<'s> SessionOutputs<'s> { - pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { +impl<'r, 's> SessionOutputs<'r, 's> { + pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { let map = output_names.clone().zip(output_values).collect(); Self { map, @@ -44,7 +44,7 @@ impl<'s> SessionOutputs<'s> { } pub(crate) fn new_backed( - output_names: impl Iterator + Clone, + output_names: impl Iterator + Clone, output_values: impl IntoIterator, allocator: &'s Allocator, backing_ptr: *mut c_void @@ -66,7 +66,7 @@ impl<'s> SessionOutputs<'s> { } } -impl<'s> Drop for SessionOutputs<'s> { +impl<'r, 's> Drop for SessionOutputs<'r, 's> { fn drop(&mut self) { if let Some((allocator, ptr)) = self.backing_ptr { unsafe { allocator.free(ptr) }; @@ -74,35 +74,35 @@ impl<'s> Drop for SessionOutputs<'s> { } } -impl<'s> Deref for SessionOutputs<'s> { - type Target = BTreeMap<&'s str, DynValue>; +impl<'r, 's> Deref for SessionOutputs<'r, 's> { + type Target = BTreeMap<&'r str, DynValue>; fn deref(&self) -> &Self::Target { &self.map } } -impl<'s> DerefMut for SessionOutputs<'s> { +impl<'r, 's> DerefMut for SessionOutputs<'r, 's> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.map } } -impl<'s> Index<&str> for SessionOutputs<'s> { +impl<'r, 's> Index<&str> for SessionOutputs<'r, 's> { type Output = DynValue; fn index(&self, index: &str) -> &Self::Output { self.map.get(index).expect("no entry found for key") } } -impl<'s> Index for SessionOutputs<'s> { +impl<'r, 's> Index for SessionOutputs<'r, 's> { type Output = DynValue; fn index(&self, index: String) -> &Self::Output { self.map.get(index.as_str()).expect("no entry found for key") } } -impl<'s> Index for SessionOutputs<'s> { +impl<'r, 's> Index for SessionOutputs<'r, 's> { type Output = DynValue; fn index(&self, index: usize) -> &Self::Output { self.map.get(&self.idxs[index]).expect("no entry found for key") diff --git a/src/session/run_options.rs b/src/session/run_options.rs index 92e85e8e..fa5ef214 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -1,12 +1,73 @@ -use std::{ffi::CString, ptr::NonNull}; +use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc}; -use crate::{ortsys, Error, Result}; +use crate::{ortsys, DynValue, Error, Output, Result, Value, ValueTypeMarker}; + +#[derive(Debug)] +pub struct OutputSelector { + use_defaults: bool, + default_blocklist: Vec, + allowlist: Vec, + preallocated_outputs: HashMap +} + +impl Default for OutputSelector { + fn default() -> Self { + Self { + use_defaults: true, + allowlist: Vec::new(), + default_blocklist: Vec::new(), + preallocated_outputs: HashMap::new() + } + } +} + +impl OutputSelector { + pub fn no_default() -> Self { + Self { + use_defaults: false, + ..Default::default() + } + } + + pub fn with(mut self, name: impl Into) -> Self { + self.allowlist.push(name.into()); + self + } + + pub fn without(mut self, name: impl Into) -> Self { + self.default_blocklist.push(name.into()); + self + } + + pub fn preallocate(mut self, name: impl Into, value: Value) -> Self { + self.preallocated_outputs.insert(name.into(), value.into_dyn()); + self + } + + pub(crate) fn resolve_outputs<'a, 's: 'a>(&'a self, outputs: &'s [Output]) -> (Vec<&'a str>, Vec>) { + if self.use_defaults { outputs.iter() } else { [].iter() } + .map(|o| &o.name) + .filter(|n| !self.default_blocklist.contains(n)) + .chain(self.allowlist.iter()) + .map(|n| { + ( + n.as_str(), + self.preallocated_outputs.get(n).map(|v| DynValue { + inner: Arc::clone(&v.inner), + _markers: PhantomData + }) + ) + }) + .unzip() + } +} /// A structure which can be passed to [`crate::Session::run_with_options`] to allow terminating/unterminating a session /// inference run from a different thread. #[derive(Debug)] pub struct RunOptions { - pub(crate) run_options_ptr: NonNull + pub(crate) run_options_ptr: NonNull, + pub(crate) outputs: OutputSelector } // https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 @@ -19,10 +80,16 @@ impl RunOptions { let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; Ok(Self { - run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) } + run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, + outputs: OutputSelector::default() }) } + pub fn with_outputs(mut self, outputs: OutputSelector) -> Self { + self.outputs = outputs; + self + } + /// Sets a tag to identify this run in logs. pub fn set_tag(&mut self, tag: impl AsRef) -> Result<()> { let tag = CString::new(tag.as_ref())?; From 8f8bbfb4247a8a6091b821f5b80f5299080705d6 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 13:11:57 -0500 Subject: [PATCH 32/49] chore: update to ONNX Runtime v1.18.1 --- README.md | 2 +- docs/pages/migrating/version-mapping.mdx | 4 ++-- docs/pages/perf/execution-providers.mdx | 5 +++++ ort-sys/dist.txt | 24 ++++++++++++------------ src/environment.rs | 2 +- src/memory.rs | 2 +- 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 55cc4b8e..0de9db92 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
Coverage Results Crates.io Open Collective backers and sponsors
- Crates.io ONNX Runtime + Crates.io ONNX Runtime `ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.18 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. diff --git a/docs/pages/migrating/version-mapping.mdx b/docs/pages/migrating/version-mapping.mdx index e238b5a8..c4ac5d43 100644 --- a/docs/pages/migrating/version-mapping.mdx +++ b/docs/pages/migrating/version-mapping.mdx @@ -6,7 +6,7 @@ description: Information about `ort`'s versioning and relation to ONNX Runtime v ## A note on SemVer `ort` versions pre-2.0 were not SemVer compatible. From v2.0 onwards, breaking API changes are accompanied by a **major version update**. -Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.18.0, but 2.1 may ship with 1.19.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): +Updates to the version of ONNX Runtime used by `ort` may occur on **minor** version updates, i.e. 2.0 ships with ONNX Runtime 1.18.1, but 2.1 may ship with 1.19.0. ONNX Runtime is generally forward compatible, but in case you require a specific version of ONNX Runtime, you should pin the minor version in your `Cargo.toml` using a [tilde requirement](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#tilde-requirements): ```toml [dependencies] ort = { version = "~2.0", ... } @@ -16,7 +16,7 @@ ort = { version = "~2.0", ... } | **ort** | **ONNX Runtime** | | -------- | ----------------:| -| v2.0.0+ | v1.18.0 | +| v2.0.0+ | v1.18.1 | | v1.16.0-v1.16.2 | v1.16.0 | | v1.15.0-v1.15.5 | v1.15.1 | | v1.14.2-v1.14.8 | v1.14.1 | diff --git a/docs/pages/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx index 923590b2..b084463e 100644 --- a/docs/pages/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -195,6 +195,11 @@ If it seems like the execution provider is not registering properly, or you are ## Notes +### CUDA +`ort` provides binaries for both CUDA 11 and CUDA 12; `ort` will automatically choose which binary to install based on whether CUDA 12 is installed. + +CUDA 11 requires cuDNN 8.x. CUDA 12 requires cuDNN 9.x. Make sure the correct version of cuDNN is installed and available on the `PATH`. + ### CoreML Statically linking to CoreML (the default behavior when using downloaded binaries + the `coreml` Cargo feature) requires an additional Rust flag in order to link properly. You'll need to provide the flag `-C link-arg=-fapple-link-rtlib` to `rustc`. You can do this via an entry in [`.cargo/config.toml`](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure), in a build script, or in an environment variable. diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index b50dc69e..5ba397a3 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -1,15 +1,15 @@ -none aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-unknown-linux-gnu.tgz 5337059CE144C2ACBEE4744E0E59644ED03196AF6423062C82567240DE7BE235 -cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-unknown-linux-gnu.tgz D37C85BB1CE639135B4C168DEC12120FDDC223D4F33193C11B7CFDAF755D4C92 -cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-unknown-linux-gnu.tgz 168478A99F4C514B1BD8A8C142ED502501AEEDA038497389BDDAC37B9F12ED77 -rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_rocm-v1.18.0-x86_64-unknown-linux-gnu.tgz D6113A895DEB0BCBC28FD7E23A201DE4C5FBA6BADEB49F3190A084A36C24B43D -none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-unknown-linux-gnu.tgz F486F4B9F040FF533DCD6B26E074BEB5F9092E8E4C67F72D08696D9EB4C9C082 +none aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-unknown-linux-gnu.tgz 221BFD9E8D0D5B31A3E7DD5290EAB8D677B31221DAC0DD2AD8D63DF216D411D8 +cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12-v1.18.1-x86_64-unknown-linux-gnu.tgz 74F4EC918F00B6517BC71AA23DEA0E0809694D5FAB10A494A7AE571F06AEA0BD +cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu11-v1.18.1-x86_64-unknown-linux-gnu.tgz 8DFD96ABB8AAE7F66DD415BCA65E9EE98033E11D3FDB2A178F31E8B1B61C1FD9 +rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz 84F74428E0BEC68C55B8E1E91B9282E984CD2866148A2584382B8CB3284214A3 +none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-unknown-linux-gnu.tgz 0A193706A95286853D792D7D9B2271CBEA35C57F249943FE811CED97E0E24862 -none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-pc-windows-msvc.tgz D9807ACE93E87CC45286A9B1892138FE4F28D1C764E30E9FC0B20DBE300063BA -cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-pc-windows-msvc.tgz D40FBAA7C4348A4CB5E38F59BB172A93829C6D457B1F97D68472D551A6E961E4 -cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-pc-windows-msvc.tgz 3BDA7C8BCB97DFB58114A391ECA7A1C1395A47CB88490103153B36CFF8C1CD48 -none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-pc-windows-msvc.tgz 6D9CFE125807CB9EC4C37D903457B82D505BB77CE7DEF4F750467613BBC2702A +none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-pc-windows-msvc.tgz B2F962F0E75F17F3D657B3504CE891BAA6461B26AF65FBD9244B3CCA17FD79D4 +cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12-v1.18.1-x86_64-pc-windows-msvc.tgz CDBC2D87B202E1847900E94796D102EE4D5C19A9568BBD014838ECD1F5D5350B +cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu11-v1.18.1-x86_64-pc-windows-msvc.tgz B514FC25453F955F8592100448B27F5E1762A344E8C2D57D41B908978EF2A126 +none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-pc-windows-msvc.tgz EB2BCD1778C5934437D4C5B17F67DEAF5F67D2E3C18C7298973EACD41113DC01 -none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-apple-darwin.tgz F8DB068DFACFE3B00B9F0181B79780C6971CD1A6EAEB9D9A7FC2129CEB8413A5 -none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-apple-darwin.tgz E6E0457CB9C727DBA818D10245D3A2A29203CB037546B39C217E4CC9FB61ABE8 +none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-apple-darwin.tgz B42BE76AFB9495983A6D5D498D56D5E685B018F1011EF4C5B8C56124B192FD37 +none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-apple-darwin.tgz 247F73A5B3665A6660DFB35213E6FEAAC6ED6CAC5816DD85A348DF790F60A30B -none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8AB76874E977961A1CFA9714973521AA1B85F0F40D31EF38492CCA659BE58BF5 +none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF diff --git a/src/environment.rs b/src/environment.rs index 810f9e9e..1a7860e5 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -106,7 +106,7 @@ impl EnvironmentBuilder { /// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled. /// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled. /// - /// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.0/onnxruntime/core/platform/windows/telemetry.cc). + /// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.1/onnxruntime/core/platform/windows/telemetry.cc). /// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or /// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names, /// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to diff --git a/src/memory.rs b/src/memory.rs index bc7644ca..5ffc85b0 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -133,7 +133,7 @@ impl Drop for Allocator { /// Represents possible devices that have their own device allocator. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AllocationDevice { - // https://github.com/microsoft/onnxruntime/blob/v1.18.0/include/onnxruntime/core/framework/allocator.h#L43-L53 + // https://github.com/microsoft/onnxruntime/blob/v1.18.1/include/onnxruntime/core/framework/allocator.h#L43-L53 // ort will likely never support WebGPU, so I think it's best to leave `WebGPU_Buffer` out entirely to reduce confusion CPU, CUDA, From e1d77b499eb35656194880d6b27168b3e7a7c3e0 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 22:19:40 -0500 Subject: [PATCH 33/49] refactor: simplify `run_async` --- src/session/async.rs | 64 ++++++++------------------------------------ 1 file changed, 11 insertions(+), 53 deletions(-) diff --git a/src/session/async.rs b/src/session/async.rs index 4bc338f7..a63ea483 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -2,14 +2,10 @@ use std::{ cell::UnsafeCell, ffi::{c_char, CString}, future::Future, - mem::MaybeUninit, ops::Deref, pin::Pin, ptr::NonNull, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex - }, + sync::{Arc, Mutex}, task::{Context, Poll, Waker} }; @@ -17,47 +13,26 @@ use ort_sys::{c_void, OrtStatus}; use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; -pub(crate) enum InnerValue { - Present(T), - Pending, - Closed -} - -const VALUE_PRESENT: usize = 1 << 0; -const CHANNEL_CLOSED: usize = 1 << 1; - #[derive(Debug)] pub(crate) struct InferenceFutInner<'r, 's> { - presence: AtomicUsize, - value: UnsafeCell>>>, + value: UnsafeCell>>>, waker: Mutex> } impl<'r, 's> InferenceFutInner<'r, 's> { pub(crate) fn new() -> Self { InferenceFutInner { - presence: AtomicUsize::new(0), waker: Mutex::new(None), - value: UnsafeCell::new(MaybeUninit::uninit()) + value: UnsafeCell::new(None) } } - pub(crate) fn try_take(&self) -> InnerValue>> { - let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire); - if state_snapshot & VALUE_PRESENT == 0 { - if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 { - InnerValue::Closed - } else { - InnerValue::Pending - } - } else { - InnerValue::Present(unsafe { (*self.value.get()).assume_init_read() }) - } + pub(crate) fn try_take(&self) -> Option>> { + unsafe { &mut *self.value.get() }.take() } pub(crate) fn emplace_value(&self, value: Result>) { - unsafe { (*self.value.get()).write(value) }; - self.presence.fetch_or(VALUE_PRESENT, Ordering::Release); + unsafe { &mut *self.value.get() }.replace(value); } pub(crate) fn set_waker(&self, waker: Option<&Waker>) { @@ -69,18 +44,6 @@ impl<'r, 's> InferenceFutInner<'r, 's> { waker.wake(); } } - - pub(crate) fn close(&self) -> bool { - self.presence.fetch_or(CHANNEL_CLOSED, Ordering::Acquire) & CHANNEL_CLOSED == 0 - } -} - -impl<'r, 's> Drop for InferenceFutInner<'r, 's> { - fn drop(&mut self) { - if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 { - unsafe { (*self.value.get()).assume_init_drop() }; - } - } } unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {} @@ -136,24 +99,19 @@ impl<'s, 'r> Future for InferenceFut<'s, 'r> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = Pin::into_inner(self); - match this.inner.try_take() { - InnerValue::Present(v) => { - this.did_receive = true; - return Poll::Ready(v); - } - InnerValue::Pending => {} - InnerValue::Closed => panic!() - }; + if let Some(v) = this.inner.try_take() { + this.did_receive = true; + return Poll::Ready(v); + } this.inner.set_waker(Some(cx.waker())); - Poll::Pending } } impl<'s, 'r> Drop for InferenceFut<'s, 'r> { fn drop(&mut self) { - if !self.did_receive && self.inner.close() { + if !self.did_receive { let _ = self.run_options.terminate(); self.inner.set_waker(None); } From 920cee9427c1e3eaf00a517b43df19104b0dbedb Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 23:10:57 -0500 Subject: [PATCH 34/49] refactor: simplify logging function --- src/environment.rs | 36 ++++++------------------------------ 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index 1a7860e5..6bc60380 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -268,45 +268,21 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder { EnvironmentBuilder::new() } -/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct. -#[derive(Debug)] -struct CodeLocation<'a> { - file: &'a str, - line: &'a str, - function: &'a str -} - -impl<'a> From<&'a str> for CodeLocation<'a> { - fn from(code_location: &'a str) -> Self { - let mut splitter = code_location.split(' '); - let file_and_line = splitter.next().unwrap_or(":"); - let function = splitter.next().unwrap_or(""); - let mut file_and_line_splitter = file_and_line.split(':'); - let file = file_and_line_splitter.next().unwrap_or(""); - let line = file_and_line_splitter.next().unwrap_or(""); - - CodeLocation { file, line, function } - } -} - extern_system_fn! { /// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate. - pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, category: *const c_char, _: *const c_char, code_location: *const c_char, message: *const c_char) { - assert_ne!(category, ptr::null()); - let category = unsafe { CStr::from_ptr(category) }.to_str().unwrap_or(""); + pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, _: *const c_char, id: *const c_char, code_location: *const c_char, message: *const c_char) { assert_ne!(code_location, ptr::null()); - let code_location_str = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); + let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); assert_ne!(message, ptr::null()); let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or(""); + assert_ne!(id, ptr::null()); + let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or(""); - let code_location = CodeLocation::from(code_location_str); let span = tracing::span!( Level::TRACE, "ort", - category = category, - file = code_location.file, - line = code_location.line, - function = code_location.function + id = id, + location = code_location ); match severity { From 3b93e73b278b96ec765ae131520eca76d68eadf9 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 23:11:58 -0500 Subject: [PATCH 35/49] chore: remove unused imports --- src/execution_providers/migraphx.rs | 12 ++++++++---- src/session/builder.rs | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/execution_providers/migraphx.rs b/src/execution_providers/migraphx.rs index ebcb50db..2eb02ff5 100644 --- a/src/execution_providers/migraphx.rs +++ b/src/execution_providers/migraphx.rs @@ -1,7 +1,7 @@ -use std::{ffi::CString, ptr}; +use std::ffi::CString; use super::ExecutionProvider; -use crate::{ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; #[derive(Debug, Default, Clone)] pub struct MIGraphXExecutionProvider { @@ -68,9 +68,13 @@ impl ExecutionProvider for MIGraphXExecutionProvider { migraphx_fp16_enable: self.enable_fp16.into(), migraphx_int8_enable: self.enable_int8.into(), migraphx_use_native_calibration_table: self.use_native_calibration_table.into(), - migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) + migraphx_int8_calibration_table_name: self + .int8_calibration_table_name + .as_ref() + .map(|c| c.as_ptr()) + .unwrap_or_else(std::ptr::null) }; - ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider]; + crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider]; return Ok(()); } diff --git a/src/session/builder.rs b/src/session/builder.rs index 7d654c2a..e105fd7d 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -1,7 +1,11 @@ +#[cfg(any(feature = "operator-libraries", not(windows)))] +use std::ffi::CString; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(target_family = "windows")] use std::os::windows::ffi::OsStrExt; +#[cfg(not(target_arch = "wasm32"))] +use std::path::Path; #[cfg(feature = "fetch-models")] use std::path::PathBuf; use std::{ @@ -11,8 +15,6 @@ use std::{ rc::Rc, sync::{atomic::Ordering, Arc} }; -#[cfg(not(target_arch = "wasm32"))] -use std::{ffi::CString, path::Path}; use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}; #[cfg(feature = "fetch-models")] From a127d0f372262357650d82065ff4af1ea4bc80b3 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 23:50:20 -0500 Subject: [PATCH 36/49] refactor: typestate for `RunOptions` that have selected outputs --- src/io_binding.rs | 6 +- src/lib.rs | 8 +- src/session/async.rs | 32 ++++---- src/session/mod.rs | 45 ++++++------ src/session/run_options.rs | 146 ++++++++++++++++++++++++++++++++++--- src/value/mod.rs | 2 +- 6 files changed, 182 insertions(+), 57 deletions(-) diff --git a/src/io_binding.rs b/src/io_binding.rs index 0467cbb7..f0a704a8 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -12,7 +12,7 @@ use crate::{ ortsys, session::{output::SessionOutputs, RunOptions}, value::{Value, ValueInner}, - DynValue, Error, Result, Session, ValueTypeMarker + DynValue, Error, NoSelectedOutputs, Result, Session, ValueTypeMarker }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. @@ -177,11 +177,11 @@ impl<'s> IoBinding<'s> { } /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. - pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { + pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { self.run_inner(Some(run_options)) } - fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { + fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { let run_options_ptr = if let Some(run_options) = run_options { run_options.run_options_ptr.as_ptr() } else { diff --git a/src/lib.rs b/src/lib.rs index b8c9176f..1dd8204b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,8 +59,8 @@ pub use self::operator::{ InferShapeFn, Operator, OperatorDomain }; pub use self::session::{ - GraphOptimizationLevel, InMemorySession, Input, Output, OutputSelector, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs, - SessionOutputs, SharedSessionInner + GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions, + SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner }; #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] @@ -69,8 +69,8 @@ pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, Tensor pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, - SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef, - ValueRefMut, ValueType, ValueTypeMarker + SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, Value, + ValueRef, ValueRefMut, ValueType, ValueTypeMarker }; #[cfg(not(all(target_arch = "x86", target_os = "windows")))] diff --git a/src/session/async.rs b/src/session/async.rs index a63ea483..db13a84e 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -11,7 +11,7 @@ use std::{ use ort_sys::{c_void, OrtStatus}; -use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; +use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; #[derive(Debug)] pub(crate) struct InferenceFutInner<'r, 's> { @@ -49,25 +49,25 @@ impl<'r, 's> InferenceFutInner<'r, 's> { unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {} unsafe impl<'r, 's> Sync for InferenceFutInner<'r, 's> {} -pub enum RunOptionsRef<'r> { - Arc(Arc), - Ref(&'r RunOptions) +pub enum RunOptionsRef<'r, O: SelectedOutputMarker> { + Arc(Arc>), + Ref(&'r RunOptions) } -impl<'r> From<&Arc> for RunOptionsRef<'r> { - fn from(value: &Arc) -> Self { +impl<'r, O: SelectedOutputMarker> From<&Arc>> for RunOptionsRef<'r, O> { + fn from(value: &Arc>) -> Self { Self::Arc(Arc::clone(value)) } } -impl<'r> From<&'r RunOptions> for RunOptionsRef<'r> { - fn from(value: &'r RunOptions) -> Self { +impl<'r, O: SelectedOutputMarker> From<&'r RunOptions> for RunOptionsRef<'r, O> { + fn from(value: &'r RunOptions) -> Self { Self::Ref(value) } } -impl<'r> Deref for RunOptionsRef<'r> { - type Target = RunOptions; +impl<'r, O: SelectedOutputMarker> Deref for RunOptionsRef<'r, O> { + type Target = RunOptions; fn deref(&self) -> &Self::Target { match self { @@ -77,14 +77,14 @@ impl<'r> Deref for RunOptionsRef<'r> { } } -pub struct InferenceFut<'s, 'r> { +pub struct InferenceFut<'s, 'r, O: SelectedOutputMarker> { inner: Arc>, - run_options: RunOptionsRef<'r>, + run_options: RunOptionsRef<'r, O>, did_receive: bool } -impl<'s, 'r> InferenceFut<'s, 'r> { - pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r>) -> Self { +impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, O> { + pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r, O>) -> Self { Self { inner, run_options, @@ -93,7 +93,7 @@ impl<'s, 'r> InferenceFut<'s, 'r> { } } -impl<'s, 'r> Future for InferenceFut<'s, 'r> { +impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, O> { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -109,7 +109,7 @@ impl<'s, 'r> Future for InferenceFut<'s, 'r> { } } -impl<'s, 'r> Drop for InferenceFut<'s, 'r> { +impl<'s, 'r, O: SelectedOutputMarker> Drop for InferenceFut<'s, 'r, O> { fn drop(&mut self) { if !self.did_receive { let _ = self.run_options.terminate(); diff --git a/src/session/mod.rs b/src/session/mod.rs index 865e09ba..fd78a116 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,8 +2,6 @@ use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; -use r#async::RunOptionsRef; - use super::{ char_p_to_string, environment::Environment, @@ -21,13 +19,13 @@ pub(crate) mod builder; pub(crate) mod input; pub(crate) mod output; mod run_options; -use self::r#async::{AsyncInferenceContext, InferenceFutInner}; +use self::r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef}; pub use self::{ r#async::InferenceFut, builder::{GraphOptimizationLevel, SessionBuilder}, input::{SessionInputValue, SessionInputs}, output::SessionOutputs, - run_options::{OutputSelector, RunOptions} + run_options::{HasSelectedOutputs, NoSelectedOutputs, OutputSelector, RunOptions, SelectedOutputMarker} }; /// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator. @@ -164,14 +162,16 @@ impl Session { pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) + self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) } SessionInputs::ValueArray(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) - } - SessionInputs::ValueMap(input_values) => { - self.run_inner(&input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), input_values.iter().map(|(_, v)| v), None) + self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) } + SessionInputs::ValueMap(input_values) => self.run_inner::( + &input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), + input_values.iter().map(|(_, v)| v), + None + ) } } @@ -201,10 +201,10 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, const N: usize>( + pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>( &'s self, input_values: impl Into>, - run_options: &'r RunOptions + run_options: &'r RunOptions ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { @@ -219,11 +219,11 @@ impl Session { } } - fn run_inner<'i, 'r, 's: 'r, 'v: 'i>( + fn run_inner<'i, 'r, 's: 'r, 'v: 'i, O: SelectedOutputMarker>( &'s self, input_names: &[&str], input_values: impl Iterator>, - run_options: Option<&'r RunOptions> + run_options: Option<&'r RunOptions> ) -> Result> { let input_names_ptr: Vec<*const c_char> = input_names .iter() @@ -321,7 +321,7 @@ impl Session { pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>( &'s self, input_values: impl Into> + 'static - ) -> Result> { + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { @@ -335,11 +335,11 @@ impl Session { /// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`]. /// See [`Session::run_with_options`] and [`Session::run_async`] for more details. - pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, const N: usize>( + pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>( &'s self, input_values: impl Into> + 'static, - run_options: &'r RunOptions - ) -> Result> { + run_options: &'r RunOptions + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { @@ -353,17 +353,20 @@ impl Session { } } - fn run_inner_async<'s, 'v: 's, 'r>( + fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>( &'s self, input_names: &[String], input_values: impl Iterator>, - run_options: Option<&'r RunOptions> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let run_options = match run_options { Some(r) => RunOptionsRef::Ref(r), // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial // (performance-wise) for routines involving `tokio::select!` or timeouts - None => RunOptionsRef::Arc(Arc::new(RunOptions::new()?)) + None => RunOptionsRef::Arc(Arc::new(unsafe { + // SAFETY: transmuting from `RunOptions` to `RunOptions`; safe because its just a marker + std::mem::transmute(RunOptions::new()?) + })) }; let input_name_ptrs: Vec<*const c_char> = input_names diff --git a/src/session/run_options.rs b/src/session/run_options.rs index fa5ef214..ae222cb6 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -2,6 +2,30 @@ use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, use crate::{ortsys, DynValue, Error, Output, Result, Value, ValueTypeMarker}; +/// Allows selecting/deselecting/preallocating the outputs of a [`crate::Session`] inference call. +/// +/// ``` +/// # use std::sync::Arc; +/// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; +/// # fn main() -> ort::Result<()> { +/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; +/// +/// let output0 = session.outputs[0].name.as_str(); +/// let options = RunOptions::new()?.with_outputs( +/// // Disable all outputs... +/// OutputSelector::no_default() +/// // except for the first one... +/// .with(output0) +/// // and since this is a 2x upsampler model, pre-allocate the output to be twice as large. +/// .preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) +/// ); +/// +/// // `outputs[0]` will be the tensor we just pre-allocated. +/// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] pub struct OutputSelector { use_defaults: bool, @@ -11,6 +35,8 @@ pub struct OutputSelector { } impl Default for OutputSelector { + /// Creates an [`OutputSelector`] that enables all outputs by default. Use [`OutputSelector::without`] to disable a + /// specific output. fn default() -> Self { Self { use_defaults: true, @@ -22,6 +48,8 @@ impl Default for OutputSelector { } impl OutputSelector { + /// Creates an [`OutputSelector`] that does not enable any outputs. Use [`OutputSelector::with`] to enable a + /// specific output. pub fn no_default() -> Self { Self { use_defaults: false, @@ -29,16 +57,46 @@ impl OutputSelector { } } + /// Mark the output specified by the `name` for inclusion. pub fn with(mut self, name: impl Into) -> Self { self.allowlist.push(name.into()); self } + /// Mark the output specified by `name` to be **excluded**. ONNX Runtime may prune some of the output node's + /// ancestor nodes. pub fn without(mut self, name: impl Into) -> Self { self.default_blocklist.push(name.into()); self } + /// Pre-allocates an output. Assuming the type & shape of the value matches what is expected by the model, the + /// output value corresponding to `name` returned by the inference call will be the exact same value as the + /// pre-allocated value. + /// + /// **The same value will be reused as long as this [`OutputSelector`] and its parent [`RunOptions`] is used**, so + /// if you use the same `RunOptions` across multiple runs with a preallocated value, the preallocated value will be + /// overwritten upon each run. + /// + /// This can improve performance if the size and type of the output is known, and does not change between runs, i.e. + /// for an ODE or embeddings model. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; + /// + /// let output0 = session.outputs[0].name.as_str(); + /// let options = RunOptions::new()?.with_outputs( + /// OutputSelector::default().preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) + /// ); + /// + /// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; + /// # Ok(()) + /// # } + /// ``` pub fn preallocate(mut self, name: impl Into, value: Value) -> Self { self.preallocated_outputs.insert(name.into(), value.into_dyn()); self @@ -62,32 +120,96 @@ impl OutputSelector { } } -/// A structure which can be passed to [`crate::Session::run_with_options`] to allow terminating/unterminating a session -/// inference run from a different thread. +/// Types that specify whether a [`RunOptions`] was configured with an [`OutputSelector`]. +pub trait SelectedOutputMarker {} +/// Marks that a [`RunOptions`] was not configured with an [`OutputSelector`]. +pub struct NoSelectedOutputs; +impl SelectedOutputMarker for NoSelectedOutputs {} +/// Marks that a [`RunOptions`] was configured with an [`OutputSelector`]. +pub struct HasSelectedOutputs; +impl SelectedOutputMarker for HasSelectedOutputs {} + +/// Allows for finer control over session inference. +/// +/// [`RunOptions`] provides three main features: +/// - **Run tagging**: Each individual session run can have a uniquely identifiable tag attached with +/// [`RunOptions::set_tag`], which will show up in logs. This can be especially useful for debugging +/// performance/errors in inference servers. +/// - **Termination**: Allows for terminating an inference call from another thread; when [`RunOptions::terminate`] is +/// called, any sessions currently running under that [`RunOptions`] instance will halt graph execution as soon as the +/// termination signal is received. This allows for [`crate::Session::run_async`]'s cancel-safety. +/// - **Output specification**: Certain session outputs can be [disabled](`OutputSelector::without`) or +/// [pre-allocated](`OutputSelector::preallocate`). Disabling an output might mean ONNX Runtime will not execute parts +/// of the graph that are only used by that output. Pre-allocation can reduce expensive re-allocations by allowing you +/// to use the same memory across runs. +/// +/// [`RunOptions`] can be passed to most places where a session can be inferred, e.g. +/// [`crate::Session::run_with_options`], [`crate::Session::run_async_with_options`], +/// [`crate::IoBinding::run_with_options`]. Some of these patterns (notably `IoBinding`) do not accept +/// [`OutputSelector`], hence [`RunOptions`] contains an additional type parameter that marks whether or not outputs +/// have been selected. #[derive(Debug)] -pub struct RunOptions { +pub struct RunOptions { pub(crate) run_options_ptr: NonNull, - pub(crate) outputs: OutputSelector + pub(crate) outputs: OutputSelector, + _marker: PhantomData } // https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 -unsafe impl Send for RunOptions {} -unsafe impl Sync for RunOptions {} +unsafe impl Send for RunOptions {} +// Only allow `Sync` if we don't have (potentially pre-allocated) outputs selected. +// Allowing `Sync` here would mean a single pre-allocated `Value` could be mutated simultaneously in different threads - +// a brazen crime against crabkind. +unsafe impl Sync for RunOptions {} impl RunOptions { /// Creates a new [`RunOptions`] struct. - pub fn new() -> Result { + pub fn new() -> Result> { let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; - Ok(Self { + Ok(RunOptions { run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, - outputs: OutputSelector::default() + outputs: OutputSelector::default(), + _marker: PhantomData }) } +} - pub fn with_outputs(mut self, outputs: OutputSelector) -> Self { +impl RunOptions { + /// Select/deselect/preallocate outputs for this run. + /// + /// See [`OutputSelector`] for more details. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; + /// + /// let output0 = session.outputs[0].name.as_str(); + /// let options = RunOptions::new()?.with_outputs( + /// // Disable all outputs... + /// OutputSelector::no_default() + /// // except for the first one... + /// .with(output0) + /// // and since this is a 2x upsampler model, pre-allocate the output to be twice as large. + /// .preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) + /// ); + /// + /// // `outputs[0]` will be the tensor we just pre-allocated. + /// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; + /// # Ok(()) + /// # } + /// ``` + pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions { self.outputs = outputs; - self + unsafe { std::mem::transmute(self) } + } + + /// Sets a tag to identify this run in logs. + pub fn with_tag(mut self, tag: impl AsRef) -> Result { + self.set_tag(tag).map(|_| self) } /// Sets a tag to identify this run in logs. @@ -158,7 +280,7 @@ impl RunOptions { } } -impl Drop for RunOptions { +impl Drop for RunOptions { fn drop(&mut self) { ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; } diff --git a/src/value/mod.rs b/src/value/mod.rs index cf506009..b68b3014 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -16,7 +16,7 @@ pub use self::{ impl_sequence::{ DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker }, - impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker} + impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} }; use crate::{error::status_to_result, memory::MemoryInfo, ortsys, session::SharedSessionInner, tensor::TensorElementType, Error, Result}; From dc79ade39d3221238b0251319f8f9b8988a5c121 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 1 Jul 2024 00:05:35 -0500 Subject: [PATCH 37/49] chore: more readable imports when auto-importing, rust-analyzer seems to only do absolute imports on odd-numbered days and Tuesdays, falling back to crate:: imports otherwise. this commit adds some consistency. --- src/environment.rs | 3 ++- src/error.rs | 2 +- src/execution_providers/acl.rs | 7 +++++-- src/execution_providers/armnn.rs | 7 +++++-- src/execution_providers/cann.rs | 7 +++++-- src/execution_providers/coreml.rs | 7 +++++-- src/execution_providers/cpu.rs | 8 ++++++-- src/execution_providers/cuda.rs | 7 +++++-- src/execution_providers/directml.rs | 7 +++++-- src/execution_providers/migraphx.rs | 7 +++++-- src/execution_providers/mod.rs | 7 ++++++- src/execution_providers/nnapi.rs | 7 +++++-- src/execution_providers/onednn.rs | 7 +++++-- src/execution_providers/openvino.rs | 7 +++++-- src/execution_providers/qnn.rs | 7 +++++-- src/execution_providers/rocm.rs | 7 +++++-- src/execution_providers/tensorrt.rs | 7 +++++-- src/execution_providers/tvm.rs | 7 +++++-- src/execution_providers/xnnpack.rs | 7 +++++-- src/io_binding.rs | 6 +++--- src/memory.rs | 9 +++++---- src/metadata.rs | 8 ++++++-- src/operator/io.rs | 2 +- src/operator/kernel.rs | 7 ++++++- src/operator/mod.rs | 7 +++++-- src/session/async.rs | 6 +++++- src/session/builder.rs | 5 +++-- src/session/input.rs | 8 +++----- src/session/mod.rs | 2 +- src/session/output.rs | 2 +- src/session/run_options.rs | 7 ++++++- src/tensor/types.rs | 5 ++++- src/value/impl_map.rs | 9 ++++++--- src/value/impl_sequence.rs | 8 ++++++-- src/value/impl_tensor/create.rs | 11 +++++------ src/value/impl_tensor/extract.rs | 7 +++++-- src/value/impl_tensor/mod.rs | 9 +++++++-- src/value/mod.rs | 8 +++++++- 38 files changed, 175 insertions(+), 76 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index 6bc60380..9b945326 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -15,7 +15,8 @@ use tracing::{debug, Level}; use crate::G_ORT_DYLIB_PATH; use crate::{ error::{Error, Result}, - extern_system_fn, ortsys, ExecutionProviderDispatch + execution_providers::ExecutionProviderDispatch, + extern_system_fn, ortsys }; struct EnvironmentSingleton { diff --git a/src/error.rs b/src/error.rs index fb25f204..c7ef7d22 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,7 @@ use std::{convert::Infallible, ffi::CString, io, path::PathBuf, ptr, string}; use thiserror::Error; -use super::{char_p_to_string, ortsys, tensor::TensorElementType, ValueType}; +use crate::{char_p_to_string, ortsys, tensor::TensorElementType, value::ValueType}; /// Type alias for the Result type returned by ORT functions. pub type Result = std::result::Result; diff --git a/src/execution_providers/acl.rs b/src/execution_providers/acl.rs index 1f15ac70..ddce6299 100644 --- a/src/execution_providers/acl.rs +++ b/src/execution_providers/acl.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "acl"))] extern "C" { diff --git a/src/execution_providers/armnn.rs b/src/execution_providers/armnn.rs index 86332f01..c428feb8 100644 --- a/src/execution_providers/armnn.rs +++ b/src/execution_providers/armnn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "armnn"))] extern "C" { diff --git a/src/execution_providers/cann.rs b/src/execution_providers/cann.rs index f37a2f1b..91895681 100644 --- a/src/execution_providers/cann.rs +++ b/src/execution_providers/cann.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] diff --git a/src/execution_providers/coreml.rs b/src/execution_providers/coreml.rs index 256de1e5..2fa4aa77 100644 --- a/src/execution_providers/coreml.rs +++ b/src/execution_providers/coreml.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "coreml"))] extern "C" { diff --git a/src/execution_providers/cpu.rs b/src/execution_providers/cpu.rs index eb4be919..06e031b8 100644 --- a/src/execution_providers/cpu.rs +++ b/src/execution_providers/cpu.rs @@ -1,5 +1,9 @@ -use super::ExecutionProvider; -use crate::{error::status_to_result, ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{status_to_result, Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + ortsys, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct CPUExecutionProvider { diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index 17fbe825..67cad84c 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; /// The type of search done for cuDNN convolution algorithms. #[derive(Debug, Clone)] diff --git a/src/execution_providers/directml.rs b/src/execution_providers/directml.rs index 38556f11..085e68f0 100644 --- a/src/execution_providers/directml.rs +++ b/src/execution_providers/directml.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "directml"))] extern "C" { diff --git a/src/execution_providers/migraphx.rs b/src/execution_providers/migraphx.rs index 2eb02ff5..d3cc62aa 100644 --- a/src/execution_providers/migraphx.rs +++ b/src/execution_providers/migraphx.rs @@ -1,7 +1,10 @@ use std::ffi::CString; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct MIGraphXExecutionProvider { diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index c7f49370..47fb855f 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -1,6 +1,11 @@ use std::{fmt::Debug, os::raw::c_char, sync::Arc}; -use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder}; +use crate::{ + char_p_to_string, + error::{Error, Result}, + ortsys, + session::SessionBuilder +}; mod cpu; pub use self::cpu::CPUExecutionProvider; diff --git a/src/execution_providers/nnapi.rs b/src/execution_providers/nnapi.rs index 9f1951ef..68d275af 100644 --- a/src/execution_providers/nnapi.rs +++ b/src/execution_providers/nnapi.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "nnapi"))] extern "C" { diff --git a/src/execution_providers/onednn.rs b/src/execution_providers/onednn.rs index 795d0e66..45dec270 100644 --- a/src/execution_providers/onednn.rs +++ b/src/execution_providers/onednn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "onednn"))] extern "C" { diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index 95dc8e26..61924c53 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -1,7 +1,10 @@ use std::os::raw::c_void; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub struct OpenVINOExecutionProvider { diff --git a/src/execution_providers/qnn.rs b/src/execution_providers/qnn.rs index eb7075d5..54ee71b4 100644 --- a/src/execution_providers/qnn.rs +++ b/src/execution_providers/qnn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub enum QNNExecutionProviderPerformanceMode { diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index be4cfdea..3c3553be 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -1,7 +1,10 @@ use std::os::raw::c_void; -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub struct ROCmExecutionProvider { diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index e60e16f0..1ea8dd8d 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct TensorRTExecutionProvider { diff --git a/src/execution_providers/tvm.rs b/src/execution_providers/tvm.rs index 19c8ea7a..6a04d940 100644 --- a/src/execution_providers/tvm.rs +++ b/src/execution_providers/tvm.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "tvm"))] extern "C" { diff --git a/src/execution_providers/xnnpack.rs b/src/execution_providers/xnnpack.rs index 87933260..bd3763e0 100644 --- a/src/execution_providers/xnnpack.rs +++ b/src/execution_providers/xnnpack.rs @@ -1,7 +1,10 @@ use std::num::NonZeroUsize; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct XNNPACKExecutionProvider { diff --git a/src/io_binding.rs b/src/io_binding.rs index f0a704a8..e3b9b76e 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -8,11 +8,11 @@ use std::{ }; use crate::{ + error::{Error, Result}, memory::MemoryInfo, ortsys, - session::{output::SessionOutputs, RunOptions}, - value::{Value, ValueInner}, - DynValue, Error, NoSelectedOutputs, Result, Session, ValueTypeMarker + session::{output::SessionOutputs, NoSelectedOutputs, RunOptions, Session}, + value::{DynValue, Value, ValueInner, ValueTypeMarker} }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. diff --git a/src/memory.rs b/src/memory.rs index 5ffc85b0..74eb7703 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -4,11 +4,12 @@ use std::{ sync::Arc }; -use super::{ - error::{Error, Result}, - ortsys +use crate::{ + char_p_to_string, + error::{status_to_result, Error, Result}, + ortsys, + session::{Session, SharedSessionInner} }; -use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner}; /// A device allocator used to manage the allocation of [`crate::Value`]s. /// diff --git a/src/metadata.rs b/src/metadata.rs index 5464e5f9..84fc69e2 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -1,7 +1,11 @@ use std::{ffi::CString, os::raw::c_char, ptr::NonNull}; -use super::{char_p_to_string, error::Result, ortsys, Error}; -use crate::Allocator; +use crate::{ + char_p_to_string, + error::{Error, Result}, + memory::Allocator, + ortsys +}; /// Container for model metadata, including name & producer information. pub struct ModelMetadata<'s> { diff --git a/src/operator/io.rs b/src/operator/io.rs index 5a7507a8..16d0e93e 100644 --- a/src/operator/io.rs +++ b/src/operator/io.rs @@ -1,4 +1,4 @@ -use crate::{MemoryType, TensorElementType}; +use crate::{memory::MemoryType, tensor::TensorElementType}; #[repr(i32)] #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index db09aa28..8c9280c7 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -3,7 +3,12 @@ use std::{ ptr::{self, NonNull} }; -use crate::{error::status_to_result, ortsys, value::ValueRefMut, Allocator, DowncastableTarget, DynValue, Error, Result, Value, ValueRef}; +use crate::{ + error::{status_to_result, Error, Result}, + memory::Allocator, + ortsys, + value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut} +}; pub trait Kernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()>; diff --git a/src/operator/mod.rs b/src/operator/mod.rs index ad361f29..207a74d9 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -8,11 +8,14 @@ pub(crate) mod io; pub(crate) mod kernel; use self::{ - bound::ErasedBoundOperator, + bound::{BoundOperator, ErasedBoundOperator}, io::{OperatorInput, OperatorOutput}, kernel::{DummyKernel, Kernel, KernelAttributes} }; -use crate::{operator::bound::BoundOperator, ortsys, Error, Result}; +use crate::{ + error::{Error, Result}, + ortsys +}; pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>; diff --git a/src/session/async.rs b/src/session/async.rs index db13a84e..c02a8eb2 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -11,7 +11,11 @@ use std::{ use ort_sys::{c_void, OrtStatus}; -use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; +use crate::{ + error::{assert_non_null_pointer, Error, Result}, + session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner}, + value::Value +}; #[derive(Debug)] pub(crate) struct InferenceFutInner<'r, 's> { diff --git a/src/session/builder.rs b/src/session/builder.rs index e105fd7d..8632c57d 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -23,8 +23,9 @@ use crate::{ environment::get_environment, error::{assert_non_null_pointer, status_to_result, Error, Result}, execution_providers::{apply_execution_providers, ExecutionProviderDispatch}, - memory::Allocator, - ortsys, MemoryInfo, OperatorDomain + memory::{Allocator, MemoryInfo}, + operator::OperatorDomain, + ortsys }; /// Creates a session using the builder pattern. diff --git a/src/session/input.rs b/src/session/input.rs index 61d55e5c..31a1433f 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -1,9 +1,6 @@ use std::{borrow::Cow, collections::HashMap, ops::Deref}; -use crate::{ - value::{DynValueTypeMarker, ValueTypeMarker}, - Value, ValueRef, ValueRefMut -}; +use crate::value::{DynValueTypeMarker, Value, ValueRef, ValueRefMut, ValueTypeMarker}; pub enum SessionInputValue<'v> { ViewMut(ValueRefMut<'v, DynValueTypeMarker>), @@ -140,7 +137,8 @@ macro_rules! inputs { mod tests { use std::{collections::HashMap, sync::Arc}; - use crate::{DynTensor, SessionInputs}; + use super::SessionInputs; + use crate::value::DynTensor; #[test] fn test_hashmap_static_keys() -> crate::Result<()> { diff --git a/src/session/mod.rs b/src/session/mod.rs index fd78a116..53d9518e 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,7 +2,7 @@ use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; -use super::{ +use crate::{ char_p_to_string, environment::Environment, error::{assert_non_null_pointer, assert_null_pointer, status_to_result, Error, ErrorInternal, Result}, diff --git a/src/session/output.rs b/src/session/output.rs index 74e13324..2409d6b6 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -4,7 +4,7 @@ use std::{ ops::{Deref, DerefMut, Index} }; -use crate::{Allocator, DynValue}; +use crate::{memory::Allocator, value::DynValue}; /// The outputs returned by a [`crate::Session`] inference call. /// diff --git a/src/session/run_options.rs b/src/session/run_options.rs index ae222cb6..42368923 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -1,6 +1,11 @@ use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc}; -use crate::{ortsys, DynValue, Error, Output, Result, Value, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + ortsys, + session::Output, + value::{DynValue, Value, ValueTypeMarker} +}; /// Allows selecting/deselecting/preallocating the outputs of a [`crate::Session`] inference call. /// diff --git a/src/tensor/types.rs b/src/tensor/types.rs index aabe6839..f5f0a1ab 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -2,7 +2,10 @@ use std::ptr; #[cfg(feature = "ndarray")] -use crate::{ortsys, Error, Result}; +use crate::{ + error::{Error, Result}, + ortsys +}; /// Enum mapping ONNX Runtime's supported tensor data types. #[derive(Debug, PartialEq, Eq, Clone, Copy)] diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index f6387876..87d653f4 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -7,12 +7,15 @@ use std::{ sync::Arc }; -use super::{ValueInner, ValueTypeMarker}; +use super::{ + impl_tensor::{calculate_tensor_size, DynTensor, Tensor}, + DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker +}; use crate::{ + error::{Error, Result}, memory::Allocator, ortsys, - value::impl_tensor::{calculate_tensor_size, DynTensor}, - DynValue, Error, IntoTensorElementType, PrimitiveTensorElementType, Result, Tensor, TensorElementType, Value, ValueRef, ValueRefMut, ValueType + tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType} }; pub trait MapValueTypeMarker: ValueTypeMarker { diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index 2923e27d..8b209e15 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -5,8 +5,12 @@ use std::{ sync::Arc }; -use super::{DowncastableTarget, ValueInner, ValueTypeMarker}; -use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType}; +use super::{DowncastableTarget, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + memory::Allocator, + ortsys +}; pub trait SequenceValueTypeMarker: ValueTypeMarker { crate::private_trait!(); diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 8a2e9502..391f4480 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -10,14 +10,13 @@ use std::{ #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, CowArray, Dimension}; -use super::{DynTensor, Tensor}; +use super::{calculate_tensor_size, DynTensor, Tensor, TensorRefMut}; use crate::{ - error::assert_non_null_pointer, - memory::{Allocator, MemoryInfo}, + error::{assert_non_null_pointer, Error, Result}, + memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, ortsys, - tensor::{TensorElementType, Utf8Data}, - value::{impl_tensor::calculate_tensor_size, ValueInner}, - AllocationDevice, AllocatorType, DynValue, Error, MemoryType, PrimitiveTensorElementType, Result, TensorRefMut, Value + tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data}, + value::{DynValue, Value, ValueInner} }; impl Tensor { diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index bec573f4..1b27d8ab 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -3,11 +3,14 @@ use std::{fmt::Debug, ptr, string::FromUtf8Error}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; -use super::TensorValueTypeMarker; +use super::{calculate_tensor_size, Tensor, TensorValueTypeMarker}; #[cfg(feature = "ndarray")] use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; use crate::{ - ortsys, tensor::TensorElementType, value::impl_tensor::calculate_tensor_size, Error, PrimitiveTensorElementType, Result, Tensor, Value, ValueType + error::{Error, Result}, + ortsys, + tensor::{PrimitiveTensorElementType, TensorElementType}, + value::{Value, ValueType} }; impl Value { diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index a4c7ba69..92a08c96 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -8,8 +8,13 @@ use std::{ ptr::NonNull }; -use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker}; -use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType}; +use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + memory::MemoryInfo, + ortsys, + tensor::IntoTensorElementType +}; pub trait TensorValueTypeMarker: ValueTypeMarker { crate::private_trait!(); diff --git a/src/value/mod.rs b/src/value/mod.rs index b68b3014..3aa57ef9 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -18,7 +18,13 @@ pub use self::{ }, impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} }; -use crate::{error::status_to_result, memory::MemoryInfo, ortsys, session::SharedSessionInner, tensor::TensorElementType, Error, Result}; +use crate::{ + error::{status_to_result, Error, Result}, + memory::MemoryInfo, + ortsys, + session::SharedSessionInner, + tensor::TensorElementType +}; /// The type of a [`Value`], or a session input/output. /// From b69f41d2ffa477b9286cbc15fccf97ab83139732 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 2 Jul 2024 00:39:02 -0500 Subject: [PATCH 38/49] feat: most missing custom operator methods --- src/error.rs | 14 ++++- src/operator/kernel.rs | 139 +++++++++++++++++++++++++++++++++++++++-- src/session/mod.rs | 26 +------- src/value/mod.rs | 49 ++++++++------- 4 files changed, 172 insertions(+), 56 deletions(-) diff --git a/src/error.rs b/src/error.rs index c7ef7d22..661e3325 100644 --- a/src/error.rs +++ b/src/error.rs @@ -240,8 +240,16 @@ pub enum Error { GetOperatorInput(ErrorInternal), #[error("Failed to get operator output: {0}")] GetOperatorOutput(ErrorInternal), + #[error("Failed to get operator node name: {0}")] + GetOperatorNodeName(ErrorInternal), #[error("Failed to retrieve GPU compute stream from kernel context: {0}")] - GetOperatorGPUComputeStream(ErrorInternal), + GetKernelGPUComputeStream(ErrorInternal), + #[error("Failed to retrieve EP resource from kernel context: {0}")] + GetKernelResource(ErrorInternal), + #[error("Failed to create allocator in kernel context: {0}")] + GetKernelAllocator(ErrorInternal), + #[error("Failed to allocate temporary buffer in kernel context: {0}")] + GetKernelBuffer(ErrorInternal), #[error("{0}")] CustomError(#[from] Box), #[error("String tensors cannot be borrowed as mutable")] @@ -280,8 +288,8 @@ pub enum ErrorInternal { /// Details about the error. #[error("{0}")] Msg(String), - /// Converting the ONNX error message to UTF-8 failed. - #[error("an error occurred, but ort failed to convert the error message to UTF-8")] + /// Converting an FFI string to UTF-8 failed. + #[error("failed to convert string to UTF-8: {0}")] IntoStringError(std::ffi::IntoStringError) } diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index 8c9280c7..eda4be14 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -1,13 +1,15 @@ use std::{ ffi::{c_char, CString}, + ops::{Deref, DerefMut}, ptr::{self, NonNull} }; use crate::{ - error::{status_to_result, Error, Result}, - memory::Allocator, + error::{status_to_result, Error, ErrorInternal, Result}, + memory::{Allocator, MemoryInfo}, ortsys, - value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut} + session::{Input, Output}, + value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType} }; pub trait Kernel { @@ -34,6 +36,61 @@ impl KernelAttributes { let name = CString::new(name.as_ref()).ok()?; T::get_from(self.0.as_ptr(), name.as_ptr()) } + + pub fn inputs(&self) -> Result> { + let mut num_inputs: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetInputCount(self.0.as_ptr(), &mut num_inputs) -> Error::GetOperatorInput]; + + let mut inputs = Vec::with_capacity(num_inputs as _); + for idx in 0..num_inputs as usize { + let mut name_len: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len) -> Error::GetOperatorInput]; + let mut name = vec![0u8; name_len as _]; + ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorInput]; + let name = CString::from_vec_with_nul(name) + .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? + .into_string() + .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e)))?; + let mut type_info = ptr::null_mut(); + ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx, &mut type_info) -> Error::GetOperatorInput; nonNull(type_info)]; + let input_type = ValueType::from_type_info(type_info)?; + inputs.push(Input { name, input_type }) + } + Ok(inputs) + } + + pub fn outputs(&self) -> Result> { + let mut num_outputs: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetOutputCount(self.0.as_ptr(), &mut num_outputs) -> Error::GetOperatorOutput]; + + let mut outputs = Vec::with_capacity(num_outputs as _); + for idx in 0..num_outputs as usize { + let mut name_len: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len) -> Error::GetOperatorOutput]; + let mut name = vec![0u8; name_len as _]; + ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorOutput]; + let name = CString::from_vec_with_nul(name) + .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? + .into_string() + .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e)))?; + let mut type_info = ptr::null_mut(); + ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx, &mut type_info) -> Error::GetOperatorOutput; nonNull(type_info)]; + let output_type = ValueType::from_type_info(type_info)?; + outputs.push(Output { name, output_type }) + } + Ok(outputs) + } + + pub fn node_name(&self) -> Result { + let mut name_len: ort_sys::size_t = 0; + ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), ptr::null_mut(), &mut name_len) -> Error::GetOperatorNodeName]; + let mut name = vec![0u8; name_len as _]; + ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorNodeName]; + CString::from_vec_with_nul(name) + .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? + .into_string() + .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e))) + } } pub trait GetKernelAttribute<'s> { @@ -120,6 +177,33 @@ impl<'s, T: DowncastableTarget> GetKernelAttribute<'s> for ValueRef<'s, T> { } } +pub struct ScratchBuffer { + allocator: Allocator, + buffer: *mut T, + size: usize +} + +impl Deref for ScratchBuffer { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + unsafe { std::slice::from_raw_parts(self.buffer.cast_const(), self.size) } + } +} +impl DerefMut for ScratchBuffer { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { std::slice::from_raw_parts_mut(self.buffer, self.size) } + } +} + +impl Drop for ScratchBuffer { + fn drop(&mut self) { + unsafe { + self.allocator.free(self.buffer); + } + } +} + pub struct KernelContext { ptr: NonNull } @@ -144,12 +228,59 @@ impl KernelContext { Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) }))) } + pub fn num_inputs(&self) -> Result { + let mut num: ort_sys::size_t = 0; + ortsys![unsafe KernelContext_GetInputCount(self.ptr.as_ptr(), &mut num) -> Error::GetOperatorInput]; + Ok(num as _) + } + + pub fn num_outputs(&self) -> Result { + let mut num: ort_sys::size_t = 0; + ortsys![unsafe KernelContext_GetOutputCount(self.ptr.as_ptr(), &mut num) -> Error::GetOperatorOutput]; + Ok(num as _) + } + + pub fn allocator(&self, memory_info: &MemoryInfo) -> Result { + let mut allocator_ptr = ptr::null_mut(); + ortsys![unsafe KernelContext_GetAllocator(self.ptr.as_ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::GetKernelAllocator]; + println!("allocator ptr {allocator_ptr:?}"); + Ok(unsafe { Allocator::from_raw_unchecked(allocator_ptr) }) + } + + pub fn get_resource(&self, id: ort_sys::c_int, version: ort_sys::c_int) -> Result>> { + let mut resource_ptr: *mut ort_sys::c_void = ptr::null_mut(); + ortsys![unsafe KernelContext_GetResource(self.ptr.as_ptr(), version, id, &mut resource_ptr) -> Error::GetKernelResource]; + Ok(NonNull::new(resource_ptr)) + } + + // TODO: STATUS_ACCESS_VIOLATION inside `KernelContext_GetScratchBuffer`. gonna assume this one is just an internal ONNX + // Runtime bug. + // + // pub fn allocate(&self, memory_info: &MemoryInfo, len: usize) -> Result> { + // let mut buffer = ptr::null_mut(); + // let allocator = self.allocator(memory_info)?; + // ortsys![ + // unsafe KernelContext_GetScratchBuffer( + // self.ptr.as_ptr(), + // memory_info.ptr.as_ptr(), + // (len * std::mem::size_of::()) as ort_sys::size_t, + // &mut buffer + // ) -> Error::GetKernelBuffer; + // nonNull(buffer) + // ]; + // Ok(ScratchBuffer { + // allocator, + // buffer: buffer.cast::(), + // size: len + // }) + // } + /// Returns a pointer to the GPU compute stream (i.e. `cudaStream_t`) used by the execution provider, if this /// kernel's operator was configured to use said execution provider (see /// [`super::Operator::execution_provider_type`]). pub fn compute_stream(&self) -> Result>> { let mut stream_ptr: *mut ort_sys::c_void = ptr::null_mut(); - ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetOperatorGPUComputeStream]; + ortsys![unsafe KernelContext_GetGPUComputeStream(self.ptr.as_ptr(), &mut stream_ptr) -> Error::GetKernelGPUComputeStream]; Ok(NonNull::new(stream_ptr)) } } diff --git a/src/session/mod.rs b/src/session/mod.rs index 53d9518e..e31793e3 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -448,7 +448,6 @@ unsafe impl Sync for Session {} mod dangerous { use super::*; - use crate::value::{extract_data_type_from_map_info, extract_data_type_from_sequence_info, extract_data_type_from_tensor_info}; pub(super) fn extract_inputs_count(session_ptr: NonNull) -> Result { let f = ortsys![unsafe SessionGetInputCount]; @@ -545,29 +544,6 @@ mod dangerous { status_to_result(status).map_err(Error::GetTypeInfo)?; assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?; - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - let status = ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; - status_to_result(status).map_err(Error::GetOnnxTypeFromTypeInfo)?; - let io_type = match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_tensor_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { - let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToSequenceTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_sequence_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_map_info(info_ptr)? } - } - _ => unreachable!() - }; - - ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; - Ok(io_type) + ValueType::from_type_info(typeinfo_ptr) } } diff --git a/src/value/mod.rs b/src/value/mod.rs index 3aa57ef9..2dad7ffb 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -91,6 +91,30 @@ pub enum ValueType { } impl ValueType { + pub(crate) fn from_type_info(typeinfo_ptr: *mut ort_sys::OrtTypeInfo) -> Result { + let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; + ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty) -> Error::GetOnnxTypeFromTypeInfo]; + let io_type = match ty { + ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { + let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)]; + unsafe { extract_data_type_from_tensor_info(info_ptr)? } + } + ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { + let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToSequenceTypeInfo; nonNull(info_ptr)]; + unsafe { extract_data_type_from_sequence_info(info_ptr)? } + } + ort_sys::ONNXType::ONNX_TYPE_MAP => { + let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); + ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)]; + unsafe { extract_data_type_from_map_info(info_ptr)? } + } + _ => unreachable!() + }; + ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; + Ok(io_type) + } /// Returns the dimensions of this value type if it is a tensor, or `None` if it is a sequence or map. /// /// ``` @@ -348,30 +372,7 @@ impl Value { pub fn dtype(&self) -> Result { let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); ortsys![unsafe GetTypeInfo(self.ptr(), &mut typeinfo_ptr) -> Error::GetTypeInfo; nonNull(typeinfo_ptr)]; - - let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty) -> Error::GetOnnxTypeFromTypeInfo]; - let io_type = match ty { - ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { - let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToTensorInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_tensor_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => { - let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToSequenceTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_sequence_info(info_ptr)? } - } - ort_sys::ONNXType::ONNX_TYPE_MAP => { - let mut info_ptr: *const ort_sys::OrtMapTypeInfo = std::ptr::null_mut(); - ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr, &mut info_ptr) -> Error::CastTypeInfoToMapTypeInfo; nonNull(info_ptr)]; - unsafe { extract_data_type_from_map_info(info_ptr)? } - } - _ => unreachable!() - }; - - ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)]; - Ok(io_type) + ValueType::from_type_info(typeinfo_ptr) } /// Construct a [`Value`] from a C++ [`ort_sys::OrtValue`] pointer. From e59d0c74db0e3797cf7dc082fd3546d3a665c3f7 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 2 Jul 2024 00:43:52 -0500 Subject: [PATCH 39/49] fix arm build --- src/operator/kernel.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index eda4be14..15c8a9f8 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -44,15 +44,15 @@ impl KernelAttributes { let mut inputs = Vec::with_capacity(num_inputs as _); for idx in 0..num_inputs as usize { let mut name_len: ort_sys::size_t = 0; - ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len) -> Error::GetOperatorInput]; + ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx as _, ptr::null_mut(), &mut name_len) -> Error::GetOperatorInput]; let mut name = vec![0u8; name_len as _]; - ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorInput]; + ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx as _, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorInput]; let name = CString::from_vec_with_nul(name) .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? .into_string() .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e)))?; let mut type_info = ptr::null_mut(); - ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx, &mut type_info) -> Error::GetOperatorInput; nonNull(type_info)]; + ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info) -> Error::GetOperatorInput; nonNull(type_info)]; let input_type = ValueType::from_type_info(type_info)?; inputs.push(Input { name, input_type }) } @@ -66,15 +66,15 @@ impl KernelAttributes { let mut outputs = Vec::with_capacity(num_outputs as _); for idx in 0..num_outputs as usize { let mut name_len: ort_sys::size_t = 0; - ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len) -> Error::GetOperatorOutput]; + ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx as _, ptr::null_mut(), &mut name_len) -> Error::GetOperatorOutput]; let mut name = vec![0u8; name_len as _]; - ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorOutput]; + ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx as _, name.as_mut_ptr().cast::(), &mut name_len) -> Error::GetOperatorOutput]; let name = CString::from_vec_with_nul(name) .map_err(|e| Error::FfiStringConversion(ErrorInternal::Msg(e.to_string())))? .into_string() .map_err(|e| Error::FfiStringConversion(ErrorInternal::IntoStringError(e)))?; let mut type_info = ptr::null_mut(); - ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx, &mut type_info) -> Error::GetOperatorOutput; nonNull(type_info)]; + ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info) -> Error::GetOperatorOutput; nonNull(type_info)]; let output_type = ValueType::from_type_info(type_info)?; outputs.push(Output { name, output_type }) } From 66b0cb2a23523f2ba89ca22e352f876aa8ed3e06 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Tue, 2 Jul 2024 08:50:32 -0500 Subject: [PATCH 40/49] fix(sys): only link nsync if required, closes #223 --- ort-sys/build.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index da15be08..dbf8cdba 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -257,8 +257,12 @@ fn prepare_libort_dir() -> (PathBuf, bool) { println!("cargo:rustc-link-lib=static=onnx"); println!("cargo:rustc-link-lib=static=onnx_proto"); - add_search_dir(transform_dep(external_lib_dir.join("google_nsync-build"), &profile)); - println!("cargo:rustc-link-lib=static=nsync_cpp"); + let nsync_path = transform_dep(external_lib_dir.join("google_nsync-build"), &profile); + // some builds of ONNX Runtime, particularly the default no-EP windows build, don't require nsync + if nsync_path.exists() { + add_search_dir(nsync_path); + println!("cargo:rustc-link-lib=static=nsync_cpp"); + } if target_arch != "wasm32" { add_search_dir(transform_dep(external_lib_dir.join("pytorch_cpuinfo-build"), &profile)); From bf10c1834962ec25e856afdde65d13e74b939f93 Mon Sep 17 00:00:00 2001 From: Julien Cretin Date: Wed, 3 Jul 2024 15:58:34 +0200 Subject: [PATCH 41/49] docs: Add Magika to projects using ort (#224) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0de9db92..128f921d 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ - **[Ortex](https://github.com/relaypro-open/ortex)** uses `ort` for safe ONNX Runtime bindings in Elixir. - **[Supabase](https://supabase.com/)** uses `ort` to remove cold starts for their edge functions. - **[Lantern](https://github.com/lanterndata/lantern_extras)** uses `ort` to provide embedding model inference inside Postgres. +- **[Magika](https://github.com/google/magika)** uses `ort` for content type detection. ## 🌠 Sponsor `ort` From 1bff72d9780490ff7856528e6bc3997528fd1714 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 6 Jul 2024 09:11:22 -0500 Subject: [PATCH 42/49] fix: TVM EP register function definition for non-`load-dynamic`, closes #227 --- src/execution_providers/tvm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/execution_providers/tvm.rs b/src/execution_providers/tvm.rs index 6a04d940..6e43601f 100644 --- a/src/execution_providers/tvm.rs +++ b/src/execution_providers/tvm.rs @@ -6,7 +6,7 @@ use crate::{ #[cfg(all(not(feature = "load-dynamic"), feature = "tvm"))] extern "C" { - fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr; + fn OrtSessionOptionsAppendExecutionProvider_Tvm(options: *mut ort_sys::OrtSessionOptions, opt_str: *const std::os::raw::c_char) -> ort_sys::OrtStatusPtr; } #[derive(Debug, Clone, Copy, PartialEq, Eq)] From 0407adb5cc8e1c31e6da344f22e937cda22f31fc Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 6 Jul 2024 09:46:12 -0500 Subject: [PATCH 43/49] fix: link missing absl libraries, closes #228 --- ort-sys/build.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index dbf8cdba..d1c5995a 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -281,6 +281,9 @@ fn prepare_libort_dir() -> (PathBuf, bool) { add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("base"), &profile)); println!("cargo:rustc-link-lib=static=absl_base"); + println!("cargo:rustc-link-lib=static=absl_spinlock_wait"); + println!("cargo:rustc-link-lib=static=absl_malloc_internal"); + println!("cargo:rustc-link-lib=static=absl_raw_logging_internal"); println!("cargo:rustc-link-lib=static=absl_throw_delegate"); add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("hash"), &profile)); println!("cargo:rustc-link-lib=static=absl_hash"); @@ -288,6 +291,23 @@ fn prepare_libort_dir() -> (PathBuf, bool) { println!("cargo:rustc-link-lib=static=absl_low_level_hash"); add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("container"), &profile)); println!("cargo:rustc-link-lib=static=absl_raw_hash_set"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("synchronization"), &profile)); + println!("cargo:rustc-link-lib=static=absl_kernel_timeout_internal"); + println!("cargo:rustc-link-lib=static=absl_graphcycles_internal"); + println!("cargo:rustc-link-lib=static=absl_synchronization"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("time"), &profile)); + println!("cargo:rustc-link-lib=static=absl_time_zone"); + println!("cargo:rustc-link-lib=static=absl_time"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("numeric"), &profile)); + println!("cargo:rustc-link-lib=static=absl_int128"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("strings"), &profile)); + println!("cargo:rustc-link-lib=static=absl_str_format_internal"); + println!("cargo:rustc-link-lib=static=absl_strings"); + println!("cargo:rustc-link-lib=static=absl_string_view"); + println!("cargo:rustc-link-lib=static=absl_strings_internal"); + add_search_dir(transform_dep(external_lib_dir.join("abseil_cpp-build").join("absl").join("debugging"), &profile)); + println!("cargo:rustc-link-lib=static=absl_symbolize"); + println!("cargo:rustc-link-lib=static=absl_stacktrace"); if cfg!(feature = "coreml") && (target_os == "macos" || target_os == "ios") { println!("cargo:rustc-link-lib=framework=CoreML"); From 0a43482d031f0cea7239642669fb6ca03cde4010 Mon Sep 17 00:00:00 2001 From: Carson M Date: Sat, 6 Jul 2024 11:07:27 -0500 Subject: [PATCH 44/49] feat: training (#202) --- .gitignore | 5 + Cargo.toml | 5 +- examples/training/Cargo.toml | 18 ++ examples/training/README.md | 26 ++ examples/training/build.rs | 5 + examples/training/examples/pretokenize.rs | 44 ++++ .../training/examples/train-clm-simple.rs | 118 +++++++++ examples/training/examples/train-clm.rs | 133 ++++++++++ ort-sys/Cargo.toml | 1 + ort-sys/build.rs | 32 ++- ort-sys/dist.txt | 14 + ort-sys/src/lib.rs | 110 +++++++- src/error.rs | 4 +- src/lib.rs | 6 + src/session/builder.rs | 19 +- src/training/mod.rs | 142 +++++++++++ src/training/simple.rs | 240 ++++++++++++++++++ src/training/trainer.rs | 235 +++++++++++++++++ src/util.rs | 26 ++ tools/requirements.txt | 4 + tools/train-data/mini-clm.py | 140 ++++++++++ 21 files changed, 1294 insertions(+), 33 deletions(-) create mode 100644 examples/training/Cargo.toml create mode 100644 examples/training/README.md create mode 100644 examples/training/build.rs create mode 100644 examples/training/examples/pretokenize.rs create mode 100644 examples/training/examples/train-clm-simple.rs create mode 100644 examples/training/examples/train-clm.rs create mode 100644 src/training/mod.rs create mode 100644 src/training/simple.rs create mode 100644 src/training/trainer.rs create mode 100644 src/util.rs create mode 100644 tools/requirements.txt create mode 100644 tools/train-data/mini-clm.py diff --git a/.gitignore b/.gitignore index b00a624d..6ab71818 100644 --- a/.gitignore +++ b/.gitignore @@ -186,6 +186,7 @@ WixTools/ # ONNX Runtime downloaded models **/*.onnx **/*.ort +**/*.pbseq !examples/webassembly/**/*.ort !tests/data/*.onnx !tests/data/*.ort @@ -196,4 +197,8 @@ WixTools/ # Glassbench results /glassbench*.db +# Python virtual environment .venv* + +# Training checkpoints +tools/train-data/**/checkpoint diff --git a/Cargo.toml b/Cargo.toml index e9c3ccf5..c1a70bb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ 'examples/model-info', 'examples/yolov8', 'examples/modnet', + 'examples/training', 'examples/webassembly' ] default-members = [ @@ -45,13 +46,15 @@ strip = true codegen-units = 1 [package.metadata.docs.rs] -features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] +features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] rustdoc-args = [ "--cfg", "docsrs" ] [features] default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ] +training = [ "ort-sys/training" ] + operator-libraries = [ "libc", "winapi" ] fetch-models = [ "ureq" ] diff --git a/examples/training/Cargo.toml b/examples/training/Cargo.toml new file mode 100644 index 00000000..945f62ea --- /dev/null +++ b/examples/training/Cargo.toml @@ -0,0 +1,18 @@ +[package] +publish = false +name = "example-training" +version = "0.0.0" +edition = "2021" + +[dependencies] +ort = { path = "../../", features = [ "training" ] } +ndarray = "0.15" +tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } +rand = "0.8" +simd-json = "0.13" +kdam = "0.5" +tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } + +[features] +load-dynamic = [ "ort/load-dynamic" ] +cuda = [ "ort/cuda" ] diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 00000000..7c99d643 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,26 @@ +# Training Examples + +## `train-clm` +This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device. + +To get started, create a Python virtual environment and install the following packages: +``` +pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3 +``` + +We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts. + +Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo. + +Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate. + +While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text: +``` +100%|██████████████████████████████████████| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611] +I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|> +``` + +Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute. + +## `train-clm-simple` +This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to 🤗 Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you! diff --git a/examples/training/build.rs b/examples/training/build.rs new file mode 100644 index 00000000..79d3a0bb --- /dev/null +++ b/examples/training/build.rs @@ -0,0 +1,5 @@ +fn main() { + // Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml + #[cfg(target_os = "macos")] + println!("cargo:rustc-link-arg=-fapple-link-rtlib"); +} diff --git a/examples/training/examples/pretokenize.rs b/examples/training/examples/pretokenize.rs new file mode 100644 index 00000000..79eee195 --- /dev/null +++ b/examples/training/examples/pretokenize.rs @@ -0,0 +1,44 @@ +use std::{ + env, + fs::File, + io::{BufRead, BufReader, BufWriter, Write}, + path::Path +}; + +use simd_json::derived::ValueObjectAccessAsScalar; +use tokenizers::Tokenizer; + +const MAX_TOKENS: usize = 500_000; + +fn main() { + let input = env::args().nth(1).expect("provide input jsonl"); + let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into()); + + let input = BufReader::new(File::open(input).unwrap()); + let mut output = BufWriter::new(File::create(output).unwrap()); + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + let mut bytes_written = 0; + + for line in input.lines() { + let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() }; + let tokenized = tokenizer + .encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false) + .unwrap(); + let id_bytes: Vec = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect(); + output.write_all(&id_bytes).unwrap(); + bytes_written += id_bytes.len(); + if bytes_written >= MAX_TOKENS * 2 { + output.flush().unwrap(); + break; + } + } +} diff --git a/examples/training/examples/train-clm-simple.rs b/examples/training/examples/train-clm-simple.rs new file mode 100644 index 00000000..0c3ac326 --- /dev/null +++ b/examples/training/examples/train-clm-simple.rs @@ -0,0 +1,118 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + let trainer = Trainer::new_from_artifacts( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + "tools/train-data/mini-clm", + None + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let dataloader = move |_: usize| { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + Ok(( + ort::inputs![Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?, + ort::inputs![Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]? + )) + }; + + trainer.train( + TrainingArguments::new(dataloader) + .with_lr(7e-5) + .with_max_steps(5000) + .with_ckpt_strategy(CheckpointStrategy::Steps(500)) + )?; + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs new file mode 100644 index 00000000..9e46bf44 --- /dev/null +++ b/examples/training/examples/train-clm.rs @@ -0,0 +1,133 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use kdam::BarExt; +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + kdam::term::init(true); + let _ = kdam::term::hide_cursor(); + + let trainer = Trainer::new( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + Checkpoint::load("tools/train-data/mini-clm/checkpoint")?, + "tools/train-data/mini-clm/training_model.onnx", + "tools/train-data/mini-clm/eval_model.onnx", + "tools/train-data/mini-clm/optimizer_model.onnx" + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let optimizer = trainer.optimizer(); + optimizer.set_lr(7e-5)?; + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut pb = kdam::tqdm!(total = 5000); + for _ in 0..5000 { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + let inputs = Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + let labels = Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + + let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; + let loss = outputs[0].try_extract_scalar::()?; + pb.set_postfix(format!("loss={loss:.3}")); + pb.update(1).unwrap(); + if loss.is_nan() { + return Ok(()); + } + optimizer.step()?; + optimizer.reset_grad()?; + } + + eprintln!(); + let _ = kdam::term::show_cursor(); + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 26e2aa81..015a4656 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -16,6 +16,7 @@ include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ] [features] default = [] +training = [] download-binaries = [ "ureq", "tar", "flate2", "sha2" ] load-dynamic = [] copy-dylibs = [] diff --git a/ort-sys/build.rs b/ort-sys/build.rs index d1c5995a..719a2059 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -37,12 +37,12 @@ fn fetch_file(source_url: &str) -> Vec { buffer } -fn find_dist(target: &str, designator: &str) -> Option<(&'static str, &'static str)> { +fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> { DIST_TABLE .split('\n') .filter(|c| !c.is_empty() && !c.starts_with('#')) .map(|c| c.split('\t').collect::>()) - .find(|c| c[0] == designator && c[1] == target) + .find(|c| c[0] == feature_set && c[1] == target) .map(|c| (c[2], c[3])) } @@ -341,23 +341,31 @@ fn prepare_libort_dir() -> (PathBuf, bool) { #[cfg(feature = "download-binaries")] { let target = env::var("TARGET").unwrap().to_string(); - let designator = if cfg!(any(feature = "cuda", feature = "tensorrt")) { - if lib_exists("cudart64_12.dll") || lib_exists("libcudart.so.12") { "cu12" } else { "cu11" } + + let mut feature_set = Vec::new(); + if cfg!(feature = "training") { + feature_set.push("train"); + } + if cfg!(any(feature = "cuda", feature = "tensorrt")) { + if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") { + feature_set.push("cu11"); + } else { + feature_set.push("cu12"); + } } else if cfg!(feature = "rocm") { - "rocm" - } else { - "none" - }; - let mut dist = find_dist(&target, designator); - if dist.is_none() && designator != "none" { + feature_set.push("rocm"); + } + let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() }; + let mut dist = find_dist(&target, &feature_set); + if dist.is_none() && feature_set != "none" { dist = find_dist(&target, "none"); } if dist.is_none() { panic!( "downloaded binaries not available for target {target}{}\nyou may have to compile ONNX Runtime from source", - if designator != "none" { - format!(" (note: also requested `{designator}`)") + if feature_set != "none" { + format!(" (note: also requested features `{feature_set}`)") } else { String::new() } diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 5ba397a3..98e3f3f0 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -4,12 +4,26 @@ cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/ rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz 84F74428E0BEC68C55B8E1E91B9282E984CD2866148A2584382B8CB3284214A3 none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-unknown-linux-gnu.tgz 0A193706A95286853D792D7D9B2271CBEA35C57F249943FE811CED97E0E24862 +train aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-unknown-linux-gnu.tgz C04DBEAF19F2BCD3643F8F7D7FA01110A1AF429DFDD1C1DC7C5EDA2B1A8AA324 +train,cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-unknown-linux-gnu.tgz A139D8AD8930930F5A61DF112C8275AAD1F0415FAFD08CE3031CEFFFC30F2445 +train,cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-unknown-linux-gnu.tgz 2DAA2E2CF44E9B9A96AB2E9C4271C35189C96BF264D1797DABCF1D6711730DE7 +train,rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz DD373BA6B251D21953223B2FBB64F4DF34CFE98A63C26D16607BEAC6BC788466 +train x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-unknown-linux-gnu.tgz 0E617970AE83ABE5FB9A3D5D69AAC9A67ED4C9494AD527B14A84FDC98CA9B924 + none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-pc-windows-msvc.tgz B2F962F0E75F17F3D657B3504CE891BAA6461B26AF65FBD9244B3CCA17FD79D4 cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12-v1.18.1-x86_64-pc-windows-msvc.tgz CDBC2D87B202E1847900E94796D102EE4D5C19A9568BBD014838ECD1F5D5350B cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu11-v1.18.1-x86_64-pc-windows-msvc.tgz B514FC25453F955F8592100448B27F5E1762A344E8C2D57D41B908978EF2A126 none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-pc-windows-msvc.tgz EB2BCD1778C5934437D4C5B17F67DEAF5F67D2E3C18C7298973EACD41113DC01 +train aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-pc-windows-msvc.tgz 8CC1FFFD8AB5E526A076C29A767A650C436E31179D0C6E52C2EA936067B72566 +train,cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-pc-windows-msvc.tgz 6AF64567E25B59AD1196D4953EF8C6A65795E8A4B864E10D8303A027AC50B2D0 +train,cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-pc-windows-msvc.tgz E14AA0F4FBBBCAF925AD4DB4F76B06402F654B36C5F221E00010D1005F47AE56 +train x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-pc-windows-msvc.tgz 84728438E5A950027EBBDC51463F4E5B99B4979087F0F127EA18BC604507E979 + none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-apple-darwin.tgz B42BE76AFB9495983A6D5D498D56D5E685B018F1011EF4C5B8C56124B192FD37 none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-apple-darwin.tgz 247F73A5B3665A6660DFB35213E6FEAAC6ED6CAC5816DD85A348DF790F60A30B +train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-apple-darwin.tgz 29DC09AFA5C3619CF3125F3D55DD64E5EE64451D6BD0044527776849AADEE344 +train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-apple-darwin.tgz 898EC9E3F852843ECDB618CF8E317F4C92BDEB33FC773038960857BCB37CB347 + none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index f5d3b130..f7cb853e 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -823,9 +823,117 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() { } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct OrtTrainingApi { +pub struct OrtTrainingSession { _unused: [u8; 0] } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtCheckpointState { + _unused: [u8; 0] +} +#[repr(i32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum OrtPropertyType { + OrtIntProperty = 0, + OrtFloatProperty = 1, + OrtStringProperty = 2 +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtTrainingApi { + pub LoadCheckpoint: + ::std::option::Option<_system!(unsafe fn(checkpoint_path: *const ortchar, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr)>, + pub SaveCheckpoint: ::std::option::Option< + _system!(unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ortchar, include_optimizer_state: bool) -> OrtStatusPtr) + >, + pub CreateTrainingSession: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_path: *const ortchar, + eval_model_path: *const ortchar, + optimizer_model_path: *const ortchar, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub CreateTrainingSessionFromBuffer: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_data: *const (), + train_data_length: size_t, + eval_model_data: *const (), + eval_data_length: size_t, + optimizer_model_data: *const (), + optimizer_data_length: size_t, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub TrainingSessionGetTrainingModelOutputCount: + ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) + >, + pub TrainingSessionGetEvalModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) + >, + pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub TrainStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub EvalStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub SetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: f32) -> OrtStatusPtr)>, + pub GetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: *mut f32) -> OrtStatusPtr)>, + pub OptimizerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, run_options: *const OrtRunOptions) -> OrtStatusPtr)>, + pub RegisterLinearLRScheduler: ::std::option::Option< + _system!(unsafe fn(session: *mut OrtTrainingSession, warmup_step_count: i64, total_step_count: i64, initial_lr: f32) -> OrtStatusPtr) + >, + pub SchedulerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub GetParametersSize: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, out: *mut size_t, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyParametersToBuffer: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyBufferToParameters: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>, + pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))>, + pub ExportModelForInferencing: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + inference_model_path: *const ortchar, + graph_outputs_len: usize, + graph_output_names: *const *const c_char + ) -> OrtStatusPtr + ) + > +} #[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"] #[repr(C)] #[derive(Debug, Copy, Clone)] diff --git a/src/error.rs b/src/error.rs index 661e3325..2e84580a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -261,7 +261,9 @@ pub enum Error { #[error("Could't get `AllocatorType` from memory info: {0}")] GetAllocatorType(ErrorInternal), #[error("Could't get device ID from memory info: {0}")] - GetDeviceId(ErrorInternal) + GetDeviceId(ErrorInternal), + #[error("Training API is not enabled in this build of ONNX Runtime.")] + TrainingNotEnabled } impl Error { diff --git a/src/lib.rs b/src/lib.rs index 1dd8204b..3455a60f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,9 @@ pub(crate) mod metadata; pub(crate) mod operator; pub(crate) mod session; pub(crate) mod tensor; +#[cfg(feature = "training")] +pub(crate) mod training; +pub(crate) mod util; pub(crate) mod value; #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] #[cfg(target_arch = "wasm32")] @@ -66,6 +69,9 @@ pub use self::session::{ #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; +#[cfg(feature = "training")] +#[cfg_attr(docsrs, doc(cfg(feature = "training")))] +pub use self::training::*; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, diff --git a/src/session/builder.rs b/src/session/builder.rs index 8632c57d..458c6ade 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -1,9 +1,5 @@ #[cfg(any(feature = "operator-libraries", not(windows)))] use std::ffi::CString; -#[cfg(unix)] -use std::os::unix::ffi::OsStrExt; -#[cfg(target_family = "windows")] -use std::os::windows::ffi::OsStrExt; #[cfg(not(target_arch = "wasm32"))] use std::path::Path; #[cfg(feature = "fetch-models")] @@ -316,20 +312,7 @@ impl SessionBuilder { }); } - // Build an OsString, then a vector of bytes to pass to C - let model_path = std::ffi::OsString::from(model_filepath); - #[cfg(target_family = "windows")] - let model_path: Vec = model_path - .encode_wide() - .chain(std::iter::once(0)) // Make sure we have a null terminated string - .collect(); - #[cfg(not(target_family = "windows"))] - let model_path: Vec = model_path - .as_encoded_bytes() - .iter() - .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string - .map(|b| *b as std::os::raw::c_char) - .collect(); + let model_path = crate::util::path_to_os_char(model_filepath); let env = get_environment()?; apply_execution_providers(&self, env.execution_providers.iter().cloned())?; diff --git a/src/training/mod.rs b/src/training/mod.rs new file mode 100644 index 00000000..d66db11d --- /dev/null +++ b/src/training/mod.rs @@ -0,0 +1,142 @@ +use std::{ + path::Path, + ptr::{self, NonNull}, + sync::{ + atomic::{AtomicPtr, Ordering}, + OnceLock + } +}; + +use crate::{ortsys, Error, Result, RunOptions}; + +mod simple; +mod trainer; + +pub use self::{ + simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments}, + trainer::Trainer +}; + +pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); + +/// Returns a pointer to the global [`ort_sys::OrtTrainingApi`] object, or errors if the Training API is not enabled. +/// +/// # Panics +/// May panic if: +/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. +/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +pub fn training_api() -> Result> { + NonNull::new( + TRAINING_API + .get_or_init(|| { + let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)]; + AtomicPtr::new(training_api.cast_mut()) + }) + .load(Ordering::Relaxed) + ) + .ok_or(Error::TrainingNotEnabled) +} + +macro_rules! trainsys { + ($method:ident) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) + }; + (unsafe $method:ident) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) } + }; + ($method:ident($($n:expr),+ $(,)?)) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) + }; + (unsafe $method:ident($($n:expr),+ $(,)?)) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) } + }; + ($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result($crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e) + }; + (unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result(unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e) + }; + ($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+); + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + let _x = unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }; + $($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+ + _x + }}; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + }; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }}; +} +pub(crate) use trainsys; + +#[derive(Debug)] +pub struct Checkpoint { + pub(crate) ptr: NonNull +} + +impl Checkpoint { + pub fn load(path: impl AsRef) -> Result { + let path = crate::util::path_to_os_char(path); + let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut(); + trainsys![unsafe LoadCheckpoint(path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + Ok(Checkpoint { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + } + + pub fn save(&self, path: impl AsRef, include_optimizer_state: bool) -> Result<()> { + let path = crate::util::path_to_os_char(path); + trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state) -> Error::CreateSession]; + Ok(()) + } +} + +impl Drop for Checkpoint { + fn drop(&mut self) { + tracing::trace!("dropping checkpoint"); + trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())]; + } +} + +#[derive(Debug)] +pub struct Optimizer(NonNull); + +impl Optimizer { + pub fn reset_grad(&self) -> Result<()> { + trainsys![unsafe LazyResetGrad(self.0.as_ptr()) -> Error::CreateSession]; + Ok(()) + } + + pub fn lr(&self) -> Result { + let mut lr = f32::NAN; + trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr) -> Error::CreateSession]; + Ok(lr) + } + + pub fn set_lr(&self, lr: f32) -> Result<()> { + trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr) -> Error::CreateSession]; + Ok(()) + } + + pub fn step(&self) -> Result<()> { + self.step_with_options(RunOptions::new()?) + } + + pub fn step_with_options(&self, options: RunOptions) -> Result<()> { + trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.run_options_ptr.as_ptr()) -> Error::CreateSession]; + Ok(()) + } +} diff --git a/src/training/simple.rs b/src/training/simple.rs new file mode 100644 index 00000000..267f3c64 --- /dev/null +++ b/src/training/simple.rs @@ -0,0 +1,240 @@ +use std::{collections::VecDeque, fs, path::PathBuf}; + +use crate::{Result, SessionInputs}; + +#[allow(clippy::len_without_is_empty)] +pub trait DataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)>; + + fn len(&self) -> Option { + None + } +} + +pub struct IterableDataLoader Result<(I, L)>> { + items: Box<[T]>, + collator: C +} + +impl Result<(I, L)>> DataLoader for IterableDataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self.collator)(&self.items[idx]) + } + + fn len(&self) -> Option { + Some(self.items.len()) + } +} + +pub fn iterable_data_loader Result<(I, L)>>(iterable: impl Iterator, collator: C) -> IterableDataLoader { + IterableDataLoader { + items: iterable.collect::>().into_boxed_slice(), + collator + } +} + +impl Result<(I, L)>> DataLoader for F { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self)(idx) + } + + fn len(&self) -> Option { + None + } +} + +pub enum EvaluationStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl EvaluationStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub enum CheckpointStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl CheckpointStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub struct TrainingArguments>, L: Into>, const NI: usize, const NL: usize> { + loader: Box>, + eval_loader: Option>>, + eval_strategy: EvaluationStrategy, + ckpt_strategy: CheckpointStrategy, + ckpt_path: PathBuf, + lr: f32, + max_saved_ckpts: usize, + gradient_accumulation_steps: usize, + max_steps: usize, + max_eval_steps: usize +} + +impl>, L: Into>, const NI: usize, const NL: usize> + TrainingArguments +{ + pub fn new + 'static>(train_loader: D) -> Self { + Self { + loader: Box::new(train_loader), + eval_loader: None, + eval_strategy: EvaluationStrategy::None, + ckpt_strategy: CheckpointStrategy::Epochs(1), + ckpt_path: PathBuf::from("checkpoints"), + lr: 1e-4, + gradient_accumulation_steps: 1, + max_saved_ckpts: 1, + max_steps: usize::MAX, + max_eval_steps: usize::MAX + } + } + + pub fn with_lr(mut self, lr: f32) -> Self { + self.lr = lr; + self + } + + pub fn with_max_steps(mut self, steps: usize) -> Self { + self.max_steps = steps; + self + } + + pub fn with_max_eval_steps(mut self, steps: usize) -> Self { + self.max_eval_steps = steps; + self + } + + pub fn with_gradient_accumulation(mut self, steps: usize) -> Self { + self.gradient_accumulation_steps = steps; + self + } + + pub fn with_ckpt_path(mut self, path: impl Into) -> Self { + self.ckpt_path = path.into(); + self + } + + pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self { + self.ckpt_strategy = strategy; + self + } + + pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self { + self.max_saved_ckpts = max_ckpts; + self + } + + pub fn with_eval_loader + 'static>(mut self, eval_loader: D) -> Self { + self.eval_loader = Some(Box::new(eval_loader)); + self + } + + pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self { + self.eval_strategy = strategy; + self + } +} + +impl super::Trainer { + pub fn train>, L: Into>, const NI: usize, const NL: usize>( + &self, + mut args: TrainingArguments + ) -> crate::Result<()> { + let optimizer = self.optimizer(); + optimizer.set_lr(args.lr)?; + + let mut saved_ckpts = VecDeque::new(); + let mut global_step = 0; + for (iter_step, _) in (0..args.max_steps).enumerate() { + let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX); + let (inputs, labels) = args.loader.load(iter_step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.step(inputs, labels)?; + let loss = outputs[0].try_extract_scalar::()?; + println!("epoch={epoch} step={global_step} loss={loss}"); + + if iter_step % args.gradient_accumulation_steps == 0 { + optimizer.step()?; + optimizer.reset_grad()?; + global_step += 1; + } + + if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) { + if !args.ckpt_path.exists() { + let _ = fs::create_dir_all(&args.ckpt_path); + } + + let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt")); + self.checkpoint().save(&ckpt_path, true)?; + + saved_ckpts.push_front(ckpt_path.clone()); + while saved_ckpts.len() > args.max_saved_ckpts { + let Some(old_ckpt) = saved_ckpts.pop_back() else { + break; + }; + let _ = fs::remove_file(old_ckpt); + } + } + + if args + .eval_strategy + .should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len())) + { + let eval_loss = self.eval_inner(&mut args)?; + println!("eval_loss={eval_loss}"); + } + } + Ok(()) + } + + pub(crate) fn eval_inner>, L: Into>, const NI: usize, const NL: usize>( + &self, + args: &mut TrainingArguments + ) -> crate::Result { + let Some(eval_loader) = &mut args.eval_loader else { + return Ok(0.0); + }; + + let mut total_loss = 0.0; + for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) { + let (inputs, labels) = eval_loader.load(step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.eval_step(inputs, labels)?; + let loss = outputs[0].try_extract_scalar::()?; + total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.); + } + + Ok(total_loss) + } +} diff --git a/src/training/trainer.rs b/src/training/trainer.rs new file mode 100644 index 00000000..f7c7cb38 --- /dev/null +++ b/src/training/trainer.rs @@ -0,0 +1,235 @@ +use std::{ + ffi::CString, + path::Path, + ptr::{self, NonNull}, + sync::Arc +}; + +use ort_sys::c_char; + +use super::{trainsys, Checkpoint, Optimizer}; +use crate::{ + char_p_to_string, + error::{assert_non_null_pointer, status_to_result}, + Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value +}; + +#[derive(Debug)] +pub struct Trainer { + pub(crate) ptr: NonNull, + train_output_names: Vec, + optimizer: Optimizer, + ckpt: Checkpoint, + _allocator: Allocator +} + +impl Trainer { + pub fn new( + session_options: SessionBuilder, + allocator: Allocator, + ckpt: Checkpoint, + training_model_path: impl AsRef, + eval_model_path: impl AsRef, + optimizer_model_path: impl AsRef + ) -> Result { + let training_model_path = crate::util::path_to_os_char(training_model_path); + let eval_model_path = crate::util::path_to_os_char(eval_model_path); + let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); + + let env = crate::get_environment()?; + + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + let mut train_output_len = 0; + trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession]; + let train_output_names = (0..train_output_len) + .map(|i| { + let mut name_bytes: *mut c_char = std::ptr::null_mut(); + trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession]; + let name = match char_p_to_string(name_bytes) { + Ok(name) => name, + Err(e) => { + unsafe { allocator.free(name_bytes) }; + return Err(e); + } + }; + unsafe { allocator.free(name_bytes) }; + Ok(name) + }) + .collect::>>()?; + + Ok(Self { + ptr, + _allocator: allocator, + train_output_names, + optimizer: Optimizer(ptr), + ckpt + }) + } + + pub fn new_from_artifacts( + session_options: SessionBuilder, + allocator: Allocator, + base_dir: impl AsRef, + override_ckpt: Option + ) -> Result { + let base_dir = base_dir.as_ref(); + let ckpt = if let Some(ckpt) = override_ckpt { + ckpt + } else { + Checkpoint::load(base_dir.join("checkpoint"))? + }; + Self::new( + session_options, + allocator, + ckpt, + base_dir.join("training_model.onnx"), + base_dir.join("eval_model.onnx"), + base_dir.join("optimizer_model.onnx") + ) + } + + pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { + let out_path = crate::util::path_to_os_char(out_path); + + let output_names_ptr: Vec<*const c_char> = output_names + .as_ref() + .iter() + .map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!())) + .map(|n| n.into_raw().cast_const()) + .collect(); + + let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())]; + + // Reconvert name ptrs to CString so drop impl is called and memory is freed + drop( + output_names_ptr + .into_iter() + .map(|p| { + assert_non_null_pointer(p, "c_char for CString")?; + unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } + }) + .collect::>>()? + ); + + status_to_result(res).map_err(Error::CreateSession)?; + + Ok(()) + } + + pub fn optimizer(&self) -> &Optimizer { + &self.optimizer + } + + pub fn checkpoint(&self) -> &Checkpoint { + &self.ckpt + } +} + +impl Drop for Trainer { + fn drop(&mut self) { + tracing::trace!("dropping trainer"); + trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 00000000..bfa11d98 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,26 @@ +#[cfg(not(target_family = "windows"))] +use std::os::raw::c_char; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(target_family = "windows")] +use std::os::windows::ffi::OsStrExt; +use std::{ffi::OsString, path::Path}; + +#[cfg(target_family = "windows")] +type OsCharArray = Vec; +#[cfg(not(target_family = "windows"))] +type OsCharArray = Vec; + +pub fn path_to_os_char(path: impl AsRef) -> OsCharArray { + let model_path = OsString::from(path.as_ref()); + #[cfg(target_family = "windows")] + let model_path: Vec = model_path.encode_wide().chain(std::iter::once(0)).collect(); + #[cfg(not(target_family = "windows"))] + let model_path: Vec = model_path + .as_encoded_bytes() + .iter() + .chain(std::iter::once(&b'\0')) + .map(|b| *b as c_char) + .collect(); + model_path +} diff --git a/tools/requirements.txt b/tools/requirements.txt new file mode 100644 index 00000000..d49cd910 --- /dev/null +++ b/tools/requirements.txt @@ -0,0 +1,4 @@ +torch~=2.3 +torch-ort~=1.17 +onnx~=1.16 +--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 diff --git a/tools/train-data/mini-clm.py b/tools/train-data/mini-clm.py new file mode 100644 index 00000000..6f06a70f --- /dev/null +++ b/tools/train-data/mini-clm.py @@ -0,0 +1,140 @@ +import math + +import onnx +from onnxruntime.training import artifacts +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +class RMSNorm(nn.Module): + def __init__(self, dim: int, *, eps: float = 1e-6): + super().__init__() + + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + if x.dtype != torch.float32: + xf = x.to(dtype=torch.float32) + else: + xf = x + output = (xf * torch.sqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)) + if x.dtype != torch.float32: + output = output.to(dtype=x.dtype) + return output * self.weight + +class RoPE(nn.Module): + def __init__(self, embedding_dim: int, *, max_seq_length: int = 2048, base: float = 10000.0): + super().__init__() + + pe = torch.zeros(max_seq_length, embedding_dim) + position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embedding_dim, step=2).float() * (-math.log(base) / embedding_dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe, persistent=False) + + @torch.no_grad() + def forward(self, x: Tensor) -> Tensor: + return x + self.pe[:, :x.shape[1], :] + +class Attention(nn.Module): + def __init__(self, embedding_dim: int, *, rope: RoPE, max_seq_length: int = 2048, n_heads: int = 4): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_heads = n_heads + self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) + self.proj = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.rope = rope + self.register_buffer('bias', torch.tril(torch.ones(max_seq_length, max_seq_length))[None, None, :, :], persistent=False) + + def forward(self, x: Tensor) -> Tensor: + b, t, c = x.size() + + x = self.rope(x) + + q, k, v = self.qkv(x).split(self.embedding_dim, dim=2) + q = q.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + k = k.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + v = v.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v + y = y.transpose(1, 2).contiguous().view(b, t, c) + + return self.proj(y) + +class FFN(nn.Module): + def __init__(self, embedding_dim: int, intermediate_dim: int | None = None): + super().__init__() + + intermediate_dim = intermediate_dim or embedding_dim * 4 + + self.w1 = nn.Linear(embedding_dim, intermediate_dim * 2, bias=False) + self.w2 = nn.Linear(intermediate_dim, embedding_dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x, gate = self.w1(x).chunk(2, dim=-1) + return self.w2(F.gelu(gate) * x) + +class Layer(nn.Module): + def __init__(self, embedding_dim: int, rope: RoPE): + super().__init__() + + self.attn = Attention(embedding_dim, rope=rope) + self.norm1 = RMSNorm(embedding_dim) + self.ffn = FFN(embedding_dim) + self.norm2 = RMSNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + +class CLM(nn.Module): + def __init__(self, embedding_dim: int, n_layers: int, *, vocab_size: int): + super().__init__() + + rope = RoPE(embedding_dim) + self.layers = nn.ModuleList([Layer(embedding_dim, rope=rope) for _ in range(n_layers)]) + self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) + self.norm = RMSNorm(embedding_dim) + self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.word_embeddings(x) + for layer in self.layers: + x = layer(x) + logits = self.lm_head(self.norm(x)) + return logits.view(-1, logits.size(-1)) + +lm = CLM(256, 4, vocab_size=50257) +torch.onnx.export( + lm, + torch.randint(0, 50256, (1, 64)), + f'tools/train-data/mini-clm/model.onnx', + input_names=['input_ids'], + output_names=['probs'], + dynamic_axes={ + 'input_ids': {0: 'batch', 1: 'seq'}, + 'probs': {0: 'batch_seq'} + }, + opset_version=14 +) + +onnx_model = onnx.load('tools/train-data/mini-clm/model.onnx') +requires_grad = [param.name for param in onnx_model.graph.initializer] + +artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=[], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory='tools/train-data/mini-clm' +) From 3dec0173e8a46eaab2a12555ca4ff3d8a44aba93 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 6 Jul 2024 12:16:29 -0500 Subject: [PATCH 45/49] 2.0.0-rc.3 --- Cargo.toml | 6 +++--- README.md | 2 +- docs/pages/_meta.json | 2 +- docs/pages/index.mdx | 4 ++-- docs/pages/migrating/v2.mdx | 2 +- docs/pages/perf/execution-providers.mdx | 2 +- docs/pages/setup/cargo-features.mdx | 4 ++-- docs/pages/setup/platforms.mdx | 2 +- ort-sys/Cargo.toml | 4 ++-- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c1a70bb0..173eff42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,8 @@ exclude = [ 'examples/cudarc' ] [package] name = "ort" -description = "A safe Rust wrapper for ONNX Runtime 1.17 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-rc.2" +description = "A safe Rust wrapper for ONNX Runtime 1.18 - Optimize and accelerate machine learning inference & training" +version = "2.0.0-rc.3" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" @@ -83,7 +83,7 @@ qnn = [ "ort-sys/qnn" ] [dependencies] ndarray = { version = "0.15", optional = true } thiserror = "1.0" -ort-sys = { version = "2.0.0-rc.2", path = "ort-sys" } +ort-sys = { version = "2.0.0-rc.3", path = "ort-sys" } libloading = { version = "0.8", optional = true } ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } diff --git a/README.md b/README.md index 128f921d..e3bdcc7d 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Crates.io ONNX Runtime -`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.18 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference on both CPU & GPU. +`ort` is an (unofficial) [ONNX Runtime](https://onnxruntime.ai/) 1.18 wrapper for Rust based on the now inactive [`onnxruntime-rs`](https://github.com/nbigaouette/onnxruntime-rs). ONNX Runtime accelerates ML inference and training on both CPU & GPU. ## 📖 Documentation - [Guide](https://ort.pyke.io/) diff --git a/docs/pages/_meta.json b/docs/pages/_meta.json index 14840b87..5fe05b28 100644 --- a/docs/pages/_meta.json +++ b/docs/pages/_meta.json @@ -10,7 +10,7 @@ }, "link-api": { "title": "API Reference ↗", - "href": "https://docs.rs/ort/2.0.0-rc.2/ort" + "href": "https://docs.rs/ort/2.0.0-rc.3/ort" }, "link-crates": { "title": "Crates.io ↗", diff --git a/docs/pages/index.mdx b/docs/pages/index.mdx index 97ab9f5e..d6add90e 100644 --- a/docs/pages/index.mdx +++ b/docs/pages/index.mdx @@ -11,7 +11,7 @@ import { Callout, Card, Cards, Steps } from 'nextra/components'; - These docs are for the latest alpha version of `ort`, `2.0.0-rc.2`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. + These docs are for the latest alpha version of `ort`, `2.0.0-rc.3`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. `ort` makes it easy to deploy your machine learning models to production via [ONNX Runtime](https://onnxruntime.ai/), a hardware-accelerated inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA) on almost any hardware, often far faster than PyTorch, and with the added bonus of Rust's efficiency. @@ -37,7 +37,7 @@ Converting a neural network to a graph representation like ONNX opens the door t If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: ```toml [dependencies] -ort = "2.0.0-rc.2" +ort = "2.0.0-rc.3" ``` ### Convert your model diff --git a/docs/pages/migrating/v2.mdx b/docs/pages/migrating/v2.mdx index a3c1afdf..c18776b8 100644 --- a/docs/pages/migrating/v2.mdx +++ b/docs/pages/migrating/v2.mdx @@ -173,7 +173,7 @@ let l = outputs["latents"].try_extract_tensor::()?; ``` ## Execution providers -Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.2/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. +Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.3/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. ```diff -// v1.x diff --git a/docs/pages/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx index b084463e..fbc59759 100644 --- a/docs/pages/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -83,7 +83,7 @@ fn main() -> anyhow::Result<()> { ``` ## Configuring EPs -EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.2/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. +EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.3/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. ```rust use ort::{CoreMLExecutionProvider, Session}; diff --git a/docs/pages/setup/cargo-features.mdx b/docs/pages/setup/cargo-features.mdx index 9fc92416..9e9f72ef 100644 --- a/docs/pages/setup/cargo-features.mdx +++ b/docs/pages/setup/cargo-features.mdx @@ -9,8 +9,8 @@ title: Cargo features - ✅ **`half`**: Enables support for float16 & bfloat16 tensors via the [`half`](https://crates.io/crates/half) crate. ONNX models that are converted to 16-bit precision will typically convert to/from 32-bit floats at the input/output, so you will likely never actually need to interact with a 16-bit tensor on the Rust side. Though, `half` isn't a heavy enough crate to worry about it affecting compile times. - ✅ **`copy-dylibs`**: In case dynamic libraries are used (like with the CUDA execution provider), creates a symlink to them in the relevant places in the `target` folder to make [compile-time dynamic linking](/setup/linking#compile-time-dynamic-linking) work. - ⚒️ **`load-dynamic`**: Enables [runtime dynamic linking](/setup/linking#runtime-loading-with-load-dynamic), which alleviates many of the troubles with compile-time dynamic linking and offers greater flexibility. -- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.2/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. -- ⚒️ **`operator-libraries`**: Allows for sessions to load custom operators from dynamic C++ libraries via [`SessionBuilder::with_operator_library`](https://docs.rs/ort/2.0.0-rc.2/ort/struct.SessionBuilder.html#method.with_operator_library). If possible, we recommend [writing your custom ops in Rust instead](/perf/custom-operators). +- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.3/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. +- ⚒️ **`operator-libraries`**: Allows for sessions to load custom operators from dynamic C++ libraries via [`SessionBuilder::with_operator_library`](https://docs.rs/ort/2.0.0-rc.3/ort/struct.SessionBuilder.html#method.with_operator_library). If possible, we recommend [writing your custom ops in Rust instead](/perf/custom-operators). ## Execution providers Each [execution provider](/perf/execution-providers) is also gated behind a Cargo feature. diff --git a/docs/pages/setup/platforms.mdx b/docs/pages/setup/platforms.mdx index f83d131b..6fbec097 100644 --- a/docs/pages/setup/platforms.mdx +++ b/docs/pages/setup/platforms.mdx @@ -5,7 +5,7 @@ description: ONNX Runtime, and by extension `ort`, supports a wide variety of pl import { Callout } from 'nextra/components'; -Here are the supported platforms and binary availability status, as of v2.0.0-rc.2. +Here are the supported platforms and binary availability status, as of v2.0.0-rc.3. * 🟢 - Supported. Dynamic & static binaries provided by pyke. * 🔷 - Supported. Static binaries provided by pyke. diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 015a4656..e07f561b 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "ort-sys" -description = "Unsafe Rust bindings for ONNX Runtime 1.17 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-rc.2" +description = "Unsafe Rust bindings for ONNX Runtime 1.18 - Optimize and Accelerate Machine Learning Inferencing" +version = "2.0.0-rc.3" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" From bb57252f1087a75abbf5bd442b171d8f6d48533a Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 7 Jul 2024 12:22:06 -0500 Subject: [PATCH 46/49] fix(sys): more robust CUDA version check. fixes #234 --- ort-sys/build.rs | 69 ++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 719a2059..465ed2a4 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -1,8 +1,12 @@ use std::{ env, fs, - path::{Path, PathBuf} + path::{Path, PathBuf}, + process::Command }; +#[allow(unused)] +const ONNXRUNTIME_VERSION: &str = "1.18.1"; + const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION"; const ORT_ENV_SYSTEM_LIB_PROFILE: &str = "ORT_LIB_PROFILE"; #[cfg(feature = "download-binaries")] @@ -46,29 +50,6 @@ fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static .map(|c| (c[2], c[3])) } -fn lib_exists(name: &str) -> bool { - #[cfg(any(target_family = "windows", unix))] - let lib_str = std::ffi::CString::new(name).unwrap(); - // note that we're not performing any cleanup here because this is a short lived build script; the OS will clean it up - // for us when we finish - #[cfg(target_family = "windows")] - return unsafe { - extern "C" { - fn LoadLibraryA(lplibfilename: *const std::ffi::c_char) -> isize; - } - LoadLibraryA(lib_str.as_ptr()) != 0 - }; - #[cfg(unix)] - return unsafe { - extern "C" { - fn dlopen(file: *const std::ffi::c_char, mode: std::ffi::c_int) -> *const std::ffi::c_void; - } - !dlopen(lib_str.as_ptr(), 1).is_null() - }; - #[cfg(not(any(target_family = "windows", unix)))] - return false; -} - #[cfg(feature = "download-binaries")] fn hex_str_to_bytes(c: impl AsRef<[u8]>) -> Vec { fn nibble(c: u8) -> u8 { @@ -136,7 +117,7 @@ fn copy_libraries(lib_dir: &Path, out_dir: &Path) { #[cfg(target_os = "linux")] { let main_dy = lib_dir.join("libonnxruntime.so"); - let versioned_dy = out_dir.join("libonnxruntime.so.1.17.3"); + let versioned_dy = out_dir.join(format!("libonnxruntime.so.{}", ONNXRUNTIME_VERSION)); if main_dy.exists() && !versioned_dy.exists() { if versioned_dy.is_symlink() { fs::remove_file(&versioned_dy).unwrap(); @@ -347,15 +328,38 @@ fn prepare_libort_dir() -> (PathBuf, bool) { feature_set.push("train"); } if cfg!(any(feature = "cuda", feature = "tensorrt")) { - if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") { - feature_set.push("cu11"); - } else { - feature_set.push("cu12"); + match env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() { + Ok("11") => feature_set.push("cu11"), + Ok("12") => feature_set.push("cu12"), + _ => { + let mut success = false; + if let Ok(nvcc_output) = Command::new("nvcc").arg("--version").output() { + if nvcc_output.status.success() { + let stdout = String::from_utf8_lossy(&nvcc_output.stdout); + let version_line = stdout.lines().nth(3).unwrap(); + let release_section = version_line.split(", ").nth(1).unwrap(); + let version_number = release_section.split(' ').nth(1).unwrap(); + if version_number.starts_with("12") { + feature_set.push("cu12"); + } else { + feature_set.push("cu11"); + } + success = true; + } + } + + if !success { + println!("cargo:warning=nvcc call did not succeed. falling back to CUDA 12"); + // fallback to CUDA 12. + feature_set.push("cu12"); + } + } } } else if cfg!(feature = "rocm") { feature_set.push("rocm"); } let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() }; + println!("selected feature set: {feature_set}"); let mut dist = find_dist(&target, &feature_set); if dist.is_none() && feature_set != "none" { dist = find_dist(&target, "none"); @@ -411,6 +415,13 @@ fn prepare_libort_dir() -> (PathBuf, bool) { fn try_setup_with_pkg_config() -> bool { match pkg_config::Config::new().probe("libonnxruntime") { Ok(lib) => { + let expected_minor = ONNXRUNTIME_VERSION.split('.').nth(1).unwrap().parse::().unwrap(); + let got_minor = lib.version.split('.').nth(1).unwrap().parse::().unwrap(); + if got_minor < expected_minor { + println!("libonnxruntime provided by pkg-config is out of date, so it will be ignored - expected {}, got {}", ONNXRUNTIME_VERSION, lib.version); + return false; + } + // Setting the link paths for path in lib.link_paths { println!("cargo:rustc-link-search=native={}", path.display()); From b2f4e1f7722b7d9a9c5948f4b7d3e4cb32c1525b Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 7 Jul 2024 12:26:55 -0500 Subject: [PATCH 47/49] fix: create an environment if one does not exist when registering TensorRT EP. ref #236 --- src/execution_providers/tensorrt.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index 1ea8dd8d..599eaf09 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -230,6 +230,11 @@ impl ExecutionProvider for TensorRTExecutionProvider { fn register(&self, session_builder: &SessionBuilder) -> Result<()> { #[cfg(any(feature = "load-dynamic", feature = "tensorrt"))] { + // The TensorRT execution provider specifically is pretty picky about requiring an environment to be initialized by the + // time we register it. This isn't always the case in `ort`, so if we get to this point, let's make sure we have an + // environment initialized. + let _ = crate::get_environment(); + let mut trt_options: *mut ort_sys::OrtTensorRTProviderOptionsV2 = std::ptr::null_mut(); crate::error::status_to_result(crate::ortsys![unsafe CreateTensorRTProviderOptions(&mut trt_options)]).map_err(Error::ExecutionProvider)?; let (key_ptrs, value_ptrs, len, keys, values) = super::map_keys! { From bc764d388f213a631db934029af6463f19cacd5e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 7 Jul 2024 14:54:40 -0500 Subject: [PATCH 48/49] feat(sys): CUDA 12 + cuDNN 8 builds, ref #235 --- ort-sys/build.rs | 16 ++++++++++++++-- ort-sys/dist.txt | 5 +++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 465ed2a4..a7bbcf0b 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -328,6 +328,18 @@ fn prepare_libort_dir() -> (PathBuf, bool) { feature_set.push("train"); } if cfg!(any(feature = "cuda", feature = "tensorrt")) { + // pytorch's CUDA docker images set `NV_CUDNN_VERSION` + let cu12_tag = match env::var("NV_CUDNN_VERSION").or_else(|_| env::var("ORT_CUDNN_VERSION")).as_deref() { + Ok(v) => { + if v.starts_with("8") { + "cu12+cudnn8" + } else { + "cu12" + } + } + Err(_) => "cu12" + }; + match env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() { Ok("11") => feature_set.push("cu11"), Ok("12") => feature_set.push("cu12"), @@ -340,7 +352,7 @@ fn prepare_libort_dir() -> (PathBuf, bool) { let release_section = version_line.split(", ").nth(1).unwrap(); let version_number = release_section.split(' ').nth(1).unwrap(); if version_number.starts_with("12") { - feature_set.push("cu12"); + feature_set.push(cu12_tag); } else { feature_set.push("cu11"); } @@ -351,7 +363,7 @@ fn prepare_libort_dir() -> (PathBuf, bool) { if !success { println!("cargo:warning=nvcc call did not succeed. falling back to CUDA 12"); // fallback to CUDA 12. - feature_set.push("cu12"); + feature_set.push(cu12_tag); } } } diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 98e3f3f0..9910a5e4 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -27,3 +27,8 @@ train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/mso train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-apple-darwin.tgz 898EC9E3F852843ECDB618CF8E317F4C92BDEB33FC773038960857BCB37CB347 none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF + +train,cu12+cudnn8 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12+cudnn8-v1.18.1-x86_64-pc-windows-msvc.tgz 52F02DBF276409DC49533373DE89B17FDE0CCB31F9974CCF31F250DC51258971 +train,cu12+cudnn8 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12+cudnn8-v1.18.1-x86_64-unknown-linux-gnu.tgz EE0580CA961CE512ECF7C1087FB081E74C780A494EAC95596CEF1089AB573242 +cu12+cudnn8 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12+cudnn8-v1.18.1-x86_64-unknown-linux-gnu.tgz F8D72E825F744A7A7BF2036591CBE6D1F30352DBE108BEEAD1745BD571566819 +cu12+cudnn8 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12+cudnn8-v1.18.1-x86_64-pc-windows-msvc.tgz D41121A6489B52EB7AF9614D7924AE984F4F10BF49F15E0C4FC2655649A978ED From 04da381f4da0127ce26e2f6ccab1e7e5a8cd1985 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 7 Jul 2024 15:50:08 -0500 Subject: [PATCH 49/49] 2.0.0-rc.4 --- Cargo.toml | 4 ++-- docs/pages/_meta.json | 2 +- docs/pages/index.mdx | 4 ++-- docs/pages/migrating/v2.mdx | 2 +- docs/pages/perf/execution-providers.mdx | 2 +- docs/pages/setup/cargo-features.mdx | 4 ++-- docs/pages/setup/platforms.mdx | 2 +- ort-sys/Cargo.toml | 2 +- 8 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 173eff42..28ed81c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ exclude = [ 'examples/cudarc' ] [package] name = "ort" description = "A safe Rust wrapper for ONNX Runtime 1.18 - Optimize and accelerate machine learning inference & training" -version = "2.0.0-rc.3" +version = "2.0.0-rc.4" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0" @@ -83,7 +83,7 @@ qnn = [ "ort-sys/qnn" ] [dependencies] ndarray = { version = "0.15", optional = true } thiserror = "1.0" -ort-sys = { version = "2.0.0-rc.3", path = "ort-sys" } +ort-sys = { version = "2.0.0-rc.4", path = "ort-sys" } libloading = { version = "0.8", optional = true } ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } diff --git a/docs/pages/_meta.json b/docs/pages/_meta.json index 5fe05b28..a58afd92 100644 --- a/docs/pages/_meta.json +++ b/docs/pages/_meta.json @@ -10,7 +10,7 @@ }, "link-api": { "title": "API Reference ↗", - "href": "https://docs.rs/ort/2.0.0-rc.3/ort" + "href": "https://docs.rs/ort/2.0.0-rc.4/ort" }, "link-crates": { "title": "Crates.io ↗", diff --git a/docs/pages/index.mdx b/docs/pages/index.mdx index d6add90e..8034d46c 100644 --- a/docs/pages/index.mdx +++ b/docs/pages/index.mdx @@ -11,7 +11,7 @@ import { Callout, Card, Cards, Steps } from 'nextra/components'; - These docs are for the latest alpha version of `ort`, `2.0.0-rc.3`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. + These docs are for the latest alpha version of `ort`, `2.0.0-rc.4`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. `ort` makes it easy to deploy your machine learning models to production via [ONNX Runtime](https://onnxruntime.ai/), a hardware-accelerated inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA) on almost any hardware, often far faster than PyTorch, and with the added bonus of Rust's efficiency. @@ -37,7 +37,7 @@ Converting a neural network to a graph representation like ONNX opens the door t If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: ```toml [dependencies] -ort = "2.0.0-rc.3" +ort = "2.0.0-rc.4" ``` ### Convert your model diff --git a/docs/pages/migrating/v2.mdx b/docs/pages/migrating/v2.mdx index c18776b8..f1a6af3f 100644 --- a/docs/pages/migrating/v2.mdx +++ b/docs/pages/migrating/v2.mdx @@ -173,7 +173,7 @@ let l = outputs["latents"].try_extract_tensor::()?; ``` ## Execution providers -Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.3/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. +Execution provider structs with public fields have been replaced with builder pattern structs. See the [API reference](https://docs.rs/ort/2.0.0-rc.4/ort/index.html?search=ExecutionProvider) and the [execution providers reference](/perf/execution-providers) for more information. ```diff -// v1.x diff --git a/docs/pages/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx index fbc59759..f7f20aaa 100644 --- a/docs/pages/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -83,7 +83,7 @@ fn main() -> anyhow::Result<()> { ``` ## Configuring EPs -EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.3/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. +EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.4/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do. ```rust use ort::{CoreMLExecutionProvider, Session}; diff --git a/docs/pages/setup/cargo-features.mdx b/docs/pages/setup/cargo-features.mdx index 9e9f72ef..ad495f72 100644 --- a/docs/pages/setup/cargo-features.mdx +++ b/docs/pages/setup/cargo-features.mdx @@ -9,8 +9,8 @@ title: Cargo features - ✅ **`half`**: Enables support for float16 & bfloat16 tensors via the [`half`](https://crates.io/crates/half) crate. ONNX models that are converted to 16-bit precision will typically convert to/from 32-bit floats at the input/output, so you will likely never actually need to interact with a 16-bit tensor on the Rust side. Though, `half` isn't a heavy enough crate to worry about it affecting compile times. - ✅ **`copy-dylibs`**: In case dynamic libraries are used (like with the CUDA execution provider), creates a symlink to them in the relevant places in the `target` folder to make [compile-time dynamic linking](/setup/linking#compile-time-dynamic-linking) work. - ⚒️ **`load-dynamic`**: Enables [runtime dynamic linking](/setup/linking#runtime-loading-with-load-dynamic), which alleviates many of the troubles with compile-time dynamic linking and offers greater flexibility. -- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.3/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. -- ⚒️ **`operator-libraries`**: Allows for sessions to load custom operators from dynamic C++ libraries via [`SessionBuilder::with_operator_library`](https://docs.rs/ort/2.0.0-rc.3/ort/struct.SessionBuilder.html#method.with_operator_library). If possible, we recommend [writing your custom ops in Rust instead](/perf/custom-operators). +- ⚒️ **`fetch-models`**: Enables the [`SessionBuilder::commit_from_url`](https://docs.rs/ort/2.0.0-rc.4/ort/struct.SessionBuilder.html#method.commit_from_url) method, allowing you to quickly download & run a model from a URL. This should only be used for quick testing. +- ⚒️ **`operator-libraries`**: Allows for sessions to load custom operators from dynamic C++ libraries via [`SessionBuilder::with_operator_library`](https://docs.rs/ort/2.0.0-rc.4/ort/struct.SessionBuilder.html#method.with_operator_library). If possible, we recommend [writing your custom ops in Rust instead](/perf/custom-operators). ## Execution providers Each [execution provider](/perf/execution-providers) is also gated behind a Cargo feature. diff --git a/docs/pages/setup/platforms.mdx b/docs/pages/setup/platforms.mdx index 6fbec097..c7a55c9d 100644 --- a/docs/pages/setup/platforms.mdx +++ b/docs/pages/setup/platforms.mdx @@ -5,7 +5,7 @@ description: ONNX Runtime, and by extension `ort`, supports a wide variety of pl import { Callout } from 'nextra/components'; -Here are the supported platforms and binary availability status, as of v2.0.0-rc.3. +Here are the supported platforms and binary availability status, as of v2.0.0-rc.4. * 🟢 - Supported. Dynamic & static binaries provided by pyke. * 🔷 - Supported. Static binaries provided by pyke. diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index e07f561b..4907b6d9 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "ort-sys" description = "Unsafe Rust bindings for ONNX Runtime 1.18 - Optimize and Accelerate Machine Learning Inferencing" -version = "2.0.0-rc.3" +version = "2.0.0-rc.4" edition = "2021" rust-version = "1.70" license = "MIT OR Apache-2.0"