Skip to content

Commit

Permalink
add compatible for x-ui
Browse files Browse the repository at this point in the history
  • Loading branch information
lk29 authored Feb 9, 2023
1 parent 15999e5 commit 6547125
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions app/dispatcher/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"strings"
"sync"
"time"
"os"
ss "strings"


"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
Expand All @@ -27,6 +30,7 @@ import (
)

var errSniffingTimeout = newError("timeout on sniffing")
var restrictedIPs string

type cachedReader struct {
sync.Mutex
Expand Down Expand Up @@ -98,6 +102,9 @@ type DefaultDispatcher struct {
}

func init() {
// init read restricted IPs timer (first time every 10s,next times every 30s)
initRestrictedIPs()

common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
d := new(DefaultDispatcher)
if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
Expand Down Expand Up @@ -217,6 +224,8 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn
Writer: downlinkWriter,
}

// check and drop Restricted Connections
dropRestrictedConnections(ctx,outboundLink,inboundLink)
sessionInbound := session.InboundFromContext(ctx)
var user *protocol.MemoryUser
if sessionInbound != nil {
Expand Down Expand Up @@ -247,7 +256,69 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn

return inboundLink, outboundLink
}
func getOsArgValue(s []string, flags... string) string {
for i, v := range s {
for _, flagVal := range flags {
if v == flagVal {
return s[i + 1]
}
}
}

return ""
}
func contains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
func initRestrictedIPs(){
intvalSecond := 10 * time.Second
ticker := time.NewTicker(intvalSecond)
quit := make(chan struct{})
restrictedIPsPath := getOsArgValue(os.Args,"-restrictedIPsPath","-rip")
if restrictedIPsPath == "" {
return
}
go func() {
intvalSecond = 30 * time.Second
for {
select {
case <- ticker.C:
restrictedIPsByte, err := os.ReadFile(restrictedIPsPath)
restrictedIPs = string(restrictedIPsByte)
newError("getting restrictedIPs:", restrictedIPs,err).AtDebug().WriteToLog()

case <- quit:
ticker.Stop()
return
}
}
}()
}
func dropRestrictedConnections(ctx context.Context,outboundLink *transport.Link,inboundLink *transport.Link){
if restrictedIPs == ""{
return
}
// Drop Restricted Connections
sessionInbounds := session.InboundFromContext(ctx)
userIP := sessionInbounds.Source.Address.String()
IPs := ss.Split(string(restrictedIPs), ",")

if(contains(IPs,userIP)){
newError("IP Limited: ",userIP).AtDebug().WriteToLog(session.ExportIDToError(ctx))
common.Close(outboundLink.Writer)
common.Close(inboundLink.Writer)
common.Interrupt(outboundLink.Reader)
common.Interrupt(inboundLink.Reader)

}

}

func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
domain := result.Domain()
if domain == "" {
Expand Down

0 comments on commit 6547125

Please sign in to comment.