diff --git a/common/forward_warp/forward_warp.py b/common/forward_warp/forward_warp.py index 17a4f19..1487586 100644 --- a/common/forward_warp/forward_warp.py +++ b/common/forward_warp/forward_warp.py @@ -1,10 +1,46 @@ import tensorflow as tf +from common.utils.tf import load_op_library +from tensorflow.python.framework import ops -DISOCC_THRESH = 0.8 +DISOCC_THRESH = 0.5 -def forward_warp(features, flow): +# Load op library. +mod = load_op_library('forward_warp_op', 'build') + + +def is_forward_warp_cuda(): + """ + :return: Bool. Whether the forward_warp is using a custom CUDA op. + """ + return mod is not None + + +def forward_warp(features, flow, splat_variance=0.5): + """ + For an algorithm that gives the same end result, see section 3 in https://arxiv.org/pdf/1711.05890.pdf. + Note that the actual implementation here is not n^2, and should be linear in GPU memory. + :param features: A Tensor. Features to be warped, of shape [batch_size, H, W, C]. + :param flow: A Tensor. Un-normalized flow in image pixel units, of shape [batch_size, H, W, 2]. + Flow vectors should have (x, y) ordering. + :param splat_variance: Float. Variance of the splat. Only used for the CUDA op. + """ + if is_forward_warp_cuda(): + return mod.forward_warp(features, flow, variance=splat_variance) + else: + return forward_warp_tf(features, flow) + + +if is_forward_warp_cuda(): + @ops.RegisterGradient('ForwardWarp') + def _ForwardWarpGrad(op, grad): + image_grad, flow_grad = mod.forward_warp_grad( + grad, op.inputs[0], op.inputs[1], variance=op.get_attr('variance')) + return [image_grad, flow_grad] + + +def forward_warp_tf(features, flow): """ For an algorithm that gives the same end result, see section 3 in https://arxiv.org/pdf/1711.05890.pdf. Note that the actual implementation here is not n^2, and should be linear in GPU memory. @@ -88,7 +124,7 @@ def _get_translated_pixels(features, translations): return all_indices, all_vals -def create_disocclusion_mask(flow): +def create_disocclusion_mask(flow, splat_variance=1.0): """ Creates a disocclusion mask representing areas that were previously occluded and will become visible. This is done by forward warping some ones and thresholding them for visibility. @@ -97,10 +133,11 @@ def create_disocclusion_mask(flow): https://github.com/simonmeister/UnFlow/blob/8bff4939963c7d0adb9435880dc506fb3f988080/src/e2eflow/core/losses.py#L28 This isn't mentioned in the paper anywhere, but clearly enough, it is in the code. :param flow: Tensor of shape [B, H, W, 2]. + :param splat_variance: Float. Variance of the splat. Only used for the CUDA op. :return: Tensor of shape [B, H, W, 1]. """ with tf.name_scope('disocclusion_mask'): batch, height, width, _ = tf.unstack(tf.shape(flow)) prewarp_mask = tf.ones([batch, height, width, 1], dtype=tf.float32) - forward_warped_mask = forward_warp(prewarp_mask, flow) + forward_warped_mask = forward_warp(prewarp_mask, flow, splat_variance=splat_variance) return tf.cast(forward_warped_mask < DISOCC_THRESH, dtype=tf.float32) diff --git a/common/forward_warp/forward_warp_profile.py b/common/forward_warp/forward_warp_profile.py new file mode 100644 index 0000000..dd0fa09 --- /dev/null +++ b/common/forward_warp/forward_warp_profile.py @@ -0,0 +1,29 @@ +import numpy as np +import tensorflow as tf +from common.forward_warp.forward_warp import forward_warp +from common.utils.profile import run_profiler + +if __name__ == '__main__': + height = 512 + width = 512 + im_channels = 3 + batch_size = 8 + + # Create the graph. + image_shape = [batch_size, height, width, im_channels] + flow_shape = [batch_size, height, width, 2] + image_placeholder = tf.placeholder(shape=image_shape, dtype=tf.float32) + flow_placeholder = tf.placeholder(shape=flow_shape, dtype=tf.float32) + warped = forward_warp(image_placeholder, flow_placeholder, splat_variance=0.5) + grads = tf.gradients(warped, [image_placeholder, flow_placeholder]) + + # Create dummy images. + image = np.zeros(shape=[batch_size, height, width, im_channels], dtype=np.float32) + flow = np.zeros(shape=[batch_size, height, width, 2], dtype=np.float32) + image[:, 2:height - 2, 2:width - 2, :] = 1.0 + flow[:, 4:height - 4, 5:width - 5, :] = 1.0 + + query = [warped, grads] + feed_dict = {image_placeholder: image, + flow_placeholder: flow} + run_profiler(query, feed_dict, name='forward-warp') diff --git a/common/forward_warp/forward_warp_test.py b/common/forward_warp/forward_warp_test.py index 714a8f7..c8e81a7 100644 --- a/common/forward_warp/forward_warp_test.py +++ b/common/forward_warp/forward_warp_test.py @@ -3,14 +3,16 @@ import numpy as np import tensorflow as tf from common.utils.img import read_image, show_image -from common.forward_warp.forward_warp import forward_warp, create_disocclusion_mask +from common.forward_warp.forward_warp import forward_warp, create_disocclusion_mask, is_forward_warp_cuda, forward_warp_tf from common.utils.flow import read_flow_file +from tensorflow.python.ops import gradient_checker + VISUALIZE = False WRITE_TO_VIDEO = False -class TestForwardWarp(unittest.TestCase): +class TestForwardWarpTF(unittest.TestCase): def setUp(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True @@ -37,7 +39,7 @@ def test_forward_warp_whole_1(self): flow_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) features_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) - warp_tensor = forward_warp(features_tensor, flow_tensor) + warp_tensor = forward_warp_tf(features_tensor, flow_tensor) warp = self.sess.run(warp_tensor, feed_dict={flow_tensor: flow, features_tensor: features}) self.assertEqual(warp.tolist(), expected_warp) @@ -62,7 +64,7 @@ def test_forward_warp_partial_1(self): flow_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) features_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) - warp_tensor = forward_warp(features_tensor, flow_tensor) + warp_tensor = forward_warp_tf(features_tensor, flow_tensor) warp = self.sess.run(warp_tensor, feed_dict={flow_tensor: flow, features_tensor: features}) self.assertEqual(warp.tolist(), expected_warp) @@ -90,7 +92,7 @@ def test_forward_warp_partial_2(self): flow_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) features_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) - warp_tensor = forward_warp(features_tensor, flow_tensor) + warp_tensor = forward_warp_tf(features_tensor, flow_tensor) warp = self.sess.run(warp_tensor, feed_dict={flow_tensor: flow, features_tensor: features}) self.assertEqual(warp.tolist(), expected_warp) @@ -131,7 +133,7 @@ def test_forward_warp_oob(self): flow_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) features_tensor = tf.placeholder(tf.float32, (1, height, width, 2)) - warp_tensor = forward_warp(features_tensor, flow_tensor) + warp_tensor = forward_warp_tf(features_tensor, flow_tensor) warp = self.sess.run(warp_tensor, feed_dict={flow_tensor: flow, features_tensor: features}) self.assertEqual(warp.tolist(), expected_warp) @@ -174,7 +176,7 @@ def test_forward_warp_batch(self): flow_tensor = tf.placeholder(tf.float32, (2, height, width, 2)) features_tensor = tf.placeholder(tf.float32, (2, height, width, 2)) - warp_tensor = forward_warp(features_tensor, flow_tensor) + warp_tensor = forward_warp_tf(features_tensor, flow_tensor) warp = self.sess.run(warp_tensor, feed_dict={flow_tensor: flow, features_tensor: features}) self.assertEqual(warp[0].tolist(), expected_warp[0]) self.assertEqual(warp[1].tolist(), expected_warp[1]) @@ -189,17 +191,25 @@ def test_forward_warp_batch(self): self.assertNotEqual(np.sum(flow_grads), 0.0) self.assertNotEqual(np.sum(feature_grads), 0.0) + +class TestForwardWarpCommon(unittest.TestCase): + def setUp(self): + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + + self.flow_path = os.path.join('pwcnet', 'warp', 'test_data', 'flow_ab.flo') + self.image_path_a = os.path.join('pwcnet', 'warp', 'test_data', 'image_a.png') + self.image_path_b = os.path.join('pwcnet', 'warp', 'test_data', 'image_b.png') + + self.max_allowable_grad_err = 5e-4 + def test_visualization(self): if not VISUALIZE: return - cur_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.join(cur_dir, '..') - flow_path = os.path.join(root_dir, 'pwcnet', 'warp', 'test_data', 'flow_ab.flo') - image_path = os.path.join(root_dir, 'pwcnet', 'warp', 'test_data', 'image_a.png') - - flow_ab = [read_flow_file(flow_path)] - img_a = [read_image(image_path, as_float=True)] + flow_ab = [read_flow_file(self.flow_path)] + img_a = [read_image(self.image_path_a, as_float=True)] t_tensor = tf.placeholder(tf.float32, None) flow_ab_tensor = tf.placeholder(tf.float32, np.shape(flow_ab)) img_a_tensor = tf.placeholder(tf.float32, np.shape(img_a)) @@ -207,15 +217,21 @@ def test_visualization(self): warp = self.sess.run(warp_tensor, feed_dict={flow_ab_tensor: flow_ab, img_a_tensor: img_a, t_tensor: 1.0}) warp = np.clip(warp[0], 0.0, 1.0) - show_image(warp) + try: + show_image(warp) + except: + print('show_image(warp) failed.') # For writing to video. if WRITE_TO_VIDEO: + if not os.path.exists('outputs'): + os.makedirs('outputs') + import cv2 - import mpimg + import matplotlib.image as mpimg height = img_a[0].shape[0] width = img_a[0].shape[1] - writer = cv2.VideoWriter(cur_dir + '/outputs/warped.avi', + writer = cv2.VideoWriter('outputs/warped.avi', cv2.VideoWriter_fourcc(*'MJPG'), 20, (width, height)) steps = 60 for i in range(steps): @@ -225,17 +241,31 @@ def test_visualization(self): feed_dict={flow_ab_tensor: flow_ab, img_a_tensor: img_a, t_tensor: t}) warped = warped[0] warped = np.clip(warped, 0.0, 1.0) - output_path = cur_dir + '/outputs/out-%.2f.png' % t + output_path = 'outputs/out-%.2f.png' % t mpimg.imsave(output_path, warped) writer.write(cv2.imread(output_path)) writer.release() + def test_warp_error(self): + flow_ab = [read_flow_file(self.flow_path)] + img_a = [read_image(self.image_path_a, as_float=True)] + img_b = read_image(self.image_path_b, as_float=True) + flow_ab_tensor = tf.placeholder(tf.float32, np.shape(flow_ab)) + img_a_tensor = tf.placeholder(tf.float32, np.shape(img_a)) + warp_tensor = forward_warp(img_a_tensor, flow_ab_tensor, splat_variance=0.3) + mask = 1.0 - create_disocclusion_mask(flow_ab_tensor) + + warp, mask = self.sess.run([warp_tensor, mask], feed_dict={flow_ab_tensor: flow_ab, img_a_tensor: img_a}) + warp = np.clip(warp[0], 0.0, 1.0) + + self.assertLess(np.average(np.abs(warp - img_b) * mask[0]), 0.0212) + def test_create_disocclusion_map(self): height = 3 width = 3 flow_tensor = tf.placeholder(shape=(None, height, width, 2), dtype=tf.float32) - mask_tensor = create_disocclusion_mask(flow_tensor) + mask_tensor = create_disocclusion_mask(flow_tensor, splat_variance=0.2) flow = np.asarray([ [ @@ -261,7 +291,7 @@ def test_create_disocclusion_map_batched(self): width = 3 flow_tensor = tf.placeholder(shape=(None, height, width, 2), dtype=tf.float32) - mask_tensor = create_disocclusion_mask(flow_tensor) + mask_tensor = create_disocclusion_mask(flow_tensor, splat_variance=0.2) flow = np.asarray([ [ @@ -302,6 +332,64 @@ def test_create_disocclusion_map_no_gradient(self): grad = tf.gradients(mask_tensor, flow_tensor)[0] self.assertEqual(None, grad) + def test_gradients_errors(self): + self.gradient_errors_helper(splat_variance=1.0) + + def test_gradients_errors_low_splat(self): + if not is_forward_warp_cuda(): + return + self.gradient_errors_helper(splat_variance=0.2) + + def test_gradients_errors_high_splat(self): + if not is_forward_warp_cuda(): + return + self.gradient_errors_helper(splat_variance=0.4) + + def gradient_errors_helper(self, splat_variance): + with self.sess: + # This test is flaky, so retry if fail. + num_tries = 2 if is_forward_warp_cuda() else 4 + error1 = 0 + error2 = 0 + for i in range(num_tries): + img_shape = (16, 3, 4, 4) + flow_shape = (16, 3, 4, 2) + img_a = np.random.rand(*img_shape) + flow_ab = (np.random.rand(*flow_shape) - 0.5) * 3 + input = tf.placeholder(shape=img_a.shape, dtype=tf.float32) + flow_tensor = tf.placeholder(shape=flow_ab.shape, dtype=tf.float32) + warped_tensor = forward_warp(input, flow_tensor, splat_variance=splat_variance) + + error1 = gradient_checker.compute_gradient_error(input, img_a.shape, warped_tensor, img_a.shape, + extra_feed_dict={flow_tensor: flow_ab}, + x_init_value=img_a) + error2 = gradient_checker.compute_gradient_error(flow_tensor, flow_ab.shape, warped_tensor, + img_a.shape, extra_feed_dict={input: img_a}, + x_init_value=flow_ab) + if error1 <= self.max_allowable_grad_err and error2 <= self.max_allowable_grad_err: + return + self.assertLessEqual(max(error1, error2), self.max_allowable_grad_err, + 'Exceeded the error threshold. Note that this test may be flaky.') + + def test_gradient_errors_simultaneous(self): + with self.sess: + # This test is flaky, so retry if fail. + num_tries = 2 if is_forward_warp_cuda() else 4 + error = 0 + for i in range(num_tries): + img_shape = (16, 3, 4, 4) + flow_shape = (16, 3, 4, 2) + input = tf.ones(shape=img_shape, dtype=tf.float32) + flow_tensor = tf.ones(shape=flow_shape, dtype=tf.float32) + warped_tensor = forward_warp(input, flow_tensor, splat_variance=0.2) + + error = gradient_checker.compute_gradient_error([input, flow_tensor], [img_shape, flow_shape], + warped_tensor, img_shape) + if error <= self.max_allowable_grad_err: + return + self.assertLessEqual(error, self.max_allowable_grad_err, + 'Exceeded the error threshold. Note that this test may be flaky.') + if __name__ == '__main__': unittest.main() diff --git a/common/forward_warp/native/CMakeLists.txt b/common/forward_warp/native/CMakeLists.txt new file mode 100644 index 0000000..48ce78e --- /dev/null +++ b/common/forward_warp/native/CMakeLists.txt @@ -0,0 +1,6 @@ +cmake_minimum_required(VERSION 3.5) + +add_op_library(NAME forward_warp_op SOURCES + "forward_warp_op.cc" + "forward_warp_op.cc.cu" +) diff --git a/common/forward_warp/native/forward_warp_op.cc b/common/forward_warp/native/forward_warp_op.cc new file mode 100644 index 0000000..d42ed72 --- /dev/null +++ b/common/forward_warp/native/forward_warp_op.cc @@ -0,0 +1,139 @@ +// Taken from https://github.com/simonmeister/UnFlow/blob/master/ops/forward_warp_op.cc. +// Commit bac9bbaf49be44b9e1c1f004fce4fb04b247763d. +#define EIGEN_USE_THREADS + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +// TODO assert input flow channel count = 2, assert matching numbers in all other dims + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +using namespace tensorflow; + +void ForwardWarp(const GPUDevice& d, + typename TTypes::ConstTensor images, + typename TTypes::ConstTensor flows, + typename TTypes::Tensor output, + float variance); + +void ForwardWarpGrad(const GPUDevice& d, + typename TTypes::ConstTensor input_grad, + typename TTypes::ConstTensor original_images, + typename TTypes::ConstTensor original_flows, + typename TTypes::Tensor output_image_grad, + typename TTypes::Tensor output_flow_grad, + float variance); + +class ForwardWarpOp : public OpKernel { +public: + explicit ForwardWarpOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("variance", &variance_)); + // Check that variance_ is positive + OP_REQUIRES(context, variance_ >= 0.0f, + errors::InvalidArgument("Need variance_ >= 0, got ", variance_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& image = context->input(0); + const Tensor& flow = context->input(1); + + typename TTypes::ConstTensor image_data = image.tensor(); + typename TTypes::ConstTensor flow_data = flow.tensor(); + + Tensor* output = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, image.shape(), &output)); + typename TTypes::Tensor output_data = output->tensor(); + + ForwardWarp(context->eigen_device(), + image_data, flow_data, output_data, variance_); + } + +private: + float variance_; +}; + +class ForwardWarpOpGrad : public OpKernel { +public: + explicit ForwardWarpOpGrad(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("variance", &variance_)); + // Check that variance_ is positive + OP_REQUIRES(context, variance_ >= 0.0f, + errors::InvalidArgument("Need variance_ >= 0, got ", variance_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& original_images = context->input(1); + const Tensor& original_flows = context->input(2); + + Tensor* output_image_grads = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, original_images.shape(), + &output_image_grads)); + Tensor* output_flow_grads = NULL; + OP_REQUIRES_OK(context, context->allocate_output(1, original_flows.shape(), + &output_flow_grads)); + + typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::ConstTensor original_images_data = original_images.tensor(); + typename TTypes::ConstTensor original_flows_data = original_flows.tensor(); + typename TTypes::Tensor output_image_grads_data = output_image_grads->tensor(); + typename TTypes::Tensor output_flow_grads_data = output_flow_grads->tensor(); + + ForwardWarpGrad(context->eigen_device(), + input_data, original_images_data, original_flows_data, + output_image_grads_data, output_flow_grads_data, variance_); + } + +private: + float variance_; +}; + +using shape_inference::DimensionHandle; +using shape_inference::ShapeHandle; + +REGISTER_OP("ForwardWarp") + .Attr("variance: float = 1.0") + .Input("images: float") + .Input("flows: float") + .Output("output: float") + .SetShapeFn([](shape_inference::InferenceContext* c) { + ShapeHandle in = c->input(0); + DimensionHandle batch = c->Dim(in, 0); + DimensionHandle height = c->Dim(in, 1); + DimensionHandle width = c->Dim(in, 2); + DimensionHandle channels = c->Dim(in, 3); + c->set_output(0, c->MakeShape({batch, height, width, channels})); + return Status::OK(); + }); + +REGISTER_OP("ForwardWarpGrad") + .Attr("variance: float = 1.0") + .Input("grads: float") + .Input("original_images: float") + .Input("original_flows: float") + .Output("output_image_grad: float") + .Output("output_flow_grad: float") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(1)); + c->set_output(1, c->input(2)); + return Status::OK(); + }); + +#if GOOGLE_CUDA + +REGISTER_KERNEL_BUILDER(Name("ForwardWarp").Device(DEVICE_GPU), ForwardWarpOp); +REGISTER_KERNEL_BUILDER(Name("ForwardWarpGrad").Device(DEVICE_GPU), ForwardWarpOpGrad); + +#endif // GOOGLE_CUDA diff --git a/common/forward_warp/native/forward_warp_op.cc.cu b/common/forward_warp/native/forward_warp_op.cc.cu new file mode 100644 index 0000000..844d639 --- /dev/null +++ b/common/forward_warp/native/forward_warp_op.cc.cu @@ -0,0 +1,208 @@ +// Taken from https://github.com/simonmeister/UnFlow/blob/master/ops/forward_warp_op.cu.cc. +// Commit bac9bbaf49be44b9e1c1f004fce4fb04b247763d. +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#define _USE_MATH_DEFINES +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +using namespace tensorflow; + +#define gauss(x, y, std) + +typedef Eigen::GpuDevice GPUDevice; + +__global__ void ForwardWarpKernel(const int32 nthreads, + const float* images, const float* flows, + int batch, int height, int width, int channels, + float variance, float* output) { + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { + // out_idx = c + channels * (src_x + width * (src_y + height * b)). + int idx = out_idx; + const int c = idx % channels; + idx /= channels; + const int pixel_index = idx; + const int flow_index = pixel_index * 2; + const int src_x = idx % width; + idx /= width; + const int src_y = idx % height; + const int b = idx / height; + + const float target_x = src_x + flows[flow_index]; + const float target_y = src_y + flows[flow_index + 1]; + + const float std = sqrtf(variance); + const float dist = std * 2.0; + const int k = int(dist) + 2; + + // center pixel closest to mapping location. +#define IMG_OFFSET(iy, ix) (c + channels * (ix + width * (iy + height * b))) + const float image_value = images[out_idx]; + const float x_m_k = target_x - k; + const float x_p_k = target_x + k; + const float y_m_k = target_y - k; + const float y_p_k = target_y + k; + const int floor_x_m_k = (int)floorf(x_m_k); + const int floor_x_p_k = (int)floorf(x_p_k); + const int floor_y_m_k = (int)floorf(y_m_k); + const int floor_y_p_k = (int)floorf(y_p_k); + if (floor_x_m_k < width && floor_x_p_k >= 0 + && floor_y_m_k < height && floor_y_p_k >= 0) { + const int min_n_x = x_m_k > 0 ? floor_x_m_k : 0; + const int min_n_y = y_m_k > 0 ? floor_y_m_k : 0; + const int max_n_x = x_p_k < width? floor_x_p_k : width - 1; + const int max_n_y = y_p_k < height? floor_y_p_k : height - 1; + + const float gauss_divisor = 2.0f * variance; + const float gauss_normalizer = gauss_divisor * float(M_PI); + for (int n_x = min_n_x; n_x <= max_n_x; ++n_x) { + for (int n_y = min_n_y; n_y <= max_n_y; ++n_y) { + const float x = n_x - target_x; + const float y = n_y - target_y; + const float weight = expf(-(x * x + y * y) / gauss_divisor) / gauss_normalizer; + CudaAtomicAdd(output + IMG_OFFSET(n_y, n_x), weight * image_value); + } + } + } +#undef IMG_OFFSET + } +} + +__global__ void ForwardWarpGradKernel(const int32 nthreads, + const float* input_grad, const float* images, const float* flows, + int batch, int height, int width, int channels, float variance, + float* output_image_grad, float* output_flow_grad) { + CUDA_1D_KERNEL_LOOP(in_idx, nthreads) { + // in_idx = c + channels * (src_x + width * (src_y + height * b)). + int idx = in_idx; + const int c = idx % channels; + idx /= channels; + const int pixel_index = idx; + const int flow_index = pixel_index * 2; + const int src_x = idx % width; + idx /= width; + const int src_y = idx % height; + const int b = idx / height; + + const float target_x = src_x + flows[flow_index]; + const float target_y = src_y + flows[flow_index + 1]; + + const float std = sqrtf(variance); + const float dist = std * 2.0; + const int k = int(dist) + 2; + + float du = 0.0; + float dv = 0.0; + +#define IMG_OFFSET(iy, ix) (c + channels * (ix + width * (iy + height * b))) + float d_img = 0.0; + const float image_value = images[in_idx]; + const float x_m_k = target_x - k; + const float x_p_k = target_x + k; + const float y_m_k = target_y - k; + const float y_p_k = target_y + k; + const int floor_x_m_k = (int)floorf(x_m_k); + const int floor_x_p_k = (int)floorf(x_p_k); + const int floor_y_m_k = (int)floorf(y_m_k); + const int floor_y_p_k = (int)floorf(y_p_k); + if (floor_x_m_k < width && floor_x_p_k >= 0 + && floor_y_m_k < height && floor_y_p_k >= 0) { + const int min_n_x = x_m_k > 0? floor_x_m_k : 0; + const int min_n_y = y_m_k > 0? floor_y_m_k : 0; + const int max_n_x = x_p_k < width? floor_x_p_k : width - 1; + const int max_n_y = y_p_k < height? floor_y_p_k : height - 1; + + const float gauss_divisor = 2.0f * variance; + const float gauss_normalizer = gauss_divisor * float(M_PI); + for (int n_x = min_n_x; n_x <= max_n_x; ++n_x) { + for (int n_y = min_n_y; n_y <= max_n_y; ++n_y) { + const float x = n_x - target_x; + const float y = n_y - target_y; + const float weight = expf(-(x * x + y * y) / gauss_divisor) / gauss_normalizer; + const float weighted_din = input_grad[IMG_OFFSET(n_y, n_x)] * weight; + const float factor = 2 * weighted_din / gauss_divisor * image_value; + du += factor * x; + dv += factor * y; + d_img += weighted_din; + } + } + } + + output_image_grad[in_idx] = d_img; + CudaAtomicAdd(output_flow_grad + flow_index, du); + CudaAtomicAdd(output_flow_grad + flow_index + 1, dv); + } +#undef IMG_OFFSET +} + +void ForwardWarp(const GPUDevice& d, + typename TTypes::ConstTensor images, + typename TTypes::ConstTensor flows, + typename TTypes::Tensor output, + float variance) { + const int batch = images.dimension(0); + const int height = images.dimension(1); + const int width = images.dimension(2); + const int channels = images.dimension(3); + + const int total_count = batch * height * width * channels; + if (total_count == 0) return; + + CudaLaunchConfig config; + + // Initialize output with all zeros. + config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + config.virtual_thread_count, output.data()); + + config = GetCudaLaunchConfig(total_count, d); + ForwardWarpKernel + <<>>( + config.virtual_thread_count, images.data(), flows.data(), + batch, height, width, channels, + variance, output.data()); +} + +void ForwardWarpGrad(const GPUDevice& d, + typename TTypes::ConstTensor input_grad, + typename TTypes::ConstTensor original_images, + typename TTypes::ConstTensor original_flows, + typename TTypes::Tensor output_image_grad, + typename TTypes::Tensor output_flow_grad, + float variance) { + const int batch = input_grad.dimension(0); + const int height = input_grad.dimension(1); + const int width = input_grad.dimension(2); + const int channels = input_grad.dimension(3); + + int total_count = batch * height * width * 2; + if (total_count == 0) return; + + // Initialize output_flow_grad with all zeros. + CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + config.virtual_thread_count, output_flow_grad.data()); + + // Initialize output_image_grad with all zeros. + total_count = batch * height * width * channels; + config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + config.virtual_thread_count, output_image_grad.data()); + + // Accumulate. + config = GetCudaLaunchConfig(total_count, d); + ForwardWarpGradKernel + <<>>( + config.virtual_thread_count, input_grad.data(), + original_images.data(), original_flows.data(), + batch, height, width, channels, variance, + output_image_grad.data(), output_flow_grad.data()); +} + +#endif // GOOGLE_CUDA diff --git a/common/utils/tf.py b/common/utils/tf.py index 049ce04..d9a5534 100644 --- a/common/utils/tf.py +++ b/common/utils/tf.py @@ -6,7 +6,9 @@ matplotlib.use('TkAgg') from matplotlib import pyplot as plt import io +import os.path import tensorflow as tf +from sys import platform from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -228,3 +230,22 @@ def leaky_relu(features, alpha=0.1, name=None): :return: Tensor. The activated value. """ return tf.nn.leaky_relu(features, alpha=alpha, name=name) + + +def load_op_library(op_name, directory='build'): + """ + Loads a Tensorflow native op, or returns None if not found. + :param op_name: Str. Name of the op. + :param directory: Str. Directory to search in. + :return: Tensorflow op module, or None if the op was not found. + """ + if platform == 'win32': + lib_path = os.path.join(directory, op_name + '.dll') + else: + lib_path = os.path.join(directory, 'lib' + op_name + '.so') + if os.path.isfile(lib_path): + mod = tf.load_op_library(lib_path) + else: + print('Warning: No native implementation of', op_name, 'found. Falling back to the Tensorflow version.') + mod = None + return mod diff --git a/mains/configs/train_pwcnet.json b/mains/configs/train_pwcnet.json index 868bc1c..50f08b6 100644 --- a/mains/configs/train_pwcnet.json +++ b/mains/configs/train_pwcnet.json @@ -102,7 +102,7 @@ "args": { "directory": { "var_ref": "sintel_data" }, "checkpoint_directory": { "var_ref": "checkpoint_directory" }, - "config": "mains/configs/pwcnet_schedule/train_pwcnet_sintel_10.json", + "config": "mains/configs/pwcnet_schedule/train_pwcnet_sintel6_10.json", "iterations": 200000 } }, diff --git a/mains/train_pwcnet.py b/mains/train_pwcnet.py index e5e26d2..85cbe52 100644 --- a/mains/train_pwcnet.py +++ b/mains/train_pwcnet.py @@ -15,7 +15,6 @@ def main(): config_proto = tf.ConfigProto() config_proto.gpu_options.allow_growth = True - config_proto.allow_soft_placement = True session = tf.Session(config=config_proto) # Read the JSON config. diff --git a/pwcnet/cost_volume/cost_volume.py b/pwcnet/cost_volume/cost_volume.py index 63bcda3..c65a8c2 100644 --- a/pwcnet/cost_volume/cost_volume.py +++ b/pwcnet/cost_volume/cost_volume.py @@ -1,21 +1,12 @@ # Mostly copied from https://github.com/nameless-Chatoyant/PWC-Net_pytorch/blob/master/modules.py. # Master branch commit 2225ad2082371126cc9c8e57a8b962a88933a8c0. import tensorflow as tf -import os.path -from sys import platform +from common.utils.tf import load_op_library from tensorflow.python.framework import ops # Load op library. -if platform == 'win32': - lib_path = os.path.join('build', 'correlation_op.dll') -else: - lib_path = os.path.join('build', 'libcorrelation_op.so') -if os.path.isfile(lib_path): - mod = tf.load_op_library(lib_path) -else: - print('Warning: No CUDA implementation of cost_volume found. Falling back to the Tensorflow version.') - mod = None +mod = load_op_library('correlation_op', 'build') def cost_volume(c1, c2, search_range=4): diff --git a/pwcnet/warp/warp.py b/pwcnet/warp/warp.py index fe040f3..6bc4a71 100644 --- a/pwcnet/warp/warp.py +++ b/pwcnet/warp/warp.py @@ -1,20 +1,11 @@ import tensorflow as tf -import os.path +from common.utils.tf import load_op_library from pwcnet.warp.spacial_transformer_network.transformer import spatial_transformer_network -from sys import platform from tensorflow.python.framework import ops # Load op library. -if platform == 'win32': - lib_path = os.path.join('build', 'backward_warp_op.dll') -else: - lib_path = os.path.join('build', 'libbackward_warp_op.so') -if os.path.isfile(lib_path): - mod = tf.load_op_library(lib_path) -else: - print('Warning: No CUDA implementation of backward_warp found. Falling back to the Tensorflow version.') - mod = None +mod = load_op_library('backward_warp_op', 'build') def backward_warp(images, optical_flows, bilinear_sample=True):