Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ProbabilisticActor for discreate action used with parallel env creates wrong action dim #2572

Open
jasorsi13 opened this issue Nov 15, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@jasorsi13
Copy link

########Bug

`def make_env():
base_env = GymEnv('CartPole-v1', device=device)

# Create a transformed environment with observation normalization, float conversion, and step counter
transformed_env = TransformedEnv(
    base_env,
    Compose(
        #ObservationNorm(in_keys=["observation"]),  # Normalize observations
        #DoubleToFloat(),  # Ensure observations are float32
        StepCounter()  # Track steps in the environment
    )
)

# Initialize stats specifically for ObservationNorm (transform[0])
#transformed_env.transform[0].init_stats(1024)  # Collect stats from 1024 frames for normalization

return transformed_env

parallel_env = ParallelEnv(num_workers= 7, create_env_fn = lambda : make_env())
check_env_specs(parallel_env)

# Get observation and action dimensions from the environment specs
observation_dim = parallel_env.observation_spec["observation"].shape[-1]
action_dim = parallel_env.action_spec.shape[-1]

Define the actor network using nn.Sequential (outputs logits for a categorical distribution)

class ActorNet(nn.Module):
    def __init__(self, observation_dim, action_dim):
        super(ActorNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(observation_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)  # Outputs logits for each action in the discrete space
        )

    def forward(self, x):
        print("actor net input shape  ",x.shape)

        temp=  self.net(x)
        print("actor net output shape  ",temp.shape)
        
        return temp

# Instantiate the actor network
actor_net = ActorNet(observation_dim, action_dim).to(device)

Wrap actor network in TensorDictModule to work with TorchRL's TensorDict system

tensordict_module = TensorDictModule(actor_net, in_keys=["observation"], out_keys=["logits"])
actor_module = ProbabilisticActor(
    module=tensordict_module,
    spec=parallel_env.action_spec,
    in_keys=["logits"],
    out_keys=["action"],
    distribution_class=Categorical,
    return_log_prob=True  # Return log probability of sampled actions (required for PPO loss)
)

error on rollout

parallel_env.rollout(1000,actor_module)
`

Issue is created tensordict['action'] has shape torch.size[7]
should have been tensordict['action'] has shape torch.size[7,1]
###########Fix
`
# Use ProbabilisticActor to sample actions probabilistically from Categorical distribution based on logits
class CustomProbabilisticActor(ProbabilisticActor):
def forward(self, tensordict):
# Call the parent class's forward method to get actions and log probabilities
tensordict = super().forward(tensordict)

        # Reshape the action tensor to have shape [batch_size, 1]
        tensordict.set("action", tensordict.get("action").unsqueeze(-1))
        
        return tensordict


# Use ProbabilisticActor to sample actions probabilistically from Categorical distribution based on logits
actor_module = CustomProbabilisticActor(
    module=tensordict_module,
    spec=parallel_env.action_spec,
    in_keys=["logits"],
    out_keys=["action"],
    distribution_class=Categorical,
    return_log_prob=True  # Return log probability of sampled actions (required for PPO loss)
)

rollout works now

parallel_env.rollout(1000,actor_module)
`

Version:
torch 2.5.0 pypi_0 pypi
torchrl 0.5.0 pypi_0 pypi
tensordict 0.5.0 pypi_0 pypi

@jasorsi13 jasorsi13 added the bug Something isn't working label Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants