Descarga la aplicación para disfrutar aún más
Vista previa del material en texto
Algoritmo de aprendizaje reforzado utilizando la biblioteca OpenAI Gym y el algoritmo DQN (Deep Q-Network): import gym import tensorflow as tf from tensorflow.keras import layers # Crear el entorno de gym env = gym.make('CartPole-v1') # Definir la red neuronal model = tf.keras.Sequential([ layers.Dense(32, activation='relu', input_shape=(4,)), layers.Dense(32, activation='relu'), layers.Dense(env.action_space.n) ]) # Definir el algoritmo de optimización optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) # Función de pérdida y actualización de la red neuronal def compute_loss(logits, actions, rewards): ce_loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True) actions = tf.cast(actions, tf.int32) loss = ce_loss(actions, logits, sample_weight=rewards) return loss # Entrenamiento del agente @tf.function def train_step(states, actions, rewards, next_states, done): with tf.GradientTape() as tape: logits = model(states, training=True) loss = compute_loss(logits, actions, rewards) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # Bucle principal de entrenamiento for episode in range(num_episodes): state = env.reset() episode_reward = 0 done = False while not done: # Seleccionar una acción utilizando la política epsilon-greedy epsilon = max(0.1, 0.9 * (1 - episode / num_episodes)) if np.random.rand() < epsilon: action = env.action_space.sample() else: logits = model(tf.expand_dims(state, 0)) action = tf.argmax(logits, axis=1)[0].numpy() # Tomar la acción y obtener la siguiente observación next_state, reward, done, _ = env.step(action) # Actualizar la recompensa acumulada episode_reward += reward # Guardar la transición en el buffer de memoria replay_buffer.append((state, action, reward, next_state, done)) # Actualizar la red neuronal if len(replay_buffer) >= batch_size: batch = random.sample(replay_buffer, batch_size) train_step(*zip(*batch)) # Actualizar el estado actual state = next_state
Compartir