匿名希望のおでんFortranツヴァイさん太郎

生き物、Fortran、川について書く

ifとcaseと分岐処理削除

Fortranでは分岐処理にifとcaseがあります.条件分岐が三つ以上ある場合,if文では条件を満たすまで上から順にelse ifをたどるのに対し,case文では値の評価をした後,条件を満たすところにジャンプするため,分岐が多い場合にはcaseの方が速いそうです.私は統計処理をする際,文字列を用いた分岐に対してはcaseを使うことが多いです(一文字目が"A"の時に処理aをするなど).

チューニングの際は,まず条件分岐自体を無くせないかを考えます.意外とifやcaseを使わなくても書けるものです.やたらに条件分岐が多いコードを書いている間は,アルゴリズム自体の理解が浅いことが多いです.

今回はdo loopのカウンターの値に応じて処理を変える場合の計算速度をif文,case文,条件分岐を使わない場合で比較します.実用上このような場面は多く,例えば,

  • 移流方程式を時間幅dt=0.1でt=0からt=10まで解き,t=1ごとに解を出力したいとき
  • マルコフ連鎖モンテカルロ法で10回ごとにサンプリングを行いたいとき

などです.

以下がソースコードです.

program main
implicit none
!    Parameter
integer, parameter::iter = 10**8
integer, parameter::thin1 = 10
integer, parameter::thin2 = 5
integer, parameter::thin3 = 2
integer, parameter::thin4 = 1
!    Loop
integer i, j
!    Temporary
integer s
integer m
! Time
integer t0
integer t1 
integer t_rate
integer t_max


s = 0
call system_clock(t0)
do i = 1, iter
	m = mod(i, thin1)
	if (m == 0) then
		s = 0
	else if (m == thin4) then
		s = s + thin4
	else if (m == thin3) then
		s = s + thin3
	else if (m == thin2) then
		s = s + thin2
	else
		s = s * 2
	end if
end do
call system_clock(t1, t_rate, t_max)
print*, calc_time(t0, t1, t_rate, t_max)


s = 0
call system_clock(t0)
do i = 1, iter
	m = mod(i, thin1)
	select case(m)
		case (0)
			s = 0
		case (thin4)
			s = s + thin4
		case (thin3)
			s = s + thin3
		case (thin2)
			s = s + thin2
		case default
			s = s * 2
	end select
end do
call system_clock(t1, t_rate, t_max)
print*, calc_time(t0, t1, t_rate, t_max)


s = 0
i = 1
call system_clock(t0)
do while (i <= iter)
	! mod(i, thin1) = thin4 = 1
	s = s + thin4
	i = i + 1

	! mod(i, thin1) = thin3 = 2
	s = s + thin3
	i = i + 1

	! mod(i, thin1) = 3--4
	do j = thin3+1, thin2-1
		s = s * 2
		i = i + 1
	end do

	! mod(i, thin1) = thin2 = 5
	s = s + thin2
	i = i + 1

	! mod(i, thin1) =  6--9
	do j = thin2+1, thin1-1
		s = s * 2
		i = i + 1
	end do

	! mod(i, thin1) = 0
	s = 0
	i = i + 1
end do
call system_clock(t1, t_rate, t_max)
print*, calc_time(t0, t1, t_rate, t_max)


contains 

function calc_time(t0, t1, t_rate, t_max) result(rslt)
	integer, intent(in)::t0
	integer, intent(in)::t1
	integer, intent(in)::t_rate
	integer, intent(in)::t_max
	!    Temporary
	double precision rslt
	integer diff

	if (t1 < t0) then
		diff = t_max - t0 + t1 + 1
	else
		diff = t1 - t0
	end if
	
	rslt = dble(diff)/dble(t_rate)
end function

end program

結果

  0.39000000000000001     
  0.34399999999999997     
  0.18800000000000000     

若干caseの方が速いですが,あまり変わりませんね.もう少し差が出ると思っていました.条件分岐がない場合は二倍くらいの速度で計算が終わりました.

計算結果の出力を伴う場合,どれくらい差が出るでしょうか.計算部分を以下のように書き直します.三つ目の計算では,write文を書く位置はiのカウントを増やす前であることに注意しましょう.

open(17, file="ex.csv", status="replace")
s = 0
call system_clock(t0)
do i = 1, iter
	m = mod(i, thin1)
	if (m == 0) then
		s = 0
		write(17, *) i, ",", s
	else if (m == thin4) then
		s = s + thin4
		write(17, *) i, ",", s
	else if (m == thin3) then
		s = s + thin3
		write(17, *) i, ",", s
	else if (m == thin2) then
		s = s + thin2
		write(17, *) i, ",", s
	else
		s = s * 2
	end if
end do
call system_clock(t1, t_rate, t_max)
print*, calc_time(t0, t1, t_rate, t_max)

write(17, *) 

s = 0
call system_clock(t0)
do i = 1, iter
	m = mod(i, thin1)
	select case(m)
		case (0)
			s = 0
			write(17, *) i, ",", s
		case (thin4)
			s = s + thin4
			write(17, *) i, ",", s
		case (thin3)
			s = s + thin3
			write(17, *) i, ",", s
		case (thin2)
			s = s + thin2
			write(17, *) i, ",", s
		case default
			s = s * 2
	end select
end do
call system_clock(t1, t_rate, t_max)
print*, calc_time(t0, t1, t_rate, t_max)

write(17, *) 

s = 0
i = 1
call system_clock(t0)
do while (i <= iter)
	! mod(i, thin1) = thin4 = 1
	s = s + thin4
	write(17, *) i, ",", s
	i = i + 1

	! mod(i, thin1) = thin3 = 2
	s = s + thin3
	write(17, *) i, ",", s
	i = i + 1

	! mod(i, thin1) = 3--4
	do j = thin3+1, thin2-1
		s = s * 2
		i = i + 1
	end do

	! mod(i, thin1) = thin2 = 5
	s = s + thin2
	write(17, *) i, ",", s
	i = i + 1

	! mod(i, thin1) = 6--9
	do j = thin2+1, thin1-1
		s = s * 2
		i = i + 1
	end do

	! mod(i, thin1) = 0
	s = 0
	write(17, *) i, ",", s
	i = i + 1
end do
call system_clock(t1, t_rate, t_max)
print*, calc_time(t0, t1, t_rate, t_max)

close(17)


結果

  0.93700000000000006     
  0.93799999999999994     
  0.92200000000000004     

理由はよくわかりませんが,三者ともほとんど同じ速度になりました.条件分岐のない計算が遅くなったのはなぜでしょうか.今回はここまでとします.