计算数学精解【15】- flux机器学习精解【1】
flux
概述
Flux是一个用于机器学习的库。它“内置电池”,即包含了许多有用的工具,但同时也允许你在需要时充分利用Julia语言的强大功能。我们遵循以下几个关键原则:
- 做显而易见的事情。Flux的明确API相对较少。相反,只需写下数学形式,它就能工作——而且速度很快。
- 默认可扩展。Flux在保持高性能的同时,也被设计成高度灵活的。扩展Flux就像使用你自己的代码作为你想要模型的一部分一样简单——它全是高级Julia代码。
- 与其他库和谐共处。Flux与从图像处理到微分方程求解器等无关的Julia库都能很好地协同工作,而不是重复它们的功能。
异或
# This will prompt if neccessary to install everything, including CUDA:
using Flux,Cuda,Statistics, ProgressMeter
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
truth = [xor(col[1]>0.5, col[2]>0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool}
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
model = Chain(
Dense(2 => 3, tanh), # activation function inside layer
BatchNorm(3),
Dense(3 => 2)) |> gpu # move model to GPU, if available
# The model encapsulates parameters, randomly initialised. Its initial output is:
out1 = model(noisy |> gpu) |> cpu # 2×1000 Matrix{Float32}
probs1 = softmax(out1) # normalise to get probabilities
# To train the model, we use batches of 64 samples, and one-hot encoding:
target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix
loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true);
# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)
optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.
# Training loop, using the whole data set 1000 times:
losses = []
@showprogress for epoch in 1:1_000
for (x, y) in loader
loss, grads = Flux.withgradient(model) do m
# Evaluate model and loss inside gradient context:
y_hat = m(x)
Flux.logitcrossentropy(y_hat, y)
end
Flux.update!(optim, model, grads[1])
push!(losses, loss) # logging, outside gradient context
end
end
optim # parameters, momenta and output have all changed
out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false)
probs2 = softmax(out2) # normalise to get probabilities
mean((probs2[1,:] .> 0.5) .== truth) # accuracy 94% so far!
参考文献
- https://fluxml.ai/
- 文心一言
原文地址:https://blog.csdn.net/sakura_sea/article/details/142407669
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!