我正在尝试在Python中运行一个用于使用TensorFlow进行自组织地图(SOM)的代码。我从这里获得了代码,但是当我运行它时,出现了一个错误:
错误:参数必须为密集张量:range(2,3)-形状为1,但需要[]
我认为相关的代码是:
s = SOM( (3,), 30, num_training, sess )
接```
着:
```js
class SOM:
def __init__(self, input_shape, map_size_n, num_expected_iterations, session):
input_shape = tuple([i for i in input_shape if i is not None])
要么:
def initialize_graph(self):
self.weights = tf.Variable( tf.random_uniform((self.n*self.n, )+self.input_shape, 0.0, 1.0) )
self.input_placeholder = tf.placeholder(tf.float32, (None,)+self.input_shape)
self.current_iteration = tf.placeholder(tf.float32)
## Compute the current iteration's neighborhood sigma and learning rate alpha:
self.sigma_tmp = self.sigma * tf.exp( - self.current_iteration/self.timeconst_sigma )
self.sigma2 = 2.0*tf.multiply(self.sigma_tmp, self.sigma_tmp)
self.alpha_tmp = self.alpha * tf.exp( - self.current_iteration/self.timeconst_alpha )
self.input_placeholder_ = tf.expand_dims(self.input_placeholder, 1)
self.input_placeholder_ = tf.tile(self.input_placeholder_, (1,self.n*self.n,1) )
self.diff = self.input_placeholder_ - self.weights
self.diff_sq = tf.square(self.diff)
self.diff_sum = tf.reduce_sum( self.diff_sq, axis=range(2, 2+len(self.input_shape)) )
# Get the index of the best matching unit
self.bmu_index = tf.argmin(self.diff_sum, 1)
self.bmu_dist = tf.reduce_min(self.diff_sum, 1)
self.bmu_activity = tf.exp( -self.bmu_dist/self.sigma_act )
self.diff = tf.squeeze(self.diff)
self.diff_2 = tf.placeholder(tf.float32, (self.n*self.n,)+self.input_shape)
self.dist_sliced = tf.placeholder(tf.float32, (self.n*self.n,))
self.distances = tf.exp(-self.dist_sliced / self.sigma2 )
self.lr_times_neigh = tf.multiply( self.alpha_tmp, self.distances )
for i in range(len(self.input_shape)):
self.lr_times_neigh = tf.expand_dims(self.lr_times_neigh, -1)
self.lr_times_neigh = tf.tile(self.lr_times_neigh, (1,)+self.input_shape )
self.delta_w = self.lr_times_neigh * self.diff_2
self.update_weights = tf.assign_add(self.weights, self.delta_w)
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。