Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
postmannen committed Nov 4, 2022
2 parents d947557 + 93983a1 commit 9e144fc
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 63 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
certs/*.pem
certs/*.srl
make.sh
43 changes: 43 additions & 0 deletions DockerFile
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# build stage
FROM golang:1.19.3-alpine AS build-env
RUN apk --no-cache add build-base git gcc

RUN mkdir -p /build
COPY ./ /build/

WORKDIR /build/client/
RUN go version
#RUN git checkout main && go build -o steward
RUN go build -o usbtcpclient
RUN pwd
RUN ls -l

# final stage
FROM alpine

RUN apk update && apk add curl

WORKDIR /app
COPY --from=build-env /build/client/usbtcpclient /app/

# If MTLS is enabled, give the path to the CA cert
ENV CA_CERT ""
# The path to the certificate
ENV CERT ""
# 127.0.0.1:45000
ENV IP_PORT ""
# If MTLS is enabled, give the path to the key
ENV KEY ""
# Set to 1 to enable MTLS
ENV MTLS ""
# The directory path for where to store the port.info file
ENV PORT_INFO_FILE_DIR ""

CMD ["ash","-c","/app/usbtcpclient\
-caCert=$CA_CERT\
-cert=$CERT\
-ipPort=$IP_PORT\
-key=$KEY\
-mtls=$MTLS\
-portInfoFileDir=$PORT_INFO_FILE_DIR\
"]
24 changes: 21 additions & 3 deletions client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"path/filepath"

"github.com/creack/pty"
)
Expand All @@ -32,7 +32,7 @@ func newTLSConfig(nc netConfig) (*tls.Config, error) {
}

certPool := x509.NewCertPool()
pemCABytes, err := ioutil.ReadFile(nc.caCert)
pemCABytes, err := os.ReadFile(nc.caCert)
if err != nil {
return nil, fmt.Errorf("error: failed to read ca cert: %v", err)
}
Expand Down Expand Up @@ -84,6 +84,8 @@ func main() {
cert := flag.String("cert", "../certs/client-cert.pem", "the path to the server certificate")
key := flag.String("key", "../certs/client-key.pem", "the path to the private key")
ipPort := flag.String("ipPort", "127.0.0.1:45000", "ip:port of the host to connec to")
portInfoFileDir := flag.String("portInfoFileDir", "./", "the directory path of where to store the port.info file")

flag.Parse()

nConf := netConfig{
Expand All @@ -106,6 +108,20 @@ func main() {
log.Printf("pty: %v\n", pt.Name())
log.Printf("tty: %v\n", tt.Name())

portInfoPath := filepath.Join(*portInfoFileDir, "port.info")
fh, err := os.Create(portInfoPath)
if err != nil {
log.Printf("error: os.Create failed: %v\n", err)
os.Exit(1)
}
defer fh.Close()

_, err = fh.Write([]byte(tt.Name()))
if err != nil {
log.Printf("error: writing to file failed: %v\n", err)
os.Exit(1)
}

// --- Client: Open dial network

conn, err := getNetConn(nConf)
Expand All @@ -129,14 +145,16 @@ func main() {
log.Printf("error: conn.Read err != nil || err != io.EOF: characters=%v, %v\n", n, err)
continue
}

if err == io.EOF && n == 0 {
log.Printf("error: conn.Read err == io.EOF && n == 0: characters=%v, %v\n", n, err)
os.Exit(1)
}

{
n, err := pt.Write(b)
if err != nil || n == 0 {
//if err != nil || n == 0 {
if err != nil {
log.Printf("error: pt.Write: characters=%v, %v\n", n, err)
return
}
Expand Down
187 changes: 127 additions & 60 deletions server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"time"

"github.com/pkg/term"
"go.bug.st/serial/enumerator"
Expand All @@ -20,8 +21,9 @@ type netConfig struct {
cert string
key string

baud int
ipPort string
baud int
ipPort string
ttyReadTimeout int
}

// getTTY will get the path of the tty.
Expand All @@ -46,73 +48,136 @@ func getTTY(vid string, pid string) (string, error) {

// relay will start relaying the data between the TTY and the network connection.
func relay(ttyName string, nConf netConfig) error {
// --- Server: Open tty

tty, err := term.Open(ttyName)
// The for loop will initiate both the network listener and the TTY.
// If the connection is dropped for either network or tty, then all
// connection both to the TTY and the Network is closed, and all go
// routines for reading and writing are exited, and new connection
// are made for the next iteration.
for {
err := func() error {

if err != nil {
log.Printf("error: tty OpenFile: %v\n", err)
}
defer tty.Close()
defer tty.Restore()
term.RawMode(tty)
// --- Server: Open tty

err = tty.SetSpeed(9600)
if err != nil {
return fmt.Errorf("error: failed to set baud: %v", err)
}
tty, err := term.Open(ttyName)

nl, err := getNetListener(nConf)
if err != nil {
return fmt.Errorf("error: opening network listener failed: %v", err)
}
defer nl.Close()
if err != nil {
log.Printf("error: tty OpenFile: %v\n", err)
}
defer tty.Close()
defer tty.Restore()
term.RawMode(tty)

for {
conn, err := nl.Accept()
if err != nil {
log.Printf("error: opening out endpoint failed: %v\n", err)
continue
}
err = tty.SetSpeed(9600)
if err != nil {
return fmt.Errorf("error: failed to set baud: %v", err)
}

nl, err := getNetListener(nConf)
if err != nil {
return fmt.Errorf("error: opening network listener failed: %v", err)
}
defer nl.Close()

// Read tty -> write net.Conn
go func() {
for {
b := make([]byte, 1)
_, err := tty.Read(b)
if err != nil && err != io.EOF {
log.Printf("error: fh, failed to read : %v\n", err)
return

err := tty.SetReadTimeout(time.Second * time.Duration(nConf.ttyReadTimeout))
if err != nil {
return fmt.Errorf("error: setReadTimeoutFailed: %v", err)
}

_, err = conn.Write(b)
conn, err := nl.Accept()
if err != nil {
log.Printf("error: pt.Write: %v\n", err)
return
log.Printf("error: opening out endpoint failed: %v\n", err)
continue
}
}
}()

// Read net.Conn -> write tty
for {
b := make([]byte, 1)
connOK := true

errCh := make(chan error)

// Read tty -> write net.Conn
go func() {
log.Printf(" * starting go routine for Read tty -> write net.Conn\n")
defer log.Printf(" ** ending go routine for Read tty -> write net.Conn\n")

for {

b := make([]byte, 1)
_, err := tty.Read(b)
if err != nil {
if connOK {
// fmt.Printf("connOK = %v\n", connOK)
continue
}

er := fmt.Errorf("error: tty.Read failed: %v", err)
select {
case errCh <- er:
default:
log.Printf("connection marked as down, exiting reader for TTY: %v\n", er)
}

return
}

// fmt.Printf(" tty read nr = %v\n", n)

_, err = conn.Write(b)
if err != nil {
errCh <- fmt.Errorf("error: conn.Write failed: %v", err)
return
}
}
}()

// Read net.Conn -> write tty
go func() {
log.Printf(" * starting go routine for Read net.Conn -> write tty\n")
defer log.Printf(" ** ending go routine for Read net.Conn -> write tty\n")
defer func() { connOK = false }()

for {
b := make([]byte, 1)

_, err := conn.Read(b)
if err != nil && err != io.EOF {
errCh <- fmt.Errorf("error: conn.Read failed : %v", err)
return
}
if err == io.EOF {
errCh <- fmt.Errorf("error: conn.Read failed, got io.EOF: %v", err)
return
}

_, err = tty.Write(b)
if err != nil {
er := fmt.Errorf("error: tty.Write failed : %v", err)
select {
case errCh <- er:
default:
log.Printf("%v\n", er)
}
return
}
}
}()

err = <-errCh
if err != nil {
log.Printf("%v\n", err)
}
tty.Close()
conn.Close()

_, err := conn.Read(b)
if err != nil && err != io.EOF {
log.Printf("error: failed to read pt : %v\n", err)
continue
}
if err == io.EOF {
return fmt.Errorf("error: pt.Read, got io.EOF: %v", err)
}
}()

_, err = tty.Write(b)
if err != nil {
return fmt.Errorf("error: fh.Write : %v", err)
}
if err != nil {
log.Printf("%v\n", err)
}

}

}

// getNetListener will return either an normal or TLS encryptet net.Listener.
Expand All @@ -127,7 +192,7 @@ func getNetListener(nConf netConfig) (net.Listener, error) {
}

certPool := x509.NewCertPool()
pemCABytes, err := ioutil.ReadFile(nConf.caCert)
pemCABytes, err := os.ReadFile(nConf.caCert)
if err != nil {
return nil, fmt.Errorf("error: failed to read ca cert: %v", err)
}
Expand Down Expand Up @@ -172,16 +237,18 @@ func main() {
key := flag.String("key", "../certs/server-key.pem", "the path to the private key")
baud := flag.Int("baud", 9600, "baud rate")
ipPort := flag.String("ipPort", "127.0.0.1:45000", "ip:port for where to start the network listener")
ttyReadTimeout := flag.Int("ttyReadTimeout", 1, "The timeout for TTY read given in seconds")

flag.Parse()

nConf := netConfig{
mtls: *mtls,
caCert: *caCert,
cert: *cert,
key: *key,
baud: *baud,
ipPort: *ipPort,
mtls: *mtls,
caCert: *caCert,
cert: *cert,
key: *key,
baud: *baud,
ipPort: *ipPort,
ttyReadTimeout: *ttyReadTimeout,
}

ttyName, err := getTTY(*vid, *pid)
Expand Down

0 comments on commit 9e144fc

Please sign in to comment.