diff --git a/device.go b/device.go index 2b77ba4..f6f4f5c 100644 --- a/device.go +++ b/device.go @@ -6,6 +6,11 @@ import ( giDevice "github.com/electricbubble/gidevice" ) +const ( + defaultPort = 8100 + defaultMjpegPort = 9100 +) + type Device struct { deviceID int serialNumber string @@ -15,6 +20,61 @@ type Device struct { d giDevice.Device } +type DeviceOption func(d *Device) + +func WithSerialNumber(serialNumber string) DeviceOption { + return func(d *Device) { + d.serialNumber = serialNumber + } +} + +func WithPort(port int) DeviceOption { + return func(d *Device) { + d.Port = port + } +} + +func WithMjpegPort(port int) DeviceOption { + return func(d *Device) { + d.MjpegPort = port + } +} + +func NewDevice(options ...DeviceOption) (device *Device, err error) { + var usbmux giDevice.Usbmux + if usbmux, err = giDevice.NewUsbmux(); err != nil { + return nil, fmt.Errorf("init usbmux failed: %v", err) + } + + var deviceList []giDevice.Device + if deviceList, err = usbmux.Devices(); err != nil { + return nil, fmt.Errorf("get attached devices failed: %v", err) + } + + device = &Device{ + Port: defaultPort, + MjpegPort: defaultMjpegPort, + } + for _, option := range options { + option(device) + } + + serialNumber := device.serialNumber + for _, d := range deviceList { + // find device by serial number if specified + if serialNumber != "" && d.Properties().SerialNumber != serialNumber { + continue + } + + device.deviceID = d.Properties().DeviceID + device.serialNumber = d.Properties().SerialNumber + device.d = d + return device, nil + } + + return nil, fmt.Errorf("device %s not found", device.serialNumber) +} + func DeviceList() (devices []Device, err error) { var usbmux giDevice.Usbmux if usbmux, err = giDevice.NewUsbmux(); err != nil { @@ -31,8 +91,8 @@ func DeviceList() (devices []Device, err error) { for i := range devices { devices[i].deviceID = deviceList[i].Properties().DeviceID devices[i].serialNumber = deviceList[i].Properties().SerialNumber - devices[i].Port = 8100 - devices[i].MjpegPort = 9100 + devices[i].Port = defaultPort + devices[i].MjpegPort = defaultMjpegPort devices[i].d = deviceList[i] } diff --git a/driver.go b/driver.go index c7f6457..e1246af 100644 --- a/driver.go +++ b/driver.go @@ -20,7 +20,7 @@ import ( // NewDriver creates new remote client, this will also start a new session. func NewDriver(capabilities Capabilities, urlPrefix string, mjpegPort ...int) (driver WebDriver, err error) { if len(mjpegPort) == 0 { - mjpegPort = []int{9100} + mjpegPort = []int{defaultMjpegPort} } wd := new(remoteWD) if wd.urlPrefix, err = url.Parse(urlPrefix); err != nil { diff --git a/driver_test.go b/driver_test.go index 6a234ba..6f380dc 100644 --- a/driver_test.go +++ b/driver_test.go @@ -41,6 +41,28 @@ func TestViaUSB(t *testing.T) { } } +func TestNewDevice(t *testing.T) { + device, _ := NewDevice() + if device != nil { + t.Log(device) + } + + device, _ = NewDevice(WithSerialNumber("xxxx")) + if device != nil { + t.Log(device) + } + + device, _ = NewDevice(WithPort(8700), WithMjpegPort(8800)) + if device != nil { + t.Log(device) + } + + device, _ = NewDevice(WithSerialNumber("xxxx"), WithPort(8700), WithMjpegPort(8800)) + if device != nil { + t.Log(device) + } +} + func TestNewDriver(t *testing.T) { var err error driver, err = NewDriver(nil, urlPrefix)