Hyperparameter tuning for DNNs tends to be a bit more involved than other ML models due to the number of hyperparameters that can/should be assessed and the dependencies between these parameters. To automate the hyperparameter tuning for keras and tensorflow, we use the tfruns package.

This notebook shows an example of performing a grid search on a densley connected feedforward neural network for the IMDB movie review classifier.

library(tfruns)
library(dplyr)

tfruns provides added flexibility for tracking, visualizing, and managing training runs. The most common way to use tfruns is to create an R script that contains the code to be executed. For this example, I created the imdb-grid-search.R script.

Within this script, we create “flags”, which identify the hyperparameters of interest. For this example, I assess:

This equates to 324 models. Since these models run relatively quickly, and since I ran this at the end of the day, I performed a full cartesian grid search. This means I ran and assessed every single one of the 324 models. If you are pressed for time you can run a stochastic hyperparameter grid search by using the sample parameter in tuning_run below.

Next, to execute this grid search, you first specify the grid search hyperparameter values for the flags you created in your .R script like below. You then run tuning_run() in place of source() to execute your .R script for the supplied hyperparameter grid.

Note: this takes over 3 hours to run on a non-GPU

grid_search <- list(
  batch_size = c(128, 512),
  layers = c(1, 2, 3),
  units = c(16, 32, 64),
  learning_rate = c(0.001, 0.0001),
  dropout = c(0, 0.3, 0.5),
  weight_decay = c(0, 0.01, 0.001)
)

tuning_run(
  "imdb-grid-search.R",
  runs_dir = "imdb_runs",
  flags = grid_search, 
  confirm = FALSE, 
  echo = FALSE
  )

This grid search execution will create a “runs” subdirectory within your working directory. This fold contains information for every single training run executed during the grid search.

To list the results you can run ls_runs():

data.frame(ls_runs(runs_dir = "imdb_runs"))

You can even filter and order the results. The following illustrates that a few models tied with the lowest loss score of 0.262.

ls_runs(runs_dir = "imdb_runs", order = eval_best_loss, decreasing = FALSE)
Data frame: 324 x 31 
# ... with 314 more rows
# ... with 23 more columns:
#   flag_batch_size, flag_layers, flag_units, flag_learning_rate, flag_dropout,
#   flag_weight_decay, samples, batch_size, epochs, epochs_completed, metrics, model,
#   loss_function, optimizer, learning_rate, script, start, end, completed, output,
#   source_code, context, type

To see details about any one of these models you can run view_run(). In this example, I take the first optimal model from above. When you execute this, a pop up window will appear with that models summary information as illustrated below.

best_run <- ls_runs(
  runs_dir = "imdb_runs",
  order = eval_best_loss,
  decreasing = FALSE
  ) %>%
  slice(1) %>%
  pull(run_dir)

view_run(best_run)

There are many other handy features to the tfruns package. I suggest you check it out at https://tensorflow.rstudio.com/tools/tfruns/overview/ and take it for a test drive.

LS0tCnRpdGxlOiAiSU1EQiBNb3ZpZSBDbGFzc2lmaWNhdGlvbiBHcmlkIFNlYXJjaCIKb3V0cHV0OiBodG1sX25vdGVib29rCi0tLQoKYGBge3Igc2V0dXAsIGluY2x1ZGU9RkFMU0V9CmtuaXRyOjpvcHRzX2NodW5rJHNldChlY2hvID0gVFJVRSkKZ2dwbG90Mjo6dGhlbWVfc2V0KGdncGxvdDI6OnRoZW1lX21pbmltYWwoKSkKCiMgY2xlYW4gdXAgaW4gY2FzZSB5b3UgcnVuIHRoaXMgbXVsdGlwbGUgdGltZXMKI3RmcnVuczo6Y2xlYW5fcnVucyhjb25maXJtID0gRkFMU0UpCmBgYAoKSHlwZXJwYXJhbWV0ZXIgdHVuaW5nIGZvciBETk5zIHRlbmRzIHRvIGJlIGEgYml0IG1vcmUgaW52b2x2ZWQgdGhhbiBvdGhlciBNTAptb2RlbHMgZHVlIHRvIHRoZSBudW1iZXIgb2YgaHlwZXJwYXJhbWV0ZXJzIHRoYXQgY2FuL3Nob3VsZCBiZSBhc3Nlc3NlZCBhbmQgdGhlCmRlcGVuZGVuY2llcyBiZXR3ZWVuIHRoZXNlIHBhcmFtZXRlcnMuIFRvIGF1dG9tYXRlIHRoZSBoeXBlcnBhcmFtZXRlciB0dW5pbmcKZm9yIGtlcmFzIGFuZCB0ZW5zb3JmbG93LCB3ZSB1c2UgdGhlIFtfX3RmcnVuc19fXShodHRwczovL2dpdGh1Yi5jb20vcnN0dWRpby90ZnJ1bnMpCnBhY2thZ2UuCgpUaGlzIG5vdGVib29rIHNob3dzIGFuIGV4YW1wbGUgb2YgcGVyZm9ybWluZyBhIGdyaWQgc2VhcmNoIG9uIGEgZGVuc2xleSBjb25uZWN0ZWQKZmVlZGZvcndhcmQgbmV1cmFsIG5ldHdvcmsgZm9yIHRoZSBJTURCIG1vdmllIHJldmlldyBjbGFzc2lmaWVyLgoKYGBge3IsIG1lc3NhZ2U9RkFMU0UsIHdhcm5pbmc9RkFMU0V9CmxpYnJhcnkodGZydW5zKQpsaWJyYXJ5KGRwbHlyKQpgYGAKCnRmcnVucyBwcm92aWRlcyBhZGRlZCBmbGV4aWJpbGl0eSBmb3IgdHJhY2tpbmcsIHZpc3VhbGl6aW5nLCBhbmQgbWFuYWdpbmcgCnRyYWluaW5nIHJ1bnMuIFRoZSBtb3N0IGNvbW1vbiB3YXkgdG8gdXNlIHRmcnVucyBpcyB0byBjcmVhdGUgYW4gUiBzY3JpcHQgdGhhdApjb250YWlucyB0aGUgY29kZSB0byBiZSBleGVjdXRlZC4gRm9yIHRoaXMgZXhhbXBsZSwgSSBjcmVhdGVkIHRoZQpbaW1kYi1ncmlkLXNlYXJjaC5SXShodHRwczovL2dpdGh1Yi5jb20vbWlzay1kYXRhLXNjaWVuY2UvbWlzay1kbC9ibG9iL21hc3Rlci9tYXRlcmlhbHMvOTktZXh0cmFzL2ltZGItZ3JpZC1zZWFyY2guUikKc2NyaXB0LgoKV2l0aGluIHRoaXMgc2NyaXB0LCB3ZSBjcmVhdGUgImZsYWdzIiwgd2hpY2ggaWRlbnRpZnkgdGhlIGh5cGVycGFyYW1ldGVycyBvZgppbnRlcmVzdC4gRm9yIHRoaXMgZXhhbXBsZSwgSSBhc3Nlc3M6CgotIGJhdGNoIHNpemVzIG9mIDEyOCBhbmQgNTEyCi0gbGF5ZXJzIG9mIDEsIDIsIGFuZCAzCi0gbnVtYmVyIG9mIHVuaXRzIHBlciBoaWRkZW4gbGF5ZXIgb2YgMTYsIDMyLCBhbmQgNjQKLSBsZWFybmluZyByYXRlIG9mIDAuMDAxIGFuZCAwLjAwMDEKLSBkcm9wb3V0IHJhdGVzIG9mIDAsIDAuMywgYW5kIDAuNQotIHdlaWdodCBkZWNheSBvZiAwLCAwLjAxLCBhbmQgMC4wMDEKClRoaXMgZXF1YXRlcyB0byAzMjQgbW9kZWxzLiBTaW5jZSB0aGVzZSBtb2RlbHMgcnVuIHJlbGF0aXZlbHkgcXVpY2tseSwgYW5kIHNpbmNlCkkgcmFuIHRoaXMgYXQgdGhlIGVuZCBvZiB0aGUgZGF5LCBJIHBlcmZvcm1lZCBhIGZ1bGwgY2FydGVzaWFuIGdyaWQgc2VhcmNoLiBUaGlzCm1lYW5zIEkgcmFuIGFuZCBhc3Nlc3NlZCBldmVyeSBzaW5nbGUgb25lIG9mIHRoZSAzMjQgbW9kZWxzLiBJZiB5b3UgYXJlIHByZXNzZWQKZm9yIHRpbWUgeW91IGNhbiBydW4gYSBzdG9jaGFzdGljIGh5cGVycGFyYW1ldGVyIGdyaWQgc2VhcmNoIGJ5IHVzaW5nIHRoZQpgc2FtcGxlYCBwYXJhbWV0ZXIgaW4gYHR1bmluZ19ydW5gIGJlbG93LgoKTmV4dCwgdG8gZXhlY3V0ZSB0aGlzIGdyaWQgc2VhcmNoLCB5b3UgZmlyc3Qgc3BlY2lmeSB0aGUgZ3JpZCBzZWFyY2ggaHlwZXJwYXJhbWV0ZXIKdmFsdWVzIGZvciB0aGUgZmxhZ3MgeW91IGNyZWF0ZWQgaW4geW91ciAuUiBzY3JpcHQgbGlrZSBiZWxvdy4gWW91IHRoZW4gcnVuCmB0dW5pbmdfcnVuKClgIGluIHBsYWNlIG9mIGBzb3VyY2UoKWAgdG8gZXhlY3V0ZSB5b3VyIC5SIHNjcmlwdCBmb3IgdGhlIHN1cHBsaWVkCmh5cGVycGFyYW1ldGVyIGdyaWQuCgpfX19Ob3RlOiB0aGlzIHRha2VzIG92ZXIgMyBob3VycyB0byBydW4gb24gYSBub24tR1BVX19fCgpgYGB7ciwgbWVzc2FnZT1GQUxTRSwgd2FybmluZz1GQUxTRSwgZXZhbD1GQUxTRX0KZ3JpZF9zZWFyY2ggPC0gbGlzdCgKICBiYXRjaF9zaXplID0gYygxMjgsIDUxMiksCiAgbGF5ZXJzID0gYygxLCAyLCAzKSwKICB1bml0cyA9IGMoMTYsIDMyLCA2NCksCiAgbGVhcm5pbmdfcmF0ZSA9IGMoMC4wMDEsIDAuMDAwMSksCiAgZHJvcG91dCA9IGMoMCwgMC4zLCAwLjUpLAogIHdlaWdodF9kZWNheSA9IGMoMCwgMC4wMSwgMC4wMDEpCikKCnR1bmluZ19ydW4oCiAgImltZGItZ3JpZC1zZWFyY2guUiIsCiAgcnVuc19kaXIgPSAiaW1kYl9ydW5zIiwKICBmbGFncyA9IGdyaWRfc2VhcmNoLCAKICBjb25maXJtID0gRkFMU0UsIAogIGVjaG8gPSBGQUxTRQogICkKYGBgCgpUaGlzIGdyaWQgc2VhcmNoIGV4ZWN1dGlvbiB3aWxsIGNyZWF0ZSBhICJydW5zIiBzdWJkaXJlY3Rvcnkgd2l0aGluIHlvdXIgd29ya2luZwpkaXJlY3RvcnkuIFRoaXMgZm9sZCBjb250YWlucyBpbmZvcm1hdGlvbiBmb3IgZXZlcnkgc2luZ2xlIHRyYWluaW5nIHJ1biBleGVjdXRlZApkdXJpbmcgdGhlIGdyaWQgc2VhcmNoLgoKVG8gbGlzdCB0aGUgcmVzdWx0cyB5b3UgY2FuIHJ1biBgbHNfcnVucygpYDoKCmBgYHtyfQpkYXRhLmZyYW1lKGxzX3J1bnMocnVuc19kaXIgPSAiaW1kYl9ydW5zIikpCmBgYAoKWW91IGNhbiBldmVuIGZpbHRlciBhbmQgb3JkZXIgdGhlIHJlc3VsdHMuIFRoZSBmb2xsb3dpbmcgaWxsdXN0cmF0ZXMgdGhhdCBhIGZldwptb2RlbHMgdGllZCB3aXRoIHRoZSBsb3dlc3QgbG9zcyBzY29yZSBvZiAwLjI2Mi4KCmBgYHtyfQpsc19ydW5zKHJ1bnNfZGlyID0gImltZGJfcnVucyIsIG9yZGVyID0gZXZhbF9iZXN0X2xvc3MsIGRlY3JlYXNpbmcgPSBGQUxTRSkKYGBgCgpUbyBzZWUgZGV0YWlscyBhYm91dCBhbnkgb25lIG9mIHRoZXNlIG1vZGVscyB5b3UgY2FuIHJ1biBgdmlld19ydW4oKWAuIEluIHRoaXMKZXhhbXBsZSwgSSB0YWtlIHRoZSBmaXJzdCBvcHRpbWFsIG1vZGVsIGZyb20gYWJvdmUuIFdoZW4geW91IGV4ZWN1dGUgdGhpcywKYSBwb3AgdXAgd2luZG93IHdpbGwgYXBwZWFyIHdpdGggdGhhdCBtb2RlbHMgc3VtbWFyeSBpbmZvcm1hdGlvbiBhcyBpbGx1c3RyYXRlZApiZWxvdy4KCmBgYHtyLCBldmFsPUZBTFNFfQpiZXN0X3J1biA8LSBsc19ydW5zKAogIHJ1bnNfZGlyID0gImltZGJfcnVucyIsCiAgb3JkZXIgPSBldmFsX2Jlc3RfbG9zcywKICBkZWNyZWFzaW5nID0gRkFMU0UKICApICU+JQogIHNsaWNlKDEpICU+JQogIHB1bGwocnVuX2RpcikKCnZpZXdfcnVuKGJlc3RfcnVuKQpgYGAKCmBgYHtyLCBlY2hvPUZBTFNFfQprbml0cjo6aW5jbHVkZV9ncmFwaGljcygiLi4vLi4vZG9jcy9pbWFnZXMvaW1kYl9ncmlkX3NlYXJjaF9iZXN0X21vZGVsLnBuZyIpCmBgYAoKVGhlcmUgYXJlIG1hbnkgb3RoZXIgaGFuZHkgZmVhdHVyZXMgdG8gdGhlIHRmcnVucyBwYWNrYWdlLiBJIHN1Z2dlc3QgeW91IGNoZWNrCml0IG91dCBhdCBodHRwczovL3RlbnNvcmZsb3cucnN0dWRpby5jb20vdG9vbHMvdGZydW5zL292ZXJ2aWV3LyBhbmQgdGFrZSBpdCBmb3IKYSB0ZXN0IGRyaXZlLgo=