Branching vs calculation

The Partridge Family were neither partridges nor a family. Discuss.
Post Reply
User avatar
cyboryxmen
Posts: 190
Joined: November 14th, 2014, 2:03 am

Branching vs calculation

Post by cyboryxmen » June 27th, 2017, 1:25 pm

Anyone that has dabbled with very simple physics will have seen one of these before.

Code: Select all

void Update ( const float delta_time )
{
	if ( alive_ )
	{
		position_ += speed_ * delta_time;
	}
}
However, I have come to the realisation that this may be better.

Code: Select all

void Update ( const float delta_time )
{
	position_ += speed_ * delta_time;
}
If the entity is "dead", simply set its speed to zero.

Now the reason why this may be the case is that if you read up on modern cpu architecture, you will realise that cpus are already getting better and better at making complex calculations in a matter of nanoseconds while at the same time, getting worse at branching. Your if statements, switch statements and virtual function calls are becoming more expensive than just simply doing x + y. Their actual costs hasn't actually increased at all really but the cpu has a lot of optimisations built into it that relies on it knowing ahead of time which code to run. So even though branching itself hasn't become more expensive, it prevents you from using so many of the optimisations the cpu uses to increase efficiency.

The advantage of the second example I gave(with the branching code) is that you can skip the unnecessary calculations since the entity is dead. In my tests, if you have a whole lot of entities and 90% of them are dead, the branching code is actually much faster. At the same time however, that's not representative to real world scenarios and in my games personally, the number of dead entities in my buffer is never lower than 50%. For me, doing a check on each and every entity to see if they're dead is usually not as fast as just doing the calculations for all of them.

Don't get me wrong: You can't make a real program that has no branching in it. You still need your if statements, switch statements and virtual function calls if you wish to do the most basic of user input. Just try and watch out for unnecessary ones where just doing the calculations is faster than doing a check with it.

I made a benchmark to help prove my point. Try running it on your system and see if I'm right or just spouting nonsense.

Code: Select all

#include <iostream>
#include <string>
#include <cstdint>
#include <vector>
#include <chrono>
#include <random>
#include <memory>
#include <algorithm>
#include <limits>
#include <fstream>
#include <functional>
#include <future>
#include <type_traits>
#include <array>
#include <unordered_map>

class Vector3
{
	using Float = float;

	Float x_ { };
	Float y_ { };
	Float z_ { };

public:
	Vector3 ( ) = default;

	constexpr explicit Vector3 ( const Float x, const Float y, const Float z ) noexcept : x_ { x }, y_ { y }, z_ { z }
	{
	}

	Vector3& operator+=( const Vector3 vector ) noexcept
	{
		x_ += vector.x_;
		y_ += vector.y_;
		z_ += vector.z_;
		return *this;
	}
	constexpr Vector3 operator+( const Vector3 vector ) const noexcept
	{
		return Vector3 { x_ + vector.x_, y_ + vector.y_, z_ + vector.z_ };
	}

	Vector3& operator-=( const Vector3 vector ) noexcept
	{
		x_ -= vector.x_;
		y_ -= vector.y_;
		z_ -= vector.z_;
		return *this;
	}
	constexpr Vector3 operator-( const Vector3 vector ) const noexcept
	{
		return Vector3 { x_ - vector.x_, y_ - vector.y_, z_ - vector.z_ };
	}

	constexpr Vector3 operator-( ) const noexcept
	{
		return Vector3 { -x_, -y_, -z_ };
	}

	Vector3& operator*=( const Float scalar ) noexcept
	{
		x_ *= scalar;
		y_ *= scalar;
		z_ *= scalar;
		return *this;
	}
	constexpr Vector3 operator*( const Float scalar ) const noexcept
	{
		return Vector3 { x_ * scalar, y_ * scalar, z_ * scalar };
	}


	Vector3& operator/=( const Float scalar ) noexcept
	{
		x_ /= scalar;
		y_ /= scalar;
		z_ /= scalar;
		return *this;
	}
	constexpr Vector3 operator/( const Float scalar ) const noexcept
	{
		return Vector3 { x_ / scalar, y_ / scalar, z_ / scalar };
	}

	Float Length ( ) const noexcept
	{
		return std::sqrt ( LengthSquared ( ) );
	}
	Float LengthSquared ( ) const noexcept
	{
		return x_ * x_ + y_ * y_ + z_ * z_;
	}
	Vector3 Normalized ( ) const noexcept
	{
		const Float length = Length ( );

		return Vector3 { x_ / length, y_ / length, z_ / length };
	}

	constexpr Float Dot ( const Vector3 vector ) const noexcept
	{
		return x_ * vector.x_ + y_ * vector.y_ + z_ * vector.z_;
	}
	constexpr Vector3 Cross ( const Vector3 vector ) const noexcept
	{
		return Vector3 { y_ * vector.z_ - z_ * vector.y_, z_ * vector.x_ - x_ * vector.z_, x_ * vector.y_ - y_ * vector.x_ };
	}

	constexpr Float X ( ) const noexcept
	{
		return x_;
	}
	constexpr Float Y ( ) const noexcept
	{
		return y_;
	}
	constexpr Float Z ( ) const noexcept
	{
		return z_;
	}
};

class EntityWithoutBranch
{
public:
	EntityWithoutBranch ( const bool alive, const Vector3 position, const Vector3 speed ) : position_ { position }
	{
		if ( alive )
		{
			speed_ = speed;
		}
	}

	void Update ( const float delta_time )
	{
		position_ += speed_ * delta_time;
	}

private:
	Vector3 position_;
	Vector3 speed_;
};

class EntityWithBranch
{
public:
	EntityWithBranch ( const bool alive, const Vector3 position, const Vector3 speed ) : alive_ { alive }, position_ { position }, speed_ { speed }
	{

	}

	void Update ( const float delta_time )
	{
		if ( alive_ )
		{
			position_ += speed_ * delta_time;
		}
	}

private:
	bool alive_;
	Vector3 position_;
	Vector3 speed_;
};

using Entity = EntityWithBranch;

int main ( )
{
	std::mt19937 engine;
	std::uniform_int_distribution<int> rand_alive ( 0, 1 );
	std::uniform_real_distribution<float> rand_pos( -5.0f, 5.0f );

	constexpr std::size_t num_tests = 4000;
	constexpr std::size_t num_entities = 100000;
	constexpr std::size_t expected_active_entities = 50000;
	std::size_t active_entities = 0;

	std::vector<Entity> entities;

	for ( std::size_t i = 0; i < num_entities; ++i )
	{
		bool alive = false;
		if ( active_entities < expected_active_entities && rand_alive ( engine ) > 0 )
		{
			alive = true;
			++active_entities;
		}
		entities.emplace_back ( alive, Vector3 { rand_pos ( engine ), rand_pos ( engine ), rand_pos ( engine ) }, Vector3 { rand_pos ( engine ), rand_pos ( engine ), rand_pos ( engine ) } );
	}

	// It's supposed to be called mean but I don't give a fuck!
	unsigned long long lowest_time_taken = std::numeric_limits<unsigned long long>::max ( );
	unsigned long long highest_time_taken = std::numeric_limits<unsigned long long>::min ( );
	long double average = 0.0;
	std::vector<unsigned long long> data_points;
	data_points.resize ( num_tests );

	for ( std::size_t i = 0; i < num_tests; ++i )
	{
		const auto start = std::chrono::steady_clock::now ( );
		for ( auto& entity : entities )
		{
			entity.Update ( 0.5f );
		}
		const auto end = std::chrono::steady_clock::now ( );
		const auto duration = end - start;
		const auto time_taken = std::chrono::duration_cast< std::chrono::nanoseconds >( duration ).count ( );

		data_points [ i ] = time_taken;
	}

	std::sort ( data_points.begin ( ), data_points.end ( ) );

	constexpr std::size_t median_index = num_tests % 2 == 0 ? num_tests / 2 : num_tests / 2 + 1;
	const auto median = data_points [ median_index ];

	long double standard_deviation = 0.0l;
	for ( std::size_t i = 0; i < num_tests; ++i )
	{
		const auto time_taken = data_points [ i ];
		average += time_taken;
		if ( time_taken < lowest_time_taken )
		{
			lowest_time_taken = time_taken;
		}
		if ( time_taken > highest_time_taken )
		{
			highest_time_taken = time_taken;
		}
		const long double distance_to_average = static_cast<long double>( time_taken ) - average;
		standard_deviation += distance_to_average * distance_to_average;
	}
	average = average / num_tests;
	standard_deviation = std::sqrt ( standard_deviation / num_tests );

	std::cout << "Num entities: " << num_entities << std::endl;
	std::cout << "Active entities: " << active_entities << std::endl;
	std::cout << "Num tests: " << num_tests << std::endl;
	std::cout << "Lowest time taken: " << lowest_time_taken << std::endl;
	std::cout << "Highest time taken: " << highest_time_taken << std::endl;
	std::cout << "Average: " << average << std::endl;
	std::cout << "Median: " << median << std::endl;
	std::cout << "Standard deviation: " << standard_deviation << std::endl;

	system ( "pause" );
	return 0;
}
Zekilk

albinopapa
Posts: 4373
Joined: February 28th, 2013, 3:23 am
Location: Oklahoma, United States

Re: Branching vs calculation

Post by albinopapa » June 27th, 2017, 6:26 pm

Well, ran the tests. Times are of course in nanoseconds. Ran over 100,000 entites

EntityWithoutBranching
Average: 680,951.83225000

EntityWithoutBranching (modified)
Average: 713,395.49150000

EntitiesWithBranching
Average: 1,110,782.99575000

EntitiesWithBranching (modified)
Average: 1,750,109.87600000

Polymorphic
Average: 2,082,753.00725000

Polymorphic (modified)
Average: 3,262,963.41275000

The modified versions set up the entities vector with 50,000 alive and dead, then shuffles the order. This is to ensure that the number of alive and dead counts are constant. All the tests I ran the alive/dead counts were 50,000. Just wanted to make it explicit from run to run.

The polymorphic classes were Living which did calculate the new position, and Dead which was an empty Update function.

With branching is ~2 times slower than without and ~2 times faster than using polymorphism, at least in these tests.

Good times as always cybor, thanks for sharing.

[DISCLAIMER TO NEWCOMERS]
While this has some merrit, in the larger scheme of things, this may not show true real-world results. Usually, by the time you factor in all the rest of your code ( audio, graphics, physics, etc...) using if blocks and switch statements and the like, won't affect performance. This is a test where there are 3 add and 3 multiply operations every loop iteration which is really fast and very few instructions and nothing else going on during the timed portion. Also, this test doesn't take into account that you'd still need to set speed_ to 0.f with an if check before the calculations can be done, so I really don't know if this would be a valid test either.
If you think paging some data from disk into RAM is slow, try paging it into a simian cerebrum over a pair of optical nerves. - gameprogrammingpatterns.com

User avatar
cyboryxmen
Posts: 190
Joined: November 14th, 2014, 2:03 am

Re: Branching vs calculation

Post by cyboryxmen » June 27th, 2017, 6:34 pm

Thanks. I learnt to benchmark a lot better since the last time. For one thing, the random number engine is not seeded so it will produce the same result every time. Not to mention I learnt to do better statistics on these benchmarks and included all the things you need to have proper statistics. Ranges, mean, median, standard deviation; the works.

I have to say though, the main reason why the polymorphic entities are slower is going to come from the fact that you are not storing them contiguously but with individual new(s). A guy from Discord asked me why I couldn't just delete the objects from the buffer. He was using std::unique_ptrs so he didn't understand the pain behind managing a contiguous array of objects(especially when trying to get their addresses).
Zekilk

User avatar
_Java
Posts: 4
Joined: February 16th, 2017, 12:45 am
Location: Indiana, United States

Re: Branching vs calculation

Post by _Java » June 27th, 2017, 10:13 pm

Pretty interesting how some things like this defy our own logic. :shock:
But it made sense after thinking about it for a minute, and its always fun to consider this kinds of stuff.
Thanks for sharing ;)
No, I do not write code in Java!

Okay, maybe a little......

albinopapa
Posts: 4373
Joined: February 28th, 2013, 3:23 am
Location: Oklahoma, United States

Re: Branching vs calculation

Post by albinopapa » June 27th, 2017, 11:07 pm

Yeah, I didn't think about cache locality when I was wanting to throw in polymorphism. I too used a vector of unique_ptrs. I didn't catch that the RNG wasn't seeded, that makes sense then. The standard deviations are coming out to be way larger than the min/max values though, not sure it should be like that, though I don't know enough about statistics to "put in my $0.02".
If you think paging some data from disk into RAM is slow, try paging it into a simian cerebrum over a pair of optical nerves. - gameprogrammingpatterns.com

albinopapa
Posts: 4373
Joined: February 28th, 2013, 3:23 am
Location: Oklahoma, United States

Re: Branching vs calculation

Post by albinopapa » June 28th, 2017, 12:07 am

So adding a few more lines of code that might also be in an Update function of a game reduces the gap between the two cases ( branching and nonbranching ).

Code: Select all

// code addition
	void Update( const float delta_time, const Vector3 &OtherPos )
	{
		const auto speed = speed_.Length();
		speed_ = ( position_ - OtherPos ).Normalized() * speed;
		position_ += speed_ * delta_time;
	}


EntityWithBranch
------------------------------------
Num entities: 100,000
Active entities: 50,000
Num tests: 4,000
Lowest time taken: 2,324,614
Highest time taken: 6,853,044
Average: 2,560,914.06250000
Median: 2,552,259
Standard deviation: 91,528,832.14215240

Code: Select all

// code addition
	void Update( const float delta_time, const Vector3 &OtherPos )
	{
		if( alive_ )
		{
			const auto speed = speed_.Length();
			speed_ = ( position_ - OtherPos ).Normalized() * speed;

			position_ += speed_ * delta_time;
		}
	}

EntityWithoutBranch
------------------------------------
Num entities: 100,000
Active entities: 50,000
Num tests: 4,000
Lowest time taken: 2,633,843
Highest time taken: 4,656,860
Average: 2,932,901.68450000
Median: 2,926,229
Standard deviation: 104,398,618.52052431
If you think paging some data from disk into RAM is slow, try paging it into a simian cerebrum over a pair of optical nerves. - gameprogrammingpatterns.com

User avatar
chili
Site Admin
Posts: 3948
Joined: December 31st, 2011, 4:53 pm
Location: Japan
Contact:

Re: Branching vs calculation

Post by chili » July 1st, 2017, 5:21 am

Nice post.

The results of branching code vs non-branching depend greatly on the type of data being processed. Take for example a routine for drawing a sprite with a chroma key. The two alternatives are branch on chroma test, skipping process on pass, or use chroma test result as a mask and process each pixel. In this case we generally see that the branching code will perform faster. Reason is because the pattern of chroma pixels and non-chroma is generally not random, and there are often long runs of both, so the branch predictor handles them well and enables the cpu to do less work. The same is found with hatched patterns (alternating chroma/non-chroma). With a random distribution of chroma/non-chroma, a branchless solution will be greatly superior though, but this is not a common real-life scenario.

The moral of this story is that you really should understand your data well before making such design decisions.
Chili

User avatar
cyboryxmen
Posts: 190
Joined: November 14th, 2014, 2:03 am

Re: Branching vs calculation

Post by cyboryxmen » July 1st, 2017, 6:05 am

Yeah I didn't mention the cpu's branch predictor since I have no information on the performance of that component. I suppose that would make sense if your inputs are going to be really predictable, the cpu would easily predict what each branch is going to be and apply the appropriate optimisations.
Zekilk

Post Reply