梯度下降法解决线性模型

1,梯度下降法

给定一个数据集,x_data、y_data。寻找y=wx模型的w最优解。 代码如下:

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

scope_list = []
w_list = []

w = 60

#学习率
k = 0.01
# 

for i in range(200):
    # 计算cost(loss的和)
    loss_sum = 0
    for x_val, y_val in zip(x_data,y_data):
        loss_sum += 2 * x_val * (w*x_val - y_val)
    cost = loss_sum / 3
    
    # 计算本轮w
    w = w - k * cost
    print(w)
    scope_list.append(i)
    w_list.append(w)

plt.plot(scope_list,w_list)
plt.xlabel("scope")
plt.ylabel("W")
plt.show()

2,batch(随机梯度下降法)

代码如下:


import random
import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

scope_list = []
w_list = []

w = 60

#学习率
k = 0.01

for i in range(200):
    # 计算cost(即随机一个loss当cost用)
    rand = random.randint(0,2)
    cost = 2 * x_data[rand] * (w*x_data[rand] - y_data[rand])
    
    # 计算本轮w
    w = w - k * cost
    print(w)
    scope_list.append(i)
    w_list.append(w)

plt.plot(scope_list,w_list)
plt.xlabel("scope")
plt.ylabel("W")
plt.show()

全部评论

相关推荐

06-26 19:47
中南大学 Java
悲,毕业了!这是个坏事儿啊!
爱睡觉的冰箱哥:《这是个好事啊》---峰哥浪走天涯
毕业后不工作的日子里我在...
点赞 评论 收藏
分享
每晚夜里独自颤抖:你cet6就cet6,cet4就cet4,你写个cet证书等是什么意思。专业技能快赶上项目行数,你做的这2个项目哪里能提现你有这么多技能呢
点赞 评论 收藏
分享
06-25 21:00
门头沟学院 Java
多拆解背记一下当前的高频场景面试题,结合自己的项目经历去作答,面试通过率原来真的不会低!
牛客96559368...:小公司不就是这样的吗,面试要么是点击就送,要么就是往死里拷打,没有一个统一的标准。这个不能代表所有公司
点赞 评论 收藏
分享
评论
2
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务