1、首先对数据进行读取和预处理
2、读取数据后,对x数据进行标准化处理,以便于后续训练的稳定性,并转换为tensor格式
3、接下来设置训练参数和模型
这里采用回归模型,既y=x*weight1+bias1,设置的学习率为0.0006,损失函数采用了MSE(均方误差)
4、绘制图像
由于数据量较少,所以将整个训练集作为测试集,观察生成的图像
完整代码
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
import warnings
warnings.filterwarnings("ignore")
# In[4]:
features = pd.read_csv('房价预测.csv')
features
# In[26]:
year = []
price = []
for i in range(0,12):
year.append([features['Year'][i]])
price.append([features['Price'][i]])
# In[27]:
year = np.array(year)
price = np.array(price)
year,price
# In[53]:
from sklearn import preprocessing
# 特征标准化处理
year = preprocessing.StandardScaler().fit_transform(year)
year[0]
# In[54]:
x = torch.tensor(year,dtype=float)
y = torch.tensor(price,dtype=float)
x,y
# In[62]:
learning_rate = 0.0001
weights1 = torch.randn((1,1),dtype=float,requires_grad=True)
bias1 = torch.randn(1,dtype=float,requires_grad=True)
losses = []
for i in range(0, 5000):
ans = x.mm(weights1) + bias1
#计算损失
criterion = torch.nn.MSELoss() # 使用适当的损失函数
loss = criterion(ans, y)
losses.append(loss)
if i%100==0:
print(f'loss={loss},epoch={i},w={weights1}')
#反向传播
loss.backward()
#更新参数
weights1.data.add_(-learning_rate*weights1.grad.data)
bias1.data.add_(-learning_rate*bias1.grad.data)
#清空
weights1.grad.data.zero_()
bias1.grad.data.zero_()
# 使用 features['Year'] 和 features['Price'] 创建日期和价格的列表
year = features['Year']
price = features['Price']
# 将 ans 转换为 Python 列表
ans_list = ans.tolist()
# 提取列表中的每个元素(确保是单个的标量值)
predictions = [item[0] for item in ans_list]
# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data={'date': year, 'actual': price})
predictions_data = pd.DataFrame(data={'date': year, 'prediction': predictions})
# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label='actual')
# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label='prediction')
plt.xticks(rotation='60')
plt.legend()
# 图名
plt.xlabel('Date')
plt.ylabel('Price') # 注意修改为你的标签
plt.title('Actual and Predicted Values')
plt.show()
本文由博客一文多发平台 OpenWrite 发布!
1.本站内容仅供参考,不作为任何法律依据。用户在使用本站内容时,应自行判断其真实性、准确性和完整性,并承担相应风险。
2.本站部分内容来源于互联网,仅用于交流学习研究知识,若侵犯了您的合法权益,请及时邮件或站内私信与本站联系,我们将尽快予以处理。
3.本文采用知识共享 署名4.0国际许可协议 [BY-NC-SA] 进行授权
4.根据《计算机软件保护条例》第十七条规定“为了学习和研究软件内含的设计思想和原理,通过安装、显示、传输或者存储软件等方式使用软件的,可以不经软件著作权人许可,不向其支付报酬。”您需知晓本站所有内容资源均来源于网络,仅供用户交流学习与研究使用,版权归属原版权方所有,版权争议与本站无关,用户本人下载后不能用作商业或非法用途,需在24个小时之内从您的电脑中彻底删除上述内容,否则后果均由用户承担责任;如果您访问和下载此文件,表示您同意只将此文件用于参考、学习而非其他用途,否则一切后果请您自行承担,如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。
5.本站是非经营性个人站点,所有软件信息均来自网络,所有资源仅供学习参考研究目的,并不贩卖软件,不存在任何商业目的及用途
暂无评论内容