accelbrainbase.observabledata._mxnet.functionapproximator package¶
Submodules¶
accelbrainbase.observabledata._mxnet.functionapproximator.function_approximator module¶
-
class
accelbrainbase.observabledata._mxnet.functionapproximator.function_approximator.
FunctionApproximator
(model, initializer=None, learning_rate=1e-05, optimizer_name='SGD', hybridize_flag=True, scale=1.0, ctx=gpu(0), **kwargs)¶ Bases:
mxnet.gluon.block.HybridBlock
,accelbrainbase.observabledata._mxnet.function_approximator.FunctionApproximator
The function approximator for the Deep Q-Learning.
The convolutional neural networks(CNNs) are hierarchical models whose convolutional layers alternate with subsampling layers, reminiscent of simple and complex cells in the primary visual cortex.
Mainly, this class demonstrates that a CNNs can solve generalisation problems to learn successful control policies from observed data points in complex Reinforcement Learning environments. The network is trained with a variant of the Q-learning algorithm, with stochastic gradient descent to update the weights.
But there is no need for the function approximator to be a CNNs. We probide this interface that implements various models as function approximations, not limited to CNNs.
References
- Dumoulin, V., & V,kisin, F. (2016). A guide to convolution arithmetic for deep learning. arXiv preprint arXiv:1603.07285.
- Masci, J., Meier, U., Cireşan, D., & Schmidhuber, J. (2011, June). Stacked convolutional auto-encoders for hierarchical feature extraction. In International Conference on Artificial Neural Networks (pp. 52-59). Springer, Berlin, Heidelberg.
- Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., & Riedmiller, M. (2013). Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602.
-
collect_params
(select=None)¶ Overrided collect_params in mxnet.gluon.HybridBlok.
-
forward_propagation
(F, x)¶ Hybrid forward with Gluon API.
Parameters: - F – mxnet.ndarray or mxnet.symbol.
- x – mxnet.ndarray of observed data points.
Returns: mxnet.ndarray or mxnet.symbol of inferenced feature points.
-
get_model
()¶ getter for mxnet.gluon.hybrid.hybridblock.HybridBlock.
-
hybrid_forward
(F, x)¶ Hybrid forward with Gluon API.
Parameters: - F – mxnet.ndarray or mxnet.symbol.
- x – mxnet.ndarray of observed data points.
Returns: mxnet.ndarray or mxnet.symbol of inferenced feature points.
-
inference
(observed_arr)¶ Draw samples from the fake distribution.
Parameters: observed_arr – mxnet.ndarray or mxnet.symbol of observed data points. Returns: Tuple of `mxnet.ndarray`s.
-
model
¶ getter for mxnet.gluon.hybrid.hybridblock.HybridBlock.
-
set_model
(value)¶ setter for mxnet.gluon.hybrid.hybridblock.HybridBlock.