Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduce Results on Galaxy Dataset #2

Open
mmubeen-6 opened this issue Mar 18, 2021 · 4 comments
Open

Reproduce Results on Galaxy Dataset #2

mmubeen-6 opened this issue Mar 18, 2021 · 4 comments

Comments

@mmubeen-6
Copy link

mmubeen-6 commented Mar 18, 2021

HI @tbepler, I am trying to reproduce the results of your paper on the galaxy dataset but unable to exactly achieve those. Could you please share the exact training parameters. I am currently using the following command to train it.

python3 train_galaxy.py galaxy_zoo/galaxy_zoo_train.npy galaxy_zoo/galaxy_zoo_test.npy -d 0 --num-epochs 300 --save-prefix galaxy_zoo_models/testing -z 100 --minibatch-size 100 --dx-scale 0.125 .

Moreover, in order visualize the reconstructed images, I am using the following code snippet. Please have a look at it.

def get_reconstruction(iterator, x_coord, p_net, q_net, img_size=64, rotate=True, translate=True, dx_scale=0.1, theta_prior=np.pi
                        , augment_rotation=False, z_scale=1, use_cuda=False):
    
    def decode_tensor(input_tensor, img_size):
        input_tensor = input_tensor.view(input_tensor.shape[0], img_size, img_size, 3)
        input_tensor = input_tensor.cpu().detach().numpy()

        input_tensor = input_tensor.clip(0., 1.)
        input_tensor = input_tensor * 255.
        input_tensor = input_tensor.reshape(img_size, img_size, 3)
        input_tensor = input_tensor.astype("uint8")

        print(input_tensor.shape, input_tensor.dtype)
        return input_tensor
    
    for y, in iterator:
        b = y.size(0)
        assert b == 1
        x = Variable(x_coord)
        y = Variable(y)

        x = x.expand(b, x.size(0), x.size(1))
        n = int(np.sqrt(y.size(1)))

        if use_cuda:
            y = y.cuda()
            
        # first do inference on the latent variables
        z_mu,z_logstd = q_net(y_rot.view(b,-1))
        z_std = torch.exp(z_logstd)
        z_dim = z_mu.size(1)

        # draw samples from variational posterior to calculate
        # E[p(x|z)]
        r = Variable(x.data.new(b,z_dim).normal_())
        z = z_std*r + z_mu
        
        if rotate:
            # z[0] is the rotation
            theta_mu = z_mu[:,0]
            theta_std = z_std[:,0]
            theta_logstd = z_logstd[:,0]
            theta = z[:,0]
            z = z[:,1:]
            z_mu = z_mu[:,1:]
            z_std = z_std[:,1:]
            z_logstd = z_logstd[:,1:]

            # calculate rotation matrix
            rot = Variable(theta.data.new(b,2,2).zero_())
            rot[:,0,0] = torch.cos(theta)
            rot[:,0,1] = torch.sin(theta)
            rot[:,1,0] = -torch.sin(theta)
            rot[:,1,1] = torch.cos(theta)
            x = torch.bmm(x, rot) # rotate coordinates by theta

            # use modified KL for rotation with no penalty on mean
            sigma = theta_prior

        if translate:
            # z[0,1] are the translations
            dx_mu = z_mu[:,:2]
            dx_std = z_std[:,:2]
            dx_logstd = z_logstd[:,:2]
            dx = z[:,:2]*dx_scale # scale dx by standard deviation
            dx = dx.unsqueeze(1)
            z = z[:,2:]

            x = x + dx # translate coordinates

        z = z*z_scale

        # reconstruct
        y_hat = p_net(x.contiguous(), z)
        y_hat = y_hat.view(b, -1, 3)

        input_image = decode_tensor(y_rot, img_size)
        recon_image = decode_tensor(y_hat, img_size)

        import matplotlib.pyplot as plt
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize =(15, 6))
        ax1.imshow(input_image)
        ax2.imshow(recon_image)
        fig.savefig('foo.png')
        plt.show(fig)

        break
@tomouellette
Copy link

Hi @mmubeen-6, any luck on reproducing this work? I re-implemented the model from the ground up and have yet to generate comparable results relative to the original paper. I'll try a few more hyperparameter settings and train longer than the original paper (which may help?), but I'd be interested to hear if you've solved any of your pre-existing issues.

@tbepler
Copy link
Owner

tbepler commented Jan 25, 2023

It's been a long time since I ran those experiments, but a few things that will improve the generated images are:

  • Make the spatial generator bigger (more layers, more units per layer)
  • Use a feature expansion of the coordinates going into the model. Sinusoidal features are popular now (e.g., random Fourier features) but tuning the scaling parameter is important. Polynomial features can also improve performance.

It's also worth noting that the encoder in spatial-VAE can sometimes get stuck in bad local optima (especially regarding rotation inference) which then leads to bad generator performance. There are a few tricks implemented here to try to avoid those, but they only work so-so (e.g., including rotated images as input, but decoding the unrotated image by shifted the predicted rotation by the known augmentation rotation). You might want to take a look at some newer work (https://arxiv.org/abs/2210.12918, https://github.com/SMLC-NYSBC/TARGET-VAE) where we improved the encoder to address some of these issues. We lightly tested that on galaxy zoo but didn't push it as far as it should be able to go with a larger spatial generator and/or better initial featurization of the coordinates

@tomouellette
Copy link

Hi @tbepler, thanks for your comments! I spent the evening tinkering with a bit of my code and I did notice a bit of the local optima issues on a few runs, so I will play with it a bit more taking your suggestions into account. I also did see your TARGET-VAE paper pop up as well too, congrats on that! I will try playing around with the group convolutions if I get a chance.

If you don't mind, I might create another github repository with a bit of a refactored/re-engineered version of the spatial-VAE (with proper attributions and references of course). I think this could be a nice architecture for some of the applied bio stuff I'm working on - so I will probably run some additional experiments with more expressive encoders, different ways to condition the decoding with the latent variables, cyclical annealing of the kld, and maybe swapping out linear layers for 1x1 convolutions, etc. etc. I may potentially add in reflection if I have time.

I also wonder if semi-supervised learning can help with convergence; since it's seems reasonably straight forward to aggregate some ground truth rotations/translations either through augmentation or extracting them via fits (e.g. major axis or something).

@tbepler
Copy link
Owner

tbepler commented Feb 1, 2023

@tomouellette you're welcome to fork the code and use it however you like. I agree semi-supervised learning could help with convergence if you have labeled angles for a subset of images. It should be pretty straightforward to include in the objective.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants