Successfully reported this slideshow.
Upcoming SlideShare
×

# Modeling the Dynamics of SGD by Stochastic Differential Equation

• Full Name
Comment goes here.

Are you sure you want to Yes No
• Be the first to comment

### Modeling the Dynamics of SGD by Stochastic Differential Equation

1. 1. Modeling the Dynamics of SGD by Stochastic Differential Equation
2. 2. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 2
3. 3. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 3
4. 4. Stochastic Gradient Descent (SGD) • Gradient Descent v.s. Stochastic Gradient Descent Minimum Gradient Descent Batch size = Dataset size SGD Small batch size SGD Large batch size w : weights ⌘ : learning rate B : batch L(x, w) : loss function w w ⌘ 1 |B| X x2B @L(x, w) @w 4
5. 5. Stochastic Gradient Descent (SGD) • Convergence of SGD • Assume that the loss function is convex E[L(x, ¯w) L(x, w⇤ )]  o( 1 p T ) SGD ¯w Minimum w⇤ The distance is guaranteed to be small T : step counts ¯w : w after T steps w⇤ : w at the minimum of L 5
6. 6. Stochastic Gradient Descent (SGD) • Dynamics of SGD Minimum SGD The process between the starting point and the final solution 6
7. 7. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 7
8. 8. Random Walk x t = tt = t 1/2 probability1/2 probability t = 0 0x ⇢ P(X t = x) = 1 2 P(X t = x) = 1 2 Position of the particle at t is a random variable X t such that 8
9. 9. Random Walk t = t t = 2 t 1/2 1/2 x x0 x x0 2 x2 x 1/4 1/2 1/4 x x0 2 x2 x 1/8 3/8 1/8 t = 3 t 3/8 3 x 3 x X t X2 t X3 t 9
10. 10. Random Walk Normal distribution D = ( x)2 t t = n t, n ! 1, t ! 0 Xt = N(0, Dt) Diffusion coefficient 10
11. 11. Stochastic Differential Equation (SDE) • Ordinary Differential Equation ⇢ dx(t) dt = b(x(t)), where t > 0 x(0) = x0 x0 x(t) Trajectory of x 11
12. 12. ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt , where t > 0 and W(t) is a Wiener process x(0) = x0 Stochastic Differential Equation (SDE) • Stochastic Differential Equation x0 x(t) Trajectory samples of x Deterministic part Stochastic part (Random walk with infinitely small step) 12
13. 13. Stochastic Differential Equation (SDE) • Solving Stochastic Differential Equation ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt x(0) = x0 ⇢ dx(t) = b(x(t))dt + B(x(t))dW(t) x(0) = x0 x(t) = x0 + Z t 0 b(x(s))ds + Z t 0 B(x(s))dW(s) Multiply both sides by dt Integrate both sides by dt Stochastic integral 13
14. 14. Stochastic Differential Equation (SDE) • Solution of a Stochastic Integral is a Random Variable If g : [0, 1] ! R is a deterministic function: E h Z 1 0 gdW i = 0 , and E h ( Z 1 0 gdW)2 i = Z 1 0 g2 dt If G is a stochastic process such that E h Z T 0 G2 dt i < 1 : E h Z T 0 GdW i = 0 , and E h ( Z T 0 GdW)2 i = E h Z T 0 G2 dt i mean variance mean variance 14
15. 15. Outlines • Stochastic Gradient Descent (SGD) • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization 15
16. 16. Continuous-time SGD & Controlled SGD 16
17. 17. Continuous-time SGD & Controlled SGD • Notation Conventions: Gradient Descent : xk+1 = xk ⌘rf(xk) Stochastic Gradient Descent : xk+1 = xk ⌘rf k (xk) f : loss function xk : weights at step k k : index of training sample at step k (assume batch size is 1) fi : loss function calculated by batch i, where f(x) = (1/n)⌃n i=1fi(x) 17
18. 18. Continuous-time SGD & Controlled SGD xk+1 xk = ⌘rf k (xk) xk+1 xk = ⌘rf(xk) + p ⌘Vk Deterministic part Stochastic part minimum Deterministic partStochastic part Vk = p ⌘(rf(xk) rf k (xk)) mean of Vk : 0 covariance of Vk : ⌘⌃(xk), where ⌃(xk) = (1/n)⌃n i=1(rf(xk) rfi(xk))(rf(xk) rfi(xk))T 18
19. 19. Continuous-time SGD & Controlled SGD • Continuos-time SGD xk+1 xk = ⌘rf(xk) + p ⌘Vk Convert to continuous time domain dXt = rf(Xt)dt + p ⌘⌃(Xt)dWt dXt = r(f(Xt) + ⌘ 4 |rf(Xt)|2 )dt + p ⌘⌃(Xt)dWt Continuos-time SGD, Order 1 and Order 2 weak approximation 19
20. 20. Continuous-time SGD & Controlled SGD • A Toy Example • Continuous-time SGD (order 2 weak approximation): • Solution: f(x) = x2 , f1(x) = (x 1)2 1, f2(x) = (x + 1)2 1 f2(x) f1(x)f(x) dXt = 2(1 + ⌘)Xtdt + 2 p ⌘dWt Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t )) 20
21. 21. Continuous-time SGD & Controlled SGD Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t )) t x E[Xt] = ⇢ x0, when t = 0 0, when t ! 1 x0 Var[Xt] = ⇢ 0, when t = 0 ⌘ 1+⌘ , when t ! 1 E[Xt⇤ ] = p Var[Xt⇤ ] Fluctuations phaseDescent phase r ⌘ 1 + ⌘ 21
22. 22. Continuous-time SGD & Controlled SGD • Controlled SGD : Adaptive Hyper-parameter Adjustment xk+1 = xk ⌘ukf0 (xk), where uk 2 [0, 1] is adjustment factor Optimal Control Formulation min ut Ef(Xt) subject to : dXt = utf0 (Xt)dt + ut p ⌘⌃(Xt)dWt 22
23. 23. Continuous-time SGD & Controlled SGD • Quadratic Objective Function • Continuous-time SGD: • Optimal control policy : dXt = aut(Xt b)dt + ut p ⌘⌃dWt f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x) u⇤ t = ⇢ 1 if a  0 or t  t⇤ 1 1+a(t t⇤) if a > 0 and t > t⇤ 23
24. 24. Continuous-time SGD & Controlled SGD • Optimal control policy u⇤ t = ⇢ 1 if a  0 or t  t⇤ , ( t  t⇤ is descent phase) 1 1+a(t t⇤) if a > 0 and t > t⇤ , ( t > t⇤ is ﬂuctuations phase) t x Fluctuations phase Descent phase t⇤ a  0 a > 0 f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x) 24
25. 25. Continuous-time SGD & Controlled SGD • General Objective Function f(x) and fi(x) is not necessarily quadratic, and x 2 Rd assume f(x) ⇡ 1 2 dX i=1 a(i)(x(i) b(i))2 hold locally in x, and ⌃ ⇡ diag{⌃(1), ..., ⌃(d)} where each ⌃(i) is locally constant. (each dimension is independent) 25
26. 26. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms At each step k, estimate ak,(i), bk,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 . Since rf(i) ⇡ a(i)(x(i) b(i)), we use linear regression to estimate ak,(i), bk,(i): 1 2 ak,(i)(xk,(i) bk,(i))2 xk,(i) xk 1,(i)ak,(i) = gxk,(i) gk,(i)xk,(i) x2 k,(i) x2 k,(i) , and bk,(i) = xk,(i) gk,(i) ak,(i) where gk,(i) = rf k (xk)(i), and gk+1,(i) = k,(i)gk,(i) + (1 k,(i))gk,(i) Exponential moving average 26
27. 27. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms Solve the optimal control policy u⇤ k,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 u⇤ k,(i) = ( 1 if a  0, min(1, ak,(i)(¯xk,(i) bk,(i))2 ⌘⌃k,(i) ) if a > 0 . where ⌃k,(i) = g2 k,(i) ¯g2 k,(i) 27
28. 28. MNIST fully connected NN CIFAR-10 fully connected NN CIFAR-10 CNN 28
29. 29. Continuous-time SGD & Controlled SGD • Implementation of cSGD • https://github.com/LiQianxiao/cSGD-cMSGD 29
30. 30. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Effects of SGD on Generalization 30
31. 31. Effects of SGD on Generalization ICANN 2018 31
32. 32. Effects of SGD on Generalization • Notation Conventions: 32 Loss function : L(✓) = 1 N NX n=1 l(✓, xn), where N is the size of dataset GD : ✓k+1 = ✓k ⌘g(✓k), where g(✓) = @L @✓ SGD : ✓k+1 = ✓k ⌘g(S) (✓k), where gS (✓) = 1 S X n2B @ @✓ l(✓, xn), B is batch and S is batch size
33. 33. Effects of SGD on Generalization • Continuous-time SGD d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) S is the covariance of ⇣ g(S) (✓) g(✓) ⌘ 33
34. 34. Effects of SGD on Generalization • Effects of different learning rate and batch size Minimum small ⌘ large ⌘ Minimum small S large S d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) 34
35. 35. Effects of SGD on Generalization • Flat minimum v.s. Sharp minimum (https://arxiv.org/abs/1609.04836) Loss function Loss function (evaluated on testing data) Flat minimum Sharp minimum High testing error Low testing error 35
36. 36. Effects of SGD on Generalization • Effects of learning rate / batch size on generalization d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) Flat Minimum Sharp Minimum Flat Minimum Sharp Minimum large r ⌘ S , (large ⌘, small S)small r ⌘ S , (small ⌘, large S) 36
37. 37. Effects of SGD on Generalization • Theoretical Explanation • Assumption 1 : the loss surface can be approximated by a quadratic bowl, with minimum at zero loss. • Assumption 2 : The covariance of the gradients and the Hessian of the loss approximation are approximately equal, i.e. C = H. 37
38. 38. Effects of SGD on Generalization • Theoretical Explanation d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) dz = ⇤zdt + r ⌘ S p ⇤dW(t) Change of variables: z : New variable, where z = V T (✓ ✓⇤ ) ✓⇤ : The parameters at the minimum V : Orthogonal matrix of the eigen decomposition H = V ⇤V T H : The Hession of L(✓) 38
39. 39. Effects of SGD on Generalization • Theoretical Explanation • Expectation of loss function dz = ⇤zdt + r ⌘ S p ⇤dW(t) E(L) = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) Ornstein-Unhlenbeck process for z, solution : E[z] = 0 and cov[z] = ⌘ 2S I 39
40. 40. • Theoretical Explanation Effects of SGD on Generalization E = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) =) E(L) Tr(H) / ⌘ S low ⌘ S Sharp minimum : low E(L) Tr(H) Flat munimum : high E(L) Tr(H) (Minima with similar loss values) high ⌘ S 40
41. 41. Increasing LR/BS, Increasing accuracy Similar LR/BS, Similar accuracy 41
42. 42. Tips for Tuning Batch Size and Learning Rate • Learning rate can be decayed when epoch increases. • Learning rate should not be initialized from small value. • To keep the validation accuracy, LR/BS should remain constant when changing batch size. • To achieve higher validation accuracy, increase learning rate or reduce batch size. 42
43. 43. Further Readings • An Introduction to Stochastic Differential Equations • http://ft- sipil.unila.ac.id/dbooks/AN%20INTRODUCTION%20TO%20STOCHASTIC %20DIFFERENTIAL%20EQUATIONS%20VERSION%201.2.pdf • Stochastic Modified Equations and Adaptive Stochastic Gradient Algorithms • https://arxiv.org/abs/1511.06251 • Three Factors Influencing Minima in SGD • https://arxiv.org/abs/1710.11029 43