Successfully reported this slideshow.
Upcoming SlideShare
×

# Modeling the Dynamics of SGD by Stochastic Differential Equation

Modeling the Dynamics of SGD by Stochastic Differential Equation

• Full Name
Comment goes here.

Are you sure you want to Yes No
• Be the first to comment

### Modeling the Dynamics of SGD by Stochastic Differential Equation

1. 1. Modeling the Dynamics of SGD by Stochastic Differential Equation Mark Chang 2020/09/18
2. 2. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
3. 3. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
4. 4. Stochastic Gradient Descent (SGD) • Machine Learning Problem Training data Model Loss function Gradient descent (update the weights in Model) Label Input data Prediction
5. 5. Stochastic Gradient Descent (SGD) • Gradient Descent v.s. Stochastic Gradient Descent Minimum Gradient Descent Batch size = Dataset size SGD Small batch size SGD Large batch size w : weights ⌘ : learning rate B : batch L(x, w) : loss function w w ⌘ 1 |B| X x2B @L(x, w) @w
6. 6. Stochastic Gradient Descent (SGD) • Convergence of SGD • Assume that the loss function is convex E[L(x, ¯w) L(x, w⇤ )]  o( 1 p T ) SGD ¯w Minimum w⇤ The distance is guaranteed to be smallT : step counts ¯w : w after T steps w⇤ : w at the minimum of L
7. 7. Stochastic Gradient Descent (SGD) • Dynamics of SGD Minimum SGD The process between the starting point and the final solution
8. 8. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
9. 9. Random Walk Minimum Seed=1 Seed=2 Seed=0 Path of particle 2 Path of particle 0 Path of particle 1 SGD Random Walk
10. 10. Random Walk x t = tt = t 1/2 probability1/2 probability t = 0 0x ⇢ P(X t = x) = 1 2 P(X t = x) = 1 2 Position of the particle at t is a random variable X t such that
11. 11. Random Walk t = t t = 2 t 1/2 1/2 x x0 x x0 2 x2 x 1/4 1/2 1/4 x x0 2 x2 x 1/8 3/8 1/8 t = 3 t 3/8 3 x 3 x X t X2 t X3 t
12. 12. Random Walk Normal distribution D = ( x)2 t t = n t, n ! 1, t ! 0 Xt = N(0, Dt) Diffusion coefficient
13. 13. Diffusion t = 0 t = T p(x, t = 0) = N(0, 0) p(x, t = T) = N(0, DT) Probability density function of Xt : p(x, t) Di↵usion equation : @p(x, t) @t = D 2 @2 p(x, t) @x2
14. 14. Wiener process A stochastic process W(·) is called a Wiener process if: (1) W(0) = 0 almost surely, (2) W(t) W(s) ⇠ N(0, t s) for all t s 0, (3) W(t1), W(t2) W(t1), ..., W(tn) W(tn 1) are independent random variables. for all tn > tn 1 > · · · > t2 > t1 > 0 W(t) = Xn t is a Wiener process when t = n t, n ! 1, t ! 0
15. 15. Wiener process • Random Walk t = 2 t x x0 2 x2 x 1/4 1/2 1/4 0 1/8 t = 3 t x x x x 0 x x0 1/8 1/8 1/8 1/41/4 X2 t X3 t X2 t = X t
16. 16. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
17. 17. Stochastic Differential Equation (SDE) • Ordinary Differential Equation ⇢ dx(t) dt = b(x(t)), where t > 0 x(0) = x0 x0 x(t) Trajectory of x
18. 18. ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt , where t > 0 and W(t) is a Wiener process x(0) = x0 Stochastic Differential Equation (SDE) • Stochastic Differential Equation x0 x(t) Trajectory samples of x Deterministic part Stochastic part
19. 19. Stochastic Differential Equation (SDE) • Solving Stochastic Differential Equation ⇢ dx(t) dt = b(x(t)) + B(x(t))dW (t) dt x(0) = x0 ⇢ dx(t) = b(x(t))dt + B(x(t))dW(t) x(0) = x0 x(t) = x0 + Z t 0 b(x(s))ds + Z t 0 B(x(s))dW(s) Multiplyboth sides by dt Integrate both sides by dt Stochastic integral
20. 20. Stochastic Differential Equation (SDE) • Solution of a Stochastic Integral is a Random Variable If g : [0, 1] ! R is a deterministic function: E h Z 1 0 gdW i = 0 , and E h ( Z 1 0 gdW)2 i = Z 1 0 g2 dt If G is a stochastic process such that E h Z T 0 G2 dt i < 1 : E h Z T 0 GdW i = 0 , and E h ( Z T 0 GdW)2 i = E h Z T 0 G2 dt i mean variance mean variance
21. 21. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
22. 22. Continuous-time SGD & Controlled SGD
23. 23. Continuous-time SGD & Controlled SGD • Notation Conventions: Gradient Descent : xk+1 = xk ⌘rf(xk) Stochastic Gradient Descent : xk+1 = xk ⌘rf k (xk) f : loss function xk : weights at step k k : index of training sample at step k (assume batch size is 1) fi : loss function calculated by batch i, where f(x) = (1/n)⌃n i=1fi(x)
24. 24. Continuous-time SGD & Controlled SGD xk+1 xk = ⌘rf k (xk) xk+1 xk = ⌘rf(xk) + p ⌘Vk Vk = p ⌘(rf(xk) f k (xk)) mean of Vk : 0 covariance of Vk : ⌘⌃(xk), where ⌃(xk) = (1/n)⌃n i=1(rf(xk) rfi(xk))(rf(xk) rfi(xk))T Deterministic part Stochastic part minimum Deterministic partStochastic part
25. 25. Continuous-time SGD & Controlled SGD • Continuos-time SGD xk+1 xk = ⌘rf(xk) + p ⌘Vk Convert to continuous time domain dXt = rf(Xt)dt + p ⌘⌃(Xt)dWt dXt = r(f(Xt) + ⌘ 4 |rf(Xt)|2 )dt + p ⌘⌃(Xt)dWt Continuos-time SGD, Order 1 and Order 2 weak approximation
26. 26. Continuous-time SGD & Controlled SGD • A Toy Example • Continuous-time SGD (order 2 weak approximation): • Solution: f(x) = x2 , f1(x) = (x 1)2 1, f2(x) = (x + 1)2 1 f2(x) f1(x)f(x) dXt = 2(1 + ⌘)Xtdt + 2 p ⌘dWt Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t ))
27. 27. Continuous-time SGD & Controlled SGD Xt ⇠ N(x0e 2(1+⌘)t , ⌘ 1 + ⌘ (1 e 4(1+⌘)t )) t x E[Xt] = ⇢ x0, when t = 0 0, when t ! 1 x0 Var[Xt] = ⇢ 0, when t = 0 ⌘ 1+⌘ , when t ! 1 E[Xt⇤ ] = p Var[Xt⇤ ] Fluctuations phase Descent phase r ⌘ 1 + ⌘
28. 28. Continuous-time SGD & Controlled SGD
29. 29. Continuous-time SGD & Controlled SGD • Controlled SGD : Adaptive Hyper-parameter Adjustment xk+1 = xk ⌘ukf0 (xk), where uk 2 [0, 1] is adjustment factor Optimal Control Formulation min ut Ef(Xt) subject to : dXt = utf0 (Xt)dt + ut p ⌘⌃(Xt)dWt
30. 30. Continuous-time SGD & Controlled SGD • Quadratic Objective Function • Continuous-time SGD: • Optimal control policy : dXt = aut(Xt b)dt + ut p ⌘⌃dWt f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x) u⇤ t = ⇢ 1 if a  0 or t  t⇤ 1 1+a(t t⇤) if a > 0 and t > t⇤
31. 31. Continuous-time SGD & Controlled SGD • Optimal control policy u⇤ t = ⇢ 1 if a  0 or t  t⇤ , ( t  t⇤ is descent phase) 1 1+a(t t⇤) if a > 0 and t > t⇤ , ( t > t⇤ is ﬂuctuations phase) t x Fluctuations phase Descent phase t⇤ a  0 a > 0 f(x) = 1 2 a(x b)2 , assume the covariance of f0 i is ⌘⌃(x)
32. 32. Continuous-time SGD & Controlled SGD • General Objective Function f(x) and fi(x) is not necessarily quadratic, and x 2 Rd assume f(x) ⇡ 1 2 dX i=1 a(i)(x(i) b(i))2 hold locally in x, and ⌃ ⇡ diag{⌃(1), ..., ⌃(d)} where each ⌃(i) is locally constant. (each dimension is independent)
33. 33. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms At each step k, estimate ak,(i), bk,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 . Since rf(i) ⇡ a(i)(x(i) b(i)), we use linear regression to estimate ak,(i), bk,(i): 1 2 ak,(i)(xk,(i) bk,(i))2 xk,(i) xk 1,(i)ak,(i) = gxk,(i) gk,(i)xk,(i) x2 k,(i) x2 k,(i) , and bk,(i) = xk,(i) gk,(i) ak,(i) where gk,(i) = rf k (xk)(i), and gk+1,(i) = k,(i)gk,(i) + (1 k,(i))gk,(i) Exponentialmoving average
34. 34. Continuous-time SGD & Controlled SGD • Controlled SGD Algorithms Solve the optimal control policy u⇤ k,(i) for 1 2 ak,(i)(xk,(i) bk,(i))2 u⇤ k,(i) = ( 1 if a  0, min(1, ak,(i)(¯xk,(i) bk,(i))2 ⌘⌃k,(i) ) if a > 0 . where ⌃k,(i) = g2 k,(i) ¯g2 k,(i)
35. 35. MNIST fully connected NN CIFAR-10 fully connected NN CIFAR-10 CNN
36. 36. Continuous-time SGD & Controlled SGD • Implementation of cSGD • https://github.com/LiQianxiao/cSGD-cMSGD
37. 37. Outlines • Stochastic Gradient Descent (SGD) • Random Walk, Diffusion and Wiener process • Stochastic Differential Equation (SDE) • Continuous-time SGD & Controlled SGD • Effects of SGD on Generalization
38. 38. Effects of SGD on Generalization ICANN 2018
39. 39. Effects of SGD on Generalization • Notation Conventions: Loss function : L(✓) = 1 N NX n=1 l(✓, xn), where N is the size of dataset GD : ✓k+1 = ✓k ⌘g(✓k), where g(✓) = @L @✓ SGD : ✓k+1 = ✓k ⌘g(S) (✓k), where gS (✓) = 1 S X n2B @ @✓ l(✓, xn), B is batch and S is batch size)
40. 40. Effects of SGD on Generalization • Continuous-time SGD d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) S is the covariance of ⇣ g(S) (✓) g(✓) ⌘
41. 41. Effects of SGD on Generalization • Effects of different learning rate and batch size Minimum small ⌘ large ⌘ Minimum small S large S d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓)
42. 42. Effects of SGD on Generalization • Flat minimum v.s. Sharp minimum (https://arxiv.org/abs/1609.04836) Loss function Loss function (evaluated on testing data) Flat minimum Sharp minimum High testing error Low testing error
43. 43. Effects of SGD on Generalization • Effects of learning rate / batch size on generalization d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) Flat Minimum Sharp Minimum Flat Minimum Sharp Minimum large r ⌘ S , (large ⌘, small S)small r ⌘ S , (small ⌘, large S)
44. 44. Effects of SGD on Generalization • Intuition • Larger learning rate or smaller batch size • -> Flatter minimum • -> Less over-fitting
45. 45. Effects of SGD on Generalization • Theoretical Explanation d✓ = g(✓)dt + r ⌘ S R(✓)dW(t), where R(✓)R(✓)T = C(✓) and C(✓) is the covariance of g(✓) dz = ⇤zdt + r ⌘ S p ⇤dW(t) Change of variables: z : New variable, where z = V T (✓ ✓⇤ ) ✓⇤ : The parameters at the minimum V : Orthogonal matrix of the eigen decomposition H = V ⇤V T H : The Hession of L(✓)
46. 46. Effects of SGD on Generalization • Theoretical Explanation • Expectation of loss function dz = ⇤zdt + r ⌘ S p ⇤dW(t) E(L) = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) Ornstein-Unhlenbeck process for z, solution : E[z] = 0 and cov[z] = ⌘ 2S I
47. 47. • Theoretical Explanation Effects of SGD on Generalization E = 1 2 qX i=1 iE(z2 i ) = ⌘ 4S Tr(⇤) = ⌘ 4S Tr(H) =) E(L) Tr(H) / ⌘ S low ⌘ S Sharp minimum : low E(L) Tr(H) Flat munimum : high E(L) Tr(H) (Minima with similar loss values) high ⌘ S
48. 48. Increasing LR/BS, Increasing accuracy Similar LR/BS, Similar accuracy
49. 49. Increasing LR/BS, Increasing accuracy Similar LR/BS, Similar accuracy
50. 50. Tips for Tuning Batch Size and Learning Rate • Learning rate can be decayed when epoch increases. • Learning rate should not be initialized from small value. • To keep the validation accuracy, LR/BS should remain constant when changing batch size. • To achieve higher validation accuracy, increase learning rate or reduce batch size.
51. 51. Further Readings • An Introduction to Stochastic Differential Equations • http://ft- sipil.unila.ac.id/dbooks/AN%20INTRODUCTION%20TO%20STOCHASTIC%20DI FFERENTIAL%20EQUATIONS%20VERSION%201.2.pdf • Stochastic Modified Equations and Adaptive Stochastic Gradient Algorithms • https://arxiv.org/abs/1511.06251 • Three Factors Influencing Minima in SGD • https://arxiv.org/abs/1710.11029