Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.

Modeling the Dynamics of SGD by Stochastic Differential Equation

Modeling the Dynamics of SGD by Stochastic Differential Equation

  • Be the first to comment

Modeling the Dynamics of SGD by Stochastic Differential Equation

  1. 1. Modeling the Dynamics of SGD by Stochastic Differential Equation Mark Chang 2020/09/18
  2. 2. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
  3. 3. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
  4. 4. Stochastic Gradient Descent (SGD) • Machine Learning Problem Training data Model Loss function Gradient descent (update the weights in Model) Label Input data Prediction
  5. 5. Stochastic Gradient Descent (SGD) • Gradient Descent v.s. Stochastic Gradient Descent Minimum Gradient Descent Batch size = Dataset size SGD Small batch size SGD Large batch size w : weights ⌘ : learning rate B : batch L(x, w) : loss function w w ⌘ 1 |B| X x2B @L(x, w) @w
  6. 6. Stochastic Gradient Descent (SGD) • Convergence of SGD • Assume that the loss function is convex E[L(x, ¯w) L(x, w⇤ )]  o( 1 p T ) SGD ¯w Minimum w⇤ The distance is guaranteed to be smallT : step counts ¯w : w after T steps w⇤ : w at the minimum of L
  7. 7. Stochastic Gradient Descent (SGD) • Dynamics of SGD Minimum SGD The process between the starting point and the final solution
  8. 8. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
  9. 9. Random Walk Minimum Seed=1 Seed=2 Seed=0 Path of particle 2 Path of particle 0 Path of particle 1 SGD Random Walk
  10. 10. Random Walk x t = tt = t 1/2 probability1/2 probability t = 0 0x ⇢ P(X t = x) = 1 2 P(X t = x) = 1 2 Position of the particle at t is a random variable X t such that
  11. 11. Random Walk t = t t = 2 t 1/2 1/2 x x0 x x0 2 x2 x 1/4 1/2 1/4 x x0 2 x2 x 1/8 3/8 1/8 t = 3 t 3/8 3 x 3 x X t X2 t X3 t
  12. 12. Random Walk Normal distribution D = ( x)2 t t = n t, n ! 1, t ! 0 Xt = N(0, Dt) Diffusion coefficient
  13. 13. Diffusion t = 0 t = T p(x, t = 0) = N(0, 0) p(x, t = T) = N(0, DT) Probability density function of Xt : p(x, t) Di↵usion equation : @p(x, t) @t = D 2 @2 p(x, t) @x2
  14. 14. Wiener process A stochastic process W(·) is called a Wiener process if: (1) W(0) = 0 almost surely, (2) W(t) W(s) ⇠ N(0, t s) for all t s 0, (3) W(t1), W(t2) W(t1), ..., W(tn) W(tn 1) are independent random variables. for all tn > tn 1 > · · · > t2 > t1 > 0 W(t) = Xn t is a Wiener process when t = n t, n ! 1, t ! 0
  15. 15. Wiener process • Random Walk t = 2 t x x0 2 x2 x 1/4 1/2 1/4 0 1/8 t = 3 t x x x x 0 x x0 1/8 1/8 1/8 1/41/4 X2 t X3 t X2 t = X t
  16. 16. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
  17. 17. Stochastic Differential Equation (SDE) • Ordinary Differential Equation ⇢ dx(t) dt = b(x(t)), where t > 0 x(0) = x0 x0 x(t) Trajectory of x
  18. 18. ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt , where t > 0 and W(t) is a Wiener process x(0) = x0 Stochastic Differential Equation (SDE) • Stochastic Differential Equation x0 x(t) Trajectory samples of x Deterministic part Stochastic part
  19. 19. Stochastic Differential Equation (SDE) • Solving Stochastic Differential Equation ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt x(0) = x0 ⇢ dx(t) = b(x(t))dt + B(x(t))dW(t) x(0) = x0 x(t) = x0 + Z t 0 b(x(s))ds + Z t 0 B(x(s))dW(s) Multiplyboth sides by dt Integrate both sides by dt Stochastic integral
  20. 20. Stochastic Differential Equation (SDE) • Solution of a Stochastic Integral is a Random Variable If g : [0, 1] ! R is a deterministic function: E h Z 1 0 gdW i = 0 , and E h ( Z 1 0 gdW)2 i = Z 1 0 g2 dt If G is a stochastic process such that E h Z T 0 G2 dt i < 1 : E h Z T 0 GdW i = 0 , and E h ( Z T 0 GdW)2 i = E h Z T 0 G2 dt i mean variance mean variance
  21. 21. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
  22. 22. Continuous-time SGD & Controlled SGD
  23. 23. Continuous-time SGD & Controlled SGD • Notation Conventions: Gradient Descent : xk+1 = xk ⌘rf(xk) Stochastic Gradient Descent : xk+1 = xk ⌘rf k (xk) f : loss function xk : weights at step k k : index of training sample at step k (assume batch size is 1) fi : loss function calculated by batch i, where f(x) = (1/n)⌃n i=1fi(x)
  24. 24. Continuous-time SGD & Controlled SGD xk+1 xk = ⌘rf k (xk) xk+1 xk = ⌘rf(xk) + p ⌘Vk Vk = p ⌘(rf(xk) f k (xk)) mean of Vk : 0 covariance of Vk : ⌘⌃(xk), where ⌃(xk) = (1/n)⌃n i=1(rf(xk) rfi(xk))(rf(xk) rfi(xk))T Deterministic part Stochastic part minimum Deterministic partStochastic part
  25. 25. Continuous-time SGD & Controlled SGD • Continuos-time SGD xk+1 xk = ⌘rf(xk) + p ⌘Vk Convert to continuous time domain dXt = rf(Xt)dt + p ⌘⌃(Xt)dWt dXt = r(f(Xt) + ⌘ 4 |rf(Xt)|2 )dt + p ⌘⌃(Xt)dWt Continuos-time SGD, Order 1 and Order 2 weak approximation
  26. 26. Continuous-time SGD & Controlled SGD • A Toy Example • Continuous-time SGD (order 2 weak approximation): • Solution: f(x) = x2 , f1(x) = (x 1)2 1, f2(x) = (x + 1)2 1 f2(x) f1(x)f(x) dXt = 2(1 + ⌘)Xtdt + 2 p ⌘dWt Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t ))
  27. 27. Continuous-time SGD & Controlled SGD Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t )) t x E[Xt] = ⇢ x0, when t = 0 0, when t ! 1 x0 Var[Xt] = ⇢ 0, when t = 0 ⌘ 1+⌘ , when t ! 1 E[Xt⇤ ] = p Var[Xt⇤ ] Fluctuations phase Descent phase r ⌘ 1 + ⌘
  28. 28. Continuous-time SGD & Controlled SGD
  29. 29. Continuous-time SGD & Controlled SGD • Controlled SGD : Adaptive Hyper-parameter Adjustment xk+1 = xk ⌘ukf0 (xk), where uk 2 [0, 1] is adjustment factor Optimal Control Formulation min ut Ef(Xt) subject to : dXt = utf0 (Xt)dt + ut p ⌘⌃(Xt)dWt
  30. 30. Continuous-time SGD & Controlled SGD • Quadratic Objective Function • Continuous-time SGD: • Optimal control policy : dXt = aut(Xt b)dt + ut p ⌘⌃dWt f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x) u⇤ t = ⇢ 1 if a  0 or t  t⇤ 1 1+a(t t⇤) if a > 0 and t > t⇤
  31. 31. Continuous-time SGD & Controlled SGD • Optimal control policy u⇤ t = ⇢ 1 if a  0 or t  t⇤ , ( t  t⇤ is descent phase) 1 1+a(t t⇤) if a > 0 and t > t⇤ , ( t > t⇤ is fluctuations phase) t x Fluctuations phase Descent phase t⇤ a  0 a > 0 f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x)
  32. 32. Continuous-time SGD & Controlled SGD • General Objective Function f(x) and fi(x) is not necessarily quadratic, and x 2 Rd assume f(x) ⇡ 1 2 dX i=1 a(i)(x(i) b(i))2 hold locally in x, and ⌃ ⇡ diag{⌃(1), ..., ⌃(d)} where each ⌃(i) is locally constant. (each dimension is independent)
  33. 33. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms At each step k, estimate ak,(i), bk,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 . Since rf(i) ⇡ a(i)(x(i) b(i)), we use linear regression to estimate ak,(i), bk,(i): 1 2 ak,(i)(xk,(i) bk,(i))2 xk,(i) xk 1,(i)ak,(i) = gxk,(i) gk,(i)xk,(i) x2 k,(i) x2 k,(i) , and bk,(i) = xk,(i) gk,(i) ak,(i) where gk,(i) = rf k (xk)(i), and gk+1,(i) = k,(i)gk,(i) + (1 k,(i))gk,(i) Exponentialmoving average
  34. 34. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms Solve the optimal control policy u⇤ k,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 u⇤ k,(i) = ( 1 if a  0, min(1, ak,(i)(¯xk,(i) bk,(i))2 ⌘⌃k,(i) ) if a > 0 . where ⌃k,(i) = g2 k,(i) ¯g2 k,(i)
  35. 35. MNIST fully connected NN CIFAR-10 fully connected NN CIFAR-10 CNN
  36. 36. Continuous-time SGD & Controlled SGD • Implementation of cSGD • https://github.com/LiQianxiao/cSGD-cMSGD
  37. 37. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
  38. 38. Effects of SGD on Generalization ICANN 2018
  39. 39. Effects of SGD on Generalization • Notation Conventions: Loss function : L(✓) = 1 N NX n=1 l(✓, xn), where N is the size of dataset GD : ✓k+1 = ✓k ⌘g(✓k), where g(✓) = @L @✓ SGD : ✓k+1 = ✓k ⌘g(S) (✓k), where gS (✓) = 1 S X n2B @ @✓ l(✓, xn), B is batch and S is batch size)
  40. 40. Effects of SGD on Generalization • Continuous-time SGD d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) S is the covariance of ⇣ g(S) (✓) g(✓) ⌘
  41. 41. Effects of SGD on Generalization • Effects of different learning rate and batch size Minimum small ⌘ large ⌘ Minimum small S large S d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓)
  42. 42. Effects of SGD on Generalization • Flat minimum v.s. Sharp minimum (https://arxiv.org/abs/1609.04836) Loss function Loss function (evaluated on testing data) Flat minimum Sharp minimum High testing error Low testing error
  43. 43. Effects of SGD on Generalization • Effects of learning rate / batch size on generalization d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) Flat Minimum Sharp Minimum Flat Minimum Sharp Minimum large r ⌘ S , (large ⌘, small S)small r ⌘ S , (small ⌘, large S)
  44. 44. Effects of SGD on Generalization • Intuition • Larger learning rate or smaller batch size • -> Flatter minimum • -> Less over-fitting
  45. 45. Effects of SGD on Generalization • Theoretical Explanation d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) dz = ⇤zdt + r ⌘ S p ⇤dW(t) Change of variables: z : New variable, where z = V T (✓ ✓⇤ ) ✓⇤ : The parameters at the minimum V : Orthogonal matrix of the eigen decomposition H = V ⇤V T H : The Hession of L(✓)
  46. 46. Effects of SGD on Generalization • Theoretical Explanation • Expectation of loss function dz = ⇤zdt + r ⌘ S p ⇤dW(t) E(L) = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) Ornstein-Unhlenbeck process for z, solution : E[z] = 0 and cov[z] = ⌘ 2S I
  47. 47. • Theoretical Explanation Effects of SGD on Generalization E = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) =) E(L) Tr(H) / ⌘ S low ⌘ S Sharp minimum : low E(L) Tr(H) Flat munimum : high E(L) Tr(H) (Minima with similar loss values) high ⌘ S
  48. 48. Increasing LR/BS, Increasing accuracy Similar LR/BS, Similar accuracy
  49. 49. Increasing LR/BS, Increasing accuracy Similar LR/BS, Similar accuracy
  50. 50. Tips for Tuning Batch Size and Learning Rate • Learning rate can be decayed when epoch increases. • Learning rate should not be initialized from small value. • To keep the validation accuracy, LR/BS should remain constant when changing batch size. • To achieve higher validation accuracy, increase learning rate or reduce batch size.
  51. 51. Further Readings • An Introduction to Stochastic Differential Equations • http://ft- sipil.unila.ac.id/dbooks/AN%20INTRODUCTION%20TO%20STOCHASTIC%20DI FFERENTIAL%20EQUATIONS%20VERSION%201.2.pdf • Stochastic Modified Equations and Adaptive Stochastic Gradient Algorithms • https://arxiv.org/abs/1511.06251 • Three Factors Influencing Minima in SGD • https://arxiv.org/abs/1710.11029

×