Desample and Deslice
This is a hyper-technical post, only for despicable freaks like me. Maybe somebody is looking for similar functions so if that’s the case, feel free to copy paste. For a project I’m working on, I needed to sample a tensor, but keep it inside the original tensorflow graph. So I had to sample the tensor, keep the remainder, and then get back the original tensor together, as if nothing had happened. I found a way to do it. I call it sample and desample. For another part of the project I needed a random slice of a multidimensional tensor, and put it back to its original form, to have it back in the graph. I call it slice and deslice. I’m going to explain how I did it, and how it works. You can find the code here, and give a try to the containing package that I’m building here. I’d love to get some feedback.
The sample_axis
function selects randomly items in the chosen axis, and gives back the remainder
and the indices necessary to undo the sampling:
def sample_axis(tensor, max_dim=1024, axis=1):
if tensor.shape[axis] > max_dim:
newdim_inp = sorted(np.random.choice(tensor.shape[axis], max_dim, replace=False))
out_tensor = tf.gather(tensor, indices=newdim_inp, axis=axis)
else:
out_tensor = tensor
if tensor.shape[axis] > max_dim:
remaining_indices = list(set(range(tensor.shape[axis])).difference(set(newdim_inp)))
shuffled_indices = newdim_inp + remaining_indices
deshuffle_indices = np.array(shuffled_indices).argsort()
remainder = tf.gather(tensor, indices=remaining_indices, axis=axis)
else:
remainder, deshuffle_indices = None, None
return out_tensor, remainder, deshuffle_indices
Luckily the random indices can be created with numpy, since the gradient won’t need to pass through them, even though it might still be good idea to make those functions in tf. To undo the sampling, we just need to gather the remainder such as
def desample_axis(sample, remainder, deshuffle_indices, axis = 1):
if not remainder is None:
concat = tf.concat([sample, remainder], axis=axis)
deshuffled = tf.gather(concat, indices=deshuffle_indices, axis=axis)
else:
deshuffled = sample
return deshuffled
To randomly slice and deslice, we just need to sample one sample from the few axis desired, and save the remainders and indices to do the deshuffling. The following code will show you how sampling and desampling gets the initial tensor, and how slicing and deslicing gets the initial tensor:
def test_sampling_desampling():
test_several_samples = True
test_choosing_axis = True
test_deslice = True
if test_several_samples:
print('-' * 20)
t = tf.random.uniform((2, 34))
st, remainder, deshuffle_indices = sample_axis(t, max_dim=4, return_deshuffling=True)
print('original shape:', t.shape)
print('sample shape: ', st.shape)
print('reminder shape:', remainder.shape)
print(deshuffle_indices)
dst = desample_axis(st, remainder, deshuffle_indices)
print('Is the desampled tensor equal to how it was at the beginning?', np.all(dst == t))
if test_choosing_axis:
for axis in [0, 1, 2]:
print('-' * 20)
t = tf.random.uniform((2, 3, 4))
st, remainder, deshuffle_indices = sample_axis(t, max_dim=1, return_deshuffling=True, axis=axis)
print('original shape:', t.shape)
print('sample shape: ', st.shape)
print('reminder shape:', remainder.shape)
print(deshuffle_indices)
dst = desample_axis(st, remainder, deshuffle_indices, axis=axis)
print('desampld shape:', dst.shape)
print('Is the desampled tensor equal to how it was at the beginning?', np.all(dst==t))
if test_deslice:
print('-' * 20)
deslice_axis=[1,2]
t = tf.random.uniform((2, 3, 4, 5))
st = t
reminders = []
deshuffles = []
for axis in deslice_axis:
st, remainder, deshuffle_indices = sample_axis(st, max_dim=1, return_deshuffling=True, axis=axis)
reminders.append(remainder)
deshuffles.append(deshuffle_indices)
print('original shape:', t.shape)
print('sample shape: ', st.shape)
print('reminder shape:', remainder.shape)
print(deshuffle_indices)
for j, _ in enumerate(deslice_axis):
i = -j - 1
st = desample_axis(st, reminders[i], deshuffles[i], axis=deslice_axis[i])
print('desampld shape:', st.shape)
print('Is the desampled tensor equal to how it was at the beginning?', np.all(st==t))