You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Sometimes it would be useful to be able to trace intermediate values computed as part of a JointDistribution* model. The current solution to this is to make use of tfd.Deterministic. As an example, supposing we would like to trace the mean of a simple linear regression on a single feature, a user might write:
A call to model.sample() will return a named tuple including the value of mean, conditional on alpha and beta (and feature). However, if we wish to compute the log probability density of the model given response, intercept, and slope, we also have to pass into model.log_prob a value of mean. Here, mean must be consistent with intercept and slope, which requires the user to duplicate the expression for meanoutside the model object. e.g.
intercept=0.1slope=0.2feature=0.5response=0.21# Since `mean` is deterministic, we should not have to re-compute it outside of `model`mean=intercept+slope*featurelp=model.log_prob(intercept=intercept, slope=slope, mean=mean, response=response)
This seems wasteful in terms of keystrokes, but also error-prone if model changes.
Suggested solution
A potential solution would be to include a sub-class similar to JointDistribution.Root called JointDistribution.Trace which would flag an expression for tracing in the forward generating process (i.e. model.sample()), but exclude the associated variable from the CDF/CMF and PDF/PMF-related methods. Thus we could write:
@tfd.JointDistributionCoroutinedefmodel():
intercept=yieldtfd.Normal(loc=0.0, scale=1.0, name="intercept")
slope=yieldtfd.Normal(loc=0.0, scale=1.0, name="slope")
mean=Trace(intercept+slope*feature, name="mean")
yieldtfd.Normal(loc=mean, scale=1.0, name="response")
draw=model.sample(seed=[0,0])
# `mean` is simply ignoredmodel.log_prob(draw)
# `mean` does not have to be suppliedmodel.log_prob(intercept=draw.intercept, slope=draw.slope, response=draw.response)
Does this seem like a feasible addition? (I may have some resource to devote to it)
The text was updated successfully, but these errors were encountered:
Background
Sometimes it would be useful to be able to trace intermediate values computed as part of a
JointDistribution*
model. The current solution to this is to make use oftfd.Deterministic
. As an example, supposing we would like to trace the mean of a simple linear regression on a singlefeature
, a user might write:A call to
model.sample()
will return a named tuple including the value ofmean
, conditional onalpha
andbeta
(andfeature
). However, if we wish to compute the log probability density of the model givenresponse
,intercept
, andslope
, we also have to pass intomodel.log_prob
a value ofmean
. Here,mean
must be consistent withintercept
andslope
, which requires the user to duplicate the expression formean
outside the model object. e.g.This seems wasteful in terms of keystrokes, but also error-prone if
model
changes.Suggested solution
A potential solution would be to include a sub-class similar to
JointDistribution.Root
calledJointDistribution.Trace
which would flag an expression for tracing in the forward generating process (i.e.model.sample()
), but exclude the associated variable from the CDF/CMF and PDF/PMF-related methods. Thus we could write:Does this seem like a feasible addition? (I may have some resource to devote to it)
The text was updated successfully, but these errors were encountered: