Skip to content

Commit

Permalink
catch some uncaught exceptions in ws transport
Browse files Browse the repository at this point in the history
  • Loading branch information
nobody committed Jan 12, 2025
1 parent 2f94256 commit 3f793eb
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 105 deletions.
182 changes: 96 additions & 86 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ const BAD_REQUEST = new Response(null, {
statusText: 'Bad Request',
})

function get_length(o) {
return (o && (o.byteLength || o.length)) || 0
}

function validate_uuid(left, right) {
for (let i = 0; i < 16; i++) {
if (left[i] !== right[i]) {
Expand Down Expand Up @@ -130,7 +126,7 @@ function random_uuid() {
}

function random_padding(range_str) {
if (!range_str || typeof range_str !== 'string') {
if (!range_str || range_str === '0' || typeof range_str !== 'string') {
return null
}
const range = range_str
Expand Down Expand Up @@ -161,23 +157,26 @@ function parse_uuid(uuid) {
}

async function read_vless_header(reader, cfg_uuid_str) {
let buff = await read_atleast(reader, 1 + 16 + 1)
let readed_len = buff.value.length
let header = buff.value
let readed_len = 0
let header = new Uint8Array()

// prevent inner_read_until() throw error
let read_result = { value: header, done: false }
async function inner_read_until(offset) {
if (buff.done) {
if (read_result.done) {
throw new Error('header length too short')
}
const len = offset - readed_len
if (len < 1) {
return
}
buff = await read_atleast(reader, len)
readed_len += buff.value.length
header = concat_typed_arrays(header, buff.value)
read_result = await read_atleast(reader, len)
readed_len += read_result.value.length
header = concat_typed_arrays(header, read_result.value)
}

await inner_read_until(1 + 16 + 1)

const version = header[0]
const uuid = header.slice(1, 1 + 16)
const cfg_uuid = parse_uuid(cfg_uuid_str)
Expand Down Expand Up @@ -243,13 +242,19 @@ async function read_vless_header(reader, cfg_uuid_str) {
}
}

async function pump(readable, writable, first_packet) {
if (get_length(first_packet) > 0) {
const writer = writable.getWriter()
await writer.write(first_packet)
writer.releaseLock()
async function pump(log, tag, src, dest, first_packet) {
try {
if (first_packet.length > 0) {
const writer = dest.writable.getWriter()
await writer.write(first_packet)
writer.releaseLock()
}
await src.readable.pipeTo(dest.writable, src.pipe_to_options)
} catch (err) {
if (!(err instanceof AbortError)) {
log.info(`${tag} error: ${err.message}`)
}
}
await readable.pipeTo(writable)
}

function pick_random_proxy(cfg_proxy) {
Expand All @@ -264,16 +269,15 @@ function pick_random_proxy(cfg_proxy) {
async function connect_remote(log, hostname, port, cfg_proxy) {
async function inner_connect(remote) {
const conn = connect({ hostname: remote, port })
const info = await conn.opened
log.debug(`connection opened: ${info.remoteAddress}`)
await conn.opened
return conn
}

try {
log.info(`direct connect [${hostname}]:${port}`)
return await inner_connect(hostname)
} catch (err) {
log.debug(`direct connect failed: ${err}`)
log.debug(`direct connect failed: ${err.message}`)
}

const proxy = pick_random_proxy(cfg_proxy)
Expand All @@ -285,16 +289,13 @@ async function connect_remote(log, hostname, port, cfg_proxy) {
throw new Error('all attempts failed')
}

async function parse_header(log, uuid_str, client_readable) {
const reader = client_readable.getReader()
async function parse_header(uuid_str, client) {
try {
const reader = client.readable.getReader()
const vless = await read_vless_header(reader, uuid_str)
reader.releaseLock()
return vless
} catch (err) {
drain_connection(log, reader).catch((err) =>
log.info(`drain error: ${err}`),
)
throw new Error(`read vless header error: ${err.message}`)
}
}
Expand All @@ -307,29 +308,33 @@ async function read_atleast(reader, n) {
if (r.value) {
const b = new Uint8Array(r.value)
buffs.push(b)
n -= get_length(b)
n -= b.length
}
done = r.done
}
if (n > 0) {
throw new Error(`not enough data to read`)
}

return {
value: concat_typed_arrays(...buffs),
done,
}
}

function create_xhttp_client(cfg, client_readable) {
const abort_ctrl = new AbortController()

const buff_stream = new TransformStream(
{
transform(chunk, controller) {
controller.enqueue(chunk)
try {
controller.enqueue(chunk)
} catch {
abort_ctrl.abort()
}
},
},
new ByteLengthQueuingStrategy({ highWaterMark: BUFFER_SIZE }),
new ByteLengthQueuingStrategy({ highWaterMark: BUFFER_SIZE }),
)

const headers = {
Expand All @@ -350,25 +355,32 @@ function create_xhttp_client(cfg, client_readable) {
return {
readable: client_readable,
writable: buff_stream.writable,
pipe_to_options: { signal: abort_ctrl.signal },
resp,
}
}

function create_ws_client() {
const [ws_client, ws_server] = Object.values(new WebSocketPair())
ws_server.accept()
function create_ws_client(ws_client, ws_server) {
const abort_ctrl = new AbortController()

const readable = new ReadableStream(
{
start(controller) {
ws_server.addEventListener('message', ({ data }) => {
controller.enqueue(data)
try {
controller.enqueue(data)
} catch {}
})
ws_server.addEventListener('error', (err) => {
controller.error(err)
try {
controller.error(err)
} catch {}
})
ws_server.addEventListener('close', () => {
controller.close()
abort_ctrl.abort()
try {
controller.close()
} catch {}
})
},
},
Expand All @@ -378,18 +390,16 @@ function create_ws_client() {
const writable = new WritableStream(
{
write(chunk) {
ws_server.send(chunk)
try {
ws_server.send(chunk)
} catch {
abort_ctrl.abort()
}
},
},
new ByteLengthQueuingStrategy({ highWaterMark: BUFFER_SIZE }),
)

function on_closed() {
try {
ws_server.close()
} catch {}
}

const resp = new Response(null, {
status: 101,
webSocket: ws_client,
Expand All @@ -398,32 +408,32 @@ function create_ws_client() {
return {
readable,
writable,
on_closed,
resp,
pipe_to_options: { signal: abort_ctrl.signal },
}
}

async function handle_client(cfg, log, client) {
const vless = await parse_header(log, cfg.UUID, client.readable)

const remote = await connect_remote(
log,
vless.hostname,
vless.port,
cfg.PROXY,
)
try {
const vless = await parse_header(cfg.UUID, client)
const remote = await connect_remote(
log,
vless.hostname,
vless.port,
cfg.PROXY,
)

const upload_done = pump(client.readable, remote.writable, vless.data)
const download_done = pump(remote.readable, client.writable, vless.resp)
const uploader = pump(log, 'upload', client, remote, vless.data)
const downloader = pump(log, 'download', remote, client, vless.resp)

download_done
.catch((err) => log.error(`download error: ${err}`))
.finally(() => upload_done)
.catch((err) => log.debug(`upload error: ${err}`))
.finally(() => {
client.on_closed && client.on_closed()
log.info('connection closed')
})
Promise.all([uploader, downloader]).finally(() =>
log.info('connection closed'),
)
return true
} catch (err) {
log.error(`handle client error: ${err.message}`)
}
return false
}

function append_slash(path) {
Expand Down Expand Up @@ -528,16 +538,6 @@ const config_template = `{
]
}`

async function drain_connection(log, reader) {
log.info(`drain connection`)
while (true) {
const r = await reader.read()
if (r.done) {
break
}
}
}

async function handle_doh(log, request, url, upstream) {
const mime_dnsmsg = 'application/dns-message'
const method = request.method
Expand Down Expand Up @@ -664,13 +664,23 @@ Refresh this page to re-generate a random settings example.`

async function main(request, env) {
const cfg = load_settings(env, SETTINGS)
const log = new Logger(cfg.LOG_LEVEL, cfg.TIME_ZONE)
try {
const resp = await handle_request(cfg, log, request)
return resp
} catch (err) {
log.error(`unhandled error: ${err}`)
}
return BAD_REQUEST
}

async function handle_request(cfg, log, request) {
const url = new URL(request.url)
if (!cfg.UUID) {
const text = example(url)
return new Response(text)
}

const log = new Logger(cfg.LOG_LEVEL, cfg.TIME_ZONE)
const path = url.pathname

if (
Expand All @@ -679,11 +689,17 @@ async function main(request, env) {
path.endsWith(cfg.WS_PATH)
) {
log.info('handle ws client')
const client = create_ws_client()
const [ws_client, ws_server] = new WebSocketPair()
const client = create_ws_client(ws_client, ws_server)
// Do not block here. Client is waiting for upgrade-response.
handle_client(cfg, log, client).catch((err) =>
log.error(`handle ws client error: ${err}`),
)
setTimeout(() => {
try {
ws_server.accept()
} catch (err) {
log.error(`accept ws client error: ${err.message}`)
}
handle_client(cfg, log, client)
}, 0)
return client.resp
}

Expand All @@ -693,14 +709,9 @@ async function main(request, env) {
path.endsWith(cfg.XHTTP_PATH)
) {
log.info('handle xhttp client')
try {
const client = create_xhttp_client(cfg, request.body)
await handle_client(cfg, log, client)
return client.resp
} catch (err) {
log.error(`handle xhttp client error: ${err}`)
}
return BAD_REQUEST
const client = create_xhttp_client(cfg, request.body)
const ok = await handle_client(cfg, log, client)
return ok ? client.resp : BAD_REQUEST
}

if (cfg.DOH_QUERY_PATH && append_slash(path).endsWith(cfg.DOH_QUERY_PATH)) {
Expand All @@ -727,7 +738,6 @@ export default {

// for unit testing
concat_typed_arrays,
get_length,
parse_uuid,
pick_random_proxy,
random_id,
Expand Down
Loading

0 comments on commit 3f793eb

Please sign in to comment.