运行上述代码,即可将测试数据处理成标准格式import csv with open("mnist_test.csv","r+") as f: text = f.read() result = text.split(',') with open("new.csv","w",newline='') as datacsv: csvwriter = csv.writer(datacsv,dialect=("excel")) num = 1 ans = [] for i in result[:7840000]: ans.append(i) if num % 784 == 0: csvwriter.writerow(ans) ans = [] num+=1
from sklearn import tree f = open("train") x = [] tmp = [] y = [] for line in f: line = line.strip('\n') line = line.split(',') y.append(int(line[-1])) for i in range(764): tmp.append(int(line[i])) x.append(tmp) tmp = [] clf = tree.DecisionTreeClassifier() model = clf.fit(x,y) print model count = 0 x_test = [[]] y_test = [] file = open("test") tmp = [] for line in file: line = line.strip('\n') line = line.split(',') label = int(line[-1]) y_test.append(label) for i in range(764): tmp.append(int(line[i])) x_test.append(tmp) res = clf.predict([tmp]) print res #print label,res if(res == label): count += 1.0 tmp = []