diff --git a/application/application.go b/application/application.go index 68adbed..042d9ea 100644 --- a/application/application.go +++ b/application/application.go @@ -93,6 +93,9 @@ type Application struct { conn cast.Conn debug bool + // Device name override (originating e.g. from mdns lookup). + deviceNameOverride string + // Internal mapping of request id to result channel resultChanMap map[int]chan *pb.CastMessage @@ -187,6 +190,12 @@ func WithSkipadRetries(retries int) ApplicationOption { } } +func WithDeviceNameOverride(deviceName string) ApplicationOption { + return func(a *Application) { + a.SetDeviceNameOverride(deviceName) + } +} + func NewApplication(opts ...ApplicationOption) *Application { a := &Application{ conn: cast.NewConnection(), @@ -223,6 +232,10 @@ func (a *Application) SetIface(iface *net.Interface) { a.iface = iface } func (a *Application) SetSkipadSleep(sleep time.Duration) { a.skipadSleep = sleep } func (a *Application) SetSkipadRetries(retries int) { a.skipadRetries = retries } +func (a *Application) SetDeviceNameOverride(deviceName string) { + a.deviceNameOverride = deviceName +} + func (a *Application) App() *cast.Application { return a.application } func (a *Application) Media() *cast.Media { return a.media } func (a *Application) Volume() *cast.Volume { return a.volumeReceiver } @@ -439,11 +452,18 @@ func (a *Application) Status() (*cast.Application, *cast.Media, *cast.Volume) { func (a *Application) Info() (*cast.DeviceInfo, error) { addr, err := a.conn.RemoteAddr() - if err != nil { return nil, err } - return GetInfo(addr) + info, err := GetInfo(addr) + if err != nil { + return nil, err + } + log.Printf("deviceNameOverride: %s", a.deviceNameOverride) + if len(a.deviceNameOverride) > 0 { + info.Name = a.deviceNameOverride + } + return info, err } func (a *Application) Pause() error { diff --git a/application/info.go b/application/info.go index 6463237..267cf71 100644 --- a/application/info.go +++ b/application/info.go @@ -14,7 +14,10 @@ import ( // information about the cast-device. // OBS: The 8008 seems to be pure http, whereas 8009 is typically the port // to use for protobuf-communication, + func GetInfo(ip string) (info *cast.DeviceInfo, err error) { + // Note: Services exposed not on 8009 port are "Google Cast Group"s + // The only way to find the true device (group) name, is using mDNS outside of this function. url := fmt.Sprintf("http://%v:8008/setup/eureka_info", ip) log.Printf("Fetching: %s", url) resp, err := http.Get(url) diff --git a/cast/connection.go b/cast/connection.go index 9f06f8a..5e044d8 100644 --- a/cast/connection.go +++ b/cast/connection.go @@ -32,6 +32,7 @@ type Conn interface { SetDebug(debug bool) LocalAddr() (addr string, err error) RemoteAddr() (addr string, err error) + RemotePort() (addr string, err error) Send(requestID int, payload Payload, sourceID, destinationID, namespace string) error } @@ -91,8 +92,13 @@ func (c *Connection) LocalAddr() (addr string, err error) { } func (c *Connection) RemoteAddr() (addr string, err error) { - host, _, err := net.SplitHostPort(c.conn.RemoteAddr().String()) - return host, err + addr, _, err = net.SplitHostPort(c.conn.RemoteAddr().String()) + return addr, err +} + +func (c *Connection) RemotePort() (port string, err error) { + _, port, err = net.SplitHostPort(c.conn.RemoteAddr().String()) + return port, err } func (c *Connection) log(message string, args ...interface{}) { diff --git a/cast/mocks/Conn.go b/cast/mocks/Conn.go index 50120b6..f73946b 100644 --- a/cast/mocks/Conn.go +++ b/cast/mocks/Conn.go @@ -108,6 +108,34 @@ func (_m *Conn) RemoteAddr() (string, error) { return r0, r1 } +// RemotePort provides a mock function with given fields: +func (_m *Conn) RemotePort() (string, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RemotePort") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func() (string, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Send provides a mock function with given fields: requestID, payload, sourceID, destinationID, namespace func (_m *Conn) Send(requestID int, payload cast.Payload, sourceID string, destinationID string, namespace string) error { ret := _m.Called(requestID, payload, sourceID, destinationID, namespace) diff --git a/http/handlers.go b/http/handlers.go index a630caf..e0ffc74 100644 --- a/http/handlers.go +++ b/http/handlers.go @@ -125,6 +125,7 @@ func (h *Handler) discoverDnsEntries(ctx context.Context, iface string, waitq st } for d := range devicesChan { + log.Printf("Device: %#v", d) devices = append(devices, device{ Addr: d.AddrV4.String(), Port: d.Port, @@ -196,10 +197,11 @@ func (h *Handler) connect(w http.ResponseWriter, r *http.Request) { deviceAddr := q.Get("addr") devicePort := q.Get("port") + deviceName := q.Get("name") iface := q.Get("interface") wait := q.Get("wait") - if deviceAddr == "" || devicePort == "" { + if deviceAddr == "" || devicePort == "" || (deviceName == "" && devicePort != "8009") { h.log("device addr and/or port are missing, trying to lookup address for uuid %q", deviceUUID) devices := h.discoverDnsEntries(context.Background(), iface, wait) @@ -210,6 +212,7 @@ func (h *Handler) connect(w http.ResponseWriter, r *http.Request) { // TODO: This is an unnessecary conversion since // we cast back to int a bit later. devicePort = strconv.Itoa(device.Port) + deviceName = device.DeviceName } } } @@ -228,7 +231,7 @@ func (h *Handler) connect(w http.ResponseWriter, r *http.Request) { return } - app, err := h.connectInternal(deviceAddr, devicePortI) + app, err := h.connectInternal(deviceAddr, devicePortI, deviceName) if err != nil { h.log("unable to start application: %v", err) httpError(w, fmt.Errorf("unable to start application: %v", err)) @@ -246,11 +249,14 @@ func (h *Handler) connect(w http.ResponseWriter, r *http.Request) { } } -func (h *Handler) connectInternal(deviceAddr string, devicePort int) (application.App, error) { +func (h *Handler) connectInternal(deviceAddr string, devicePort int, deviceName string) (application.App, error) { applicationOptions := []application.ApplicationOption{ application.WithDebug(h.verbose), application.WithCacheDisabled(true), } + if deviceName != "" { + applicationOptions = append(applicationOptions, application.WithDeviceNameOverride(deviceName)) + } app := application.NewApplication(applicationOptions...) if err := app.Start(deviceAddr, devicePort); err != nil { @@ -289,7 +295,7 @@ func (h *Handler) connectAllInternal(iface string, waitSec string) error { devices := h.discoverDnsEntries(context.Background(), iface, waitSec) uuidMap := map[string]application.App{} for _, device := range devices { - app, err := h.connectInternal(device.Addr, device.Port) + app, err := h.connectInternal(device.Addr, device.Port, device.DeviceName) if err != nil { return err }