Slideshare uses cookies to improve functionality and performance, and to provide you with relevant advertising. If you continue browsing the site, you agree to the use of cookies on this website. See our User Agreement and Privacy Policy.

Slideshare uses cookies to improve functionality and performance, and to provide you with relevant advertising. If you continue browsing the site, you agree to the use of cookies on this website. See our Privacy Policy and User Agreement for details.

Successfully reported this slideshow.

Like this presentation? Why not share!

- What to Upload to SlideShare by SlideShare 6470580 views
- Customer Code: Creating a Company C... by HubSpot 4821610 views
- Be A Great Product Leader (Amplify,... by Adam Nash 1079748 views
- Trillion Dollar Coach Book (Bill Ca... by Eric Schmidt 1266886 views
- APIdays Paris 2019 - Innovation @ s... by apidays 1518637 views
- A few thoughts on work life-balance by Wim Vanderbauwhede 1107150 views

slides of https://www.youtube.com/watch?v=JCoI6iCc6Cs

License: CC Attribution License

No Downloads

Total views

271

On SlideShare

0

From Embeds

0

Number of Embeds

7

Shares

0

Downloads

2

Comments

0

Likes

1

No notes for slide

- 1. Modeling the Dynamics of SGD by Stochastic Differential Equation
- 2. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 2
- 3. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 3
- 4. Stochastic Gradient Descent (SGD) • Gradient Descent v.s. Stochastic Gradient Descent Minimum Gradient Descent Batch size = Dataset size SGD Small batch size SGD Large batch size w : weights ⌘ : learning rate B : batch L(x, w) : loss function w w ⌘ 1 |B| X x2B @L(x, w) @w 4
- 5. Stochastic Gradient Descent (SGD) • Convergence of SGD • Assume that the loss function is convex E[L(x, ¯w) L(x, w⇤ )] o( 1 p T ) SGD ¯w Minimum w⇤ The distance is guaranteed to be small T : step counts ¯w : w after T steps w⇤ : w at the minimum of L 5
- 6. Stochastic Gradient Descent (SGD) • Dynamics of SGD Minimum SGD The process between the starting point and the final solution 6
- 7. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 7
- 8. Random Walk x t = tt = t 1/2 probability1/2 probability t = 0 0x ⇢ P(X t = x) = 1 2 P(X t = x) = 1 2 Position of the particle at t is a random variable X t such that 8
- 9. Random Walk t = t t = 2 t 1/2 1/2 x x0 x x0 2 x2 x 1/4 1/2 1/4 x x0 2 x2 x 1/8 3/8 1/8 t = 3 t 3/8 3 x 3 x X t X2 t X3 t 9
- 10. Random Walk Normal distribution D = ( x)2 t t = n t, n ! 1, t ! 0 Xt = N(0, Dt) Diffusion coefficient 10
- 11. Stochastic Differential Equation (SDE) • Ordinary Differential Equation ⇢ dx(t) dt = b(x(t)), where t > 0 x(0) = x0 x0 x(t) Trajectory of x 11
- 12. ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt , where t > 0 and W(t) is a Wiener process x(0) = x0 Stochastic Differential Equation (SDE) • Stochastic Differential Equation x0 x(t) Trajectory samples of x Deterministic part Stochastic part (Random walk with infinitely small step) 12
- 13. Stochastic Differential Equation (SDE) • Solving Stochastic Differential Equation ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt x(0) = x0 ⇢ dx(t) = b(x(t))dt + B(x(t))dW(t) x(0) = x0 x(t) = x0 + Z t 0 b(x(s))ds + Z t 0 B(x(s))dW(s) Multiply both sides by dt Integrate both sides by dt Stochastic integral 13
- 14. Stochastic Differential Equation (SDE) • Solution of a Stochastic Integral is a Random Variable If g : [0, 1] ! R is a deterministic function: E h Z 1 0 gdW i = 0 , and E h ( Z 1 0 gdW)2 i = Z 1 0 g2 dt If G is a stochastic process such that E h Z T 0 G2 dt i < 1 : E h Z T 0 GdW i = 0 , and E h ( Z T 0 GdW)2 i = E h Z T 0 G2 dt i mean variance mean variance 14
- 15. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 15
- 16. Continuous-time SGD & Controlled SGD 16
- 17. Continuous-time SGD & Controlled SGD • Notation Conventions: Gradient Descent : xk+1 = xk ⌘rf(xk) Stochastic Gradient Descent : xk+1 = xk ⌘rf k (xk) f : loss function xk : weights at step k k : index of training sample at step k (assume batch size is 1) fi : loss function calculated by batch i, where f(x) = (1/n)⌃n i=1fi(x) 17
- 18. Continuous-time SGD & Controlled SGD xk+1 xk = ⌘rf k (xk) xk+1 xk = ⌘rf(xk) + p ⌘Vk Deterministic part Stochastic part minimum Deterministic partStochastic part Vk = p ⌘(rf(xk) rf k (xk)) mean of Vk : 0 covariance of Vk : ⌘⌃(xk), where ⌃(xk) = (1/n)⌃n i=1(rf(xk) rfi(xk))(rf(xk) rfi(xk))T 18
- 19. Continuous-time SGD & Controlled SGD • Continuos-time SGD xk+1 xk = ⌘rf(xk) + p ⌘Vk Convert to continuous time domain dXt = rf(Xt)dt + p ⌘⌃(Xt)dWt dXt = r(f(Xt) + ⌘ 4 |rf(Xt)|2 )dt + p ⌘⌃(Xt)dWt Continuos-time SGD, Order 1 and Order 2 weak approximation 19
- 20. Continuous-time SGD & Controlled SGD • A Toy Example • Continuous-time SGD (order 2 weak approximation): • Solution: f(x) = x2 , f1(x) = (x 1)2 1, f2(x) = (x + 1)2 1 f2(x) f1(x)f(x) dXt = 2(1 + ⌘)Xtdt + 2 p ⌘dWt Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t )) 20
- 21. Continuous-time SGD & Controlled SGD Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t )) t x E[Xt] = ⇢ x0, when t = 0 0, when t ! 1 x0 Var[Xt] = ⇢ 0, when t = 0 ⌘ 1+⌘ , when t ! 1 E[Xt⇤ ] = p Var[Xt⇤ ] Fluctuations phaseDescent phase r ⌘ 1 + ⌘ 21
- 22. Continuous-time SGD & Controlled SGD • Controlled SGD : Adaptive Hyper-parameter Adjustment xk+1 = xk ⌘ukf0 (xk), where uk 2 [0, 1] is adjustment factor Optimal Control Formulation min ut Ef(Xt) subject to : dXt = utf0 (Xt)dt + ut p ⌘⌃(Xt)dWt 22
- 23. Continuous-time SGD & Controlled SGD • Quadratic Objective Function • Continuous-time SGD: • Optimal control policy : dXt = aut(Xt b)dt + ut p ⌘⌃dWt f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x) u⇤ t = ⇢ 1 if a 0 or t t⇤ 1 1+a(t t⇤) if a > 0 and t > t⇤ 23
- 24. Continuous-time SGD & Controlled SGD • Optimal control policy u⇤ t = ⇢ 1 if a 0 or t t⇤ , ( t t⇤ is descent phase) 1 1+a(t t⇤) if a > 0 and t > t⇤ , ( t > t⇤ is ﬂuctuations phase) t x Fluctuations phase Descent phase t⇤ a 0 a > 0 f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x) 24
- 25. Continuous-time SGD & Controlled SGD • General Objective Function f(x) and fi(x) is not necessarily quadratic, and x 2 Rd assume f(x) ⇡ 1 2 dX i=1 a(i)(x(i) b(i))2 hold locally in x, and ⌃ ⇡ diag{⌃(1), ..., ⌃(d)} where each ⌃(i) is locally constant. (each dimension is independent) 25
- 26. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms At each step k, estimate ak,(i), bk,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 . Since rf(i) ⇡ a(i)(x(i) b(i)), we use linear regression to estimate ak,(i), bk,(i): 1 2 ak,(i)(xk,(i) bk,(i))2 xk,(i) xk 1,(i)ak,(i) = gxk,(i) gk,(i)xk,(i) x2 k,(i) x2 k,(i) , and bk,(i) = xk,(i) gk,(i) ak,(i) where gk,(i) = rf k (xk)(i), and gk+1,(i) = k,(i)gk,(i) + (1 k,(i))gk,(i) Exponential moving average 26
- 27. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms Solve the optimal control policy u⇤ k,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 u⇤ k,(i) = ( 1 if a 0, min(1, ak,(i)(¯xk,(i) bk,(i))2 ⌘⌃k,(i) ) if a > 0 . where ⌃k,(i) = g2 k,(i) ¯g2 k,(i) 27
- 28. MNIST fully connected NN CIFAR-10 fully connected NN CIFAR-10 CNN 28
- 29. Continuous-time SGD & Controlled SGD • Implementation of cSGD • https://github.com/LiQianxiao/cSGD-cMSGD 29
- 30. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Effects of SGD on Generalization 30
- 31. Effects of SGD on Generalization ICANN 2018 31
- 32. Effects of SGD on Generalization • Notation Conventions: 32 Loss function : L(✓) = 1 N NX n=1 l(✓, xn), where N is the size of dataset GD : ✓k+1 = ✓k ⌘g(✓k), where g(✓) = @L @✓ SGD : ✓k+1 = ✓k ⌘g(S) (✓k), where gS (✓) = 1 S X n2B @ @✓ l(✓, xn), B is batch and S is batch size
- 33. Effects of SGD on Generalization • Continuous-time SGD d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) S is the covariance of ⇣ g(S) (✓) g(✓) ⌘ 33
- 34. Effects of SGD on Generalization • Effects of different learning rate and batch size Minimum small ⌘ large ⌘ Minimum small S large S d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) 34
- 35. Effects of SGD on Generalization • Flat minimum v.s. Sharp minimum (https://arxiv.org/abs/1609.04836) Loss function Loss function (evaluated on testing data) Flat minimum Sharp minimum High testing error Low testing error 35
- 36. Effects of SGD on Generalization • Effects of learning rate / batch size on generalization d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) Flat Minimum Sharp Minimum Flat Minimum Sharp Minimum large r ⌘ S , (large ⌘, small S)small r ⌘ S , (small ⌘, large S) 36
- 37. Effects of SGD on Generalization • Theoretical Explanation • Assumption 1 : the loss surface can be approximated by a quadratic bowl, with minimum at zero loss. • Assumption 2 : The covariance of the gradients and the Hessian of the loss approximation are approximately equal, i.e. C = H. 37
- 38. Effects of SGD on Generalization • Theoretical Explanation d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) dz = ⇤zdt + r ⌘ S p ⇤dW(t) Change of variables: z : New variable, where z = V T (✓ ✓⇤ ) ✓⇤ : The parameters at the minimum V : Orthogonal matrix of the eigen decomposition H = V ⇤V T H : The Hession of L(✓) 38
- 39. Effects of SGD on Generalization • Theoretical Explanation • Expectation of loss function dz = ⇤zdt + r ⌘ S p ⇤dW(t) E(L) = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) Ornstein-Unhlenbeck process for z, solution : E[z] = 0 and cov[z] = ⌘ 2S I 39
- 40. • Theoretical Explanation Effects of SGD on Generalization E = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) =) E(L) Tr(H) / ⌘ S low ⌘ S Sharp minimum : low E(L) Tr(H) Flat munimum : high E(L) Tr(H) (Minima with similar loss values) high ⌘ S 40
- 41. Increasing LR/BS, Increasing accuracy Similar LR/BS, Similar accuracy 41
- 42. Tips for Tuning Batch Size and Learning Rate • Learning rate can be decayed when epoch increases. • Learning rate should not be initialized from small value. • To keep the validation accuracy, LR/BS should remain constant when changing batch size. • To achieve higher validation accuracy, increase learning rate or reduce batch size. 42
- 43. Further Readings • An Introduction to Stochastic Differential Equations • http://ft- sipil.unila.ac.id/dbooks/AN%20INTRODUCTION%20TO%20STOCHASTIC %20DIFFERENTIAL%20EQUATIONS%20VERSION%201.2.pdf • Stochastic Modified Equations and Adaptive Stochastic Gradient Algorithms • https://arxiv.org/abs/1511.06251 • Three Factors Influencing Minima in SGD • https://arxiv.org/abs/1710.11029 43

No public clipboards found for this slide

Be the first to comment