paddlepaddle 分支语句

条件分支和循环:
某些场景下,用户需要根据当前的网络状态,来具体决定后续使用哪一种操作,或者根据需要的网络状态来重复执行某些操作。针对这类需求, 提供了 两个API来实现条件分支的操作,以及 来实现循环操作。

lr = fluid.layers.tensor.create_global_var(
        shape=[1],
        value=0.0,
        dtype='float32',
        persistable=True,
        name="learning_rate")

one_var = fluid.layers.fill_constant(
        shape=[1], dtype='float32', value=1.0)
two_var = fluid.layers.fill_constant(
        shape=[1], dtype='float32', value=2.0)

with fluid.layers.control_flow.Switch() as switch:
    with switch.case(global_step == zero_var):
        fluid.layers.tensor.assign(input=one_var, output=lr)
    with switch.default():
        fluid.layers.tensor.assign(input=two_var, output=lr)

If else的例子

import numpy as np
import paddle.fluid as fluid
x = fluid.layers.data(name='x',shape=(4,1),dtype='float32',append_batch_size=False)
y = fluid.layers.data(name='y',shape=(4,1),dtype='float32',append_batch_size=False)
x_d = np.array([[3], [1], [-2], [-3]]).astype(np.float32)
y_d = np.zeros((4, 1)).astype(np.float32)
cond = fluid.layers.greater_than(x,y)
ie = fluid.layers.IfElse(cond)
with ie.true_block():
    out_1 = ie.input(x)
    out_1 = out_1 - 10
    ie.output(out_1)
with ie.false_block():
    out_2 = ie.input(x)
    out_2 = out_2 + 10
    ie.output(out_2)

output = ie()
out = fluid.layers.reduce_sum(output[0])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())

res = exe.run(fluid.default_main_program(), feed={"x":x_d, "y":y_d}, fetch_list=[out])
print res

while的例子

import paddle.fluid as fluid
import numpy as np

i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)           # 循环计数器

loop_len = fluid.layers.fill_constant(shape=[1],dtype='int64', value=10)    # 循环次数

cond = fluid.layers.less_than(x=i, y=loop_len)              # 循环条件
while_op = fluid.layers.While(cond=cond)
with while_op.block():  # 循环体
    i = fluid.layers.increment(x=i, value=1, in_place=True)
    fluid.layers.less_than(x=i, y=loop_len, cond=cond)      # 更新循环条件

exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())

res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[i])
print(res) # [array([10])]
全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务