DNN 多分类

对iris数据集合进行多酚类


library(tfestimators)

response <- function() "Species"
features <- function() setdiff(names(iris), response())

# split into train, test datasets
set.seed(123)
partitions <- modelr::resample_partition(iris, c(test = 0.2, train = 0.8))
iris_train <- as.data.frame(partitions$train)
iris_test  <- as.data.frame(partitions$test)

# construct feature columns
feature_columns <- feature_columns(
  column_numeric(features())
)

# construct classifier
classifier <- dnn_classifier(
  feature_columns = feature_columns,
  hidden_units = c(10, 20, 10),
  n_classes = 3
)

# construct input function 
iris_input_fn <- function(data) {
  input_fn(data, features = features(), response = response())
}

# train classifier with training dataset
train(classifier, input_fn = iris_input_fn(iris_train))
The following factor levels of 'Species' have been encoded:
- 'setosa' => 0
- 'versicolor' => 1
- 'virginica' => 2
2018-02-18 16:13:55.750304: E tensorflow/core/util/events_writer.cc:162] The events file /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local has disappeared.
2018-02-18 16:13:55.750358: E tensorflow/core/util/events_writer.cc:131] Failed to flush 1 events to /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local
[-] Training -- loss: 143.24, step: 1
2018-02-18 16:13:56.963146: E tensorflow/core/util/events_writer.cc:162] The events file /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local has disappeared.
2018-02-18 16:13:56.963224: E tensorflow/core/util/events_writer.cc:131] Failed to flush 5 events to /var/folders/jz/qf7zhsc97f71slzzf59mvs2w0000gn/T/tmpfg11dt9y/events.out.tfevents.1518941432.MideMacBook-Air.local
# valuate with test dataset

predictions <- predict(classifier, input_fn = iris_input_fn(iris_test))
predictions
# A tibble: 29 x 4
   logits                probabilities         classes class_ids
                                        
 1 <-1.44, 0.525, -1.15> <0.105, 0.753, 0.142> <1>     <1>      
 2 <-1.37, 0.516, -1.15> <0.113, 0.746, 0.141> <1>     <1>      
 3 <-1.29, 0.475, -1.05> <0.123, 0.72, 0.157>  <1>     <1>      
 4 <-1.28, 0.469, -1.03> <0.125, 0.716, 0.159> <1>     <1>      
 5 <-1.28, 0.464, -1.02> <0.125, 0.713, 0.162> <1>     <1>      
 6 <-1.28, 0.485, -1.09> <0.124, 0.726, 0.15>  <1>     <1>      
 7 <-1.32, 0.486, -1.07> <0.119, 0.728, 0.153> <1>     <1>      
 8 <-1.25, 0.462, -1.03> <0.128, 0.711, 0.161> <1>     <1>      
 9 <-1.45, 0.526, -1.14> <0.104, 0.754, 0.142> <1>     <1>      
10 <-1.22, 0.44, -0.963> <0.132, 0.697, 0.171> <1>     <1>      
# ... with 19 more rows
evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_test))
The following factor levels of 'Species' have been encoded:
- 'setosa' => 0
- 'versicolor' => 1
- 'virginica' => 2
[-] Evaluating -- loss: 43.15, step: 1
> evaluation
# A tibble: 1 x 4
  average_loss accuracy  loss global_step
                     
1         1.49    0.345  43.2        2.00

你可能感兴趣的:(DNN 多分类)