博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
CS231n assignment3 Q4 Style Transfer
阅读量:5044 次
发布时间:2019-06-12

本文共 3901 字,大约阅读时间需要 13 分钟。

.

复现这一篇论文中的代码
loss由三部分组成,内容loss,风格loss,正则化loss,其中风格loss使用gram矩阵。

Content loss

def content_loss(content_weight, content_current, content_original):    """    Compute the content loss for style transfer.        Inputs:    - content_weight: scalar constant we multiply the content_loss by.    - content_current: features of the current image, Tensor with shape [1, height, width, channels]    - content_target: features of the content image, Tensor with shape [1, height, width, channels]        Returns:    - scalar content loss    """    # tf.squared_difference(x,y,name=None) 返回的是(x-y)(x-y)    return content_weight * tf.reduce_sum(tf.squared_difference(content_current,content_original))

Style loss

def gram_matrix(features, normalize=True):    """    Compute the Gram matrix from features.        Inputs:    - features: Tensor of shape (1, H, W, C) giving features for      a single image.    - normalize: optional, whether to normalize the Gram matrix        If True, divide the Gram matrix by the number of neurons (H * W * C)        Returns:    - gram: Tensor of shape (C, C) giving the (optionally normalized)      Gram matrices for the input image.    """    features = tf.transpose(features,[0,3,1,2])    shape = tf.shape(features)    features = tf.reshape(features,(shape[0],shape[1],-1))    transpose_features = tf.transpose(features,[0,2,1])    result = tf.matmul(features,transpose_features)    if normalize:        result = tf.div(result,tf.cast(shape[0] * shape[1] * shape[2] * shape[3],tf.float32))    return resultdef style_loss(feats, style_layers, style_targets, style_weights):    """    Computes the style loss at a set of layers.        Inputs:    - feats: list of the features at every layer of the current image, as produced by      the extract_features function.    - style_layers: List of layer indices into feats giving the layers to include in the      style loss.    - style_targets: List of the same length as style_layers, where style_targets[i] is      a Tensor giving the Gram matrix of the source style image computed at      layer style_layers[i].    - style_weights: List of the same length as style_layers, where style_weights[i]      is a scalar giving the weight for the style loss at layer style_layers[i].          Returns:    - style_loss: A Tensor containing the scalar style loss.    """    # Hint: you can do this with one for loop over the style layers, and should    # not be very much code (~5 lines). You will need to use your gram_matrix function.    style_losses = 0    for i in range(len(style_layers)):        cur_index = style_layers[i]        cur_feat = feats[cur_index]        cur_weight = style_weights[i]        cur_style_target = style_targets[i] #已经是一个gram矩阵了        grammatrix = gram_matrix(cur_feat) #计算当前层的特征图的gram矩阵        style_losses += cur_weight * tf.reduce_sum(tf.squared_difference(grammatrix,cur_style_target))    return style_losses

Total-variation regularization

def tv_loss(img, tv_weight):    """    Compute total variation loss.        Inputs:    - img: Tensor of shape (1, H, W, 3) holding an input image.    - tv_weight: Scalar giving the weight w_t to use for the TV loss.        Returns:    - loss: Tensor holding a scalar giving the total variation loss      for img weighted by tv_weight.    """    # Your implementation should be vectorized and not require any loops!    shape = tf.shape(img)    img_row_before = tf.slice(img,[0,0,0,0],[-1,shape[1]-1,-1,-1])    img_row_after = tf.slice(img,[0,1,0,0],[-1,shape[1]-1,-1,-1])    img_col_before = tf.slice(img,[0,0,0,0],[-1,-1,shape[2]-1,-1])    img_col_after = tf.slice(img,[0,0,1,0],[-1,-1,shape[2]-1,-1])    result = tv_weight * (tf.reduce_sum(tf.squared_difference(img_row_before,img_row_after)) +                           tf.reduce_sum(tf.squared_difference(img_col_before,img_col_after)))    return result

1250085-20190105140329760-2032866622.png

1250085-20190105140344771-804122792.png

1250085-20190105140359823-1452842850.png

1250085-20190105140425283-595895402.png

转载于:https://www.cnblogs.com/bernieloveslife/p/10224313.html

你可能感兴趣的文章
net.sf.json 迄今 时刻 格式 办法
查看>>
奇怪++操作
查看>>
Oracle建立表空间和用户
查看>>
开机黑屏 仅仅显示鼠标 电脑黑屏 仅仅有鼠标 移动 [已成功解决]
查看>>
逃生_拓扑排序
查看>>
Java 中带参带返回值方法的使用
查看>>
JSON.Net 的使用
查看>>
wxWidgets 安装方法(Windows 8.1 + Visual Studio 2013)
查看>>
输入法的选择
查看>>
hdu 1043 八数码--打表
查看>>
html meta标签使用及属性介绍
查看>>
Redis学习笔记一:基本安装和配置
查看>>
rsync 安全复制使用程序
查看>>
使用 resizableImageWithCapInsets 方法实现可伸缩图片
查看>>
2015—全新的起点,新的自己
查看>>
linux内核container_of宏定义分析
查看>>
总结 <stdlib.h>头文件 在算法中可能会用到的一些函数
查看>>
Parkside's Triangle poj3173
查看>>
假期周进度报告08
查看>>
Java Struts2 (四)
查看>>