-
Notifications
You must be signed in to change notification settings - Fork 0
/
2022-05-09--conda-test.R
53 lines (38 loc) · 1.1 KB
/
2022-05-09--conda-test.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env Rscript
library(reticulate)
library(keras)
reticulate::use_condaenv(condaenv = 'pyrstudio')
#Keras::install_keras(method='conda', conda='pyrstudio'
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_train <- x_train / 255
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
x_test <- x_test / 255
y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)
model <- keras_model_sequential()
model |>
layer_dense(units = 256, activation = 'gelu', input_shape = c(784)) |>
layer_dropout(rate = 0.4) |>
layer_dense(units = 128, activation = 'gelu') |>
layer_dropout(rate = 0.3) |>
layer_dense(units = 10, activation = 'softmax')
summary(model)
model |>
compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy'))
history <-
model |>
fit(
x_train, y_train,
epochs = 30, batch_size = 128,
validation_split = 0.2
)
plot(history)
model %>% evaluate(x_test, y_test)