accelbrainbase.observabledata._mxnet.functionapproximator package

Submodules

accelbrainbase.observabledata._mxnet.functionapproximator.function_approximator module

class accelbrainbase.observabledata._mxnet.functionapproximator.function_approximator.FunctionApproximator(model, initializer=None, learning_rate=1e-05, optimizer_name='SGD', hybridize_flag=True, scale=1.0, ctx=gpu(0), **kwargs)

Bases: mxnet.gluon.block.HybridBlock, accelbrainbase.observabledata._mxnet.function_approximator.FunctionApproximator

The function approximator for the Deep Q-Learning.

The convolutional neural networks(CNNs) are hierarchical models whose convolutional layers alternate with subsampling layers, reminiscent of simple and complex cells in the primary visual cortex.

Mainly, this class demonstrates that a CNNs can solve generalisation problems to learn successful control policies from observed data points in complex Reinforcement Learning environments. The network is trained with a variant of the Q-learning algorithm, with stochastic gradient descent to update the weights.

But there is no need for the function approximator to be a CNNs. We probide this interface that implements various models as function approximations, not limited to CNNs.

References

  • Dumoulin, V., & V,kisin, F. (2016). A guide to convolution arithmetic for deep learning. arXiv preprint arXiv:1603.07285.
  • Masci, J., Meier, U., Cireşan, D., & Schmidhuber, J. (2011, June). Stacked convolutional auto-encoders for hierarchical feature extraction. In International Conference on Artificial Neural Networks (pp. 52-59). Springer, Berlin, Heidelberg.
  • Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., & Riedmiller, M. (2013). Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602.
collect_params(select=None)

Overrided collect_params in mxnet.gluon.HybridBlok.

forward_propagation(F, x)

Hybrid forward with Gluon API.

Parameters:
  • Fmxnet.ndarray or mxnet.symbol.
  • xmxnet.ndarray of observed data points.
Returns:

mxnet.ndarray or mxnet.symbol of inferenced feature points.

get_model()

getter for mxnet.gluon.hybrid.hybridblock.HybridBlock.

hybrid_forward(F, x)

Hybrid forward with Gluon API.

Parameters:
  • Fmxnet.ndarray or mxnet.symbol.
  • xmxnet.ndarray of observed data points.
Returns:

mxnet.ndarray or mxnet.symbol of inferenced feature points.

inference(observed_arr)

Draw samples from the fake distribution.

Parameters:observed_arrmxnet.ndarray or mxnet.symbol of observed data points.
Returns:Tuple of `mxnet.ndarray`s.
model

getter for mxnet.gluon.hybrid.hybridblock.HybridBlock.

set_model(value)

setter for mxnet.gluon.hybrid.hybridblock.HybridBlock.

Module contents