18
18
import os
19
19
import signal
20
20
import socket
21
+ import sys
21
22
import time
22
23
from multiprocessing import Process
23
24
from typing import Callable , Dict , List , Optional , Tuple
@@ -42,8 +43,31 @@ def _start_triton_server(
42
43
model_path : Optional [str ] = None ,
43
44
) -> List [tuple ]:
44
45
"""Task to start Triton server process on a Spark executor."""
45
- sig = inspect .signature (triton_server_fn )
46
- params = sig .parameters
46
+
47
+ def _prepare_pytriton_env ():
48
+ """Expose PyTriton to correct libpython3.11.so and Triton bundled libraries."""
49
+ ld_library_paths = []
50
+
51
+ # Add nvidia_pytriton.libs to LD_LIBRARY_PATH
52
+ for path in sys .path :
53
+ if os .path .isdir (path ) and "site-packages" in path :
54
+ libs_path = os .path .join (path , "nvidia_pytriton.libs" )
55
+ if os .path .isdir (libs_path ):
56
+ ld_library_paths .append (libs_path )
57
+ break
58
+
59
+ # Add ${CONDA_PREFIX}/lib to LD_LIBRARY_PATH for conda environments
60
+ if os .path .exists (os .path .join (sys .prefix , "conda-meta" )):
61
+ conda_lib = os .path .join (sys .prefix , "lib" )
62
+ if os .path .isdir (conda_lib ):
63
+ ld_library_paths .append (conda_lib )
64
+
65
+ if "LD_LIBRARY_PATH" in os .environ :
66
+ ld_library_paths .append (os .environ ["LD_LIBRARY_PATH" ])
67
+
68
+ os .environ ["LD_LIBRARY_PATH" ] = ":" .join (ld_library_paths )
69
+
70
+ return None
47
71
48
72
def _find_ports (start_port : int = 7000 ) -> List [int ]:
49
73
"""Find available ports for Triton's HTTP, gRPC, and metrics services."""
@@ -59,6 +83,8 @@ def _find_ports(start_port: int = 7000) -> List[int]:
59
83
return ports
60
84
61
85
ports = _find_ports ()
86
+ sig = inspect .signature (triton_server_fn )
87
+ params = sig .parameters
62
88
63
89
if model_path is not None :
64
90
assert (
@@ -69,6 +95,7 @@ def _find_ports(start_port: int = 7000) -> List[int]:
69
95
assert len (params ) == 1 , "Server function must accept (ports) argument"
70
96
args = (ports ,)
71
97
98
+ _prepare_pytriton_env ()
72
99
hostname = socket .gethostname ()
73
100
process = Process (target = triton_server_fn , args = args )
74
101
process .start ()
@@ -83,6 +110,11 @@ def _find_ports(start_port: int = 7000) -> List[int]:
83
110
except Exception :
84
111
pass
85
112
113
+ client .close ()
114
+ if process .is_alive ():
115
+ # Terminate if timeout is exceeded to avoid dangling server processes
116
+ process .terminate ()
117
+
86
118
raise TimeoutError (
87
119
"Failure: server startup timeout exceeded. Check the executor logs for more info."
88
120
)
@@ -98,14 +130,19 @@ def _stop_triton_server(
98
130
pid , _ = server_pids_ports .get (hostname )
99
131
assert pid is not None , f"No server PID found for host { hostname } "
100
132
101
- for _ in range (wait_retries ):
133
+ try :
134
+ process = psutil .Process (pid )
135
+ process .terminate ()
136
+ process .wait (timeout = wait_timeout * wait_retries )
137
+ return [True ]
138
+ except psutil .NoSuchProcess :
139
+ return [True ]
140
+ except psutil .TimeoutExpired :
102
141
try :
103
- os .kill (pid , signal .SIGTERM )
104
- except OSError :
142
+ process .kill ()
105
143
return [True ]
106
- time .sleep (wait_timeout )
107
-
108
- return [False ] # Failed to terminate or timed out
144
+ except :
145
+ return [False ]
109
146
110
147
111
148
class TritonServerManager :
0 commit comments