There could be an issue with sampling due to (my) confusion about standard deviation and variance.
The samples are drawn using numpy like so (documentation) (line 238 of __init__.py)
sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)
But the output from the mixture density layer are treated as scale variables in tfp.distributions.MultivariateNormalDiag. This notes that:
covariance = scale @ scale.T
Thus, it seems we should have been squaring the cov_matrix before putting it into the multivariate normal sampling procedure. This could explain why we end up having to scale down the sigma variable so much in real-world applications.
A todo here is to get a definite answer and do some test to try out what's going on.