-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb.go
160 lines (142 loc) · 4.24 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
package main
import (
"database/sql"
"fmt"
"log"
"net/http"
"time"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
)
var DB *sql.DB
func SetupDB() {
var err error
DB, err = sql.Open("sqlite3", DB_FILE)
if err != nil {
log.Fatal(err)
}
_, err = DB.Exec(
`CREATE TABLE IF NOT EXISTS files
(id INTEGER NOT NULL PRIMARY KEY, hash BLOB NOT NULL UNIQUE, data BLOB, filename TEXT, user_id INTEGER NOT NULL, created_at INTEGER NOT NULL);
CREATE TABLE IF NOT EXISTS users
(id INTEGER NOT NULL PRIMARY KEY, username TEXT NOT NULL UNIQUE, email TEXT NOT NULL, password TEXT NOT NULL, role TEXT NOT NULL, created_at INTEGER NOT NULL);
CREATE TABLE IF NOT EXISTS sessions
(id INTEGER NOT NULL PRIMARY KEY, session_hash BLOB NOT NULL UNIQUE, user_id INTEGER NOT NULL, created_at INTEGER NOT NULL)`,
)
if err != nil {
log.Fatal(err)
}
// Create default user admin|admin if not exists
rows_num := 0
err = DB.QueryRow("SELECT COUNT(*) FROM users").Scan(&rows_num)
if err != nil {
log.Fatal(err)
}
if rows_num == 0 {
passwdHash, err := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
if err != nil {
log.Fatal(err)
}
_, err = DB.Exec(
"INSERT INTO users (username, email, password, role, created_at) VALUES (?, ?, ?, ?, ?)",
"admin", "", passwdHash, "admin", time.Now().Unix(),
)
if err != nil {
log.Fatal(err)
}
}
}
func CreateUser(u *User) error {
hashedPass, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
_, err = DB.Exec(
"INSERT INTO users (username, email, password, role, created_at) VALUES (?, ?, ?, ?, ?)",
u.Name, u.Email, hashedPass, "admin", time.Now().Unix(),
)
return err
}
func UpdateUser(u *User) error {
hashedPass, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
_, err = DB.Exec(
"UPDATE users SET username=?, email=?, password=? WHERE id=?",
u.Name, u.Email, hashedPass, u.ID,
)
return err
}
func GetUserByName(username string) (*User, error) {
u := new(User)
err := DB.QueryRow(
"SELECT id, username, email, password, role FROM users WHERE username=?", username,
).Scan(&u.ID, &u.Name, &u.Email, &u.Password, &u.Role)
return u, err
}
func CreateEntry(hash []byte, data []byte, filename string, u *User) error {
_, err := DB.Exec(
"INSERT INTO files (hash, data, filename, user_id, created_at) VALUES (?, ?, ?, ?, ?)",
hash, data, filename, u.ID, time.Now().Unix(),
)
return err
}
func DeleteEntryByID(id uint64) error {
_, err := DB.Exec("DELETE FROM files WHERE id=?", id)
return err
}
func GetSessionUser(sessionHash []byte) (*User, error) {
u := new(User)
err := DB.QueryRow(
"SELECT id, username, email, role FROM users WHERE id=(SELECT user_id FROM SESSIONS WHERE session_hash=?)", sessionHash,
).Scan(&u.ID, &u.Name, &u.Email, &u.Role)
return u, err
}
func GetFile(r *http.Request, hash []byte) (*Entry, error) {
file := new(Entry)
err := DB.QueryRow(
"SELECT id, hash, data, filename FROM files WHERE hash=?", hash,
).Scan(&file.ID, &file.Hash, &file.Data, &file.Filename)
file.URL = fmt.Sprintf("http://%s/%x", r.Host, hash)
return file, err
}
func GetUserFilesList(r *http.Request, u *User) []Entry {
entry := Entry{}
list := []Entry{}
rows, err := DB.Query(
"SELECT id, filename, hash, created_at FROM files WHERE user_id=? ORDER BY id DESC",
u.ID,
)
if err != nil {
log.Println(err)
}
defer rows.Close()
for rows.Next() {
rows.Scan(&entry.ID, &entry.Filename, &entry.Hash, &entry.CreatedAt)
entry.URL = fmt.Sprintf("http://%s/%x", r.Host, entry.Hash)
list = append(list, entry)
}
return list
}
func DeleteSessionFromDB(sessionHash []byte) error {
_, err := DB.Exec("DELETE FROM sessions WHERE session_hash=?", sessionHash)
return err
}
func CreateSessionInDB(sessionHash []byte, user *User) error {
_, err := DB.Exec(
"INSERT INTO sessions (session_hash, user_id, created_at) VALUES (?, ?, ?)",
sessionHash, user.ID, time.Now().Unix(),
)
return err
}
// Sessions clean up goroutine
func SessionCleaner() {
for {
_, err := DB.Exec("DELETE FROM sessions WHERE created_at < ?", time.Now().Unix()-SESSION_MAX_AGE)
if err != nil {
log.Println(err)
}
time.Sleep(time.Second * 300)
}
}