4. Stochastic Gradient Descent (SGD)
• Gradient Descent v.s. Stochastic Gradient Descent
Minimum
Gradient Descent
Batch size = Dataset size
SGD
Small batch size
SGD
Large batch size
w : weights
⌘ : learning rate
B : batch
L(x, w) : loss function
w w ⌘
1
|B|
X
x2B
@L(x, w)
@w
4
5. Stochastic Gradient Descent (SGD)
• Convergence of SGD
• Assume that the loss function is convex
E[L(x, ¯w) L(x, w⇤
)] o(
1
p
T
) SGD
¯w
Minimum
w⇤
The distance is
guaranteed to be small
T : step counts
¯w : w after T steps
w⇤
: w at the minimum of L
5
6. Stochastic Gradient Descent (SGD)
• Dynamics of SGD
Minimum
SGD
The process between the
starting point and the final
solution
6
8. Random Walk
x
t = tt = t
1/2 probability1/2 probability
t = 0
0x
⇢
P(X t = x) = 1
2
P(X t = x) = 1
2
Position of the particle at t is a random variable X t such that
8
9. Random Walk
t = t
t = 2 t
1/2 1/2
x x0
x x0 2 x2 x
1/4 1/2 1/4
x x0 2 x2 x
1/8 3/8 1/8
t = 3 t
3/8
3 x 3 x
X t
X2 t
X3 t
9
11. Stochastic Differential Equation (SDE)
• Ordinary Differential Equation
⇢ dx(t)
dt = b(x(t)), where t > 0
x(0) = x0
x0
x(t)
Trajectory of x
11
12. ⇢ dx(t)
dt = b(x(t)) + B(x(t))dW (t)
dt , where t > 0 and W(t) is a Wiener process
x(0) = x0
Stochastic Differential Equation (SDE)
• Stochastic Differential Equation
x0
x(t)
Trajectory samples of x
Deterministic
part
Stochastic
part
(Random walk with
infinitely small step)
12
13. Stochastic Differential Equation (SDE)
• Solving Stochastic Differential Equation
⇢ dx(t)
dt = b(x(t)) + B(x(t))dW (t)
dt
x(0) = x0
⇢
dx(t) = b(x(t))dt + B(x(t))dW(t)
x(0) = x0
x(t) = x0 +
Z t
0
b(x(s))ds +
Z t
0
B(x(s))dW(s)
Multiply both sides by dt
Integrate both sides by dt
Stochastic integral
13
14. Stochastic Differential Equation (SDE)
• Solution of a Stochastic Integral is a Random Variable
If g : [0, 1] ! R is a deterministic function:
E
h Z 1
0
gdW
i
= 0 , and E
h
(
Z 1
0
gdW)2
i
=
Z 1
0
g2
dt
If G is a stochastic process such that E
h Z T
0
G2
dt
i
< 1 :
E
h Z T
0
GdW
i
= 0 , and E
h
(
Z T
0
GdW)2
i
= E
h Z T
0
G2
dt
i
mean variance
mean variance
14
17. Continuous-time SGD & Controlled SGD
• Notation Conventions:
Gradient Descent : xk+1 = xk ⌘rf(xk)
Stochastic Gradient Descent : xk+1 = xk ⌘rf k
(xk)
f : loss function
xk : weights at step k
k : index of training sample at step k (assume batch size is 1)
fi : loss function calculated by batch i, where f(x) = (1/n)⌃n
i=1fi(x)
17
18. Continuous-time SGD & Controlled SGD
xk+1 xk = ⌘rf k
(xk)
xk+1 xk = ⌘rf(xk) +
p
⌘Vk
Deterministic
part
Stochastic
part minimum
Deterministic
partStochastic
part
Vk =
p
⌘(rf(xk) rf k
(xk))
mean of Vk : 0
covariance of Vk : ⌘⌃(xk),
where ⌃(xk) = (1/n)⌃n
i=1(rf(xk) rfi(xk))(rf(xk) rfi(xk))T
18
19. Continuous-time SGD & Controlled SGD
• Continuos-time SGD
xk+1 xk = ⌘rf(xk) +
p
⌘Vk
Convert to continuous time domain
dXt = rf(Xt)dt +
p
⌘⌃(Xt)dWt
dXt = r(f(Xt) +
⌘
4
|rf(Xt)|2
)dt +
p
⌘⌃(Xt)dWt
Continuos-time SGD,
Order 1 and Order 2
weak approximation
19
21. Continuous-time SGD & Controlled SGD
Xt ⇠ N(x0e 2(1+⌘)t
,
⌘
1 + ⌘
(1 e 4(1+⌘)t
))
t
x
E[Xt] =
⇢
x0, when t = 0
0, when t ! 1
x0
Var[Xt] =
⇢
0, when t = 0
⌘
1+⌘ , when t ! 1
E[Xt⇤ ] =
p
Var[Xt⇤ ]
Fluctuations phaseDescent phase r
⌘
1 + ⌘
21
22. Continuous-time SGD & Controlled SGD
• Controlled SGD : Adaptive Hyper-parameter Adjustment
xk+1 = xk ⌘ukf0
(xk), where uk 2 [0, 1] is adjustment factor
Optimal Control
Formulation
min
ut
Ef(Xt) subject to :
dXt = utf0
(Xt)dt + ut
p
⌘⌃(Xt)dWt
22
23. Continuous-time SGD & Controlled SGD
• Quadratic Objective Function
• Continuous-time SGD:
• Optimal control policy :
dXt = aut(Xt b)dt + ut
p
⌘⌃dWt
f(x) =
1
2
a(x b)2
, assume the covariance of f0
i is ⌘⌃(x)
u⇤
t =
⇢
1 if a 0 or t t⇤
1
1+a(t t⇤) if a > 0 and t > t⇤
23
24. Continuous-time SGD & Controlled SGD
• Optimal control policy
u⇤
t =
⇢
1 if a 0 or t t⇤
, ( t t⇤
is descent phase)
1
1+a(t t⇤) if a > 0 and t > t⇤
, ( t > t⇤
is fluctuations phase)
t
x
Fluctuations
phase
Descent
phase t⇤
a 0 a > 0
f(x) =
1
2
a(x b)2
, assume the covariance of f0
i is ⌘⌃(x)
24
25. Continuous-time SGD & Controlled SGD
• General Objective Function
f(x) and fi(x) is not necessarily quadratic, and x 2 Rd
assume f(x) ⇡
1
2
dX
i=1
a(i)(x(i) b(i))2
hold locally in x, and
⌃ ⇡ diag{⌃(1), ..., ⌃(d)} where each ⌃(i) is locally constant.
(each dimension is independent)
25
26. Continuous-time SGD & Controlled SGD
• Controlled SGD Algorithms
At each step k, estimate ak,(i), bk,(i) for
1
2
ak,(i)(xk,(i) bk,(i))2
.
Since rf(i) ⇡ a(i)(x(i) b(i)),
we use linear regression to estimate ak,(i), bk,(i):
1
2
ak,(i)(xk,(i) bk,(i))2
xk,(i)
xk 1,(i)ak,(i) =
gxk,(i) gk,(i)xk,(i)
x2
k,(i) x2
k,(i)
, and bk,(i) = xk,(i)
gk,(i)
ak,(i)
where gk,(i) = rf k
(xk)(i), and gk+1,(i) = k,(i)gk,(i) + (1 k,(i))gk,(i)
Exponential moving average
26
27. Continuous-time SGD & Controlled SGD
• Controlled SGD Algorithms
Solve the optimal control policy u⇤
k,(i) for
1
2
ak,(i)(xk,(i) bk,(i))2
u⇤
k,(i) =
(
1 if a 0,
min(1,
ak,(i)(¯xk,(i) bk,(i))2
⌘⌃k,(i)
) if a > 0 .
where ⌃k,(i) = g2
k,(i) ¯g2
k,(i)
27
30. Outlines
• Stochastic Gradient Descent (SGD)
• Random Walk, Diffusion and Wiener process
• Stochastic Differential Equation (SDE)
• Effects of SGD on Generalization
30
32. Effects of SGD on Generalization
• Notation Conventions:
32
Loss function : L(✓) =
1
N
NX
n=1
l(✓, xn), where N is the size of dataset
GD : ✓k+1 = ✓k ⌘g(✓k), where g(✓) =
@L
@✓
SGD : ✓k+1 = ✓k ⌘g(S)
(✓k), where gS
(✓) =
1
S
X
n2B
@
@✓
l(✓, xn),
B is batch and S is batch size
33. Effects of SGD on Generalization
• Continuous-time SGD
d✓ = g(✓)dt +
r
⌘
S
R(✓)dW(t),
where R(✓)R(✓)T
= C(✓) and
C(✓)
S
is the covariance of
⇣
g(S)
(✓) g(✓)
⌘
33
34. Effects of SGD on Generalization
• Effects of different learning rate and batch size
Minimum
small ⌘
large ⌘
Minimum
small S
large S
d✓ = g(✓)dt +
r
⌘
S
R(✓)dW(t),
where R(✓)R(✓)T
= C(✓) and C(✓) is the covariance of g(✓)
34
35. Effects of SGD on Generalization
• Flat minimum v.s. Sharp minimum (https://arxiv.org/abs/1609.04836)
Loss
function
Loss function
(evaluated on
testing data)
Flat
minimum
Sharp
minimum
High
testing
error
Low
testing
error
35
36. Effects of SGD on Generalization
• Effects of learning rate / batch size on generalization
d✓ = g(✓)dt +
r
⌘
S
R(✓)dW(t),
where R(✓)R(✓)T
= C(✓) and C(✓) is the covariance of g(✓)
Flat
Minimum
Sharp
Minimum
Flat
Minimum
Sharp
Minimum
large
r
⌘
S
, (large ⌘, small S)small
r
⌘
S
, (small ⌘, large S)
36
37. Effects of SGD on Generalization
• Theoretical Explanation
• Assumption 1 : the loss surface can be approximated by a quadratic bowl, with
minimum at zero loss.
• Assumption 2 : The covariance of the gradients and the Hessian of the loss
approximation are approximately equal, i.e. C = H.
37
38. Effects of SGD on Generalization
• Theoretical Explanation
d✓ = g(✓)dt +
r
⌘
S
R(✓)dW(t),
where R(✓)R(✓)T
= C(✓) and C(✓) is the covariance of g(✓)
dz = ⇤zdt +
r
⌘
S
p
⇤dW(t)
Change of variables:
z : New variable, where z = V T
(✓ ✓⇤
)
✓⇤
: The parameters at the minimum
V : Orthogonal matrix of the eigen decomposition H = V ⇤V T
H : The Hession of L(✓)
38
39. Effects of SGD on Generalization
• Theoretical Explanation
• Expectation of loss function
dz = ⇤zdt +
r
⌘
S
p
⇤dW(t)
E(L) =
1
2
qX
i=1
iE(z2
i ) =
⌘
4S
Tr(⇤) =
⌘
4S
Tr(H)
Ornstein-Unhlenbeck process for z, solution : E[z] = 0 and cov[z] =
⌘
2S
I
39
40. • Theoretical Explanation
Effects of SGD on Generalization
E =
1
2
qX
i=1
iE(z2
i ) =
⌘
4S
Tr(⇤) =
⌘
4S
Tr(H)
=)
E(L)
Tr(H)
/
⌘
S
low
⌘
S
Sharp minimum : low
E(L)
Tr(H)
Flat munimum : high
E(L)
Tr(H)
(Minima with similar loss values)
high
⌘
S
40
42. Tips for Tuning Batch Size and Learning
Rate
• Learning rate can be decayed when epoch increases.
• Learning rate should not be initialized from small value.
• To keep the validation accuracy, LR/BS should remain constant when
changing batch size.
• To achieve higher validation accuracy, increase learning rate or reduce
batch size.
42
43. Further Readings
• An Introduction to Stochastic Differential Equations
• http://ft-
sipil.unila.ac.id/dbooks/AN%20INTRODUCTION%20TO%20STOCHASTIC
%20DIFFERENTIAL%20EQUATIONS%20VERSION%201.2.pdf
• Stochastic Modified Equations and Adaptive Stochastic Gradient
Algorithms
• https://arxiv.org/abs/1511.06251
• Three Factors Influencing Minima in SGD
• https://arxiv.org/abs/1710.11029
43