| import tensorflow as tf | |
| def window_partition(x, window_size): | |
| B, H, W, C = tf.unstack(tf.shape(x), num=4) | |
| x = tf.reshape(x, shape=[-1, H // window_size, window_size, W // window_size, window_size, C]) | |
| x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5]) | |
| windows = tf.reshape(x, shape=[-1, window_size, window_size, C]) | |
| return windows | |
| def window_reverse(windows, window_size, H, W, C): | |
| x = tf.reshape(windows, shape=[-1, H // window_size, W // window_size, window_size, window_size, C]) | |
| x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5]) | |
| x = tf.reshape(x, shape=[-1, H, W, C]) | |
| return x |