Source code for pygan.gansvaluefunction.mini_max
# -*- coding: utf-8 -*-
import numpy as np
from pygan.gans_value_function import GANsValueFunction
[docs]class MiniMax(GANsValueFunction):
'''
Value function in GANs framework.
'''
[docs] def compute_discriminator_reward(
self,
true_posterior_arr,
generated_posterior_arr
):
'''
Compute discriminator's reward.
Args:
true_posterior_arr: `np.ndarray` of `true` posterior inferenced by the discriminator.
generated_posterior_arr: `np.ndarray` of `fake` posterior inferenced by the discriminator.
Returns:
`np.ndarray` of Gradients.
'''
grad_arr = np.log(true_posterior_arr + 1e-08) + np.log(1 - generated_posterior_arr + 1e-08)
return grad_arr
[docs] def compute_generator_reward(
self,
generated_posterior_arr
):
'''
Compute generator's reward.
Args:
generated_posterior_arr: `np.ndarray` of `fake` posterior inferenced by the discriminator.
Returns:
`np.ndarray` of Gradients.
'''
grad_arr = np.log(1 - generated_posterior_arr + 1e-08)
return grad_arr