Skip to content

Commit

Permalink
removed a bit of ios boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
CommanderStorm committed Oct 22, 2023
1 parent cf1f22a commit c32f57c
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 174 deletions.
28 changes: 7 additions & 21 deletions server/backend/ios_notifications/apns/jwt_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,6 @@ import (
log "github.com/sirupsen/logrus"
)

const (
// TokenTimeout for the token in seconds
TokenTimeout = 3000
)

var (
ErrorAuthKeyNotPem = errors.New("failed to parse token: AuthKey must be a valid .p8 PEM file")
ErrorAuthKeyNotEcdsa = errors.New("failed to parse token: AuthKey must be of type ecdsa.PrivateKey")
ErrorAuthKeyNil = errors.New("failed to parse token: AuthKey was nil")
ApnsKeyId = os.Getenv("APNS_KEY_ID")
ApnsTeamId = os.Getenv("APNS_TEAM_ID")
ApnsP8FilePath = os.Getenv("APNS_P8_FILE_PATH")
)

type JWTToken struct {
sync.Mutex
EncryptionKey *ecdsa.PrivateKey
Expand All @@ -46,8 +32,8 @@ func NewToken() (*JWTToken, error) {

token := JWTToken{
EncryptionKey: encryptionKey,
KeyId: ApnsKeyId,
TeamId: ApnsTeamId,
KeyId: os.Getenv("APNS_KEY_ID"),
TeamId: os.Getenv("APNS_TEAM_ID"),
}

if err = token.Generate(); err != nil {
Expand All @@ -61,7 +47,7 @@ func NewToken() (*JWTToken, error) {
// and returns it as an ecdsa.PrivateKey
// The file location is defined by the APNS_P8_FILE_PATH environment variable
func APNsEncryptionKeyFromFile() (*ecdsa.PrivateKey, error) {
path, err := filepath.Abs(ApnsP8FilePath)
path, err := filepath.Abs(os.Getenv("APNS_P8_FILE_PATH"))

if err != nil {
log.Error("No valid path to AuthKey")
Expand All @@ -81,7 +67,7 @@ func APNsEncryptionKeyFromFile() (*ecdsa.PrivateKey, error) {
if block == nil {
log.Error("Could not decode APNs encryption key from file")

return nil, ErrorAuthKeyNotPem
return nil, errors.New("failed to parse token: AuthKey must be a valid .p8 PEM file")
}

key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
Expand All @@ -96,7 +82,7 @@ func APNsEncryptionKeyFromFile() (*ecdsa.PrivateKey, error) {
return ecdsaKey, nil
}

return nil, ErrorAuthKeyNotEcdsa
return nil, errors.New("failed to parse token: AuthKey must be of type ecdsa.PrivateKey")
}

func (t *JWTToken) GenerateNewTokenIfExpired() (bearer string) {
Expand All @@ -114,12 +100,12 @@ func (t *JWTToken) GenerateNewTokenIfExpired() (bearer string) {
}

func (t *JWTToken) IsExpired() bool {
return currentTimestamp() >= (t.IssuedAt + TokenTimeout)
return currentTimestamp() >= (t.IssuedAt + 3000)
}

func (t *JWTToken) Generate() error {
if t.EncryptionKey == nil {
return ErrorAuthKeyNil
return errors.New("failed to parse token: AuthKey was nil")
}

issuedAt := currentTimestamp()
Expand Down
39 changes: 9 additions & 30 deletions server/backend/ios_notifications/apns/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,6 @@ import (
"gorm.io/gorm"
)

const (
// BundleId from the Apple Developer Portal
BundleId = "de.tum.tca"
// ReadIdleTimeout is the idle time after which the http2 transport will do a health check
ReadIdleTimeout = 15 * time.Second
// HTTPClientTimeout is the timeout for the http client used to send notifications
HTTPClientTimeout = 60 * time.Second
)

const (
ApnsDevelopmentURL = "https://api.sandbox.push.apple.com:443"
ApnsProductionURL = "https://api.push.apple.com:443"
)

var (
ErrCouldNotSendNotification = errors.New("could not send notification")
ErrCouldNotDecodeAPNsResponse = errors.New("could not decode apns response")
)

type Repository struct {
DB gorm.DB
Token *JWTToken
Expand All @@ -43,11 +24,11 @@ type Repository struct {

// ApnsUrl uses the environment variable ENVIRONMENT to determine whether
// to use the production or development APNs URL.
func (r *Repository) ApnsUrl() string {
func (r *Repository) ApnsUrl(DeviceId string) string {
if env.IsProd() {
return ApnsProductionURL
return "https://api.push.apple.com:443/3/device/" + DeviceId
}
return ApnsDevelopmentURL
return "https://api.sandbox.push.apple.com:443/3/device/" + DeviceId
}

// CreateCampusTokenRequest creates a request log in the database that can be referred to
Expand Down Expand Up @@ -81,15 +62,13 @@ func (r *Repository) SendBackgroundNotification(payload *model.IOSNotificationPa
}

func (r *Repository) SendNotification(notification *model.IOSNotificationPayload, apnsPushType model.IOSAPNSPushType, priority int) (*model.IOSRemoteNotificationResponse, error) {

url := r.ApnsUrl() + "/3/device/" + notification.DeviceId
body, _ := notification.MarshalJSON()

req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(body))
req, _ := http.NewRequest(http.MethodPost, r.ApnsUrl(notification.DeviceId), bytes.NewBuffer(body))

// can be e.g. alert or background
req.Header.Set("apns-push-type", apnsPushType.String())
req.Header.Set("apns-topic", BundleId)
req.Header.Set("apns-topic", "de.tum.tca")
// can be a value between 1 and 10
req.Header.Set("apns-priority", strconv.Itoa(priority))

Expand All @@ -99,7 +78,7 @@ func (r *Repository) SendNotification(notification *model.IOSNotificationPayload
resp, err := r.httpClient.Do(req)
if err != nil {
log.WithError(err).Error("Could not send notification")
return nil, ErrCouldNotSendNotification
return nil, errors.New("could not send notification")
}
defer func(Body io.ReadCloser) {
if err := Body.Close(); err != nil {
Expand All @@ -110,23 +89,23 @@ func (r *Repository) SendNotification(notification *model.IOSNotificationPayload
var response model.IOSRemoteNotificationResponse
if err = json.NewDecoder(resp.Body).Decode(&response); err != nil && err != io.EOF {
log.WithError(err).Error("Could not decode APNs response")
return nil, ErrCouldNotDecodeAPNsResponse
return nil, errors.New("could not decode apns response")
}

return &response, nil
}

func NewRepository(db *gorm.DB, token *JWTToken) *Repository {
transport := &http2.Transport{
ReadIdleTimeout: ReadIdleTimeout,
ReadIdleTimeout: 15 * time.Second,
}

return &Repository{
DB: *db,
Token: token,
httpClient: &http.Client{
Transport: transport,
Timeout: HTTPClientTimeout,
Timeout: 60 * time.Second,
},
}
}
Expand Down
9 changes: 5 additions & 4 deletions server/backend/ios_notifications/apns/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package apns

import (
"errors"
"os"

"github.com/TUM-Dev/Campus-Backend/server/model"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -34,21 +35,21 @@ func (s *Service) RequestGradeUpdateForDevice(deviceID string) error {

if _, err := s.Repository.SendBackgroundNotification(notification); err != nil {
log.WithError(err).Error("Could not send background notification")
return ErrCouldNotSendNotification
return errors.New("could not send notification")
}
return nil
}

func ValidateRequirementsForIOSNotificationsService() error {
if ApnsKeyId == "" {
if os.Getenv("APNS_KEY_ID") == "" {
return errors.New("APNS_KEY_ID env variable is not set")
}

if ApnsTeamId == "" {
if os.Getenv("APNS_TEAM_ID") == "" {
return errors.New("APNS_TEAM_ID env variable is not set")
}

if ApnsP8FilePath == "" {
if os.Getenv("APNS_P8_FILE_PATH") == "" {
return errors.New("APNS_P8_FILE_PATH env variable is not set")
}

Expand Down
8 changes: 3 additions & 5 deletions server/backend/ios_notifications/crypto/encrypted_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,19 @@ func AsymmetricEncrypt(plaintext string, publicKey string) (*EncryptedString, er

func StringToPublicKey(pub string) (*rsa.PublicKey, error) {
block, _ := pem.Decode([]byte(pub))

if block == nil {
return nil, errors.New("failed to parse PEM block containing the public key")
}

key, err := x509.ParsePKIXPublicKey(block.Bytes)

if err != nil {
return nil, errors.New("failed to parse DER encoded public key: " + err.Error())
}

if pubKey, ok := key.(*rsa.PublicKey); ok {
return pubKey, nil
} else {
if pubKey, ok := key.(*rsa.PublicKey); !ok {
return nil, errors.New("failed to parse DER encoded public key: " + err.Error())
} else {
return pubKey, nil
}
}

Expand Down
31 changes: 9 additions & 22 deletions server/backend/ios_notifications/device/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ type Repository struct {
}

func (repository *Repository) CreateDevice(device *model.IOSDevice) error {

return repository.DB.Transaction(func(tx *gorm.DB) error {

var foundDevice model.IOSDevice

res := tx.First(&foundDevice, "device_id = ?", device.DeviceID)

if errors.Is(res.Error, gorm.ErrRecordNotFound) {
Expand All @@ -41,11 +39,7 @@ func (repository *Repository) CreateDevice(device *model.IOSDevice) error {
}

func (repository *Repository) DeleteDevice(deviceId string) error {
if err := repository.DB.Delete(&model.IOSDevice{DeviceID: deviceId}).Error; err != nil {
return err
}

return nil
return repository.DB.Delete(&model.IOSDevice{DeviceID: deviceId}).Error
}

func (repository *Repository) GetDevices() ([]model.IOSDevice, error) {
Expand All @@ -72,8 +66,14 @@ func (repository *Repository) GetDevice(id string) (*model.IOSDevice, error) {
func (repository *Repository) GetDevicesThatShouldUpdateGrades() ([]model.IOSDeviceLastUpdated, error) {
var devices []model.IOSDeviceLastUpdated

tx := repository.DB.Raw(
buildDevicesThatShouldUpdateGradesQuery(),
tx := repository.DB.Raw(`select d.device_id, ul.created_at as last_updated, d.public_key
from ios_devices d
left join ios_scheduled_update_logs ul on d.device_id = ul.device_id
where ul.created_at is null
or (ul.type = ?
and ul.created_at < date_sub(now(), interval ? minute))
group by d.device_id, ul.created_at
order by ul.created_at`,
model.IOSUpdateTypeGrades,
model.IOSMinimumUpdateInterval,
).Scan(&devices)
Expand All @@ -85,19 +85,6 @@ func (repository *Repository) GetDevicesThatShouldUpdateGrades() ([]model.IOSDev
return devices, nil
}

func buildDevicesThatShouldUpdateGradesQuery() string {
return `
select d.device_id, ul.created_at as last_updated, d.public_key
from ios_devices d
left join ios_scheduled_update_logs ul on d.device_id = ul.device_id
where ul.created_at is null
or (ul.type = ?
and ul.created_at < date_sub(now(), interval ? minute))
group by d.device_id, ul.created_at
order by ul.created_at;
`
}

func (repository *Repository) ResetDevicesDailyActivity() error {
return repository.DB.Model(model.IOSDevice{}).Where("activity_today != ?", 0).Update("activity_today", 0).Error
}
Expand Down
7 changes: 2 additions & 5 deletions server/backend/ios_notifications/device/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ type Service struct {
}

var (
ErrCouldNotCreateDevice = status.Error(codes.Internal, "Could not create device")
ErrCouldNotDeleteDevice = status.Error(codes.Internal, "Could not delete device")

iosRegisteredDevices = promauto.NewGauge(prometheus.GaugeOpts{
Subsystem: "ios",
Name: "ios_created_devices",
Expand All @@ -33,7 +30,7 @@ func (service *Service) CreateDevice(request *pb.CreateDeviceRequest) (*pb.Creat
}

if err := service.Repository.CreateDevice(&device); err != nil {
return nil, ErrCouldNotCreateDevice
return nil, status.Error(codes.Internal, "Could not create device")
}
iosRegisteredDevices.Inc()

Expand All @@ -44,7 +41,7 @@ func (service *Service) CreateDevice(request *pb.CreateDeviceRequest) (*pb.Creat

func (service *Service) DeleteDevice(request *pb.DeleteDeviceRequest) (*pb.DeleteDeviceReply, error) {
if err := service.Repository.DeleteDevice(request.GetDeviceId()); err != nil {
return nil, ErrCouldNotDeleteDevice
return nil, status.Error(codes.Internal, "Could not delete device")
}

iosRegisteredDevices.Dec()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ type Repository struct {

func (repo *Repository) GetDevicesActivityResets() ([]model.IOSDevicesActivityReset, error) {
var resets []model.IOSDevicesActivityReset

err := repo.DB.Find(&resets).Error

if err != nil {
if err := repo.DB.Find(&resets).Error; err != nil {
return nil, err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,22 @@ type Service struct {

func (service *Service) HandleScheduledActivityReset() error {
daily, err := service.Repository.GetDevicesActivityResetDaily()

if err != nil {
service.Repository.CreateInitialRecords()

return nil
}

weekly, err := service.Repository.GetDevicesActivityResetWeekly()

if err != nil {
return err
}

monthly, err := service.Repository.GetDevicesActivityResetMonthly()

if err != nil {
return err
}

yearly, err := service.Repository.GetDevicesActivityResetYearly()

if err != nil {
return err
}
Expand Down
Loading

0 comments on commit c32f57c

Please sign in to comment.