Desample and Deslice

This is a hyper-technical post, only for despicable freaks like me. Maybe somebody is looking for similar functions so if that’s the case, feel free to copy paste. For a project I’m working on, I needed to sample a tensor, but keep it inside the original tensorflow graph. So I had to sample the tensor, keep the remainder, and then get back the original tensor together, as if nothing had happened. I found a way to do it. I call it sample and desample. For another part of the project I needed a random slice of a multidimensional tensor, and put it back to its original form, to have it back in the graph. I call it slice and deslice. I’m going to explain how I did it, and how it works. You can find the code here, and give a try to the containing package that I’m building here. I’d love to get some feedback.

The sample_axis function selects randomly items in the chosen axis, and gives back the remainder and the indices necessary to undo the sampling:

def sample_axis(tensor, max_dim=1024, axis=1):
    if tensor.shape[axis] > max_dim:
        newdim_inp = sorted(np.random.choice(tensor.shape[axis], max_dim, replace=False))
        out_tensor = tf.gather(tensor, indices=newdim_inp, axis=axis)
    else:
        out_tensor = tensor

    if tensor.shape[axis] > max_dim:
        remaining_indices = list(set(range(tensor.shape[axis])).difference(set(newdim_inp)))

        shuffled_indices = newdim_inp + remaining_indices
        deshuffle_indices = np.array(shuffled_indices).argsort()

        remainder = tf.gather(tensor, indices=remaining_indices, axis=axis)
    else:
        remainder, deshuffle_indices = None, None

    return out_tensor, remainder, deshuffle_indices

Luckily the random indices can be created with numpy, since the gradient won’t need to pass through them, even though it might still be good idea to make those functions in tf. To undo the sampling, we just need to gather the remainder such as

def desample_axis(sample, remainder, deshuffle_indices, axis = 1):
    if not remainder is None:
        concat = tf.concat([sample, remainder], axis=axis)
        deshuffled = tf.gather(concat, indices=deshuffle_indices, axis=axis)
    else:
        deshuffled = sample

    return deshuffled

To randomly slice and deslice, we just need to sample one sample from the few axis desired, and save the remainders and indices to do the deshuffling. The following code will show you how sampling and desampling gets the initial tensor, and how slicing and deslicing gets the initial tensor:


def test_sampling_desampling():

    test_several_samples = True
    test_choosing_axis = True
    test_deslice = True

    if test_several_samples:
        print('-' * 20)
        t = tf.random.uniform((2, 34))
        st, remainder, deshuffle_indices = sample_axis(t, max_dim=4, return_deshuffling=True)
        print('original shape:', t.shape)
        print('sample shape:  ', st.shape)
        print('reminder shape:', remainder.shape)
        print(deshuffle_indices)
        dst = desample_axis(st, remainder, deshuffle_indices)
        print('Is the desampled tensor equal to how it was at the beginning?', np.all(dst == t))

    if test_choosing_axis:
        for axis in [0, 1, 2]:
            print('-' * 20)

            t = tf.random.uniform((2, 3, 4))
            st, remainder, deshuffle_indices = sample_axis(t, max_dim=1, return_deshuffling=True, axis=axis)
            print('original shape:', t.shape)
            print('sample shape:  ', st.shape)
            print('reminder shape:', remainder.shape)
            print(deshuffle_indices)
            dst = desample_axis(st, remainder, deshuffle_indices, axis=axis)
            print('desampld shape:', dst.shape)
            print('Is the desampled tensor equal to how it was at the beginning?', np.all(dst==t))

    if test_deslice:
        print('-' * 20)

        deslice_axis=[1,2]
        t = tf.random.uniform((2, 3, 4, 5))
        st = t
        reminders = []
        deshuffles = []
        for axis in deslice_axis:
            st, remainder, deshuffle_indices = sample_axis(st, max_dim=1, return_deshuffling=True, axis=axis)
            reminders.append(remainder)
            deshuffles.append(deshuffle_indices)

            print('original shape:', t.shape)
            print('sample shape:  ', st.shape)
            print('reminder shape:', remainder.shape)
            print(deshuffle_indices)

        for j, _ in enumerate(deslice_axis):
            i = -j - 1
            st = desample_axis(st, reminders[i], deshuffles[i], axis=deslice_axis[i])
            print('desampld shape:', st.shape)
        print('Is the desampled tensor equal to how it was at the beginning?', np.all(st==t))
Written on December 10, 2022